diff --git a/.github/workflows/translate-i18n-base-on-english.yml b/.github/workflows/translate-i18n-base-on-english.yml new file mode 100644 index 0000000000..3f51b3b2c7 --- /dev/null +++ b/.github/workflows/translate-i18n-base-on-english.yml @@ -0,0 +1,54 @@ +name: Check i18n Files and Create PR + +on: + pull_request: + types: [closed] + branches: [main] + +jobs: + check-and-update: + if: github.event.pull_request.merged == true + runs-on: ubuntu-latest + defaults: + run: + working-directory: web + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 2 # last 2 commits + + - name: Check for file changes in i18n/en-US + id: check_files + run: | + recent_commit_sha=$(git rev-parse HEAD) + second_recent_commit_sha=$(git rev-parse HEAD~1) + changed_files=$(git diff --name-only $recent_commit_sha $second_recent_commit_sha -- 'i18n/en-US/*.ts') + echo "Changed files: $changed_files" + if [ -n "$changed_files" ]; then + echo "FILES_CHANGED=true" >> $GITHUB_ENV + else + echo "FILES_CHANGED=false" >> $GITHUB_ENV + fi + + - name: Set up Node.js + if: env.FILES_CHANGED == 'true' + uses: actions/setup-node@v2 + with: + node-version: 'lts/*' + + - name: Install dependencies + if: env.FILES_CHANGED == 'true' + run: yarn install --frozen-lockfile + + - name: Run npm script + if: env.FILES_CHANGED == 'true' + run: npm run auto-gen-i18n + + - name: Create Pull Request + if: env.FILES_CHANGED == 'true' + uses: peter-evans/create-pull-request@v6 + with: + commit-message: Update i18n files based on en-US changes + title: 'chore: translate i18n files' + body: This PR was automatically created to update i18n files based on changes in en-US locale. + branch: chore/automated-i18n-updates diff --git a/.gitignore b/.gitignore index c52b9d8bbf..22aba65364 100644 --- a/.gitignore +++ b/.gitignore @@ -178,3 +178,4 @@ pyrightconfig.json api/.vscode .idea/ +.vscode \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f810584f24..8f57cd545e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -8,7 +8,7 @@ In terms of licensing, please take a minute to read our short [License and Contr ## Before you jump in -[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types: +[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:open) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types: ### Feature requests: diff --git a/CONTRIBUTING_CN.md b/CONTRIBUTING_CN.md index 303c2513f5..7cd2bb60eb 100644 --- a/CONTRIBUTING_CN.md +++ b/CONTRIBUTING_CN.md @@ -8,7 +8,7 @@ ## 在开始之前 -[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:closed)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类: +[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:open)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类: ### 功能请求: diff --git a/CONTRIBUTING_JA.md b/CONTRIBUTING_JA.md index 6d5bfb205c..a68bdeddbc 100644 --- a/CONTRIBUTING_JA.md +++ b/CONTRIBUTING_JA.md @@ -10,7 +10,7 @@ Dify にコントリビュートしたいとお考えなのですね。それは ## 飛び込む前に -[既存の Issue](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) を探すか、[新しい Issue](https://github.com/langgenius/dify/issues/new/choose) を作成してください。私たちは Issue を 2 つのタイプに分類しています。 +[既存の Issue](https://github.com/langgenius/dify/issues?q=is:issue+is:open) を探すか、[新しい Issue](https://github.com/langgenius/dify/issues/new/choose) を作成してください。私たちは Issue を 2 つのタイプに分類しています。 ### 機能リクエスト diff --git a/CONTRIBUTING_VI.md b/CONTRIBUTING_VI.md new file mode 100644 index 0000000000..80e68a046e --- /dev/null +++ b/CONTRIBUTING_VI.md @@ -0,0 +1,156 @@ +Thật tuyệt vời khi bạn muốn đóng góp cho Dify! Chúng tôi rất mong chờ được thấy những gì bạn sẽ làm. Là một startup với nguồn nhân lực và tài chính hạn chế, chúng tôi có tham vọng lớn là thiết kế quy trình trực quan nhất để xây dựng và quản lý các ứng dụng LLM. Mọi sự giúp đỡ từ cộng đồng đều rất quý giá đối với chúng tôi. + +Chúng tôi cần linh hoạt và làm việc nhanh chóng, nhưng đồng thời cũng muốn đảm bảo các cộng tác viên như bạn có trải nghiệm đóng góp thuận lợi nhất có thể. Chúng tôi đã tạo ra hướng dẫn đóng góp này nhằm giúp bạn làm quen với codebase và cách chúng tôi làm việc với các cộng tác viên, để bạn có thể nhanh chóng bắt tay vào phần thú vị. + +Hướng dẫn này, cũng như bản thân Dify, đang trong quá trình cải tiến liên tục. Chúng tôi rất cảm kích sự thông cảm của bạn nếu đôi khi nó không theo kịp dự án thực tế, và chúng tôi luôn hoan nghênh mọi phản hồi để cải thiện. + +Về vấn đề cấp phép, xin vui lòng dành chút thời gian đọc qua [Thỏa thuận Cấp phép và Đóng góp](./LICENSE) ngắn gọn của chúng tôi. Cộng đồng cũng tuân thủ [quy tắc ứng xử](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md). + +## Trước khi bắt đầu + +[Tìm kiếm](https://github.com/langgenius/dify/issues?q=is:issue+is:open) một vấn đề hiện có, hoặc [tạo mới](https://github.com/langgenius/dify/issues/new/choose) một vấn đề. Chúng tôi phân loại các vấn đề thành 2 loại: + +### Yêu cầu tính năng: + +* Nếu bạn đang tạo một yêu cầu tính năng mới, chúng tôi muốn bạn giải thích tính năng đề xuất sẽ đạt được điều gì và cung cấp càng nhiều thông tin chi tiết càng tốt. [@perzeusss](https://github.com/perzeuss) đã tạo một [Trợ lý Yêu cầu Tính năng](https://udify.app/chat/MK2kVSnw1gakVwMX) rất hữu ích để giúp bạn soạn thảo nhu cầu của mình. Hãy thử dùng nó nhé. + +* Nếu bạn muốn chọn một vấn đề từ danh sách hiện có, chỉ cần để lại bình luận dưới vấn đề đó nói rằng bạn sẽ làm. + + Một thành viên trong nhóm làm việc trong lĩnh vực liên quan sẽ được thông báo. Nếu mọi thứ ổn, họ sẽ cho phép bạn bắt đầu code. Chúng tôi yêu cầu bạn chờ đợi cho đến lúc đó trước khi bắt tay vào làm tính năng, để không lãng phí công sức của bạn nếu chúng tôi đề xuất thay đổi. + + Tùy thuộc vào lĩnh vực mà tính năng đề xuất thuộc về, bạn có thể nói chuyện với các thành viên khác nhau trong nhóm. Dưới đây là danh sách các lĩnh vực mà các thành viên trong nhóm chúng tôi đang làm việc hiện tại: + + | Thành viên | Phạm vi | + | ------------------------------------------------------------ | ---------------------------------------------------- | + | [@yeuoly](https://github.com/Yeuoly) | Thiết kế kiến trúc Agents | + | [@jyong](https://github.com/JohnJyong) | Thiết kế quy trình RAG | + | [@GarfieldDai](https://github.com/GarfieldDai) | Xây dựng quy trình làm việc | + | [@iamjoel](https://github.com/iamjoel) & [@zxhlyh](https://github.com/zxhlyh) | Làm cho giao diện người dùng dễ sử dụng | + | [@guchenhe](https://github.com/guchenhe) & [@crazywoola](https://github.com/crazywoola) | Trải nghiệm nhà phát triển, đầu mối liên hệ cho mọi vấn đề | + | [@takatost](https://github.com/takatost) | Định hướng và kiến trúc tổng thể sản phẩm | + + Cách chúng tôi ưu tiên: + + | Loại tính năng | Mức độ ưu tiên | + | ------------------------------------------------------------ | -------------- | + | Tính năng ưu tiên cao được gắn nhãn bởi thành viên trong nhóm | Ưu tiên cao | + | Yêu cầu tính năng phổ biến từ [bảng phản hồi cộng đồng](https://github.com/langgenius/dify/discussions/categories/feedbacks) của chúng tôi | Ưu tiên trung bình | + | Tính năng không quan trọng và cải tiến nhỏ | Ưu tiên thấp | + | Có giá trị nhưng không cấp bách | Tính năng tương lai | + +### Những vấn đề khác (ví dụ: báo cáo lỗi, tối ưu hiệu suất, sửa lỗi chính tả): + +* Bắt đầu code ngay lập tức. + + Cách chúng tôi ưu tiên: + + | Loại vấn đề | Mức độ ưu tiên | + | ------------------------------------------------------------ | -------------- | + | Lỗi trong các chức năng chính (không thể đăng nhập, ứng dụng không hoạt động, lỗ hổng bảo mật) | Nghiêm trọng | + | Lỗi không quan trọng, cải thiện hiệu suất | Ưu tiên trung bình | + | Sửa lỗi nhỏ (lỗi chính tả, giao diện người dùng gây nhầm lẫn nhưng vẫn hoạt động) | Ưu tiên thấp | + + +## Cài đặt + +Dưới đây là các bước để thiết lập Dify cho việc phát triển: + +### 1. Fork repository này + +### 2. Clone repository + + Clone repository đã fork từ terminal của bạn: + +``` +git clone git@github.com:/dify.git +``` + +### 3. Kiểm tra các phụ thuộc + +Dify yêu cầu các phụ thuộc sau để build, hãy đảm bảo chúng đã được cài đặt trên hệ thống của bạn: + +- [Docker](https://www.docker.com/) +- [Docker Compose](https://docs.docker.com/compose/install/) +- [Node.js v18.x (LTS)](http://nodejs.org) +- [npm](https://www.npmjs.com/) phiên bản 8.x.x hoặc [Yarn](https://yarnpkg.com/) +- [Python](https://www.python.org/) phiên bản 3.10.x + +### 4. Cài đặt + +Dify bao gồm một backend và một frontend. Đi đến thư mục backend bằng lệnh `cd api/`, sau đó làm theo hướng dẫn trong [README của Backend](api/README.md) để cài đặt. Trong một terminal khác, đi đến thư mục frontend bằng lệnh `cd web/`, sau đó làm theo hướng dẫn trong [README của Frontend](web/README.md) để cài đặt. + +Kiểm tra [FAQ về cài đặt](https://docs.dify.ai/learn-more/faq/self-host-faq) để xem danh sách các vấn đề thường gặp và các bước khắc phục. + +### 5. Truy cập Dify trong trình duyệt của bạn + +Để xác nhận cài đặt của bạn, hãy truy cập [http://localhost:3000](http://localhost:3000) (địa chỉ mặc định, hoặc URL và cổng bạn đã cấu hình) trong trình duyệt. Bạn sẽ thấy Dify đang chạy. + +## Phát triển + +Nếu bạn đang thêm một nhà cung cấp mô hình, [hướng dẫn này](https://github.com/langgenius/dify/blob/main/api/core/model_runtime/README.md) dành cho bạn. + +Nếu bạn đang thêm một nhà cung cấp công cụ cho Agent hoặc Workflow, [hướng dẫn này](./api/core/tools/README.md) dành cho bạn. + +Để giúp bạn nhanh chóng định hướng phần đóng góp của mình, dưới đây là một bản phác thảo ngắn gọn về cấu trúc backend & frontend của Dify: + +### Backend + +Backend của Dify được viết bằng Python sử dụng [Flask](https://flask.palletsprojects.com/en/3.0.x/). Nó sử dụng [SQLAlchemy](https://www.sqlalchemy.org/) cho ORM và [Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html) cho hàng đợi tác vụ. Logic xác thực được thực hiện thông qua Flask-login. + +``` +[api/] +├── constants // Các cài đặt hằng số được sử dụng trong toàn bộ codebase. +├── controllers // Định nghĩa các route API và logic xử lý yêu cầu. +├── core // Điều phối ứng dụng cốt lõi, tích hợp mô hình và công cụ. +├── docker // Cấu hình liên quan đến Docker & containerization. +├── events // Xử lý và xử lý sự kiện +├── extensions // Mở rộng với các framework/nền tảng bên thứ 3. +├── fields // Định nghĩa trường cho serialization/marshalling. +├── libs // Thư viện và tiện ích có thể tái sử dụng. +├── migrations // Script cho việc di chuyển cơ sở dữ liệu. +├── models // Mô hình cơ sở dữ liệu & định nghĩa schema. +├── services // Xác định logic nghiệp vụ. +├── storage // Lưu trữ khóa riêng tư. +├── tasks // Xử lý các tác vụ bất đồng bộ và công việc nền. +└── tests +``` + +### Frontend + +Website được khởi tạo trên boilerplate [Next.js](https://nextjs.org/) bằng Typescript và sử dụng [Tailwind CSS](https://tailwindcss.com/) cho styling. [React-i18next](https://react.i18next.com/) được sử dụng cho việc quốc tế hóa. + +``` +[web/] +├── app // layouts, pages và components +│ ├── (commonLayout) // layout chung được sử dụng trong toàn bộ ứng dụng +│ ├── (shareLayout) // layouts được chia sẻ cụ thể cho các phiên dựa trên token +│ ├── activate // trang kích hoạt +│ ├── components // được chia sẻ bởi các trang và layouts +│ ├── install // trang cài đặt +│ ├── signin // trang đăng nhập +│ └── styles // styles được chia sẻ toàn cục +├── assets // Tài nguyên tĩnh +├── bin // scripts chạy ở bước build +├── config // cài đặt và tùy chọn có thể điều chỉnh +├── context // contexts được chia sẻ bởi các phần khác nhau của ứng dụng +├── dictionaries // File dịch cho từng ngôn ngữ +├── docker // cấu hình container +├── hooks // Hooks có thể tái sử dụng +├── i18n // Cấu hình quốc tế hóa +├── models // mô tả các mô hình dữ liệu & hình dạng của phản hồi API +├── public // tài nguyên meta như favicon +├── service // xác định hình dạng của các hành động API +├── test +├── types // mô tả các tham số hàm và giá trị trả về +└── utils // Các hàm tiện ích được chia sẻ +``` + +## Gửi PR của bạn + +Cuối cùng, đã đến lúc mở một pull request (PR) đến repository của chúng tôi. Đối với các tính năng lớn, chúng tôi sẽ merge chúng vào nhánh `deploy/dev` để kiểm tra trước khi đưa vào nhánh `main`. Nếu bạn gặp vấn đề như xung đột merge hoặc không biết cách mở pull request, hãy xem [hướng dẫn về pull request của GitHub](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests). + +Và thế là xong! Khi PR của bạn được merge, bạn sẽ được giới thiệu là một người đóng góp trong [README](https://github.com/langgenius/dify/blob/main/README.md) của chúng tôi. + +## Nhận trợ giúp + +Nếu bạn gặp khó khăn hoặc có câu hỏi cấp bách trong quá trình đóng góp, hãy đặt câu hỏi của bạn trong vấn đề GitHub liên quan, hoặc tham gia [Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi để trò chuyện nhanh chóng. \ No newline at end of file diff --git a/api/.env.example b/api/.env.example index 775149f8fd..e41e2271d5 100644 --- a/api/.env.example +++ b/api/.env.example @@ -60,7 +60,8 @@ ALIYUN_OSS_SECRET_KEY=your-secret-key ALIYUN_OSS_ENDPOINT=your-endpoint ALIYUN_OSS_AUTH_VERSION=v1 ALIYUN_OSS_REGION=your-region - +# Don't start with '/'. OSS doesn't support leading slash in object names. +ALIYUN_OSS_PATH=your-path # Google Storage configuration GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string @@ -247,8 +248,8 @@ API_TOOL_DEFAULT_READ_TIMEOUT=60 HTTP_REQUEST_MAX_CONNECT_TIMEOUT=300 HTTP_REQUEST_MAX_READ_TIMEOUT=600 HTTP_REQUEST_MAX_WRITE_TIMEOUT=600 -HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 # 10MB -HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 # 1MB +HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 +HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 # Log file path LOG_FILE= @@ -267,4 +268,13 @@ APP_MAX_ACTIVE_REQUESTS=0 # Celery beat configuration -CELERY_BEAT_SCHEDULER_TIME=1 \ No newline at end of file +CELERY_BEAT_SCHEDULER_TIME=1 + +# Position configuration +POSITION_TOOL_PINS= +POSITION_TOOL_INCLUDES= +POSITION_TOOL_EXCLUDES= + +POSITION_PROVIDER_PINS= +POSITION_PROVIDER_INCLUDES= +POSITION_PROVIDER_EXCLUDES= diff --git a/.idea/icon.png b/api/.idea/icon.png similarity index 100% rename from .idea/icon.png rename to api/.idea/icon.png diff --git a/.idea/vcs.xml b/api/.idea/vcs.xml similarity index 88% rename from .idea/vcs.xml rename to api/.idea/vcs.xml index ae8b1755c5..b7af618884 100644 --- a/.idea/vcs.xml +++ b/api/.idea/vcs.xml @@ -12,5 +12,6 @@ + - \ No newline at end of file + diff --git a/.vscode/launch.json b/api/.vscode/launch.json.example similarity index 83% rename from .vscode/launch.json rename to api/.vscode/launch.json.example index e4eb6aef93..e9f8e42dd5 100644 --- a/.vscode/launch.json +++ b/api/.vscode/launch.json.example @@ -5,8 +5,8 @@ "name": "Python: Flask", "type": "debugpy", "request": "launch", - "python": "${workspaceFolder}/api/.venv/bin/python", - "cwd": "${workspaceFolder}/api", + "python": "${workspaceFolder}/.venv/bin/python", + "cwd": "${workspaceFolder}", "envFile": ".env", "module": "flask", "justMyCode": true, @@ -18,15 +18,15 @@ "args": [ "run", "--host=0.0.0.0", - "--port=5001", + "--port=5001" ] }, { "name": "Python: Celery", "type": "debugpy", "request": "launch", - "python": "${workspaceFolder}/api/.venv/bin/python", - "cwd": "${workspaceFolder}/api", + "python": "${workspaceFolder}/.venv/bin/python", + "cwd": "${workspaceFolder}", "module": "celery", "justMyCode": true, "envFile": ".env", diff --git a/api/Dockerfile b/api/Dockerfile index 06a6f43631..ccdec017ab 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -5,6 +5,10 @@ WORKDIR /app/api # Install Poetry ENV POETRY_VERSION=1.8.3 + +# if you located in China, you can use aliyun mirror to speed up +# RUN pip install --no-cache-dir poetry==${POETRY_VERSION} -i https://mirrors.aliyun.com/pypi/simple/ + RUN pip install --no-cache-dir poetry==${POETRY_VERSION} # Configure Poetry @@ -16,6 +20,9 @@ ENV POETRY_REQUESTS_TIMEOUT=15 FROM base AS packages +# if you located in China, you can use aliyun mirror to speed up +# RUN sed -i 's@deb.debian.org@mirrors.aliyun.com@g' /etc/apt/sources.list.d/debian.sources + RUN apt-get update \ && apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev @@ -43,10 +50,12 @@ WORKDIR /app/api RUN apt-get update \ && apt-get install -y --no-install-recommends curl nodejs libgmp-dev libmpfr-dev libmpc-dev \ + # if you located in China, you can use aliyun mirror to speed up + # && echo "deb http://mirrors.aliyun.com/debian testing main" > /etc/apt/sources.list \ && echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \ && apt-get update \ # For Security - && apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.2-1 libldap-2.5-0=2.5.18+dfsg-2 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \ + && apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.2-2 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \ && apt-get autoremove -y \ && rm -rf /var/lib/apt/lists/* @@ -56,7 +65,7 @@ COPY --from=packages ${VIRTUAL_ENV} ${VIRTUAL_ENV} ENV PATH="${VIRTUAL_ENV}/bin:${PATH}" # Download nltk data -RUN python -c "import nltk; nltk.download('punkt')" +RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')" # Copy source code COPY . /app/api/ diff --git a/api/commands.py b/api/commands.py index 41f1a6444c..3bf8bc0ecc 100644 --- a/api/commands.py +++ b/api/commands.py @@ -559,8 +559,9 @@ def add_qdrant_doc_id_index(field: str): @click.command("create-tenant", help="Create account and tenant.") @click.option("--email", prompt=True, help="The email address of the tenant account.") +@click.option("--name", prompt=True, help="The workspace name of the tenant account.") @click.option("--language", prompt=True, help="Account language, default: en-US.") -def create_tenant(email: str, language: Optional[str] = None): +def create_tenant(email: str, language: Optional[str] = None, name: Optional[str] = None): """ Create tenant account """ @@ -580,13 +581,15 @@ def create_tenant(email: str, language: Optional[str] = None): if language not in languages: language = "en-US" + name = name.strip() + # generate random password new_password = secrets.token_urlsafe(16) # register account account = RegisterService.register(email=email, name=account_name, password=new_password, language=language) - TenantService.create_owner_tenant_if_not_exist(account) + TenantService.create_owner_tenant_if_not_exist(account, name) click.echo( click.style( diff --git a/api/configs/__init__.py b/api/configs/__init__.py index c0e28c34e1..3a172601c9 100644 --- a/api/configs/__init__.py +++ b/api/configs/__init__.py @@ -1,3 +1,3 @@ from .app_config import DifyConfig -dify_config = DifyConfig() \ No newline at end of file +dify_config = DifyConfig() diff --git a/api/configs/app_config.py b/api/configs/app_config.py index b277760edd..61de73c868 100644 --- a/api/configs/app_config.py +++ b/api/configs/app_config.py @@ -1,4 +1,3 @@ -from pydantic import Field, computed_field from pydantic_settings import SettingsConfigDict from configs.deploy import DeploymentConfig @@ -24,42 +23,16 @@ class DifyConfig( # **Before using, please contact business@dify.ai by email to inquire about licensing matters.** EnterpriseFeatureConfig, ): - DEBUG: bool = Field(default=False, description='whether to enable debug mode.') - model_config = SettingsConfigDict( # read from dotenv format config file - env_file='.env', - env_file_encoding='utf-8', + env_file=".env", + env_file_encoding="utf-8", frozen=True, # ignore extra attributes - extra='ignore', + extra="ignore", ) - CODE_MAX_NUMBER: int = 9223372036854775807 - CODE_MIN_NUMBER: int = -9223372036854775808 - CODE_MAX_STRING_LENGTH: int = 80000 - CODE_MAX_STRING_ARRAY_LENGTH: int = 30 - CODE_MAX_OBJECT_ARRAY_LENGTH: int = 30 - CODE_MAX_NUMBER_ARRAY_LENGTH: int = 1000 - - HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = 300 - HTTP_REQUEST_MAX_READ_TIMEOUT: int = 600 - HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = 600 - HTTP_REQUEST_NODE_MAX_BINARY_SIZE: int = 1024 * 1024 * 10 - - @computed_field - def HTTP_REQUEST_NODE_READABLE_MAX_BINARY_SIZE(self) -> str: - return f'{self.HTTP_REQUEST_NODE_MAX_BINARY_SIZE / 1024 / 1024:.2f}MB' - - HTTP_REQUEST_NODE_MAX_TEXT_SIZE: int = 1024 * 1024 - - @computed_field - def HTTP_REQUEST_NODE_READABLE_MAX_TEXT_SIZE(self) -> str: - return f'{self.HTTP_REQUEST_NODE_MAX_TEXT_SIZE / 1024 / 1024:.2f}MB' - - SSRF_PROXY_HTTP_URL: str | None = None - SSRF_PROXY_HTTPS_URL: str | None = None - - MODERATION_BUFFER_SIZE: int = Field(default=300, description='The buffer size for moderation.') - - MAX_VARIABLE_SIZE: int = Field(default=5 * 1024, description='The maximum size of a variable. default is 5KB.') + # Before adding any config, + # please consider to arrange it in the proper config group of existed or added + # for better readability and maintainability. + # Thanks for your concentration and consideration. diff --git a/api/configs/deploy/__init__.py b/api/configs/deploy/__init__.py index 219b315784..10271483c4 100644 --- a/api/configs/deploy/__init__.py +++ b/api/configs/deploy/__init__.py @@ -6,22 +6,28 @@ class DeploymentConfig(BaseSettings): """ Deployment configs """ + APPLICATION_NAME: str = Field( - description='application name', - default='langgenius/dify', + description="application name", + default="langgenius/dify", + ) + + DEBUG: bool = Field( + description="whether to enable debug mode.", + default=False, ) TESTING: bool = Field( - description='', + description="", default=False, ) EDITION: str = Field( - description='deployment edition', - default='SELF_HOSTED', + description="deployment edition", + default="SELF_HOSTED", ) DEPLOY_ENV: str = Field( - description='deployment environment, default to PRODUCTION.', - default='PRODUCTION', + description="deployment environment, default to PRODUCTION.", + default="PRODUCTION", ) diff --git a/api/configs/enterprise/__init__.py b/api/configs/enterprise/__init__.py index b5d884e10e..c661593a44 100644 --- a/api/configs/enterprise/__init__.py +++ b/api/configs/enterprise/__init__.py @@ -7,13 +7,14 @@ class EnterpriseFeatureConfig(BaseSettings): Enterprise feature configs. **Before using, please contact business@dify.ai by email to inquire about licensing matters.** """ + ENTERPRISE_ENABLED: bool = Field( - description='whether to enable enterprise features.' - 'Before using, please contact business@dify.ai by email to inquire about licensing matters.', + description="whether to enable enterprise features." + "Before using, please contact business@dify.ai by email to inquire about licensing matters.", default=False, ) CAN_REPLACE_LOGO: bool = Field( - description='whether to allow replacing enterprise logo.', + description="whether to allow replacing enterprise logo.", default=False, ) diff --git a/api/configs/extra/notion_config.py b/api/configs/extra/notion_config.py index b77e8adaae..bd1268fa45 100644 --- a/api/configs/extra/notion_config.py +++ b/api/configs/extra/notion_config.py @@ -8,27 +8,28 @@ class NotionConfig(BaseSettings): """ Notion integration configs """ + NOTION_CLIENT_ID: Optional[str] = Field( - description='Notion client ID', + description="Notion client ID", default=None, ) NOTION_CLIENT_SECRET: Optional[str] = Field( - description='Notion client secret key', + description="Notion client secret key", default=None, ) NOTION_INTEGRATION_TYPE: Optional[str] = Field( - description='Notion integration type, default to None, available values: internal.', + description="Notion integration type, default to None, available values: internal.", default=None, ) NOTION_INTERNAL_SECRET: Optional[str] = Field( - description='Notion internal secret key', + description="Notion internal secret key", default=None, ) NOTION_INTEGRATION_TOKEN: Optional[str] = Field( - description='Notion integration token', + description="Notion integration token", default=None, ) diff --git a/api/configs/extra/sentry_config.py b/api/configs/extra/sentry_config.py index e6517f730a..ea9ea60ffb 100644 --- a/api/configs/extra/sentry_config.py +++ b/api/configs/extra/sentry_config.py @@ -8,17 +8,18 @@ class SentryConfig(BaseSettings): """ Sentry configs """ + SENTRY_DSN: Optional[str] = Field( - description='Sentry DSN', + description="Sentry DSN", default=None, ) SENTRY_TRACES_SAMPLE_RATE: NonNegativeFloat = Field( - description='Sentry trace sample rate', + description="Sentry trace sample rate", default=1.0, ) SENTRY_PROFILES_SAMPLE_RATE: NonNegativeFloat = Field( - description='Sentry profiles sample rate', + description="Sentry profiles sample rate", default=1.0, ) diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 369b25d788..303bce2aa5 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1,6 +1,6 @@ -from typing import Optional +from typing import Annotated, Optional -from pydantic import AliasChoices, Field, NonNegativeInt, PositiveInt, computed_field +from pydantic import AliasChoices, Field, HttpUrl, NegativeInt, NonNegativeInt, PositiveInt, computed_field from pydantic_settings import BaseSettings from configs.feature.hosted_service import HostedServiceConfig @@ -10,16 +10,17 @@ class SecurityConfig(BaseSettings): """ Secret Key configs """ + SECRET_KEY: Optional[str] = Field( - description='Your App secret key will be used for securely signing the session cookie' - 'Make sure you are changing this key for your deployment with a strong key.' - 'You can generate a strong key using `openssl rand -base64 42`.' - 'Alternatively you can set it with `SECRET_KEY` environment variable.', + description="Your App secret key will be used for securely signing the session cookie" + "Make sure you are changing this key for your deployment with a strong key." + "You can generate a strong key using `openssl rand -base64 42`." + "Alternatively you can set it with `SECRET_KEY` environment variable.", default=None, ) RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field( - description='Expiry time in hours for reset token', + description="Expiry time in hours for reset token", default=24, ) @@ -28,12 +29,13 @@ class AppExecutionConfig(BaseSettings): """ App Execution configs """ + APP_MAX_EXECUTION_TIME: PositiveInt = Field( - description='execution timeout in seconds for app execution', + description="execution timeout in seconds for app execution", default=1200, ) APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field( - description='max active request per app, 0 means unlimited', + description="max active request per app, 0 means unlimited", default=0, ) @@ -42,14 +44,70 @@ class CodeExecutionSandboxConfig(BaseSettings): """ Code Execution Sandbox configs """ - CODE_EXECUTION_ENDPOINT: str = Field( - description='endpoint URL of code execution servcie', - default='http://sandbox:8194', + + CODE_EXECUTION_ENDPOINT: HttpUrl = Field( + description="endpoint URL of code execution servcie", + default="http://sandbox:8194", ) CODE_EXECUTION_API_KEY: str = Field( - description='API key for code execution service', - default='dify-sandbox', + description="API key for code execution service", + default="dify-sandbox", + ) + + CODE_EXECUTION_CONNECT_TIMEOUT: Optional[float] = Field( + description="connect timeout in seconds for code execution request", + default=10.0, + ) + + CODE_EXECUTION_READ_TIMEOUT: Optional[float] = Field( + description="read timeout in seconds for code execution request", + default=60.0, + ) + + CODE_EXECUTION_WRITE_TIMEOUT: Optional[float] = Field( + description="write timeout in seconds for code execution request", + default=10.0, + ) + + CODE_MAX_NUMBER: PositiveInt = Field( + description="max depth for code execution", + default=9223372036854775807, + ) + + CODE_MIN_NUMBER: NegativeInt = Field( + description="", + default=-9223372036854775807, + ) + + CODE_MAX_DEPTH: PositiveInt = Field( + description="max depth for code execution", + default=5, + ) + + CODE_MAX_PRECISION: PositiveInt = Field( + description="max precision digits for float type in code execution", + default=20, + ) + + CODE_MAX_STRING_LENGTH: PositiveInt = Field( + description="max string length for code execution", + default=80000, + ) + + CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field( + description="", + default=30, + ) + + CODE_MAX_OBJECT_ARRAY_LENGTH: PositiveInt = Field( + description="", + default=30, + ) + + CODE_MAX_NUMBER_ARRAY_LENGTH: PositiveInt = Field( + description="", + default=1000, ) @@ -57,28 +115,27 @@ class EndpointConfig(BaseSettings): """ Module URL configs """ + CONSOLE_API_URL: str = Field( - description='The backend URL prefix of the console API.' - 'used to concatenate the login authorization callback or notion integration callback.', - default='', + description="The backend URL prefix of the console API." + "used to concatenate the login authorization callback or notion integration callback.", + default="", ) CONSOLE_WEB_URL: str = Field( - description='The front-end URL prefix of the console web.' - 'used to concatenate some front-end addresses and for CORS configuration use.', - default='', + description="The front-end URL prefix of the console web." + "used to concatenate some front-end addresses and for CORS configuration use.", + default="", ) SERVICE_API_URL: str = Field( - description='Service API Url prefix.' - 'used to display Service API Base Url to the front-end.', - default='', + description="Service API Url prefix." "used to display Service API Base Url to the front-end.", + default="", ) APP_WEB_URL: str = Field( - description='WebApp Url prefix.' - 'used to display WebAPP API Base Url to the front-end.', - default='', + description="WebApp Url prefix." "used to display WebAPP API Base Url to the front-end.", + default="", ) @@ -86,17 +143,18 @@ class FileAccessConfig(BaseSettings): """ File Access configs """ + FILES_URL: str = Field( - description='File preview or download Url prefix.' - ' used to display File preview or download Url to the front-end or as Multi-model inputs;' - 'Url is signed and has expiration time.', - validation_alias=AliasChoices('FILES_URL', 'CONSOLE_API_URL'), + description="File preview or download Url prefix." + " used to display File preview or download Url to the front-end or as Multi-model inputs;" + "Url is signed and has expiration time.", + validation_alias=AliasChoices("FILES_URL", "CONSOLE_API_URL"), alias_priority=1, - default='', + default="", ) FILES_ACCESS_TIMEOUT: int = Field( - description='timeout in seconds for file accessing', + description="timeout in seconds for file accessing", default=300, ) @@ -105,23 +163,24 @@ class FileUploadConfig(BaseSettings): """ File Uploading configs """ + UPLOAD_FILE_SIZE_LIMIT: NonNegativeInt = Field( - description='size limit in Megabytes for uploading files', + description="size limit in Megabytes for uploading files", default=15, ) UPLOAD_FILE_BATCH_LIMIT: NonNegativeInt = Field( - description='batch size limit for uploading files', + description="batch size limit for uploading files", default=5, ) UPLOAD_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field( - description='image file size limit in Megabytes for uploading files', + description="image file size limit in Megabytes for uploading files", default=10, ) BATCH_UPLOAD_LIMIT: NonNegativeInt = Field( - description='', # todo: to be clarified + description="", # todo: to be clarified default=20, ) @@ -130,45 +189,79 @@ class HttpConfig(BaseSettings): """ HTTP configs """ + API_COMPRESSION_ENABLED: bool = Field( - description='whether to enable HTTP response compression of gzip', + description="whether to enable HTTP response compression of gzip", default=False, ) inner_CONSOLE_CORS_ALLOW_ORIGINS: str = Field( - description='', - validation_alias=AliasChoices('CONSOLE_CORS_ALLOW_ORIGINS', 'CONSOLE_WEB_URL'), - default='', + description="", + validation_alias=AliasChoices("CONSOLE_CORS_ALLOW_ORIGINS", "CONSOLE_WEB_URL"), + default="", ) @computed_field @property def CONSOLE_CORS_ALLOW_ORIGINS(self) -> list[str]: - return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(',') + return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(",") inner_WEB_API_CORS_ALLOW_ORIGINS: str = Field( - description='', - validation_alias=AliasChoices('WEB_API_CORS_ALLOW_ORIGINS'), - default='*', + description="", + validation_alias=AliasChoices("WEB_API_CORS_ALLOW_ORIGINS"), + default="*", ) @computed_field @property def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]: - return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(',') + return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",") + + HTTP_REQUEST_MAX_CONNECT_TIMEOUT: Annotated[ + PositiveInt, Field(ge=10, description="connect timeout in seconds for HTTP request") + ] = 10 + + HTTP_REQUEST_MAX_READ_TIMEOUT: Annotated[ + PositiveInt, Field(ge=60, description="read timeout in seconds for HTTP request") + ] = 60 + + HTTP_REQUEST_MAX_WRITE_TIMEOUT: Annotated[ + PositiveInt, Field(ge=10, description="read timeout in seconds for HTTP request") + ] = 20 + + HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field( + description="", + default=10 * 1024 * 1024, + ) + + HTTP_REQUEST_NODE_MAX_TEXT_SIZE: PositiveInt = Field( + description="", + default=1 * 1024 * 1024, + ) + + SSRF_PROXY_HTTP_URL: Optional[str] = Field( + description="HTTP URL for SSRF proxy", + default=None, + ) + + SSRF_PROXY_HTTPS_URL: Optional[str] = Field( + description="HTTPS URL for SSRF proxy", + default=None, + ) class InnerAPIConfig(BaseSettings): """ Inner API configs """ + INNER_API: bool = Field( - description='whether to enable the inner API', + description="whether to enable the inner API", default=False, ) INNER_API_KEY: Optional[str] = Field( - description='The inner API key is used to authenticate the inner API', + description="The inner API key is used to authenticate the inner API", default=None, ) @@ -179,28 +272,27 @@ class LoggingConfig(BaseSettings): """ LOG_LEVEL: str = Field( - description='Log output level, default to INFO.' - 'It is recommended to set it to ERROR for production.', - default='INFO', + description="Log output level, default to INFO." "It is recommended to set it to ERROR for production.", + default="INFO", ) LOG_FILE: Optional[str] = Field( - description='logging output file path', + description="logging output file path", default=None, ) LOG_FORMAT: str = Field( - description='log format', - default='%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s', + description="log format", + default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s", ) LOG_DATEFORMAT: Optional[str] = Field( - description='log date format', + description="log date format", default=None, ) LOG_TZ: Optional[str] = Field( - description='specify log timezone, eg: America/New_York', + description="specify log timezone, eg: America/New_York", default=None, ) @@ -209,8 +301,9 @@ class ModelLoadBalanceConfig(BaseSettings): """ Model load balance configs """ + MODEL_LB_ENABLED: bool = Field( - description='whether to enable model load balancing', + description="whether to enable model load balancing", default=False, ) @@ -219,8 +312,9 @@ class BillingConfig(BaseSettings): """ Platform Billing Configurations """ + BILLING_ENABLED: bool = Field( - description='whether to enable billing', + description="whether to enable billing", default=False, ) @@ -229,9 +323,10 @@ class UpdateConfig(BaseSettings): """ Update configs """ + CHECK_UPDATE_URL: str = Field( - description='url for checking updates', - default='https://updates.dify.ai', + description="url for checking updates", + default="https://updates.dify.ai", ) @@ -241,47 +336,53 @@ class WorkflowConfig(BaseSettings): """ WORKFLOW_MAX_EXECUTION_STEPS: PositiveInt = Field( - description='max execution steps in single workflow execution', + description="max execution steps in single workflow execution", default=500, ) WORKFLOW_MAX_EXECUTION_TIME: PositiveInt = Field( - description='max execution time in seconds in single workflow execution', + description="max execution time in seconds in single workflow execution", default=1200, ) WORKFLOW_CALL_MAX_DEPTH: PositiveInt = Field( - description='max depth of calling in single workflow execution', + description="max depth of calling in single workflow execution", default=5, ) + MAX_VARIABLE_SIZE: PositiveInt = Field( + description="The maximum size in bytes of a variable. default to 5KB.", + default=5 * 1024, + ) + class OAuthConfig(BaseSettings): """ oauth configs """ + OAUTH_REDIRECT_PATH: str = Field( - description='redirect path for OAuth', - default='/console/api/oauth/authorize', + description="redirect path for OAuth", + default="/console/api/oauth/authorize", ) GITHUB_CLIENT_ID: Optional[str] = Field( - description='GitHub client id for OAuth', + description="GitHub client id for OAuth", default=None, ) GITHUB_CLIENT_SECRET: Optional[str] = Field( - description='GitHub client secret key for OAuth', + description="GitHub client secret key for OAuth", default=None, ) GOOGLE_CLIENT_ID: Optional[str] = Field( - description='Google client id for OAuth', + description="Google client id for OAuth", default=None, ) GOOGLE_CLIENT_SECRET: Optional[str] = Field( - description='Google client secret key for OAuth', + description="Google client secret key for OAuth", default=None, ) @@ -291,9 +392,8 @@ class ModerationConfig(BaseSettings): Moderation in app configs. """ - # todo: to be clarified in usage and unit - OUTPUT_MODERATION_BUFFER_SIZE: PositiveInt = Field( - description='buffer size for moderation', + MODERATION_BUFFER_SIZE: PositiveInt = Field( + description="buffer size for moderation", default=300, ) @@ -304,7 +404,7 @@ class ToolConfig(BaseSettings): """ TOOL_ICON_CACHE_MAX_AGE: PositiveInt = Field( - description='max age in seconds for tool icon caching', + description="max age in seconds for tool icon caching", default=3600, ) @@ -315,52 +415,52 @@ class MailConfig(BaseSettings): """ MAIL_TYPE: Optional[str] = Field( - description='Mail provider type name, default to None, availabile values are `smtp` and `resend`.', + description="Mail provider type name, default to None, availabile values are `smtp` and `resend`.", default=None, ) MAIL_DEFAULT_SEND_FROM: Optional[str] = Field( - description='default email address for sending from ', + description="default email address for sending from ", default=None, ) RESEND_API_KEY: Optional[str] = Field( - description='API key for Resend', + description="API key for Resend", default=None, ) RESEND_API_URL: Optional[str] = Field( - description='API URL for Resend', + description="API URL for Resend", default=None, ) SMTP_SERVER: Optional[str] = Field( - description='smtp server host', + description="smtp server host", default=None, ) SMTP_PORT: Optional[int] = Field( - description='smtp server port', + description="smtp server port", default=465, ) SMTP_USERNAME: Optional[str] = Field( - description='smtp server username', + description="smtp server username", default=None, ) SMTP_PASSWORD: Optional[str] = Field( - description='smtp server password', + description="smtp server password", default=None, ) SMTP_USE_TLS: bool = Field( - description='whether to use TLS connection to smtp server', + description="whether to use TLS connection to smtp server", default=False, ) SMTP_OPPORTUNISTIC_TLS: bool = Field( - description='whether to use opportunistic TLS connection to smtp server', + description="whether to use opportunistic TLS connection to smtp server", default=False, ) @@ -371,22 +471,22 @@ class RagEtlConfig(BaseSettings): """ ETL_TYPE: str = Field( - description='RAG ETL type name, default to `dify`, available values are `dify` and `Unstructured`. ', - default='dify', + description="RAG ETL type name, default to `dify`, available values are `dify` and `Unstructured`. ", + default="dify", ) KEYWORD_DATA_SOURCE_TYPE: str = Field( - description='source type for keyword data, default to `database`, available values are `database` .', - default='database', + description="source type for keyword data, default to `database`, available values are `database` .", + default="database", ) UNSTRUCTURED_API_URL: Optional[str] = Field( - description='API URL for Unstructured', + description="API URL for Unstructured", default=None, ) UNSTRUCTURED_API_KEY: Optional[str] = Field( - description='API key for Unstructured', + description="API key for Unstructured", default=None, ) @@ -397,22 +497,23 @@ class DataSetConfig(BaseSettings): """ CLEAN_DAY_SETTING: PositiveInt = Field( - description='interval in days for cleaning up dataset', + description="interval in days for cleaning up dataset", default=30, ) DATASET_OPERATOR_ENABLED: bool = Field( - description='whether to enable dataset operator', + description="whether to enable dataset operator", default=False, ) + class WorkspaceConfig(BaseSettings): """ Workspace configs """ INVITE_EXPIRY_HOURS: PositiveInt = Field( - description='workspaces invitation expiration in hours', + description="workspaces invitation expiration in hours", default=72, ) @@ -423,25 +524,81 @@ class IndexingConfig(BaseSettings): """ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: PositiveInt = Field( - description='max segmentation token length for indexing', + description="max segmentation token length for indexing", default=1000, ) class ImageFormatConfig(BaseSettings): MULTIMODAL_SEND_IMAGE_FORMAT: str = Field( - description='multi model send image format, support base64, url, default is base64', - default='base64', + description="multi model send image format, support base64, url, default is base64", + default="base64", ) class CeleryBeatConfig(BaseSettings): CELERY_BEAT_SCHEDULER_TIME: int = Field( - description='the time of the celery scheduler, default to 1 day', + description="the time of the celery scheduler, default to 1 day", default=1, ) +class PositionConfig(BaseSettings): + POSITION_PROVIDER_PINS: str = Field( + description="The heads of model providers", + default="", + ) + + POSITION_PROVIDER_INCLUDES: str = Field( + description="The included model providers", + default="", + ) + + POSITION_PROVIDER_EXCLUDES: str = Field( + description="The excluded model providers", + default="", + ) + + POSITION_TOOL_PINS: str = Field( + description="The heads of tools", + default="", + ) + + POSITION_TOOL_INCLUDES: str = Field( + description="The included tools", + default="", + ) + + POSITION_TOOL_EXCLUDES: str = Field( + description="The excluded tools", + default="", + ) + + @computed_field + def POSITION_PROVIDER_PINS_LIST(self) -> list[str]: + return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(",") if item.strip() != ""] + + @computed_field + def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]: + return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(",") if item.strip() != ""} + + @computed_field + def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]: + return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(",") if item.strip() != ""} + + @computed_field + def POSITION_TOOL_PINS_LIST(self) -> list[str]: + return [item.strip() for item in self.POSITION_TOOL_PINS.split(",") if item.strip() != ""] + + @computed_field + def POSITION_TOOL_INCLUDES_SET(self) -> set[str]: + return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(",") if item.strip() != ""} + + @computed_field + def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]: + return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""} + + class FeatureConfig( # place the configs in alphabet order AppExecutionConfig, @@ -466,7 +623,7 @@ class FeatureConfig( UpdateConfig, WorkflowConfig, WorkspaceConfig, - + PositionConfig, # hosted services config HostedServiceConfig, CeleryBeatConfig, diff --git a/api/configs/feature/hosted_service/__init__.py b/api/configs/feature/hosted_service/__init__.py index 88fe188587..f269d0ab9c 100644 --- a/api/configs/feature/hosted_service/__init__.py +++ b/api/configs/feature/hosted_service/__init__.py @@ -10,62 +10,62 @@ class HostedOpenAiConfig(BaseSettings): """ HOSTED_OPENAI_API_KEY: Optional[str] = Field( - description='', + description="", default=None, ) HOSTED_OPENAI_API_BASE: Optional[str] = Field( - description='', + description="", default=None, ) HOSTED_OPENAI_API_ORGANIZATION: Optional[str] = Field( - description='', + description="", default=None, ) HOSTED_OPENAI_TRIAL_ENABLED: bool = Field( - description='', + description="", default=False, ) HOSTED_OPENAI_TRIAL_MODELS: str = Field( - description='', - default='gpt-3.5-turbo,' - 'gpt-3.5-turbo-1106,' - 'gpt-3.5-turbo-instruct,' - 'gpt-3.5-turbo-16k,' - 'gpt-3.5-turbo-16k-0613,' - 'gpt-3.5-turbo-0613,' - 'gpt-3.5-turbo-0125,' - 'text-davinci-003', + description="", + default="gpt-3.5-turbo," + "gpt-3.5-turbo-1106," + "gpt-3.5-turbo-instruct," + "gpt-3.5-turbo-16k," + "gpt-3.5-turbo-16k-0613," + "gpt-3.5-turbo-0613," + "gpt-3.5-turbo-0125," + "text-davinci-003", ) HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field( - description='', + description="", default=200, ) HOSTED_OPENAI_PAID_ENABLED: bool = Field( - description='', + description="", default=False, ) HOSTED_OPENAI_PAID_MODELS: str = Field( - description='', - default='gpt-4,' - 'gpt-4-turbo-preview,' - 'gpt-4-turbo-2024-04-09,' - 'gpt-4-1106-preview,' - 'gpt-4-0125-preview,' - 'gpt-3.5-turbo,' - 'gpt-3.5-turbo-16k,' - 'gpt-3.5-turbo-16k-0613,' - 'gpt-3.5-turbo-1106,' - 'gpt-3.5-turbo-0613,' - 'gpt-3.5-turbo-0125,' - 'gpt-3.5-turbo-instruct,' - 'text-davinci-003', + description="", + default="gpt-4," + "gpt-4-turbo-preview," + "gpt-4-turbo-2024-04-09," + "gpt-4-1106-preview," + "gpt-4-0125-preview," + "gpt-3.5-turbo," + "gpt-3.5-turbo-16k," + "gpt-3.5-turbo-16k-0613," + "gpt-3.5-turbo-1106," + "gpt-3.5-turbo-0613," + "gpt-3.5-turbo-0125," + "gpt-3.5-turbo-instruct," + "text-davinci-003", ) @@ -75,22 +75,22 @@ class HostedAzureOpenAiConfig(BaseSettings): """ HOSTED_AZURE_OPENAI_ENABLED: bool = Field( - description='', + description="", default=False, ) HOSTED_AZURE_OPENAI_API_KEY: Optional[str] = Field( - description='', + description="", default=None, ) HOSTED_AZURE_OPENAI_API_BASE: Optional[str] = Field( - description='', + description="", default=None, ) HOSTED_AZURE_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field( - description='', + description="", default=200, ) @@ -101,27 +101,27 @@ class HostedAnthropicConfig(BaseSettings): """ HOSTED_ANTHROPIC_API_BASE: Optional[str] = Field( - description='', + description="", default=None, ) HOSTED_ANTHROPIC_API_KEY: Optional[str] = Field( - description='', + description="", default=None, ) HOSTED_ANTHROPIC_TRIAL_ENABLED: bool = Field( - description='', + description="", default=False, ) HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field( - description='', + description="", default=600000, ) HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field( - description='', + description="", default=False, ) @@ -132,7 +132,7 @@ class HostedMinmaxConfig(BaseSettings): """ HOSTED_MINIMAX_ENABLED: bool = Field( - description='', + description="", default=False, ) @@ -143,7 +143,7 @@ class HostedSparkConfig(BaseSettings): """ HOSTED_SPARK_ENABLED: bool = Field( - description='', + description="", default=False, ) @@ -154,7 +154,7 @@ class HostedZhipuAIConfig(BaseSettings): """ HOSTED_ZHIPUAI_ENABLED: bool = Field( - description='', + description="", default=False, ) @@ -165,13 +165,13 @@ class HostedModerationConfig(BaseSettings): """ HOSTED_MODERATION_ENABLED: bool = Field( - description='', + description="", default=False, ) HOSTED_MODERATION_PROVIDERS: str = Field( - description='', - default='', + description="", + default="", ) @@ -181,15 +181,15 @@ class HostedFetchAppTemplateConfig(BaseSettings): """ HOSTED_FETCH_APP_TEMPLATES_MODE: str = Field( - description='the mode for fetching app templates,' - ' default to remote,' - ' available values: remote, db, builtin', - default='remote', + description="the mode for fetching app templates," + " default to remote," + " available values: remote, db, builtin", + default="remote", ) HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN: str = Field( - description='the domain for fetching remote app templates', - default='https://tmpl.dify.ai', + description="the domain for fetching remote app templates", + default="https://tmpl.dify.ai", ) @@ -202,7 +202,6 @@ class HostedServiceConfig( HostedOpenAiConfig, HostedSparkConfig, HostedZhipuAIConfig, - # moderation HostedModerationConfig, ): diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 07688e9aeb..f25979e5d8 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -13,6 +13,7 @@ from configs.middleware.storage.oci_storage_config import OCIStorageConfig from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig from configs.middleware.vdb.chroma_config import ChromaConfig +from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig from configs.middleware.vdb.milvus_config import MilvusConfig from configs.middleware.vdb.myscale_config import MyScaleConfig from configs.middleware.vdb.opensearch_config import OpenSearchConfig @@ -28,108 +29,108 @@ from configs.middleware.vdb.weaviate_config import WeaviateConfig class StorageConfig(BaseSettings): STORAGE_TYPE: str = Field( - description='storage type,' - ' default to `local`,' - ' available values are `local`, `s3`, `azure-blob`, `aliyun-oss`, `google-storage`.', - default='local', + description="storage type," + " default to `local`," + " available values are `local`, `s3`, `azure-blob`, `aliyun-oss`, `google-storage`.", + default="local", ) STORAGE_LOCAL_PATH: str = Field( - description='local storage path', - default='storage', + description="local storage path", + default="storage", ) class VectorStoreConfig(BaseSettings): VECTOR_STORE: Optional[str] = Field( - description='vector store type', + description="vector store type", default=None, ) class KeywordStoreConfig(BaseSettings): KEYWORD_STORE: str = Field( - description='keyword store type', - default='jieba', + description="keyword store type", + default="jieba", ) class DatabaseConfig: DB_HOST: str = Field( - description='db host', - default='localhost', + description="db host", + default="localhost", ) DB_PORT: PositiveInt = Field( - description='db port', + description="db port", default=5432, ) DB_USERNAME: str = Field( - description='db username', - default='postgres', + description="db username", + default="postgres", ) DB_PASSWORD: str = Field( - description='db password', - default='', + description="db password", + default="", ) DB_DATABASE: str = Field( - description='db database', - default='dify', + description="db database", + default="dify", ) DB_CHARSET: str = Field( - description='db charset', - default='', + description="db charset", + default="", ) DB_EXTRAS: str = Field( - description='db extras options. Example: keepalives_idle=60&keepalives=1', - default='', + description="db extras options. Example: keepalives_idle=60&keepalives=1", + default="", ) SQLALCHEMY_DATABASE_URI_SCHEME: str = Field( - description='db uri scheme', - default='postgresql', + description="db uri scheme", + default="postgresql", ) @computed_field @property def SQLALCHEMY_DATABASE_URI(self) -> str: db_extras = ( - f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}" - if self.DB_CHARSET - else self.DB_EXTRAS + f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}" if self.DB_CHARSET else self.DB_EXTRAS ).strip("&") db_extras = f"?{db_extras}" if db_extras else "" - return (f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://" - f"{quote_plus(self.DB_USERNAME)}:{quote_plus(self.DB_PASSWORD)}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}" - f"{db_extras}") + return ( + f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://" + f"{quote_plus(self.DB_USERNAME)}:{quote_plus(self.DB_PASSWORD)}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}" + f"{db_extras}" + ) SQLALCHEMY_POOL_SIZE: NonNegativeInt = Field( - description='pool size of SqlAlchemy', + description="pool size of SqlAlchemy", default=30, ) SQLALCHEMY_MAX_OVERFLOW: NonNegativeInt = Field( - description='max overflows for SqlAlchemy', + description="max overflows for SqlAlchemy", default=10, ) SQLALCHEMY_POOL_RECYCLE: NonNegativeInt = Field( - description='SqlAlchemy pool recycle', + description="SqlAlchemy pool recycle", default=3600, ) SQLALCHEMY_POOL_PRE_PING: bool = Field( - description='whether to enable pool pre-ping in SqlAlchemy', + description="whether to enable pool pre-ping in SqlAlchemy", default=False, ) SQLALCHEMY_ECHO: bool | str = Field( - description='whether to enable SqlAlchemy echo', + description="whether to enable SqlAlchemy echo", default=False, ) @@ -137,35 +138,38 @@ class DatabaseConfig: @property def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]: return { - 'pool_size': self.SQLALCHEMY_POOL_SIZE, - 'max_overflow': self.SQLALCHEMY_MAX_OVERFLOW, - 'pool_recycle': self.SQLALCHEMY_POOL_RECYCLE, - 'pool_pre_ping': self.SQLALCHEMY_POOL_PRE_PING, - 'connect_args': {'options': '-c timezone=UTC'}, + "pool_size": self.SQLALCHEMY_POOL_SIZE, + "max_overflow": self.SQLALCHEMY_MAX_OVERFLOW, + "pool_recycle": self.SQLALCHEMY_POOL_RECYCLE, + "pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING, + "connect_args": {"options": "-c timezone=UTC"}, } class CeleryConfig(DatabaseConfig): CELERY_BACKEND: str = Field( - description='Celery backend, available values are `database`, `redis`', - default='database', + description="Celery backend, available values are `database`, `redis`", + default="database", ) CELERY_BROKER_URL: Optional[str] = Field( - description='CELERY_BROKER_URL', + description="CELERY_BROKER_URL", default=None, ) @computed_field @property def CELERY_RESULT_BACKEND(self) -> str | None: - return 'db+{}'.format(self.SQLALCHEMY_DATABASE_URI) \ - if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL + return ( + "db+{}".format(self.SQLALCHEMY_DATABASE_URI) + if self.CELERY_BACKEND == "database" + else self.CELERY_BROKER_URL + ) @computed_field @property def BROKER_USE_SSL(self) -> bool: - return self.CELERY_BROKER_URL.startswith('rediss://') if self.CELERY_BROKER_URL else False + return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False class MiddlewareConfig( @@ -174,7 +178,6 @@ class MiddlewareConfig( DatabaseConfig, KeywordStoreConfig, RedisConfig, - # configs of storage and storage providers StorageConfig, AliyunOSSStorageConfig, @@ -183,7 +186,6 @@ class MiddlewareConfig( TencentCloudCOSStorageConfig, S3StorageConfig, OCIStorageConfig, - # configs of vdb and vdb providers VectorStoreConfig, AnalyticdbConfig, @@ -199,5 +201,6 @@ class MiddlewareConfig( TencentVectorDBConfig, TiDBVectorConfig, WeaviateConfig, + ElasticsearchConfig, ): pass diff --git a/api/configs/middleware/cache/redis_config.py b/api/configs/middleware/cache/redis_config.py index 436ba5d4c0..cacdaf6fb6 100644 --- a/api/configs/middleware/cache/redis_config.py +++ b/api/configs/middleware/cache/redis_config.py @@ -8,32 +8,33 @@ class RedisConfig(BaseSettings): """ Redis configs """ + REDIS_HOST: str = Field( - description='Redis host', - default='localhost', + description="Redis host", + default="localhost", ) REDIS_PORT: PositiveInt = Field( - description='Redis port', + description="Redis port", default=6379, ) REDIS_USERNAME: Optional[str] = Field( - description='Redis username', + description="Redis username", default=None, ) REDIS_PASSWORD: Optional[str] = Field( - description='Redis password', + description="Redis password", default=None, ) REDIS_DB: NonNegativeInt = Field( - description='Redis database id, default to 0', + description="Redis database id, default to 0", default=0, ) REDIS_USE_SSL: bool = Field( - description='whether to use SSL for Redis connection', + description="whether to use SSL for Redis connection", default=False, ) diff --git a/api/configs/middleware/storage/aliyun_oss_storage_config.py b/api/configs/middleware/storage/aliyun_oss_storage_config.py index 19e6cafb12..c1843dc26c 100644 --- a/api/configs/middleware/storage/aliyun_oss_storage_config.py +++ b/api/configs/middleware/storage/aliyun_oss_storage_config.py @@ -10,31 +10,36 @@ class AliyunOSSStorageConfig(BaseSettings): """ ALIYUN_OSS_BUCKET_NAME: Optional[str] = Field( - description='Aliyun OSS bucket name', + description="Aliyun OSS bucket name", default=None, ) ALIYUN_OSS_ACCESS_KEY: Optional[str] = Field( - description='Aliyun OSS access key', + description="Aliyun OSS access key", default=None, ) ALIYUN_OSS_SECRET_KEY: Optional[str] = Field( - description='Aliyun OSS secret key', + description="Aliyun OSS secret key", default=None, ) ALIYUN_OSS_ENDPOINT: Optional[str] = Field( - description='Aliyun OSS endpoint URL', + description="Aliyun OSS endpoint URL", default=None, ) ALIYUN_OSS_REGION: Optional[str] = Field( - description='Aliyun OSS region', + description="Aliyun OSS region", default=None, ) ALIYUN_OSS_AUTH_VERSION: Optional[str] = Field( - description='Aliyun OSS authentication version', + description="Aliyun OSS authentication version", + default=None, + ) + + ALIYUN_OSS_PATH: Optional[str] = Field( + description="Aliyun OSS path", default=None, ) diff --git a/api/configs/middleware/storage/amazon_s3_storage_config.py b/api/configs/middleware/storage/amazon_s3_storage_config.py index 2566fbd5da..bef9326108 100644 --- a/api/configs/middleware/storage/amazon_s3_storage_config.py +++ b/api/configs/middleware/storage/amazon_s3_storage_config.py @@ -10,36 +10,36 @@ class S3StorageConfig(BaseSettings): """ S3_ENDPOINT: Optional[str] = Field( - description='S3 storage endpoint', + description="S3 storage endpoint", default=None, ) S3_REGION: Optional[str] = Field( - description='S3 storage region', + description="S3 storage region", default=None, ) S3_BUCKET_NAME: Optional[str] = Field( - description='S3 storage bucket name', + description="S3 storage bucket name", default=None, ) S3_ACCESS_KEY: Optional[str] = Field( - description='S3 storage access key', + description="S3 storage access key", default=None, ) S3_SECRET_KEY: Optional[str] = Field( - description='S3 storage secret key', + description="S3 storage secret key", default=None, ) S3_ADDRESS_STYLE: str = Field( - description='S3 storage address style', - default='auto', + description="S3 storage address style", + default="auto", ) S3_USE_AWS_MANAGED_IAM: bool = Field( - description='whether to use aws managed IAM for S3', + description="whether to use aws managed IAM for S3", default=False, ) diff --git a/api/configs/middleware/storage/azure_blob_storage_config.py b/api/configs/middleware/storage/azure_blob_storage_config.py index 26e441c89b..10944b58ed 100644 --- a/api/configs/middleware/storage/azure_blob_storage_config.py +++ b/api/configs/middleware/storage/azure_blob_storage_config.py @@ -10,21 +10,21 @@ class AzureBlobStorageConfig(BaseSettings): """ AZURE_BLOB_ACCOUNT_NAME: Optional[str] = Field( - description='Azure Blob account name', + description="Azure Blob account name", default=None, ) AZURE_BLOB_ACCOUNT_KEY: Optional[str] = Field( - description='Azure Blob account key', + description="Azure Blob account key", default=None, ) AZURE_BLOB_CONTAINER_NAME: Optional[str] = Field( - description='Azure Blob container name', + description="Azure Blob container name", default=None, ) AZURE_BLOB_ACCOUNT_URL: Optional[str] = Field( - description='Azure Blob account URL', + description="Azure Blob account URL", default=None, ) diff --git a/api/configs/middleware/storage/google_cloud_storage_config.py b/api/configs/middleware/storage/google_cloud_storage_config.py index e1b0e34e0c..10a2d97e8d 100644 --- a/api/configs/middleware/storage/google_cloud_storage_config.py +++ b/api/configs/middleware/storage/google_cloud_storage_config.py @@ -10,11 +10,11 @@ class GoogleCloudStorageConfig(BaseSettings): """ GOOGLE_STORAGE_BUCKET_NAME: Optional[str] = Field( - description='Google Cloud storage bucket name', + description="Google Cloud storage bucket name", default=None, ) GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: Optional[str] = Field( - description='Google Cloud storage service account json base64', + description="Google Cloud storage service account json base64", default=None, ) diff --git a/api/configs/middleware/storage/oci_storage_config.py b/api/configs/middleware/storage/oci_storage_config.py index 6c0c067469..f8993496c9 100644 --- a/api/configs/middleware/storage/oci_storage_config.py +++ b/api/configs/middleware/storage/oci_storage_config.py @@ -10,27 +10,26 @@ class OCIStorageConfig(BaseSettings): """ OCI_ENDPOINT: Optional[str] = Field( - description='OCI storage endpoint', + description="OCI storage endpoint", default=None, ) OCI_REGION: Optional[str] = Field( - description='OCI storage region', + description="OCI storage region", default=None, ) OCI_BUCKET_NAME: Optional[str] = Field( - description='OCI storage bucket name', + description="OCI storage bucket name", default=None, ) OCI_ACCESS_KEY: Optional[str] = Field( - description='OCI storage access key', + description="OCI storage access key", default=None, ) OCI_SECRET_KEY: Optional[str] = Field( - description='OCI storage secret key', + description="OCI storage secret key", default=None, ) - diff --git a/api/configs/middleware/storage/tencent_cos_storage_config.py b/api/configs/middleware/storage/tencent_cos_storage_config.py index 1060c7b93e..765ac08f3e 100644 --- a/api/configs/middleware/storage/tencent_cos_storage_config.py +++ b/api/configs/middleware/storage/tencent_cos_storage_config.py @@ -10,26 +10,26 @@ class TencentCloudCOSStorageConfig(BaseSettings): """ TENCENT_COS_BUCKET_NAME: Optional[str] = Field( - description='Tencent Cloud COS bucket name', + description="Tencent Cloud COS bucket name", default=None, ) TENCENT_COS_REGION: Optional[str] = Field( - description='Tencent Cloud COS region', + description="Tencent Cloud COS region", default=None, ) TENCENT_COS_SECRET_ID: Optional[str] = Field( - description='Tencent Cloud COS secret id', + description="Tencent Cloud COS secret id", default=None, ) TENCENT_COS_SECRET_KEY: Optional[str] = Field( - description='Tencent Cloud COS secret key', + description="Tencent Cloud COS secret key", default=None, ) TENCENT_COS_SCHEME: Optional[str] = Field( - description='Tencent Cloud COS scheme', + description="Tencent Cloud COS scheme", default=None, ) diff --git a/api/configs/middleware/vdb/analyticdb_config.py b/api/configs/middleware/vdb/analyticdb_config.py index db2899265e..04f5b0e5bf 100644 --- a/api/configs/middleware/vdb/analyticdb_config.py +++ b/api/configs/middleware/vdb/analyticdb_config.py @@ -10,35 +10,28 @@ class AnalyticdbConfig(BaseModel): https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/getting-started/create-an-instance-instances-with-vector-engine-optimization-enabled """ - ANALYTICDB_KEY_ID : Optional[str] = Field( - default=None, - description="The Access Key ID provided by Alibaba Cloud for authentication." + ANALYTICDB_KEY_ID: Optional[str] = Field( + default=None, description="The Access Key ID provided by Alibaba Cloud for authentication." ) - ANALYTICDB_KEY_SECRET : Optional[str] = Field( - default=None, - description="The Secret Access Key corresponding to the Access Key ID for secure access." + ANALYTICDB_KEY_SECRET: Optional[str] = Field( + default=None, description="The Secret Access Key corresponding to the Access Key ID for secure access." ) - ANALYTICDB_REGION_ID : Optional[str] = Field( - default=None, - description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')." + ANALYTICDB_REGION_ID: Optional[str] = Field( + default=None, description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')." ) - ANALYTICDB_INSTANCE_ID : Optional[str] = Field( + ANALYTICDB_INSTANCE_ID: Optional[str] = Field( default=None, - description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456').." + description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456')..", ) - ANALYTICDB_ACCOUNT : Optional[str] = Field( - default=None, - description="The account name used to log in to the AnalyticDB instance." + ANALYTICDB_ACCOUNT: Optional[str] = Field( + default=None, description="The account name used to log in to the AnalyticDB instance." ) - ANALYTICDB_PASSWORD : Optional[str] = Field( - default=None, - description="The password associated with the AnalyticDB account for authentication." + ANALYTICDB_PASSWORD: Optional[str] = Field( + default=None, description="The password associated with the AnalyticDB account for authentication." ) - ANALYTICDB_NAMESPACE : Optional[str] = Field( - default=None, - description="The namespace within AnalyticDB for schema isolation." + ANALYTICDB_NAMESPACE: Optional[str] = Field( + default=None, description="The namespace within AnalyticDB for schema isolation." ) - ANALYTICDB_NAMESPACE_PASSWORD : Optional[str] = Field( - default=None, - description="The password for accessing the specified namespace within the AnalyticDB instance." + ANALYTICDB_NAMESPACE_PASSWORD: Optional[str] = Field( + default=None, description="The password for accessing the specified namespace within the AnalyticDB instance." ) diff --git a/api/configs/middleware/vdb/chroma_config.py b/api/configs/middleware/vdb/chroma_config.py index f365879efb..d386623a56 100644 --- a/api/configs/middleware/vdb/chroma_config.py +++ b/api/configs/middleware/vdb/chroma_config.py @@ -10,31 +10,31 @@ class ChromaConfig(BaseSettings): """ CHROMA_HOST: Optional[str] = Field( - description='Chroma host', + description="Chroma host", default=None, ) CHROMA_PORT: PositiveInt = Field( - description='Chroma port', + description="Chroma port", default=8000, ) CHROMA_TENANT: Optional[str] = Field( - description='Chroma database', + description="Chroma database", default=None, ) CHROMA_DATABASE: Optional[str] = Field( - description='Chroma database', + description="Chroma database", default=None, ) CHROMA_AUTH_PROVIDER: Optional[str] = Field( - description='Chroma authentication provider', + description="Chroma authentication provider", default=None, ) CHROMA_AUTH_CREDENTIALS: Optional[str] = Field( - description='Chroma authentication credentials', + description="Chroma authentication credentials", default=None, ) diff --git a/api/configs/middleware/vdb/elasticsearch_config.py b/api/configs/middleware/vdb/elasticsearch_config.py new file mode 100644 index 0000000000..5b6a8fd939 --- /dev/null +++ b/api/configs/middleware/vdb/elasticsearch_config.py @@ -0,0 +1,30 @@ +from typing import Optional + +from pydantic import Field, PositiveInt +from pydantic_settings import BaseSettings + + +class ElasticsearchConfig(BaseSettings): + """ + Elasticsearch configs + """ + + ELASTICSEARCH_HOST: Optional[str] = Field( + description="Elasticsearch host", + default="127.0.0.1", + ) + + ELASTICSEARCH_PORT: PositiveInt = Field( + description="Elasticsearch port", + default=9200, + ) + + ELASTICSEARCH_USERNAME: Optional[str] = Field( + description="Elasticsearch username", + default="elastic", + ) + + ELASTICSEARCH_PASSWORD: Optional[str] = Field( + description="Elasticsearch password", + default="elastic", + ) diff --git a/api/configs/middleware/vdb/milvus_config.py b/api/configs/middleware/vdb/milvus_config.py index 01502d4590..85466cd5cc 100644 --- a/api/configs/middleware/vdb/milvus_config.py +++ b/api/configs/middleware/vdb/milvus_config.py @@ -10,31 +10,31 @@ class MilvusConfig(BaseSettings): """ MILVUS_HOST: Optional[str] = Field( - description='Milvus host', + description="Milvus host", default=None, ) MILVUS_PORT: PositiveInt = Field( - description='Milvus RestFul API port', + description="Milvus RestFul API port", default=9091, ) MILVUS_USER: Optional[str] = Field( - description='Milvus user', + description="Milvus user", default=None, ) MILVUS_PASSWORD: Optional[str] = Field( - description='Milvus password', + description="Milvus password", default=None, ) MILVUS_SECURE: bool = Field( - description='whether to use SSL connection for Milvus', + description="whether to use SSL connection for Milvus", default=False, ) MILVUS_DATABASE: str = Field( - description='Milvus database, default to `default`', - default='default', + description="Milvus database, default to `default`", + default="default", ) diff --git a/api/configs/middleware/vdb/myscale_config.py b/api/configs/middleware/vdb/myscale_config.py index 895cd6f176..6451d26e1c 100644 --- a/api/configs/middleware/vdb/myscale_config.py +++ b/api/configs/middleware/vdb/myscale_config.py @@ -1,4 +1,3 @@ - from pydantic import BaseModel, Field, PositiveInt @@ -8,31 +7,31 @@ class MyScaleConfig(BaseModel): """ MYSCALE_HOST: str = Field( - description='MyScale host', - default='localhost', + description="MyScale host", + default="localhost", ) MYSCALE_PORT: PositiveInt = Field( - description='MyScale port', + description="MyScale port", default=8123, ) MYSCALE_USER: str = Field( - description='MyScale user', - default='default', + description="MyScale user", + default="default", ) MYSCALE_PASSWORD: str = Field( - description='MyScale password', - default='', + description="MyScale password", + default="", ) MYSCALE_DATABASE: str = Field( - description='MyScale database name', - default='default', + description="MyScale database name", + default="default", ) MYSCALE_FTS_PARAMS: str = Field( - description='MyScale fts index parameters', - default='', + description="MyScale fts index parameters", + default="", ) diff --git a/api/configs/middleware/vdb/opensearch_config.py b/api/configs/middleware/vdb/opensearch_config.py index 15d6f5b6a9..5823dc1433 100644 --- a/api/configs/middleware/vdb/opensearch_config.py +++ b/api/configs/middleware/vdb/opensearch_config.py @@ -10,26 +10,26 @@ class OpenSearchConfig(BaseSettings): """ OPENSEARCH_HOST: Optional[str] = Field( - description='OpenSearch host', + description="OpenSearch host", default=None, ) OPENSEARCH_PORT: PositiveInt = Field( - description='OpenSearch port', + description="OpenSearch port", default=9200, ) OPENSEARCH_USER: Optional[str] = Field( - description='OpenSearch user', + description="OpenSearch user", default=None, ) OPENSEARCH_PASSWORD: Optional[str] = Field( - description='OpenSearch password', + description="OpenSearch password", default=None, ) OPENSEARCH_SECURE: bool = Field( - description='whether to use SSL connection for OpenSearch', + description="whether to use SSL connection for OpenSearch", default=False, ) diff --git a/api/configs/middleware/vdb/oracle_config.py b/api/configs/middleware/vdb/oracle_config.py index 888fc19492..62614ae870 100644 --- a/api/configs/middleware/vdb/oracle_config.py +++ b/api/configs/middleware/vdb/oracle_config.py @@ -10,26 +10,26 @@ class OracleConfig(BaseSettings): """ ORACLE_HOST: Optional[str] = Field( - description='ORACLE host', + description="ORACLE host", default=None, ) ORACLE_PORT: Optional[PositiveInt] = Field( - description='ORACLE port', + description="ORACLE port", default=1521, ) ORACLE_USER: Optional[str] = Field( - description='ORACLE user', + description="ORACLE user", default=None, ) ORACLE_PASSWORD: Optional[str] = Field( - description='ORACLE password', + description="ORACLE password", default=None, ) ORACLE_DATABASE: Optional[str] = Field( - description='ORACLE database', + description="ORACLE database", default=None, ) diff --git a/api/configs/middleware/vdb/pgvector_config.py b/api/configs/middleware/vdb/pgvector_config.py index 8a677f60a3..39a7c1d8d5 100644 --- a/api/configs/middleware/vdb/pgvector_config.py +++ b/api/configs/middleware/vdb/pgvector_config.py @@ -10,26 +10,26 @@ class PGVectorConfig(BaseSettings): """ PGVECTOR_HOST: Optional[str] = Field( - description='PGVector host', + description="PGVector host", default=None, ) PGVECTOR_PORT: Optional[PositiveInt] = Field( - description='PGVector port', + description="PGVector port", default=5433, ) PGVECTOR_USER: Optional[str] = Field( - description='PGVector user', + description="PGVector user", default=None, ) PGVECTOR_PASSWORD: Optional[str] = Field( - description='PGVector password', + description="PGVector password", default=None, ) PGVECTOR_DATABASE: Optional[str] = Field( - description='PGVector database', + description="PGVector database", default=None, ) diff --git a/api/configs/middleware/vdb/pgvectors_config.py b/api/configs/middleware/vdb/pgvectors_config.py index 39f52f22ff..c40e5ff921 100644 --- a/api/configs/middleware/vdb/pgvectors_config.py +++ b/api/configs/middleware/vdb/pgvectors_config.py @@ -10,26 +10,26 @@ class PGVectoRSConfig(BaseSettings): """ PGVECTO_RS_HOST: Optional[str] = Field( - description='PGVectoRS host', + description="PGVectoRS host", default=None, ) PGVECTO_RS_PORT: Optional[PositiveInt] = Field( - description='PGVectoRS port', + description="PGVectoRS port", default=5431, ) PGVECTO_RS_USER: Optional[str] = Field( - description='PGVectoRS user', + description="PGVectoRS user", default=None, ) PGVECTO_RS_PASSWORD: Optional[str] = Field( - description='PGVectoRS password', + description="PGVectoRS password", default=None, ) PGVECTO_RS_DATABASE: Optional[str] = Field( - description='PGVectoRS database', + description="PGVectoRS database", default=None, ) diff --git a/api/configs/middleware/vdb/qdrant_config.py b/api/configs/middleware/vdb/qdrant_config.py index c85bf9c7dc..27f75491c9 100644 --- a/api/configs/middleware/vdb/qdrant_config.py +++ b/api/configs/middleware/vdb/qdrant_config.py @@ -10,26 +10,26 @@ class QdrantConfig(BaseSettings): """ QDRANT_URL: Optional[str] = Field( - description='Qdrant url', + description="Qdrant url", default=None, ) QDRANT_API_KEY: Optional[str] = Field( - description='Qdrant api key', + description="Qdrant api key", default=None, ) QDRANT_CLIENT_TIMEOUT: NonNegativeInt = Field( - description='Qdrant client timeout in seconds', + description="Qdrant client timeout in seconds", default=20, ) QDRANT_GRPC_ENABLED: bool = Field( - description='whether enable grpc support for Qdrant connection', + description="whether enable grpc support for Qdrant connection", default=False, ) QDRANT_GRPC_PORT: PositiveInt = Field( - description='Qdrant grpc port', + description="Qdrant grpc port", default=6334, ) diff --git a/api/configs/middleware/vdb/relyt_config.py b/api/configs/middleware/vdb/relyt_config.py index be93185f3c..66b9ecc03f 100644 --- a/api/configs/middleware/vdb/relyt_config.py +++ b/api/configs/middleware/vdb/relyt_config.py @@ -10,26 +10,26 @@ class RelytConfig(BaseSettings): """ RELYT_HOST: Optional[str] = Field( - description='Relyt host', + description="Relyt host", default=None, ) RELYT_PORT: PositiveInt = Field( - description='Relyt port', + description="Relyt port", default=9200, ) RELYT_USER: Optional[str] = Field( - description='Relyt user', + description="Relyt user", default=None, ) RELYT_PASSWORD: Optional[str] = Field( - description='Relyt password', + description="Relyt password", default=None, ) RELYT_DATABASE: Optional[str] = Field( - description='Relyt database', - default='default', + description="Relyt database", + default="default", ) diff --git a/api/configs/middleware/vdb/tencent_vector_config.py b/api/configs/middleware/vdb/tencent_vector_config.py index 531ec84068..46b4cb6a24 100644 --- a/api/configs/middleware/vdb/tencent_vector_config.py +++ b/api/configs/middleware/vdb/tencent_vector_config.py @@ -10,41 +10,41 @@ class TencentVectorDBConfig(BaseSettings): """ TENCENT_VECTOR_DB_URL: Optional[str] = Field( - description='Tencent Vector URL', + description="Tencent Vector URL", default=None, ) TENCENT_VECTOR_DB_API_KEY: Optional[str] = Field( - description='Tencent Vector API key', + description="Tencent Vector API key", default=None, ) TENCENT_VECTOR_DB_TIMEOUT: PositiveInt = Field( - description='Tencent Vector timeout in seconds', + description="Tencent Vector timeout in seconds", default=30, ) TENCENT_VECTOR_DB_USERNAME: Optional[str] = Field( - description='Tencent Vector username', + description="Tencent Vector username", default=None, ) TENCENT_VECTOR_DB_PASSWORD: Optional[str] = Field( - description='Tencent Vector password', + description="Tencent Vector password", default=None, ) TENCENT_VECTOR_DB_SHARD: PositiveInt = Field( - description='Tencent Vector sharding number', + description="Tencent Vector sharding number", default=1, ) TENCENT_VECTOR_DB_REPLICAS: NonNegativeInt = Field( - description='Tencent Vector replicas', + description="Tencent Vector replicas", default=2, ) TENCENT_VECTOR_DB_DATABASE: Optional[str] = Field( - description='Tencent Vector Database', + description="Tencent Vector Database", default=None, ) diff --git a/api/configs/middleware/vdb/tidb_vector_config.py b/api/configs/middleware/vdb/tidb_vector_config.py index 8d459691a8..dbcb276c01 100644 --- a/api/configs/middleware/vdb/tidb_vector_config.py +++ b/api/configs/middleware/vdb/tidb_vector_config.py @@ -10,26 +10,26 @@ class TiDBVectorConfig(BaseSettings): """ TIDB_VECTOR_HOST: Optional[str] = Field( - description='TiDB Vector host', + description="TiDB Vector host", default=None, ) TIDB_VECTOR_PORT: Optional[PositiveInt] = Field( - description='TiDB Vector port', + description="TiDB Vector port", default=4000, ) TIDB_VECTOR_USER: Optional[str] = Field( - description='TiDB Vector user', + description="TiDB Vector user", default=None, ) TIDB_VECTOR_PASSWORD: Optional[str] = Field( - description='TiDB Vector password', + description="TiDB Vector password", default=None, ) TIDB_VECTOR_DATABASE: Optional[str] = Field( - description='TiDB Vector database', + description="TiDB Vector database", default=None, ) diff --git a/api/configs/middleware/vdb/weaviate_config.py b/api/configs/middleware/vdb/weaviate_config.py index b985ecea12..63d1022f6a 100644 --- a/api/configs/middleware/vdb/weaviate_config.py +++ b/api/configs/middleware/vdb/weaviate_config.py @@ -10,21 +10,21 @@ class WeaviateConfig(BaseSettings): """ WEAVIATE_ENDPOINT: Optional[str] = Field( - description='Weaviate endpoint URL', + description="Weaviate endpoint URL", default=None, ) WEAVIATE_API_KEY: Optional[str] = Field( - description='Weaviate API key', + description="Weaviate API key", default=None, ) WEAVIATE_GRPC_ENABLED: bool = Field( - description='whether to enable gRPC for Weaviate connection', + description="whether to enable gRPC for Weaviate connection", default=True, ) WEAVIATE_BATCH_SIZE: PositiveInt = Field( - description='Weaviate batch size', + description="Weaviate batch size", default=100, ) diff --git a/api/configs/packaging/__init__.py b/api/configs/packaging/__init__.py index a7c5eb15a3..2d540ca584 100644 --- a/api/configs/packaging/__init__.py +++ b/api/configs/packaging/__init__.py @@ -8,11 +8,11 @@ class PackagingInfo(BaseSettings): """ CURRENT_VERSION: str = Field( - description='Dify version', - default='0.7.1', + description="Dify version", + default="0.7.3", ) COMMIT_SHA: str = Field( description="SHA-1 checksum of the git commit used to build the app", - default='', + default="", ) diff --git a/api/controllers/__init__.py b/api/controllers/__init__.py index b28b04f643..8b13789179 100644 --- a/api/controllers/__init__.py +++ b/api/controllers/__init__.py @@ -1,3 +1 @@ - - diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index b2b9d8d496..eb7c1464d3 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -2,7 +2,7 @@ from flask import Blueprint from libs.external_api import ExternalApi -bp = Blueprint('console', __name__, url_prefix='/console/api') +bp = Blueprint("console", __name__, url_prefix="/console/api") api = ExternalApi(bp) # Import other controllers diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index 028be5de54..a4ceec2662 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -15,24 +15,24 @@ from models.model import App, InstalledApp, RecommendedApp def admin_required(view): @wraps(view) def decorated(*args, **kwargs): - if not os.getenv('ADMIN_API_KEY'): - raise Unauthorized('API key is invalid.') + if not os.getenv("ADMIN_API_KEY"): + raise Unauthorized("API key is invalid.") - auth_header = request.headers.get('Authorization') + auth_header = request.headers.get("Authorization") if auth_header is None: - raise Unauthorized('Authorization header is missing.') + raise Unauthorized("Authorization header is missing.") - if ' ' not in auth_header: - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + if " " not in auth_header: + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") auth_scheme, auth_token = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() - if auth_scheme != 'bearer': - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + if auth_scheme != "bearer": + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - if os.getenv('ADMIN_API_KEY') != auth_token: - raise Unauthorized('API key is invalid.') + if os.getenv("ADMIN_API_KEY") != auth_token: + raise Unauthorized("API key is invalid.") return view(*args, **kwargs) @@ -44,37 +44,41 @@ class InsertExploreAppListApi(Resource): @admin_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('app_id', type=str, required=True, nullable=False, location='json') - parser.add_argument('desc', type=str, location='json') - parser.add_argument('copyright', type=str, location='json') - parser.add_argument('privacy_policy', type=str, location='json') - parser.add_argument('custom_disclaimer', type=str, location='json') - parser.add_argument('language', type=supported_language, required=True, nullable=False, location='json') - parser.add_argument('category', type=str, required=True, nullable=False, location='json') - parser.add_argument('position', type=int, required=True, nullable=False, location='json') + parser.add_argument("app_id", type=str, required=True, nullable=False, location="json") + parser.add_argument("desc", type=str, location="json") + parser.add_argument("copyright", type=str, location="json") + parser.add_argument("privacy_policy", type=str, location="json") + parser.add_argument("custom_disclaimer", type=str, location="json") + parser.add_argument("language", type=supported_language, required=True, nullable=False, location="json") + parser.add_argument("category", type=str, required=True, nullable=False, location="json") + parser.add_argument("position", type=int, required=True, nullable=False, location="json") args = parser.parse_args() - app = App.query.filter(App.id == args['app_id']).first() + app = App.query.filter(App.id == args["app_id"]).first() if not app: raise NotFound(f'App \'{args["app_id"]}\' is not found') site = app.site if not site: - desc = args['desc'] if args['desc'] else '' - copy_right = args['copyright'] if args['copyright'] else '' - privacy_policy = args['privacy_policy'] if args['privacy_policy'] else '' - custom_disclaimer = args['custom_disclaimer'] if args['custom_disclaimer'] else '' + desc = args["desc"] if args["desc"] else "" + copy_right = args["copyright"] if args["copyright"] else "" + privacy_policy = args["privacy_policy"] if args["privacy_policy"] else "" + custom_disclaimer = args["custom_disclaimer"] if args["custom_disclaimer"] else "" else: - desc = site.description if site.description else \ - args['desc'] if args['desc'] else '' - copy_right = site.copyright if site.copyright else \ - args['copyright'] if args['copyright'] else '' - privacy_policy = site.privacy_policy if site.privacy_policy else \ - args['privacy_policy'] if args['privacy_policy'] else '' - custom_disclaimer = site.custom_disclaimer if site.custom_disclaimer else \ - args['custom_disclaimer'] if args['custom_disclaimer'] else '' + desc = site.description if site.description else args["desc"] if args["desc"] else "" + copy_right = site.copyright if site.copyright else args["copyright"] if args["copyright"] else "" + privacy_policy = ( + site.privacy_policy if site.privacy_policy else args["privacy_policy"] if args["privacy_policy"] else "" + ) + custom_disclaimer = ( + site.custom_disclaimer + if site.custom_disclaimer + else args["custom_disclaimer"] + if args["custom_disclaimer"] + else "" + ) - recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first() + recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first() if not recommended_app: recommended_app = RecommendedApp( @@ -83,9 +87,9 @@ class InsertExploreAppListApi(Resource): copyright=copy_right, privacy_policy=privacy_policy, custom_disclaimer=custom_disclaimer, - language=args['language'], - category=args['category'], - position=args['position'] + language=args["language"], + category=args["category"], + position=args["position"], ) db.session.add(recommended_app) @@ -93,21 +97,21 @@ class InsertExploreAppListApi(Resource): app.is_public = True db.session.commit() - return {'result': 'success'}, 201 + return {"result": "success"}, 201 else: recommended_app.description = desc recommended_app.copyright = copy_right recommended_app.privacy_policy = privacy_policy recommended_app.custom_disclaimer = custom_disclaimer - recommended_app.language = args['language'] - recommended_app.category = args['category'] - recommended_app.position = args['position'] + recommended_app.language = args["language"] + recommended_app.category = args["category"] + recommended_app.position = args["position"] app.is_public = True db.session.commit() - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class InsertExploreAppApi(Resource): @@ -116,15 +120,14 @@ class InsertExploreAppApi(Resource): def delete(self, app_id): recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == str(app_id)).first() if not recommended_app: - return {'result': 'success'}, 204 + return {"result": "success"}, 204 app = App.query.filter(App.id == recommended_app.app_id).first() if app: app.is_public = False installed_apps = InstalledApp.query.filter( - InstalledApp.app_id == recommended_app.app_id, - InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id + InstalledApp.app_id == recommended_app.app_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id ).all() for installed_app in installed_apps: @@ -133,8 +136,8 @@ class InsertExploreAppApi(Resource): db.session.delete(recommended_app) db.session.commit() - return {'result': 'success'}, 204 + return {"result": "success"}, 204 -api.add_resource(InsertExploreAppListApi, '/admin/insert-explore-apps') -api.add_resource(InsertExploreAppApi, '/admin/insert-explore-apps/') +api.add_resource(InsertExploreAppListApi, "/admin/insert-explore-apps") +api.add_resource(InsertExploreAppApi, "/admin/insert-explore-apps/") diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 324b831175..3f5e1adca2 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -14,26 +14,21 @@ from .setup import setup_required from .wraps import account_initialization_required api_key_fields = { - 'id': fields.String, - 'type': fields.String, - 'token': fields.String, - 'last_used_at': TimestampField, - 'created_at': TimestampField + "id": fields.String, + "type": fields.String, + "token": fields.String, + "last_used_at": TimestampField, + "created_at": TimestampField, } -api_key_list = { - 'data': fields.List(fields.Nested(api_key_fields), attribute="items") -} +api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="items")} def _get_resource(resource_id, tenant_id, resource_model): - resource = resource_model.query.filter_by( - id=resource_id, tenant_id=tenant_id - ).first() + resource = resource_model.query.filter_by(id=resource_id, tenant_id=tenant_id).first() if resource is None: - flask_restful.abort( - 404, message=f"{resource_model.__name__} not found.") + flask_restful.abort(404, message=f"{resource_model.__name__} not found.") return resource @@ -50,30 +45,32 @@ class BaseApiKeyListResource(Resource): @marshal_with(api_key_list) def get(self, resource_id): resource_id = str(resource_id) - _get_resource(resource_id, current_user.current_tenant_id, - self.resource_model) - keys = db.session.query(ApiToken). \ - filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id). \ - all() + _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) + keys = ( + db.session.query(ApiToken) + .filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) + .all() + ) return {"items": keys} @marshal_with(api_key_fields) def post(self, resource_id): resource_id = str(resource_id) - _get_resource(resource_id, current_user.current_tenant_id, - self.resource_model) + _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) if not current_user.is_admin_or_owner: raise Forbidden() - current_key_count = db.session.query(ApiToken). \ - filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id). \ - count() + current_key_count = ( + db.session.query(ApiToken) + .filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) + .count() + ) if current_key_count >= self.max_keys: flask_restful.abort( 400, message=f"Cannot create more than {self.max_keys} API keys for this resource type.", - code='max_keys_exceeded' + code="max_keys_exceeded", ) key = ApiToken.generate_api_key(self.token_prefix, 24) @@ -97,79 +94,78 @@ class BaseApiKeyResource(Resource): def delete(self, resource_id, api_key_id): resource_id = str(resource_id) api_key_id = str(api_key_id) - _get_resource(resource_id, current_user.current_tenant_id, - self.resource_model) + _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: raise Forbidden() - key = db.session.query(ApiToken). \ - filter(getattr(ApiToken, self.resource_id_field) == resource_id, ApiToken.type == self.resource_type, ApiToken.id == api_key_id). \ - first() + key = ( + db.session.query(ApiToken) + .filter( + getattr(ApiToken, self.resource_id_field) == resource_id, + ApiToken.type == self.resource_type, + ApiToken.id == api_key_id, + ) + .first() + ) if key is None: - flask_restful.abort(404, message='API key not found') + flask_restful.abort(404, message="API key not found") db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete() db.session.commit() - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class AppApiKeyListResource(BaseApiKeyListResource): - def after_request(self, resp): - resp.headers['Access-Control-Allow-Origin'] = '*' - resp.headers['Access-Control-Allow-Credentials'] = 'true' + resp.headers["Access-Control-Allow-Origin"] = "*" + resp.headers["Access-Control-Allow-Credentials"] = "true" return resp - resource_type = 'app' + resource_type = "app" resource_model = App - resource_id_field = 'app_id' - token_prefix = 'app-' + resource_id_field = "app_id" + token_prefix = "app-" class AppApiKeyResource(BaseApiKeyResource): - def after_request(self, resp): - resp.headers['Access-Control-Allow-Origin'] = '*' - resp.headers['Access-Control-Allow-Credentials'] = 'true' + resp.headers["Access-Control-Allow-Origin"] = "*" + resp.headers["Access-Control-Allow-Credentials"] = "true" return resp - resource_type = 'app' + resource_type = "app" resource_model = App - resource_id_field = 'app_id' + resource_id_field = "app_id" class DatasetApiKeyListResource(BaseApiKeyListResource): - def after_request(self, resp): - resp.headers['Access-Control-Allow-Origin'] = '*' - resp.headers['Access-Control-Allow-Credentials'] = 'true' + resp.headers["Access-Control-Allow-Origin"] = "*" + resp.headers["Access-Control-Allow-Credentials"] = "true" return resp - resource_type = 'dataset' + resource_type = "dataset" resource_model = Dataset - resource_id_field = 'dataset_id' - token_prefix = 'ds-' + resource_id_field = "dataset_id" + token_prefix = "ds-" class DatasetApiKeyResource(BaseApiKeyResource): - def after_request(self, resp): - resp.headers['Access-Control-Allow-Origin'] = '*' - resp.headers['Access-Control-Allow-Credentials'] = 'true' + resp.headers["Access-Control-Allow-Origin"] = "*" + resp.headers["Access-Control-Allow-Credentials"] = "true" return resp - resource_type = 'dataset' + + resource_type = "dataset" resource_model = Dataset - resource_id_field = 'dataset_id' + resource_id_field = "dataset_id" -api.add_resource(AppApiKeyListResource, '/apps//api-keys') -api.add_resource(AppApiKeyResource, - '/apps//api-keys/') -api.add_resource(DatasetApiKeyListResource, - '/datasets//api-keys') -api.add_resource(DatasetApiKeyResource, - '/datasets//api-keys/') +api.add_resource(AppApiKeyListResource, "/apps//api-keys") +api.add_resource(AppApiKeyResource, "/apps//api-keys/") +api.add_resource(DatasetApiKeyListResource, "/datasets//api-keys") +api.add_resource(DatasetApiKeyResource, "/datasets//api-keys/") diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py index fa2b3807e8..e7346bdf1d 100644 --- a/api/controllers/console/app/advanced_prompt_template.py +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -8,19 +8,18 @@ from services.advanced_prompt_template_service import AdvancedPromptTemplateServ class AdvancedPromptTemplateList(Resource): - @setup_required @login_required @account_initialization_required def get(self): - parser = reqparse.RequestParser() - parser.add_argument('app_mode', type=str, required=True, location='args') - parser.add_argument('model_mode', type=str, required=True, location='args') - parser.add_argument('has_context', type=str, required=False, default='true', location='args') - parser.add_argument('model_name', type=str, required=True, location='args') + parser.add_argument("app_mode", type=str, required=True, location="args") + parser.add_argument("model_mode", type=str, required=True, location="args") + parser.add_argument("has_context", type=str, required=False, default="true", location="args") + parser.add_argument("model_name", type=str, required=True, location="args") args = parser.parse_args() return AdvancedPromptTemplateService.get_prompt(args) -api.add_resource(AdvancedPromptTemplateList, '/app/prompt-templates') \ No newline at end of file + +api.add_resource(AdvancedPromptTemplateList, "/app/prompt-templates") diff --git a/api/controllers/console/app/agent.py b/api/controllers/console/app/agent.py index aee367276c..51899da705 100644 --- a/api/controllers/console/app/agent.py +++ b/api/controllers/console/app/agent.py @@ -18,15 +18,12 @@ class AgentLogApi(Resource): def get(self, app_model): """Get agent logs""" parser = reqparse.RequestParser() - parser.add_argument('message_id', type=uuid_value, required=True, location='args') - parser.add_argument('conversation_id', type=uuid_value, required=True, location='args') + parser.add_argument("message_id", type=uuid_value, required=True, location="args") + parser.add_argument("conversation_id", type=uuid_value, required=True, location="args") args = parser.parse_args() - return AgentService.get_agent_logs( - app_model, - args['conversation_id'], - args['message_id'] - ) - -api.add_resource(AgentLogApi, '/apps//agent/logs') \ No newline at end of file + return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"]) + + +api.add_resource(AgentLogApi, "/apps//agent/logs") diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index bc15919a99..1ea1c82679 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -21,23 +21,23 @@ class AnnotationReplyActionApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") def post(self, app_id, action): if not current_user.is_editor: raise Forbidden() app_id = str(app_id) parser = reqparse.RequestParser() - parser.add_argument('score_threshold', required=True, type=float, location='json') - parser.add_argument('embedding_provider_name', required=True, type=str, location='json') - parser.add_argument('embedding_model_name', required=True, type=str, location='json') + parser.add_argument("score_threshold", required=True, type=float, location="json") + parser.add_argument("embedding_provider_name", required=True, type=str, location="json") + parser.add_argument("embedding_model_name", required=True, type=str, location="json") args = parser.parse_args() - if action == 'enable': + if action == "enable": result = AppAnnotationService.enable_app_annotation(args, app_id) - elif action == 'disable': + elif action == "disable": result = AppAnnotationService.disable_app_annotation(app_id) else: - raise ValueError('Unsupported annotation reply action') + raise ValueError("Unsupported annotation reply action") return result, 200 @@ -66,7 +66,7 @@ class AppAnnotationSettingUpdateApi(Resource): annotation_setting_id = str(annotation_setting_id) parser = reqparse.RequestParser() - parser.add_argument('score_threshold', required=True, type=float, location='json') + parser.add_argument("score_threshold", required=True, type=float, location="json") args = parser.parse_args() result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args) @@ -77,28 +77,24 @@ class AnnotationReplyActionStatusApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") def get(self, app_id, job_id, action): if not current_user.is_editor: raise Forbidden() job_id = str(job_id) - app_annotation_job_key = '{}_app_annotation_job_{}'.format(action, str(job_id)) + app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id)) cache_result = redis_client.get(app_annotation_job_key) if cache_result is None: raise ValueError("The job is not exist.") job_status = cache_result.decode() - error_msg = '' - if job_status == 'error': - app_annotation_error_key = '{}_app_annotation_error_{}'.format(action, str(job_id)) + error_msg = "" + if job_status == "error": + app_annotation_error_key = "{}_app_annotation_error_{}".format(action, str(job_id)) error_msg = redis_client.get(app_annotation_error_key).decode() - return { - 'job_id': job_id, - 'job_status': job_status, - 'error_msg': error_msg - }, 200 + return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 class AnnotationListApi(Resource): @@ -109,18 +105,18 @@ class AnnotationListApi(Resource): if not current_user.is_editor: raise Forbidden() - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) - keyword = request.args.get('keyword', default=None, type=str) + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) + keyword = request.args.get("keyword", default=None, type=str) app_id = str(app_id) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword) response = { - 'data': marshal(annotation_list, annotation_fields), - 'has_more': len(annotation_list) == limit, - 'limit': limit, - 'total': total, - 'page': page + "data": marshal(annotation_list, annotation_fields), + "has_more": len(annotation_list) == limit, + "limit": limit, + "total": total, + "page": page, } return response, 200 @@ -135,9 +131,7 @@ class AnnotationExportApi(Resource): app_id = str(app_id) annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id) - response = { - 'data': marshal(annotation_list, annotation_fields) - } + response = {"data": marshal(annotation_list, annotation_fields)} return response, 200 @@ -145,7 +139,7 @@ class AnnotationCreateApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") @marshal_with(annotation_fields) def post(self, app_id): if not current_user.is_editor: @@ -153,8 +147,8 @@ class AnnotationCreateApi(Resource): app_id = str(app_id) parser = reqparse.RequestParser() - parser.add_argument('question', required=True, type=str, location='json') - parser.add_argument('answer', required=True, type=str, location='json') + parser.add_argument("question", required=True, type=str, location="json") + parser.add_argument("answer", required=True, type=str, location="json") args = parser.parse_args() annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id) return annotation @@ -164,7 +158,7 @@ class AnnotationUpdateDeleteApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") @marshal_with(annotation_fields) def post(self, app_id, annotation_id): if not current_user.is_editor: @@ -173,8 +167,8 @@ class AnnotationUpdateDeleteApi(Resource): app_id = str(app_id) annotation_id = str(annotation_id) parser = reqparse.RequestParser() - parser.add_argument('question', required=True, type=str, location='json') - parser.add_argument('answer', required=True, type=str, location='json') + parser.add_argument("question", required=True, type=str, location="json") + parser.add_argument("answer", required=True, type=str, location="json") args = parser.parse_args() annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id) return annotation @@ -189,29 +183,29 @@ class AnnotationUpdateDeleteApi(Resource): app_id = str(app_id) annotation_id = str(annotation_id) AppAnnotationService.delete_app_annotation(app_id, annotation_id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class AnnotationBatchImportApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") def post(self, app_id): if not current_user.is_editor: raise Forbidden() app_id = str(app_id) # get file from request - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: raise TooManyFilesError() # check file type - if not file.filename.endswith('.csv'): + if not file.filename.endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") return AppAnnotationService.batch_import_app_annotations(app_id, file) @@ -220,27 +214,23 @@ class AnnotationBatchImportStatusApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") def get(self, app_id, job_id): if not current_user.is_editor: raise Forbidden() job_id = str(job_id) - indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id)) + indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id)) cache_result = redis_client.get(indexing_cache_key) if cache_result is None: raise ValueError("The job is not exist.") job_status = cache_result.decode() - error_msg = '' - if job_status == 'error': - indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id)) + error_msg = "" + if job_status == "error": + indexing_error_msg_key = "app_annotation_batch_import_error_msg_{}".format(str(job_id)) error_msg = redis_client.get(indexing_error_msg_key).decode() - return { - 'job_id': job_id, - 'job_status': job_status, - 'error_msg': error_msg - }, 200 + return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 class AnnotationHitHistoryListApi(Resource): @@ -251,30 +241,32 @@ class AnnotationHitHistoryListApi(Resource): if not current_user.is_editor: raise Forbidden() - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) app_id = str(app_id) annotation_id = str(annotation_id) - annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(app_id, annotation_id, - page, limit) + annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories( + app_id, annotation_id, page, limit + ) response = { - 'data': marshal(annotation_hit_history_list, annotation_hit_history_fields), - 'has_more': len(annotation_hit_history_list) == limit, - 'limit': limit, - 'total': total, - 'page': page + "data": marshal(annotation_hit_history_list, annotation_hit_history_fields), + "has_more": len(annotation_hit_history_list) == limit, + "limit": limit, + "total": total, + "page": page, } return response -api.add_resource(AnnotationReplyActionApi, '/apps//annotation-reply/') -api.add_resource(AnnotationReplyActionStatusApi, - '/apps//annotation-reply//status/') -api.add_resource(AnnotationListApi, '/apps//annotations') -api.add_resource(AnnotationExportApi, '/apps//annotations/export') -api.add_resource(AnnotationUpdateDeleteApi, '/apps//annotations/') -api.add_resource(AnnotationBatchImportApi, '/apps//annotations/batch-import') -api.add_resource(AnnotationBatchImportStatusApi, '/apps//annotations/batch-import-status/') -api.add_resource(AnnotationHitHistoryListApi, '/apps//annotations//hit-histories') -api.add_resource(AppAnnotationSettingDetailApi, '/apps//annotation-setting') -api.add_resource(AppAnnotationSettingUpdateApi, '/apps//annotation-settings/') +api.add_resource(AnnotationReplyActionApi, "/apps//annotation-reply/") +api.add_resource( + AnnotationReplyActionStatusApi, "/apps//annotation-reply//status/" +) +api.add_resource(AnnotationListApi, "/apps//annotations") +api.add_resource(AnnotationExportApi, "/apps//annotations/export") +api.add_resource(AnnotationUpdateDeleteApi, "/apps//annotations/") +api.add_resource(AnnotationBatchImportApi, "/apps//annotations/batch-import") +api.add_resource(AnnotationBatchImportStatusApi, "/apps//annotations/batch-import-status/") +api.add_resource(AnnotationHitHistoryListApi, "/apps//annotations//hit-histories") +api.add_resource(AppAnnotationSettingDetailApi, "/apps//annotation-setting") +api.add_resource(AppAnnotationSettingUpdateApi, "/apps//annotation-settings/") diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 8651597fd7..1b46a3a7d3 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -18,27 +18,35 @@ from libs.login import login_required from services.app_dsl_service import AppDslService from services.app_service import AppService -ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion'] +ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] class AppListApi(Resource): - @setup_required @login_required @account_initialization_required def get(self): """Get app list""" + def uuid_list(value): try: - return [str(uuid.UUID(v)) for v in value.split(',')] + return [str(uuid.UUID(v)) for v in value.split(",")] except ValueError: abort(400, message="Invalid UUID format in tag_ids.") + parser = reqparse.RequestParser() - parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args') - parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args') - parser.add_argument('mode', type=str, choices=['chat', 'workflow', 'agent-chat', 'channel', 'all'], default='all', location='args', required=False) - parser.add_argument('name', type=str, location='args', required=False) - parser.add_argument('tag_ids', type=uuid_list, location='args', required=False) + parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") + parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") + parser.add_argument( + "mode", + type=str, + choices=["chat", "workflow", "agent-chat", "channel", "all"], + default="all", + location="args", + required=False, + ) + parser.add_argument("name", type=str, location="args", required=False) + parser.add_argument("tag_ids", type=uuid_list, location="args", required=False) args = parser.parse_args() @@ -46,7 +54,7 @@ class AppListApi(Resource): app_service = AppService() app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args) if not app_pagination: - return {'data': [], 'total': 0, 'page': 1, 'limit': 20, 'has_more': False} + return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} return marshal(app_pagination, app_pagination_fields) @@ -54,23 +62,23 @@ class AppListApi(Resource): @login_required @account_initialization_required @marshal_with(app_detail_fields) - @cloud_edition_billing_resource_check('apps') + @cloud_edition_billing_resource_check("apps") def post(self): """Create app""" parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') - parser.add_argument('description', type=str, location='json') - parser.add_argument('mode', type=str, choices=ALLOW_CREATE_APP_MODES, location='json') - parser.add_argument('icon_type', type=str, location='json') - parser.add_argument('icon', type=str, location='json') - parser.add_argument('icon_background', type=str, location='json') + parser.add_argument("name", type=str, required=True, location="json") + parser.add_argument("description", type=str, location="json") + parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json") + parser.add_argument("icon_type", type=str, location="json") + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") args = parser.parse_args() # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - if 'mode' not in args or args['mode'] is None: + if "mode" not in args or args["mode"] is None: raise BadRequest("mode is required") app_service = AppService() @@ -84,7 +92,7 @@ class AppImportApi(Resource): @login_required @account_initialization_required @marshal_with(app_detail_fields_with_site) - @cloud_edition_billing_resource_check('apps') + @cloud_edition_billing_resource_check("apps") def post(self): """Import app""" # The role of the current user in the ta table must be admin, owner, or editor @@ -92,19 +100,16 @@ class AppImportApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('data', type=str, required=True, nullable=False, location='json') - parser.add_argument('name', type=str, location='json') - parser.add_argument('description', type=str, location='json') - parser.add_argument('icon_type', type=str, location='json') - parser.add_argument('icon', type=str, location='json') - parser.add_argument('icon_background', type=str, location='json') + parser.add_argument("data", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, location="json") + parser.add_argument("description", type=str, location="json") + parser.add_argument("icon_type", type=str, location="json") + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") args = parser.parse_args() app = AppDslService.import_and_create_new_app( - tenant_id=current_user.current_tenant_id, - data=args['data'], - args=args, - account=current_user + tenant_id=current_user.current_tenant_id, data=args["data"], args=args, account=current_user ) return app, 201 @@ -115,7 +120,7 @@ class AppImportFromUrlApi(Resource): @login_required @account_initialization_required @marshal_with(app_detail_fields_with_site) - @cloud_edition_billing_resource_check('apps') + @cloud_edition_billing_resource_check("apps") def post(self): """Import app from url""" # The role of the current user in the ta table must be admin, owner, or editor @@ -123,25 +128,21 @@ class AppImportFromUrlApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('url', type=str, required=True, nullable=False, location='json') - parser.add_argument('name', type=str, location='json') - parser.add_argument('description', type=str, location='json') - parser.add_argument('icon', type=str, location='json') - parser.add_argument('icon_background', type=str, location='json') + parser.add_argument("url", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, location="json") + parser.add_argument("description", type=str, location="json") + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") args = parser.parse_args() app = AppDslService.import_and_create_new_app_from_url( - tenant_id=current_user.current_tenant_id, - url=args['url'], - args=args, - account=current_user + tenant_id=current_user.current_tenant_id, url=args["url"], args=args, account=current_user ) return app, 201 class AppApi(Resource): - @setup_required @login_required @account_initialization_required @@ -165,14 +166,15 @@ class AppApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, nullable=False, location='json') - parser.add_argument('description', type=str, location='json') - parser.add_argument('icon_type', type=str, location='json') - parser.add_argument('icon', type=str, location='json') - parser.add_argument('icon_background', type=str, location='json') - parser.add_argument('max_active_requests', type=int, location='json') + parser.add_argument("name", type=str, required=True, nullable=False, location="json") + parser.add_argument("description", type=str, location="json") + parser.add_argument("icon_type", type=str, location="json") + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") + parser.add_argument("max_active_requests", type=int, location="json") + parser.add_argument("use_icon_as_answer_icon", type=bool, location="json") args = parser.parse_args() app_service = AppService() @@ -193,7 +195,7 @@ class AppApi(Resource): app_service = AppService() app_service.delete_app(app_model) - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class AppCopyApi(Resource): @@ -209,19 +211,16 @@ class AppCopyApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('name', type=str, location='json') - parser.add_argument('description', type=str, location='json') - parser.add_argument('icon_type', type=str, location='json') - parser.add_argument('icon', type=str, location='json') - parser.add_argument('icon_background', type=str, location='json') + parser.add_argument("name", type=str, location="json") + parser.add_argument("description", type=str, location="json") + parser.add_argument("icon_type", type=str, location="json") + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") args = parser.parse_args() data = AppDslService.export_dsl(app_model=app_model, include_secret=True) app = AppDslService.import_and_create_new_app( - tenant_id=current_user.current_tenant_id, - data=data, - args=args, - account=current_user + tenant_id=current_user.current_tenant_id, data=data, args=args, account=current_user ) return app, 201 @@ -240,12 +239,10 @@ class AppExportApi(Resource): # Add include_secret params parser = reqparse.RequestParser() - parser.add_argument('include_secret', type=inputs.boolean, default=False, location='args') + parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args") args = parser.parse_args() - return { - "data": AppDslService.export_dsl(app_model=app_model, include_secret=args['include_secret']) - } + return {"data": AppDslService.export_dsl(app_model=app_model, include_secret=args["include_secret"])} class AppNameApi(Resource): @@ -258,13 +255,13 @@ class AppNameApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') + parser.add_argument("name", type=str, required=True, location="json") args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_name(app_model, args.get('name')) + app_model = app_service.update_app_name(app_model, args.get("name")) return app_model @@ -279,14 +276,14 @@ class AppIconApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('icon', type=str, location='json') - parser.add_argument('icon_background', type=str, location='json') + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_icon(app_model, args.get('icon'), args.get('icon_background')) + app_model = app_service.update_app_icon(app_model, args.get("icon"), args.get("icon_background")) return app_model @@ -301,13 +298,13 @@ class AppSiteStatus(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('enable_site', type=bool, required=True, location='json') + parser.add_argument("enable_site", type=bool, required=True, location="json") args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_site_status(app_model, args.get('enable_site')) + app_model = app_service.update_app_site_status(app_model, args.get("enable_site")) return app_model @@ -322,13 +319,13 @@ class AppApiStatus(Resource): # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('enable_api', type=bool, required=True, location='json') + parser.add_argument("enable_api", type=bool, required=True, location="json") args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_api_status(app_model, args.get('enable_api')) + app_model = app_service.update_app_api_status(app_model, args.get("enable_api")) return app_model @@ -339,9 +336,7 @@ class AppTraceApi(Resource): @account_initialization_required def get(self, app_id): """Get app trace""" - app_trace_config = OpsTraceManager.get_app_tracing_config( - app_id=app_id - ) + app_trace_config = OpsTraceManager.get_app_tracing_config(app_id=app_id) return app_trace_config @@ -353,27 +348,27 @@ class AppTraceApi(Resource): if not current_user.is_admin_or_owner: raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('enabled', type=bool, required=True, location='json') - parser.add_argument('tracing_provider', type=str, required=True, location='json') + parser.add_argument("enabled", type=bool, required=True, location="json") + parser.add_argument("tracing_provider", type=str, required=True, location="json") args = parser.parse_args() OpsTraceManager.update_app_tracing_config( app_id=app_id, - enabled=args['enabled'], - tracing_provider=args['tracing_provider'], + enabled=args["enabled"], + tracing_provider=args["tracing_provider"], ) return {"result": "success"} -api.add_resource(AppListApi, '/apps') -api.add_resource(AppImportApi, '/apps/import') -api.add_resource(AppImportFromUrlApi, '/apps/import/url') -api.add_resource(AppApi, '/apps/') -api.add_resource(AppCopyApi, '/apps//copy') -api.add_resource(AppExportApi, '/apps//export') -api.add_resource(AppNameApi, '/apps//name') -api.add_resource(AppIconApi, '/apps//icon') -api.add_resource(AppSiteStatus, '/apps//site-enable') -api.add_resource(AppApiStatus, '/apps//api-enable') -api.add_resource(AppTraceApi, '/apps//trace') +api.add_resource(AppListApi, "/apps") +api.add_resource(AppImportApi, "/apps/import") +api.add_resource(AppImportFromUrlApi, "/apps/import/url") +api.add_resource(AppApi, "/apps/") +api.add_resource(AppCopyApi, "/apps//copy") +api.add_resource(AppExportApi, "/apps//export") +api.add_resource(AppNameApi, "/apps//name") +api.add_resource(AppIconApi, "/apps//icon") +api.add_resource(AppSiteStatus, "/apps//site-enable") +api.add_resource(AppApiStatus, "/apps//api-enable") +api.add_resource(AppTraceApi, "/apps//trace") diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 1de08afa4e..437a6a7b38 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -39,7 +39,7 @@ class ChatMessageAudioApi(Resource): @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) def post(self, app_model): - file = request.files['file'] + file = request.files["file"] try: response = AudioService.transcript_asr( @@ -85,31 +85,31 @@ class ChatMessageTextApi(Resource): try: parser = reqparse.RequestParser() - parser.add_argument('message_id', type=str, location='json') - parser.add_argument('text', type=str, location='json') - parser.add_argument('voice', type=str, location='json') - parser.add_argument('streaming', type=bool, location='json') + parser.add_argument("message_id", type=str, location="json") + parser.add_argument("text", type=str, location="json") + parser.add_argument("voice", type=str, location="json") + parser.add_argument("streaming", type=bool, location="json") args = parser.parse_args() - message_id = args.get('message_id', None) - text = args.get('text', None) - if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] - and app_model.workflow - and app_model.workflow.features_dict): - text_to_speech = app_model.workflow.features_dict.get('text_to_speech') - voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice') + message_id = args.get("message_id", None) + text = args.get("text", None) + if ( + app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + and app_model.workflow + and app_model.workflow.features_dict + ): + text_to_speech = app_model.workflow.features_dict.get("text_to_speech") + voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") else: try: - voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get( - 'voice') + voice = ( + args.get("voice") + if args.get("voice") + else app_model.app_model_config.text_to_speech_dict.get("voice") + ) except Exception: voice = None - response = AudioService.transcript_tts( - app_model=app_model, - text=text, - message_id=message_id, - voice=voice - ) + response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice) return response except services.errors.app_model_config.AppModelConfigBrokenError: logging.exception("App model config broken.") @@ -145,12 +145,12 @@ class TextModesApi(Resource): def get(self, app_model): try: parser = reqparse.RequestParser() - parser.add_argument('language', type=str, required=True, location='args') + parser.add_argument("language", type=str, required=True, location="args") args = parser.parse_args() response = AudioService.transcript_tts_voices( tenant_id=app_model.tenant_id, - language=args['language'], + language=args["language"], ) return response @@ -179,6 +179,6 @@ class TextModesApi(Resource): raise InternalServerError() -api.add_resource(ChatMessageAudioApi, '/apps//audio-to-text') -api.add_resource(ChatMessageTextApi, '/apps//text-to-audio') -api.add_resource(TextModesApi, '/apps//text-to-audio/voices') +api.add_resource(ChatMessageAudioApi, "/apps//audio-to-text") +api.add_resource(ChatMessageTextApi, "/apps//text-to-audio") +api.add_resource(TextModesApi, "/apps//text-to-audio/voices") diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 61582536fd..53de51c24d 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -17,6 +17,7 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required +from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ( @@ -31,37 +32,33 @@ from libs.helper import uuid_value from libs.login import login_required from models.model import AppMode from services.app_generate_service import AppGenerateService +from services.errors.llm import InvokeRateLimitError # define completion message api for user class CompletionMessageApi(Resource): - @setup_required @login_required @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) def post(self, app_model): parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, location='json', default='') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('model_config', type=dict, required=True, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, location="json", default="") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("model_config", type=dict, required=True, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") args = parser.parse_args() - streaming = args['response_mode'] != 'blocking' - args['auto_generate_name'] = False + streaming = args["response_mode"] != "blocking" + args["auto_generate_name"] = False account = flask_login.current_user try: response = AppGenerateService.generate( - app_model=app_model, - user=account, - args=args, - invoke_from=InvokeFrom.DEBUGGER, - streaming=streaming + app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming ) return helper.compact_generate_response(response) @@ -97,7 +94,7 @@ class CompletionMessageStopApi(Resource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class ChatMessageApi(Resource): @@ -107,27 +104,23 @@ class ChatMessageApi(Resource): @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) def post(self, app_model): parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, required=True, location='json') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('model_config', type=dict, required=True, location='json') - parser.add_argument('conversation_id', type=uuid_value, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, required=True, location="json") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("model_config", type=dict, required=True, location="json") + parser.add_argument("conversation_id", type=uuid_value, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") args = parser.parse_args() - streaming = args['response_mode'] != 'blocking' - args['auto_generate_name'] = False + streaming = args["response_mode"] != "blocking" + args["auto_generate_name"] = False account = flask_login.current_user try: response = AppGenerateService.generate( - app_model=app_model, - user=account, - args=args, - invoke_from=InvokeFrom.DEBUGGER, - streaming=streaming + app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming ) return helper.compact_generate_response(response) @@ -144,6 +137,8 @@ class ChatMessageApi(Resource): raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() + except InvokeRateLimitError as ex: + raise InvokeRateLimitHttpError(ex.description) except InvokeError as e: raise CompletionRequestError(e.description) except (ValueError, AppInvokeQuotaExceededError) as e: @@ -163,10 +158,10 @@ class ChatMessageStopApi(Resource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(CompletionMessageApi, '/apps//completion-messages') -api.add_resource(CompletionMessageStopApi, '/apps//completion-messages//stop') -api.add_resource(ChatMessageApi, '/apps//chat-messages') -api.add_resource(ChatMessageStopApi, '/apps//chat-messages//stop') +api.add_resource(CompletionMessageApi, "/apps//completion-messages") +api.add_resource(CompletionMessageStopApi, "/apps//completion-messages//stop") +api.add_resource(ChatMessageApi, "/apps//chat-messages") +api.add_resource(ChatMessageStopApi, "/apps//chat-messages//stop") diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index eb61c83d46..753a6be20c 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -26,7 +26,6 @@ from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotat class CompletionConversationApi(Resource): - @setup_required @login_required @account_initialization_required @@ -36,24 +35,23 @@ class CompletionConversationApi(Resource): if not current_user.is_editor: raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('keyword', type=str, location='args') - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('annotation_status', type=str, - choices=['annotated', 'not_annotated', 'all'], default='all', location='args') - parser.add_argument('page', type=int_range(1, 99999), default=1, location='args') - parser.add_argument('limit', type=int_range(1, 100), default=20, location='args') + parser.add_argument("keyword", type=str, location="args") + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument( + "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" + ) + parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") + parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") args = parser.parse_args() - query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == 'completion') + query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == "completion") - if args['keyword']: - query = query.join( - Message, Message.conversation_id == Conversation.id - ).filter( + if args["keyword"]: + query = query.join(Message, Message.conversation_id == Conversation.id).filter( or_( - Message.query.ilike('%{}%'.format(args['keyword'])), - Message.answer.ilike('%{}%'.format(args['keyword'])) + Message.query.ilike("%{}%".format(args["keyword"])), + Message.answer.ilike("%{}%".format(args["keyword"])), ) ) @@ -61,8 +59,8 @@ class CompletionConversationApi(Resource): timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) @@ -70,8 +68,8 @@ class CompletionConversationApi(Resource): query = query.where(Conversation.created_at >= start_datetime_utc) - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=59) end_datetime_timezone = timezone.localize(end_datetime) @@ -79,29 +77,25 @@ class CompletionConversationApi(Resource): query = query.where(Conversation.created_at < end_datetime_utc) - if args['annotation_status'] == "annotated": + if args["annotation_status"] == "annotated": query = query.options(joinedload(Conversation.message_annotations)).join( MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id ) - elif args['annotation_status'] == "not_annotated": - query = query.outerjoin( - MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id - ).group_by(Conversation.id).having(func.count(MessageAnnotation.id) == 0) + elif args["annotation_status"] == "not_annotated": + query = ( + query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) + .group_by(Conversation.id) + .having(func.count(MessageAnnotation.id) == 0) + ) query = query.order_by(Conversation.created_at.desc()) - conversations = db.paginate( - query, - page=args['page'], - per_page=args['limit'], - error_out=False - ) + conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False) return conversations class CompletionConversationDetailApi(Resource): - @setup_required @login_required @account_initialization_required @@ -123,8 +117,11 @@ class CompletionConversationDetailApi(Resource): raise Forbidden() conversation_id = str(conversation_id) - conversation = db.session.query(Conversation) \ - .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first() + conversation = ( + db.session.query(Conversation) + .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) + .first() + ) if not conversation: raise NotFound("Conversation Not Exists.") @@ -132,11 +129,10 @@ class CompletionConversationDetailApi(Resource): conversation.is_deleted = True db.session.commit() - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class ChatConversationApi(Resource): - @setup_required @login_required @account_initialization_required @@ -146,20 +142,28 @@ class ChatConversationApi(Resource): if not current_user.is_editor: raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('keyword', type=str, location='args') - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('annotation_status', type=str, - choices=['annotated', 'not_annotated', 'all'], default='all', location='args') - parser.add_argument('message_count_gte', type=int_range(1, 99999), required=False, location='args') - parser.add_argument('page', type=int_range(1, 99999), required=False, default=1, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("keyword", type=str, location="args") + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument( + "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" + ) + parser.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args") + parser.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + parser.add_argument( + "sort_by", + type=str, + choices=["created_at", "-created_at", "updated_at", "-updated_at"], + required=False, + default="-updated_at", + location="args", + ) args = parser.parse_args() subquery = ( db.session.query( - Conversation.id.label('conversation_id'), - EndUser.session_id.label('from_end_user_session_id') + Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id") ) .outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id) .subquery() @@ -167,28 +171,31 @@ class ChatConversationApi(Resource): query = db.select(Conversation).where(Conversation.app_id == app_model.id) - if args['keyword']: - keyword_filter = '%{}%'.format(args['keyword']) - query = query.join( - Message, Message.conversation_id == Conversation.id, - ).join( - subquery, subquery.c.conversation_id == Conversation.id - ).filter( - or_( - Message.query.ilike(keyword_filter), - Message.answer.ilike(keyword_filter), - Conversation.name.ilike(keyword_filter), - Conversation.introduction.ilike(keyword_filter), - subquery.c.from_end_user_session_id.ilike(keyword_filter) - ), + if args["keyword"]: + keyword_filter = "%{}%".format(args["keyword"]) + query = ( + query.join( + Message, + Message.conversation_id == Conversation.id, + ) + .join(subquery, subquery.c.conversation_id == Conversation.id) + .filter( + or_( + Message.query.ilike(keyword_filter), + Message.answer.ilike(keyword_filter), + Conversation.name.ilike(keyword_filter), + Conversation.introduction.ilike(keyword_filter), + subquery.c.from_end_user_session_id.ilike(keyword_filter), + ), + ) ) account = current_user timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) @@ -196,8 +203,8 @@ class ChatConversationApi(Resource): query = query.where(Conversation.created_at >= start_datetime_utc) - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=59) end_datetime_timezone = timezone.localize(end_datetime) @@ -205,40 +212,46 @@ class ChatConversationApi(Resource): query = query.where(Conversation.created_at < end_datetime_utc) - if args['annotation_status'] == "annotated": + if args["annotation_status"] == "annotated": query = query.options(joinedload(Conversation.message_annotations)).join( MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id ) - elif args['annotation_status'] == "not_annotated": - query = query.outerjoin( - MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id - ).group_by(Conversation.id).having(func.count(MessageAnnotation.id) == 0) + elif args["annotation_status"] == "not_annotated": + query = ( + query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) + .group_by(Conversation.id) + .having(func.count(MessageAnnotation.id) == 0) + ) - if args['message_count_gte'] and args['message_count_gte'] >= 1: + if args["message_count_gte"] and args["message_count_gte"] >= 1: query = ( query.options(joinedload(Conversation.messages)) .join(Message, Message.conversation_id == Conversation.id) .group_by(Conversation.id) - .having(func.count(Message.id) >= args['message_count_gte']) + .having(func.count(Message.id) >= args["message_count_gte"]) ) if app_model.mode == AppMode.ADVANCED_CHAT.value: query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value) - query = query.order_by(Conversation.created_at.desc()) + match args["sort_by"]: + case "created_at": + query = query.order_by(Conversation.created_at.asc()) + case "-created_at": + query = query.order_by(Conversation.created_at.desc()) + case "updated_at": + query = query.order_by(Conversation.updated_at.asc()) + case "-updated_at": + query = query.order_by(Conversation.updated_at.desc()) + case _: + query = query.order_by(Conversation.created_at.desc()) - conversations = db.paginate( - query, - page=args['page'], - per_page=args['limit'], - error_out=False - ) + conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False) return conversations class ChatConversationDetailApi(Resource): - @setup_required @login_required @account_initialization_required @@ -260,8 +273,11 @@ class ChatConversationDetailApi(Resource): raise Forbidden() conversation_id = str(conversation_id) - conversation = db.session.query(Conversation) \ - .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first() + conversation = ( + db.session.query(Conversation) + .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) + .first() + ) if not conversation: raise NotFound("Conversation Not Exists.") @@ -269,18 +285,21 @@ class ChatConversationDetailApi(Resource): conversation.is_deleted = True db.session.commit() - return {'result': 'success'}, 204 + return {"result": "success"}, 204 -api.add_resource(CompletionConversationApi, '/apps//completion-conversations') -api.add_resource(CompletionConversationDetailApi, '/apps//completion-conversations/') -api.add_resource(ChatConversationApi, '/apps//chat-conversations') -api.add_resource(ChatConversationDetailApi, '/apps//chat-conversations/') +api.add_resource(CompletionConversationApi, "/apps//completion-conversations") +api.add_resource(CompletionConversationDetailApi, "/apps//completion-conversations/") +api.add_resource(ChatConversationApi, "/apps//chat-conversations") +api.add_resource(ChatConversationDetailApi, "/apps//chat-conversations/") def _get_conversation(app_model, conversation_id): - conversation = db.session.query(Conversation) \ - .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first() + conversation = ( + db.session.query(Conversation) + .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) + .first() + ) if not conversation: raise NotFound("Conversation Not Exists.") diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index aa0722ea35..23b234dac9 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -21,7 +21,7 @@ class ConversationVariablesApi(Resource): @marshal_with(paginated_conversation_variable_fields) def get(self, app_model): parser = reqparse.RequestParser() - parser.add_argument('conversation_id', type=str, location='args') + parser.add_argument("conversation_id", type=str, location="args") args = parser.parse_args() stmt = ( @@ -29,10 +29,10 @@ class ConversationVariablesApi(Resource): .where(ConversationVariable.app_id == app_model.id) .order_by(ConversationVariable.created_at) ) - if args['conversation_id']: - stmt = stmt.where(ConversationVariable.conversation_id == args['conversation_id']) + if args["conversation_id"]: + stmt = stmt.where(ConversationVariable.conversation_id == args["conversation_id"]) else: - raise ValueError('conversation_id is required') + raise ValueError("conversation_id is required") # NOTE: This is a temporary solution to avoid performance issues. page = 1 @@ -43,14 +43,14 @@ class ConversationVariablesApi(Resource): rows = session.scalars(stmt).all() return { - 'page': page, - 'limit': page_size, - 'total': len(rows), - 'has_more': False, - 'data': [ + "page": page, + "limit": page_size, + "total": len(rows), + "has_more": False, + "data": [ { - 'created_at': row.created_at, - 'updated_at': row.updated_at, + "created_at": row.created_at, + "updated_at": row.updated_at, **row.to_variable().model_dump(), } for row in rows @@ -58,4 +58,4 @@ class ConversationVariablesApi(Resource): } -api.add_resource(ConversationVariablesApi, '/apps//conversation-variables') +api.add_resource(ConversationVariablesApi, "/apps//conversation-variables") diff --git a/api/controllers/console/app/error.py b/api/controllers/console/app/error.py index f6feed1221..1559f82d6e 100644 --- a/api/controllers/console/app/error.py +++ b/api/controllers/console/app/error.py @@ -2,116 +2,128 @@ from libs.exception import BaseHTTPException class AppNotFoundError(BaseHTTPException): - error_code = 'app_not_found' + error_code = "app_not_found" description = "App not found." code = 404 class ProviderNotInitializeError(BaseHTTPException): - error_code = 'provider_not_initialize' - description = "No valid model provider credentials found. " \ - "Please go to Settings -> Model Provider to complete your provider credentials." + error_code = "provider_not_initialize" + description = ( + "No valid model provider credentials found. " + "Please go to Settings -> Model Provider to complete your provider credentials." + ) code = 400 class ProviderQuotaExceededError(BaseHTTPException): - error_code = 'provider_quota_exceeded' - description = "Your quota for Dify Hosted Model Provider has been exhausted. " \ - "Please go to Settings -> Model Provider to complete your own provider credentials." + error_code = "provider_quota_exceeded" + description = ( + "Your quota for Dify Hosted Model Provider has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials." + ) code = 400 class ProviderModelCurrentlyNotSupportError(BaseHTTPException): - error_code = 'model_currently_not_support' + error_code = "model_currently_not_support" description = "Dify Hosted OpenAI trial currently not support the GPT-4 model." code = 400 class ConversationCompletedError(BaseHTTPException): - error_code = 'conversation_completed' + error_code = "conversation_completed" description = "The conversation has ended. Please start a new conversation." code = 400 class AppUnavailableError(BaseHTTPException): - error_code = 'app_unavailable' + error_code = "app_unavailable" description = "App unavailable, please check your app configurations." code = 400 class CompletionRequestError(BaseHTTPException): - error_code = 'completion_request_error' + error_code = "completion_request_error" description = "Completion request failed." code = 400 class AppMoreLikeThisDisabledError(BaseHTTPException): - error_code = 'app_more_like_this_disabled' + error_code = "app_more_like_this_disabled" description = "The 'More like this' feature is disabled. Please refresh your page." code = 403 class NoAudioUploadedError(BaseHTTPException): - error_code = 'no_audio_uploaded' + error_code = "no_audio_uploaded" description = "Please upload your audio." code = 400 class AudioTooLargeError(BaseHTTPException): - error_code = 'audio_too_large' + error_code = "audio_too_large" description = "Audio size exceeded. {message}" code = 413 class UnsupportedAudioTypeError(BaseHTTPException): - error_code = 'unsupported_audio_type' + error_code = "unsupported_audio_type" description = "Audio type not allowed." code = 415 class ProviderNotSupportSpeechToTextError(BaseHTTPException): - error_code = 'provider_not_support_speech_to_text' + error_code = "provider_not_support_speech_to_text" description = "Provider not support speech to text." code = 400 class NoFileUploadedError(BaseHTTPException): - error_code = 'no_file_uploaded' + error_code = "no_file_uploaded" description = "Please upload your file." code = 400 class TooManyFilesError(BaseHTTPException): - error_code = 'too_many_files' + error_code = "too_many_files" description = "Only one file is allowed." code = 400 class DraftWorkflowNotExist(BaseHTTPException): - error_code = 'draft_workflow_not_exist' + error_code = "draft_workflow_not_exist" description = "Draft workflow need to be initialized." code = 400 class DraftWorkflowNotSync(BaseHTTPException): - error_code = 'draft_workflow_not_sync' + error_code = "draft_workflow_not_sync" description = "Workflow graph might have been modified, please refresh and resubmit." code = 400 class TracingConfigNotExist(BaseHTTPException): - error_code = 'trace_config_not_exist' + error_code = "trace_config_not_exist" description = "Trace config not exist." code = 400 class TracingConfigIsExist(BaseHTTPException): - error_code = 'trace_config_is_exist' + error_code = "trace_config_is_exist" description = "Trace config is exist." code = 400 class TracingConfigCheckError(BaseHTTPException): - error_code = 'trace_config_check_error' + error_code = "trace_config_check_error" description = "Invalid Credentials." code = 400 + + +class InvokeRateLimitError(BaseHTTPException): + """Raised when the Invoke returns rate limit error.""" + + error_code = "rate_limit_error" + description = "Rate Limit Error" + code = 429 diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 6803775e20..3d1e6b7a37 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -24,21 +24,21 @@ class RuleGenerateApi(Resource): @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('instruction', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_config', type=dict, required=True, nullable=False, location='json') - parser.add_argument('no_variable', type=bool, required=True, default=False, location='json') + parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") + parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") + parser.add_argument("no_variable", type=bool, required=True, default=False, location="json") args = parser.parse_args() account = current_user - PROMPT_GENERATION_MAX_TOKENS = int(os.getenv('PROMPT_GENERATION_MAX_TOKENS', '512')) + PROMPT_GENERATION_MAX_TOKENS = int(os.getenv("PROMPT_GENERATION_MAX_TOKENS", "512")) try: rules = LLMGenerator.generate_rule_config( tenant_id=account.current_tenant_id, - instruction=args['instruction'], - model_config=args['model_config'], - no_variable=args['no_variable'], - rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS + instruction=args["instruction"], + model_config=args["model_config"], + no_variable=args["no_variable"], + rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -52,4 +52,4 @@ class RuleGenerateApi(Resource): return rules -api.add_resource(RuleGenerateApi, '/rule-generate') +api.add_resource(RuleGenerateApi, "/rule-generate") diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 056415f19a..fe06201982 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -33,9 +33,9 @@ from services.message_service import MessageService class ChatMessageListApi(Resource): message_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(message_detail_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(message_detail_fields)), } @setup_required @@ -45,55 +45,69 @@ class ChatMessageListApi(Resource): @marshal_with(message_infinite_scroll_pagination_fields) def get(self, app_model): parser = reqparse.RequestParser() - parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') - parser.add_argument('first_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") + parser.add_argument("first_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - conversation = db.session.query(Conversation).filter( - Conversation.id == args['conversation_id'], - Conversation.app_id == app_model.id - ).first() + conversation = ( + db.session.query(Conversation) + .filter(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id) + .first() + ) if not conversation: raise NotFound("Conversation Not Exists.") - if args['first_id']: - first_message = db.session.query(Message) \ - .filter(Message.conversation_id == conversation.id, Message.id == args['first_id']).first() + if args["first_id"]: + first_message = ( + db.session.query(Message) + .filter(Message.conversation_id == conversation.id, Message.id == args["first_id"]) + .first() + ) if not first_message: raise NotFound("First message not found") - history_messages = db.session.query(Message).filter( - Message.conversation_id == conversation.id, - Message.created_at < first_message.created_at, - Message.id != first_message.id - ) \ - .order_by(Message.created_at.desc()).limit(args['limit']).all() + history_messages = ( + db.session.query(Message) + .filter( + Message.conversation_id == conversation.id, + Message.created_at < first_message.created_at, + Message.id != first_message.id, + ) + .order_by(Message.created_at.desc()) + .limit(args["limit"]) + .all() + ) else: - history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \ - .order_by(Message.created_at.desc()).limit(args['limit']).all() + history_messages = ( + db.session.query(Message) + .filter(Message.conversation_id == conversation.id) + .order_by(Message.created_at.desc()) + .limit(args["limit"]) + .all() + ) has_more = False - if len(history_messages) == args['limit']: + if len(history_messages) == args["limit"]: current_page_first_message = history_messages[-1] - rest_count = db.session.query(Message).filter( - Message.conversation_id == conversation.id, - Message.created_at < current_page_first_message.created_at, - Message.id != current_page_first_message.id - ).count() + rest_count = ( + db.session.query(Message) + .filter( + Message.conversation_id == conversation.id, + Message.created_at < current_page_first_message.created_at, + Message.id != current_page_first_message.id, + ) + .count() + ) if rest_count > 0: has_more = True history_messages = list(reversed(history_messages)) - return InfiniteScrollPagination( - data=history_messages, - limit=args['limit'], - has_more=has_more - ) + return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more) class MessageFeedbackApi(Resource): @@ -103,49 +117,46 @@ class MessageFeedbackApi(Resource): @get_app_model def post(self, app_model): parser = reqparse.RequestParser() - parser.add_argument('message_id', required=True, type=uuid_value, location='json') - parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') + parser.add_argument("message_id", required=True, type=uuid_value, location="json") + parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") args = parser.parse_args() - message_id = str(args['message_id']) + message_id = str(args["message_id"]) - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app_model.id - ).first() + message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first() if not message: raise NotFound("Message Not Exists.") feedback = message.admin_feedback - if not args['rating'] and feedback: + if not args["rating"] and feedback: db.session.delete(feedback) - elif args['rating'] and feedback: - feedback.rating = args['rating'] - elif not args['rating'] and not feedback: - raise ValueError('rating cannot be None when feedback not exists') + elif args["rating"] and feedback: + feedback.rating = args["rating"] + elif not args["rating"] and not feedback: + raise ValueError("rating cannot be None when feedback not exists") else: feedback = MessageFeedback( app_id=app_model.id, conversation_id=message.conversation_id, message_id=message.id, - rating=args['rating'], - from_source='admin', - from_account_id=current_user.id + rating=args["rating"], + from_source="admin", + from_account_id=current_user.id, ) db.session.add(feedback) db.session.commit() - return {'result': 'success'} + return {"result": "success"} class MessageAnnotationApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") @get_app_model @marshal_with(annotation_fields) def post(self, app_model): @@ -153,10 +164,10 @@ class MessageAnnotationApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('message_id', required=False, type=uuid_value, location='json') - parser.add_argument('question', required=True, type=str, location='json') - parser.add_argument('answer', required=True, type=str, location='json') - parser.add_argument('annotation_reply', required=False, type=dict, location='json') + parser.add_argument("message_id", required=False, type=uuid_value, location="json") + parser.add_argument("question", required=True, type=str, location="json") + parser.add_argument("answer", required=True, type=str, location="json") + parser.add_argument("annotation_reply", required=False, type=dict, location="json") args = parser.parse_args() annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id) @@ -169,11 +180,9 @@ class MessageAnnotationCountApi(Resource): @account_initialization_required @get_app_model def get(self, app_model): - count = db.session.query(MessageAnnotation).filter( - MessageAnnotation.app_id == app_model.id - ).count() + count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_model.id).count() - return {'count': count} + return {"count": count} class MessageSuggestedQuestionApi(Resource): @@ -186,10 +195,7 @@ class MessageSuggestedQuestionApi(Resource): try: questions = MessageService.get_suggested_questions_after_answer( - app_model=app_model, - message_id=message_id, - user=current_user, - invoke_from=InvokeFrom.DEBUGGER + app_model=app_model, message_id=message_id, user=current_user, invoke_from=InvokeFrom.DEBUGGER ) except MessageNotExistsError: raise NotFound("Message not found") @@ -209,7 +215,7 @@ class MessageSuggestedQuestionApi(Resource): logging.exception("internal server error.") raise InternalServerError() - return {'data': questions} + return {"data": questions} class MessageApi(Resource): @@ -221,10 +227,7 @@ class MessageApi(Resource): def get(self, app_model, message_id): message_id = str(message_id) - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app_model.id - ).first() + message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first() if not message: raise NotFound("Message Not Exists.") @@ -232,9 +235,9 @@ class MessageApi(Resource): return message -api.add_resource(MessageSuggestedQuestionApi, '/apps//chat-messages//suggested-questions') -api.add_resource(ChatMessageListApi, '/apps//chat-messages', endpoint='console_chat_messages') -api.add_resource(MessageFeedbackApi, '/apps//feedbacks') -api.add_resource(MessageAnnotationApi, '/apps//annotations') -api.add_resource(MessageAnnotationCountApi, '/apps//annotations/count') -api.add_resource(MessageApi, '/apps//messages/', endpoint='console_message') +api.add_resource(MessageSuggestedQuestionApi, "/apps//chat-messages//suggested-questions") +api.add_resource(ChatMessageListApi, "/apps//chat-messages", endpoint="console_chat_messages") +api.add_resource(MessageFeedbackApi, "/apps//feedbacks") +api.add_resource(MessageAnnotationApi, "/apps//annotations") +api.add_resource(MessageAnnotationCountApi, "/apps//annotations/count") +api.add_resource(MessageApi, "/apps//messages/", endpoint="console_message") diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index c8df879a29..f5068a4cd8 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -19,37 +19,35 @@ from services.app_model_config_service import AppModelConfigService class ModelConfigResource(Resource): - @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]) def post(self, app_model): - """Modify app model config""" # validate config model_configuration = AppModelConfigService.validate_configuration( - tenant_id=current_user.current_tenant_id, - config=request.json, - app_mode=AppMode.value_of(app_model.mode) + tenant_id=current_user.current_tenant_id, config=request.json, app_mode=AppMode.value_of(app_model.mode) ) new_app_model_config = AppModelConfig( app_id=app_model.id, + created_by=current_user.id, + updated_by=current_user.id, ) new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration) if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: # get original app model config - original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter( - AppModelConfig.id == app_model.app_model_config_id - ).first() + original_app_model_config: AppModelConfig = ( + db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() + ) agent_mode = original_app_model_config.agent_mode_dict # decrypt agent tool parameters if it's secret-input parameter_map = {} masked_parameter_map = {} tool_map = {} - for tool in agent_mode.get('tools') or []: + for tool in agent_mode.get("tools") or []: if not isinstance(tool, dict) or len(tool.keys()) <= 3: continue @@ -66,7 +64,7 @@ class ModelConfigResource(Resource): tool_runtime=tool_runtime, provider_name=agent_tool_entity.provider_id, provider_type=agent_tool_entity.provider_type, - identity_id=f'AGENT.{app_model.id}' + identity_id=f"AGENT.{app_model.id}", ) except Exception as e: continue @@ -79,18 +77,18 @@ class ModelConfigResource(Resource): parameters = {} masked_parameter = {} - key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' + key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}" masked_parameter_map[key] = masked_parameter parameter_map[key] = parameters tool_map[key] = tool_runtime # encrypt agent tool parameters if it's secret-input agent_mode = new_app_model_config.agent_mode_dict - for tool in agent_mode.get('tools') or []: + for tool in agent_mode.get("tools") or []: agent_tool_entity = AgentToolEntity(**tool) # get tool - key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' + key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}" if key in tool_map: tool_runtime = tool_map[key] else: @@ -108,7 +106,7 @@ class ModelConfigResource(Resource): tool_runtime=tool_runtime, provider_name=agent_tool_entity.provider_id, provider_type=agent_tool_entity.provider_type, - identity_id=f'AGENT.{app_model.id}' + identity_id=f"AGENT.{app_model.id}", ) manager.delete_tool_parameters_cache() @@ -116,15 +114,17 @@ class ModelConfigResource(Resource): if agent_tool_entity.tool_parameters: if key not in masked_parameter_map: continue - + for masked_key, masked_value in masked_parameter_map[key].items(): - if masked_key in agent_tool_entity.tool_parameters and \ - agent_tool_entity.tool_parameters[masked_key] == masked_value: + if ( + masked_key in agent_tool_entity.tool_parameters + and agent_tool_entity.tool_parameters[masked_key] == masked_value + ): agent_tool_entity.tool_parameters[masked_key] = parameter_map[key].get(masked_key) # encrypt parameters if agent_tool_entity.tool_parameters: - tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) + tool["tool_parameters"] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) # update app model config new_app_model_config.agent_mode = json.dumps(agent_mode) @@ -135,12 +135,9 @@ class ModelConfigResource(Resource): app_model.app_model_config_id = new_app_model_config.id db.session.commit() - app_model_config_was_updated.send( - app_model, - app_model_config=new_app_model_config - ) + app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config) - return {'result': 'success'} + return {"result": "success"} -api.add_resource(ModelConfigResource, '/apps//model-config') +api.add_resource(ModelConfigResource, "/apps//model-config") diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py index c0cf7b9e33..374bd2b815 100644 --- a/api/controllers/console/app/ops_trace.py +++ b/api/controllers/console/app/ops_trace.py @@ -18,13 +18,11 @@ class TraceAppConfigApi(Resource): @account_initialization_required def get(self, app_id): parser = reqparse.RequestParser() - parser.add_argument('tracing_provider', type=str, required=True, location='args') + parser.add_argument("tracing_provider", type=str, required=True, location="args") args = parser.parse_args() try: - trace_config = OpsService.get_tracing_app_config( - app_id=app_id, tracing_provider=args['tracing_provider'] - ) + trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"]) if not trace_config: return {"has_not_configured": True} return trace_config @@ -37,19 +35,17 @@ class TraceAppConfigApi(Resource): def post(self, app_id): """Create a new trace app configuration""" parser = reqparse.RequestParser() - parser.add_argument('tracing_provider', type=str, required=True, location='json') - parser.add_argument('tracing_config', type=dict, required=True, location='json') + parser.add_argument("tracing_provider", type=str, required=True, location="json") + parser.add_argument("tracing_config", type=dict, required=True, location="json") args = parser.parse_args() try: result = OpsService.create_tracing_app_config( - app_id=app_id, - tracing_provider=args['tracing_provider'], - tracing_config=args['tracing_config'] + app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"] ) if not result: raise TracingConfigIsExist() - if result.get('error'): + if result.get("error"): raise TracingConfigCheckError() return result except Exception as e: @@ -61,15 +57,13 @@ class TraceAppConfigApi(Resource): def patch(self, app_id): """Update an existing trace app configuration""" parser = reqparse.RequestParser() - parser.add_argument('tracing_provider', type=str, required=True, location='json') - parser.add_argument('tracing_config', type=dict, required=True, location='json') + parser.add_argument("tracing_provider", type=str, required=True, location="json") + parser.add_argument("tracing_config", type=dict, required=True, location="json") args = parser.parse_args() try: result = OpsService.update_tracing_app_config( - app_id=app_id, - tracing_provider=args['tracing_provider'], - tracing_config=args['tracing_config'] + app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"] ) if not result: raise TracingConfigNotExist() @@ -83,14 +77,11 @@ class TraceAppConfigApi(Resource): def delete(self, app_id): """Delete an existing trace app configuration""" parser = reqparse.RequestParser() - parser.add_argument('tracing_provider', type=str, required=True, location='args') + parser.add_argument("tracing_provider", type=str, required=True, location="args") args = parser.parse_args() try: - result = OpsService.delete_tracing_app_config( - app_id=app_id, - tracing_provider=args['tracing_provider'] - ) + result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"]) if not result: raise TracingConfigNotExist() return {"result": "success"} @@ -98,4 +89,4 @@ class TraceAppConfigApi(Resource): raise e -api.add_resource(TraceAppConfigApi, '/apps//trace-config') +api.add_resource(TraceAppConfigApi, "/apps//trace-config") diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 7db58c048a..26da1ef26d 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,3 +1,5 @@ +from datetime import datetime, timezone + from flask_login import current_user from flask_restful import Resource, marshal_with, reqparse from werkzeug.exceptions import Forbidden, NotFound @@ -15,23 +17,24 @@ from models.model import Site def parse_app_site_args(): parser = reqparse.RequestParser() - parser.add_argument('title', type=str, required=False, location='json') - parser.add_argument('icon_type', type=str, required=False, location='json') - parser.add_argument('icon', type=str, required=False, location='json') - parser.add_argument('icon_background', type=str, required=False, location='json') - parser.add_argument('description', type=str, required=False, location='json') - parser.add_argument('default_language', type=supported_language, required=False, location='json') - parser.add_argument('chat_color_theme', type=str, required=False, location='json') - parser.add_argument('chat_color_theme_inverted', type=bool, required=False, location='json') - parser.add_argument('customize_domain', type=str, required=False, location='json') - parser.add_argument('copyright', type=str, required=False, location='json') - parser.add_argument('privacy_policy', type=str, required=False, location='json') - parser.add_argument('custom_disclaimer', type=str, required=False, location='json') - parser.add_argument('customize_token_strategy', type=str, choices=['must', 'allow', 'not_allow'], - required=False, - location='json') - parser.add_argument('prompt_public', type=bool, required=False, location='json') - parser.add_argument('show_workflow_steps', type=bool, required=False, location='json') + parser.add_argument("title", type=str, required=False, location="json") + parser.add_argument("icon_type", type=str, required=False, location="json") + parser.add_argument("icon", type=str, required=False, location="json") + parser.add_argument("icon_background", type=str, required=False, location="json") + parser.add_argument("description", type=str, required=False, location="json") + parser.add_argument("default_language", type=supported_language, required=False, location="json") + parser.add_argument("chat_color_theme", type=str, required=False, location="json") + parser.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json") + parser.add_argument("customize_domain", type=str, required=False, location="json") + parser.add_argument("copyright", type=str, required=False, location="json") + parser.add_argument("privacy_policy", type=str, required=False, location="json") + parser.add_argument("custom_disclaimer", type=str, required=False, location="json") + parser.add_argument( + "customize_token_strategy", type=str, choices=["must", "allow", "not_allow"], required=False, location="json" + ) + parser.add_argument("prompt_public", type=bool, required=False, location="json") + parser.add_argument("show_workflow_steps", type=bool, required=False, location="json") + parser.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json") return parser.parse_args() @@ -48,38 +51,38 @@ class AppSite(Resource): if not current_user.is_editor: raise Forbidden() - site = db.session.query(Site). \ - filter(Site.app_id == app_model.id). \ - one_or_404() + site = db.session.query(Site).filter(Site.app_id == app_model.id).one_or_404() for attr_name in [ - 'title', - 'icon_type', - 'icon', - 'icon_background', - 'description', - 'default_language', - 'chat_color_theme', - 'chat_color_theme_inverted', - 'customize_domain', - 'copyright', - 'privacy_policy', - 'custom_disclaimer', - 'customize_token_strategy', - 'prompt_public', - 'show_workflow_steps' + "title", + "icon_type", + "icon", + "icon_background", + "description", + "default_language", + "chat_color_theme", + "chat_color_theme_inverted", + "customize_domain", + "copyright", + "privacy_policy", + "custom_disclaimer", + "customize_token_strategy", + "prompt_public", + "show_workflow_steps", + "use_icon_as_answer_icon", ]: value = args.get(attr_name) if value is not None: setattr(site, attr_name, value) + site.updated_by = current_user.id + site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() return site class AppSiteAccessTokenReset(Resource): - @setup_required @login_required @account_initialization_required @@ -96,10 +99,12 @@ class AppSiteAccessTokenReset(Resource): raise NotFound site.code = Site.generate_code(16) + site.updated_by = current_user.id + site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() return site -api.add_resource(AppSite, '/apps//site') -api.add_resource(AppSiteAccessTokenReset, '/apps//site/access-token-reset') +api.add_resource(AppSite, "/apps//site") +api.add_resource(AppSiteAccessTokenReset, "/apps//site/access-token-reset") diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index b882ffef34..81826a20d0 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -16,8 +16,7 @@ from libs.login import login_required from models.model import AppMode -class DailyConversationStatistic(Resource): - +class DailyMessageStatistic(Resource): @setup_required @login_required @account_initialization_required @@ -26,58 +25,52 @@ class DailyConversationStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' - SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.conversation_id) AS conversation_count + sql_query = """ + SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(*) AS message_count FROM messages where app_id = :app_id - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + """ + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'conversation_count': i.conversation_count - }) + response_data.append({"date": str(i.date), "message_count": i.message_count}) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) -class DailyTerminalsStatistic(Resource): - +class DailyConversationStatistic(Resource): @setup_required @login_required @account_initialization_required @@ -86,54 +79,103 @@ class DailyTerminalsStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' - SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.from_end_user_id) AS terminal_count - FROM messages where app_id = :app_id - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + sql_query = """ + SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.conversation_id) AS conversation_count + FROM messages where app_id = :app_id + """ + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'terminal_count': i.terminal_count - }) + response_data.append({"date": str(i.date), "conversation_count": i.conversation_count}) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) + + +class DailyTerminalsStatistic(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model + def get(self, app_model): + account = current_user + + parser = reqparse.RequestParser() + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + args = parser.parse_args() + + sql_query = """ + SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.from_end_user_id) AS terminal_count + FROM messages where app_id = :app_id + """ + arg_dict = {"tz": account.timezone, "app_id": app_model.id} + + timezone = pytz.timezone(account.timezone) + utc_timezone = pytz.utc + + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") + start_datetime = start_datetime.replace(second=0) + + start_datetime_timezone = timezone.localize(start_datetime) + start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) + + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc + + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") + end_datetime = end_datetime.replace(second=0) + + end_datetime_timezone = timezone.localize(end_datetime) + end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) + + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc + + sql_query += " GROUP BY date order by date" + + response_data = [] + + with db.engine.begin() as conn: + rs = conn.execute(db.text(sql_query), arg_dict) + for i in rs: + response_data.append({"date": str(i.date), "terminal_count": i.terminal_count}) + + return jsonify({"data": response_data}) class DailyTokenCostStatistic(Resource): @@ -145,58 +187,53 @@ class DailyTokenCostStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' + sql_query = """ SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, (sum(messages.message_tokens) + sum(messages.answer_tokens)) as token_count, sum(total_price) as total_price FROM messages where app_id = :app_id - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + """ + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'token_count': i.token_count, - 'total_price': i.total_price, - 'currency': 'USD' - }) + response_data.append( + {"date": str(i.date), "token_count": i.token_count, "total_price": i.total_price, "currency": "USD"} + ) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) class AverageSessionInteractionStatistic(Resource): @@ -208,8 +245,8 @@ class AverageSessionInteractionStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, @@ -218,30 +255,30 @@ FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count FROM conversations c JOIN messages m ON c.id = m.conversation_id WHERE c.override_model_configs IS NULL AND c.app_id = :app_id""" - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and c.created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and c.created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and c.created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and c.created_at < :end" + arg_dict["end"] = end_datetime_utc sql_query += """ GROUP BY m.conversation_id) subquery @@ -250,18 +287,15 @@ GROUP BY date ORDER BY date""" response_data = [] - + with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'interactions': float(i.interactions.quantize(Decimal('0.01'))) - }) + response_data.append( + {"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))} + ) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) class UserSatisfactionRateStatistic(Resource): @@ -273,57 +307,57 @@ class UserSatisfactionRateStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' + sql_query = """ SELECT date(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, COUNT(m.id) as message_count, COUNT(mf.id) as feedback_count FROM messages m LEFT JOIN message_feedbacks mf on mf.message_id=m.id and mf.rating='like' WHERE m.app_id = :app_id - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + """ + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and m.created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and m.created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and m.created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and m.created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'rate': round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2), - }) + response_data.append( + { + "date": str(i.date), + "rate": round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2), + } + ) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) class AverageResponseTimeStatistic(Resource): @@ -335,56 +369,51 @@ class AverageResponseTimeStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' + sql_query = """ SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, AVG(provider_response_latency) as latency FROM messages WHERE app_id = :app_id - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + """ + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'latency': round(i.latency * 1000, 4) - }) + response_data.append({"date": str(i.date), "latency": round(i.latency * 1000, 4)}) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) class TokensPerSecondStatistic(Resource): @@ -396,63 +425,59 @@ class TokensPerSecondStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = '''SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, CASE WHEN SUM(provider_response_latency) = 0 THEN 0 ELSE (SUM(answer_tokens) / SUM(provider_response_latency)) END as tokens_per_second FROM messages -WHERE app_id = :app_id''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} +WHERE app_id = :app_id""" + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'tps': round(i.tokens_per_second, 4) - }) + response_data.append({"date": str(i.date), "tps": round(i.tokens_per_second, 4)}) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) -api.add_resource(DailyConversationStatistic, '/apps//statistics/daily-conversations') -api.add_resource(DailyTerminalsStatistic, '/apps//statistics/daily-end-users') -api.add_resource(DailyTokenCostStatistic, '/apps//statistics/token-costs') -api.add_resource(AverageSessionInteractionStatistic, '/apps//statistics/average-session-interactions') -api.add_resource(UserSatisfactionRateStatistic, '/apps//statistics/user-satisfaction-rate') -api.add_resource(AverageResponseTimeStatistic, '/apps//statistics/average-response-time') -api.add_resource(TokensPerSecondStatistic, '/apps//statistics/tokens-per-second') +api.add_resource(DailyMessageStatistic, "/apps//statistics/daily-messages") +api.add_resource(DailyConversationStatistic, "/apps//statistics/daily-conversations") +api.add_resource(DailyTerminalsStatistic, "/apps//statistics/daily-end-users") +api.add_resource(DailyTokenCostStatistic, "/apps//statistics/token-costs") +api.add_resource(AverageSessionInteractionStatistic, "/apps//statistics/average-session-interactions") +api.add_resource(UserSatisfactionRateStatistic, "/apps//statistics/user-satisfaction-rate") +api.add_resource(AverageResponseTimeStatistic, "/apps//statistics/average-response-time") +api.add_resource(TokensPerSecondStatistic, "/apps//statistics/tokens-per-second") diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index a2052b9764..e44820f634 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -64,51 +64,51 @@ class DraftWorkflowApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - - content_type = request.headers.get('Content-Type', '') - if 'application/json' in content_type: + content_type = request.headers.get("Content-Type", "") + + if "application/json" in content_type: parser = reqparse.RequestParser() - parser.add_argument('graph', type=dict, required=True, nullable=False, location='json') - parser.add_argument('features', type=dict, required=True, nullable=False, location='json') - parser.add_argument('hash', type=str, required=False, location='json') + parser.add_argument("graph", type=dict, required=True, nullable=False, location="json") + parser.add_argument("features", type=dict, required=True, nullable=False, location="json") + parser.add_argument("hash", type=str, required=False, location="json") # TODO: set this to required=True after frontend is updated - parser.add_argument('environment_variables', type=list, required=False, location='json') - parser.add_argument('conversation_variables', type=list, required=False, location='json') + parser.add_argument("environment_variables", type=list, required=False, location="json") + parser.add_argument("conversation_variables", type=list, required=False, location="json") args = parser.parse_args() - elif 'text/plain' in content_type: + elif "text/plain" in content_type: try: - data = json.loads(request.data.decode('utf-8')) - if 'graph' not in data or 'features' not in data: - raise ValueError('graph or features not found in data') + data = json.loads(request.data.decode("utf-8")) + if "graph" not in data or "features" not in data: + raise ValueError("graph or features not found in data") - if not isinstance(data.get('graph'), dict) or not isinstance(data.get('features'), dict): - raise ValueError('graph or features is not a dict') + if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict): + raise ValueError("graph or features is not a dict") args = { - 'graph': data.get('graph'), - 'features': data.get('features'), - 'hash': data.get('hash'), - 'environment_variables': data.get('environment_variables'), - 'conversation_variables': data.get('conversation_variables'), + "graph": data.get("graph"), + "features": data.get("features"), + "hash": data.get("hash"), + "environment_variables": data.get("environment_variables"), + "conversation_variables": data.get("conversation_variables"), } except json.JSONDecodeError: - return {'message': 'Invalid JSON data'}, 400 + return {"message": "Invalid JSON data"}, 400 else: abort(415) workflow_service = WorkflowService() try: - environment_variables_list = args.get('environment_variables') or [] + environment_variables_list = args.get("environment_variables") or [] environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] - conversation_variables_list = args.get('conversation_variables') or [] + conversation_variables_list = args.get("conversation_variables") or [] conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] workflow = workflow_service.sync_draft_workflow( app_model=app_model, - graph=args['graph'], - features=args['features'], - unique_hash=args.get('hash'), + graph=args["graph"], + features=args["features"], + unique_hash=args.get("hash"), account=current_user, environment_variables=environment_variables, conversation_variables=conversation_variables, @@ -119,7 +119,7 @@ class DraftWorkflowApi(Resource): return { "result": "success", "hash": workflow.unique_hash, - "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at) + "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at), } @@ -138,13 +138,11 @@ class DraftWorkflowImportApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('data', type=str, required=True, nullable=False, location='json') + parser.add_argument("data", type=str, required=True, nullable=False, location="json") args = parser.parse_args() workflow = AppDslService.import_and_overwrite_workflow( - app_model=app_model, - data=args['data'], - account=current_user + app_model=app_model, data=args["data"], account=current_user ) return workflow @@ -162,21 +160,17 @@ class AdvancedChatDraftWorkflowRunApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, location='json') - parser.add_argument('query', type=str, required=True, location='json', default='') - parser.add_argument('files', type=list, location='json') - parser.add_argument('conversation_id', type=uuid_value, location='json') + parser.add_argument("inputs", type=dict, location="json") + parser.add_argument("query", type=str, required=True, location="json", default="") + parser.add_argument("files", type=list, location="json") + parser.add_argument("conversation_id", type=uuid_value, location="json") args = parser.parse_args() try: response = AppGenerateService.generate( - app_model=app_model, - user=current_user, - args=args, - invoke_from=InvokeFrom.DEBUGGER, - streaming=True + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True ) return helper.compact_generate_response(response) @@ -190,6 +184,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource): logging.exception("internal server error.") raise InternalServerError() + class AdvancedChatDraftRunIterationNodeApi(Resource): @setup_required @login_required @@ -202,18 +197,14 @@ class AdvancedChatDraftRunIterationNodeApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, location='json') + parser.add_argument("inputs", type=dict, location="json") args = parser.parse_args() try: response = AppGenerateService.generate_single_iteration( - app_model=app_model, - user=current_user, - node_id=node_id, - args=args, - streaming=True + app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True ) return helper.compact_generate_response(response) @@ -227,6 +218,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource): logging.exception("internal server error.") raise InternalServerError() + class WorkflowDraftRunIterationNodeApi(Resource): @setup_required @login_required @@ -239,18 +231,14 @@ class WorkflowDraftRunIterationNodeApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, location='json') + parser.add_argument("inputs", type=dict, location="json") args = parser.parse_args() try: response = AppGenerateService.generate_single_iteration( - app_model=app_model, - user=current_user, - node_id=node_id, - args=args, - streaming=True + app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True ) return helper.compact_generate_response(response) @@ -264,6 +252,7 @@ class WorkflowDraftRunIterationNodeApi(Resource): logging.exception("internal server error.") raise InternalServerError() + class DraftWorkflowRunApi(Resource): @setup_required @login_required @@ -276,19 +265,15 @@ class DraftWorkflowRunApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') - parser.add_argument('files', type=list, required=False, location='json') + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("files", type=list, required=False, location="json") args = parser.parse_args() try: response = AppGenerateService.generate( - app_model=app_model, - user=current_user, - args=args, - invoke_from=InvokeFrom.DEBUGGER, - streaming=True + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True ) return helper.compact_generate_response(response) @@ -311,12 +296,10 @@ class WorkflowTaskStopApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) - return { - "result": "success" - } + return {"result": "success"} class DraftWorkflowNodeRunApi(Resource): @@ -332,24 +315,20 @@ class DraftWorkflowNodeRunApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() workflow_service = WorkflowService() workflow_node_execution = workflow_service.run_draft_workflow_node( - app_model=app_model, - node_id=node_id, - user_inputs=args.get('inputs'), - account=current_user + app_model=app_model, node_id=node_id, user_inputs=args.get("inputs"), account=current_user ) return workflow_node_execution class PublishedWorkflowApi(Resource): - @setup_required @login_required @account_initialization_required @@ -362,7 +341,7 @@ class PublishedWorkflowApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + # fetch published workflow by app_model workflow_service = WorkflowService() workflow = workflow_service.get_published_workflow(app_model=app_model) @@ -381,14 +360,11 @@ class PublishedWorkflowApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + workflow_service = WorkflowService() workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user) - return { - "result": "success", - "created_at": TimestampField().format(workflow.created_at) - } + return {"result": "success", "created_at": TimestampField().format(workflow.created_at)} class DefaultBlockConfigsApi(Resource): @@ -403,7 +379,7 @@ class DefaultBlockConfigsApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + # Get default block configs workflow_service = WorkflowService() return workflow_service.get_default_block_configs() @@ -421,24 +397,21 @@ class DefaultBlockConfigApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('q', type=str, location='args') + parser.add_argument("q", type=str, location="args") args = parser.parse_args() filters = None - if args.get('q'): + if args.get("q"): try: - filters = json.loads(args.get('q')) + filters = json.loads(args.get("q")) except json.JSONDecodeError: - raise ValueError('Invalid filters') + raise ValueError("Invalid filters") # Get default block configs workflow_service = WorkflowService() - return workflow_service.get_default_block_config( - node_type=block_type, - filters=filters - ) + return workflow_service.get_default_block_config(node_type=block_type, filters=filters) class ConvertToWorkflowApi(Resource): @@ -455,41 +428,43 @@ class ConvertToWorkflowApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + if request.data: parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=False, nullable=True, location='json') - parser.add_argument('icon_type', type=str, required=False, nullable=True, location='json') - parser.add_argument('icon', type=str, required=False, nullable=True, location='json') - parser.add_argument('icon_background', type=str, required=False, nullable=True, location='json') + parser.add_argument("name", type=str, required=False, nullable=True, location="json") + parser.add_argument("icon_type", type=str, required=False, nullable=True, location="json") + parser.add_argument("icon", type=str, required=False, nullable=True, location="json") + parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json") args = parser.parse_args() else: args = {} # convert to workflow mode workflow_service = WorkflowService() - new_app_model = workflow_service.convert_to_workflow( - app_model=app_model, - account=current_user, - args=args - ) + new_app_model = workflow_service.convert_to_workflow(app_model=app_model, account=current_user, args=args) # return app id return { - 'new_app_id': new_app_model.id, + "new_app_id": new_app_model.id, } -api.add_resource(DraftWorkflowApi, '/apps//workflows/draft') -api.add_resource(DraftWorkflowImportApi, '/apps//workflows/draft/import') -api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps//advanced-chat/workflows/draft/run') -api.add_resource(DraftWorkflowRunApi, '/apps//workflows/draft/run') -api.add_resource(WorkflowTaskStopApi, '/apps//workflow-runs/tasks//stop') -api.add_resource(DraftWorkflowNodeRunApi, '/apps//workflows/draft/nodes//run') -api.add_resource(AdvancedChatDraftRunIterationNodeApi, '/apps//advanced-chat/workflows/draft/iteration/nodes//run') -api.add_resource(WorkflowDraftRunIterationNodeApi, '/apps//workflows/draft/iteration/nodes//run') -api.add_resource(PublishedWorkflowApi, '/apps//workflows/publish') -api.add_resource(DefaultBlockConfigsApi, '/apps//workflows/default-workflow-block-configs') -api.add_resource(DefaultBlockConfigApi, '/apps//workflows/default-workflow-block-configs' - '/') -api.add_resource(ConvertToWorkflowApi, '/apps//convert-to-workflow') +api.add_resource(DraftWorkflowApi, "/apps//workflows/draft") +api.add_resource(DraftWorkflowImportApi, "/apps//workflows/draft/import") +api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps//advanced-chat/workflows/draft/run") +api.add_resource(DraftWorkflowRunApi, "/apps//workflows/draft/run") +api.add_resource(WorkflowTaskStopApi, "/apps//workflow-runs/tasks//stop") +api.add_resource(DraftWorkflowNodeRunApi, "/apps//workflows/draft/nodes//run") +api.add_resource( + AdvancedChatDraftRunIterationNodeApi, + "/apps//advanced-chat/workflows/draft/iteration/nodes//run", +) +api.add_resource( + WorkflowDraftRunIterationNodeApi, "/apps//workflows/draft/iteration/nodes//run" +) +api.add_resource(PublishedWorkflowApi, "/apps//workflows/publish") +api.add_resource(DefaultBlockConfigsApi, "/apps//workflows/default-workflow-block-configs") +api.add_resource( + DefaultBlockConfigApi, "/apps//workflows/default-workflow-block-configs" "/" +) +api.add_resource(ConvertToWorkflowApi, "/apps//convert-to-workflow") diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 6d1709ed8e..dc962409cc 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -22,20 +22,19 @@ class WorkflowAppLogApi(Resource): Get workflow app logs """ parser = reqparse.RequestParser() - parser.add_argument('keyword', type=str, location='args') - parser.add_argument('status', type=str, choices=['succeeded', 'failed', 'stopped'], location='args') - parser.add_argument('page', type=int_range(1, 99999), default=1, location='args') - parser.add_argument('limit', type=int_range(1, 100), default=20, location='args') + parser.add_argument("keyword", type=str, location="args") + parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args") + parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") + parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") args = parser.parse_args() # get paginate workflow app logs workflow_app_service = WorkflowAppService() workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs( - app_model=app_model, - args=args + app_model=app_model, args=args ) return workflow_app_log_pagination -api.add_resource(WorkflowAppLogApi, '/apps//workflow-app-logs') +api.add_resource(WorkflowAppLogApi, "/apps//workflow-app-logs") diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 35d982e37c..a055d03deb 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -28,15 +28,12 @@ class AdvancedChatAppWorkflowRunListApi(Resource): Get advanced chat app workflow run list """ parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() workflow_run_service = WorkflowRunService() - result = workflow_run_service.get_paginate_advanced_chat_workflow_runs( - app_model=app_model, - args=args - ) + result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args=args) return result @@ -52,15 +49,12 @@ class WorkflowRunListApi(Resource): Get workflow run list """ parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() workflow_run_service = WorkflowRunService() - result = workflow_run_service.get_paginate_workflow_runs( - app_model=app_model, - args=args - ) + result = workflow_run_service.get_paginate_workflow_runs(app_model=app_model, args=args) return result @@ -98,12 +92,10 @@ class WorkflowRunNodeExecutionListApi(Resource): workflow_run_service = WorkflowRunService() node_executions = workflow_run_service.get_workflow_run_node_executions(app_model=app_model, run_id=run_id) - return { - 'data': node_executions - } + return {"data": node_executions} -api.add_resource(AdvancedChatAppWorkflowRunListApi, '/apps//advanced-chat/workflow-runs') -api.add_resource(WorkflowRunListApi, '/apps//workflow-runs') -api.add_resource(WorkflowRunDetailApi, '/apps//workflow-runs/') -api.add_resource(WorkflowRunNodeExecutionListApi, '/apps//workflow-runs//node-executions') +api.add_resource(AdvancedChatAppWorkflowRunListApi, "/apps//advanced-chat/workflow-runs") +api.add_resource(WorkflowRunListApi, "/apps//workflow-runs") +api.add_resource(WorkflowRunDetailApi, "/apps//workflow-runs/") +api.add_resource(WorkflowRunNodeExecutionListApi, "/apps//workflow-runs//node-executions") diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index 1d7dc395ff..db2f683589 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -26,56 +26,56 @@ class WorkflowDailyRunsStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' + sql_query = """ SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(id) AS runs FROM workflow_runs WHERE app_id = :app_id AND triggered_from = :triggered_from - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value} + """ + arg_dict = { + "tz": account.timezone, + "app_id": app_model.id, + "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, + } timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'runs': i.runs - }) + response_data.append({"date": str(i.date), "runs": i.runs}) + + return jsonify({"data": response_data}) - return jsonify({ - 'data': response_data - }) class WorkflowDailyTerminalsStatistic(Resource): @setup_required @@ -86,56 +86,56 @@ class WorkflowDailyTerminalsStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' + sql_query = """ SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct workflow_runs.created_by) AS terminal_count FROM workflow_runs WHERE app_id = :app_id AND triggered_from = :triggered_from - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value} + """ + arg_dict = { + "tz": account.timezone, + "app_id": app_model.id, + "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, + } timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'terminal_count': i.terminal_count - }) + response_data.append({"date": str(i.date), "terminal_count": i.terminal_count}) + + return jsonify({"data": response_data}) - return jsonify({ - 'data': response_data - }) class WorkflowDailyTokenCostStatistic(Resource): @setup_required @@ -146,58 +146,63 @@ class WorkflowDailyTokenCostStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' + sql_query = """ SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, SUM(workflow_runs.total_tokens) as token_count FROM workflow_runs WHERE app_id = :app_id AND triggered_from = :triggered_from - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value} + """ + arg_dict = { + "tz": account.timezone, + "app_id": app_model.id, + "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, + } timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'token_count': i.token_count, - }) + response_data.append( + { + "date": str(i.date), + "token_count": i.token_count, + } + ) + + return jsonify({"data": response_data}) - return jsonify({ - 'data': response_data - }) class WorkflowAverageAppInteractionStatistic(Resource): @setup_required @@ -208,8 +213,8 @@ class WorkflowAverageAppInteractionStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ @@ -229,50 +234,54 @@ class WorkflowAverageAppInteractionStatistic(Resource): GROUP BY date, c.created_by) sub GROUP BY sub.date """ - arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value} + arg_dict = { + "tz": account.timezone, + "app_id": app_model.id, + "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, + } timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query = sql_query.replace('{{start}}', ' AND c.created_at >= :start') - arg_dict['start'] = start_datetime_utc + sql_query = sql_query.replace("{{start}}", " AND c.created_at >= :start") + arg_dict["start"] = start_datetime_utc else: - sql_query = sql_query.replace('{{start}}', '') + sql_query = sql_query.replace("{{start}}", "") - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query = sql_query.replace('{{end}}', ' and c.created_at < :end') - arg_dict['end'] = end_datetime_utc + sql_query = sql_query.replace("{{end}}", " and c.created_at < :end") + arg_dict["end"] = end_datetime_utc else: - sql_query = sql_query.replace('{{end}}', '') + sql_query = sql_query.replace("{{end}}", "") response_data = [] - + with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'interactions': float(i.interactions.quantize(Decimal('0.01'))) - }) + response_data.append( + {"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))} + ) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) -api.add_resource(WorkflowDailyRunsStatistic, '/apps//workflow/statistics/daily-conversations') -api.add_resource(WorkflowDailyTerminalsStatistic, '/apps//workflow/statistics/daily-terminals') -api.add_resource(WorkflowDailyTokenCostStatistic, '/apps//workflow/statistics/token-costs') -api.add_resource(WorkflowAverageAppInteractionStatistic, '/apps//workflow/statistics/average-app-interactions') + +api.add_resource(WorkflowDailyRunsStatistic, "/apps//workflow/statistics/daily-conversations") +api.add_resource(WorkflowDailyTerminalsStatistic, "/apps//workflow/statistics/daily-terminals") +api.add_resource(WorkflowDailyTokenCostStatistic, "/apps//workflow/statistics/token-costs") +api.add_resource( + WorkflowAverageAppInteractionStatistic, "/apps//workflow/statistics/average-app-interactions" +) diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index d61ab6d6ae..5e0a4bc814 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -8,24 +8,23 @@ from libs.login import current_user from models.model import App, AppMode -def get_app_model(view: Optional[Callable] = None, *, - mode: Union[AppMode, list[AppMode]] = None): +def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None): def decorator(view_func): @wraps(view_func) def decorated_view(*args, **kwargs): - if not kwargs.get('app_id'): - raise ValueError('missing app_id in path parameters') + if not kwargs.get("app_id"): + raise ValueError("missing app_id in path parameters") - app_id = kwargs.get('app_id') + app_id = kwargs.get("app_id") app_id = str(app_id) - del kwargs['app_id'] + del kwargs["app_id"] - app_model = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app_model = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app_model: raise AppNotFoundError() @@ -44,9 +43,10 @@ def get_app_model(view: Optional[Callable] = None, *, mode_values = {m.value for m in modes} raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}") - kwargs['app_model'] = app_model + kwargs["app_model"] = app_model return view_func(*args, **kwargs) + return decorated_view if view is None: diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index 8efb55cdb6..8ba6b53e7e 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -17,60 +17,61 @@ from services.account_service import RegisterService class ActivateCheckApi(Resource): def get(self): parser = reqparse.RequestParser() - parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='args') - parser.add_argument('email', type=email, required=False, nullable=True, location='args') - parser.add_argument('token', type=str, required=True, nullable=False, location='args') + parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="args") + parser.add_argument("email", type=email, required=False, nullable=True, location="args") + parser.add_argument("token", type=str, required=True, nullable=False, location="args") args = parser.parse_args() - workspaceId = args['workspace_id'] - reg_email = args['email'] - token = args['token'] + workspaceId = args["workspace_id"] + reg_email = args["email"] + token = args["token"] invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token) - return {'is_valid': invitation is not None, 'workspace_name': invitation['tenant'].name if invitation else None} + return {"is_valid": invitation is not None, "workspace_name": invitation["tenant"].name if invitation else None} class ActivateApi(Resource): def post(self): parser = reqparse.RequestParser() - parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='json') - parser.add_argument('email', type=email, required=False, nullable=True, location='json') - parser.add_argument('token', type=str, required=True, nullable=False, location='json') - parser.add_argument('name', type=str_len(30), required=True, nullable=False, location='json') - parser.add_argument('password', type=valid_password, required=True, nullable=False, location='json') - parser.add_argument('interface_language', type=supported_language, required=True, nullable=False, - location='json') - parser.add_argument('timezone', type=timezone, required=True, nullable=False, location='json') + parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json") + parser.add_argument("email", type=email, required=False, nullable=True, location="json") + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=str_len(30), required=True, nullable=False, location="json") + parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json") + parser.add_argument( + "interface_language", type=supported_language, required=True, nullable=False, location="json" + ) + parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json") args = parser.parse_args() - invitation = RegisterService.get_invitation_if_token_valid(args['workspace_id'], args['email'], args['token']) + invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"]) if invitation is None: raise AlreadyActivateError() - RegisterService.revoke_token(args['workspace_id'], args['email'], args['token']) + RegisterService.revoke_token(args["workspace_id"], args["email"], args["token"]) - account = invitation['account'] - account.name = args['name'] + account = invitation["account"] + account.name = args["name"] # generate password salt salt = secrets.token_bytes(16) base64_salt = base64.b64encode(salt).decode() # encrypt password with salt - password_hashed = hash_password(args['password'], salt) + password_hashed = hash_password(args["password"], salt) base64_password_hashed = base64.b64encode(password_hashed).decode() account.password = base64_password_hashed account.password_salt = base64_salt - account.interface_language = args['interface_language'] - account.timezone = args['timezone'] - account.interface_theme = 'light' + account.interface_language = args["interface_language"] + account.timezone = args["timezone"] + account.interface_theme = "light" account.status = AccountStatus.ACTIVE.value account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() - return {'result': 'success'} + return {"result": "success"} -api.add_resource(ActivateCheckApi, '/activate/check') -api.add_resource(ActivateApi, '/activate') +api.add_resource(ActivateCheckApi, "/activate/check") +api.add_resource(ActivateApi, "/activate") diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index f79b93b74f..50db6eebc1 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -19,18 +19,19 @@ class ApiKeyAuthDataSource(Resource): data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id) if data_source_api_key_bindings: return { - 'sources': [{ - 'id': data_source_api_key_binding.id, - 'category': data_source_api_key_binding.category, - 'provider': data_source_api_key_binding.provider, - 'disabled': data_source_api_key_binding.disabled, - 'created_at': int(data_source_api_key_binding.created_at.timestamp()), - 'updated_at': int(data_source_api_key_binding.updated_at.timestamp()), - } - for data_source_api_key_binding in - data_source_api_key_bindings] + "sources": [ + { + "id": data_source_api_key_binding.id, + "category": data_source_api_key_binding.category, + "provider": data_source_api_key_binding.provider, + "disabled": data_source_api_key_binding.disabled, + "created_at": int(data_source_api_key_binding.created_at.timestamp()), + "updated_at": int(data_source_api_key_binding.updated_at.timestamp()), + } + for data_source_api_key_binding in data_source_api_key_bindings + ] } - return {'sources': []} + return {"sources": []} class ApiKeyAuthDataSourceBinding(Resource): @@ -42,16 +43,16 @@ class ApiKeyAuthDataSourceBinding(Resource): if not current_user.is_admin_or_owner: raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('category', type=str, required=True, nullable=False, location='json') - parser.add_argument('provider', type=str, required=True, nullable=False, location='json') - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("category", type=str, required=True, nullable=False, location="json") + parser.add_argument("provider", type=str, required=True, nullable=False, location="json") + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() ApiKeyAuthService.validate_api_key_auth_args(args) try: ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args) except Exception as e: raise ApiKeyAuthFailedError(str(e)) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class ApiKeyAuthDataSourceBindingDelete(Resource): @@ -65,9 +66,9 @@ class ApiKeyAuthDataSourceBindingDelete(Resource): ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(ApiKeyAuthDataSource, '/api-key-auth/data-source') -api.add_resource(ApiKeyAuthDataSourceBinding, '/api-key-auth/data-source/binding') -api.add_resource(ApiKeyAuthDataSourceBindingDelete, '/api-key-auth/data-source/') +api.add_resource(ApiKeyAuthDataSource, "/api-key-auth/data-source") +api.add_resource(ApiKeyAuthDataSourceBinding, "/api-key-auth/data-source/binding") +api.add_resource(ApiKeyAuthDataSourceBindingDelete, "/api-key-auth/data-source/") diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 1992ed391a..fd31e5ccc3 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -17,13 +17,13 @@ from ..wraps import account_initialization_required def get_oauth_providers(): with current_app.app_context(): - notion_oauth = NotionOAuth(client_id=dify_config.NOTION_CLIENT_ID, - client_secret=dify_config.NOTION_CLIENT_SECRET, - redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/data-source/callback/notion') + notion_oauth = NotionOAuth( + client_id=dify_config.NOTION_CLIENT_ID, + client_secret=dify_config.NOTION_CLIENT_SECRET, + redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/data-source/callback/notion", + ) - OAUTH_PROVIDERS = { - 'notion': notion_oauth - } + OAUTH_PROVIDERS = {"notion": notion_oauth} return OAUTH_PROVIDERS @@ -37,16 +37,16 @@ class OAuthDataSource(Resource): oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) print(vars(oauth_provider)) if not oauth_provider: - return {'error': 'Invalid provider'}, 400 - if dify_config.NOTION_INTEGRATION_TYPE == 'internal': + return {"error": "Invalid provider"}, 400 + if dify_config.NOTION_INTEGRATION_TYPE == "internal": internal_secret = dify_config.NOTION_INTERNAL_SECRET if not internal_secret: - return {'error': 'Internal secret is not set'}, + return ({"error": "Internal secret is not set"},) oauth_provider.save_internal_access_token(internal_secret) - return {'data': ''} + return {"data": ""} else: auth_url = oauth_provider.get_authorization_url() - return {'data': auth_url}, 200 + return {"data": auth_url}, 200 class OAuthDataSourceCallback(Resource): @@ -55,17 +55,17 @@ class OAuthDataSourceCallback(Resource): with current_app.app_context(): oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) if not oauth_provider: - return {'error': 'Invalid provider'}, 400 - if 'code' in request.args: - code = request.args.get('code') + return {"error": "Invalid provider"}, 400 + if "code" in request.args: + code = request.args.get("code") - return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}') - elif 'error' in request.args: - error = request.args.get('error') + return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}") + elif "error" in request.args: + error = request.args.get("error") - return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}') + return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}") else: - return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied') + return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied") class OAuthDataSourceBinding(Resource): @@ -74,17 +74,18 @@ class OAuthDataSourceBinding(Resource): with current_app.app_context(): oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) if not oauth_provider: - return {'error': 'Invalid provider'}, 400 - if 'code' in request.args: - code = request.args.get('code') + return {"error": "Invalid provider"}, 400 + if "code" in request.args: + code = request.args.get("code") try: oauth_provider.get_access_token(code) except requests.exceptions.HTTPError as e: logging.exception( - f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}") - return {'error': 'OAuth data source process failed'}, 400 + f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}" + ) + return {"error": "OAuth data source process failed"}, 400 - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class OAuthDataSourceSync(Resource): @@ -98,18 +99,17 @@ class OAuthDataSourceSync(Resource): with current_app.app_context(): oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) if not oauth_provider: - return {'error': 'Invalid provider'}, 400 + return {"error": "Invalid provider"}, 400 try: oauth_provider.sync_data_source(binding_id) except requests.exceptions.HTTPError as e: - logging.exception( - f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}") - return {'error': 'OAuth data source process failed'}, 400 + logging.exception(f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}") + return {"error": "OAuth data source process failed"}, 400 - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(OAuthDataSource, '/oauth/data-source/') -api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/callback/') -api.add_resource(OAuthDataSourceBinding, '/oauth/data-source/binding/') -api.add_resource(OAuthDataSourceSync, '/oauth/data-source///sync') +api.add_resource(OAuthDataSource, "/oauth/data-source/") +api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/") +api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/") +api.add_resource(OAuthDataSourceSync, "/oauth/data-source///sync") diff --git a/api/controllers/console/auth/error.py b/api/controllers/console/auth/error.py index 53dab3298f..ea23e097d0 100644 --- a/api/controllers/console/auth/error.py +++ b/api/controllers/console/auth/error.py @@ -2,31 +2,30 @@ from libs.exception import BaseHTTPException class ApiKeyAuthFailedError(BaseHTTPException): - error_code = 'auth_failed' + error_code = "auth_failed" description = "{message}" code = 500 class InvalidEmailError(BaseHTTPException): - error_code = 'invalid_email' + error_code = "invalid_email" description = "The email address is not valid." code = 400 class PasswordMismatchError(BaseHTTPException): - error_code = 'password_mismatch' + error_code = "password_mismatch" description = "The passwords do not match." code = 400 class InvalidTokenError(BaseHTTPException): - error_code = 'invalid_or_expired_token' + error_code = "invalid_or_expired_token" description = "The token is invalid or has expired." code = 400 class PasswordResetRateLimitExceededError(BaseHTTPException): - error_code = 'password_reset_rate_limit_exceeded' + error_code = "password_reset_rate_limit_exceeded" description = "Password reset rate limit exceeded. Try again later." code = 429 - diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index d78be770ab..0b01a4906a 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -21,14 +21,13 @@ from services.errors.account import RateLimitExceededError class ForgotPasswordSendEmailApi(Resource): - @setup_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('email', type=str, required=True, location='json') + parser.add_argument("email", type=str, required=True, location="json") args = parser.parse_args() - email = args['email'] + email = args["email"] if not email_validate(email): raise InvalidEmailError() @@ -49,38 +48,36 @@ class ForgotPasswordSendEmailApi(Resource): class ForgotPasswordCheckApi(Resource): - @setup_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('token', type=str, required=True, nullable=False, location='json') + parser.add_argument("token", type=str, required=True, nullable=False, location="json") args = parser.parse_args() - token = args['token'] + token = args["token"] reset_data = AccountService.get_reset_password_data(token) if reset_data is None: - return {'is_valid': False, 'email': None} - return {'is_valid': True, 'email': reset_data.get('email')} + return {"is_valid": False, "email": None} + return {"is_valid": True, "email": reset_data.get("email")} class ForgotPasswordResetApi(Resource): - @setup_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('token', type=str, required=True, nullable=False, location='json') - parser.add_argument('new_password', type=valid_password, required=True, nullable=False, location='json') - parser.add_argument('password_confirm', type=valid_password, required=True, nullable=False, location='json') + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") + parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") args = parser.parse_args() - new_password = args['new_password'] - password_confirm = args['password_confirm'] + new_password = args["new_password"] + password_confirm = args["password_confirm"] if str(new_password).strip() != str(password_confirm).strip(): raise PasswordMismatchError() - token = args['token'] + token = args["token"] reset_data = AccountService.get_reset_password_data(token) if reset_data is None: @@ -94,14 +91,14 @@ class ForgotPasswordResetApi(Resource): password_hashed = hash_password(new_password, salt) base64_password_hashed = base64.b64encode(password_hashed).decode() - account = Account.query.filter_by(email=reset_data.get('email')).first() + account = Account.query.filter_by(email=reset_data.get("email")).first() account.password = base64_password_hashed account.password_salt = base64_salt db.session.commit() - return {'result': 'success'} + return {"result": "success"} -api.add_resource(ForgotPasswordSendEmailApi, '/forgot-password') -api.add_resource(ForgotPasswordCheckApi, '/forgot-password/validity') -api.add_resource(ForgotPasswordResetApi, '/forgot-password/resets') +api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password") +api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity") +api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets") diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index c135ece67e..62837af2b9 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -20,37 +20,39 @@ class LoginApi(Resource): def post(self): """Authenticate user and login.""" parser = reqparse.RequestParser() - parser.add_argument('email', type=email, required=True, location='json') - parser.add_argument('password', type=valid_password, required=True, location='json') - parser.add_argument('remember_me', type=bool, required=False, default=False, location='json') + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("password", type=valid_password, required=True, location="json") + parser.add_argument("remember_me", type=bool, required=False, default=False, location="json") args = parser.parse_args() # todo: Verify the recaptcha try: - account = AccountService.authenticate(args['email'], args['password']) + account = AccountService.authenticate(args["email"], args["password"]) except services.errors.account.AccountLoginError as e: - return {'code': 'unauthorized', 'message': str(e)}, 401 + return {"code": "unauthorized", "message": str(e)}, 401 # SELF_HOSTED only have one workspace tenants = TenantService.get_join_tenants(account) if len(tenants) == 0: - return {'result': 'fail', 'data': 'workspace not found, please contact system admin to invite you to join in a workspace'} + return { + "result": "fail", + "data": "workspace not found, please contact system admin to invite you to join in a workspace", + } token = AccountService.login(account, ip_address=get_remote_ip(request)) - return {'result': 'success', 'data': token} + return {"result": "success", "data": token} class LogoutApi(Resource): - @setup_required def get(self): account = cast(Account, flask_login.current_user) - token = request.headers.get('Authorization', '').split(' ')[1] + token = request.headers.get("Authorization", "").split(" ")[1] AccountService.logout(account=account, token=token) flask_login.logout_user() - return {'result': 'success'} + return {"result": "success"} class ResetPasswordApi(Resource): @@ -80,11 +82,11 @@ class ResetPasswordApi(Resource): # 'subject': 'Reset your Dify password', # 'html': """ #

Dear User,

- #

The Dify team has generated a new password for you, details as follows:

+ #

The Dify team has generated a new password for you, details as follows:

#

{new_password}

#

Please change your password to log in as soon as possible.

#

Regards,

- #

The Dify Team

+ #

The Dify Team

# """ # } @@ -101,8 +103,8 @@ class ResetPasswordApi(Resource): # # handle error # pass - return {'result': 'success'} + return {"result": "success"} -api.add_resource(LoginApi, '/login') -api.add_resource(LogoutApi, '/logout') +api.add_resource(LoginApi, "/login") +api.add_resource(LogoutApi, "/logout") diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 4a651bfe7b..ae1b49f3ec 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -25,7 +25,7 @@ def get_oauth_providers(): github_oauth = GitHubOAuth( client_id=dify_config.GITHUB_CLIENT_ID, client_secret=dify_config.GITHUB_CLIENT_SECRET, - redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/github', + redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/github", ) if not dify_config.GOOGLE_CLIENT_ID or not dify_config.GOOGLE_CLIENT_SECRET: google_oauth = None @@ -33,10 +33,10 @@ def get_oauth_providers(): google_oauth = GoogleOAuth( client_id=dify_config.GOOGLE_CLIENT_ID, client_secret=dify_config.GOOGLE_CLIENT_SECRET, - redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/google', + redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/google", ) - OAUTH_PROVIDERS = {'github': github_oauth, 'google': google_oauth} + OAUTH_PROVIDERS = {"github": github_oauth, "google": google_oauth} return OAUTH_PROVIDERS @@ -47,7 +47,7 @@ class OAuthLogin(Resource): oauth_provider = OAUTH_PROVIDERS.get(provider) print(vars(oauth_provider)) if not oauth_provider: - return {'error': 'Invalid provider'}, 400 + return {"error": "Invalid provider"}, 400 auth_url = oauth_provider.get_authorization_url() return redirect(auth_url) @@ -59,20 +59,20 @@ class OAuthCallback(Resource): with current_app.app_context(): oauth_provider = OAUTH_PROVIDERS.get(provider) if not oauth_provider: - return {'error': 'Invalid provider'}, 400 + return {"error": "Invalid provider"}, 400 - code = request.args.get('code') + code = request.args.get("code") try: token = oauth_provider.get_access_token(code) user_info = oauth_provider.get_user_info(token) except requests.exceptions.HTTPError as e: - logging.exception(f'An error occurred during the OAuth process with {provider}: {e.response.text}') - return {'error': 'OAuth process failed'}, 400 + logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}") + return {"error": "OAuth process failed"}, 400 account = _generate_account(provider, user_info) # Check account status if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: - return {'error': 'Account is banned or closed.'}, 403 + return {"error": "Account is banned or closed."}, 403 if account.status == AccountStatus.PENDING.value: account.status = AccountStatus.ACTIVE.value @@ -83,7 +83,7 @@ class OAuthCallback(Resource): token = AccountService.login(account, ip_address=get_remote_ip(request)) - return redirect(f'{dify_config.CONSOLE_WEB_URL}?console_token={token}') + return redirect(f"{dify_config.CONSOLE_WEB_URL}?console_token={token}") def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: @@ -101,7 +101,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): if not account: # Create account - account_name = user_info.name if user_info.name else 'Dify' + account_name = user_info.name if user_info.name else "Dify" account = RegisterService.register( email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider ) @@ -121,5 +121,5 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): return account -api.add_resource(OAuthLogin, '/oauth/login/') -api.add_resource(OAuthCallback, '/oauth/authorize/') +api.add_resource(OAuthLogin, "/oauth/login/") +api.add_resource(OAuthCallback, "/oauth/authorize/") diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 72a6129efa..9a1d914869 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -9,28 +9,24 @@ from services.billing_service import BillingService class Subscription(Resource): - @setup_required @login_required @account_initialization_required @only_edition_cloud def get(self): - parser = reqparse.RequestParser() - parser.add_argument('plan', type=str, required=True, location='args', choices=['professional', 'team']) - parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year']) + parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"]) + parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"]) args = parser.parse_args() BillingService.is_tenant_owner_or_admin(current_user) - return BillingService.get_subscription(args['plan'], - args['interval'], - current_user.email, - current_user.current_tenant_id) + return BillingService.get_subscription( + args["plan"], args["interval"], current_user.email, current_user.current_tenant_id + ) class Invoices(Resource): - @setup_required @login_required @account_initialization_required @@ -40,5 +36,5 @@ class Invoices(Resource): return BillingService.get_invoices(current_user.email, current_user.current_tenant_id) -api.add_resource(Subscription, '/billing/subscription') -api.add_resource(Invoices, '/billing/invoices') +api.add_resource(Subscription, "/billing/subscription") +api.add_resource(Invoices, "/billing/invoices") diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 0ca0f0a856..0e1acab946 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -22,19 +22,22 @@ from tasks.document_indexing_sync_task import document_indexing_sync_task class DataSourceApi(Resource): - @setup_required @login_required @account_initialization_required @marshal_with(integrate_list_fields) def get(self): # get workspace data source integrates - data_source_integrates = db.session.query(DataSourceOauthBinding).filter( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.disabled == False - ).all() + data_source_integrates = ( + db.session.query(DataSourceOauthBinding) + .filter( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.disabled == False, + ) + .all() + ) - base_url = request.url_root.rstrip('/') + base_url = request.url_root.rstrip("/") data_source_oauth_base_path = "/console/api/oauth/data-source" providers = ["notion"] @@ -44,26 +47,30 @@ class DataSourceApi(Resource): existing_integrates = filter(lambda item: item.provider == provider, data_source_integrates) if existing_integrates: for existing_integrate in list(existing_integrates): - integrate_data.append({ - 'id': existing_integrate.id, - 'provider': provider, - 'created_at': existing_integrate.created_at, - 'is_bound': True, - 'disabled': existing_integrate.disabled, - 'source_info': existing_integrate.source_info, - 'link': f'{base_url}{data_source_oauth_base_path}/{provider}' - }) + integrate_data.append( + { + "id": existing_integrate.id, + "provider": provider, + "created_at": existing_integrate.created_at, + "is_bound": True, + "disabled": existing_integrate.disabled, + "source_info": existing_integrate.source_info, + "link": f"{base_url}{data_source_oauth_base_path}/{provider}", + } + ) else: - integrate_data.append({ - 'id': None, - 'provider': provider, - 'created_at': None, - 'source_info': None, - 'is_bound': False, - 'disabled': None, - 'link': f'{base_url}{data_source_oauth_base_path}/{provider}' - }) - return {'data': integrate_data}, 200 + integrate_data.append( + { + "id": None, + "provider": provider, + "created_at": None, + "source_info": None, + "is_bound": False, + "disabled": None, + "link": f"{base_url}{data_source_oauth_base_path}/{provider}", + } + ) + return {"data": integrate_data}, 200 @setup_required @login_required @@ -71,92 +78,82 @@ class DataSourceApi(Resource): def patch(self, binding_id, action): binding_id = str(binding_id) action = str(action) - data_source_binding = DataSourceOauthBinding.query.filter_by( - id=binding_id - ).first() + data_source_binding = DataSourceOauthBinding.query.filter_by(id=binding_id).first() if data_source_binding is None: - raise NotFound('Data source binding not found.') + raise NotFound("Data source binding not found.") # enable binding - if action == 'enable': + if action == "enable": if data_source_binding.disabled: data_source_binding.disabled = False data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(data_source_binding) db.session.commit() else: - raise ValueError('Data source is not disabled.') + raise ValueError("Data source is not disabled.") # disable binding - if action == 'disable': + if action == "disable": if not data_source_binding.disabled: data_source_binding.disabled = True data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(data_source_binding) db.session.commit() else: - raise ValueError('Data source is disabled.') - return {'result': 'success'}, 200 + raise ValueError("Data source is disabled.") + return {"result": "success"}, 200 class DataSourceNotionListApi(Resource): - @setup_required @login_required @account_initialization_required @marshal_with(integrate_notion_info_list_fields) def get(self): - dataset_id = request.args.get('dataset_id', default=None, type=str) + dataset_id = request.args.get("dataset_id", default=None, type=str) exist_page_ids = [] # import notion in the exist dataset if dataset_id: dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') - if dataset.data_source_type != 'notion_import': - raise ValueError('Dataset is not notion type.') + raise NotFound("Dataset not found.") + if dataset.data_source_type != "notion_import": + raise ValueError("Dataset is not notion type.") documents = Document.query.filter_by( dataset_id=dataset_id, tenant_id=current_user.current_tenant_id, - data_source_type='notion_import', - enabled=True + data_source_type="notion_import", + enabled=True, ).all() if documents: for document in documents: data_source_info = json.loads(document.data_source_info) - exist_page_ids.append(data_source_info['notion_page_id']) + exist_page_ids.append(data_source_info["notion_page_id"]) # get all authorized pages data_source_bindings = DataSourceOauthBinding.query.filter_by( - tenant_id=current_user.current_tenant_id, - provider='notion', - disabled=False + tenant_id=current_user.current_tenant_id, provider="notion", disabled=False ).all() if not data_source_bindings: - return { - 'notion_info': [] - }, 200 + return {"notion_info": []}, 200 pre_import_info_list = [] for data_source_binding in data_source_bindings: source_info = data_source_binding.source_info - pages = source_info['pages'] + pages = source_info["pages"] # Filter out already bound pages for page in pages: - if page['page_id'] in exist_page_ids: - page['is_bound'] = True + if page["page_id"] in exist_page_ids: + page["is_bound"] = True else: - page['is_bound'] = False + page["is_bound"] = False pre_import_info = { - 'workspace_name': source_info['workspace_name'], - 'workspace_icon': source_info['workspace_icon'], - 'workspace_id': source_info['workspace_id'], - 'pages': pages, + "workspace_name": source_info["workspace_name"], + "workspace_icon": source_info["workspace_icon"], + "workspace_id": source_info["workspace_id"], + "pages": pages, } pre_import_info_list.append(pre_import_info) - return { - 'notion_info': pre_import_info_list - }, 200 + return {"notion_info": pre_import_info_list}, 200 class DataSourceNotionApi(Resource): - @setup_required @login_required @account_initialization_required @@ -166,64 +163,67 @@ class DataSourceNotionApi(Resource): data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"' + DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', ) ).first() if not data_source_binding: - raise NotFound('Data source binding not found.') + raise NotFound("Data source binding not found.") extractor = NotionExtractor( notion_workspace_id=workspace_id, notion_obj_id=page_id, notion_page_type=page_type, notion_access_token=data_source_binding.access_token, - tenant_id=current_user.current_tenant_id + tenant_id=current_user.current_tenant_id, ) text_docs = extractor.extract() - return { - 'content': "\n".join([doc.page_content for doc in text_docs]) - }, 200 + return {"content": "\n".join([doc.page_content for doc in text_docs])}, 200 @setup_required @login_required @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json') - parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') - parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json') + parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json") + parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") + parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" + ) args = parser.parse_args() # validate args DocumentService.estimate_args_validate(args) - notion_info_list = args['notion_info_list'] + notion_info_list = args["notion_info_list"] extract_settings = [] for notion_info in notion_info_list: - workspace_id = notion_info['workspace_id'] - for page in notion_info['pages']: + workspace_id = notion_info["workspace_id"] + for page in notion_info["pages"]: extract_setting = ExtractSetting( datasource_type="notion_import", notion_info={ "notion_workspace_id": workspace_id, - "notion_obj_id": page['page_id'], - "notion_page_type": page['type'], - "tenant_id": current_user.current_tenant_id + "notion_obj_id": page["page_id"], + "notion_page_type": page["type"], + "tenant_id": current_user.current_tenant_id, }, - document_model=args['doc_form'] + document_model=args["doc_form"], ) extract_settings.append(extract_setting) indexing_runner = IndexingRunner() - response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings, - args['process_rule'], args['doc_form'], - args['doc_language']) + response = indexing_runner.indexing_estimate( + current_user.current_tenant_id, + extract_settings, + args["process_rule"], + args["doc_form"], + args["doc_language"], + ) return response, 200 class DataSourceNotionDatasetSyncApi(Resource): - @setup_required @login_required @account_initialization_required @@ -240,7 +240,6 @@ class DataSourceNotionDatasetSyncApi(Resource): class DataSourceNotionDocumentSyncApi(Resource): - @setup_required @login_required @account_initialization_required @@ -258,10 +257,14 @@ class DataSourceNotionDocumentSyncApi(Resource): return 200 -api.add_resource(DataSourceApi, '/data-source/integrates', '/data-source/integrates//') -api.add_resource(DataSourceNotionListApi, '/notion/pre-import/pages') -api.add_resource(DataSourceNotionApi, - '/notion/workspaces//pages///preview', - '/datasets/notion-indexing-estimate') -api.add_resource(DataSourceNotionDatasetSyncApi, '/datasets//notion/sync') -api.add_resource(DataSourceNotionDocumentSyncApi, '/datasets//documents//notion/sync') +api.add_resource(DataSourceApi, "/data-source/integrates", "/data-source/integrates//") +api.add_resource(DataSourceNotionListApi, "/notion/pre-import/pages") +api.add_resource( + DataSourceNotionApi, + "/notion/workspaces//pages///preview", + "/datasets/notion-indexing-estimate", +) +api.add_resource(DataSourceNotionDatasetSyncApi, "/datasets//notion/sync") +api.add_resource( + DataSourceNotionDocumentSyncApi, "/datasets//documents//notion/sync" +) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index a5bc2dd86a..44c1390c14 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -24,52 +24,47 @@ from fields.app_fields import related_app_list from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields from fields.document_fields import document_status_fields from libs.login import login_required -from models.dataset import Dataset, Document, DocumentSegment +from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment from models.model import ApiToken, UploadFile from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService def _validate_name(name): if not name or len(name) < 1 or len(name) > 40: - raise ValueError('Name must be between 1 to 40 characters.') + raise ValueError("Name must be between 1 to 40 characters.") return name def _validate_description_length(description): if len(description) > 400: - raise ValueError('Description cannot exceed 400 characters.') + raise ValueError("Description cannot exceed 400 characters.") return description class DatasetListApi(Resource): - @setup_required @login_required @account_initialization_required def get(self): - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) - ids = request.args.getlist('ids') - provider = request.args.get('provider', default="vendor") - search = request.args.get('keyword', default=None, type=str) - tag_ids = request.args.getlist('tag_ids') + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) + ids = request.args.getlist("ids") + provider = request.args.get("provider", default="vendor") + search = request.args.get("keyword", default=None, type=str) + tag_ids = request.args.getlist("tag_ids") if ids: datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id) else: - datasets, total = DatasetService.get_datasets(page, limit, provider, - current_user.current_tenant_id, current_user, search, tag_ids) + datasets, total = DatasetService.get_datasets( + page, limit, provider, current_user.current_tenant_id, current_user, search, tag_ids + ) # check embedding setting provider_manager = ProviderManager() - configurations = provider_manager.get_configurations( - tenant_id=current_user.current_tenant_id - ) + configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) - embedding_models = configurations.get_models( - model_type=ModelType.TEXT_EMBEDDING, - only_active=True - ) + embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) model_names = [] for embedding_model in embedding_models: @@ -77,28 +72,22 @@ class DatasetListApi(Resource): data = marshal(datasets, dataset_detail_fields) for item in data: - if item['indexing_technique'] == 'high_quality': + if item["indexing_technique"] == "high_quality": item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" if item_model in model_names: - item['embedding_available'] = True + item["embedding_available"] = True else: - item['embedding_available'] = False + item["embedding_available"] = False else: - item['embedding_available'] = True + item["embedding_available"] = True - if item.get('permission') == 'partial_members': - part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item['id']) - item.update({'partial_member_list': part_users_list}) + if item.get("permission") == "partial_members": + part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item["id"]) + item.update({"partial_member_list": part_users_list}) else: - item.update({'partial_member_list': []}) + item.update({"partial_member_list": []}) - response = { - 'data': data, - 'has_more': len(datasets) == limit, - 'limit': limit, - 'total': total, - 'page': page - } + response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} return response, 200 @setup_required @@ -106,13 +95,21 @@ class DatasetListApi(Resource): @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('name', nullable=False, required=True, - help='type is required. Name must be between 1 to 40 characters.', - type=_validate_name) - parser.add_argument('indexing_technique', type=str, location='json', - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, - help='Invalid indexing technique.') + parser.add_argument( + "name", + nullable=False, + required=True, + help="type is required. Name must be between 1 to 40 characters.", + type=_validate_name, + ) + parser.add_argument( + "indexing_technique", + type=str, + location="json", + choices=Dataset.INDEXING_TECHNIQUE_LIST, + nullable=True, + help="Invalid indexing technique.", + ) args = parser.parse_args() # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator @@ -122,9 +119,10 @@ class DatasetListApi(Resource): try: dataset = DatasetService.create_empty_dataset( tenant_id=current_user.current_tenant_id, - name=args['name'], - indexing_technique=args['indexing_technique'], - account=current_user + name=args["name"], + indexing_technique=args["indexing_technique"], + account=current_user, + permission=DatasetPermissionEnum.ONLY_ME, ) except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() @@ -142,42 +140,36 @@ class DatasetApi(Resource): if dataset is None: raise NotFound("Dataset not found.") try: - DatasetService.check_dataset_permission( - dataset, current_user) + DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) data = marshal(dataset, dataset_detail_fields) - if data.get('permission') == 'partial_members': + if data.get("permission") == "partial_members": part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) - data.update({'partial_member_list': part_users_list}) + data.update({"partial_member_list": part_users_list}) # check embedding setting provider_manager = ProviderManager() - configurations = provider_manager.get_configurations( - tenant_id=current_user.current_tenant_id - ) + configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) - embedding_models = configurations.get_models( - model_type=ModelType.TEXT_EMBEDDING, - only_active=True - ) + embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) model_names = [] for embedding_model in embedding_models: model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") - if data['indexing_technique'] == 'high_quality': + if data["indexing_technique"] == "high_quality": item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}" if item_model in model_names: - data['embedding_available'] = True + data["embedding_available"] = True else: - data['embedding_available'] = False + data["embedding_available"] = False else: - data['embedding_available'] = True + data["embedding_available"] = True - if data.get('permission') == 'partial_members': + if data.get("permission") == "partial_members": part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) - data.update({'partial_member_list': part_users_list}) + data.update({"partial_member_list": part_users_list}) return data, 200 @@ -191,42 +183,49 @@ class DatasetApi(Resource): raise NotFound("Dataset not found.") parser = reqparse.RequestParser() - parser.add_argument('name', nullable=False, - help='type is required. Name must be between 1 to 40 characters.', - type=_validate_name) - parser.add_argument('description', - location='json', store_missing=False, - type=_validate_description_length) - parser.add_argument('indexing_technique', type=str, location='json', - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, - help='Invalid indexing technique.') - parser.add_argument('permission', type=str, location='json', choices=( - 'only_me', 'all_team_members', 'partial_members'), help='Invalid permission.' - ) - parser.add_argument('embedding_model', type=str, - location='json', help='Invalid embedding model.') - parser.add_argument('embedding_model_provider', type=str, - location='json', help='Invalid embedding model provider.') - parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.') - parser.add_argument('partial_member_list', type=list, location='json', help='Invalid parent user list.') + parser.add_argument( + "name", + nullable=False, + help="type is required. Name must be between 1 to 40 characters.", + type=_validate_name, + ) + parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length) + parser.add_argument( + "indexing_technique", + type=str, + location="json", + choices=Dataset.INDEXING_TECHNIQUE_LIST, + nullable=True, + help="Invalid indexing technique.", + ) + parser.add_argument( + "permission", + type=str, + location="json", + choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), + help="Invalid permission.", + ) + parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.") + parser.add_argument( + "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider." + ) + parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.") + parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.") args = parser.parse_args() data = request.get_json() # check embedding model setting - if data.get('indexing_technique') == 'high_quality': - DatasetService.check_embedding_model_setting(dataset.tenant_id, - data.get('embedding_model_provider'), - data.get('embedding_model') - ) + if data.get("indexing_technique") == "high_quality": + DatasetService.check_embedding_model_setting( + dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model") + ) # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator DatasetPermissionService.check_permission( - current_user, dataset, data.get('permission'), data.get('partial_member_list') + current_user, dataset, data.get("permission"), data.get("partial_member_list") ) - dataset = DatasetService.update_dataset( - dataset_id_str, args, current_user) + dataset = DatasetService.update_dataset(dataset_id_str, args, current_user) if dataset is None: raise NotFound("Dataset not found.") @@ -234,16 +233,19 @@ class DatasetApi(Resource): result_data = marshal(dataset, dataset_detail_fields) tenant_id = current_user.current_tenant_id - if data.get('partial_member_list') and data.get('permission') == 'partial_members': + if data.get("partial_member_list") and data.get("permission") == "partial_members": DatasetPermissionService.update_partial_member_list( - tenant_id, dataset_id_str, data.get('partial_member_list') + tenant_id, dataset_id_str, data.get("partial_member_list") ) # clear partial member list when permission is only_me or all_team_members - elif data.get('permission') == 'only_me' or data.get('permission') == 'all_team_members': + elif ( + data.get("permission") == DatasetPermissionEnum.ONLY_ME + or data.get("permission") == DatasetPermissionEnum.ALL_TEAM + ): DatasetPermissionService.clear_partial_member_list(dataset_id_str) partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) - result_data.update({'partial_member_list': partial_member_list}) + result_data.update({"partial_member_list": partial_member_list}) return result_data, 200 @@ -260,12 +262,13 @@ class DatasetApi(Resource): try: if DatasetService.delete_dataset(dataset_id_str, current_user): DatasetPermissionService.clear_partial_member_list(dataset_id_str) - return {'result': 'success'}, 204 + return {"result": "success"}, 204 else: raise NotFound("Dataset not found.") except services.errors.dataset.DatasetInUseError: raise DatasetInUseError() + class DatasetUseCheckApi(Resource): @setup_required @login_required @@ -274,10 +277,10 @@ class DatasetUseCheckApi(Resource): dataset_id_str = str(dataset_id) dataset_is_using = DatasetService.dataset_use_check(dataset_id_str) - return {'is_using': dataset_is_using}, 200 + return {"is_using": dataset_is_using}, 200 + class DatasetQueryApi(Resource): - @setup_required @login_required @account_initialization_required @@ -292,51 +295,53 @@ class DatasetQueryApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) - dataset_queries, total = DatasetService.get_dataset_queries( - dataset_id=dataset.id, - page=page, - per_page=limit - ) + dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit) response = { - 'data': marshal(dataset_queries, dataset_query_detail_fields), - 'has_more': len(dataset_queries) == limit, - 'limit': limit, - 'total': total, - 'page': page + "data": marshal(dataset_queries, dataset_query_detail_fields), + "has_more": len(dataset_queries) == limit, + "limit": limit, + "total": total, + "page": page, } return response, 200 class DatasetIndexingEstimateApi(Resource): - @setup_required @login_required @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json') - parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') - parser.add_argument('indexing_technique', type=str, required=True, - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, location='json') - parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, - location='json') + parser.add_argument("info_list", type=dict, required=True, nullable=True, location="json") + parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") + parser.add_argument( + "indexing_technique", + type=str, + required=True, + choices=Dataset.INDEXING_TECHNIQUE_LIST, + nullable=True, + location="json", + ) + parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument("dataset_id", type=str, required=False, nullable=False, location="json") + parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" + ) args = parser.parse_args() # validate args DocumentService.estimate_args_validate(args) extract_settings = [] - if args['info_list']['data_source_type'] == 'upload_file': - file_ids = args['info_list']['file_info_list']['file_ids'] - file_details = db.session.query(UploadFile).filter( - UploadFile.tenant_id == current_user.current_tenant_id, - UploadFile.id.in_(file_ids) - ).all() + if args["info_list"]["data_source_type"] == "upload_file": + file_ids = args["info_list"]["file_info_list"]["file_ids"] + file_details = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)) + .all() + ) if file_details is None: raise NotFound("File not found.") @@ -344,55 +349,58 @@ class DatasetIndexingEstimateApi(Resource): if file_details: for file_detail in file_details: extract_setting = ExtractSetting( - datasource_type="upload_file", - upload_file=file_detail, - document_model=args['doc_form'] + datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"] ) extract_settings.append(extract_setting) - elif args['info_list']['data_source_type'] == 'notion_import': - notion_info_list = args['info_list']['notion_info_list'] + elif args["info_list"]["data_source_type"] == "notion_import": + notion_info_list = args["info_list"]["notion_info_list"] for notion_info in notion_info_list: - workspace_id = notion_info['workspace_id'] - for page in notion_info['pages']: + workspace_id = notion_info["workspace_id"] + for page in notion_info["pages"]: extract_setting = ExtractSetting( datasource_type="notion_import", notion_info={ "notion_workspace_id": workspace_id, - "notion_obj_id": page['page_id'], - "notion_page_type": page['type'], - "tenant_id": current_user.current_tenant_id + "notion_obj_id": page["page_id"], + "notion_page_type": page["type"], + "tenant_id": current_user.current_tenant_id, }, - document_model=args['doc_form'] + document_model=args["doc_form"], ) extract_settings.append(extract_setting) - elif args['info_list']['data_source_type'] == 'website_crawl': - website_info_list = args['info_list']['website_info_list'] - for url in website_info_list['urls']: + elif args["info_list"]["data_source_type"] == "website_crawl": + website_info_list = args["info_list"]["website_info_list"] + for url in website_info_list["urls"]: extract_setting = ExtractSetting( datasource_type="website_crawl", website_info={ - "provider": website_info_list['provider'], - "job_id": website_info_list['job_id'], + "provider": website_info_list["provider"], + "job_id": website_info_list["job_id"], "url": url, "tenant_id": current_user.current_tenant_id, - "mode": 'crawl', - "only_main_content": website_info_list['only_main_content'] + "mode": "crawl", + "only_main_content": website_info_list["only_main_content"], }, - document_model=args['doc_form'] + document_model=args["doc_form"], ) extract_settings.append(extract_setting) else: - raise ValueError('Data source type not support') + raise ValueError("Data source type not support") indexing_runner = IndexingRunner() try: - response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings, - args['process_rule'], args['doc_form'], - args['doc_language'], args['dataset_id'], - args['indexing_technique']) + response = indexing_runner.indexing_estimate( + current_user.current_tenant_id, + extract_settings, + args["process_rule"], + args["doc_form"], + args["doc_language"], + args["dataset_id"], + args["indexing_technique"], + ) except LLMBadRequestError: raise ProviderNotInitializeError( - "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except Exception as e: @@ -402,7 +410,6 @@ class DatasetIndexingEstimateApi(Resource): class DatasetRelatedAppListApi(Resource): - @setup_required @login_required @account_initialization_required @@ -426,52 +433,52 @@ class DatasetRelatedAppListApi(Resource): if app_model: related_apps.append(app_model) - return { - 'data': related_apps, - 'total': len(related_apps) - }, 200 + return {"data": related_apps, "total": len(related_apps)}, 200 class DatasetIndexingStatusApi(Resource): - @setup_required @login_required @account_initialization_required def get(self, dataset_id): dataset_id = str(dataset_id) - documents = db.session.query(Document).filter( - Document.dataset_id == dataset_id, - Document.tenant_id == current_user.current_tenant_id - ).all() + documents = ( + db.session.query(Document) + .filter(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id) + .all() + ) documents_status = [] for document in documents: - completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() - total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() + completed_segments = DocumentSegment.query.filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != "re_segment", + ).count() + total_segments = DocumentSegment.query.filter( + DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" + ).count() document.completed_segments = completed_segments document.total_segments = total_segments documents_status.append(marshal(document, document_status_fields)) - data = { - 'data': documents_status - } + data = {"data": documents_status} return data class DatasetApiKeyApi(Resource): max_keys = 10 - token_prefix = 'dataset-' - resource_type = 'dataset' + token_prefix = "dataset-" + resource_type = "dataset" @setup_required @login_required @account_initialization_required @marshal_with(api_key_list) def get(self): - keys = db.session.query(ApiToken). \ - filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \ - all() + keys = ( + db.session.query(ApiToken) + .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) + .all() + ) return {"items": keys} @setup_required @@ -483,15 +490,17 @@ class DatasetApiKeyApi(Resource): if not current_user.is_admin_or_owner: raise Forbidden() - current_key_count = db.session.query(ApiToken). \ - filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \ - count() + current_key_count = ( + db.session.query(ApiToken) + .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) + .count() + ) if current_key_count >= self.max_keys: flask_restful.abort( 400, message=f"Cannot create more than {self.max_keys} API keys for this resource type.", - code='max_keys_exceeded' + code="max_keys_exceeded", ) key = ApiToken.generate_api_key(self.token_prefix, 24) @@ -505,7 +514,7 @@ class DatasetApiKeyApi(Resource): class DatasetApiDeleteApi(Resource): - resource_type = 'dataset' + resource_type = "dataset" @setup_required @login_required @@ -517,18 +526,23 @@ class DatasetApiDeleteApi(Resource): if not current_user.is_admin_or_owner: raise Forbidden() - key = db.session.query(ApiToken). \ - filter(ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.type == self.resource_type, - ApiToken.id == api_key_id). \ - first() + key = ( + db.session.query(ApiToken) + .filter( + ApiToken.tenant_id == current_user.current_tenant_id, + ApiToken.type == self.resource_type, + ApiToken.id == api_key_id, + ) + .first() + ) if key is None: - flask_restful.abort(404, message='API key not found') + flask_restful.abort(404, message="API key not found") db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete() db.session.commit() - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class DatasetApiBaseUrlApi(Resource): @@ -537,8 +551,10 @@ class DatasetApiBaseUrlApi(Resource): @account_initialization_required def get(self): return { - 'api_base_url': (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL - else request.host_url.rstrip('/')) + '/v1' + "api_base_url": ( + dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/") + ) + + "/v1" } @@ -549,15 +565,26 @@ class DatasetRetrievalSettingApi(Resource): def get(self): vector_type = dify_config.VECTOR_STORE match vector_type: - case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT: + case ( + VectorType.MILVUS + | VectorType.RELYT + | VectorType.PGVECTOR + | VectorType.TIDB_VECTOR + | VectorType.CHROMA + | VectorType.TENCENT + ): + return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} + case ( + VectorType.QDRANT + | VectorType.WEAVIATE + | VectorType.OPENSEARCH + | VectorType.ANALYTICDB + | VectorType.MYSCALE + | VectorType.ORACLE + | VectorType.ELASTICSEARCH + ): return { - 'retrieval_method': [ - RetrievalMethod.SEMANTIC_SEARCH.value - ] - } - case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH: - return { - 'retrieval_method': [ + "retrieval_method": [ RetrievalMethod.SEMANTIC_SEARCH.value, RetrievalMethod.FULL_TEXT_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value, @@ -573,15 +600,27 @@ class DatasetRetrievalSettingMockApi(Resource): @account_initialization_required def get(self, vector_type): match vector_type: - case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT: + case ( + VectorType.MILVUS + | VectorType.RELYT + | VectorType.TIDB_VECTOR + | VectorType.CHROMA + | VectorType.TENCENT + | VectorType.PGVECTO_RS + ): + return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} + case ( + VectorType.QDRANT + | VectorType.WEAVIATE + | VectorType.OPENSEARCH + | VectorType.ANALYTICDB + | VectorType.MYSCALE + | VectorType.ORACLE + | VectorType.ELASTICSEARCH + | VectorType.PGVECTOR + ): return { - 'retrieval_method': [ - RetrievalMethod.SEMANTIC_SEARCH.value - ] - } - case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH: - return { - 'retrieval_method': [ + "retrieval_method": [ RetrievalMethod.SEMANTIC_SEARCH.value, RetrievalMethod.FULL_TEXT_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value, @@ -591,7 +630,6 @@ class DatasetRetrievalSettingMockApi(Resource): raise ValueError(f"Unsupported vector db type {vector_type}.") - class DatasetErrorDocs(Resource): @setup_required @login_required @@ -603,10 +641,7 @@ class DatasetErrorDocs(Resource): raise NotFound("Dataset not found.") results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str) - return { - 'data': [marshal(item, document_status_fields) for item in results], - 'total': len(results) - }, 200 + return {"data": [marshal(item, document_status_fields) for item in results], "total": len(results)}, 200 class DatasetPermissionUserListApi(Resource): @@ -626,21 +661,21 @@ class DatasetPermissionUserListApi(Resource): partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) return { - 'data': partial_members_list, + "data": partial_members_list, }, 200 -api.add_resource(DatasetListApi, '/datasets') -api.add_resource(DatasetApi, '/datasets/') -api.add_resource(DatasetUseCheckApi, '/datasets//use-check') -api.add_resource(DatasetQueryApi, '/datasets//queries') -api.add_resource(DatasetErrorDocs, '/datasets//error-docs') -api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate') -api.add_resource(DatasetRelatedAppListApi, '/datasets//related-apps') -api.add_resource(DatasetIndexingStatusApi, '/datasets//indexing-status') -api.add_resource(DatasetApiKeyApi, '/datasets/api-keys') -api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/') -api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info') -api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting') -api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/') -api.add_resource(DatasetPermissionUserListApi, '/datasets//permission-part-users') +api.add_resource(DatasetListApi, "/datasets") +api.add_resource(DatasetApi, "/datasets/") +api.add_resource(DatasetUseCheckApi, "/datasets//use-check") +api.add_resource(DatasetQueryApi, "/datasets//queries") +api.add_resource(DatasetErrorDocs, "/datasets//error-docs") +api.add_resource(DatasetIndexingEstimateApi, "/datasets/indexing-estimate") +api.add_resource(DatasetRelatedAppListApi, "/datasets//related-apps") +api.add_resource(DatasetIndexingStatusApi, "/datasets//indexing-status") +api.add_resource(DatasetApiKeyApi, "/datasets/api-keys") +api.add_resource(DatasetApiDeleteApi, "/datasets/api-keys/") +api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info") +api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting") +api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/") +api.add_resource(DatasetPermissionUserListApi, "/datasets//permission-part-users") diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 976b97660a..6bc29a8643 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -57,7 +57,7 @@ class DocumentResource(Resource): def get_document(self, dataset_id: str, document_id: str) -> Document: dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") try: DatasetService.check_dataset_permission(dataset, current_user) @@ -67,17 +67,17 @@ class DocumentResource(Resource): document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") if document.tenant_id != current_user.current_tenant_id: - raise Forbidden('No permission.') + raise Forbidden("No permission.") return document def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]: dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") try: DatasetService.check_dataset_permission(dataset, current_user) @@ -87,7 +87,7 @@ class DocumentResource(Resource): documents = DocumentService.get_batch_documents(dataset_id, batch) if not documents: - raise NotFound('Documents not found.') + raise NotFound("Documents not found.") return documents @@ -99,11 +99,11 @@ class GetProcessRuleApi(Resource): def get(self): req_data = request.args - document_id = req_data.get('document_id') + document_id = req_data.get("document_id") # get default rules - mode = DocumentService.DEFAULT_RULES['mode'] - rules = DocumentService.DEFAULT_RULES['rules'] + mode = DocumentService.DEFAULT_RULES["mode"] + rules = DocumentService.DEFAULT_RULES["rules"] if document_id: # get the latest process rule document = Document.query.get_or_404(document_id) @@ -111,7 +111,7 @@ class GetProcessRuleApi(Resource): dataset = DatasetService.get_dataset(document.dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") try: DatasetService.check_dataset_permission(dataset, current_user) @@ -119,19 +119,18 @@ class GetProcessRuleApi(Resource): raise Forbidden(str(e)) # get the latest process rule - dataset_process_rule = db.session.query(DatasetProcessRule). \ - filter(DatasetProcessRule.dataset_id == document.dataset_id). \ - order_by(DatasetProcessRule.created_at.desc()). \ - limit(1). \ - one_or_none() + dataset_process_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.dataset_id == document.dataset_id) + .order_by(DatasetProcessRule.created_at.desc()) + .limit(1) + .one_or_none() + ) if dataset_process_rule: mode = dataset_process_rule.mode rules = dataset_process_rule.rules_dict - return { - 'mode': mode, - 'rules': rules - } + return {"mode": mode, "rules": rules} class DatasetDocumentListApi(Resource): @@ -140,49 +139,48 @@ class DatasetDocumentListApi(Resource): @account_initialization_required def get(self, dataset_id): dataset_id = str(dataset_id) - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) - search = request.args.get('keyword', default=None, type=str) - sort = request.args.get('sort', default='-created_at', type=str) + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) + search = request.args.get("keyword", default=None, type=str) + sort = request.args.get("sort", default="-created_at", type=str) # "yes", "true", "t", "y", "1" convert to True, while others convert to False. try: - fetch = string_to_bool(request.args.get('fetch', default='false')) + fetch = string_to_bool(request.args.get("fetch", default="false")) except (ArgumentTypeError, ValueError, Exception) as e: fetch = False dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") try: DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - query = Document.query.filter_by( - dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) + query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) if search: - search = f'%{search}%' + search = f"%{search}%" query = query.filter(Document.name.like(search)) - if sort.startswith('-'): + if sort.startswith("-"): sort_logic = desc sort = sort[1:] else: sort_logic = asc - if sort == 'hit_count': - sub_query = db.select(DocumentSegment.document_id, - db.func.sum(DocumentSegment.hit_count).label("total_hit_count")) \ - .group_by(DocumentSegment.document_id) \ + if sort == "hit_count": + sub_query = ( + db.select(DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count")) + .group_by(DocumentSegment.document_id) .subquery() + ) - query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id) \ - .order_by( - sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)), - sort_logic(Document.position), - ) - elif sort == 'created_at': + query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by( + sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)), + sort_logic(Document.position), + ) + elif sort == "created_at": query = query.order_by( sort_logic(Document.created_at), sort_logic(Document.position), @@ -193,48 +191,47 @@ class DatasetDocumentListApi(Resource): desc(Document.position), ) - paginated_documents = query.paginate( - page=page, per_page=limit, max_per_page=100, error_out=False) + paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) documents = paginated_documents.items if fetch: for document in documents: - completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() - total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() + completed_segments = DocumentSegment.query.filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != "re_segment", + ).count() + total_segments = DocumentSegment.query.filter( + DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" + ).count() document.completed_segments = completed_segments document.total_segments = total_segments data = marshal(documents, document_with_segments_fields) else: data = marshal(documents, document_fields) response = { - 'data': data, - 'has_more': len(documents) == limit, - 'limit': limit, - 'total': paginated_documents.total, - 'page': page + "data": data, + "has_more": len(documents) == limit, + "limit": limit, + "total": paginated_documents.total, + "page": page, } return response - documents_and_batch_fields = { - 'documents': fields.List(fields.Nested(document_fields)), - 'batch': fields.String - } + documents_and_batch_fields = {"documents": fields.List(fields.Nested(document_fields)), "batch": fields.String} @setup_required @login_required @account_initialization_required @marshal_with(documents_and_batch_fields) - @cloud_edition_billing_resource_check('vector_space') + @cloud_edition_billing_resource_check("vector_space") def post(self, dataset_id): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_dataset_editor: @@ -246,21 +243,22 @@ class DatasetDocumentListApi(Resource): raise Forbidden(str(e)) parser = reqparse.RequestParser() - parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, - location='json') - parser.add_argument('data_source', type=dict, required=False, location='json') - parser.add_argument('process_rule', type=dict, required=False, location='json') - parser.add_argument('duplicate', type=bool, default=True, nullable=False, location='json') - parser.add_argument('original_document_id', type=str, required=False, location='json') - parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, - location='json') - parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, - location='json') + parser.add_argument( + "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" + ) + parser.add_argument("data_source", type=dict, required=False, location="json") + parser.add_argument("process_rule", type=dict, required=False, location="json") + parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json") + parser.add_argument("original_document_id", type=str, required=False, location="json") + parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" + ) + parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") args = parser.parse_args() - if not dataset.indexing_technique and not args['indexing_technique']: - raise ValueError('indexing_technique is required.') + if not dataset.indexing_technique and not args["indexing_technique"]: + raise ValueError("indexing_technique is required.") # validate args DocumentService.document_create_args_validate(args) @@ -274,51 +272,53 @@ class DatasetDocumentListApi(Resource): except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() - return { - 'documents': documents, - 'batch': batch - } + return {"documents": documents, "batch": batch} class DatasetInitApi(Resource): - @setup_required @login_required @account_initialization_required @marshal_with(dataset_and_document_fields) - @cloud_edition_billing_resource_check('vector_space') + @cloud_edition_billing_resource_check("vector_space") def post(self): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, required=True, - nullable=False, location='json') - parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json') - parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') - parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, - location='json') - parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, - location='json') + parser.add_argument( + "indexing_technique", + type=str, + choices=Dataset.INDEXING_TECHNIQUE_LIST, + required=True, + nullable=False, + location="json", + ) + parser.add_argument("data_source", type=dict, required=True, nullable=True, location="json") + parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") + parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" + ) + parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") args = parser.parse_args() # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator if not current_user.is_dataset_editor: raise Forbidden() - if args['indexing_technique'] == 'high_quality': + if args["indexing_technique"] == "high_quality": try: model_manager = ModelManager() model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, - model_type=ModelType.TEXT_EMBEDDING + tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING ) except InvokeAuthorizationError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -327,9 +327,7 @@ class DatasetInitApi(Resource): try: dataset, documents, batch = DocumentService.save_document_without_dataset_id( - tenant_id=current_user.current_tenant_id, - document_data=args, - account=current_user + tenant_id=current_user.current_tenant_id, document_data=args, account=current_user ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -338,17 +336,12 @@ class DatasetInitApi(Resource): except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() - response = { - 'dataset': dataset, - 'documents': documents, - 'batch': batch - } + response = {"dataset": dataset, "documents": documents, "batch": batch} return response class DocumentIndexingEstimateApi(DocumentResource): - @setup_required @login_required @account_initialization_required @@ -357,50 +350,49 @@ class DocumentIndexingEstimateApi(DocumentResource): document_id = str(document_id) document = self.get_document(dataset_id, document_id) - if document.indexing_status in ['completed', 'error']: + if document.indexing_status in ["completed", "error"]: raise DocumentAlreadyFinishedError() data_process_rule = document.dataset_process_rule data_process_rule_dict = data_process_rule.to_dict() - response = { - "tokens": 0, - "total_price": 0, - "currency": "USD", - "total_segments": 0, - "preview": [] - } + response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []} - if document.data_source_type == 'upload_file': + if document.data_source_type == "upload_file": data_source_info = document.data_source_info_dict - if data_source_info and 'upload_file_id' in data_source_info: - file_id = data_source_info['upload_file_id'] + if data_source_info and "upload_file_id" in data_source_info: + file_id = data_source_info["upload_file_id"] - file = db.session.query(UploadFile).filter( - UploadFile.tenant_id == document.tenant_id, - UploadFile.id == file_id - ).first() + file = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) + .first() + ) # raise error if file not found if not file: - raise NotFound('File not found.') + raise NotFound("File not found.") extract_setting = ExtractSetting( - datasource_type="upload_file", - upload_file=file, - document_model=document.doc_form + datasource_type="upload_file", upload_file=file, document_model=document.doc_form ) indexing_runner = IndexingRunner() try: - response = indexing_runner.indexing_estimate(current_user.current_tenant_id, [extract_setting], - data_process_rule_dict, document.doc_form, - 'English', dataset_id) + response = indexing_runner.indexing_estimate( + current_user.current_tenant_id, + [extract_setting], + data_process_rule_dict, + document.doc_form, + "English", + dataset_id, + ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except Exception as e: @@ -410,7 +402,6 @@ class DocumentIndexingEstimateApi(DocumentResource): class DocumentBatchIndexingEstimateApi(DocumentResource): - @setup_required @login_required @account_initialization_required @@ -418,13 +409,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): dataset_id = str(dataset_id) batch = str(batch) documents = self.get_batch_documents(dataset_id, batch) - response = { - "tokens": 0, - "total_price": 0, - "currency": "USD", - "total_segments": 0, - "preview": [] - } + response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []} if not documents: return response data_process_rule = documents[0].dataset_process_rule @@ -432,82 +417,83 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): info_list = [] extract_settings = [] for document in documents: - if document.indexing_status in ['completed', 'error']: + if document.indexing_status in ["completed", "error"]: raise DocumentAlreadyFinishedError() data_source_info = document.data_source_info_dict # format document files info - if data_source_info and 'upload_file_id' in data_source_info: - file_id = data_source_info['upload_file_id'] + if data_source_info and "upload_file_id" in data_source_info: + file_id = data_source_info["upload_file_id"] info_list.append(file_id) # format document notion info - elif data_source_info and 'notion_workspace_id' in data_source_info and 'notion_page_id' in data_source_info: + elif ( + data_source_info and "notion_workspace_id" in data_source_info and "notion_page_id" in data_source_info + ): pages = [] - page = { - 'page_id': data_source_info['notion_page_id'], - 'type': data_source_info['type'] - } + page = {"page_id": data_source_info["notion_page_id"], "type": data_source_info["type"]} pages.append(page) - notion_info = { - 'workspace_id': data_source_info['notion_workspace_id'], - 'pages': pages - } + notion_info = {"workspace_id": data_source_info["notion_workspace_id"], "pages": pages} info_list.append(notion_info) - if document.data_source_type == 'upload_file': - file_id = data_source_info['upload_file_id'] - file_detail = db.session.query(UploadFile).filter( - UploadFile.tenant_id == current_user.current_tenant_id, - UploadFile.id == file_id - ).first() + if document.data_source_type == "upload_file": + file_id = data_source_info["upload_file_id"] + file_detail = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id) + .first() + ) if file_detail is None: raise NotFound("File not found.") extract_setting = ExtractSetting( - datasource_type="upload_file", - upload_file=file_detail, - document_model=document.doc_form + datasource_type="upload_file", upload_file=file_detail, document_model=document.doc_form ) extract_settings.append(extract_setting) - elif document.data_source_type == 'notion_import': + elif document.data_source_type == "notion_import": extract_setting = ExtractSetting( datasource_type="notion_import", notion_info={ - "notion_workspace_id": data_source_info['notion_workspace_id'], - "notion_obj_id": data_source_info['notion_page_id'], - "notion_page_type": data_source_info['type'], - "tenant_id": current_user.current_tenant_id + "notion_workspace_id": data_source_info["notion_workspace_id"], + "notion_obj_id": data_source_info["notion_page_id"], + "notion_page_type": data_source_info["type"], + "tenant_id": current_user.current_tenant_id, }, - document_model=document.doc_form + document_model=document.doc_form, ) extract_settings.append(extract_setting) - elif document.data_source_type == 'website_crawl': + elif document.data_source_type == "website_crawl": extract_setting = ExtractSetting( datasource_type="website_crawl", website_info={ - "provider": data_source_info['provider'], - "job_id": data_source_info['job_id'], - "url": data_source_info['url'], + "provider": data_source_info["provider"], + "job_id": data_source_info["job_id"], + "url": data_source_info["url"], "tenant_id": current_user.current_tenant_id, - "mode": data_source_info['mode'], - "only_main_content": data_source_info['only_main_content'] + "mode": data_source_info["mode"], + "only_main_content": data_source_info["only_main_content"], }, - document_model=document.doc_form + document_model=document.doc_form, ) extract_settings.append(extract_setting) else: - raise ValueError('Data source type not support') + raise ValueError("Data source type not support") indexing_runner = IndexingRunner() try: - response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings, - data_process_rule_dict, document.doc_form, - 'English', dataset_id) + response = indexing_runner.indexing_estimate( + current_user.current_tenant_id, + extract_settings, + data_process_rule_dict, + document.doc_form, + "English", + dataset_id, + ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except Exception as e: @@ -516,7 +502,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): class DocumentBatchIndexingStatusApi(DocumentResource): - @setup_required @login_required @account_initialization_required @@ -526,24 +511,24 @@ class DocumentBatchIndexingStatusApi(DocumentResource): documents = self.get_batch_documents(dataset_id, batch) documents_status = [] for document in documents: - completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() - total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() + completed_segments = DocumentSegment.query.filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != "re_segment", + ).count() + total_segments = DocumentSegment.query.filter( + DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" + ).count() document.completed_segments = completed_segments document.total_segments = total_segments if document.is_paused: - document.indexing_status = 'paused' + document.indexing_status = "paused" documents_status.append(marshal(document, document_status_fields)) - data = { - 'data': documents_status - } + data = {"data": documents_status} return data class DocumentIndexingStatusApi(DocumentResource): - @setup_required @login_required @account_initialization_required @@ -552,25 +537,24 @@ class DocumentIndexingStatusApi(DocumentResource): document_id = str(document_id) document = self.get_document(dataset_id, document_id) - completed_segments = DocumentSegment.query \ - .filter(DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document_id), - DocumentSegment.status != 're_segment') \ - .count() - total_segments = DocumentSegment.query \ - .filter(DocumentSegment.document_id == str(document_id), - DocumentSegment.status != 're_segment') \ - .count() + completed_segments = DocumentSegment.query.filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document_id), + DocumentSegment.status != "re_segment", + ).count() + total_segments = DocumentSegment.query.filter( + DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment" + ).count() document.completed_segments = completed_segments document.total_segments = total_segments if document.is_paused: - document.indexing_status = 'paused' + document.indexing_status = "paused" return marshal(document, document_status_fields) class DocumentDetailApi(DocumentResource): - METADATA_CHOICES = {'all', 'only', 'without'} + METADATA_CHOICES = {"all", "only", "without"} @setup_required @login_required @@ -580,77 +564,75 @@ class DocumentDetailApi(DocumentResource): document_id = str(document_id) document = self.get_document(dataset_id, document_id) - metadata = request.args.get('metadata', 'all') + metadata = request.args.get("metadata", "all") if metadata not in self.METADATA_CHOICES: - raise InvalidMetadataError(f'Invalid metadata value: {metadata}') + raise InvalidMetadataError(f"Invalid metadata value: {metadata}") - if metadata == 'only': - response = { - 'id': document.id, - 'doc_type': document.doc_type, - 'doc_metadata': document.doc_metadata - } - elif metadata == 'without': + if metadata == "only": + response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata} + elif metadata == "without": process_rules = DatasetService.get_process_rules(dataset_id) data_source_info = document.data_source_detail_dict response = { - 'id': document.id, - 'position': document.position, - 'data_source_type': document.data_source_type, - 'data_source_info': data_source_info, - 'dataset_process_rule_id': document.dataset_process_rule_id, - 'dataset_process_rule': process_rules, - 'name': document.name, - 'created_from': document.created_from, - 'created_by': document.created_by, - 'created_at': document.created_at.timestamp(), - 'tokens': document.tokens, - 'indexing_status': document.indexing_status, - 'completed_at': int(document.completed_at.timestamp()) if document.completed_at else None, - 'updated_at': int(document.updated_at.timestamp()) if document.updated_at else None, - 'indexing_latency': document.indexing_latency, - 'error': document.error, - 'enabled': document.enabled, - 'disabled_at': int(document.disabled_at.timestamp()) if document.disabled_at else None, - 'disabled_by': document.disabled_by, - 'archived': document.archived, - 'segment_count': document.segment_count, - 'average_segment_length': document.average_segment_length, - 'hit_count': document.hit_count, - 'display_status': document.display_status, - 'doc_form': document.doc_form + "id": document.id, + "position": document.position, + "data_source_type": document.data_source_type, + "data_source_info": data_source_info, + "dataset_process_rule_id": document.dataset_process_rule_id, + "dataset_process_rule": process_rules, + "name": document.name, + "created_from": document.created_from, + "created_by": document.created_by, + "created_at": document.created_at.timestamp(), + "tokens": document.tokens, + "indexing_status": document.indexing_status, + "completed_at": int(document.completed_at.timestamp()) if document.completed_at else None, + "updated_at": int(document.updated_at.timestamp()) if document.updated_at else None, + "indexing_latency": document.indexing_latency, + "error": document.error, + "enabled": document.enabled, + "disabled_at": int(document.disabled_at.timestamp()) if document.disabled_at else None, + "disabled_by": document.disabled_by, + "archived": document.archived, + "segment_count": document.segment_count, + "average_segment_length": document.average_segment_length, + "hit_count": document.hit_count, + "display_status": document.display_status, + "doc_form": document.doc_form, + "doc_language": document.doc_language, } else: process_rules = DatasetService.get_process_rules(dataset_id) data_source_info = document.data_source_detail_dict response = { - 'id': document.id, - 'position': document.position, - 'data_source_type': document.data_source_type, - 'data_source_info': data_source_info, - 'dataset_process_rule_id': document.dataset_process_rule_id, - 'dataset_process_rule': process_rules, - 'name': document.name, - 'created_from': document.created_from, - 'created_by': document.created_by, - 'created_at': document.created_at.timestamp(), - 'tokens': document.tokens, - 'indexing_status': document.indexing_status, - 'completed_at': int(document.completed_at.timestamp()) if document.completed_at else None, - 'updated_at': int(document.updated_at.timestamp()) if document.updated_at else None, - 'indexing_latency': document.indexing_latency, - 'error': document.error, - 'enabled': document.enabled, - 'disabled_at': int(document.disabled_at.timestamp()) if document.disabled_at else None, - 'disabled_by': document.disabled_by, - 'archived': document.archived, - 'doc_type': document.doc_type, - 'doc_metadata': document.doc_metadata, - 'segment_count': document.segment_count, - 'average_segment_length': document.average_segment_length, - 'hit_count': document.hit_count, - 'display_status': document.display_status, - 'doc_form': document.doc_form + "id": document.id, + "position": document.position, + "data_source_type": document.data_source_type, + "data_source_info": data_source_info, + "dataset_process_rule_id": document.dataset_process_rule_id, + "dataset_process_rule": process_rules, + "name": document.name, + "created_from": document.created_from, + "created_by": document.created_by, + "created_at": document.created_at.timestamp(), + "tokens": document.tokens, + "indexing_status": document.indexing_status, + "completed_at": int(document.completed_at.timestamp()) if document.completed_at else None, + "updated_at": int(document.updated_at.timestamp()) if document.updated_at else None, + "indexing_latency": document.indexing_latency, + "error": document.error, + "enabled": document.enabled, + "disabled_at": int(document.disabled_at.timestamp()) if document.disabled_at else None, + "disabled_by": document.disabled_by, + "archived": document.archived, + "doc_type": document.doc_type, + "doc_metadata": document.doc_metadata, + "segment_count": document.segment_count, + "average_segment_length": document.average_segment_length, + "hit_count": document.hit_count, + "display_status": document.display_status, + "doc_form": document.doc_form, + "doc_language": document.doc_language, } return response, 200 @@ -671,7 +653,7 @@ class DocumentProcessingApi(DocumentResource): if action == "pause": if document.indexing_status != "indexing": - raise InvalidActionError('Document not in indexing state.') + raise InvalidActionError("Document not in indexing state.") document.paused_by = current_user.id document.paused_at = datetime.now(timezone.utc).replace(tzinfo=None) @@ -680,7 +662,7 @@ class DocumentProcessingApi(DocumentResource): elif action == "resume": if document.indexing_status not in ["paused", "error"]: - raise InvalidActionError('Document not in paused or error state.') + raise InvalidActionError("Document not in paused or error state.") document.paused_by = None document.paused_at = None @@ -689,7 +671,7 @@ class DocumentProcessingApi(DocumentResource): else: raise InvalidActionError() - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class DocumentDeleteApi(DocumentResource): @@ -710,9 +692,9 @@ class DocumentDeleteApi(DocumentResource): try: DocumentService.delete_document(document) except services.errors.document.DocumentIndexingError: - raise DocumentIndexingError('Cannot delete document during indexing.') + raise DocumentIndexingError("Cannot delete document during indexing.") - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class DocumentMetadataApi(DocumentResource): @@ -726,26 +708,26 @@ class DocumentMetadataApi(DocumentResource): req_data = request.get_json() - doc_type = req_data.get('doc_type') - doc_metadata = req_data.get('doc_metadata') + doc_type = req_data.get("doc_type") + doc_metadata = req_data.get("doc_metadata") # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() if doc_type is None or doc_metadata is None: - raise ValueError('Both doc_type and doc_metadata must be provided.') + raise ValueError("Both doc_type and doc_metadata must be provided.") if doc_type not in DocumentService.DOCUMENT_METADATA_SCHEMA: - raise ValueError('Invalid doc_type.') + raise ValueError("Invalid doc_type.") if not isinstance(doc_metadata, dict): - raise ValueError('doc_metadata must be a dictionary.') + raise ValueError("doc_metadata must be a dictionary.") metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type] document.doc_metadata = {} - if doc_type == 'others': + if doc_type == "others": document.doc_metadata = doc_metadata else: for key, value_type in metadata_schema.items(): @@ -757,14 +739,14 @@ class DocumentMetadataApi(DocumentResource): document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() - return {'result': 'success', 'message': 'Document metadata updated.'}, 200 + return {"result": "success", "message": "Document metadata updated."}, 200 class DocumentStatusApi(DocumentResource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('vector_space') + @cloud_edition_billing_resource_check("vector_space") def patch(self, dataset_id, document_id, action): dataset_id = str(dataset_id) document_id = str(document_id) @@ -784,14 +766,14 @@ class DocumentStatusApi(DocumentResource): document = self.get_document(dataset_id, document_id) - indexing_cache_key = 'document_{}_indexing'.format(document.id) + indexing_cache_key = "document_{}_indexing".format(document.id) cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise InvalidActionError("Document is being indexed, please try again later") if action == "enable": if document.enabled: - raise InvalidActionError('Document already enabled.') + raise InvalidActionError("Document already enabled.") document.enabled = True document.disabled_at = None @@ -804,13 +786,13 @@ class DocumentStatusApi(DocumentResource): add_document_to_index_task.delay(document_id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 elif action == "disable": - if not document.completed_at or document.indexing_status != 'completed': - raise InvalidActionError('Document is not completed.') + if not document.completed_at or document.indexing_status != "completed": + raise InvalidActionError("Document is not completed.") if not document.enabled: - raise InvalidActionError('Document already disabled.') + raise InvalidActionError("Document already disabled.") document.enabled = False document.disabled_at = datetime.now(timezone.utc).replace(tzinfo=None) @@ -823,11 +805,11 @@ class DocumentStatusApi(DocumentResource): remove_document_from_index_task.delay(document_id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 elif action == "archive": if document.archived: - raise InvalidActionError('Document already archived.') + raise InvalidActionError("Document already archived.") document.archived = True document.archived_at = datetime.now(timezone.utc).replace(tzinfo=None) @@ -841,10 +823,10 @@ class DocumentStatusApi(DocumentResource): remove_document_from_index_task.delay(document_id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 elif action == "un_archive": if not document.archived: - raise InvalidActionError('Document is not archived.') + raise InvalidActionError("Document is not archived.") document.archived = False document.archived_at = None @@ -857,13 +839,12 @@ class DocumentStatusApi(DocumentResource): add_document_to_index_task.delay(document_id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 else: raise InvalidActionError() class DocumentPauseApi(DocumentResource): - @setup_required @login_required @account_initialization_required @@ -874,7 +855,7 @@ class DocumentPauseApi(DocumentResource): dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") document = DocumentService.get_document(dataset.id, document_id) @@ -890,9 +871,9 @@ class DocumentPauseApi(DocumentResource): # pause document DocumentService.pause_document(document) except services.errors.document.DocumentIndexingError: - raise DocumentIndexingError('Cannot pause completed document.') + raise DocumentIndexingError("Cannot pause completed document.") - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class DocumentRecoverApi(DocumentResource): @@ -905,7 +886,7 @@ class DocumentRecoverApi(DocumentResource): document_id = str(document_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") document = DocumentService.get_document(dataset.id, document_id) # 404 if document not found @@ -919,9 +900,9 @@ class DocumentRecoverApi(DocumentResource): # pause document DocumentService.recover_document(document) except services.errors.document.DocumentIndexingError: - raise DocumentIndexingError('Document is not in paused status.') + raise DocumentIndexingError("Document is not in paused status.") - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class DocumentRetryApi(DocumentResource): @@ -932,15 +913,14 @@ class DocumentRetryApi(DocumentResource): """retry document.""" parser = reqparse.RequestParser() - parser.add_argument('document_ids', type=list, required=True, nullable=False, - location='json') + parser.add_argument("document_ids", type=list, required=True, nullable=False, location="json") args = parser.parse_args() dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) retry_documents = [] if not dataset: - raise NotFound('Dataset not found.') - for document_id in args['document_ids']: + raise NotFound("Dataset not found.") + for document_id in args["document_ids"]: try: document_id = str(document_id) @@ -955,7 +935,7 @@ class DocumentRetryApi(DocumentResource): raise ArchivedDocumentImmutableError() # 400 if document is completed - if document.indexing_status == 'completed': + if document.indexing_status == "completed": raise DocumentAlreadyFinishedError() retry_documents.append(document) except Exception as e: @@ -964,7 +944,7 @@ class DocumentRetryApi(DocumentResource): # retry document DocumentService.retry_document(dataset_id, retry_documents) - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class DocumentRenameApi(DocumentResource): @@ -979,13 +959,13 @@ class DocumentRenameApi(DocumentResource): dataset = DatasetService.get_dataset(dataset_id) DatasetService.check_dataset_operator_permission(current_user, dataset) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, nullable=False, location='json') + parser.add_argument("name", type=str, required=True, nullable=False, location="json") args = parser.parse_args() try: - document = DocumentService.rename_document(dataset_id, document_id, args['name']) + document = DocumentService.rename_document(dataset_id, document_id, args["name"]) except services.errors.document.DocumentIndexingError: - raise DocumentIndexingError('Cannot delete document during indexing.') + raise DocumentIndexingError("Cannot delete document during indexing.") return document @@ -999,51 +979,43 @@ class WebsiteDocumentSyncApi(DocumentResource): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") if document.tenant_id != current_user.current_tenant_id: - raise Forbidden('No permission.') - if document.data_source_type != 'website_crawl': - raise ValueError('Document is not a website document.') + raise Forbidden("No permission.") + if document.data_source_type != "website_crawl": + raise ValueError("Document is not a website document.") # 403 if document is archived if DocumentService.check_archived(document): raise ArchivedDocumentImmutableError() # sync document DocumentService.sync_website_document(dataset_id, document) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(GetProcessRuleApi, '/datasets/process-rule') -api.add_resource(DatasetDocumentListApi, - '/datasets//documents') -api.add_resource(DatasetInitApi, - '/datasets/init') -api.add_resource(DocumentIndexingEstimateApi, - '/datasets//documents//indexing-estimate') -api.add_resource(DocumentBatchIndexingEstimateApi, - '/datasets//batch//indexing-estimate') -api.add_resource(DocumentBatchIndexingStatusApi, - '/datasets//batch//indexing-status') -api.add_resource(DocumentIndexingStatusApi, - '/datasets//documents//indexing-status') -api.add_resource(DocumentDetailApi, - '/datasets//documents/') -api.add_resource(DocumentProcessingApi, - '/datasets//documents//processing/') -api.add_resource(DocumentDeleteApi, - '/datasets//documents/') -api.add_resource(DocumentMetadataApi, - '/datasets//documents//metadata') -api.add_resource(DocumentStatusApi, - '/datasets//documents//status/') -api.add_resource(DocumentPauseApi, '/datasets//documents//processing/pause') -api.add_resource(DocumentRecoverApi, '/datasets//documents//processing/resume') -api.add_resource(DocumentRetryApi, '/datasets//retry') -api.add_resource(DocumentRenameApi, - '/datasets//documents//rename') +api.add_resource(GetProcessRuleApi, "/datasets/process-rule") +api.add_resource(DatasetDocumentListApi, "/datasets//documents") +api.add_resource(DatasetInitApi, "/datasets/init") +api.add_resource( + DocumentIndexingEstimateApi, "/datasets//documents//indexing-estimate" +) +api.add_resource(DocumentBatchIndexingEstimateApi, "/datasets//batch//indexing-estimate") +api.add_resource(DocumentBatchIndexingStatusApi, "/datasets//batch//indexing-status") +api.add_resource(DocumentIndexingStatusApi, "/datasets//documents//indexing-status") +api.add_resource(DocumentDetailApi, "/datasets//documents/") +api.add_resource( + DocumentProcessingApi, "/datasets//documents//processing/" +) +api.add_resource(DocumentDeleteApi, "/datasets//documents/") +api.add_resource(DocumentMetadataApi, "/datasets//documents//metadata") +api.add_resource(DocumentStatusApi, "/datasets//documents//status/") +api.add_resource(DocumentPauseApi, "/datasets//documents//processing/pause") +api.add_resource(DocumentRecoverApi, "/datasets//documents//processing/resume") +api.add_resource(DocumentRetryApi, "/datasets//retry") +api.add_resource(DocumentRenameApi, "/datasets//documents//rename") -api.add_resource(WebsiteDocumentSyncApi, '/datasets//documents//website-sync') +api.add_resource(WebsiteDocumentSyncApi, "/datasets//documents//website-sync") diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index a4210d5a0c..2405649387 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -40,7 +40,7 @@ class DatasetDocumentSegmentListApi(Resource): document_id = str(document_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") try: DatasetService.check_dataset_permission(dataset, current_user) @@ -50,37 +50,33 @@ class DatasetDocumentSegmentListApi(Resource): document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") parser = reqparse.RequestParser() - parser.add_argument('last_id', type=str, default=None, location='args') - parser.add_argument('limit', type=int, default=20, location='args') - parser.add_argument('status', type=str, - action='append', default=[], location='args') - parser.add_argument('hit_count_gte', type=int, - default=None, location='args') - parser.add_argument('enabled', type=str, default='all', location='args') - parser.add_argument('keyword', type=str, default=None, location='args') + parser.add_argument("last_id", type=str, default=None, location="args") + parser.add_argument("limit", type=int, default=20, location="args") + parser.add_argument("status", type=str, action="append", default=[], location="args") + parser.add_argument("hit_count_gte", type=int, default=None, location="args") + parser.add_argument("enabled", type=str, default="all", location="args") + parser.add_argument("keyword", type=str, default=None, location="args") args = parser.parse_args() - last_id = args['last_id'] - limit = min(args['limit'], 100) - status_list = args['status'] - hit_count_gte = args['hit_count_gte'] - keyword = args['keyword'] + last_id = args["last_id"] + limit = min(args["limit"], 100) + status_list = args["status"] + hit_count_gte = args["hit_count_gte"] + keyword = args["keyword"] query = DocumentSegment.query.filter( - DocumentSegment.document_id == str(document_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id ) if last_id is not None: last_segment = db.session.get(DocumentSegment, str(last_id)) if last_segment: - query = query.filter( - DocumentSegment.position > last_segment.position) + query = query.filter(DocumentSegment.position > last_segment.position) else: - return {'data': [], 'has_more': False, 'limit': limit}, 200 + return {"data": [], "has_more": False, "limit": limit}, 200 if status_list: query = query.filter(DocumentSegment.status.in_(status_list)) @@ -89,12 +85,12 @@ class DatasetDocumentSegmentListApi(Resource): query = query.filter(DocumentSegment.hit_count >= hit_count_gte) if keyword: - query = query.where(DocumentSegment.content.ilike(f'%{keyword}%')) + query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) - if args['enabled'].lower() != 'all': - if args['enabled'].lower() == 'true': + if args["enabled"].lower() != "all": + if args["enabled"].lower() == "true": query = query.filter(DocumentSegment.enabled == True) - elif args['enabled'].lower() == 'false': + elif args["enabled"].lower() == "false": query = query.filter(DocumentSegment.enabled == False) total = query.count() @@ -106,11 +102,11 @@ class DatasetDocumentSegmentListApi(Resource): segments = segments[:-1] return { - 'data': marshal(segments, segment_fields), - 'doc_form': document.doc_form, - 'has_more': has_more, - 'limit': limit, - 'total': total + "data": marshal(segments, segment_fields), + "doc_form": document.doc_form, + "has_more": has_more, + "limit": limit, + "total": total, }, 200 @@ -118,12 +114,12 @@ class DatasetDocumentSegmentApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('vector_space') + @cloud_edition_billing_resource_check("vector_space") def patch(self, dataset_id, segment_id, action): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # The role of the current user in the ta table must be admin, owner, or editor @@ -134,7 +130,7 @@ class DatasetDocumentSegmentApi(Resource): DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": # check embedding model setting try: model_manager = ModelManager() @@ -142,32 +138,32 @@ class DatasetDocumentSegmentApi(Resource): tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id ).first() if not segment: - raise NotFound('Segment not found.') + raise NotFound("Segment not found.") - if segment.status != 'completed': - raise NotFound('Segment is not completed, enable or disable function is not allowed') + if segment.status != "completed": + raise NotFound("Segment is not completed, enable or disable function is not allowed") - document_indexing_cache_key = 'document_{}_indexing'.format(segment.document_id) + document_indexing_cache_key = "document_{}_indexing".format(segment.document_id) cache_result = redis_client.get(document_indexing_cache_key) if cache_result is not None: raise InvalidActionError("Document is being indexed, please try again later") - indexing_cache_key = 'segment_{}_indexing'.format(segment.id) + indexing_cache_key = "segment_{}_indexing".format(segment.id) cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise InvalidActionError("Segment is being indexed, please try again later") @@ -186,7 +182,7 @@ class DatasetDocumentSegmentApi(Resource): enable_segment_to_index_task.delay(segment.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 elif action == "disable": if not segment.enabled: raise InvalidActionError("Segment is already disabled.") @@ -201,7 +197,7 @@ class DatasetDocumentSegmentApi(Resource): disable_segment_from_index_task.delay(segment.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 else: raise InvalidActionError() @@ -210,35 +206,36 @@ class DatasetDocumentSegmentAddApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('vector_space') - @cloud_edition_billing_knowledge_limit_check('add_segment') + @cloud_edition_billing_resource_check("vector_space") + @cloud_edition_billing_knowledge_limit_check("add_segment") def post(self, dataset_id, document_id): # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") if not current_user.is_editor: raise Forbidden() # check embedding model setting - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": try: model_manager = ModelManager() model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) try: @@ -247,37 +244,34 @@ class DatasetDocumentSegmentAddApi(Resource): raise Forbidden(str(e)) # validate args parser = reqparse.RequestParser() - parser.add_argument('content', type=str, required=True, nullable=False, location='json') - parser.add_argument('answer', type=str, required=False, nullable=True, location='json') - parser.add_argument('keywords', type=list, required=False, nullable=True, location='json') + parser.add_argument("content", type=str, required=True, nullable=False, location="json") + parser.add_argument("answer", type=str, required=False, nullable=True, location="json") + parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") args = parser.parse_args() SegmentService.segment_create_args_validate(args, document) segment = SegmentService.create_segment(args, document, dataset) - return { - 'data': marshal(segment, segment_fields), - 'doc_form': document.doc_form - }, 200 + return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 class DatasetDocumentSegmentUpdateApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('vector_space') + @cloud_edition_billing_resource_check("vector_space") def patch(self, dataset_id, document_id, segment_id): # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') - if dataset.indexing_technique == 'high_quality': + raise NotFound("Document not found.") + if dataset.indexing_technique == "high_quality": # check embedding model setting try: model_manager = ModelManager() @@ -285,22 +279,22 @@ class DatasetDocumentSegmentUpdateApi(Resource): tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) # check segment segment_id = str(segment_id) segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id ).first() if not segment: - raise NotFound('Segment not found.') + raise NotFound("Segment not found.") # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() @@ -310,16 +304,13 @@ class DatasetDocumentSegmentUpdateApi(Resource): raise Forbidden(str(e)) # validate args parser = reqparse.RequestParser() - parser.add_argument('content', type=str, required=True, nullable=False, location='json') - parser.add_argument('answer', type=str, required=False, nullable=True, location='json') - parser.add_argument('keywords', type=list, required=False, nullable=True, location='json') + parser.add_argument("content", type=str, required=True, nullable=False, location="json") + parser.add_argument("answer", type=str, required=False, nullable=True, location="json") + parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") args = parser.parse_args() SegmentService.segment_create_args_validate(args, document) segment = SegmentService.update_segment(args, segment, document, dataset) - return { - 'data': marshal(segment, segment_fields), - 'doc_form': document.doc_form - }, 200 + return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 @setup_required @login_required @@ -329,22 +320,21 @@ class DatasetDocumentSegmentUpdateApi(Resource): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") # check segment segment_id = str(segment_id) segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id ).first() if not segment: - raise NotFound('Segment not found.') + raise NotFound("Segment not found.") # The role of the current user in the ta table must be admin or owner if not current_user.is_editor: raise Forbidden() @@ -353,36 +343,36 @@ class DatasetDocumentSegmentUpdateApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) SegmentService.delete_segment(segment, document, dataset) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class DatasetDocumentSegmentBatchImportApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('vector_space') - @cloud_edition_billing_knowledge_limit_check('add_segment') + @cloud_edition_billing_resource_check("vector_space") + @cloud_edition_billing_knowledge_limit_check("add_segment") def post(self, dataset_id, document_id): # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") # get file from request - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: raise TooManyFilesError() # check file type - if not file.filename.endswith('.csv'): + if not file.filename.endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") try: @@ -390,51 +380,47 @@ class DatasetDocumentSegmentBatchImportApi(Resource): df = pd.read_csv(file) result = [] for index, row in df.iterrows(): - if document.doc_form == 'qa_model': - data = {'content': row[0], 'answer': row[1]} + if document.doc_form == "qa_model": + data = {"content": row[0], "answer": row[1]} else: - data = {'content': row[0]} + data = {"content": row[0]} result.append(data) if len(result) == 0: raise ValueError("The CSV file is empty.") # async job job_id = str(uuid.uuid4()) - indexing_cache_key = 'segment_batch_import_{}'.format(str(job_id)) + indexing_cache_key = "segment_batch_import_{}".format(str(job_id)) # send batch add segments task - redis_client.setnx(indexing_cache_key, 'waiting') - batch_create_segment_to_index_task.delay(str(job_id), result, dataset_id, document_id, - current_user.current_tenant_id, current_user.id) + redis_client.setnx(indexing_cache_key, "waiting") + batch_create_segment_to_index_task.delay( + str(job_id), result, dataset_id, document_id, current_user.current_tenant_id, current_user.id + ) except Exception as e: - return {'error': str(e)}, 500 - return { - 'job_id': job_id, - 'job_status': 'waiting' - }, 200 + return {"error": str(e)}, 500 + return {"job_id": job_id, "job_status": "waiting"}, 200 @setup_required @login_required @account_initialization_required def get(self, job_id): job_id = str(job_id) - indexing_cache_key = 'segment_batch_import_{}'.format(job_id) + indexing_cache_key = "segment_batch_import_{}".format(job_id) cache_result = redis_client.get(indexing_cache_key) if cache_result is None: raise ValueError("The job is not exist.") - return { - 'job_id': job_id, - 'job_status': cache_result.decode() - }, 200 + return {"job_id": job_id, "job_status": cache_result.decode()}, 200 -api.add_resource(DatasetDocumentSegmentListApi, - '/datasets//documents//segments') -api.add_resource(DatasetDocumentSegmentApi, - '/datasets//segments//') -api.add_resource(DatasetDocumentSegmentAddApi, - '/datasets//documents//segment') -api.add_resource(DatasetDocumentSegmentUpdateApi, - '/datasets//documents//segments/') -api.add_resource(DatasetDocumentSegmentBatchImportApi, - '/datasets//documents//segments/batch_import', - '/datasets/batch_import_status/') +api.add_resource(DatasetDocumentSegmentListApi, "/datasets//documents//segments") +api.add_resource(DatasetDocumentSegmentApi, "/datasets//segments//") +api.add_resource(DatasetDocumentSegmentAddApi, "/datasets//documents//segment") +api.add_resource( + DatasetDocumentSegmentUpdateApi, + "/datasets//documents//segments/", +) +api.add_resource( + DatasetDocumentSegmentBatchImportApi, + "/datasets//documents//segments/batch_import", + "/datasets/batch_import_status/", +) diff --git a/api/controllers/console/datasets/error.py b/api/controllers/console/datasets/error.py index 9270b610c2..6a7a3971a8 100644 --- a/api/controllers/console/datasets/error.py +++ b/api/controllers/console/datasets/error.py @@ -2,90 +2,90 @@ from libs.exception import BaseHTTPException class NoFileUploadedError(BaseHTTPException): - error_code = 'no_file_uploaded' + error_code = "no_file_uploaded" description = "Please upload your file." code = 400 class TooManyFilesError(BaseHTTPException): - error_code = 'too_many_files' + error_code = "too_many_files" description = "Only one file is allowed." code = 400 class FileTooLargeError(BaseHTTPException): - error_code = 'file_too_large' + error_code = "file_too_large" description = "File size exceeded. {message}" code = 413 class UnsupportedFileTypeError(BaseHTTPException): - error_code = 'unsupported_file_type' + error_code = "unsupported_file_type" description = "File type not allowed." code = 415 class HighQualityDatasetOnlyError(BaseHTTPException): - error_code = 'high_quality_dataset_only' + error_code = "high_quality_dataset_only" description = "Current operation only supports 'high-quality' datasets." code = 400 class DatasetNotInitializedError(BaseHTTPException): - error_code = 'dataset_not_initialized' + error_code = "dataset_not_initialized" description = "The dataset is still being initialized or indexing. Please wait a moment." code = 400 class ArchivedDocumentImmutableError(BaseHTTPException): - error_code = 'archived_document_immutable' + error_code = "archived_document_immutable" description = "The archived document is not editable." code = 403 class DatasetNameDuplicateError(BaseHTTPException): - error_code = 'dataset_name_duplicate' + error_code = "dataset_name_duplicate" description = "The dataset name already exists. Please modify your dataset name." code = 409 class InvalidActionError(BaseHTTPException): - error_code = 'invalid_action' + error_code = "invalid_action" description = "Invalid action." code = 400 class DocumentAlreadyFinishedError(BaseHTTPException): - error_code = 'document_already_finished' + error_code = "document_already_finished" description = "The document has been processed. Please refresh the page or go to the document details." code = 400 class DocumentIndexingError(BaseHTTPException): - error_code = 'document_indexing' + error_code = "document_indexing" description = "The document is being processed and cannot be edited." code = 400 class InvalidMetadataError(BaseHTTPException): - error_code = 'invalid_metadata' + error_code = "invalid_metadata" description = "The metadata content is incorrect. Please check and verify." code = 400 class WebsiteCrawlError(BaseHTTPException): - error_code = 'crawl_failed' + error_code = "crawl_failed" description = "{message}" code = 500 class DatasetInUseError(BaseHTTPException): - error_code = 'dataset_in_use' + error_code = "dataset_in_use" description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it." code = 409 class IndexingEstimateError(BaseHTTPException): - error_code = 'indexing_estimate_error' + error_code = "indexing_estimate_error" description = "Knowledge indexing estimate failed: {message}" code = 500 diff --git a/api/controllers/console/datasets/file.py b/api/controllers/console/datasets/file.py index 3b2083bcc3..846aa70e86 100644 --- a/api/controllers/console/datasets/file.py +++ b/api/controllers/console/datasets/file.py @@ -21,7 +21,6 @@ PREVIEW_WORDS_LIMIT = 3000 class FileApi(Resource): - @setup_required @login_required @account_initialization_required @@ -31,23 +30,22 @@ class FileApi(Resource): batch_count_limit = dify_config.UPLOAD_FILE_BATCH_LIMIT image_file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT return { - 'file_size_limit': file_size_limit, - 'batch_count_limit': batch_count_limit, - 'image_file_size_limit': image_file_size_limit + "file_size_limit": file_size_limit, + "batch_count_limit": batch_count_limit, + "image_file_size_limit": image_file_size_limit, }, 200 @setup_required @login_required @account_initialization_required @marshal_with(file_fields) - @cloud_edition_billing_resource_check(resource='documents') + @cloud_edition_billing_resource_check("documents") def post(self): - # get file from request - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: @@ -69,7 +67,7 @@ class FilePreviewApi(Resource): def get(self, file_id): file_id = str(file_id) text = FileService.get_file_preview(file_id) - return {'content': text} + return {"content": text} class FileSupportTypeApi(Resource): @@ -78,10 +76,10 @@ class FileSupportTypeApi(Resource): @account_initialization_required def get(self): etl_type = dify_config.ETL_TYPE - allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS - return {'allowed_extensions': allowed_extensions} + allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS + return {"allowed_extensions": allowed_extensions} -api.add_resource(FileApi, '/files/upload') -api.add_resource(FilePreviewApi, '/files//preview') -api.add_resource(FileSupportTypeApi, '/files/support-type') +api.add_resource(FileApi, "/files/upload") +api.add_resource(FilePreviewApi, "/files//preview") +api.add_resource(FileSupportTypeApi, "/files/support-type") diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 8771bf909e..0b4a7be986 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -29,7 +29,6 @@ from services.hit_testing_service import HitTestingService class HitTestingApi(Resource): - @setup_required @login_required @account_initialization_required @@ -46,8 +45,8 @@ class HitTestingApi(Resource): raise Forbidden(str(e)) parser = reqparse.RequestParser() - parser.add_argument('query', type=str, location='json') - parser.add_argument('retrieval_model', type=dict, required=False, location='json') + parser.add_argument("query", type=str, location="json") + parser.add_argument("retrieval_model", type=dict, required=False, location="json") args = parser.parse_args() HitTestingService.hit_testing_args_check(args) @@ -55,13 +54,13 @@ class HitTestingApi(Resource): try: response = HitTestingService.retrieve( dataset=dataset, - query=args['query'], + query=args["query"], account=current_user, - retrieval_model=args['retrieval_model'], - limit=10 + retrieval_model=args["retrieval_model"], + limit=10, ) - return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)} + return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)} except services.errors.index.IndexNotInitializedError: raise DatasetNotInitializedError() except ProviderTokenNotInitError as ex: @@ -73,7 +72,8 @@ class HitTestingApi(Resource): except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model or Reranking Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except InvokeError as e: raise CompletionRequestError(e.description) except ValueError as e: @@ -83,4 +83,4 @@ class HitTestingApi(Resource): raise InternalServerError(str(e)) -api.add_resource(HitTestingApi, '/datasets//hit-testing') +api.add_resource(HitTestingApi, "/datasets//hit-testing") diff --git a/api/controllers/console/datasets/website.py b/api/controllers/console/datasets/website.py index bbd91256f1..cb54f1aacb 100644 --- a/api/controllers/console/datasets/website.py +++ b/api/controllers/console/datasets/website.py @@ -9,16 +9,14 @@ from services.website_service import WebsiteService class WebsiteCrawlApi(Resource): - @setup_required @login_required @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('provider', type=str, choices=['firecrawl'], - required=True, nullable=True, location='json') - parser.add_argument('url', type=str, required=True, nullable=True, location='json') - parser.add_argument('options', type=dict, required=True, nullable=True, location='json') + parser.add_argument("provider", type=str, choices=["firecrawl"], required=True, nullable=True, location="json") + parser.add_argument("url", type=str, required=True, nullable=True, location="json") + parser.add_argument("options", type=dict, required=True, nullable=True, location="json") args = parser.parse_args() WebsiteService.document_create_args_validate(args) # crawl url @@ -35,15 +33,15 @@ class WebsiteCrawlStatusApi(Resource): @account_initialization_required def get(self, job_id: str): parser = reqparse.RequestParser() - parser.add_argument('provider', type=str, choices=['firecrawl'], required=True, location='args') + parser.add_argument("provider", type=str, choices=["firecrawl"], required=True, location="args") args = parser.parse_args() # get crawl status try: - result = WebsiteService.get_crawl_status(job_id, args['provider']) + result = WebsiteService.get_crawl_status(job_id, args["provider"]) except Exception as e: raise WebsiteCrawlError(str(e)) return result, 200 -api.add_resource(WebsiteCrawlApi, '/website/crawl') -api.add_resource(WebsiteCrawlStatusApi, '/website/crawl/status/') +api.add_resource(WebsiteCrawlApi, "/website/crawl") +api.add_resource(WebsiteCrawlStatusApi, "/website/crawl/status/") diff --git a/api/controllers/console/error.py b/api/controllers/console/error.py index 888dad83cc..1c70ea6c59 100644 --- a/api/controllers/console/error.py +++ b/api/controllers/console/error.py @@ -2,35 +2,41 @@ from libs.exception import BaseHTTPException class AlreadySetupError(BaseHTTPException): - error_code = 'already_setup' + error_code = "already_setup" description = "Dify has been successfully installed. Please refresh the page or return to the dashboard homepage." code = 403 class NotSetupError(BaseHTTPException): - error_code = 'not_setup' - description = "Dify has not been initialized and installed yet. " \ - "Please proceed with the initialization and installation process first." + error_code = "not_setup" + description = ( + "Dify has not been initialized and installed yet. " + "Please proceed with the initialization and installation process first." + ) code = 401 + class NotInitValidateError(BaseHTTPException): - error_code = 'not_init_validated' - description = "Init validation has not been completed yet. " \ - "Please proceed with the init validation process first." + error_code = "not_init_validated" + description = ( + "Init validation has not been completed yet. " "Please proceed with the init validation process first." + ) code = 401 + class InitValidateFailedError(BaseHTTPException): - error_code = 'init_validate_failed' + error_code = "init_validate_failed" description = "Init validation failed. Please check the password and try again." code = 401 + class AccountNotLinkTenantError(BaseHTTPException): - error_code = 'account_not_link_tenant' + error_code = "account_not_link_tenant" description = "Account not link tenant." code = 403 class AlreadyActivateError(BaseHTTPException): - error_code = 'already_activate' + error_code = "already_activate" description = "Auth Token is invalid or account already activated, please check again." code = 403 diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index 27cc83042a..71cb060ecc 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -33,14 +33,10 @@ class ChatAudioApi(InstalledAppResource): def post(self, installed_app): app_model = installed_app.app - file = request.files['file'] + file = request.files["file"] try: - response = AudioService.transcript_asr( - app_model=app_model, - file=file, - end_user=None - ) + response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None) return response except services.errors.app_model_config.AppModelConfigBrokenError: @@ -76,30 +72,31 @@ class ChatTextApi(InstalledAppResource): app_model = installed_app.app try: parser = reqparse.RequestParser() - parser.add_argument('message_id', type=str, required=False, location='json') - parser.add_argument('voice', type=str, location='json') - parser.add_argument('text', type=str, location='json') - parser.add_argument('streaming', type=bool, location='json') + parser.add_argument("message_id", type=str, required=False, location="json") + parser.add_argument("voice", type=str, location="json") + parser.add_argument("text", type=str, location="json") + parser.add_argument("streaming", type=bool, location="json") args = parser.parse_args() - message_id = args.get('message_id', None) - text = args.get('text', None) - if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] - and app_model.workflow - and app_model.workflow.features_dict): - text_to_speech = app_model.workflow.features_dict.get('text_to_speech') - voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice') + message_id = args.get("message_id", None) + text = args.get("text", None) + if ( + app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + and app_model.workflow + and app_model.workflow.features_dict + ): + text_to_speech = app_model.workflow.features_dict.get("text_to_speech") + voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") else: try: - voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice') + voice = ( + args.get("voice") + if args.get("voice") + else app_model.app_model_config.text_to_speech_dict.get("voice") + ) except Exception: voice = None - response = AudioService.transcript_tts( - app_model=app_model, - message_id=message_id, - voice=voice, - text=text - ) + response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text) return response except services.errors.app_model_config.AppModelConfigBrokenError: logging.exception("App model config broken.") @@ -127,7 +124,7 @@ class ChatTextApi(InstalledAppResource): raise InternalServerError() -api.add_resource(ChatAudioApi, '/installed-apps//audio-to-text', endpoint='installed_app_audio') -api.add_resource(ChatTextApi, '/installed-apps//text-to-audio', endpoint='installed_app_text') +api.add_resource(ChatAudioApi, "/installed-apps//audio-to-text", endpoint="installed_app_audio") +api.add_resource(ChatTextApi, "/installed-apps//text-to-audio", endpoint="installed_app_text") # api.add_resource(ChatTextApiWithMessageId, '/installed-apps//text-to-audio/message-id', # endpoint='installed_app_text_with_message_id') diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 869b56e13b..c039e8bca5 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -30,33 +30,28 @@ from services.app_generate_service import AppGenerateService # define completion api for user class CompletionApi(InstalledAppResource): - def post(self, installed_app): app_model = installed_app.app - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, location='json', default='') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, location="json", default="") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' - args['auto_generate_name'] = False + streaming = args["response_mode"] == "streaming" + args["auto_generate_name"] = False installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() try: response = AppGenerateService.generate( - app_model=app_model, - user=current_user, - args=args, - invoke_from=InvokeFrom.EXPLORE, - streaming=streaming + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming ) return helper.compact_generate_response(response) @@ -85,12 +80,12 @@ class CompletionApi(InstalledAppResource): class CompletionStopApi(InstalledAppResource): def post(self, installed_app, task_id): app_model = installed_app.app - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class ChatApi(InstalledAppResource): @@ -101,25 +96,21 @@ class ChatApi(InstalledAppResource): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, required=True, location='json') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('conversation_id', type=uuid_value, location='json') - parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, required=True, location="json") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("conversation_id", type=uuid_value, location="json") + parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") args = parser.parse_args() - args['auto_generate_name'] = False + args["auto_generate_name"] = False installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() try: response = AppGenerateService.generate( - app_model=app_model, - user=current_user, - args=args, - invoke_from=InvokeFrom.EXPLORE, - streaming=True + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True ) return helper.compact_generate_response(response) @@ -154,10 +145,22 @@ class ChatStopApi(InstalledAppResource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(CompletionApi, '/installed-apps//completion-messages', endpoint='installed_app_completion') -api.add_resource(CompletionStopApi, '/installed-apps//completion-messages//stop', endpoint='installed_app_stop_completion') -api.add_resource(ChatApi, '/installed-apps//chat-messages', endpoint='installed_app_chat_completion') -api.add_resource(ChatStopApi, '/installed-apps//chat-messages//stop', endpoint='installed_app_stop_chat_completion') +api.add_resource( + CompletionApi, "/installed-apps//completion-messages", endpoint="installed_app_completion" +) +api.add_resource( + CompletionStopApi, + "/installed-apps//completion-messages//stop", + endpoint="installed_app_stop_completion", +) +api.add_resource( + ChatApi, "/installed-apps//chat-messages", endpoint="installed_app_chat_completion" +) +api.add_resource( + ChatStopApi, + "/installed-apps//chat-messages//stop", + endpoint="installed_app_stop_chat_completion", +) diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index ea0fa4e17e..2918024b64 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -16,7 +16,6 @@ from services.web_conversation_service import WebConversationService class ConversationListApi(InstalledAppResource): - @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, installed_app): app_model = installed_app.app @@ -25,21 +24,21 @@ class ConversationListApi(InstalledAppResource): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') - parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + parser.add_argument("pinned", type=str, choices=["true", "false", None], location="args") args = parser.parse_args() pinned = None - if 'pinned' in args and args['pinned'] is not None: - pinned = True if args['pinned'] == 'true' else False + if "pinned" in args and args["pinned"] is not None: + pinned = True if args["pinned"] == "true" else False try: return WebConversationService.pagination_by_last_id( app_model=app_model, user=current_user, - last_id=args['last_id'], - limit=args['limit'], + last_id=args["last_id"], + limit=args["limit"], invoke_from=InvokeFrom.EXPLORE, pinned=pinned, ) @@ -65,7 +64,6 @@ class ConversationApi(InstalledAppResource): class ConversationRenameApi(InstalledAppResource): - @marshal_with(simple_conversation_fields) def post(self, installed_app, c_id): app_model = installed_app.app @@ -76,24 +74,19 @@ class ConversationRenameApi(InstalledAppResource): conversation_id = str(c_id) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=False, location='json') - parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json') + parser.add_argument("name", type=str, required=False, location="json") + parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json") args = parser.parse_args() try: return ConversationService.rename( - app_model, - conversation_id, - current_user, - args['name'], - args['auto_generate'] + app_model, conversation_id, current_user, args["name"], args["auto_generate"] ) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") class ConversationPinApi(InstalledAppResource): - def patch(self, installed_app, c_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) @@ -123,8 +116,26 @@ class ConversationUnPinApi(InstalledAppResource): return {"result": "success"} -api.add_resource(ConversationRenameApi, '/installed-apps//conversations//name', endpoint='installed_app_conversation_rename') -api.add_resource(ConversationListApi, '/installed-apps//conversations', endpoint='installed_app_conversations') -api.add_resource(ConversationApi, '/installed-apps//conversations/', endpoint='installed_app_conversation') -api.add_resource(ConversationPinApi, '/installed-apps//conversations//pin', endpoint='installed_app_conversation_pin') -api.add_resource(ConversationUnPinApi, '/installed-apps//conversations//unpin', endpoint='installed_app_conversation_unpin') +api.add_resource( + ConversationRenameApi, + "/installed-apps//conversations//name", + endpoint="installed_app_conversation_rename", +) +api.add_resource( + ConversationListApi, "/installed-apps//conversations", endpoint="installed_app_conversations" +) +api.add_resource( + ConversationApi, + "/installed-apps//conversations/", + endpoint="installed_app_conversation", +) +api.add_resource( + ConversationPinApi, + "/installed-apps//conversations//pin", + endpoint="installed_app_conversation_pin", +) +api.add_resource( + ConversationUnPinApi, + "/installed-apps//conversations//unpin", + endpoint="installed_app_conversation_unpin", +) diff --git a/api/controllers/console/explore/error.py b/api/controllers/console/explore/error.py index 9c3216ecc8..18221b7797 100644 --- a/api/controllers/console/explore/error.py +++ b/api/controllers/console/explore/error.py @@ -2,24 +2,24 @@ from libs.exception import BaseHTTPException class NotCompletionAppError(BaseHTTPException): - error_code = 'not_completion_app' + error_code = "not_completion_app" description = "Not Completion App" code = 400 class NotChatAppError(BaseHTTPException): - error_code = 'not_chat_app' + error_code = "not_chat_app" description = "App mode is invalid." code = 400 class NotWorkflowAppError(BaseHTTPException): - error_code = 'not_workflow_app' + error_code = "not_workflow_app" description = "Only support workflow app." code = 400 class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException): - error_code = 'app_suggested_questions_after_answer_disabled' + error_code = "app_suggested_questions_after_answer_disabled" description = "Function Suggested questions after answer disabled." code = 403 diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index ec7bbed307..3f1e64a247 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -21,72 +21,72 @@ class InstalledAppsListApi(Resource): @marshal_with(installed_app_list_fields) def get(self): current_tenant_id = current_user.current_tenant_id - installed_apps = db.session.query(InstalledApp).filter( - InstalledApp.tenant_id == current_tenant_id - ).all() + installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all() current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) installed_apps = [ { - 'id': installed_app.id, - 'app': installed_app.app, - 'app_owner_tenant_id': installed_app.app_owner_tenant_id, - 'is_pinned': installed_app.is_pinned, - 'last_used_at': installed_app.last_used_at, - 'editable': current_user.role in ["owner", "admin"], - 'uninstallable': current_tenant_id == installed_app.app_owner_tenant_id + "id": installed_app.id, + "app": installed_app.app, + "app_owner_tenant_id": installed_app.app_owner_tenant_id, + "is_pinned": installed_app.is_pinned, + "last_used_at": installed_app.last_used_at, + "editable": current_user.role in ["owner", "admin"], + "uninstallable": current_tenant_id == installed_app.app_owner_tenant_id, } for installed_app in installed_apps + if installed_app.app is not None ] - installed_apps.sort(key=lambda app: (-app['is_pinned'], - app['last_used_at'] is None, - -app['last_used_at'].timestamp() if app['last_used_at'] is not None else 0)) + installed_apps.sort( + key=lambda app: ( + -app["is_pinned"], + app["last_used_at"] is None, + -app["last_used_at"].timestamp() if app["last_used_at"] is not None else 0, + ) + ) - return {'installed_apps': installed_apps} + return {"installed_apps": installed_apps} @login_required @account_initialization_required - @cloud_edition_billing_resource_check('apps') + @cloud_edition_billing_resource_check("apps") def post(self): parser = reqparse.RequestParser() - parser.add_argument('app_id', type=str, required=True, help='Invalid app_id') + parser.add_argument("app_id", type=str, required=True, help="Invalid app_id") args = parser.parse_args() - recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first() + recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first() if recommended_app is None: - raise NotFound('App not found') + raise NotFound("App not found") current_tenant_id = current_user.current_tenant_id - app = db.session.query(App).filter( - App.id == args['app_id'] - ).first() + app = db.session.query(App).filter(App.id == args["app_id"]).first() if app is None: - raise NotFound('App not found') + raise NotFound("App not found") if not app.is_public: - raise Forbidden('You can\'t install a non-public app') + raise Forbidden("You can't install a non-public app") - installed_app = InstalledApp.query.filter(and_( - InstalledApp.app_id == args['app_id'], - InstalledApp.tenant_id == current_tenant_id - )).first() + installed_app = InstalledApp.query.filter( + and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id) + ).first() if installed_app is None: # todo: position recommended_app.install_count += 1 new_installed_app = InstalledApp( - app_id=args['app_id'], + app_id=args["app_id"], tenant_id=current_tenant_id, app_owner_tenant_id=app.tenant_id, is_pinned=False, - last_used_at=datetime.now(timezone.utc).replace(tzinfo=None) + last_used_at=datetime.now(timezone.utc).replace(tzinfo=None), ) db.session.add(new_installed_app) db.session.commit() - return {'message': 'App installed successfully'} + return {"message": "App installed successfully"} class InstalledAppApi(InstalledAppResource): @@ -94,30 +94,31 @@ class InstalledAppApi(InstalledAppResource): update and delete an installed app use InstalledAppResource to apply default decorators and get installed_app """ + def delete(self, installed_app): if installed_app.app_owner_tenant_id == current_user.current_tenant_id: - raise BadRequest('You can\'t uninstall an app owned by the current tenant') + raise BadRequest("You can't uninstall an app owned by the current tenant") db.session.delete(installed_app) db.session.commit() - return {'result': 'success', 'message': 'App uninstalled successfully'} + return {"result": "success", "message": "App uninstalled successfully"} def patch(self, installed_app): parser = reqparse.RequestParser() - parser.add_argument('is_pinned', type=inputs.boolean) + parser.add_argument("is_pinned", type=inputs.boolean) args = parser.parse_args() commit_args = False - if 'is_pinned' in args: - installed_app.is_pinned = args['is_pinned'] + if "is_pinned" in args: + installed_app.is_pinned = args["is_pinned"] commit_args = True if commit_args: db.session.commit() - return {'result': 'success', 'message': 'App info updated successfully'} + return {"result": "success", "message": "App info updated successfully"} -api.add_resource(InstalledAppsListApi, '/installed-apps') -api.add_resource(InstalledAppApi, '/installed-apps/') +api.add_resource(InstalledAppsListApi, "/installed-apps") +api.add_resource(InstalledAppApi, "/installed-apps/") diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 3523a86900..f5eb185172 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -44,19 +44,21 @@ class MessageListApi(InstalledAppResource): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') - parser.add_argument('first_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") + parser.add_argument("first_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() try: - return MessageService.pagination_by_first_id(app_model, current_user, - args['conversation_id'], args['first_id'], args['limit']) + return MessageService.pagination_by_first_id( + app_model, current_user, args["conversation_id"], args["first_id"], args["limit"] + ) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.message.FirstMessageNotExistsError: raise NotFound("First Message Not Exists.") + class MessageFeedbackApi(InstalledAppResource): def post(self, installed_app, message_id): app_model = installed_app.app @@ -64,30 +66,32 @@ class MessageFeedbackApi(InstalledAppResource): message_id = str(message_id) parser = reqparse.RequestParser() - parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') + parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, current_user, args['rating']) + MessageService.create_feedback(app_model, message_id, current_user, args["rating"]) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") - return {'result': 'success'} + return {"result": "success"} class MessageMoreLikeThisApi(InstalledAppResource): def get(self, installed_app, message_id): app_model = installed_app.app - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() message_id = str(message_id) parser = reqparse.RequestParser() - parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args') + parser.add_argument( + "response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args" + ) args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' + streaming = args["response_mode"] == "streaming" try: response = AppGenerateService.generate_more_like_this( @@ -95,7 +99,7 @@ class MessageMoreLikeThisApi(InstalledAppResource): user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE, - streaming=streaming + streaming=streaming, ) return helper.compact_generate_response(response) except MessageNotExistsError: @@ -128,10 +132,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource): try: questions = MessageService.get_suggested_questions_after_answer( - app_model=app_model, - user=current_user, - message_id=message_id, - invoke_from=InvokeFrom.EXPLORE + app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE ) except MessageNotExistsError: raise NotFound("Message not found") @@ -151,10 +152,22 @@ class MessageSuggestedQuestionApi(InstalledAppResource): logging.exception("internal server error.") raise InternalServerError() - return {'data': questions} + return {"data": questions} -api.add_resource(MessageListApi, '/installed-apps//messages', endpoint='installed_app_messages') -api.add_resource(MessageFeedbackApi, '/installed-apps//messages//feedbacks', endpoint='installed_app_message_feedback') -api.add_resource(MessageMoreLikeThisApi, '/installed-apps//messages//more-like-this', endpoint='installed_app_more_like_this') -api.add_resource(MessageSuggestedQuestionApi, '/installed-apps//messages//suggested-questions', endpoint='installed_app_suggested_question') +api.add_resource(MessageListApi, "/installed-apps//messages", endpoint="installed_app_messages") +api.add_resource( + MessageFeedbackApi, + "/installed-apps//messages//feedbacks", + endpoint="installed_app_message_feedback", +) +api.add_resource( + MessageMoreLikeThisApi, + "/installed-apps//messages//more-like-this", + endpoint="installed_app_more_like_this", +) +api.add_resource( + MessageSuggestedQuestionApi, + "/installed-apps//messages//suggested-questions", + endpoint="installed_app_suggested_question", +) diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index 0a168d6306..ad55b04043 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,4 +1,3 @@ - from flask_restful import fields, marshal_with from configs import dify_config @@ -11,33 +10,32 @@ from services.app_service import AppService class AppParameterApi(InstalledAppResource): """Resource for app variables.""" + variable_fields = { - 'key': fields.String, - 'name': fields.String, - 'description': fields.String, - 'type': fields.String, - 'default': fields.String, - 'max_length': fields.Integer, - 'options': fields.List(fields.String) + "key": fields.String, + "name": fields.String, + "description": fields.String, + "type": fields.String, + "default": fields.String, + "max_length": fields.Integer, + "options": fields.List(fields.String), } - system_parameters_fields = { - 'image_file_size_limit': fields.String - } + system_parameters_fields = {"image_file_size_limit": fields.String} parameters_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw, - 'suggested_questions_after_answer': fields.Raw, - 'speech_to_text': fields.Raw, - 'text_to_speech': fields.Raw, - 'retriever_resource': fields.Raw, - 'annotation_reply': fields.Raw, - 'more_like_this': fields.Raw, - 'user_input_form': fields.Raw, - 'sensitive_word_avoidance': fields.Raw, - 'file_upload': fields.Raw, - 'system_parameters': fields.Nested(system_parameters_fields) + "opening_statement": fields.String, + "suggested_questions": fields.Raw, + "suggested_questions_after_answer": fields.Raw, + "speech_to_text": fields.Raw, + "text_to_speech": fields.Raw, + "retriever_resource": fields.Raw, + "annotation_reply": fields.Raw, + "more_like_this": fields.Raw, + "user_input_form": fields.Raw, + "sensitive_word_avoidance": fields.Raw, + "file_upload": fields.Raw, + "system_parameters": fields.Nested(system_parameters_fields), } @marshal_with(parameters_fields) @@ -56,30 +54,35 @@ class AppParameterApi(InstalledAppResource): app_model_config = app_model.app_model_config features_dict = app_model_config.to_dict() - user_input_form = features_dict.get('user_input_form', []) + user_input_form = features_dict.get("user_input_form", []) return { - 'opening_statement': features_dict.get('opening_statement'), - 'suggested_questions': features_dict.get('suggested_questions', []), - 'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer', - {"enabled": False}), - 'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}), - 'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}), - 'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}), - 'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}), - 'more_like_this': features_dict.get('more_like_this', {"enabled": False}), - 'user_input_form': user_input_form, - 'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance', - {"enabled": False, "type": "", "configs": []}), - 'file_upload': features_dict.get('file_upload', {"image": { - "enabled": False, - "number_limits": 3, - "detail": "high", - "transfer_methods": ["remote_url", "local_file"] - }}), - 'system_parameters': { - 'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT - } + "opening_statement": features_dict.get("opening_statement"), + "suggested_questions": features_dict.get("suggested_questions", []), + "suggested_questions_after_answer": features_dict.get( + "suggested_questions_after_answer", {"enabled": False} + ), + "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}), + "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}), + "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}), + "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}), + "more_like_this": features_dict.get("more_like_this", {"enabled": False}), + "user_input_form": user_input_form, + "sensitive_word_avoidance": features_dict.get( + "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []} + ), + "file_upload": features_dict.get( + "file_upload", + { + "image": { + "enabled": False, + "number_limits": 3, + "detail": "high", + "transfer_methods": ["remote_url", "local_file"], + } + }, + ), + "system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT}, } @@ -90,6 +93,7 @@ class ExploreAppMetaApi(InstalledAppResource): return AppService().get_app_meta(app_model) -api.add_resource(AppParameterApi, '/installed-apps//parameters', - endpoint='installed_app_parameters') -api.add_resource(ExploreAppMetaApi, '/installed-apps//meta', endpoint='installed_app_meta') +api.add_resource( + AppParameterApi, "/installed-apps//parameters", endpoint="installed_app_parameters" +) +api.add_resource(ExploreAppMetaApi, "/installed-apps//meta", endpoint="installed_app_meta") diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 6e10e2ec92..5daaa1e7c3 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -8,28 +8,28 @@ from libs.login import login_required from services.recommended_app_service import RecommendedAppService app_fields = { - 'id': fields.String, - 'name': fields.String, - 'mode': fields.String, - 'icon': fields.String, - 'icon_background': fields.String + "id": fields.String, + "name": fields.String, + "mode": fields.String, + "icon": fields.String, + "icon_background": fields.String, } recommended_app_fields = { - 'app': fields.Nested(app_fields, attribute='app'), - 'app_id': fields.String, - 'description': fields.String(attribute='description'), - 'copyright': fields.String, - 'privacy_policy': fields.String, - 'custom_disclaimer': fields.String, - 'category': fields.String, - 'position': fields.Integer, - 'is_listed': fields.Boolean + "app": fields.Nested(app_fields, attribute="app"), + "app_id": fields.String, + "description": fields.String(attribute="description"), + "copyright": fields.String, + "privacy_policy": fields.String, + "custom_disclaimer": fields.String, + "category": fields.String, + "position": fields.Integer, + "is_listed": fields.Boolean, } recommended_app_list_fields = { - 'recommended_apps': fields.List(fields.Nested(recommended_app_fields)), - 'categories': fields.List(fields.String) + "recommended_apps": fields.List(fields.Nested(recommended_app_fields)), + "categories": fields.List(fields.String), } @@ -40,11 +40,11 @@ class RecommendedAppListApi(Resource): def get(self): # language args parser = reqparse.RequestParser() - parser.add_argument('language', type=str, location='args') + parser.add_argument("language", type=str, location="args") args = parser.parse_args() - if args.get('language') and args.get('language') in languages: - language_prefix = args.get('language') + if args.get("language") and args.get("language") in languages: + language_prefix = args.get("language") elif current_user and current_user.interface_language: language_prefix = current_user.interface_language else: @@ -61,5 +61,5 @@ class RecommendedAppApi(Resource): return RecommendedAppService.get_recommend_app_detail(app_id) -api.add_resource(RecommendedAppListApi, '/explore/apps') -api.add_resource(RecommendedAppApi, '/explore/apps/') +api.add_resource(RecommendedAppListApi, "/explore/apps") +api.add_resource(RecommendedAppApi, "/explore/apps/") diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index cf86b2fee1..a7ccf737a8 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -11,56 +11,54 @@ from libs.helper import TimestampField, uuid_value from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService -feedback_fields = { - 'rating': fields.String -} +feedback_fields = {"rating": fields.String} message_fields = { - 'id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'answer': fields.String, - 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), - 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), - 'created_at': TimestampField + "id": fields.String, + "inputs": fields.Raw, + "query": fields.String, + "answer": fields.String, + "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), + "created_at": TimestampField, } class SavedMessageListApi(InstalledAppResource): saved_message_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(message_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(message_fields)), } @marshal_with(saved_message_infinite_scroll_pagination_fields) def get(self, installed_app): app_model = installed_app.app - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - return SavedMessageService.pagination_by_last_id(app_model, current_user, args['last_id'], args['limit']) + return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"]) def post(self, installed_app): app_model = installed_app.app - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() parser = reqparse.RequestParser() - parser.add_argument('message_id', type=uuid_value, required=True, location='json') + parser.add_argument("message_id", type=uuid_value, required=True, location="json") args = parser.parse_args() try: - SavedMessageService.save(app_model, current_user, args['message_id']) + SavedMessageService.save(app_model, current_user, args["message_id"]) except MessageNotExistsError: raise NotFound("Message Not Exists.") - return {'result': 'success'} + return {"result": "success"} class SavedMessageApi(InstalledAppResource): @@ -69,13 +67,21 @@ class SavedMessageApi(InstalledAppResource): message_id = str(message_id) - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() SavedMessageService.delete(app_model, current_user, message_id) - return {'result': 'success'} + return {"result": "success"} -api.add_resource(SavedMessageListApi, '/installed-apps//saved-messages', endpoint='installed_app_saved_messages') -api.add_resource(SavedMessageApi, '/installed-apps//saved-messages/', endpoint='installed_app_saved_message') +api.add_resource( + SavedMessageListApi, + "/installed-apps//saved-messages", + endpoint="installed_app_saved_messages", +) +api.add_resource( + SavedMessageApi, + "/installed-apps//saved-messages/", + endpoint="installed_app_saved_message", +) diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index 7c5e211d47..45f99b1db9 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -35,17 +35,13 @@ class InstalledAppWorkflowRunApi(InstalledAppResource): raise NotWorkflowAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') - parser.add_argument('files', type=list, required=False, location='json') + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("files", type=list, required=False, location="json") args = parser.parse_args() try: response = AppGenerateService.generate( - app_model=app_model, - user=current_user, - args=args, - invoke_from=InvokeFrom.EXPLORE, - streaming=True + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True ) return helper.compact_generate_response(response) @@ -76,10 +72,10 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) - return { - "result": "success" - } + return {"result": "success"} -api.add_resource(InstalledAppWorkflowRunApi, '/installed-apps//workflows/run') -api.add_resource(InstalledAppWorkflowTaskStopApi, '/installed-apps//workflows/tasks//stop') +api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps//workflows/run") +api.add_resource( + InstalledAppWorkflowTaskStopApi, "/installed-apps//workflows/tasks//stop" +) diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index 84890f1b46..3c9317847b 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -14,29 +14,33 @@ def installed_app_required(view=None): def decorator(view): @wraps(view) def decorated(*args, **kwargs): - if not kwargs.get('installed_app_id'): - raise ValueError('missing installed_app_id in path parameters') + if not kwargs.get("installed_app_id"): + raise ValueError("missing installed_app_id in path parameters") - installed_app_id = kwargs.get('installed_app_id') + installed_app_id = kwargs.get("installed_app_id") installed_app_id = str(installed_app_id) - del kwargs['installed_app_id'] + del kwargs["installed_app_id"] - installed_app = db.session.query(InstalledApp).filter( - InstalledApp.id == str(installed_app_id), - InstalledApp.tenant_id == current_user.current_tenant_id - ).first() + installed_app = ( + db.session.query(InstalledApp) + .filter( + InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id + ) + .first() + ) if installed_app is None: - raise NotFound('Installed app not found') + raise NotFound("Installed app not found") if not installed_app.app: db.session.delete(installed_app) db.session.commit() - raise NotFound('Installed app not found') + raise NotFound("Installed app not found") return view(installed_app, *args, **kwargs) + return decorated if view: diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index fe73bcb985..5d6a8bf152 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -13,23 +13,18 @@ from services.code_based_extension_service import CodeBasedExtensionService class CodeBasedExtensionAPI(Resource): - @setup_required @login_required @account_initialization_required def get(self): parser = reqparse.RequestParser() - parser.add_argument('module', type=str, required=True, location='args') + parser.add_argument("module", type=str, required=True, location="args") args = parser.parse_args() - return { - 'module': args['module'], - 'data': CodeBasedExtensionService.get_code_based_extension(args['module']) - } + return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])} class APIBasedExtensionAPI(Resource): - @setup_required @login_required @account_initialization_required @@ -44,23 +39,22 @@ class APIBasedExtensionAPI(Resource): @marshal_with(api_based_extension_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') - parser.add_argument('api_endpoint', type=str, required=True, location='json') - parser.add_argument('api_key', type=str, required=True, location='json') + parser.add_argument("name", type=str, required=True, location="json") + parser.add_argument("api_endpoint", type=str, required=True, location="json") + parser.add_argument("api_key", type=str, required=True, location="json") args = parser.parse_args() extension_data = APIBasedExtension( tenant_id=current_user.current_tenant_id, - name=args['name'], - api_endpoint=args['api_endpoint'], - api_key=args['api_key'] + name=args["name"], + api_endpoint=args["api_endpoint"], + api_key=args["api_key"], ) return APIBasedExtensionService.save(extension_data) class APIBasedExtensionDetailAPI(Resource): - @setup_required @login_required @account_initialization_required @@ -82,16 +76,16 @@ class APIBasedExtensionDetailAPI(Resource): extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') - parser.add_argument('api_endpoint', type=str, required=True, location='json') - parser.add_argument('api_key', type=str, required=True, location='json') + parser.add_argument("name", type=str, required=True, location="json") + parser.add_argument("api_endpoint", type=str, required=True, location="json") + parser.add_argument("api_key", type=str, required=True, location="json") args = parser.parse_args() - extension_data_from_db.name = args['name'] - extension_data_from_db.api_endpoint = args['api_endpoint'] + extension_data_from_db.name = args["name"] + extension_data_from_db.api_endpoint = args["api_endpoint"] - if args['api_key'] != HIDDEN_VALUE: - extension_data_from_db.api_key = args['api_key'] + if args["api_key"] != HIDDEN_VALUE: + extension_data_from_db.api_key = args["api_key"] return APIBasedExtensionService.save(extension_data_from_db) @@ -106,10 +100,10 @@ class APIBasedExtensionDetailAPI(Resource): APIBasedExtensionService.delete(extension_data_from_db) - return {'result': 'success'} + return {"result": "success"} -api.add_resource(CodeBasedExtensionAPI, '/code-based-extension') +api.add_resource(CodeBasedExtensionAPI, "/code-based-extension") -api.add_resource(APIBasedExtensionAPI, '/api-based-extension') -api.add_resource(APIBasedExtensionDetailAPI, '/api-based-extension/') +api.add_resource(APIBasedExtensionAPI, "/api-based-extension") +api.add_resource(APIBasedExtensionDetailAPI, "/api-based-extension/") diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 8475cd8488..f0482f749d 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -10,7 +10,6 @@ from .wraps import account_initialization_required, cloud_utm_record class FeatureApi(Resource): - @setup_required @login_required @account_initialization_required @@ -24,5 +23,5 @@ class SystemFeatureApi(Resource): return FeatureService.get_system_features().model_dump() -api.add_resource(FeatureApi, '/features') -api.add_resource(SystemFeatureApi, '/system-features') +api.add_resource(FeatureApi, "/features") +api.add_resource(SystemFeatureApi, "/system-features") diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index 6feb1003a9..7d3ae677ee 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -14,12 +14,11 @@ from .wraps import only_edition_self_hosted class InitValidateAPI(Resource): - def get(self): init_status = get_init_validate_status() if init_status: - return { 'status': 'finished' } - return {'status': 'not_started' } + return {"status": "finished"} + return {"status": "not_started"} @only_edition_self_hosted def post(self): @@ -29,22 +28,23 @@ class InitValidateAPI(Resource): raise AlreadySetupError() parser = reqparse.RequestParser() - parser.add_argument('password', type=str_len(30), - required=True, location='json') - input_password = parser.parse_args()['password'] + parser.add_argument("password", type=str_len(30), required=True, location="json") + input_password = parser.parse_args()["password"] - if input_password != os.environ.get('INIT_PASSWORD'): - session['is_init_validated'] = False + if input_password != os.environ.get("INIT_PASSWORD"): + session["is_init_validated"] = False raise InitValidateFailedError() - - session['is_init_validated'] = True - return {'result': 'success'}, 201 + + session["is_init_validated"] = True + return {"result": "success"}, 201 + def get_init_validate_status(): - if dify_config.EDITION == 'SELF_HOSTED': - if os.environ.get('INIT_PASSWORD'): - return session.get('is_init_validated') or DifySetup.query.first() - + if dify_config.EDITION == "SELF_HOSTED": + if os.environ.get("INIT_PASSWORD"): + return session.get("is_init_validated") or DifySetup.query.first() + return True -api.add_resource(InitValidateAPI, '/init') + +api.add_resource(InitValidateAPI, "/init") diff --git a/api/controllers/console/ping.py b/api/controllers/console/ping.py index 7664ba8c16..cd28cc946e 100644 --- a/api/controllers/console/ping.py +++ b/api/controllers/console/ping.py @@ -4,14 +4,11 @@ from controllers.console import api class PingApi(Resource): - def get(self): """ For connection health check """ - return { - "result": "pong" - } + return {"result": "pong"} -api.add_resource(PingApi, '/ping') +api.add_resource(PingApi, "/ping") diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index ef7cc6bc03..827695e00f 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -16,17 +16,13 @@ from .wraps import only_edition_self_hosted class SetupApi(Resource): - def get(self): - if dify_config.EDITION == 'SELF_HOSTED': + if dify_config.EDITION == "SELF_HOSTED": setup_status = get_setup_status() if setup_status: - return { - 'step': 'finished', - 'setup_at': setup_status.setup_at.isoformat() - } - return {'step': 'not_started'} - return {'step': 'finished'} + return {"step": "finished", "setup_at": setup_status.setup_at.isoformat()} + return {"step": "not_started"} + return {"step": "finished"} @only_edition_self_hosted def post(self): @@ -38,28 +34,22 @@ class SetupApi(Resource): tenant_count = TenantService.get_tenant_count() if tenant_count > 0: raise AlreadySetupError() - + if not get_init_validate_status(): raise NotInitValidateError() parser = reqparse.RequestParser() - parser.add_argument('email', type=email, - required=True, location='json') - parser.add_argument('name', type=str_len( - 30), required=True, location='json') - parser.add_argument('password', type=valid_password, - required=True, location='json') + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("name", type=str_len(30), required=True, location="json") + parser.add_argument("password", type=valid_password, required=True, location="json") args = parser.parse_args() # setup RegisterService.setup( - email=args['email'], - name=args['name'], - password=args['password'], - ip_address=get_remote_ip(request) + email=args["email"], name=args["name"], password=args["password"], ip_address=get_remote_ip(request) ) - return {'result': 'success'}, 201 + return {"result": "success"}, 201 def setup_required(view): @@ -68,7 +58,7 @@ def setup_required(view): # check setup if not get_init_validate_status(): raise NotInitValidateError() - + elif not get_setup_status(): raise NotSetupError() @@ -78,9 +68,10 @@ def setup_required(view): def get_setup_status(): - if dify_config.EDITION == 'SELF_HOSTED': + if dify_config.EDITION == "SELF_HOSTED": return DifySetup.query.first() else: return True -api.add_resource(SetupApi, '/setup') + +api.add_resource(SetupApi, "/setup") diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index 004afaa531..7293aeeb34 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -14,19 +14,18 @@ from services.tag_service import TagService def _validate_name(name): if not name or len(name) < 1 or len(name) > 40: - raise ValueError('Name must be between 1 to 50 characters.') + raise ValueError("Name must be between 1 to 50 characters.") return name class TagListApi(Resource): - @setup_required @login_required @account_initialization_required @marshal_with(tag_fields) def get(self): - tag_type = request.args.get('type', type=str) - keyword = request.args.get('keyword', default=None, type=str) + tag_type = request.args.get("type", type=str) + keyword = request.args.get("keyword", default=None, type=str) tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword) return tags, 200 @@ -40,28 +39,21 @@ class TagListApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('name', nullable=False, required=True, - help='Name must be between 1 to 50 characters.', - type=_validate_name) - parser.add_argument('type', type=str, location='json', - choices=Tag.TAG_TYPE_LIST, - nullable=True, - help='Invalid tag type.') + parser.add_argument( + "name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name + ) + parser.add_argument( + "type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type." + ) args = parser.parse_args() tag = TagService.save_tags(args) - response = { - 'id': tag.id, - 'name': tag.name, - 'type': tag.type, - 'binding_count': 0 - } + response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} return response, 200 class TagUpdateDeleteApi(Resource): - @setup_required @login_required @account_initialization_required @@ -72,20 +64,15 @@ class TagUpdateDeleteApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('name', nullable=False, required=True, - help='Name must be between 1 to 50 characters.', - type=_validate_name) + parser.add_argument( + "name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name + ) args = parser.parse_args() tag = TagService.update_tags(args, tag_id) binding_count = TagService.get_tag_binding_count(tag_id) - response = { - 'id': tag.id, - 'name': tag.name, - 'type': tag.type, - 'binding_count': binding_count - } + response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} return response, 200 @@ -104,7 +91,6 @@ class TagUpdateDeleteApi(Resource): class TagBindingCreateApi(Resource): - @setup_required @login_required @account_initialization_required @@ -114,14 +100,15 @@ class TagBindingCreateApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('tag_ids', type=list, nullable=False, required=True, location='json', - help='Tag IDs is required.') - parser.add_argument('target_id', type=str, nullable=False, required=True, location='json', - help='Target ID is required.') - parser.add_argument('type', type=str, location='json', - choices=Tag.TAG_TYPE_LIST, - nullable=True, - help='Invalid tag type.') + parser.add_argument( + "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required." + ) + parser.add_argument( + "target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required." + ) + parser.add_argument( + "type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type." + ) args = parser.parse_args() TagService.save_tag_binding(args) @@ -129,7 +116,6 @@ class TagBindingCreateApi(Resource): class TagBindingDeleteApi(Resource): - @setup_required @login_required @account_initialization_required @@ -139,21 +125,18 @@ class TagBindingDeleteApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('tag_id', type=str, nullable=False, required=True, - help='Tag ID is required.') - parser.add_argument('target_id', type=str, nullable=False, required=True, - help='Target ID is required.') - parser.add_argument('type', type=str, location='json', - choices=Tag.TAG_TYPE_LIST, - nullable=True, - help='Invalid tag type.') + parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") + parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") + parser.add_argument( + "type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type." + ) args = parser.parse_args() TagService.delete_tag_binding(args) return 200 -api.add_resource(TagListApi, '/tags') -api.add_resource(TagUpdateDeleteApi, '/tags/') -api.add_resource(TagBindingCreateApi, '/tag-bindings/create') -api.add_resource(TagBindingDeleteApi, '/tag-bindings/remove') +api.add_resource(TagListApi, "/tags") +api.add_resource(TagUpdateDeleteApi, "/tags/") +api.add_resource(TagBindingCreateApi, "/tag-bindings/create") +api.add_resource(TagBindingDeleteApi, "/tag-bindings/remove") diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index 1fcf4bdc00..76adbfe6a9 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -1,4 +1,3 @@ - import json import logging @@ -11,42 +10,39 @@ from . import api class VersionApi(Resource): - def get(self): parser = reqparse.RequestParser() - parser.add_argument('current_version', type=str, required=True, location='args') + parser.add_argument("current_version", type=str, required=True, location="args") args = parser.parse_args() check_update_url = dify_config.CHECK_UPDATE_URL result = { - 'version': dify_config.CURRENT_VERSION, - 'release_date': '', - 'release_notes': '', - 'can_auto_update': False, - 'features': { - 'can_replace_logo': dify_config.CAN_REPLACE_LOGO, - 'model_load_balancing_enabled': dify_config.MODEL_LB_ENABLED - } + "version": dify_config.CURRENT_VERSION, + "release_date": "", + "release_notes": "", + "can_auto_update": False, + "features": { + "can_replace_logo": dify_config.CAN_REPLACE_LOGO, + "model_load_balancing_enabled": dify_config.MODEL_LB_ENABLED, + }, } if not check_update_url: return result try: - response = requests.get(check_update_url, { - 'current_version': args.get('current_version') - }) + response = requests.get(check_update_url, {"current_version": args.get("current_version")}) except Exception as error: logging.warning("Check update version error: {}.".format(str(error))) - result['version'] = args.get('current_version') + result["version"] = args.get("current_version") return result content = json.loads(response.content) - result['version'] = content['version'] - result['release_date'] = content['releaseDate'] - result['release_notes'] = content['releaseNotes'] - result['can_auto_update'] = content['canAutoUpdate'] + result["version"] = content["version"] + result["release_date"] = content["releaseDate"] + result["release_notes"] = content["releaseNotes"] + result["can_auto_update"] = content["canAutoUpdate"] return result -api.add_resource(VersionApi, '/version') +api.add_resource(VersionApi, "/version") diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 1056d5eb62..dec426128f 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -26,52 +26,53 @@ from services.errors.account import CurrentPasswordIncorrectError as ServiceCurr class AccountInitApi(Resource): - @setup_required @login_required def post(self): account = current_user - if account.status == 'active': + if account.status == "active": raise AccountAlreadyInitedError() parser = reqparse.RequestParser() - if dify_config.EDITION == 'CLOUD': - parser.add_argument('invitation_code', type=str, location='json') + if dify_config.EDITION == "CLOUD": + parser.add_argument("invitation_code", type=str, location="json") - parser.add_argument( - 'interface_language', type=supported_language, required=True, location='json') - parser.add_argument('timezone', type=timezone, - required=True, location='json') + parser.add_argument("interface_language", type=supported_language, required=True, location="json") + parser.add_argument("timezone", type=timezone, required=True, location="json") args = parser.parse_args() - if dify_config.EDITION == 'CLOUD': - if not args['invitation_code']: - raise ValueError('invitation_code is required') + if dify_config.EDITION == "CLOUD": + if not args["invitation_code"]: + raise ValueError("invitation_code is required") # check invitation code - invitation_code = db.session.query(InvitationCode).filter( - InvitationCode.code == args['invitation_code'], - InvitationCode.status == 'unused', - ).first() + invitation_code = ( + db.session.query(InvitationCode) + .filter( + InvitationCode.code == args["invitation_code"], + InvitationCode.status == "unused", + ) + .first() + ) if not invitation_code: raise InvalidInvitationCodeError() - invitation_code.status = 'used' + invitation_code.status = "used" invitation_code.used_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) invitation_code.used_by_tenant_id = account.current_tenant_id invitation_code.used_by_account_id = account.id - account.interface_language = args['interface_language'] - account.timezone = args['timezone'] - account.interface_theme = 'light' - account.status = 'active' + account.interface_language = args["interface_language"] + account.timezone = args["timezone"] + account.interface_theme = "light" + account.status = "active" account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() - return {'result': 'success'} + return {"result": "success"} class AccountProfileApi(Resource): @@ -90,15 +91,14 @@ class AccountNameApi(Resource): @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') + parser.add_argument("name", type=str, required=True, location="json") args = parser.parse_args() # Validate account name length - if len(args['name']) < 3 or len(args['name']) > 30: - raise ValueError( - "Account name must be between 3 and 30 characters.") + if len(args["name"]) < 3 or len(args["name"]) > 30: + raise ValueError("Account name must be between 3 and 30 characters.") - updated_account = AccountService.update_account(current_user, name=args['name']) + updated_account = AccountService.update_account(current_user, name=args["name"]) return updated_account @@ -110,10 +110,10 @@ class AccountAvatarApi(Resource): @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument('avatar', type=str, required=True, location='json') + parser.add_argument("avatar", type=str, required=True, location="json") args = parser.parse_args() - updated_account = AccountService.update_account(current_user, avatar=args['avatar']) + updated_account = AccountService.update_account(current_user, avatar=args["avatar"]) return updated_account @@ -125,11 +125,10 @@ class AccountInterfaceLanguageApi(Resource): @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument( - 'interface_language', type=supported_language, required=True, location='json') + parser.add_argument("interface_language", type=supported_language, required=True, location="json") args = parser.parse_args() - updated_account = AccountService.update_account(current_user, interface_language=args['interface_language']) + updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"]) return updated_account @@ -141,11 +140,10 @@ class AccountInterfaceThemeApi(Resource): @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument('interface_theme', type=str, choices=[ - 'light', 'dark'], required=True, location='json') + parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json") args = parser.parse_args() - updated_account = AccountService.update_account(current_user, interface_theme=args['interface_theme']) + updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"]) return updated_account @@ -157,15 +155,14 @@ class AccountTimezoneApi(Resource): @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument('timezone', type=str, - required=True, location='json') + parser.add_argument("timezone", type=str, required=True, location="json") args = parser.parse_args() # Validate timezone string, e.g. America/New_York, Asia/Shanghai - if args['timezone'] not in pytz.all_timezones: + if args["timezone"] not in pytz.all_timezones: raise ValueError("Invalid timezone string.") - updated_account = AccountService.update_account(current_user, timezone=args['timezone']) + updated_account = AccountService.update_account(current_user, timezone=args["timezone"]) return updated_account @@ -177,20 +174,16 @@ class AccountPasswordApi(Resource): @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument('password', type=str, - required=False, location='json') - parser.add_argument('new_password', type=str, - required=True, location='json') - parser.add_argument('repeat_new_password', type=str, - required=True, location='json') + parser.add_argument("password", type=str, required=False, location="json") + parser.add_argument("new_password", type=str, required=True, location="json") + parser.add_argument("repeat_new_password", type=str, required=True, location="json") args = parser.parse_args() - if args['new_password'] != args['repeat_new_password']: + if args["new_password"] != args["repeat_new_password"]: raise RepeatPasswordNotMatchError() try: - AccountService.update_account_password( - current_user, args['password'], args['new_password']) + AccountService.update_account_password(current_user, args["password"], args["new_password"]) except ServiceCurrentPasswordIncorrectError: raise CurrentPasswordIncorrectError() @@ -199,14 +192,14 @@ class AccountPasswordApi(Resource): class AccountIntegrateApi(Resource): integrate_fields = { - 'provider': fields.String, - 'created_at': TimestampField, - 'is_bound': fields.Boolean, - 'link': fields.String + "provider": fields.String, + "created_at": TimestampField, + "is_bound": fields.Boolean, + "link": fields.String, } integrate_list_fields = { - 'data': fields.List(fields.Nested(integrate_fields)), + "data": fields.List(fields.Nested(integrate_fields)), } @setup_required @@ -216,10 +209,9 @@ class AccountIntegrateApi(Resource): def get(self): account = current_user - account_integrates = db.session.query(AccountIntegrate).filter( - AccountIntegrate.account_id == account.id).all() + account_integrates = db.session.query(AccountIntegrate).filter(AccountIntegrate.account_id == account.id).all() - base_url = request.url_root.rstrip('/') + base_url = request.url_root.rstrip("/") oauth_base_path = "/console/api/oauth/login" providers = ["github", "google"] @@ -227,36 +219,38 @@ class AccountIntegrateApi(Resource): for provider in providers: existing_integrate = next((ai for ai in account_integrates if ai.provider == provider), None) if existing_integrate: - integrate_data.append({ - 'id': existing_integrate.id, - 'provider': provider, - 'created_at': existing_integrate.created_at, - 'is_bound': True, - 'link': None - }) + integrate_data.append( + { + "id": existing_integrate.id, + "provider": provider, + "created_at": existing_integrate.created_at, + "is_bound": True, + "link": None, + } + ) else: - integrate_data.append({ - 'id': None, - 'provider': provider, - 'created_at': None, - 'is_bound': False, - 'link': f'{base_url}{oauth_base_path}/{provider}' - }) - - return {'data': integrate_data} - + integrate_data.append( + { + "id": None, + "provider": provider, + "created_at": None, + "is_bound": False, + "link": f"{base_url}{oauth_base_path}/{provider}", + } + ) + return {"data": integrate_data} # Register API resources -api.add_resource(AccountInitApi, '/account/init') -api.add_resource(AccountProfileApi, '/account/profile') -api.add_resource(AccountNameApi, '/account/name') -api.add_resource(AccountAvatarApi, '/account/avatar') -api.add_resource(AccountInterfaceLanguageApi, '/account/interface-language') -api.add_resource(AccountInterfaceThemeApi, '/account/interface-theme') -api.add_resource(AccountTimezoneApi, '/account/timezone') -api.add_resource(AccountPasswordApi, '/account/password') -api.add_resource(AccountIntegrateApi, '/account/integrates') +api.add_resource(AccountInitApi, "/account/init") +api.add_resource(AccountProfileApi, "/account/profile") +api.add_resource(AccountNameApi, "/account/name") +api.add_resource(AccountAvatarApi, "/account/avatar") +api.add_resource(AccountInterfaceLanguageApi, "/account/interface-language") +api.add_resource(AccountInterfaceThemeApi, "/account/interface-theme") +api.add_resource(AccountTimezoneApi, "/account/timezone") +api.add_resource(AccountPasswordApi, "/account/password") +api.add_resource(AccountIntegrateApi, "/account/integrates") # api.add_resource(AccountEmailApi, '/account/email') # api.add_resource(AccountEmailVerifyApi, '/account/email-verify') diff --git a/api/controllers/console/workspace/error.py b/api/controllers/console/workspace/error.py index 99f55835bc..9e13c7b924 100644 --- a/api/controllers/console/workspace/error.py +++ b/api/controllers/console/workspace/error.py @@ -2,36 +2,36 @@ from libs.exception import BaseHTTPException class RepeatPasswordNotMatchError(BaseHTTPException): - error_code = 'repeat_password_not_match' + error_code = "repeat_password_not_match" description = "New password and repeat password does not match." code = 400 class CurrentPasswordIncorrectError(BaseHTTPException): - error_code = 'current_password_incorrect' + error_code = "current_password_incorrect" description = "Current password is incorrect." code = 400 class ProviderRequestFailedError(BaseHTTPException): - error_code = 'provider_request_failed' + error_code = "provider_request_failed" description = None code = 400 class InvalidInvitationCodeError(BaseHTTPException): - error_code = 'invalid_invitation_code' + error_code = "invalid_invitation_code" description = "Invalid invitation code." code = 400 class AccountAlreadyInitedError(BaseHTTPException): - error_code = 'account_already_inited' + error_code = "account_already_inited" description = "The account has been initialized. Please refresh the page." code = 400 class AccountNotInitializedError(BaseHTTPException): - error_code = 'account_not_initialized' + error_code = "account_not_initialized" description = "The account has not been initialized yet. Please proceed with the initialization process first." code = 400 diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index 50514e39f6..771a866624 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -22,10 +22,16 @@ class LoadBalancingCredentialsValidateApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() # validate model load balancing credentials @@ -38,18 +44,18 @@ class LoadBalancingCredentialsValidateApi(Resource): model_load_balancing_service.validate_load_balancing_credentials( tenant_id=tenant_id, provider=provider, - model=args['model'], - model_type=args['model_type'], - credentials=args['credentials'] + model=args["model"], + model_type=args["model_type"], + credentials=args["credentials"], ) except CredentialsValidateFailedError as ex: result = False error = str(ex) - response = {'result': 'success' if result else 'error'} + response = {"result": "success" if result else "error"} if not result: - response['error'] = error + response["error"] = error return response @@ -65,10 +71,16 @@ class LoadBalancingConfigCredentialsValidateApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() # validate model load balancing config credentials @@ -81,26 +93,30 @@ class LoadBalancingConfigCredentialsValidateApi(Resource): model_load_balancing_service.validate_load_balancing_credentials( tenant_id=tenant_id, provider=provider, - model=args['model'], - model_type=args['model_type'], - credentials=args['credentials'], + model=args["model"], + model_type=args["model_type"], + credentials=args["credentials"], config_id=config_id, ) except CredentialsValidateFailedError as ex: result = False error = str(ex) - response = {'result': 'success' if result else 'error'} + response = {"result": "success" if result else "error"} if not result: - response['error'] = error + response["error"] = error return response # Load Balancing Config -api.add_resource(LoadBalancingCredentialsValidateApi, - '/workspaces/current/model-providers//models/load-balancing-configs/credentials-validate') +api.add_resource( + LoadBalancingCredentialsValidateApi, + "/workspaces/current/model-providers//models/load-balancing-configs/credentials-validate", +) -api.add_resource(LoadBalancingConfigCredentialsValidateApi, - '/workspaces/current/model-providers//models/load-balancing-configs//credentials-validate') +api.add_resource( + LoadBalancingConfigCredentialsValidateApi, + "/workspaces/current/model-providers//models/load-balancing-configs//credentials-validate", +) diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 34e9da3841..3e87bebf59 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -23,7 +23,7 @@ class MemberListApi(Resource): @marshal_with(account_with_role_list_fields) def get(self): members = TenantService.get_tenant_members(current_user.current_tenant) - return {'result': 'success', 'accounts': members}, 200 + return {"result": "success", "accounts": members}, 200 class MemberInviteEmailApi(Resource): @@ -32,48 +32,46 @@ class MemberInviteEmailApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('members') + @cloud_edition_billing_resource_check("members") def post(self): parser = reqparse.RequestParser() - parser.add_argument('emails', type=str, required=True, location='json', action='append') - parser.add_argument('role', type=str, required=True, default='admin', location='json') - parser.add_argument('language', type=str, required=False, location='json') + parser.add_argument("emails", type=str, required=True, location="json", action="append") + parser.add_argument("role", type=str, required=True, default="admin", location="json") + parser.add_argument("language", type=str, required=False, location="json") args = parser.parse_args() - invitee_emails = args['emails'] - invitee_role = args['role'] - interface_language = args['language'] + invitee_emails = args["emails"] + invitee_role = args["role"] + interface_language = args["language"] if not TenantAccountRole.is_non_owner_role(invitee_role): - return {'code': 'invalid-role', 'message': 'Invalid role'}, 400 + return {"code": "invalid-role", "message": "Invalid role"}, 400 inviter = current_user invitation_results = [] console_web_url = dify_config.CONSOLE_WEB_URL for invitee_email in invitee_emails: try: - token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter) - invitation_results.append({ - 'status': 'success', - 'email': invitee_email, - 'url': f'{console_web_url}/activate?email={invitee_email}&token={token}' - }) + token = RegisterService.invite_new_member( + inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter + ) + invitation_results.append( + { + "status": "success", + "email": invitee_email, + "url": f"{console_web_url}/activate?email={invitee_email}&token={token}", + } + ) except AccountAlreadyInTenantError: - invitation_results.append({ - 'status': 'success', - 'email': invitee_email, - 'url': f'{console_web_url}/signin' - }) + invitation_results.append( + {"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"} + ) break except Exception as e: - invitation_results.append({ - 'status': 'failed', - 'email': invitee_email, - 'message': str(e) - }) + invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)}) return { - 'result': 'success', - 'invitation_results': invitation_results, + "result": "success", + "invitation_results": invitation_results, }, 201 @@ -91,15 +89,15 @@ class MemberCancelInviteApi(Resource): try: TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user) except services.errors.account.CannotOperateSelfError as e: - return {'code': 'cannot-operate-self', 'message': str(e)}, 400 + return {"code": "cannot-operate-self", "message": str(e)}, 400 except services.errors.account.NoPermissionError as e: - return {'code': 'forbidden', 'message': str(e)}, 403 + return {"code": "forbidden", "message": str(e)}, 403 except services.errors.account.MemberNotInTenantError as e: - return {'code': 'member-not-found', 'message': str(e)}, 404 + return {"code": "member-not-found", "message": str(e)}, 404 except Exception as e: raise ValueError(str(e)) - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class MemberUpdateRoleApi(Resource): @@ -110,12 +108,12 @@ class MemberUpdateRoleApi(Resource): @account_initialization_required def put(self, member_id): parser = reqparse.RequestParser() - parser.add_argument('role', type=str, required=True, location='json') + parser.add_argument("role", type=str, required=True, location="json") args = parser.parse_args() - new_role = args['role'] + new_role = args["role"] if not TenantAccountRole.is_valid_role(new_role): - return {'code': 'invalid-role', 'message': 'Invalid role'}, 400 + return {"code": "invalid-role", "message": "Invalid role"}, 400 member = db.session.get(Account, str(member_id)) if not member: @@ -128,7 +126,7 @@ class MemberUpdateRoleApi(Resource): # todo: 403 - return {'result': 'success'} + return {"result": "success"} class DatasetOperatorMemberListApi(Resource): @@ -140,11 +138,11 @@ class DatasetOperatorMemberListApi(Resource): @marshal_with(account_with_role_list_fields) def get(self): members = TenantService.get_dataset_operator_members(current_user.current_tenant) - return {'result': 'success', 'accounts': members}, 200 + return {"result": "success", "accounts": members}, 200 -api.add_resource(MemberListApi, '/workspaces/current/members') -api.add_resource(MemberInviteEmailApi, '/workspaces/current/members/invite-email') -api.add_resource(MemberCancelInviteApi, '/workspaces/current/members/') -api.add_resource(MemberUpdateRoleApi, '/workspaces/current/members//update-role') -api.add_resource(DatasetOperatorMemberListApi, '/workspaces/current/dataset-operators') +api.add_resource(MemberListApi, "/workspaces/current/members") +api.add_resource(MemberInviteEmailApi, "/workspaces/current/members/invite-email") +api.add_resource(MemberCancelInviteApi, "/workspaces/current/members/") +api.add_resource(MemberUpdateRoleApi, "/workspaces/current/members//update-role") +api.add_resource(DatasetOperatorMemberListApi, "/workspaces/current/dataset-operators") diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index c888159f83..8c38420226 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -17,7 +17,6 @@ from services.model_provider_service import ModelProviderService class ModelProviderListApi(Resource): - @setup_required @login_required @account_initialization_required @@ -25,21 +24,23 @@ class ModelProviderListApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model_type', type=str, required=False, nullable=True, - choices=[mt.value for mt in ModelType], location='args') + parser.add_argument( + "model_type", + type=str, + required=False, + nullable=True, + choices=[mt.value for mt in ModelType], + location="args", + ) args = parser.parse_args() model_provider_service = ModelProviderService() - provider_list = model_provider_service.get_provider_list( - tenant_id=tenant_id, - model_type=args.get('model_type') - ) + provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type")) return jsonable_encoder({"data": provider_list}) class ModelProviderCredentialApi(Resource): - @setup_required @login_required @account_initialization_required @@ -47,25 +48,18 @@ class ModelProviderCredentialApi(Resource): tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() - credentials = model_provider_service.get_provider_credentials( - tenant_id=tenant_id, - provider=provider - ) + credentials = model_provider_service.get_provider_credentials(tenant_id=tenant_id, provider=provider) - return { - "credentials": credentials - } + return {"credentials": credentials} class ModelProviderValidateApi(Resource): - @setup_required @login_required @account_initialization_required def post(self, provider: str): - parser = reqparse.RequestParser() - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() tenant_id = current_user.current_tenant_id @@ -77,24 +71,21 @@ class ModelProviderValidateApi(Resource): try: model_provider_service.provider_credentials_validate( - tenant_id=tenant_id, - provider=provider, - credentials=args['credentials'] + tenant_id=tenant_id, provider=provider, credentials=args["credentials"] ) except CredentialsValidateFailedError as ex: result = False error = str(ex) - response = {'result': 'success' if result else 'error'} + response = {"result": "success" if result else "error"} if not result: - response['error'] = error + response["error"] = error return response class ModelProviderApi(Resource): - @setup_required @login_required @account_initialization_required @@ -103,21 +94,19 @@ class ModelProviderApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() model_provider_service = ModelProviderService() try: model_provider_service.save_provider_credentials( - tenant_id=current_user.current_tenant_id, - provider=provider, - credentials=args['credentials'] + tenant_id=current_user.current_tenant_id, provider=provider, credentials=args["credentials"] ) except CredentialsValidateFailedError as ex: raise ValueError(str(ex)) - return {'result': 'success'}, 201 + return {"result": "success"}, 201 @setup_required @login_required @@ -127,12 +116,9 @@ class ModelProviderApi(Resource): raise Forbidden() model_provider_service = ModelProviderService() - model_provider_service.remove_provider_credentials( - tenant_id=current_user.current_tenant_id, - provider=provider - ) + model_provider_service.remove_provider_credentials(tenant_id=current_user.current_tenant_id, provider=provider) - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class ModelProviderIconApi(Resource): @@ -146,16 +132,13 @@ class ModelProviderIconApi(Resource): def get(self, provider: str, icon_type: str, lang: str): model_provider_service = ModelProviderService() icon, mimetype = model_provider_service.get_model_provider_icon( - provider=provider, - icon_type=icon_type, - lang=lang + provider=provider, icon_type=icon_type, lang=lang ) return send_file(io.BytesIO(icon), mimetype=mimetype) class PreferredProviderTypeUpdateApi(Resource): - @setup_required @login_required @account_initialization_required @@ -166,18 +149,22 @@ class PreferredProviderTypeUpdateApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False, - choices=['system', 'custom'], location='json') + parser.add_argument( + "preferred_provider_type", + type=str, + required=True, + nullable=False, + choices=["system", "custom"], + location="json", + ) args = parser.parse_args() model_provider_service = ModelProviderService() model_provider_service.switch_preferred_provider( - tenant_id=tenant_id, - provider=provider, - preferred_provider_type=args['preferred_provider_type'] + tenant_id=tenant_id, provider=provider, preferred_provider_type=args["preferred_provider_type"] ) - return {'result': 'success'} + return {"result": "success"} class ModelProviderPaymentCheckoutUrlApi(Resource): @@ -185,13 +172,15 @@ class ModelProviderPaymentCheckoutUrlApi(Resource): @login_required @account_initialization_required def get(self, provider: str): - if provider != 'anthropic': - raise ValueError(f'provider name {provider} is invalid') + if provider != "anthropic": + raise ValueError(f"provider name {provider} is invalid") BillingService.is_tenant_owner_or_admin(current_user) - data = BillingService.get_model_provider_payment_link(provider_name=provider, - tenant_id=current_user.current_tenant_id, - account_id=current_user.id, - prefilled_email=current_user.email) + data = BillingService.get_model_provider_payment_link( + provider_name=provider, + tenant_id=current_user.current_tenant_id, + account_id=current_user.id, + prefilled_email=current_user.email, + ) return data @@ -201,10 +190,7 @@ class ModelProviderFreeQuotaSubmitApi(Resource): @account_initialization_required def post(self, provider: str): model_provider_service = ModelProviderService() - result = model_provider_service.free_quota_submit( - tenant_id=current_user.current_tenant_id, - provider=provider - ) + result = model_provider_service.free_quota_submit(tenant_id=current_user.current_tenant_id, provider=provider) return result @@ -215,32 +201,36 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource): @account_initialization_required def get(self, provider: str): parser = reqparse.RequestParser() - parser.add_argument('token', type=str, required=False, nullable=True, location='args') + parser.add_argument("token", type=str, required=False, nullable=True, location="args") args = parser.parse_args() model_provider_service = ModelProviderService() result = model_provider_service.free_quota_qualification_verify( - tenant_id=current_user.current_tenant_id, - provider=provider, - token=args['token'] + tenant_id=current_user.current_tenant_id, provider=provider, token=args["token"] ) return result -api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers') +api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers") -api.add_resource(ModelProviderCredentialApi, '/workspaces/current/model-providers//credentials') -api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers//credentials/validate') -api.add_resource(ModelProviderApi, '/workspaces/current/model-providers/') -api.add_resource(ModelProviderIconApi, '/workspaces/current/model-providers//' - '/') +api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers//credentials") +api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers//credentials/validate") +api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/") +api.add_resource( + ModelProviderIconApi, "/workspaces/current/model-providers//" "/" +) -api.add_resource(PreferredProviderTypeUpdateApi, - '/workspaces/current/model-providers//preferred-provider-type') -api.add_resource(ModelProviderPaymentCheckoutUrlApi, - '/workspaces/current/model-providers//checkout-url') -api.add_resource(ModelProviderFreeQuotaSubmitApi, - '/workspaces/current/model-providers//free-quota-submit') -api.add_resource(ModelProviderFreeQuotaQualificationVerifyApi, - '/workspaces/current/model-providers//free-quota-qualification-verify') +api.add_resource( + PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers//preferred-provider-type" +) +api.add_resource( + ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers//checkout-url" +) +api.add_resource( + ModelProviderFreeQuotaSubmitApi, "/workspaces/current/model-providers//free-quota-submit" +) +api.add_resource( + ModelProviderFreeQuotaQualificationVerifyApi, + "/workspaces/current/model-providers//free-quota-qualification-verify", +) diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 69f2253e97..dc88f6b812 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -16,27 +16,29 @@ from services.model_provider_service import ModelProviderService class DefaultModelApi(Resource): - @setup_required @login_required @account_initialization_required def get(self): parser = reqparse.RequestParser() - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='args') + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="args", + ) args = parser.parse_args() tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() default_model_entity = model_provider_service.get_default_model_of_model_type( - tenant_id=tenant_id, - model_type=args['model_type'] + tenant_id=tenant_id, model_type=args["model_type"] ) - return jsonable_encoder({ - "data": default_model_entity - }) + return jsonable_encoder({"data": default_model_entity}) @setup_required @login_required @@ -44,40 +46,39 @@ class DefaultModelApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json') + parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json") args = parser.parse_args() tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() - model_settings = args['model_settings'] + model_settings = args["model_settings"] for model_setting in model_settings: - if 'model_type' not in model_setting or model_setting['model_type'] not in [mt.value for mt in ModelType]: - raise ValueError('invalid model type') + if "model_type" not in model_setting or model_setting["model_type"] not in [mt.value for mt in ModelType]: + raise ValueError("invalid model type") - if 'provider' not in model_setting: + if "provider" not in model_setting: continue - if 'model' not in model_setting: - raise ValueError('invalid model') + if "model" not in model_setting: + raise ValueError("invalid model") try: model_provider_service.update_default_model_of_model_type( tenant_id=tenant_id, - model_type=model_setting['model_type'], - provider=model_setting['provider'], - model=model_setting['model'] + model_type=model_setting["model_type"], + provider=model_setting["provider"], + model=model_setting["model"], ) except Exception: logging.warning(f"{model_setting['model_type']} save error") - return {'result': 'success'} + return {"result": "success"} class ModelProviderModelApi(Resource): - @setup_required @login_required @account_initialization_required @@ -85,14 +86,9 @@ class ModelProviderModelApi(Resource): tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() - models = model_provider_service.get_models_by_provider( - tenant_id=tenant_id, - provider=provider - ) + models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider) - return jsonable_encoder({ - "data": models - }) + return jsonable_encoder({"data": models}) @setup_required @login_required @@ -104,61 +100,66 @@ class ModelProviderModelApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') - parser.add_argument('credentials', type=dict, required=False, nullable=True, location='json') - parser.add_argument('load_balancing', type=dict, required=False, nullable=True, location='json') - parser.add_argument('config_from', type=str, required=False, nullable=True, location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") + parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json") + parser.add_argument("config_from", type=str, required=False, nullable=True, location="json") args = parser.parse_args() model_load_balancing_service = ModelLoadBalancingService() - if ('load_balancing' in args and args['load_balancing'] and - 'enabled' in args['load_balancing'] and args['load_balancing']['enabled']): - if 'configs' not in args['load_balancing']: - raise ValueError('invalid load balancing configs') + if ( + "load_balancing" in args + and args["load_balancing"] + and "enabled" in args["load_balancing"] + and args["load_balancing"]["enabled"] + ): + if "configs" not in args["load_balancing"]: + raise ValueError("invalid load balancing configs") # save load balancing configs model_load_balancing_service.update_load_balancing_configs( tenant_id=tenant_id, provider=provider, - model=args['model'], - model_type=args['model_type'], - configs=args['load_balancing']['configs'] + model=args["model"], + model_type=args["model_type"], + configs=args["load_balancing"]["configs"], ) # enable load balancing model_load_balancing_service.enable_model_load_balancing( - tenant_id=tenant_id, - provider=provider, - model=args['model'], - model_type=args['model_type'] + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) else: # disable load balancing model_load_balancing_service.disable_model_load_balancing( - tenant_id=tenant_id, - provider=provider, - model=args['model'], - model_type=args['model_type'] + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) - if args.get('config_from', '') != 'predefined-model': + if args.get("config_from", "") != "predefined-model": model_provider_service = ModelProviderService() try: model_provider_service.save_model_credentials( tenant_id=tenant_id, provider=provider, - model=args['model'], - model_type=args['model_type'], - credentials=args['credentials'] + model=args["model"], + model_type=args["model_type"], + credentials=args["credentials"], ) except CredentialsValidateFailedError as ex: + logging.exception(f"save model credentials error: {ex}") raise ValueError(str(ex)) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 @setup_required @login_required @@ -170,24 +171,26 @@ class ModelProviderModelApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) args = parser.parse_args() model_provider_service = ModelProviderService() model_provider_service.remove_model_credentials( - tenant_id=tenant_id, - provider=provider, - model=args['model'], - model_type=args['model_type'] + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class ModelProviderModelCredentialApi(Resource): - @setup_required @login_required @account_initialization_required @@ -195,38 +198,34 @@ class ModelProviderModelCredentialApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='args') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='args') + parser.add_argument("model", type=str, required=True, nullable=False, location="args") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="args", + ) args = parser.parse_args() model_provider_service = ModelProviderService() credentials = model_provider_service.get_model_credentials( - tenant_id=tenant_id, - provider=provider, - model_type=args['model_type'], - model=args['model'] + tenant_id=tenant_id, provider=provider, model_type=args["model_type"], model=args["model"] ) model_load_balancing_service = ModelLoadBalancingService() is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs( - tenant_id=tenant_id, - provider=provider, - model=args['model'], - model_type=args['model_type'] + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) return { "credentials": credentials, - "load_balancing": { - "enabled": is_load_balancing_enabled, - "configs": load_balancing_configs - } + "load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs}, } class ModelProviderModelEnableApi(Resource): - @setup_required @login_required @account_initialization_required @@ -234,24 +233,26 @@ class ModelProviderModelEnableApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) args = parser.parse_args() model_provider_service = ModelProviderService() model_provider_service.enable_model( - tenant_id=tenant_id, - provider=provider, - model=args['model'], - model_type=args['model_type'] + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) - return {'result': 'success'} + return {"result": "success"} class ModelProviderModelDisableApi(Resource): - @setup_required @login_required @account_initialization_required @@ -259,24 +260,26 @@ class ModelProviderModelDisableApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) args = parser.parse_args() model_provider_service = ModelProviderService() model_provider_service.disable_model( - tenant_id=tenant_id, - provider=provider, - model=args['model'], - model_type=args['model_type'] + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) - return {'result': 'success'} + return {"result": "success"} class ModelProviderModelValidateApi(Resource): - @setup_required @login_required @account_initialization_required @@ -284,10 +287,16 @@ class ModelProviderModelValidateApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() model_provider_service = ModelProviderService() @@ -299,48 +308,42 @@ class ModelProviderModelValidateApi(Resource): model_provider_service.model_credentials_validate( tenant_id=tenant_id, provider=provider, - model=args['model'], - model_type=args['model_type'], - credentials=args['credentials'] + model=args["model"], + model_type=args["model_type"], + credentials=args["credentials"], ) except CredentialsValidateFailedError as ex: result = False error = str(ex) - response = {'result': 'success' if result else 'error'} + response = {"result": "success" if result else "error"} if not result: - response['error'] = error + response["error"] = error return response class ModelProviderModelParameterRuleApi(Resource): - @setup_required @login_required @account_initialization_required def get(self, provider: str): parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='args') + parser.add_argument("model", type=str, required=True, nullable=False, location="args") args = parser.parse_args() tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() parameter_rules = model_provider_service.get_model_parameter_rules( - tenant_id=tenant_id, - provider=provider, - model=args['model'] + tenant_id=tenant_id, provider=provider, model=args["model"] ) - return jsonable_encoder({ - "data": parameter_rules - }) + return jsonable_encoder({"data": parameter_rules}) class ModelProviderAvailableModelApi(Resource): - @setup_required @login_required @account_initialization_required @@ -348,27 +351,31 @@ class ModelProviderAvailableModelApi(Resource): tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() - models = model_provider_service.get_models_by_model_type( - tenant_id=tenant_id, - model_type=model_type - ) + models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type) - return jsonable_encoder({ - "data": models - }) + return jsonable_encoder({"data": models}) -api.add_resource(ModelProviderModelApi, '/workspaces/current/model-providers//models') -api.add_resource(ModelProviderModelEnableApi, '/workspaces/current/model-providers//models/enable', - endpoint='model-provider-model-enable') -api.add_resource(ModelProviderModelDisableApi, '/workspaces/current/model-providers//models/disable', - endpoint='model-provider-model-disable') -api.add_resource(ModelProviderModelCredentialApi, - '/workspaces/current/model-providers//models/credentials') -api.add_resource(ModelProviderModelValidateApi, - '/workspaces/current/model-providers//models/credentials/validate') +api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers//models") +api.add_resource( + ModelProviderModelEnableApi, + "/workspaces/current/model-providers//models/enable", + endpoint="model-provider-model-enable", +) +api.add_resource( + ModelProviderModelDisableApi, + "/workspaces/current/model-providers//models/disable", + endpoint="model-provider-model-disable", +) +api.add_resource( + ModelProviderModelCredentialApi, "/workspaces/current/model-providers//models/credentials" +) +api.add_resource( + ModelProviderModelValidateApi, "/workspaces/current/model-providers//models/credentials/validate" +) -api.add_resource(ModelProviderModelParameterRuleApi, - '/workspaces/current/model-providers//models/parameter-rules') -api.add_resource(ModelProviderAvailableModelApi, '/workspaces/current/models/model-types/') -api.add_resource(DefaultModelApi, '/workspaces/current/default-model') +api.add_resource( + ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers//models/parameter-rules" +) +api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/") +api.add_resource(DefaultModelApi, "/workspaces/current/default-model") diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index bafeabb08a..c41a898fdc 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -28,10 +28,18 @@ class ToolProviderListApi(Resource): tenant_id = current_user.current_tenant_id req = reqparse.RequestParser() - req.add_argument('type', type=str, choices=['builtin', 'model', 'api', 'workflow'], required=False, nullable=True, location='args') + req.add_argument( + "type", + type=str, + choices=["builtin", "model", "api", "workflow"], + required=False, + nullable=True, + location="args", + ) args = req.parse_args() - return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get('type', None)) + return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None)) + class ToolBuiltinProviderListToolsApi(Resource): @setup_required @@ -41,11 +49,14 @@ class ToolBuiltinProviderListToolsApi(Resource): user_id = current_user.id tenant_id = current_user.current_tenant_id - return jsonable_encoder(BuiltinToolManageService.list_builtin_tool_provider_tools( - user_id, - tenant_id, - provider, - )) + return jsonable_encoder( + BuiltinToolManageService.list_builtin_tool_provider_tools( + user_id, + tenant_id, + provider, + ) + ) + class ToolBuiltinProviderDeleteApi(Resource): @setup_required @@ -54,7 +65,7 @@ class ToolBuiltinProviderDeleteApi(Resource): def post(self, provider): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id @@ -63,7 +74,8 @@ class ToolBuiltinProviderDeleteApi(Resource): tenant_id, provider, ) - + + class ToolBuiltinProviderUpdateApi(Resource): @setup_required @login_required @@ -71,12 +83,12 @@ class ToolBuiltinProviderUpdateApi(Resource): def post(self, provider): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() @@ -84,9 +96,10 @@ class ToolBuiltinProviderUpdateApi(Resource): user_id, tenant_id, provider, - args['credentials'], + args["credentials"], ) - + + class ToolBuiltinProviderGetCredentialsApi(Resource): @setup_required @login_required @@ -101,6 +114,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource): provider, ) + class ToolBuiltinProviderIconApi(Resource): @setup_required def get(self, provider): @@ -108,6 +122,7 @@ class ToolBuiltinProviderIconApi(Resource): icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) + class ToolApiProviderAddApi(Resource): @setup_required @login_required @@ -115,35 +130,36 @@ class ToolApiProviderAddApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') - parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json') - parser.add_argument('schema', type=str, required=True, nullable=False, location='json') - parser.add_argument('provider', type=str, required=True, nullable=False, location='json') - parser.add_argument('icon', type=dict, required=True, nullable=False, location='json') - parser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json') - parser.add_argument('labels', type=list[str], required=False, nullable=True, location='json', default=[]) - parser.add_argument('custom_disclaimer', type=str, required=False, nullable=True, location='json') + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json") + parser.add_argument("schema", type=str, required=True, nullable=False, location="json") + parser.add_argument("provider", type=str, required=True, nullable=False, location="json") + parser.add_argument("icon", type=dict, required=True, nullable=False, location="json") + parser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json") + parser.add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[]) + parser.add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json") args = parser.parse_args() return ApiToolManageService.create_api_tool_provider( user_id, tenant_id, - args['provider'], - args['icon'], - args['credentials'], - args['schema_type'], - args['schema'], - args.get('privacy_policy', ''), - args.get('custom_disclaimer', ''), - args.get('labels', []), + args["provider"], + args["icon"], + args["credentials"], + args["schema_type"], + args["schema"], + args.get("privacy_policy", ""), + args.get("custom_disclaimer", ""), + args.get("labels", []), ) + class ToolApiProviderGetRemoteSchemaApi(Resource): @setup_required @login_required @@ -151,16 +167,17 @@ class ToolApiProviderGetRemoteSchemaApi(Resource): def get(self): parser = reqparse.RequestParser() - parser.add_argument('url', type=str, required=True, nullable=False, location='args') + parser.add_argument("url", type=str, required=True, nullable=False, location="args") args = parser.parse_args() return ApiToolManageService.get_api_tool_provider_remote_schema( current_user.id, current_user.current_tenant_id, - args['url'], + args["url"], ) - + + class ToolApiProviderListToolsApi(Resource): @setup_required @login_required @@ -171,15 +188,18 @@ class ToolApiProviderListToolsApi(Resource): parser = reqparse.RequestParser() - parser.add_argument('provider', type=str, required=True, nullable=False, location='args') + parser.add_argument("provider", type=str, required=True, nullable=False, location="args") args = parser.parse_args() - return jsonable_encoder(ApiToolManageService.list_api_tool_provider_tools( - user_id, - tenant_id, - args['provider'], - )) + return jsonable_encoder( + ApiToolManageService.list_api_tool_provider_tools( + user_id, + tenant_id, + args["provider"], + ) + ) + class ToolApiProviderUpdateApi(Resource): @setup_required @@ -188,37 +208,38 @@ class ToolApiProviderUpdateApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') - parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json') - parser.add_argument('schema', type=str, required=True, nullable=False, location='json') - parser.add_argument('provider', type=str, required=True, nullable=False, location='json') - parser.add_argument('original_provider', type=str, required=True, nullable=False, location='json') - parser.add_argument('icon', type=dict, required=True, nullable=False, location='json') - parser.add_argument('privacy_policy', type=str, required=True, nullable=True, location='json') - parser.add_argument('labels', type=list[str], required=False, nullable=True, location='json') - parser.add_argument('custom_disclaimer', type=str, required=True, nullable=True, location='json') + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json") + parser.add_argument("schema", type=str, required=True, nullable=False, location="json") + parser.add_argument("provider", type=str, required=True, nullable=False, location="json") + parser.add_argument("original_provider", type=str, required=True, nullable=False, location="json") + parser.add_argument("icon", type=dict, required=True, nullable=False, location="json") + parser.add_argument("privacy_policy", type=str, required=True, nullable=True, location="json") + parser.add_argument("labels", type=list[str], required=False, nullable=True, location="json") + parser.add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json") args = parser.parse_args() return ApiToolManageService.update_api_tool_provider( user_id, tenant_id, - args['provider'], - args['original_provider'], - args['icon'], - args['credentials'], - args['schema_type'], - args['schema'], - args['privacy_policy'], - args['custom_disclaimer'], - args.get('labels', []), + args["provider"], + args["original_provider"], + args["icon"], + args["credentials"], + args["schema_type"], + args["schema"], + args["privacy_policy"], + args["custom_disclaimer"], + args.get("labels", []), ) + class ToolApiProviderDeleteApi(Resource): @setup_required @login_required @@ -226,22 +247,23 @@ class ToolApiProviderDeleteApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('provider', type=str, required=True, nullable=False, location='json') + parser.add_argument("provider", type=str, required=True, nullable=False, location="json") args = parser.parse_args() return ApiToolManageService.delete_api_tool_provider( user_id, tenant_id, - args['provider'], + args["provider"], ) + class ToolApiProviderGetApi(Resource): @setup_required @login_required @@ -252,16 +274,17 @@ class ToolApiProviderGetApi(Resource): parser = reqparse.RequestParser() - parser.add_argument('provider', type=str, required=True, nullable=False, location='args') + parser.add_argument("provider", type=str, required=True, nullable=False, location="args") args = parser.parse_args() return ApiToolManageService.get_api_tool_provider( user_id, tenant_id, - args['provider'], + args["provider"], ) + class ToolBuiltinProviderCredentialsSchemaApi(Resource): @setup_required @login_required @@ -269,6 +292,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): def get(self, provider): return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider) + class ToolApiProviderSchemaApi(Resource): @setup_required @login_required @@ -276,14 +300,15 @@ class ToolApiProviderSchemaApi(Resource): def post(self): parser = reqparse.RequestParser() - parser.add_argument('schema', type=str, required=True, nullable=False, location='json') + parser.add_argument("schema", type=str, required=True, nullable=False, location="json") args = parser.parse_args() return ApiToolManageService.parser_api_schema( - schema=args['schema'], + schema=args["schema"], ) + class ToolApiProviderPreviousTestApi(Resource): @setup_required @login_required @@ -291,25 +316,26 @@ class ToolApiProviderPreviousTestApi(Resource): def post(self): parser = reqparse.RequestParser() - parser.add_argument('tool_name', type=str, required=True, nullable=False, location='json') - parser.add_argument('provider_name', type=str, required=False, nullable=False, location='json') - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') - parser.add_argument('parameters', type=dict, required=True, nullable=False, location='json') - parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json') - parser.add_argument('schema', type=str, required=True, nullable=False, location='json') + parser.add_argument("tool_name", type=str, required=True, nullable=False, location="json") + parser.add_argument("provider_name", type=str, required=False, nullable=False, location="json") + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("parameters", type=dict, required=True, nullable=False, location="json") + parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json") + parser.add_argument("schema", type=str, required=True, nullable=False, location="json") args = parser.parse_args() return ApiToolManageService.test_api_tool_preview( current_user.current_tenant_id, - args['provider_name'] if args['provider_name'] else '', - args['tool_name'], - args['credentials'], - args['parameters'], - args['schema_type'], - args['schema'], + args["provider_name"] if args["provider_name"] else "", + args["tool_name"], + args["credentials"], + args["parameters"], + args["schema_type"], + args["schema"], ) + class ToolWorkflowProviderCreateApi(Resource): @setup_required @login_required @@ -317,35 +343,36 @@ class ToolWorkflowProviderCreateApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id reqparser = reqparse.RequestParser() - reqparser.add_argument('workflow_app_id', type=uuid_value, required=True, nullable=False, location='json') - reqparser.add_argument('name', type=alphanumeric, required=True, nullable=False, location='json') - reqparser.add_argument('label', type=str, required=True, nullable=False, location='json') - reqparser.add_argument('description', type=str, required=True, nullable=False, location='json') - reqparser.add_argument('icon', type=dict, required=True, nullable=False, location='json') - reqparser.add_argument('parameters', type=list[dict], required=True, nullable=False, location='json') - reqparser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json', default='') - reqparser.add_argument('labels', type=list[str], required=False, nullable=True, location='json') + reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json") + reqparser.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") + reqparser.add_argument("label", type=str, required=True, nullable=False, location="json") + reqparser.add_argument("description", type=str, required=True, nullable=False, location="json") + reqparser.add_argument("icon", type=dict, required=True, nullable=False, location="json") + reqparser.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") + reqparser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") + reqparser.add_argument("labels", type=list[str], required=False, nullable=True, location="json") args = reqparser.parse_args() return WorkflowToolManageService.create_workflow_tool( user_id, tenant_id, - args['workflow_app_id'], - args['name'], - args['label'], - args['icon'], - args['description'], - args['parameters'], - args['privacy_policy'], - args.get('labels', []), + args["workflow_app_id"], + args["name"], + args["label"], + args["icon"], + args["description"], + args["parameters"], + args["privacy_policy"], + args.get("labels", []), ) + class ToolWorkflowProviderUpdateApi(Resource): @setup_required @login_required @@ -353,38 +380,39 @@ class ToolWorkflowProviderUpdateApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id reqparser = reqparse.RequestParser() - reqparser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='json') - reqparser.add_argument('name', type=alphanumeric, required=True, nullable=False, location='json') - reqparser.add_argument('label', type=str, required=True, nullable=False, location='json') - reqparser.add_argument('description', type=str, required=True, nullable=False, location='json') - reqparser.add_argument('icon', type=dict, required=True, nullable=False, location='json') - reqparser.add_argument('parameters', type=list[dict], required=True, nullable=False, location='json') - reqparser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json', default='') - reqparser.add_argument('labels', type=list[str], required=False, nullable=True, location='json') - + reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") + reqparser.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") + reqparser.add_argument("label", type=str, required=True, nullable=False, location="json") + reqparser.add_argument("description", type=str, required=True, nullable=False, location="json") + reqparser.add_argument("icon", type=dict, required=True, nullable=False, location="json") + reqparser.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") + reqparser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") + reqparser.add_argument("labels", type=list[str], required=False, nullable=True, location="json") + args = reqparser.parse_args() - if not args['workflow_tool_id']: - raise ValueError('incorrect workflow_tool_id') - + if not args["workflow_tool_id"]: + raise ValueError("incorrect workflow_tool_id") + return WorkflowToolManageService.update_workflow_tool( user_id, tenant_id, - args['workflow_tool_id'], - args['name'], - args['label'], - args['icon'], - args['description'], - args['parameters'], - args['privacy_policy'], - args.get('labels', []), + args["workflow_tool_id"], + args["name"], + args["label"], + args["icon"], + args["description"], + args["parameters"], + args["privacy_policy"], + args.get("labels", []), ) + class ToolWorkflowProviderDeleteApi(Resource): @setup_required @login_required @@ -392,21 +420,22 @@ class ToolWorkflowProviderDeleteApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id reqparser = reqparse.RequestParser() - reqparser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='json') + reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") args = reqparser.parse_args() return WorkflowToolManageService.delete_workflow_tool( user_id, tenant_id, - args['workflow_tool_id'], + args["workflow_tool_id"], ) - + + class ToolWorkflowProviderGetApi(Resource): @setup_required @login_required @@ -416,28 +445,29 @@ class ToolWorkflowProviderGetApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('workflow_tool_id', type=uuid_value, required=False, nullable=True, location='args') - parser.add_argument('workflow_app_id', type=uuid_value, required=False, nullable=True, location='args') + parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args") + parser.add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args") args = parser.parse_args() - if args.get('workflow_tool_id'): + if args.get("workflow_tool_id"): tool = WorkflowToolManageService.get_workflow_tool_by_tool_id( user_id, tenant_id, - args['workflow_tool_id'], + args["workflow_tool_id"], ) - elif args.get('workflow_app_id'): + elif args.get("workflow_app_id"): tool = WorkflowToolManageService.get_workflow_tool_by_app_id( user_id, tenant_id, - args['workflow_app_id'], + args["workflow_app_id"], ) else: - raise ValueError('incorrect workflow_tool_id or workflow_app_id') + raise ValueError("incorrect workflow_tool_id or workflow_app_id") return jsonable_encoder(tool) - + + class ToolWorkflowProviderListToolApi(Resource): @setup_required @login_required @@ -447,15 +477,18 @@ class ToolWorkflowProviderListToolApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='args') + parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args") args = parser.parse_args() - return jsonable_encoder(WorkflowToolManageService.list_single_workflow_tools( - user_id, - tenant_id, - args['workflow_tool_id'], - )) + return jsonable_encoder( + WorkflowToolManageService.list_single_workflow_tools( + user_id, + tenant_id, + args["workflow_tool_id"], + ) + ) + class ToolBuiltinListApi(Resource): @setup_required @@ -465,11 +498,17 @@ class ToolBuiltinListApi(Resource): user_id = current_user.id tenant_id = current_user.current_tenant_id - return jsonable_encoder([provider.to_dict() for provider in BuiltinToolManageService.list_builtin_tools( - user_id, - tenant_id, - )]) - + return jsonable_encoder( + [ + provider.to_dict() + for provider in BuiltinToolManageService.list_builtin_tools( + user_id, + tenant_id, + ) + ] + ) + + class ToolApiListApi(Resource): @setup_required @login_required @@ -478,11 +517,17 @@ class ToolApiListApi(Resource): user_id = current_user.id tenant_id = current_user.current_tenant_id - return jsonable_encoder([provider.to_dict() for provider in ApiToolManageService.list_api_tools( - user_id, - tenant_id, - )]) - + return jsonable_encoder( + [ + provider.to_dict() + for provider in ApiToolManageService.list_api_tools( + user_id, + tenant_id, + ) + ] + ) + + class ToolWorkflowListApi(Resource): @setup_required @login_required @@ -491,11 +536,17 @@ class ToolWorkflowListApi(Resource): user_id = current_user.id tenant_id = current_user.current_tenant_id - return jsonable_encoder([provider.to_dict() for provider in WorkflowToolManageService.list_tenant_workflow_tools( - user_id, - tenant_id, - )]) - + return jsonable_encoder( + [ + provider.to_dict() + for provider in WorkflowToolManageService.list_tenant_workflow_tools( + user_id, + tenant_id, + ) + ] + ) + + class ToolLabelsApi(Resource): @setup_required @login_required @@ -503,36 +554,41 @@ class ToolLabelsApi(Resource): def get(self): return jsonable_encoder(ToolLabelsService.list_tool_labels()) + # tool provider -api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers') +api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers") # builtin tool provider -api.add_resource(ToolBuiltinProviderListToolsApi, '/workspaces/current/tool-provider/builtin//tools') -api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provider/builtin//delete') -api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin//update') -api.add_resource(ToolBuiltinProviderGetCredentialsApi, '/workspaces/current/tool-provider/builtin//credentials') -api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin//credentials_schema') -api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin//icon') +api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin//tools") +api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin//delete") +api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin//update") +api.add_resource( + ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin//credentials" +) +api.add_resource( + ToolBuiltinProviderCredentialsSchemaApi, "/workspaces/current/tool-provider/builtin//credentials_schema" +) +api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin//icon") # api tool provider -api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add') -api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote') -api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools') -api.add_resource(ToolApiProviderUpdateApi, '/workspaces/current/tool-provider/api/update') -api.add_resource(ToolApiProviderDeleteApi, '/workspaces/current/tool-provider/api/delete') -api.add_resource(ToolApiProviderGetApi, '/workspaces/current/tool-provider/api/get') -api.add_resource(ToolApiProviderSchemaApi, '/workspaces/current/tool-provider/api/schema') -api.add_resource(ToolApiProviderPreviousTestApi, '/workspaces/current/tool-provider/api/test/pre') +api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add") +api.add_resource(ToolApiProviderGetRemoteSchemaApi, "/workspaces/current/tool-provider/api/remote") +api.add_resource(ToolApiProviderListToolsApi, "/workspaces/current/tool-provider/api/tools") +api.add_resource(ToolApiProviderUpdateApi, "/workspaces/current/tool-provider/api/update") +api.add_resource(ToolApiProviderDeleteApi, "/workspaces/current/tool-provider/api/delete") +api.add_resource(ToolApiProviderGetApi, "/workspaces/current/tool-provider/api/get") +api.add_resource(ToolApiProviderSchemaApi, "/workspaces/current/tool-provider/api/schema") +api.add_resource(ToolApiProviderPreviousTestApi, "/workspaces/current/tool-provider/api/test/pre") # workflow tool provider -api.add_resource(ToolWorkflowProviderCreateApi, '/workspaces/current/tool-provider/workflow/create') -api.add_resource(ToolWorkflowProviderUpdateApi, '/workspaces/current/tool-provider/workflow/update') -api.add_resource(ToolWorkflowProviderDeleteApi, '/workspaces/current/tool-provider/workflow/delete') -api.add_resource(ToolWorkflowProviderGetApi, '/workspaces/current/tool-provider/workflow/get') -api.add_resource(ToolWorkflowProviderListToolApi, '/workspaces/current/tool-provider/workflow/tools') +api.add_resource(ToolWorkflowProviderCreateApi, "/workspaces/current/tool-provider/workflow/create") +api.add_resource(ToolWorkflowProviderUpdateApi, "/workspaces/current/tool-provider/workflow/update") +api.add_resource(ToolWorkflowProviderDeleteApi, "/workspaces/current/tool-provider/workflow/delete") +api.add_resource(ToolWorkflowProviderGetApi, "/workspaces/current/tool-provider/workflow/get") +api.add_resource(ToolWorkflowProviderListToolApi, "/workspaces/current/tool-provider/workflow/tools") -api.add_resource(ToolBuiltinListApi, '/workspaces/current/tools/builtin') -api.add_resource(ToolApiListApi, '/workspaces/current/tools/api') -api.add_resource(ToolWorkflowListApi, '/workspaces/current/tools/workflow') +api.add_resource(ToolBuiltinListApi, "/workspaces/current/tools/builtin") +api.add_resource(ToolApiListApi, "/workspaces/current/tools/api") +api.add_resource(ToolWorkflowListApi, "/workspaces/current/tools/workflow") -api.add_resource(ToolLabelsApi, '/workspaces/current/tool-labels') \ No newline at end of file +api.add_resource(ToolLabelsApi, "/workspaces/current/tool-labels") diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 7a11a45ae8..623f0b8b74 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -26,39 +26,34 @@ from services.file_service import FileService from services.workspace_service import WorkspaceService provider_fields = { - 'provider_name': fields.String, - 'provider_type': fields.String, - 'is_valid': fields.Boolean, - 'token_is_set': fields.Boolean, + "provider_name": fields.String, + "provider_type": fields.String, + "is_valid": fields.Boolean, + "token_is_set": fields.Boolean, } tenant_fields = { - 'id': fields.String, - 'name': fields.String, - 'plan': fields.String, - 'status': fields.String, - 'created_at': TimestampField, - 'role': fields.String, - 'in_trial': fields.Boolean, - 'trial_end_reason': fields.String, - 'custom_config': fields.Raw(attribute='custom_config'), + "id": fields.String, + "name": fields.String, + "plan": fields.String, + "status": fields.String, + "created_at": TimestampField, + "role": fields.String, + "in_trial": fields.Boolean, + "trial_end_reason": fields.String, + "custom_config": fields.Raw(attribute="custom_config"), } tenants_fields = { - 'id': fields.String, - 'name': fields.String, - 'plan': fields.String, - 'status': fields.String, - 'created_at': TimestampField, - 'current': fields.Boolean + "id": fields.String, + "name": fields.String, + "plan": fields.String, + "status": fields.String, + "created_at": TimestampField, + "current": fields.Boolean, } -workspace_fields = { - 'id': fields.String, - 'name': fields.String, - 'status': fields.String, - 'created_at': TimestampField -} +workspace_fields = {"id": fields.String, "name": fields.String, "status": fields.String, "created_at": TimestampField} class TenantListApi(Resource): @@ -71,7 +66,7 @@ class TenantListApi(Resource): for tenant in tenants: if tenant.id == current_user.current_tenant_id: tenant.current = True # Set current=True for current tenant - return {'workspaces': marshal(tenants, tenants_fields)}, 200 + return {"workspaces": marshal(tenants, tenants_fields)}, 200 class WorkspaceListApi(Resource): @@ -79,31 +74,37 @@ class WorkspaceListApi(Resource): @admin_required def get(self): parser = reqparse.RequestParser() - parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args') - parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") + parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc())\ - .paginate(page=args['page'], per_page=args['limit']) + tenants = ( + db.session.query(Tenant) + .order_by(Tenant.created_at.desc()) + .paginate(page=args["page"], per_page=args["limit"]) + ) has_more = False - if len(tenants.items) == args['limit']: + if len(tenants.items) == args["limit"]: current_page_first_tenant = tenants[-1] - rest_count = db.session.query(Tenant).filter( - Tenant.created_at < current_page_first_tenant.created_at, - Tenant.id != current_page_first_tenant.id - ).count() + rest_count = ( + db.session.query(Tenant) + .filter( + Tenant.created_at < current_page_first_tenant.created_at, Tenant.id != current_page_first_tenant.id + ) + .count() + ) if rest_count > 0: has_more = True total = db.session.query(Tenant).count() return { - 'data': marshal(tenants.items, workspace_fields), - 'has_more': has_more, - 'limit': args['limit'], - 'page': args['page'], - 'total': total - }, 200 + "data": marshal(tenants.items, workspace_fields), + "has_more": has_more, + "limit": args["limit"], + "page": args["page"], + "total": total, + }, 200 class TenantApi(Resource): @@ -112,8 +113,8 @@ class TenantApi(Resource): @account_initialization_required @marshal_with(tenant_fields) def get(self): - if request.path == '/info': - logging.warning('Deprecated URL /info was used.') + if request.path == "/info": + logging.warning("Deprecated URL /info was used.") tenant = current_user.current_tenant @@ -125,7 +126,7 @@ class TenantApi(Resource): tenant = tenants[0] # else, raise Unauthorized else: - raise Unauthorized('workspace is archived') + raise Unauthorized("workspace is archived") return WorkspaceService.get_tenant_info(tenant), 200 @@ -136,62 +137,64 @@ class SwitchWorkspaceApi(Resource): @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('tenant_id', type=str, required=True, location='json') + parser.add_argument("tenant_id", type=str, required=True, location="json") args = parser.parse_args() # check if tenant_id is valid, 403 if not try: - TenantService.switch_tenant(current_user, args['tenant_id']) + TenantService.switch_tenant(current_user, args["tenant_id"]) except Exception: raise AccountNotLinkTenantError("Account not link tenant") - new_tenant = db.session.query(Tenant).get(args['tenant_id']) # Get new tenant + new_tenant = db.session.query(Tenant).get(args["tenant_id"]) # Get new tenant + + return {"result": "success", "new_tenant": marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)} - return {'result': 'success', 'new_tenant': marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)} - class CustomConfigWorkspaceApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('workspace_custom') + @cloud_edition_billing_resource_check("workspace_custom") def post(self): parser = reqparse.RequestParser() - parser.add_argument('remove_webapp_brand', type=bool, location='json') - parser.add_argument('replace_webapp_logo', type=str, location='json') + parser.add_argument("remove_webapp_brand", type=bool, location="json") + parser.add_argument("replace_webapp_logo", type=str, location="json") args = parser.parse_args() tenant = db.session.query(Tenant).filter(Tenant.id == current_user.current_tenant_id).one_or_404() custom_config_dict = { - 'remove_webapp_brand': args['remove_webapp_brand'], - 'replace_webapp_logo': args['replace_webapp_logo'] if args['replace_webapp_logo'] is not None else tenant.custom_config_dict.get('replace_webapp_logo') , + "remove_webapp_brand": args["remove_webapp_brand"], + "replace_webapp_logo": args["replace_webapp_logo"] + if args["replace_webapp_logo"] is not None + else tenant.custom_config_dict.get("replace_webapp_logo"), } tenant.custom_config_dict = custom_config_dict db.session.commit() - return {'result': 'success', 'tenant': marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)} - + return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)} + class WebappLogoWorkspaceApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('workspace_custom') + @cloud_edition_billing_resource_check("workspace_custom") def post(self): # get file from request - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: raise TooManyFilesError() - extension = file.filename.split('.')[-1] - if extension.lower() not in ['svg', 'png']: + extension = file.filename.split(".")[-1] + if extension.lower() not in ["svg", "png"]: raise UnsupportedFileTypeError() try: @@ -201,14 +204,14 @@ class WebappLogoWorkspaceApi(Resource): raise FileTooLargeError(file_too_large_error.description) except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - - return { 'id': upload_file.id }, 201 + + return {"id": upload_file.id}, 201 -api.add_resource(TenantListApi, '/workspaces') # GET for getting all tenants -api.add_resource(WorkspaceListApi, '/all-workspaces') # GET for getting all tenants -api.add_resource(TenantApi, '/workspaces/current', endpoint='workspaces_current') # GET for getting current tenant info -api.add_resource(TenantApi, '/info', endpoint='info') # Deprecated -api.add_resource(SwitchWorkspaceApi, '/workspaces/switch') # POST for switching tenant -api.add_resource(CustomConfigWorkspaceApi, '/workspaces/custom-config') -api.add_resource(WebappLogoWorkspaceApi, '/workspaces/custom-config/webapp-logo/upload') +api.add_resource(TenantListApi, "/workspaces") # GET for getting all tenants +api.add_resource(WorkspaceListApi, "/all-workspaces") # GET for getting all tenants +api.add_resource(TenantApi, "/workspaces/current", endpoint="workspaces_current") # GET for getting current tenant info +api.add_resource(TenantApi, "/info", endpoint="info") # Deprecated +api.add_resource(SwitchWorkspaceApi, "/workspaces/switch") # POST for switching tenant +api.add_resource(CustomConfigWorkspaceApi, "/workspaces/custom-config") +api.add_resource(WebappLogoWorkspaceApi, "/workspaces/custom-config/webapp-logo/upload") diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 3baf69acfd..7667b30e34 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -16,7 +16,7 @@ def account_initialization_required(view): # check account initialization account = current_user - if account.status == 'uninitialized': + if account.status == "uninitialized": raise AccountNotInitializedError() return view(*args, **kwargs) @@ -27,7 +27,7 @@ def account_initialization_required(view): def only_edition_cloud(view): @wraps(view) def decorated(*args, **kwargs): - if dify_config.EDITION != 'CLOUD': + if dify_config.EDITION != "CLOUD": abort(404) return view(*args, **kwargs) @@ -38,7 +38,7 @@ def only_edition_cloud(view): def only_edition_self_hosted(view): @wraps(view) def decorated(*args, **kwargs): - if dify_config.EDITION != 'SELF_HOSTED': + if dify_config.EDITION != "SELF_HOSTED": abort(404) return view(*args, **kwargs) @@ -46,8 +46,7 @@ def only_edition_self_hosted(view): return decorated -def cloud_edition_billing_resource_check(resource: str, - error_msg: str = "You have reached the limit of your subscription."): +def cloud_edition_billing_resource_check(resource: str): def interceptor(view): @wraps(view) def decorated(*args, **kwargs): @@ -58,23 +57,23 @@ def cloud_edition_billing_resource_check(resource: str, vector_space = features.vector_space documents_upload_quota = features.documents_upload_quota annotation_quota_limit = features.annotation_quota_limit - if resource == 'members' and 0 < members.limit <= members.size: - abort(403, error_msg) - elif resource == 'apps' and 0 < apps.limit <= apps.size: - abort(403, error_msg) - elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size: - abort(403, error_msg) - elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size: + if resource == "members" and 0 < members.limit <= members.size: + abort(403, "The number of members has reached the limit of your subscription.") + elif resource == "apps" and 0 < apps.limit <= apps.size: + abort(403, "The number of apps has reached the limit of your subscription.") + elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size: + abort(403, "The capacity of the vector space has reached the limit of your subscription.") + elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size: # The api of file upload is used in the multiple places, so we need to check the source of the request from datasets - source = request.args.get('source') - if source == 'datasets': - abort(403, error_msg) + source = request.args.get("source") + if source == "datasets": + abort(403, "The number of documents has reached the limit of your subscription.") else: return view(*args, **kwargs) - elif resource == 'workspace_custom' and not features.can_replace_logo: - abort(403, error_msg) - elif resource == 'annotation' and 0 < annotation_quota_limit.limit < annotation_quota_limit.size: - abort(403, error_msg) + elif resource == "workspace_custom" and not features.can_replace_logo: + abort(403, "The workspace custom feature has reached the limit of your subscription.") + elif resource == "annotation" and 0 < annotation_quota_limit.limit < annotation_quota_limit.size: + abort(403, "The annotation quota has reached the limit of your subscription.") else: return view(*args, **kwargs) @@ -85,16 +84,18 @@ def cloud_edition_billing_resource_check(resource: str, return interceptor -def cloud_edition_billing_knowledge_limit_check(resource: str, - error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."): +def cloud_edition_billing_knowledge_limit_check(resource: str): def interceptor(view): @wraps(view) def decorated(*args, **kwargs): features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: - if resource == 'add_segment': - if features.billing.subscription.plan == 'sandbox': - abort(403, error_msg) + if resource == "add_segment": + if features.billing.subscription.plan == "sandbox": + abort( + 403, + "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.", + ) else: return view(*args, **kwargs) @@ -112,7 +113,7 @@ def cloud_utm_record(view): features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: - utm_info = request.cookies.get('utm_info') + utm_info = request.cookies.get("utm_info") if utm_info: utm_info = json.loads(utm_info) diff --git a/api/controllers/files/__init__.py b/api/controllers/files/__init__.py index 8d38ab9866..97d5c3f88f 100644 --- a/api/controllers/files/__init__.py +++ b/api/controllers/files/__init__.py @@ -2,7 +2,7 @@ from flask import Blueprint from libs.external_api import ExternalApi -bp = Blueprint('files', __name__) +bp = Blueprint("files", __name__) api = ExternalApi(bp) diff --git a/api/controllers/files/image_preview.py b/api/controllers/files/image_preview.py index 247b5d45e1..2432285d93 100644 --- a/api/controllers/files/image_preview.py +++ b/api/controllers/files/image_preview.py @@ -13,35 +13,30 @@ class ImagePreviewApi(Resource): def get(self, file_id): file_id = str(file_id) - timestamp = request.args.get('timestamp') - nonce = request.args.get('nonce') - sign = request.args.get('sign') + timestamp = request.args.get("timestamp") + nonce = request.args.get("nonce") + sign = request.args.get("sign") if not timestamp or not nonce or not sign: - return {'content': 'Invalid request.'}, 400 + return {"content": "Invalid request."}, 400 try: - generator, mimetype = FileService.get_image_preview( - file_id, - timestamp, - nonce, - sign - ) + generator, mimetype = FileService.get_image_preview(file_id, timestamp, nonce, sign) except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() return Response(generator, mimetype=mimetype) - + class WorkspaceWebappLogoApi(Resource): def get(self, workspace_id): workspace_id = str(workspace_id) custom_config = TenantService.get_custom_config(workspace_id) - webapp_logo_file_id = custom_config.get('replace_webapp_logo') if custom_config is not None else None + webapp_logo_file_id = custom_config.get("replace_webapp_logo") if custom_config is not None else None if not webapp_logo_file_id: - raise NotFound('webapp logo is not found') + raise NotFound("webapp logo is not found") try: generator, mimetype = FileService.get_public_image_preview( @@ -53,11 +48,11 @@ class WorkspaceWebappLogoApi(Resource): return Response(generator, mimetype=mimetype) -api.add_resource(ImagePreviewApi, '/files//image-preview') -api.add_resource(WorkspaceWebappLogoApi, '/files/workspaces//webapp-logo') +api.add_resource(ImagePreviewApi, "/files//image-preview") +api.add_resource(WorkspaceWebappLogoApi, "/files/workspaces//webapp-logo") class UnsupportedFileTypeError(BaseHTTPException): - error_code = 'unsupported_file_type' + error_code = "unsupported_file_type" description = "File type not allowed." code = 415 diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index 5a07ad2ea5..38ac0815da 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -13,36 +13,39 @@ class ToolFilePreviewApi(Resource): parser = reqparse.RequestParser() - parser.add_argument('timestamp', type=str, required=True, location='args') - parser.add_argument('nonce', type=str, required=True, location='args') - parser.add_argument('sign', type=str, required=True, location='args') + parser.add_argument("timestamp", type=str, required=True, location="args") + parser.add_argument("nonce", type=str, required=True, location="args") + parser.add_argument("sign", type=str, required=True, location="args") args = parser.parse_args() - if not ToolFileManager.verify_file(file_id=file_id, - timestamp=args['timestamp'], - nonce=args['nonce'], - sign=args['sign'], + if not ToolFileManager.verify_file( + file_id=file_id, + timestamp=args["timestamp"], + nonce=args["nonce"], + sign=args["sign"], ): - raise Forbidden('Invalid request.') - + raise Forbidden("Invalid request.") + try: result = ToolFileManager.get_file_generator_by_tool_file_id( file_id, ) if not result: - raise NotFound('file is not found') - + raise NotFound("file is not found") + generator, mimetype = result except Exception: raise UnsupportedFileTypeError() return Response(generator, mimetype=mimetype) -api.add_resource(ToolFilePreviewApi, '/files/tools/.') + +api.add_resource(ToolFilePreviewApi, "/files/tools/.") + class UnsupportedFileTypeError(BaseHTTPException): - error_code = 'unsupported_file_type' + error_code = "unsupported_file_type" description = "File type not allowed." code = 415 diff --git a/api/controllers/inner_api/__init__.py b/api/controllers/inner_api/__init__.py index ad49a649ca..9f124736a9 100644 --- a/api/controllers/inner_api/__init__.py +++ b/api/controllers/inner_api/__init__.py @@ -2,8 +2,7 @@ from flask import Blueprint from libs.external_api import ExternalApi -bp = Blueprint('inner_api', __name__, url_prefix='/inner/api') +bp = Blueprint("inner_api", __name__, url_prefix="/inner/api") api = ExternalApi(bp) from .workspace import workspace - diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py index 06610d8933..914b60f263 100644 --- a/api/controllers/inner_api/workspace/workspace.py +++ b/api/controllers/inner_api/workspace/workspace.py @@ -9,29 +9,24 @@ from services.account_service import TenantService class EnterpriseWorkspace(Resource): - @setup_required @inner_api_only def post(self): parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') - parser.add_argument('owner_email', type=str, required=True, location='json') + parser.add_argument("name", type=str, required=True, location="json") + parser.add_argument("owner_email", type=str, required=True, location="json") args = parser.parse_args() - account = Account.query.filter_by(email=args['owner_email']).first() + account = Account.query.filter_by(email=args["owner_email"]).first() if account is None: - return { - 'message': 'owner account not found.' - }, 404 + return {"message": "owner account not found."}, 404 - tenant = TenantService.create_tenant(args['name']) - TenantService.create_tenant_member(tenant, account, role='owner') + tenant = TenantService.create_tenant(args["name"]) + TenantService.create_tenant_member(tenant, account, role="owner") tenant_was_created.send(tenant) - return { - 'message': 'enterprise workspace created.' - } + return {"message": "enterprise workspace created."} -api.add_resource(EnterpriseWorkspace, '/enterprise/workspace') +api.add_resource(EnterpriseWorkspace, "/enterprise/workspace") diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py index 5c37f5276f..51ffe683ff 100644 --- a/api/controllers/inner_api/wraps.py +++ b/api/controllers/inner_api/wraps.py @@ -17,7 +17,7 @@ def inner_api_only(view): abort(404) # get header 'X-Inner-Api-Key' - inner_api_key = request.headers.get('X-Inner-Api-Key') + inner_api_key = request.headers.get("X-Inner-Api-Key") if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY: abort(401) @@ -33,29 +33,29 @@ def inner_api_user_auth(view): return view(*args, **kwargs) # get header 'X-Inner-Api-Key' - authorization = request.headers.get('Authorization') + authorization = request.headers.get("Authorization") if not authorization: return view(*args, **kwargs) - parts = authorization.split(':') + parts = authorization.split(":") if len(parts) != 2: return view(*args, **kwargs) user_id, token = parts - if ' ' in user_id: - user_id = user_id.split(' ')[1] + if " " in user_id: + user_id = user_id.split(" ")[1] - inner_api_key = request.headers.get('X-Inner-Api-Key') + inner_api_key = request.headers.get("X-Inner-Api-Key") - data_to_sign = f'DIFY {user_id}' + data_to_sign = f"DIFY {user_id}" - signature = hmac_new(inner_api_key.encode('utf-8'), data_to_sign.encode('utf-8'), sha1) - signature = b64encode(signature.digest()).decode('utf-8') + signature = hmac_new(inner_api_key.encode("utf-8"), data_to_sign.encode("utf-8"), sha1) + signature = b64encode(signature.digest()).decode("utf-8") if signature != token: return view(*args, **kwargs) - kwargs['user'] = db.session.query(EndUser).filter(EndUser.id == user_id).first() + kwargs["user"] = db.session.query(EndUser).filter(EndUser.id == user_id).first() return view(*args, **kwargs) diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index 082660a891..ad39c160ac 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -2,7 +2,7 @@ from flask import Blueprint from libs.external_api import ExternalApi -bp = Blueprint('service_api', __name__, url_prefix='/v1') +bp = Blueprint("service_api", __name__, url_prefix="/v1") api = ExternalApi(bp) diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 3b3cf1b026..ecc2d73deb 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,4 +1,3 @@ - from flask_restful import Resource, fields, marshal_with from configs import dify_config @@ -13,32 +12,30 @@ class AppParameterApi(Resource): """Resource for app variables.""" variable_fields = { - 'key': fields.String, - 'name': fields.String, - 'description': fields.String, - 'type': fields.String, - 'default': fields.String, - 'max_length': fields.Integer, - 'options': fields.List(fields.String) + "key": fields.String, + "name": fields.String, + "description": fields.String, + "type": fields.String, + "default": fields.String, + "max_length": fields.Integer, + "options": fields.List(fields.String), } - system_parameters_fields = { - 'image_file_size_limit': fields.String - } + system_parameters_fields = {"image_file_size_limit": fields.String} parameters_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw, - 'suggested_questions_after_answer': fields.Raw, - 'speech_to_text': fields.Raw, - 'text_to_speech': fields.Raw, - 'retriever_resource': fields.Raw, - 'annotation_reply': fields.Raw, - 'more_like_this': fields.Raw, - 'user_input_form': fields.Raw, - 'sensitive_word_avoidance': fields.Raw, - 'file_upload': fields.Raw, - 'system_parameters': fields.Nested(system_parameters_fields) + "opening_statement": fields.String, + "suggested_questions": fields.Raw, + "suggested_questions_after_answer": fields.Raw, + "speech_to_text": fields.Raw, + "text_to_speech": fields.Raw, + "retriever_resource": fields.Raw, + "annotation_reply": fields.Raw, + "more_like_this": fields.Raw, + "user_input_form": fields.Raw, + "sensitive_word_avoidance": fields.Raw, + "file_upload": fields.Raw, + "system_parameters": fields.Nested(system_parameters_fields), } @validate_app_token @@ -56,30 +53,35 @@ class AppParameterApi(Resource): app_model_config = app_model.app_model_config features_dict = app_model_config.to_dict() - user_input_form = features_dict.get('user_input_form', []) + user_input_form = features_dict.get("user_input_form", []) return { - 'opening_statement': features_dict.get('opening_statement'), - 'suggested_questions': features_dict.get('suggested_questions', []), - 'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer', - {"enabled": False}), - 'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}), - 'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}), - 'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}), - 'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}), - 'more_like_this': features_dict.get('more_like_this', {"enabled": False}), - 'user_input_form': user_input_form, - 'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance', - {"enabled": False, "type": "", "configs": []}), - 'file_upload': features_dict.get('file_upload', {"image": { - "enabled": False, - "number_limits": 3, - "detail": "high", - "transfer_methods": ["remote_url", "local_file"] - }}), - 'system_parameters': { - 'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT - } + "opening_statement": features_dict.get("opening_statement"), + "suggested_questions": features_dict.get("suggested_questions", []), + "suggested_questions_after_answer": features_dict.get( + "suggested_questions_after_answer", {"enabled": False} + ), + "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}), + "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}), + "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}), + "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}), + "more_like_this": features_dict.get("more_like_this", {"enabled": False}), + "user_input_form": user_input_form, + "sensitive_word_avoidance": features_dict.get( + "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []} + ), + "file_upload": features_dict.get( + "file_upload", + { + "image": { + "enabled": False, + "number_limits": 3, + "detail": "high", + "transfer_methods": ["remote_url", "local_file"], + } + }, + ), + "system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT}, } @@ -89,16 +91,14 @@ class AppMetaApi(Resource): """Get app meta""" return AppService().get_app_meta(app_model) + class AppInfoApi(Resource): @validate_app_token def get(self, app_model: App): - """Get app infomation""" - return { - 'name':app_model.name, - 'description':app_model.description - } + """Get app information""" + return {"name": app_model.name, "description": app_model.description} -api.add_resource(AppParameterApi, '/parameters') -api.add_resource(AppMetaApi, '/meta') -api.add_resource(AppInfoApi, '/info') +api.add_resource(AppParameterApi, "/parameters") +api.add_resource(AppMetaApi, "/meta") +api.add_resource(AppInfoApi, "/info") diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 3c009af343..85aab047a7 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -33,14 +33,10 @@ from services.errors.audio import ( class AudioApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) def post(self, app_model: App, end_user: EndUser): - file = request.files['file'] + file = request.files["file"] try: - response = AudioService.transcript_asr( - app_model=app_model, - file=file, - end_user=end_user - ) + response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user) return response except services.errors.app_model_config.AppModelConfigBrokenError: @@ -74,30 +70,32 @@ class TextApi(Resource): def post(self, app_model: App, end_user: EndUser): try: parser = reqparse.RequestParser() - parser.add_argument('message_id', type=str, required=False, location='json') - parser.add_argument('voice', type=str, location='json') - parser.add_argument('text', type=str, location='json') - parser.add_argument('streaming', type=bool, location='json') + parser.add_argument("message_id", type=str, required=False, location="json") + parser.add_argument("voice", type=str, location="json") + parser.add_argument("text", type=str, location="json") + parser.add_argument("streaming", type=bool, location="json") args = parser.parse_args() - message_id = args.get('message_id', None) - text = args.get('text', None) - if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] - and app_model.workflow - and app_model.workflow.features_dict): - text_to_speech = app_model.workflow.features_dict.get('text_to_speech') - voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice') + message_id = args.get("message_id", None) + text = args.get("text", None) + if ( + app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + and app_model.workflow + and app_model.workflow.features_dict + ): + text_to_speech = app_model.workflow.features_dict.get("text_to_speech") + voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") else: try: - voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice') + voice = ( + args.get("voice") + if args.get("voice") + else app_model.app_model_config.text_to_speech_dict.get("voice") + ) except Exception: voice = None response = AudioService.transcript_tts( - app_model=app_model, - message_id=message_id, - end_user=end_user.external_user_id, - voice=voice, - text=text + app_model=app_model, message_id=message_id, end_user=end_user.external_user_id, voice=voice, text=text ) return response @@ -127,5 +125,5 @@ class TextApi(Resource): raise InternalServerError() -api.add_resource(AudioApi, '/audio-to-text') -api.add_resource(TextApi, '/text-to-audio') +api.add_resource(AudioApi, "/audio-to-text") +api.add_resource(TextApi, "/text-to-audio") diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 2511f46baf..f1771baf31 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -33,21 +33,21 @@ from services.app_generate_service import AppGenerateService class CompletionApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise AppUnavailableError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, location='json', default='') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, location="json", default="") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' + streaming = args["response_mode"] == "streaming" - args['auto_generate_name'] = False + args["auto_generate_name"] = False try: response = AppGenerateService.generate( @@ -84,12 +84,12 @@ class CompletionApi(Resource): class CompletionStopApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser, task_id): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise AppUnavailableError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class ChatApi(Resource): @@ -100,25 +100,21 @@ class ChatApi(Resource): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, required=True, location='json') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('conversation_id', type=uuid_value, location='json') - parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') - parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, required=True, location="json") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("conversation_id", type=uuid_value, location="json") + parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") + parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json") args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' + streaming = args["response_mode"] == "streaming" try: response = AppGenerateService.generate( - app_model=app_model, - user=end_user, - args=args, - invoke_from=InvokeFrom.SERVICE_API, - streaming=streaming + app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming ) return helper.compact_generate_response(response) @@ -153,10 +149,10 @@ class ChatStopApi(Resource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(CompletionApi, '/completion-messages') -api.add_resource(CompletionStopApi, '/completion-messages//stop') -api.add_resource(ChatApi, '/chat-messages') -api.add_resource(ChatStopApi, '/chat-messages//stop') +api.add_resource(CompletionApi, "/completion-messages") +api.add_resource(CompletionStopApi, "/completion-messages//stop") +api.add_resource(ChatApi, "/chat-messages") +api.add_resource(ChatStopApi, "/chat-messages//stop") diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 44bda8e771..734027a1c5 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -14,7 +14,6 @@ from services.conversation_service import ConversationService class ConversationApi(Resource): - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, app_model: App, end_user: EndUser): @@ -23,17 +22,26 @@ class ConversationApi(Resource): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + parser.add_argument( + "sort_by", + type=str, + choices=["created_at", "-created_at", "updated_at", "-updated_at"], + required=False, + default="-updated_at", + location="args", + ) args = parser.parse_args() try: return ConversationService.pagination_by_last_id( app_model=app_model, user=end_user, - last_id=args['last_id'], - limit=args['limit'], - invoke_from=InvokeFrom.SERVICE_API + last_id=args["last_id"], + limit=args["limit"], + invoke_from=InvokeFrom.SERVICE_API, + sort_by=args["sort_by"], ) except services.errors.conversation.LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") @@ -53,11 +61,10 @@ class ConversationDetailApi(Resource): ConversationService.delete(app_model, conversation_id, end_user) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class ConversationRenameApi(Resource): - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) @marshal_with(simple_conversation_fields) def post(self, app_model: App, end_user: EndUser, c_id): @@ -68,22 +75,16 @@ class ConversationRenameApi(Resource): conversation_id = str(c_id) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=False, location='json') - parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json') + parser.add_argument("name", type=str, required=False, location="json") + parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json") args = parser.parse_args() try: - return ConversationService.rename( - app_model, - conversation_id, - end_user, - args['name'], - args['auto_generate'] - ) + return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"]) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") -api.add_resource(ConversationRenameApi, '/conversations//name', endpoint='conversation_name') -api.add_resource(ConversationApi, '/conversations') -api.add_resource(ConversationDetailApi, '/conversations/', endpoint='conversation_detail') +api.add_resource(ConversationRenameApi, "/conversations//name", endpoint="conversation_name") +api.add_resource(ConversationApi, "/conversations") +api.add_resource(ConversationDetailApi, "/conversations/", endpoint="conversation_detail") diff --git a/api/controllers/service_api/app/error.py b/api/controllers/service_api/app/error.py index ac9edb1b4f..ca91da80c1 100644 --- a/api/controllers/service_api/app/error.py +++ b/api/controllers/service_api/app/error.py @@ -2,104 +2,108 @@ from libs.exception import BaseHTTPException class AppUnavailableError(BaseHTTPException): - error_code = 'app_unavailable' + error_code = "app_unavailable" description = "App unavailable, please check your app configurations." code = 400 class NotCompletionAppError(BaseHTTPException): - error_code = 'not_completion_app' + error_code = "not_completion_app" description = "Please check if your Completion app mode matches the right API route." code = 400 class NotChatAppError(BaseHTTPException): - error_code = 'not_chat_app' + error_code = "not_chat_app" description = "Please check if your app mode matches the right API route." code = 400 class NotWorkflowAppError(BaseHTTPException): - error_code = 'not_workflow_app' + error_code = "not_workflow_app" description = "Please check if your app mode matches the right API route." code = 400 class ConversationCompletedError(BaseHTTPException): - error_code = 'conversation_completed' + error_code = "conversation_completed" description = "The conversation has ended. Please start a new conversation." code = 400 class ProviderNotInitializeError(BaseHTTPException): - error_code = 'provider_not_initialize' - description = "No valid model provider credentials found. " \ - "Please go to Settings -> Model Provider to complete your provider credentials." + error_code = "provider_not_initialize" + description = ( + "No valid model provider credentials found. " + "Please go to Settings -> Model Provider to complete your provider credentials." + ) code = 400 class ProviderQuotaExceededError(BaseHTTPException): - error_code = 'provider_quota_exceeded' - description = "Your quota for Dify Hosted OpenAI has been exhausted. " \ - "Please go to Settings -> Model Provider to complete your own provider credentials." + error_code = "provider_quota_exceeded" + description = ( + "Your quota for Dify Hosted OpenAI has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials." + ) code = 400 class ProviderModelCurrentlyNotSupportError(BaseHTTPException): - error_code = 'model_currently_not_support' + error_code = "model_currently_not_support" description = "Dify Hosted OpenAI trial currently not support the GPT-4 model." code = 400 class CompletionRequestError(BaseHTTPException): - error_code = 'completion_request_error' + error_code = "completion_request_error" description = "Completion request failed." code = 400 class NoAudioUploadedError(BaseHTTPException): - error_code = 'no_audio_uploaded' + error_code = "no_audio_uploaded" description = "Please upload your audio." code = 400 class AudioTooLargeError(BaseHTTPException): - error_code = 'audio_too_large' + error_code = "audio_too_large" description = "Audio size exceeded. {message}" code = 413 class UnsupportedAudioTypeError(BaseHTTPException): - error_code = 'unsupported_audio_type' + error_code = "unsupported_audio_type" description = "Audio type not allowed." code = 415 class ProviderNotSupportSpeechToTextError(BaseHTTPException): - error_code = 'provider_not_support_speech_to_text' + error_code = "provider_not_support_speech_to_text" description = "Provider not support speech to text." code = 400 class NoFileUploadedError(BaseHTTPException): - error_code = 'no_file_uploaded' + error_code = "no_file_uploaded" description = "Please upload your file." code = 400 class TooManyFilesError(BaseHTTPException): - error_code = 'too_many_files' + error_code = "too_many_files" description = "Only one file is allowed." code = 400 class FileTooLargeError(BaseHTTPException): - error_code = 'file_too_large' + error_code = "file_too_large" description = "File size exceeded. {message}" code = 413 class UnsupportedFileTypeError(BaseHTTPException): - error_code = 'unsupported_file_type' + error_code = "unsupported_file_type" description = "File type not allowed." code = 415 diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index 5dbc1b1d1b..e0a772eb31 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -16,15 +16,13 @@ from services.file_service import FileService class FileApi(Resource): - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) @marshal_with(file_fields) def post(self, app_model: App, end_user: EndUser): - - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if not file.mimetype: @@ -43,4 +41,4 @@ class FileApi(Resource): return upload_file, 201 -api.add_resource(FileApi, '/files/upload') +api.add_resource(FileApi, "/files/upload") diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index 875870e667..b39aaf7dd8 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -17,61 +17,59 @@ from services.message_service import MessageService class MessageListApi(Resource): - feedback_fields = { - 'rating': fields.String - } + feedback_fields = {"rating": fields.String} retriever_resource_fields = { - 'id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'dataset_id': fields.String, - 'dataset_name': fields.String, - 'document_id': fields.String, - 'document_name': fields.String, - 'data_source_type': fields.String, - 'segment_id': fields.String, - 'score': fields.Float, - 'hit_count': fields.Integer, - 'word_count': fields.Integer, - 'segment_position': fields.Integer, - 'index_node_hash': fields.String, - 'content': fields.String, - 'created_at': TimestampField + "id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "dataset_id": fields.String, + "dataset_name": fields.String, + "document_id": fields.String, + "document_name": fields.String, + "data_source_type": fields.String, + "segment_id": fields.String, + "score": fields.Float, + "hit_count": fields.Integer, + "word_count": fields.Integer, + "segment_position": fields.Integer, + "index_node_hash": fields.String, + "content": fields.String, + "created_at": TimestampField, } agent_thought_fields = { - 'id': fields.String, - 'chain_id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'thought': fields.String, - 'tool': fields.String, - 'tool_labels': fields.Raw, - 'tool_input': fields.String, - 'created_at': TimestampField, - 'observation': fields.String, - 'message_files': fields.List(fields.String, attribute='files') + "id": fields.String, + "chain_id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "thought": fields.String, + "tool": fields.String, + "tool_labels": fields.Raw, + "tool_input": fields.String, + "created_at": TimestampField, + "observation": fields.String, + "message_files": fields.List(fields.String, attribute="files"), } message_fields = { - 'id': fields.String, - 'conversation_id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'answer': fields.String(attribute='re_sign_file_url_answer'), - 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), - 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), - 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), - 'created_at': TimestampField, - 'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)), - 'status': fields.String, - 'error': fields.String, + "id": fields.String, + "conversation_id": fields.String, + "inputs": fields.Raw, + "query": fields.String, + "answer": fields.String(attribute="re_sign_file_url_answer"), + "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), + "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), + "created_at": TimestampField, + "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), + "status": fields.String, + "error": fields.String, } message_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(message_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(message_fields)), } @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) @@ -82,14 +80,15 @@ class MessageListApi(Resource): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') - parser.add_argument('first_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") + parser.add_argument("first_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() try: - return MessageService.pagination_by_first_id(app_model, end_user, - args['conversation_id'], args['first_id'], args['limit']) + return MessageService.pagination_by_first_id( + app_model, end_user, args["conversation_id"], args["first_id"], args["limit"] + ) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.message.FirstMessageNotExistsError: @@ -102,15 +101,15 @@ class MessageFeedbackApi(Resource): message_id = str(message_id) parser = reqparse.RequestParser() - parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') + parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, end_user, args['rating']) + MessageService.create_feedback(app_model, message_id, end_user, args["rating"]) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") - return {'result': 'success'} + return {"result": "success"} class MessageSuggestedApi(Resource): @@ -123,10 +122,7 @@ class MessageSuggestedApi(Resource): try: questions = MessageService.get_suggested_questions_after_answer( - app_model=app_model, - user=end_user, - message_id=message_id, - invoke_from=InvokeFrom.SERVICE_API + app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.SERVICE_API ) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") @@ -136,9 +132,9 @@ class MessageSuggestedApi(Resource): logging.exception("internal server error.") raise InternalServerError() - return {'result': 'success', 'data': questions} + return {"result": "success", "data": questions} -api.add_resource(MessageListApi, '/messages') -api.add_resource(MessageFeedbackApi, '/messages//feedbacks') -api.add_resource(MessageSuggestedApi, '/messages//suggested') +api.add_resource(MessageListApi, "/messages") +api.add_resource(MessageFeedbackApi, "/messages//feedbacks") +api.add_resource(MessageSuggestedApi, "/messages//suggested") diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 9446f9d588..5822e0921b 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -30,19 +30,20 @@ from services.app_generate_service import AppGenerateService logger = logging.getLogger(__name__) workflow_run_fields = { - 'id': fields.String, - 'workflow_id': fields.String, - 'status': fields.String, - 'inputs': fields.Raw, - 'outputs': fields.Raw, - 'error': fields.String, - 'total_steps': fields.Integer, - 'total_tokens': fields.Integer, - 'created_at': fields.DateTime, - 'finished_at': fields.DateTime, - 'elapsed_time': fields.Float, + "id": fields.String, + "workflow_id": fields.String, + "status": fields.String, + "inputs": fields.Raw, + "outputs": fields.Raw, + "error": fields.String, + "total_steps": fields.Integer, + "total_tokens": fields.Integer, + "created_at": fields.DateTime, + "finished_at": fields.DateTime, + "elapsed_time": fields.Float, } + class WorkflowRunDetailApi(Resource): @validate_app_token @marshal_with(workflow_run_fields) @@ -56,6 +57,8 @@ class WorkflowRunDetailApi(Resource): workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_id).first() return workflow_run + + class WorkflowRunApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser): @@ -67,20 +70,16 @@ class WorkflowRunApi(Resource): raise NotWorkflowAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") args = parser.parse_args() - streaming = args.get('response_mode') == 'streaming' + streaming = args.get("response_mode") == "streaming" try: response = AppGenerateService.generate( - app_model=app_model, - user=end_user, - args=args, - invoke_from=InvokeFrom.SERVICE_API, - streaming=streaming + app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming ) return helper.compact_generate_response(response) @@ -111,11 +110,9 @@ class WorkflowTaskStopApi(Resource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) - return { - "result": "success" - } + return {"result": "success"} -api.add_resource(WorkflowRunApi, '/workflows/run') -api.add_resource(WorkflowRunDetailApi, '/workflows/run/') -api.add_resource(WorkflowTaskStopApi, '/workflows/tasks//stop') +api.add_resource(WorkflowRunApi, "/workflows/run") +api.add_resource(WorkflowRunDetailApi, "/workflows/run/") +api.add_resource(WorkflowTaskStopApi, "/workflows/tasks//stop") diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 8dd16c0787..c2c0672a03 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -10,13 +10,13 @@ from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager from fields.dataset_fields import dataset_detail_fields from libs.login import current_user -from models.dataset import Dataset +from models.dataset import Dataset, DatasetPermissionEnum from services.dataset_service import DatasetService def _validate_name(name): if not name or len(name) < 1 or len(name) > 40: - raise ValueError('Name must be between 1 to 40 characters.') + raise ValueError("Name must be between 1 to 40 characters.") return name @@ -26,24 +26,18 @@ class DatasetListApi(DatasetApiResource): def get(self, tenant_id): """Resource for getting datasets.""" - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) - provider = request.args.get('provider', default="vendor") - search = request.args.get('keyword', default=None, type=str) - tag_ids = request.args.getlist('tag_ids') + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) + provider = request.args.get("provider", default="vendor") + search = request.args.get("keyword", default=None, type=str) + tag_ids = request.args.getlist("tag_ids") - datasets, total = DatasetService.get_datasets(page, limit, provider, - tenant_id, current_user, search, tag_ids) + datasets, total = DatasetService.get_datasets(page, limit, provider, tenant_id, current_user, search, tag_ids) # check embedding setting provider_manager = ProviderManager() - configurations = provider_manager.get_configurations( - tenant_id=current_user.current_tenant_id - ) + configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) - embedding_models = configurations.get_models( - model_type=ModelType.TEXT_EMBEDDING, - only_active=True - ) + embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) model_names = [] for embedding_model in embedding_models: @@ -51,47 +45,59 @@ class DatasetListApi(DatasetApiResource): data = marshal(datasets, dataset_detail_fields) for item in data: - if item['indexing_technique'] == 'high_quality': + if item["indexing_technique"] == "high_quality": item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" if item_model in model_names: - item['embedding_available'] = True + item["embedding_available"] = True else: - item['embedding_available'] = False + item["embedding_available"] = False else: - item['embedding_available'] = True - response = { - 'data': data, - 'has_more': len(datasets) == limit, - 'limit': limit, - 'total': total, - 'page': page - } + item["embedding_available"] = True + response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} return response, 200 - def post(self, tenant_id): """Resource for creating datasets.""" parser = reqparse.RequestParser() - parser.add_argument('name', nullable=False, required=True, - help='type is required. Name must be between 1 to 40 characters.', - type=_validate_name) - parser.add_argument('indexing_technique', type=str, location='json', - choices=Dataset.INDEXING_TECHNIQUE_LIST, - help='Invalid indexing technique.') + parser.add_argument( + "name", + nullable=False, + required=True, + help="type is required. Name must be between 1 to 40 characters.", + type=_validate_name, + ) + parser.add_argument( + "indexing_technique", + type=str, + location="json", + choices=Dataset.INDEXING_TECHNIQUE_LIST, + help="Invalid indexing technique.", + ) + parser.add_argument( + "permission", + type=str, + location="json", + choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), + help="Invalid permission.", + required=False, + nullable=False, + ) args = parser.parse_args() try: dataset = DatasetService.create_empty_dataset( tenant_id=tenant_id, - name=args['name'], - indexing_technique=args['indexing_technique'], - account=current_user + name=args["name"], + indexing_technique=args["indexing_technique"], + account=current_user, + permission=args["permission"], ) except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() return marshal(dataset, dataset_detail_fields), 200 + class DatasetApi(DatasetApiResource): """Resource for dataset.""" @@ -103,7 +109,7 @@ class DatasetApi(DatasetApiResource): dataset_id (UUID): The ID of the dataset to be deleted. Returns: - dict: A dictionary with a key 'result' and a value 'success' + dict: A dictionary with a key 'result' and a value 'success' if the dataset was successfully deleted. Omitted in HTTP response. int: HTTP status code 204 indicating that the operation was successful. @@ -115,11 +121,12 @@ class DatasetApi(DatasetApiResource): try: if DatasetService.delete_dataset(dataset_id_str, current_user): - return {'result': 'success'}, 204 + return {"result": "success"}, 204 else: raise NotFound("Dataset not found.") except services.errors.dataset.DatasetInUseError: raise DatasetInUseError() -api.add_resource(DatasetListApi, '/datasets') -api.add_resource(DatasetApi, '/datasets/') + +api.add_resource(DatasetListApi, "/datasets") +api.add_resource(DatasetApi, "/datasets/") diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index ac1ea820a6..fb48a6c76c 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -27,47 +27,40 @@ from services.file_service import FileService class DocumentAddByTextApi(DatasetApiResource): """Resource for documents.""" - @cloud_edition_billing_resource_check('vector_space', 'dataset') - @cloud_edition_billing_resource_check('documents', 'dataset') + @cloud_edition_billing_resource_check("vector_space", "dataset") + @cloud_edition_billing_resource_check("documents", "dataset") def post(self, tenant_id, dataset_id): """Create document by text.""" parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, nullable=False, location='json') - parser.add_argument('text', type=str, required=True, nullable=False, location='json') - parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json') - parser.add_argument('original_document_id', type=str, required=False, location='json') - parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, - location='json') - parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, - location='json') - parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, - location='json') + parser.add_argument("name", type=str, required=True, nullable=False, location="json") + parser.add_argument("text", type=str, required=True, nullable=False, location="json") + parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") + parser.add_argument("original_document_id", type=str, required=False, location="json") + parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" + ) + parser.add_argument( + "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" + ) + parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") args = parser.parse_args() dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise ValueError('Dataset is not exist.') + raise ValueError("Dataset is not exist.") - if not dataset.indexing_technique and not args['indexing_technique']: - raise ValueError('indexing_technique is required.') + if not dataset.indexing_technique and not args["indexing_technique"]: + raise ValueError("indexing_technique is required.") - upload_file = FileService.upload_text(args.get('text'), args.get('name')) + upload_file = FileService.upload_text(args.get("text"), args.get("name")) data_source = { - 'type': 'upload_file', - 'info_list': { - 'data_source_type': 'upload_file', - 'file_info_list': { - 'file_ids': [upload_file.id] - } - } + "type": "upload_file", + "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, } - args['data_source'] = data_source + args["data_source"] = data_source # validate args DocumentService.document_create_args_validate(args) @@ -76,60 +69,49 @@ class DocumentAddByTextApi(DatasetApiResource): dataset=dataset, document_data=args, account=current_user, - dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, - created_from='api' + dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, + created_from="api", ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) document = documents[0] - documents_and_batch_fields = { - 'document': marshal(document, document_fields), - 'batch': batch - } + documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} return documents_and_batch_fields, 200 class DocumentUpdateByTextApi(DatasetApiResource): """Resource for update documents.""" - @cloud_edition_billing_resource_check('vector_space', 'dataset') + @cloud_edition_billing_resource_check("vector_space", "dataset") def post(self, tenant_id, dataset_id, document_id): """Update document by text.""" parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=False, nullable=True, location='json') - parser.add_argument('text', type=str, required=False, nullable=True, location='json') - parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json') - parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, - location='json') - parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, - location='json') + parser.add_argument("name", type=str, required=False, nullable=True, location="json") + parser.add_argument("text", type=str, required=False, nullable=True, location="json") + parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") + parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" + ) + parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") args = parser.parse_args() dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise ValueError('Dataset is not exist.') + raise ValueError("Dataset is not exist.") - if args['text']: - upload_file = FileService.upload_text(args.get('text'), args.get('name')) + if args["text"]: + upload_file = FileService.upload_text(args.get("text"), args.get("name")) data_source = { - 'type': 'upload_file', - 'info_list': { - 'data_source_type': 'upload_file', - 'file_info_list': { - 'file_ids': [upload_file.id] - } - } + "type": "upload_file", + "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, } - args['data_source'] = data_source + args["data_source"] = data_source # validate args - args['original_document_id'] = str(document_id) + args["original_document_id"] = str(document_id) DocumentService.document_create_args_validate(args) try: @@ -137,65 +119,53 @@ class DocumentUpdateByTextApi(DatasetApiResource): dataset=dataset, document_data=args, account=current_user, - dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, - created_from='api' + dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, + created_from="api", ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) document = documents[0] - documents_and_batch_fields = { - 'document': marshal(document, document_fields), - 'batch': batch - } + documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} return documents_and_batch_fields, 200 class DocumentAddByFileApi(DatasetApiResource): """Resource for documents.""" - @cloud_edition_billing_resource_check('vector_space', 'dataset') - @cloud_edition_billing_resource_check('documents', 'dataset') + + @cloud_edition_billing_resource_check("vector_space", "dataset") + @cloud_edition_billing_resource_check("documents", "dataset") def post(self, tenant_id, dataset_id): """Create document by upload file.""" args = {} - if 'data' in request.form: - args = json.loads(request.form['data']) - if 'doc_form' not in args: - args['doc_form'] = 'text_model' - if 'doc_language' not in args: - args['doc_language'] = 'English' + if "data" in request.form: + args = json.loads(request.form["data"]) + if "doc_form" not in args: + args["doc_form"] = "text_model" + if "doc_language" not in args: + args["doc_language"] = "English" # get dataset info dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise ValueError('Dataset is not exist.') - if not dataset.indexing_technique and not args.get('indexing_technique'): - raise ValueError('indexing_technique is required.') + raise ValueError("Dataset is not exist.") + if not dataset.indexing_technique and not args.get("indexing_technique"): + raise ValueError("indexing_technique is required.") # save file info - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: raise TooManyFilesError() upload_file = FileService.upload_file(file, current_user) - data_source = { - 'type': 'upload_file', - 'info_list': { - 'file_info_list': { - 'file_ids': [upload_file.id] - } - } - } - args['data_source'] = data_source + data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} + args["data_source"] = data_source # validate args DocumentService.document_create_args_validate(args) @@ -204,63 +174,49 @@ class DocumentAddByFileApi(DatasetApiResource): dataset=dataset, document_data=args, account=dataset.created_by_account, - dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, - created_from='api' + dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, + created_from="api", ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) document = documents[0] - documents_and_batch_fields = { - 'document': marshal(document, document_fields), - 'batch': batch - } + documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} return documents_and_batch_fields, 200 class DocumentUpdateByFileApi(DatasetApiResource): """Resource for update documents.""" - @cloud_edition_billing_resource_check('vector_space', 'dataset') + @cloud_edition_billing_resource_check("vector_space", "dataset") def post(self, tenant_id, dataset_id, document_id): """Update document by upload file.""" args = {} - if 'data' in request.form: - args = json.loads(request.form['data']) - if 'doc_form' not in args: - args['doc_form'] = 'text_model' - if 'doc_language' not in args: - args['doc_language'] = 'English' + if "data" in request.form: + args = json.loads(request.form["data"]) + if "doc_form" not in args: + args["doc_form"] = "text_model" + if "doc_language" not in args: + args["doc_language"] = "English" # get dataset info dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise ValueError('Dataset is not exist.') - if 'file' in request.files: + raise ValueError("Dataset is not exist.") + if "file" in request.files: # save file info - file = request.files['file'] - + file = request.files["file"] if len(request.files) > 1: raise TooManyFilesError() upload_file = FileService.upload_file(file, current_user) - data_source = { - 'type': 'upload_file', - 'info_list': { - 'file_info_list': { - 'file_ids': [upload_file.id] - } - } - } - args['data_source'] = data_source + data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} + args["data_source"] = data_source # validate args - args['original_document_id'] = str(document_id) + args["original_document_id"] = str(document_id) DocumentService.document_create_args_validate(args) try: @@ -268,16 +224,13 @@ class DocumentUpdateByFileApi(DatasetApiResource): dataset=dataset, document_data=args, account=dataset.created_by_account, - dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, - created_from='api' + dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, + created_from="api", ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) document = documents[0] - documents_and_batch_fields = { - 'document': marshal(document, document_fields), - 'batch': batch - } + documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} return documents_and_batch_fields, 200 @@ -289,13 +242,10 @@ class DocumentDeleteApi(DatasetApiResource): tenant_id = str(tenant_id) # get dataset info - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise ValueError('Dataset is not exist.') + raise ValueError("Dataset is not exist.") document = DocumentService.get_document(dataset.id, document_id) @@ -311,44 +261,39 @@ class DocumentDeleteApi(DatasetApiResource): # delete document DocumentService.delete_document(document) except services.errors.document.DocumentIndexingError: - raise DocumentIndexingError('Cannot delete document during indexing.') + raise DocumentIndexingError("Cannot delete document during indexing.") - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class DocumentListApi(DatasetApiResource): def get(self, tenant_id, dataset_id): dataset_id = str(dataset_id) tenant_id = str(tenant_id) - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) - search = request.args.get('keyword', default=None, type=str) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) + search = request.args.get("keyword", default=None, type=str) + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") - query = Document.query.filter_by( - dataset_id=str(dataset_id), tenant_id=tenant_id) + query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id) if search: - search = f'%{search}%' + search = f"%{search}%" query = query.filter(Document.name.like(search)) query = query.order_by(desc(Document.created_at)) - paginated_documents = query.paginate( - page=page, per_page=limit, max_per_page=100, error_out=False) + paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) documents = paginated_documents.items response = { - 'data': marshal(documents, document_fields), - 'has_more': len(documents) == limit, - 'limit': limit, - 'total': paginated_documents.total, - 'page': page + "data": marshal(documents, document_fields), + "has_more": len(documents) == limit, + "limit": limit, + "total": paginated_documents.total, + "page": page, } return response @@ -360,38 +305,36 @@ class DocumentIndexingStatusApi(DatasetApiResource): batch = str(batch) tenant_id = str(tenant_id) # get dataset - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # get documents documents = DocumentService.get_batch_documents(dataset_id, batch) if not documents: - raise NotFound('Documents not found.') + raise NotFound("Documents not found.") documents_status = [] for document in documents: - completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() - total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() + completed_segments = DocumentSegment.query.filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != "re_segment", + ).count() + total_segments = DocumentSegment.query.filter( + DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" + ).count() document.completed_segments = completed_segments document.total_segments = total_segments if document.is_paused: - document.indexing_status = 'paused' + document.indexing_status = "paused" documents_status.append(marshal(document, document_status_fields)) - data = { - 'data': documents_status - } + data = {"data": documents_status} return data -api.add_resource(DocumentAddByTextApi, '/datasets//document/create_by_text') -api.add_resource(DocumentAddByFileApi, '/datasets//document/create_by_file') -api.add_resource(DocumentUpdateByTextApi, '/datasets//documents//update_by_text') -api.add_resource(DocumentUpdateByFileApi, '/datasets//documents//update_by_file') -api.add_resource(DocumentDeleteApi, '/datasets//documents/') -api.add_resource(DocumentListApi, '/datasets//documents') -api.add_resource(DocumentIndexingStatusApi, '/datasets//documents//indexing-status') +api.add_resource(DocumentAddByTextApi, "/datasets//document/create_by_text") +api.add_resource(DocumentAddByFileApi, "/datasets//document/create_by_file") +api.add_resource(DocumentUpdateByTextApi, "/datasets//documents//update_by_text") +api.add_resource(DocumentUpdateByFileApi, "/datasets//documents//update_by_file") +api.add_resource(DocumentDeleteApi, "/datasets//documents/") +api.add_resource(DocumentListApi, "/datasets//documents") +api.add_resource(DocumentIndexingStatusApi, "/datasets//documents//indexing-status") diff --git a/api/controllers/service_api/dataset/error.py b/api/controllers/service_api/dataset/error.py index e77693b6c9..5ff5e08c72 100644 --- a/api/controllers/service_api/dataset/error.py +++ b/api/controllers/service_api/dataset/error.py @@ -2,78 +2,78 @@ from libs.exception import BaseHTTPException class NoFileUploadedError(BaseHTTPException): - error_code = 'no_file_uploaded' + error_code = "no_file_uploaded" description = "Please upload your file." code = 400 class TooManyFilesError(BaseHTTPException): - error_code = 'too_many_files' + error_code = "too_many_files" description = "Only one file is allowed." code = 400 class FileTooLargeError(BaseHTTPException): - error_code = 'file_too_large' + error_code = "file_too_large" description = "File size exceeded. {message}" code = 413 class UnsupportedFileTypeError(BaseHTTPException): - error_code = 'unsupported_file_type' + error_code = "unsupported_file_type" description = "File type not allowed." code = 415 class HighQualityDatasetOnlyError(BaseHTTPException): - error_code = 'high_quality_dataset_only' + error_code = "high_quality_dataset_only" description = "Current operation only supports 'high-quality' datasets." code = 400 class DatasetNotInitializedError(BaseHTTPException): - error_code = 'dataset_not_initialized' + error_code = "dataset_not_initialized" description = "The dataset is still being initialized or indexing. Please wait a moment." code = 400 class ArchivedDocumentImmutableError(BaseHTTPException): - error_code = 'archived_document_immutable' + error_code = "archived_document_immutable" description = "The archived document is not editable." code = 403 class DatasetNameDuplicateError(BaseHTTPException): - error_code = 'dataset_name_duplicate' + error_code = "dataset_name_duplicate" description = "The dataset name already exists. Please modify your dataset name." code = 409 class InvalidActionError(BaseHTTPException): - error_code = 'invalid_action' + error_code = "invalid_action" description = "Invalid action." code = 400 class DocumentAlreadyFinishedError(BaseHTTPException): - error_code = 'document_already_finished' + error_code = "document_already_finished" description = "The document has been processed. Please refresh the page or go to the document details." code = 400 class DocumentIndexingError(BaseHTTPException): - error_code = 'document_indexing' + error_code = "document_indexing" description = "The document is being processed and cannot be edited." code = 400 class InvalidMetadataError(BaseHTTPException): - error_code = 'invalid_metadata' + error_code = "invalid_metadata" description = "The metadata content is incorrect. Please check and verify." code = 400 class DatasetInUseError(BaseHTTPException): - error_code = 'dataset_in_use' + error_code = "dataset_in_use" description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it." code = 409 diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 0fa2aa65b2..e69db29fdc 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -21,52 +21,51 @@ from services.dataset_service import DatasetService, DocumentService, SegmentSer class SegmentApi(DatasetApiResource): """Resource for segments.""" - @cloud_edition_billing_resource_check('vector_space', 'dataset') - @cloud_edition_billing_knowledge_limit_check('add_segment', 'dataset') + @cloud_edition_billing_resource_check("vector_space", "dataset") + @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") def post(self, tenant_id, dataset_id, document_id): """Create single segment.""" # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check document document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") + if document.indexing_status != "completed": + raise NotFound("Document is not completed.") + if not document.enabled: + raise NotFound("Document is disabled.") # check embedding model setting - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": try: model_manager = ModelManager() model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") - except ProviderTokenNotInitError as ex: + "in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) # validate args parser = reqparse.RequestParser() - parser.add_argument('segments', type=list, required=False, nullable=True, location='json') + parser.add_argument("segments", type=list, required=False, nullable=True, location="json") args = parser.parse_args() - if args['segments'] is not None: - for args_item in args['segments']: + if args["segments"] is not None: + for args_item in args["segments"]: SegmentService.segment_create_args_validate(args_item, document) - segments = SegmentService.multi_create_segment(args['segments'], document, dataset) - return { - 'data': marshal(segments, segment_fields), - 'doc_form': document.doc_form - }, 200 + segments = SegmentService.multi_create_segment(args["segments"], document, dataset) + return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200 else: return {"error": "Segemtns is required"}, 400 @@ -75,61 +74,53 @@ class SegmentApi(DatasetApiResource): # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check document document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") # check embedding model setting - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": try: model_manager = ModelManager() model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) parser = reqparse.RequestParser() - parser.add_argument('status', type=str, - action='append', default=[], location='args') - parser.add_argument('keyword', type=str, default=None, location='args') + parser.add_argument("status", type=str, action="append", default=[], location="args") + parser.add_argument("keyword", type=str, default=None, location="args") args = parser.parse_args() - status_list = args['status'] - keyword = args['keyword'] + status_list = args["status"] + keyword = args["keyword"] query = DocumentSegment.query.filter( - DocumentSegment.document_id == str(document_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id ) if status_list: query = query.filter(DocumentSegment.status.in_(status_list)) if keyword: - query = query.where(DocumentSegment.content.ilike(f'%{keyword}%')) + query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) total = query.count() segments = query.order_by(DocumentSegment.position).all() - return { - 'data': marshal(segments, segment_fields), - 'doc_form': document.doc_form, - 'total': total - }, 200 + return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form, "total": total}, 200 class DatasetSegmentApi(DatasetApiResource): @@ -137,48 +128,41 @@ class DatasetSegmentApi(DatasetApiResource): # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") # check segment segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id ).first() if not segment: - raise NotFound('Segment not found.') + raise NotFound("Segment not found.") SegmentService.delete_segment(segment, document, dataset) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 - @cloud_edition_billing_resource_check('vector_space', 'dataset') + @cloud_edition_billing_resource_check("vector_space", "dataset") def post(self, tenant_id, dataset_id, document_id, segment_id): # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') - if dataset.indexing_technique == 'high_quality': + raise NotFound("Document not found.") + if dataset.indexing_technique == "high_quality": # check embedding model setting try: model_manager = ModelManager() @@ -186,35 +170,34 @@ class DatasetSegmentApi(DatasetApiResource): tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) # check segment segment_id = str(segment_id) segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id ).first() if not segment: - raise NotFound('Segment not found.') + raise NotFound("Segment not found.") # validate args parser = reqparse.RequestParser() - parser.add_argument('segment', type=dict, required=False, nullable=True, location='json') + parser.add_argument("segment", type=dict, required=False, nullable=True, location="json") args = parser.parse_args() - SegmentService.segment_create_args_validate(args['segment'], document) - segment = SegmentService.update_segment(args['segment'], segment, document, dataset) - return { - 'data': marshal(segment, segment_fields), - 'doc_form': document.doc_form - }, 200 + SegmentService.segment_create_args_validate(args["segment"], document) + segment = SegmentService.update_segment(args["segment"], segment, document, dataset) + return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 -api.add_resource(SegmentApi, '/datasets//documents//segments') -api.add_resource(DatasetSegmentApi, '/datasets//documents//segments/') +api.add_resource(SegmentApi, "/datasets//documents//segments") +api.add_resource( + DatasetSegmentApi, "/datasets//documents//segments/" +) diff --git a/api/controllers/service_api/index.py b/api/controllers/service_api/index.py index c910063ebd..d24c4597e2 100644 --- a/api/controllers/service_api/index.py +++ b/api/controllers/service_api/index.py @@ -13,4 +13,4 @@ class IndexApi(Resource): } -api.add_resource(IndexApi, '/') +api.add_resource(IndexApi, "/") diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 819512edf0..b935b23ed6 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -21,9 +21,10 @@ class WhereisUserArg(Enum): """ Enum for whereis_user_arg. """ - QUERY = 'query' - JSON = 'json' - FORM = 'form' + + QUERY = "query" + JSON = "json" + FORM = "form" class FetchUserArg(BaseModel): @@ -35,13 +36,13 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio def decorator(view_func): @wraps(view_func) def decorated_view(*args, **kwargs): - api_token = validate_and_get_api_token('app') + api_token = validate_and_get_api_token("app") app_model = db.session.query(App).filter(App.id == api_token.app_id).first() if not app_model: raise Forbidden("The app no longer exists.") - if app_model.status != 'normal': + if app_model.status != "normal": raise Forbidden("The app's status is abnormal.") if not app_model.enable_api: @@ -51,15 +52,15 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio if tenant.status == TenantStatus.ARCHIVE: raise Forbidden("The workspace's status is archived.") - kwargs['app_model'] = app_model + kwargs["app_model"] = app_model if fetch_user_arg: if fetch_user_arg.fetch_from == WhereisUserArg.QUERY: - user_id = request.args.get('user') + user_id = request.args.get("user") elif fetch_user_arg.fetch_from == WhereisUserArg.JSON: - user_id = request.get_json().get('user') + user_id = request.get_json().get("user") elif fetch_user_arg.fetch_from == WhereisUserArg.FORM: - user_id = request.form.get('user') + user_id = request.form.get("user") else: # use default-user user_id = None @@ -70,9 +71,10 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio if user_id: user_id = str(user_id) - kwargs['end_user'] = create_or_update_end_user_for_user_id(app_model, user_id) + kwargs["end_user"] = create_or_update_end_user_for_user_id(app_model, user_id) return view_func(*args, **kwargs) + return decorated_view if view is None: @@ -81,9 +83,7 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio return decorator(view) -def cloud_edition_billing_resource_check(resource: str, - api_token_type: str, - error_msg: str = "You have reached the limit of your subscription."): +def cloud_edition_billing_resource_check(resource: str, api_token_type: str): def interceptor(view): def decorated(*args, **kwargs): api_token = validate_and_get_api_token(api_token_type) @@ -95,34 +95,36 @@ def cloud_edition_billing_resource_check(resource: str, vector_space = features.vector_space documents_upload_quota = features.documents_upload_quota - if resource == 'members' and 0 < members.limit <= members.size: - raise Forbidden(error_msg) - elif resource == 'apps' and 0 < apps.limit <= apps.size: - raise Forbidden(error_msg) - elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size: - raise Forbidden(error_msg) - elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size: - raise Forbidden(error_msg) + if resource == "members" and 0 < members.limit <= members.size: + raise Forbidden("The number of members has reached the limit of your subscription.") + elif resource == "apps" and 0 < apps.limit <= apps.size: + raise Forbidden("The number of apps has reached the limit of your subscription.") + elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size: + raise Forbidden("The capacity of the vector space has reached the limit of your subscription.") + elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size: + raise Forbidden("The number of documents has reached the limit of your subscription.") else: return view(*args, **kwargs) return view(*args, **kwargs) + return decorated + return interceptor -def cloud_edition_billing_knowledge_limit_check(resource: str, - api_token_type: str, - error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."): +def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str): def interceptor(view): @wraps(view) def decorated(*args, **kwargs): api_token = validate_and_get_api_token(api_token_type) features = FeatureService.get_features(api_token.tenant_id) if features.billing.enabled: - if resource == 'add_segment': - if features.billing.subscription.plan == 'sandbox': - raise Forbidden(error_msg) + if resource == "add_segment": + if features.billing.subscription.plan == "sandbox": + raise Forbidden( + "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan." + ) else: return view(*args, **kwargs) @@ -132,17 +134,20 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, return interceptor + def validate_dataset_token(view=None): def decorator(view): @wraps(view) def decorated(*args, **kwargs): - api_token = validate_and_get_api_token('dataset') - tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \ - .filter(Tenant.id == api_token.tenant_id) \ - .filter(TenantAccountJoin.tenant_id == Tenant.id) \ - .filter(TenantAccountJoin.role.in_(['owner'])) \ - .filter(Tenant.status == TenantStatus.NORMAL) \ - .one_or_none() # TODO: only owner information is required, so only one is returned. + api_token = validate_and_get_api_token("dataset") + tenant_account_join = ( + db.session.query(Tenant, TenantAccountJoin) + .filter(Tenant.id == api_token.tenant_id) + .filter(TenantAccountJoin.tenant_id == Tenant.id) + .filter(TenantAccountJoin.role.in_(["owner"])) + .filter(Tenant.status == TenantStatus.NORMAL) + .one_or_none() + ) # TODO: only owner information is required, so only one is returned. if tenant_account_join: tenant, ta = tenant_account_join account = Account.query.filter_by(id=ta.account_id).first() @@ -156,6 +161,7 @@ def validate_dataset_token(view=None): else: raise Unauthorized("Tenant does not exist.") return view(api_token.tenant_id, *args, **kwargs) + return decorated if view: @@ -170,20 +176,24 @@ def validate_and_get_api_token(scope=None): """ Validate and get API token. """ - auth_header = request.headers.get('Authorization') - if auth_header is None or ' ' not in auth_header: + auth_header = request.headers.get("Authorization") + if auth_header is None or " " not in auth_header: raise Unauthorized("Authorization header must be provided and start with 'Bearer'") auth_scheme, auth_token = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() - if auth_scheme != 'bearer': + if auth_scheme != "bearer": raise Unauthorized("Authorization scheme must be 'Bearer'") - api_token = db.session.query(ApiToken).filter( - ApiToken.token == auth_token, - ApiToken.type == scope, - ).first() + api_token = ( + db.session.query(ApiToken) + .filter( + ApiToken.token == auth_token, + ApiToken.type == scope, + ) + .first() + ) if not api_token: raise Unauthorized("Access token is invalid") @@ -199,23 +209,26 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] Create or update session terminal based on user ID. """ if not user_id: - user_id = 'DEFAULT-USER' + user_id = "DEFAULT-USER" - end_user = db.session.query(EndUser) \ + end_user = ( + db.session.query(EndUser) .filter( - EndUser.tenant_id == app_model.tenant_id, - EndUser.app_id == app_model.id, - EndUser.session_id == user_id, - EndUser.type == 'service_api' - ).first() + EndUser.tenant_id == app_model.tenant_id, + EndUser.app_id == app_model.id, + EndUser.session_id == user_id, + EndUser.type == "service_api", + ) + .first() + ) if end_user is None: end_user = EndUser( tenant_id=app_model.tenant_id, app_id=app_model.id, - type='service_api', - is_anonymous=True if user_id == 'DEFAULT-USER' else False, - session_id=user_id + type="service_api", + is_anonymous=True if user_id == "DEFAULT-USER" else False, + session_id=user_id, ) db.session.add(end_user) db.session.commit() diff --git a/api/controllers/web/__init__.py b/api/controllers/web/__init__.py index aa19bdc034..630b9468a7 100644 --- a/api/controllers/web/__init__.py +++ b/api/controllers/web/__init__.py @@ -2,7 +2,7 @@ from flask import Blueprint from libs.external_api import ExternalApi -bp = Blueprint('web', __name__, url_prefix='/api') +bp = Blueprint("web", __name__, url_prefix="/api") api = ExternalApi(bp) diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index f4db82552c..aabca93338 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -10,33 +10,32 @@ from services.app_service import AppService class AppParameterApi(WebApiResource): """Resource for app variables.""" + variable_fields = { - 'key': fields.String, - 'name': fields.String, - 'description': fields.String, - 'type': fields.String, - 'default': fields.String, - 'max_length': fields.Integer, - 'options': fields.List(fields.String) + "key": fields.String, + "name": fields.String, + "description": fields.String, + "type": fields.String, + "default": fields.String, + "max_length": fields.Integer, + "options": fields.List(fields.String), } - system_parameters_fields = { - 'image_file_size_limit': fields.String - } + system_parameters_fields = {"image_file_size_limit": fields.String} parameters_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw, - 'suggested_questions_after_answer': fields.Raw, - 'speech_to_text': fields.Raw, - 'text_to_speech': fields.Raw, - 'retriever_resource': fields.Raw, - 'annotation_reply': fields.Raw, - 'more_like_this': fields.Raw, - 'user_input_form': fields.Raw, - 'sensitive_word_avoidance': fields.Raw, - 'file_upload': fields.Raw, - 'system_parameters': fields.Nested(system_parameters_fields) + "opening_statement": fields.String, + "suggested_questions": fields.Raw, + "suggested_questions_after_answer": fields.Raw, + "speech_to_text": fields.Raw, + "text_to_speech": fields.Raw, + "retriever_resource": fields.Raw, + "annotation_reply": fields.Raw, + "more_like_this": fields.Raw, + "user_input_form": fields.Raw, + "sensitive_word_avoidance": fields.Raw, + "file_upload": fields.Raw, + "system_parameters": fields.Nested(system_parameters_fields), } @marshal_with(parameters_fields) @@ -53,30 +52,35 @@ class AppParameterApi(WebApiResource): app_model_config = app_model.app_model_config features_dict = app_model_config.to_dict() - user_input_form = features_dict.get('user_input_form', []) + user_input_form = features_dict.get("user_input_form", []) return { - 'opening_statement': features_dict.get('opening_statement'), - 'suggested_questions': features_dict.get('suggested_questions', []), - 'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer', - {"enabled": False}), - 'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}), - 'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}), - 'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}), - 'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}), - 'more_like_this': features_dict.get('more_like_this', {"enabled": False}), - 'user_input_form': user_input_form, - 'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance', - {"enabled": False, "type": "", "configs": []}), - 'file_upload': features_dict.get('file_upload', {"image": { - "enabled": False, - "number_limits": 3, - "detail": "high", - "transfer_methods": ["remote_url", "local_file"] - }}), - 'system_parameters': { - 'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT - } + "opening_statement": features_dict.get("opening_statement"), + "suggested_questions": features_dict.get("suggested_questions", []), + "suggested_questions_after_answer": features_dict.get( + "suggested_questions_after_answer", {"enabled": False} + ), + "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}), + "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}), + "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}), + "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}), + "more_like_this": features_dict.get("more_like_this", {"enabled": False}), + "user_input_form": user_input_form, + "sensitive_word_avoidance": features_dict.get( + "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []} + ), + "file_upload": features_dict.get( + "file_upload", + { + "image": { + "enabled": False, + "number_limits": 3, + "detail": "high", + "transfer_methods": ["remote_url", "local_file"], + } + }, + ), + "system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT}, } @@ -86,5 +90,5 @@ class AppMeta(WebApiResource): return AppService().get_app_meta(app_model) -api.add_resource(AppParameterApi, '/parameters') -api.add_resource(AppMeta, '/meta') \ No newline at end of file +api.add_resource(AppParameterApi, "/parameters") +api.add_resource(AppMeta, "/meta") diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 0e905f905a..d062d2893b 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -31,14 +31,10 @@ from services.errors.audio import ( class AudioApi(WebApiResource): def post(self, app_model: App, end_user): - file = request.files['file'] + file = request.files["file"] try: - response = AudioService.transcript_asr( - app_model=app_model, - file=file, - end_user=end_user - ) + response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user) return response except services.errors.app_model_config.AppModelConfigBrokenError: @@ -70,34 +66,36 @@ class AudioApi(WebApiResource): class TextApi(WebApiResource): def post(self, app_model: App, end_user): from flask_restful import reqparse + try: parser = reqparse.RequestParser() - parser.add_argument('message_id', type=str, required=False, location='json') - parser.add_argument('voice', type=str, location='json') - parser.add_argument('text', type=str, location='json') - parser.add_argument('streaming', type=bool, location='json') + parser.add_argument("message_id", type=str, required=False, location="json") + parser.add_argument("voice", type=str, location="json") + parser.add_argument("text", type=str, location="json") + parser.add_argument("streaming", type=bool, location="json") args = parser.parse_args() - message_id = args.get('message_id', None) - text = args.get('text', None) - if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] - and app_model.workflow - and app_model.workflow.features_dict): - text_to_speech = app_model.workflow.features_dict.get('text_to_speech') - voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice') + message_id = args.get("message_id", None) + text = args.get("text", None) + if ( + app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + and app_model.workflow + and app_model.workflow.features_dict + ): + text_to_speech = app_model.workflow.features_dict.get("text_to_speech") + voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") else: try: - voice = args.get('voice') if args.get( - 'voice') else app_model.app_model_config.text_to_speech_dict.get('voice') + voice = ( + args.get("voice") + if args.get("voice") + else app_model.app_model_config.text_to_speech_dict.get("voice") + ) except Exception: voice = None response = AudioService.transcript_tts( - app_model=app_model, - message_id=message_id, - end_user=end_user.external_user_id, - voice=voice, - text=text + app_model=app_model, message_id=message_id, end_user=end_user.external_user_id, voice=voice, text=text ) return response @@ -127,5 +125,5 @@ class TextApi(WebApiResource): raise InternalServerError() -api.add_resource(AudioApi, '/audio-to-text') -api.add_resource(TextApi, '/text-to-audio') +api.add_resource(AudioApi, "/audio-to-text") +api.add_resource(TextApi, "/text-to-audio") diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 948d5fabb5..0837eedfb0 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -15,6 +15,7 @@ from controllers.web.error import ( ProviderNotInitializeError, ProviderQuotaExceededError, ) +from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from controllers.web.wraps import WebApiResource from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom @@ -24,34 +25,30 @@ from libs import helper from libs.helper import uuid_value from models.model import AppMode from services.app_generate_service import AppGenerateService +from services.errors.llm import InvokeRateLimitError # define completion api for user class CompletionApi(WebApiResource): - def post(self, app_model, end_user): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, location='json', default='') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, location="json", default="") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("retriever_from", type=str, required=False, default="web_app", location="json") args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' - args['auto_generate_name'] = False + streaming = args["response_mode"] == "streaming" + args["auto_generate_name"] = False try: response = AppGenerateService.generate( - app_model=app_model, - user=end_user, - args=args, - invoke_from=InvokeFrom.WEB_APP, - streaming=streaming + app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.WEB_APP, streaming=streaming ) return helper.compact_generate_response(response) @@ -79,12 +76,12 @@ class CompletionApi(WebApiResource): class CompletionStopApi(WebApiResource): def post(self, app_model, end_user, task_id): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class ChatApi(WebApiResource): @@ -94,25 +91,21 @@ class ChatApi(WebApiResource): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, required=True, location='json') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('conversation_id', type=uuid_value, location='json') - parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, required=True, location="json") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("conversation_id", type=uuid_value, location="json") + parser.add_argument("retriever_from", type=str, required=False, default="web_app", location="json") args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' - args['auto_generate_name'] = False + streaming = args["response_mode"] == "streaming" + args["auto_generate_name"] = False try: response = AppGenerateService.generate( - app_model=app_model, - user=end_user, - args=args, - invoke_from=InvokeFrom.WEB_APP, - streaming=streaming + app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.WEB_APP, streaming=streaming ) return helper.compact_generate_response(response) @@ -129,6 +122,8 @@ class ChatApi(WebApiResource): raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() + except InvokeRateLimitError as ex: + raise InvokeRateLimitHttpError(ex.description) except InvokeError as e: raise CompletionRequestError(e.description) except ValueError as e: @@ -146,10 +141,10 @@ class ChatStopApi(WebApiResource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(CompletionApi, '/completion-messages') -api.add_resource(CompletionStopApi, '/completion-messages//stop') -api.add_resource(ChatApi, '/chat-messages') -api.add_resource(ChatStopApi, '/chat-messages//stop') +api.add_resource(CompletionApi, "/completion-messages") +api.add_resource(CompletionStopApi, "/completion-messages//stop") +api.add_resource(ChatApi, "/chat-messages") +api.add_resource(ChatStopApi, "/chat-messages//stop") diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index b83ea3a525..6bbfa94c27 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -15,7 +15,6 @@ from services.web_conversation_service import WebConversationService class ConversationListApi(WebApiResource): - @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, app_model, end_user): app_mode = AppMode.value_of(app_model.mode) @@ -23,23 +22,32 @@ class ConversationListApi(WebApiResource): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') - parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + parser.add_argument("pinned", type=str, choices=["true", "false", None], location="args") + parser.add_argument( + "sort_by", + type=str, + choices=["created_at", "-created_at", "updated_at", "-updated_at"], + required=False, + default="-updated_at", + location="args", + ) args = parser.parse_args() pinned = None - if 'pinned' in args and args['pinned'] is not None: - pinned = True if args['pinned'] == 'true' else False + if "pinned" in args and args["pinned"] is not None: + pinned = True if args["pinned"] == "true" else False try: return WebConversationService.pagination_by_last_id( app_model=app_model, user=end_user, - last_id=args['last_id'], - limit=args['limit'], + last_id=args["last_id"], + limit=args["limit"], invoke_from=InvokeFrom.WEB_APP, pinned=pinned, + sort_by=args["sort_by"], ) except LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") @@ -62,7 +70,6 @@ class ConversationApi(WebApiResource): class ConversationRenameApi(WebApiResource): - @marshal_with(simple_conversation_fields) def post(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) @@ -72,24 +79,17 @@ class ConversationRenameApi(WebApiResource): conversation_id = str(c_id) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=False, location='json') - parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json') + parser.add_argument("name", type=str, required=False, location="json") + parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json") args = parser.parse_args() try: - return ConversationService.rename( - app_model, - conversation_id, - end_user, - args['name'], - args['auto_generate'] - ) + return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"]) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") class ConversationPinApi(WebApiResource): - def patch(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: @@ -117,8 +117,8 @@ class ConversationUnPinApi(WebApiResource): return {"result": "success"} -api.add_resource(ConversationRenameApi, '/conversations//name', endpoint='web_conversation_name') -api.add_resource(ConversationListApi, '/conversations') -api.add_resource(ConversationApi, '/conversations/') -api.add_resource(ConversationPinApi, '/conversations//pin') -api.add_resource(ConversationUnPinApi, '/conversations//unpin') +api.add_resource(ConversationRenameApi, "/conversations//name", endpoint="web_conversation_name") +api.add_resource(ConversationListApi, "/conversations") +api.add_resource(ConversationApi, "/conversations/") +api.add_resource(ConversationPinApi, "/conversations//pin") +api.add_resource(ConversationUnPinApi, "/conversations//unpin") diff --git a/api/controllers/web/error.py b/api/controllers/web/error.py index bc87f51051..9fe5d08d54 100644 --- a/api/controllers/web/error.py +++ b/api/controllers/web/error.py @@ -2,122 +2,134 @@ from libs.exception import BaseHTTPException class AppUnavailableError(BaseHTTPException): - error_code = 'app_unavailable' + error_code = "app_unavailable" description = "App unavailable, please check your app configurations." code = 400 class NotCompletionAppError(BaseHTTPException): - error_code = 'not_completion_app' + error_code = "not_completion_app" description = "Please check if your Completion app mode matches the right API route." code = 400 class NotChatAppError(BaseHTTPException): - error_code = 'not_chat_app' + error_code = "not_chat_app" description = "Please check if your app mode matches the right API route." code = 400 class NotWorkflowAppError(BaseHTTPException): - error_code = 'not_workflow_app' + error_code = "not_workflow_app" description = "Please check if your Workflow app mode matches the right API route." code = 400 class ConversationCompletedError(BaseHTTPException): - error_code = 'conversation_completed' + error_code = "conversation_completed" description = "The conversation has ended. Please start a new conversation." code = 400 class ProviderNotInitializeError(BaseHTTPException): - error_code = 'provider_not_initialize' - description = "No valid model provider credentials found. " \ - "Please go to Settings -> Model Provider to complete your provider credentials." + error_code = "provider_not_initialize" + description = ( + "No valid model provider credentials found. " + "Please go to Settings -> Model Provider to complete your provider credentials." + ) code = 400 class ProviderQuotaExceededError(BaseHTTPException): - error_code = 'provider_quota_exceeded' - description = "Your quota for Dify Hosted OpenAI has been exhausted. " \ - "Please go to Settings -> Model Provider to complete your own provider credentials." + error_code = "provider_quota_exceeded" + description = ( + "Your quota for Dify Hosted OpenAI has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials." + ) code = 400 class ProviderModelCurrentlyNotSupportError(BaseHTTPException): - error_code = 'model_currently_not_support' + error_code = "model_currently_not_support" description = "Dify Hosted OpenAI trial currently not support the GPT-4 model." code = 400 class CompletionRequestError(BaseHTTPException): - error_code = 'completion_request_error' + error_code = "completion_request_error" description = "Completion request failed." code = 400 class AppMoreLikeThisDisabledError(BaseHTTPException): - error_code = 'app_more_like_this_disabled' + error_code = "app_more_like_this_disabled" description = "The 'More like this' feature is disabled. Please refresh your page." code = 403 class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException): - error_code = 'app_suggested_questions_after_answer_disabled' + error_code = "app_suggested_questions_after_answer_disabled" description = "The 'Suggested Questions After Answer' feature is disabled. Please refresh your page." code = 403 class NoAudioUploadedError(BaseHTTPException): - error_code = 'no_audio_uploaded' + error_code = "no_audio_uploaded" description = "Please upload your audio." code = 400 class AudioTooLargeError(BaseHTTPException): - error_code = 'audio_too_large' + error_code = "audio_too_large" description = "Audio size exceeded. {message}" code = 413 class UnsupportedAudioTypeError(BaseHTTPException): - error_code = 'unsupported_audio_type' + error_code = "unsupported_audio_type" description = "Audio type not allowed." code = 415 class ProviderNotSupportSpeechToTextError(BaseHTTPException): - error_code = 'provider_not_support_speech_to_text' + error_code = "provider_not_support_speech_to_text" description = "Provider not support speech to text." code = 400 class NoFileUploadedError(BaseHTTPException): - error_code = 'no_file_uploaded' + error_code = "no_file_uploaded" description = "Please upload your file." code = 400 class TooManyFilesError(BaseHTTPException): - error_code = 'too_many_files' + error_code = "too_many_files" description = "Only one file is allowed." code = 400 class FileTooLargeError(BaseHTTPException): - error_code = 'file_too_large' + error_code = "file_too_large" description = "File size exceeded. {message}" code = 413 class UnsupportedFileTypeError(BaseHTTPException): - error_code = 'unsupported_file_type' + error_code = "unsupported_file_type" description = "File type not allowed." code = 415 class WebSSOAuthRequiredError(BaseHTTPException): - error_code = 'web_sso_auth_required' + error_code = "web_sso_auth_required" description = "Web SSO authentication required." code = 401 + + +class InvokeRateLimitError(BaseHTTPException): + """Raised when the Invoke returns rate limit error.""" + + error_code = "rate_limit_error" + description = "Rate Limit Error" + code = 429 diff --git a/api/controllers/web/feature.py b/api/controllers/web/feature.py index 69b38faaf6..0563ed2238 100644 --- a/api/controllers/web/feature.py +++ b/api/controllers/web/feature.py @@ -9,4 +9,4 @@ class SystemFeatureApi(Resource): return FeatureService.get_system_features().model_dump() -api.add_resource(SystemFeatureApi, '/system-features') +api.add_resource(SystemFeatureApi, "/system-features") diff --git a/api/controllers/web/file.py b/api/controllers/web/file.py index ca83f6037a..253b1d511c 100644 --- a/api/controllers/web/file.py +++ b/api/controllers/web/file.py @@ -10,14 +10,13 @@ from services.file_service import FileService class FileApi(WebApiResource): - @marshal_with(file_fields) def post(self, app_model, end_user): # get file from request - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: @@ -32,4 +31,4 @@ class FileApi(WebApiResource): return upload_file, 201 -api.add_resource(FileApi, '/files/upload') +api.add_resource(FileApi, "/files/upload") diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 865d2270ad..56aaaa930a 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -33,48 +33,46 @@ from services.message_service import MessageService class MessageListApi(WebApiResource): - feedback_fields = { - 'rating': fields.String - } + feedback_fields = {"rating": fields.String} retriever_resource_fields = { - 'id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'dataset_id': fields.String, - 'dataset_name': fields.String, - 'document_id': fields.String, - 'document_name': fields.String, - 'data_source_type': fields.String, - 'segment_id': fields.String, - 'score': fields.Float, - 'hit_count': fields.Integer, - 'word_count': fields.Integer, - 'segment_position': fields.Integer, - 'index_node_hash': fields.String, - 'content': fields.String, - 'created_at': TimestampField + "id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "dataset_id": fields.String, + "dataset_name": fields.String, + "document_id": fields.String, + "document_name": fields.String, + "data_source_type": fields.String, + "segment_id": fields.String, + "score": fields.Float, + "hit_count": fields.Integer, + "word_count": fields.Integer, + "segment_position": fields.Integer, + "index_node_hash": fields.String, + "content": fields.String, + "created_at": TimestampField, } message_fields = { - 'id': fields.String, - 'conversation_id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'answer': fields.String(attribute='re_sign_file_url_answer'), - 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), - 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), - 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), - 'created_at': TimestampField, - 'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)), - 'status': fields.String, - 'error': fields.String, + "id": fields.String, + "conversation_id": fields.String, + "inputs": fields.Raw, + "query": fields.String, + "answer": fields.String(attribute="re_sign_file_url_answer"), + "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), + "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), + "created_at": TimestampField, + "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), + "status": fields.String, + "error": fields.String, } message_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(message_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(message_fields)), } @marshal_with(message_infinite_scroll_pagination_fields) @@ -84,14 +82,15 @@ class MessageListApi(WebApiResource): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') - parser.add_argument('first_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") + parser.add_argument("first_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() try: - return MessageService.pagination_by_first_id(app_model, end_user, - args['conversation_id'], args['first_id'], args['limit']) + return MessageService.pagination_by_first_id( + app_model, end_user, args["conversation_id"], args["first_id"], args["limit"] + ) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.message.FirstMessageNotExistsError: @@ -103,29 +102,31 @@ class MessageFeedbackApi(WebApiResource): message_id = str(message_id) parser = reqparse.RequestParser() - parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') + parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, end_user, args['rating']) + MessageService.create_feedback(app_model, message_id, end_user, args["rating"]) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") - return {'result': 'success'} + return {"result": "success"} class MessageMoreLikeThisApi(WebApiResource): def get(self, app_model, end_user, message_id): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() message_id = str(message_id) parser = reqparse.RequestParser() - parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args') + parser.add_argument( + "response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args" + ) args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' + streaming = args["response_mode"] == "streaming" try: response = AppGenerateService.generate_more_like_this( @@ -133,7 +134,7 @@ class MessageMoreLikeThisApi(WebApiResource): user=end_user, message_id=message_id, invoke_from=InvokeFrom.WEB_APP, - streaming=streaming + streaming=streaming, ) return helper.compact_generate_response(response) @@ -166,10 +167,7 @@ class MessageSuggestedQuestionApi(WebApiResource): try: questions = MessageService.get_suggested_questions_after_answer( - app_model=app_model, - user=end_user, - message_id=message_id, - invoke_from=InvokeFrom.WEB_APP + app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.WEB_APP ) except MessageNotExistsError: raise NotFound("Message not found") @@ -189,10 +187,10 @@ class MessageSuggestedQuestionApi(WebApiResource): logging.exception("internal server error.") raise InternalServerError() - return {'data': questions} + return {"data": questions} -api.add_resource(MessageListApi, '/messages') -api.add_resource(MessageFeedbackApi, '/messages//feedbacks') -api.add_resource(MessageMoreLikeThisApi, '/messages//more-like-this') -api.add_resource(MessageSuggestedQuestionApi, '/messages//suggested-questions') +api.add_resource(MessageListApi, "/messages") +api.add_resource(MessageFeedbackApi, "/messages//feedbacks") +api.add_resource(MessageMoreLikeThisApi, "/messages//more-like-this") +api.add_resource(MessageSuggestedQuestionApi, "/messages//suggested-questions") diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index ccc8683a79..a01ffd8612 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -9,37 +9,37 @@ from controllers.web.error import WebSSOAuthRequiredError from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site +from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService class PassportResource(Resource): """Base resource for passport.""" + def get(self): - system_features = FeatureService.get_system_features() - if system_features.sso_enforced_for_web: - raise WebSSOAuthRequiredError() - - app_code = request.headers.get('X-App-Code') + app_code = request.headers.get("X-App-Code") if app_code is None: - raise Unauthorized('X-App-Code header is missing.') + raise Unauthorized("X-App-Code header is missing.") + + if system_features.sso_enforced_for_web: + app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False) + if app_web_sso_enabled: + raise WebSSOAuthRequiredError() # get site from db and check if it is normal - site = db.session.query(Site).filter( - Site.code == app_code, - Site.status == 'normal' - ).first() + site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first() if not site: raise NotFound() # get app from db and check if it is normal and enable_site app_model = db.session.query(App).filter(App.id == site.app_id).first() - if not app_model or app_model.status != 'normal' or not app_model.enable_site: + if not app_model or app_model.status != "normal" or not app_model.enable_site: raise NotFound() end_user = EndUser( tenant_id=app_model.tenant_id, app_id=app_model.id, - type='browser', + type="browser", is_anonymous=True, session_id=generate_session_id(), ) @@ -49,20 +49,20 @@ class PassportResource(Resource): payload = { "iss": site.app_id, - 'sub': 'Web API Passport', - 'app_id': site.app_id, - 'app_code': app_code, - 'end_user_id': end_user.id, + "sub": "Web API Passport", + "app_id": site.app_id, + "app_code": app_code, + "end_user_id": end_user.id, } tk = PassportService().issue(payload) return { - 'access_token': tk, + "access_token": tk, } -api.add_resource(PassportResource, '/passport') +api.add_resource(PassportResource, "/passport") def generate_session_id(): @@ -71,7 +71,6 @@ def generate_session_id(): """ while True: session_id = str(uuid.uuid4()) - existing_count = db.session.query(EndUser) \ - .filter(EndUser.session_id == session_id).count() + existing_count = db.session.query(EndUser).filter(EndUser.session_id == session_id).count() if existing_count == 0: return session_id diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index e17869ffdb..8253f5fc57 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -10,67 +10,65 @@ from libs.helper import TimestampField, uuid_value from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService -feedback_fields = { - 'rating': fields.String -} +feedback_fields = {"rating": fields.String} message_fields = { - 'id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'answer': fields.String, - 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), - 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), - 'created_at': TimestampField + "id": fields.String, + "inputs": fields.Raw, + "query": fields.String, + "answer": fields.String, + "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), + "created_at": TimestampField, } class SavedMessageListApi(WebApiResource): saved_message_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(message_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(message_fields)), } @marshal_with(saved_message_infinite_scroll_pagination_fields) def get(self, app_model, end_user): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - return SavedMessageService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit']) + return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"]) def post(self, app_model, end_user): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() parser = reqparse.RequestParser() - parser.add_argument('message_id', type=uuid_value, required=True, location='json') + parser.add_argument("message_id", type=uuid_value, required=True, location="json") args = parser.parse_args() try: - SavedMessageService.save(app_model, end_user, args['message_id']) + SavedMessageService.save(app_model, end_user, args["message_id"]) except MessageNotExistsError: raise NotFound("Message Not Exists.") - return {'result': 'success'} + return {"result": "success"} class SavedMessageApi(WebApiResource): def delete(self, app_model, end_user, message_id): message_id = str(message_id) - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() SavedMessageService.delete(app_model, end_user, message_id) - return {'result': 'success'} + return {"result": "success"} -api.add_resource(SavedMessageListApi, '/saved-messages') -api.add_resource(SavedMessageApi, '/saved-messages/') +api.add_resource(SavedMessageListApi, "/saved-messages") +api.add_resource(SavedMessageApi, "/saved-messages/") diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index 0f4a7cabe5..0564b15ea3 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -1,4 +1,3 @@ - from flask_restful import fields, marshal_with from werkzeug.exceptions import Forbidden @@ -16,41 +15,42 @@ class AppSiteApi(WebApiResource): """Resource for app sites.""" model_config_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw(attribute='suggested_questions_list'), - 'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'), - 'more_like_this': fields.Raw(attribute='more_like_this_dict'), - 'model': fields.Raw(attribute='model_dict'), - 'user_input_form': fields.Raw(attribute='user_input_form_list'), - 'pre_prompt': fields.String, + "opening_statement": fields.String, + "suggested_questions": fields.Raw(attribute="suggested_questions_list"), + "suggested_questions_after_answer": fields.Raw(attribute="suggested_questions_after_answer_dict"), + "more_like_this": fields.Raw(attribute="more_like_this_dict"), + "model": fields.Raw(attribute="model_dict"), + "user_input_form": fields.Raw(attribute="user_input_form_list"), + "pre_prompt": fields.String, } site_fields = { - 'title': fields.String, - 'chat_color_theme': fields.String, - 'chat_color_theme_inverted': fields.Boolean, - 'icon_type': fields.String, - 'icon': fields.String, - 'icon_background': fields.String, - 'icon_url': AppIconUrlField, - 'description': fields.String, - 'copyright': fields.String, - 'privacy_policy': fields.String, - 'custom_disclaimer': fields.String, - 'default_language': fields.String, - 'prompt_public': fields.Boolean, - 'show_workflow_steps': fields.Boolean, + "title": fields.String, + "chat_color_theme": fields.String, + "chat_color_theme_inverted": fields.Boolean, + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, + "description": fields.String, + "copyright": fields.String, + "privacy_policy": fields.String, + "custom_disclaimer": fields.String, + "default_language": fields.String, + "prompt_public": fields.Boolean, + "show_workflow_steps": fields.Boolean, + "use_icon_as_answer_icon": fields.Boolean, } app_fields = { - 'app_id': fields.String, - 'end_user_id': fields.String, - 'enable_site': fields.Boolean, - 'site': fields.Nested(site_fields), - 'model_config': fields.Nested(model_config_fields, allow_null=True), - 'plan': fields.String, - 'can_replace_logo': fields.Boolean, - 'custom_config': fields.Raw(attribute='custom_config'), + "app_id": fields.String, + "end_user_id": fields.String, + "enable_site": fields.Boolean, + "site": fields.Nested(site_fields), + "model_config": fields.Nested(model_config_fields, allow_null=True), + "plan": fields.String, + "can_replace_logo": fields.Boolean, + "custom_config": fields.Raw(attribute="custom_config"), } @marshal_with(app_fields) @@ -70,7 +70,7 @@ class AppSiteApi(WebApiResource): return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo) -api.add_resource(AppSiteApi, '/site') +api.add_resource(AppSiteApi, "/site") class AppSiteInfo: @@ -88,9 +88,13 @@ class AppSiteInfo: if can_replace_logo: base_url = dify_config.FILES_URL - remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False) - replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None + remove_webapp_brand = tenant.custom_config_dict.get("remove_webapp_brand", False) + replace_webapp_logo = ( + f"{base_url}/files/workspaces/{tenant.id}/webapp-logo" + if tenant.custom_config_dict.get("replace_webapp_logo") + else None + ) self.custom_config = { - 'remove_webapp_brand': remove_webapp_brand, - 'replace_webapp_logo': replace_webapp_logo, + "remove_webapp_brand": remove_webapp_brand, + "replace_webapp_logo": replace_webapp_logo, } diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index 77c468e417..55b0c3e2ab 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -33,17 +33,13 @@ class WorkflowRunApi(WebApiResource): raise NotWorkflowAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') - parser.add_argument('files', type=list, required=False, location='json') + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("files", type=list, required=False, location="json") args = parser.parse_args() try: response = AppGenerateService.generate( - app_model=app_model, - user=end_user, - args=args, - invoke_from=InvokeFrom.WEB_APP, - streaming=True + app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.WEB_APP, streaming=True ) return helper.compact_generate_response(response) @@ -73,10 +69,8 @@ class WorkflowTaskStopApi(WebApiResource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) - return { - "result": "success" - } + return {"result": "success"} -api.add_resource(WorkflowRunApi, '/workflows/run') -api.add_resource(WorkflowTaskStopApi, '/workflows/tasks//stop') +api.add_resource(WorkflowRunApi, "/workflows/run") +api.add_resource(WorkflowTaskStopApi, "/workflows/tasks//stop") diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index f5ab49d7e1..93dc691d62 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -8,6 +8,7 @@ from controllers.web.error import WebSSOAuthRequiredError from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site +from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService @@ -18,7 +19,9 @@ def validate_jwt_token(view=None): app_model, end_user = decode_jwt_token() return view(app_model, end_user, *args, **kwargs) + return decorated + if view: return decorator(view) return decorator @@ -26,56 +29,62 @@ def validate_jwt_token(view=None): def decode_jwt_token(): system_features = FeatureService.get_system_features() - + app_code = request.headers.get("X-App-Code") try: - auth_header = request.headers.get('Authorization') + auth_header = request.headers.get("Authorization") if auth_header is None: - raise Unauthorized('Authorization header is missing.') + raise Unauthorized("Authorization header is missing.") - if ' ' not in auth_header: - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + if " " not in auth_header: + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") auth_scheme, tk = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() - if auth_scheme != 'bearer': - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + if auth_scheme != "bearer": + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") decoded = PassportService().verify(tk) - app_code = decoded.get('app_code') - app_model = db.session.query(App).filter(App.id == decoded['app_id']).first() + app_code = decoded.get("app_code") + app_model = db.session.query(App).filter(App.id == decoded["app_id"]).first() site = db.session.query(Site).filter(Site.code == app_code).first() if not app_model: raise NotFound() if not app_code or not site: - raise BadRequest('Site URL is no longer valid.') + raise BadRequest("Site URL is no longer valid.") if app_model.enable_site is False: - raise BadRequest('Site is disabled.') - end_user = db.session.query(EndUser).filter(EndUser.id == decoded['end_user_id']).first() + raise BadRequest("Site is disabled.") + end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first() if not end_user: raise NotFound() - _validate_web_sso_token(decoded, system_features) + _validate_web_sso_token(decoded, system_features, app_code) return app_model, end_user except Unauthorized as e: if system_features.sso_enforced_for_web: - raise WebSSOAuthRequiredError() + app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False) + if app_web_sso_enabled: + raise WebSSOAuthRequiredError() raise Unauthorized(e.description) -def _validate_web_sso_token(decoded, system_features): +def _validate_web_sso_token(decoded, system_features, app_code): + app_web_sso_enabled = False + # Check if SSO is enforced for web, and if the token source is not SSO, raise an error and redirect to SSO login if system_features.sso_enforced_for_web: - source = decoded.get('token_source') - if not source or source != 'sso': - raise WebSSOAuthRequiredError() + app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False) + if app_web_sso_enabled: + source = decoded.get("token_source") + if not source or source != "sso": + raise WebSSOAuthRequiredError() # Check if SSO is not enforced for web, and if the token source is SSO, raise an error and redirect to normal passport login - if not system_features.sso_enforced_for_web: - source = decoded.get('token_source') - if source and source == 'sso': - raise Unauthorized('sso token expired.') + if not system_features.sso_enforced_for_web or not app_web_sso_enabled: + source = decoded.get("token_source") + if source and source == "sso": + raise Unauthorized("sso token expired.") class WebApiResource(Resource): diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index f4e6675bd4..1a621d2090 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -93,7 +93,7 @@ class DatasetConfigManager: reranking_model=dataset_configs.get('reranking_model'), weights=dataset_configs.get('weights'), reranking_enabled=dataset_configs.get('reranking_enabled', True), - rerank_mode=dataset_configs["reranking_mode"], + rerank_mode=dataset_configs.get('reranking_mode', 'reranking_model'), ) ) diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index 3eb006b46e..15fa4d99fd 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -1,6 +1,6 @@ import re -from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity +from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType from core.external_data_tool.factory import ExternalDataToolFactory @@ -13,7 +13,7 @@ class BasicVariablesConfigManager: :param config: model config args """ external_data_variables = [] - variables = [] + variable_entities = [] # old external_data_tools external_data_tools = config.get('external_data_tools', []) @@ -30,50 +30,41 @@ class BasicVariablesConfigManager: ) # variables and external_data_tools - for variable in config.get('user_input_form', []): - typ = list(variable.keys())[0] - if typ == 'external_data_tool': - val = variable[typ] - if 'config' not in val: + for variables in config.get('user_input_form', []): + variable_type = list(variables.keys())[0] + if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL: + variable = variables[variable_type] + if 'config' not in variable: continue external_data_variables.append( ExternalDataVariableEntity( - variable=val['variable'], - type=val['type'], - config=val['config'] + variable=variable['variable'], + type=variable['type'], + config=variable['config'] ) ) - elif typ in [ - VariableEntity.Type.TEXT_INPUT.value, - VariableEntity.Type.PARAGRAPH.value, - VariableEntity.Type.NUMBER.value, + elif variable_type in [ + VariableEntityType.TEXT_INPUT, + VariableEntityType.PARAGRAPH, + VariableEntityType.NUMBER, + VariableEntityType.SELECT, ]: - variables.append( + variable = variables[variable_type] + variable_entities.append( VariableEntity( - type=VariableEntity.Type.value_of(typ), - variable=variable[typ].get('variable'), - description=variable[typ].get('description'), - label=variable[typ].get('label'), - required=variable[typ].get('required', False), - max_length=variable[typ].get('max_length'), - default=variable[typ].get('default'), - ) - ) - elif typ == VariableEntity.Type.SELECT.value: - variables.append( - VariableEntity( - type=VariableEntity.Type.SELECT, - variable=variable[typ].get('variable'), - description=variable[typ].get('description'), - label=variable[typ].get('label'), - required=variable[typ].get('required', False), - options=variable[typ].get('options'), - default=variable[typ].get('default'), + type=variable_type, + variable=variable.get('variable'), + description=variable.get('description'), + label=variable.get('label'), + required=variable.get('required', False), + max_length=variable.get('max_length'), + options=variable.get('options'), + default=variable.get('default'), ) ) - return variables, external_data_variables + return variable_entities, external_data_variables @classmethod def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: @@ -183,4 +174,4 @@ class BasicVariablesConfigManager: config=config ) - return config, ["external_data_tools"] \ No newline at end of file + return config, ["external_data_tools"] diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 05a42a898e..bbb10d3d76 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -82,43 +82,29 @@ class PromptTemplateEntity(BaseModel): advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None +class VariableEntityType(str, Enum): + TEXT_INPUT = "text-input" + SELECT = "select" + PARAGRAPH = "paragraph" + NUMBER = "number" + EXTERNAL_DATA_TOOL = "external-data-tool" + + class VariableEntity(BaseModel): """ Variable Entity. """ - class Type(Enum): - TEXT_INPUT = 'text-input' - SELECT = 'select' - PARAGRAPH = 'paragraph' - NUMBER = 'number' - - @classmethod - def value_of(cls, value: str) -> 'VariableEntity.Type': - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f'invalid variable type value {value}') variable: str label: str description: Optional[str] = None - type: Type + type: VariableEntityType required: bool = False max_length: Optional[int] = None options: Optional[list[str]] = None default: Optional[str] = None hint: Optional[str] = None - @property - def name(self) -> str: - return self.variable - class ExternalDataVariableEntity(BaseModel): """ @@ -252,4 +238,4 @@ class WorkflowUIBasedAppConfig(AppConfig): """ Workflow UI Based App Config Entity. """ - workflow_id: str \ No newline at end of file + workflow_id: str diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 351eb05d8a..e7c9ebe097 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -4,7 +4,7 @@ import os import threading import uuid from collections.abc import Generator -from typing import Union +from typing import Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -29,7 +29,7 @@ from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from extensions.ext_database import db from models.account import Account from models.model import App, Conversation, EndUser, Message @@ -39,6 +39,26 @@ logger = logging.getLogger(__name__) class AdvancedChatAppGenerator(MessageBasedAppGenerator): + @overload + def generate( + self, app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[True] = True, + ) -> Generator[str, None, None]: ... + + @overload + def generate( + self, app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[False] = False, + ) -> dict: ... + def generate( self, app_model: App, workflow: Workflow, @@ -46,7 +66,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): args: dict, invoke_from: InvokeFrom, stream: bool = True, - ) -> Union[dict, Generator[dict, None, None]]: + ): """ Generate App response. @@ -73,8 +93,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # get conversation conversation = None - if args.get('conversation_id'): - conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + conversation_id = args.get('conversation_id') + if conversation_id: + conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user) # parse files files = args['files'] if args.get('files') else [] @@ -133,8 +154,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): node_id: str, user: Account, args: dict, - stream: bool = True) \ - -> Union[dict, Generator[dict, None, None]]: + stream: bool = True): """ Generate App response. @@ -157,8 +177,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # get conversation conversation = None - if args.get('conversation_id'): - conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + conversation_id = args.get('conversation_id') + if conversation_id: + conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user) # convert to app config app_config = AdvancedChatAppConfigManager.get_app_config( @@ -200,8 +221,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): invoke_from: InvokeFrom, application_generate_entity: AdvancedChatAppGenerateEntity, conversation: Conversation | None = None, - stream: bool = True) \ - -> Union[dict, Generator[dict, None, None]]: + stream: bool = True): is_first_conversation = False if not conversation: is_first_conversation = True @@ -270,11 +290,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # Create a variable pool. system_inputs = { - SystemVariable.QUERY: query, - SystemVariable.FILES: files, - SystemVariable.CONVERSATION_ID: conversation_id, - SystemVariable.USER_ID: user_id, - SystemVariable.DIALOGUE_COUNT: conversation_dialogue_count, + SystemVariableKey.QUERY: query, + SystemVariableKey.FILES: files, + SystemVariableKey.CONVERSATION_ID: conversation_id, + SystemVariableKey.USER_ID: user_id, + SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count, } variable_pool = VariablePool( system_variables=system_inputs, @@ -362,7 +382,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + if os.environ.get("DEBUG", "false").lower() == 'true': logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index ac51a4e840..2b3596ded2 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -49,7 +49,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.entities.node_entities import NodeType -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk from events.message_event import message_was_created @@ -74,7 +74,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc _workflow: Workflow _user: Union[Account, EndUser] # Deprecated - _workflow_system_variables: dict[SystemVariable, Any] + _workflow_system_variables: dict[SystemVariableKey, Any] _iteration_nested_relations: dict[str, list[str]] def __init__( @@ -108,10 +108,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._message = message # Deprecated self._workflow_system_variables = { - SystemVariable.QUERY: message.query, - SystemVariable.FILES: application_generate_entity.files, - SystemVariable.CONVERSATION_ID: conversation.id, - SystemVariable.USER_ID: user_id, + SystemVariableKey.QUERY: message.query, + SystemVariableKey.FILES: application_generate_entity.files, + SystemVariableKey.CONVERSATION_ID: conversation.id, + SystemVariableKey.USER_ID: user_id, } self._task_state = AdvancedChatTaskState( diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 53780bdfb0..daf37f4a50 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -3,7 +3,7 @@ import os import threading import uuid from collections.abc import Generator -from typing import Any, Union +from typing import Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -28,6 +28,24 @@ logger = logging.getLogger(__name__) class AgentChatAppGenerator(MessageBasedAppGenerator): + @overload + def generate( + self, app_model: App, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[True] = True, + ) -> Generator[dict, None, None]: ... + + @overload + def generate( + self, app_model: App, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[False] = False, + ) -> dict: ... + def generate(self, app_model: App, user: Union[Account, EndUser], args: Any, diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 6f48aa2363..9e331dff4d 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -1,7 +1,7 @@ from collections.abc import Mapping from typing import Any, Optional -from core.app.app_config.entities import AppConfig, VariableEntity +from core.app.app_config.entities import AppConfig, VariableEntity, VariableEntityType class BaseAppGenerator: @@ -9,29 +9,29 @@ class BaseAppGenerator: user_inputs = user_inputs or {} # Filter input variables from form configuration, handle required fields, default values, and option values variables = app_config.variables - filtered_inputs = {var.name: self._validate_input(inputs=user_inputs, var=var) for var in variables} + filtered_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables} filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()} return filtered_inputs def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity): - user_input_value = inputs.get(var.name) + user_input_value = inputs.get(var.variable) if var.required and not user_input_value: - raise ValueError(f'{var.name} is required in input form') + raise ValueError(f'{var.variable} is required in input form') if not var.required and not user_input_value: # TODO: should we return None here if the default value is None? return var.default or '' if ( var.type in ( - VariableEntity.Type.TEXT_INPUT, - VariableEntity.Type.SELECT, - VariableEntity.Type.PARAGRAPH, + VariableEntityType.TEXT_INPUT, + VariableEntityType.SELECT, + VariableEntityType.PARAGRAPH, ) and user_input_value and not isinstance(user_input_value, str) ): - raise ValueError(f"(type '{var.type}') {var.name} in input form must be a string") - if var.type == VariableEntity.Type.NUMBER and isinstance(user_input_value, str): + raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string") + if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str): # may raise ValueError if user_input_value is not a valid number try: if '.' in user_input_value: @@ -39,14 +39,14 @@ class BaseAppGenerator: else: return int(user_input_value) except ValueError: - raise ValueError(f"{var.name} in input form must be a valid number") - if var.type == VariableEntity.Type.SELECT: + raise ValueError(f"{var.variable} in input form must be a valid number") + if var.type == VariableEntityType.SELECT: options = var.options or [] if user_input_value not in options: - raise ValueError(f'{var.name} in input form must be one of the following: {options}') - elif var.type in (VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH): + raise ValueError(f'{var.variable} in input form must be one of the following: {options}') + elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH): if var.max_length and user_input_value and len(user_input_value) > var.max_length: - raise ValueError(f'{var.name} in input form must be less than {var.max_length} characters') + raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters') return user_input_value diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 5b896e2845..ab15928b74 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -3,7 +3,7 @@ import os import threading import uuid from collections.abc import Generator -from typing import Any, Union +from typing import Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -28,13 +28,31 @@ logger = logging.getLogger(__name__) class ChatAppGenerator(MessageBasedAppGenerator): + @overload + def generate( + self, app_model: App, + user: Union[Account, EndUser], + args: Any, + invoke_from: InvokeFrom, + stream: Literal[True] = True, + ) -> Generator[str, None, None]: ... + + @overload + def generate( + self, app_model: App, + user: Union[Account, EndUser], + args: Any, + invoke_from: InvokeFrom, + stream: Literal[False] = False, + ) -> dict: ... + def generate( self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True, - ) -> Union[dict, Generator[dict, None, None]]: + ) -> Union[dict, Generator[str, None, None]]: """ Generate App response. diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index c4e1caf65a..c0b13b40fd 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -3,7 +3,7 @@ import os import threading import uuid from collections.abc import Generator -from typing import Any, Union +from typing import Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -30,12 +30,30 @@ logger = logging.getLogger(__name__) class CompletionAppGenerator(MessageBasedAppGenerator): + @overload + def generate( + self, app_model: App, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[True] = True, + ) -> Generator[str, None, None]: ... + + @overload + def generate( + self, app_model: App, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[False] = False, + ) -> dict: ... + def generate(self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True) \ - -> Union[dict, Generator[dict, None, None]]: + -> Union[dict, Generator[str, None, None]]: """ Generate App response. @@ -203,7 +221,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): user: Union[Account, EndUser], invoke_from: InvokeFrom, stream: bool = True) \ - -> Union[dict, Generator[dict, None, None]]: + -> Union[dict, Generator[str, None, None]]: """ Generate App response. diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 49ef7b7b40..fceed95b91 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -1,6 +1,7 @@ import json import logging from collections.abc import Generator +from datetime import datetime, timezone from typing import Optional, Union from sqlalchemy import and_ @@ -36,17 +37,17 @@ logger = logging.getLogger(__name__) class MessageBasedAppGenerator(BaseAppGenerator): def _handle_response( - self, application_generate_entity: Union[ - ChatAppGenerateEntity, - CompletionAppGenerateEntity, - AgentChatAppGenerateEntity, - AdvancedChatAppGenerateEntity - ], - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, - user: Union[Account, EndUser], - stream: bool = False, + self, application_generate_entity: Union[ + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity, + AdvancedChatAppGenerateEntity + ], + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + user: Union[Account, EndUser], + stream: bool = False, ) -> Union[ ChatbotAppBlockingResponse, CompletionAppBlockingResponse, @@ -193,6 +194,9 @@ class MessageBasedAppGenerator(BaseAppGenerator): db.session.add(conversation) db.session.commit() db.session.refresh(conversation) + else: + conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + db.session.commit() message = Message( app_id=app_config.app_id, diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index df40aec154..26bb6c0f4f 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -4,7 +4,7 @@ import os import threading import uuid from collections.abc import Generator -from typing import Union +from typing import Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -32,6 +32,26 @@ logger = logging.getLogger(__name__) class WorkflowAppGenerator(BaseAppGenerator): + @overload + def generate( + self, app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[True] = True, + ) -> Generator[str, None, None]: ... + + @overload + def generate( + self, app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[False] = False, + ) -> dict: ... + def generate( self, app_model: App, workflow: Workflow, @@ -107,7 +127,7 @@ class WorkflowAppGenerator(BaseAppGenerator): application_generate_entity: WorkflowAppGenerateEntity, invoke_from: InvokeFrom, stream: bool = True, - ) -> Union[dict, Generator[dict, None, None]]: + ) -> Union[dict, Generator[str, None, None]]: """ Generate App response. diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 994919391e..e388d0184b 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -12,7 +12,7 @@ from core.app.entities.app_invoke_entities import ( ) from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base_node import UserFrom from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db @@ -67,8 +67,8 @@ class WorkflowAppRunner: # Create a variable pool. system_inputs = { - SystemVariable.FILES: files, - SystemVariable.USER_ID: user_id, + SystemVariableKey.FILES: files, + SystemVariableKey.USER_ID: user_id, } variable_pool = VariablePool( system_variables=system_inputs, diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 5022eb0438..de8542d7b9 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -43,7 +43,7 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.entities.node_entities import NodeType -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.end.end_node import EndNode from extensions.ext_database import db from models.account import Account @@ -67,7 +67,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa _user: Union[Account, EndUser] _task_state: WorkflowTaskState _application_generate_entity: WorkflowAppGenerateEntity - _workflow_system_variables: dict[SystemVariable, Any] + _workflow_system_variables: dict[SystemVariableKey, Any] _iteration_nested_relations: dict[str, list[str]] def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, @@ -92,8 +92,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa self._workflow = workflow self._workflow_system_variables = { - SystemVariable.FILES: application_generate_entity.files, - SystemVariable.USER_ID: user_id + SystemVariableKey.FILES: application_generate_entity.files, + SystemVariableKey.USER_ID: user_id } self._task_state = WorkflowTaskState( diff --git a/api/core/app/task_pipeline/workflow_cycle_state_manager.py b/api/core/app/task_pipeline/workflow_cycle_state_manager.py index 8baa8ba09e..bd98c82720 100644 --- a/api/core/app/task_pipeline/workflow_cycle_state_manager.py +++ b/api/core/app/task_pipeline/workflow_cycle_state_manager.py @@ -2,7 +2,7 @@ from typing import Any, Union from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from models.account import Account from models.model import EndUser from models.workflow import Workflow @@ -13,4 +13,4 @@ class WorkflowCycleStateManager: _workflow: Workflow _user: Union[Account, EndUser] _task_state: Union[AdvancedChatTaskState, WorkflowTaskState] - _workflow_system_variables: dict[SystemVariable, Any] + _workflow_system_variables: dict[SystemVariableKey, Any] diff --git a/api/core/errors/error.py b/api/core/errors/error.py index 859a747c12..53323a2eeb 100644 --- a/api/core/errors/error.py +++ b/api/core/errors/error.py @@ -43,3 +43,8 @@ class ModelCurrentlyNotSupportError(Exception): Custom exception raised when the model not support """ description = "Model Currently Not Support" + + +class InvokeRateLimitError(Exception): + """Raised when the Invoke returns rate limit error.""" + description = "Rate Limit Error" diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index 0296126d8b..8d73aa2b8b 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -65,7 +65,7 @@ class Extensible: if os.path.exists(builtin_file_path): with open(builtin_file_path, encoding='utf-8') as f: position = int(f.read().strip()) - position_map[extension_name] = position + position_map[extension_name] = position if (extension_name + '.py') not in file_names: logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index afb2bbbbf3..4662ebb47a 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -1,15 +1,13 @@ import logging -import time from enum import Enum from threading import Lock -from typing import Literal, Optional +from typing import Optional -from httpx import get, post +from httpx import Timeout, post from pydantic import BaseModel from yarl import URL from configs import dify_config -from core.helper.code_executor.entities import CodeDependency from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer @@ -17,12 +15,6 @@ from core.helper.code_executor.template_transformer import TemplateTransformer logger = logging.getLogger(__name__) -# Code Executor -CODE_EXECUTION_ENDPOINT = dify_config.CODE_EXECUTION_ENDPOINT -CODE_EXECUTION_API_KEY = dify_config.CODE_EXECUTION_API_KEY - -CODE_EXECUTION_TIMEOUT = (10, 60) - class CodeExecutionException(Exception): pass @@ -66,18 +58,17 @@ class CodeExecutor: def execute_code(cls, language: CodeLanguage, preload: str, - code: str, - dependencies: Optional[list[CodeDependency]] = None) -> str: + code: str) -> str: """ Execute code :param language: code language :param code: code :return: """ - url = URL(CODE_EXECUTION_ENDPOINT) / 'v1' / 'sandbox' / 'run' + url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) / 'v1' / 'sandbox' / 'run' headers = { - 'X-Api-Key': CODE_EXECUTION_API_KEY + 'X-Api-Key': dify_config.CODE_EXECUTION_API_KEY } data = { @@ -87,11 +78,13 @@ class CodeExecutor: 'enable_network': True } - if dependencies: - data['dependencies'] = [dependency.model_dump() for dependency in dependencies] - try: - response = post(str(url), json=data, headers=headers, timeout=CODE_EXECUTION_TIMEOUT) + response = post(str(url), json=data, headers=headers, + timeout=Timeout( + connect=dify_config.CODE_EXECUTION_CONNECT_TIMEOUT, + read=dify_config.CODE_EXECUTION_READ_TIMEOUT, + write=dify_config.CODE_EXECUTION_WRITE_TIMEOUT, + pool=None)) if response.status_code == 503: raise CodeExecutionException('Code execution service is unavailable') elif response.status_code != 200: @@ -102,7 +95,7 @@ class CodeExecutor: raise CodeExecutionException('Failed to execute code, which is likely a network issue,' ' please check if the sandbox service is running.' f' ( Error: {str(e)} )') - + try: response = response.json() except: @@ -110,16 +103,16 @@ class CodeExecutor: if (code := response.get('code')) != 0: raise CodeExecutionException(f"Got error code: {code}. Got error msg: {response.get('message')}") - + response = CodeExecutionResponse(**response) - + if response.data.error: raise CodeExecutionException(response.data.error) - - return response.data.stdout + + return response.data.stdout or '' @classmethod - def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict, dependencies: Optional[list[CodeDependency]] = None) -> dict: + def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict) -> dict: """ Execute code :param language: code language @@ -131,67 +124,11 @@ class CodeExecutor: if not template_transformer: raise CodeExecutionException(f'Unsupported language {language}') - runner, preload, dependencies = template_transformer.transform_caller(code, inputs, dependencies) + runner, preload = template_transformer.transform_caller(code, inputs) try: - response = cls.execute_code(language, preload, runner, dependencies) + response = cls.execute_code(language, preload, runner) except CodeExecutionException as e: raise e return template_transformer.transform_response(response) - - @classmethod - def list_dependencies(cls, language: str) -> list[CodeDependency]: - if language not in cls.supported_dependencies_languages: - return [] - - with cls.dependencies_cache_lock: - if language in cls.dependencies_cache: - # check expiration - dependencies = cls.dependencies_cache[language] - if dependencies['expiration'] > time.time(): - return dependencies['data'] - # remove expired cache - del cls.dependencies_cache[language] - - dependencies = cls._get_dependencies(language) - with cls.dependencies_cache_lock: - cls.dependencies_cache[language] = { - 'data': dependencies, - 'expiration': time.time() + 60 - } - - return dependencies - - @classmethod - def _get_dependencies(cls, language: Literal['python3']) -> list[CodeDependency]: - """ - List dependencies - """ - url = URL(CODE_EXECUTION_ENDPOINT) / 'v1' / 'sandbox' / 'dependencies' - - headers = { - 'X-Api-Key': CODE_EXECUTION_API_KEY - } - - running_language = cls.code_language_to_running_language.get(language) - if isinstance(running_language, Enum): - running_language = running_language.value - - data = { - 'language': running_language, - } - - try: - response = get(str(url), params=data, headers=headers, timeout=CODE_EXECUTION_TIMEOUT) - if response.status_code != 200: - raise Exception(f'Failed to list dependencies, got status code {response.status_code}, please check if the sandbox service is running') - response = response.json() - dependencies = response.get('data', {}).get('dependencies', []) - return [ - CodeDependency(**dependency) for dependency in dependencies - if dependency.get('name') not in Python3TemplateTransformer.get_standard_packages() - ] - except Exception as e: - logger.exception(f'Failed to list dependencies: {e}') - return [] \ No newline at end of file diff --git a/api/core/helper/code_executor/code_node_provider.py b/api/core/helper/code_executor/code_node_provider.py index 761c0e2b25..3f099b7ac5 100644 --- a/api/core/helper/code_executor/code_node_provider.py +++ b/api/core/helper/code_executor/code_node_provider.py @@ -2,8 +2,6 @@ from abc import abstractmethod from pydantic import BaseModel -from core.helper.code_executor.code_executor import CodeExecutor - class CodeNodeProvider(BaseModel): @staticmethod @@ -23,10 +21,6 @@ class CodeNodeProvider(BaseModel): """ pass - @classmethod - def get_default_available_packages(cls) -> list[dict]: - return [p.model_dump() for p in CodeExecutor.list_dependencies(cls.get_language())] - @classmethod def get_default_config(cls) -> dict: return { @@ -50,6 +44,5 @@ class CodeNodeProvider(BaseModel): "children": None } } - }, - "available_dependencies": cls.get_default_available_packages(), + } } diff --git a/api/core/helper/code_executor/entities.py b/api/core/helper/code_executor/entities.py deleted file mode 100644 index cc10288521..0000000000 --- a/api/core/helper/code_executor/entities.py +++ /dev/null @@ -1,6 +0,0 @@ -from pydantic import BaseModel - - -class CodeDependency(BaseModel): - name: str - version: str diff --git a/api/core/helper/code_executor/jinja2/jinja2_formatter.py b/api/core/helper/code_executor/jinja2/jinja2_formatter.py index 63f48a56e2..f1e5da584c 100644 --- a/api/core/helper/code_executor/jinja2/jinja2_formatter.py +++ b/api/core/helper/code_executor/jinja2/jinja2_formatter.py @@ -3,7 +3,7 @@ from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage class Jinja2Formatter: @classmethod - def format(cls, template: str, inputs: str) -> str: + def format(cls, template: str, inputs: dict) -> str: """ Format template :param template: template diff --git a/api/core/helper/code_executor/jinja2/jinja2_transformer.py b/api/core/helper/code_executor/jinja2/jinja2_transformer.py index a8f8095d52..b8cb29600e 100644 --- a/api/core/helper/code_executor/jinja2/jinja2_transformer.py +++ b/api/core/helper/code_executor/jinja2/jinja2_transformer.py @@ -1,14 +1,9 @@ from textwrap import dedent -from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer from core.helper.code_executor.template_transformer import TemplateTransformer class Jinja2TemplateTransformer(TemplateTransformer): - @classmethod - def get_standard_packages(cls) -> set[str]: - return {'jinja2'} | Python3TemplateTransformer.get_standard_packages() - @classmethod def transform_response(cls, response: str) -> dict: """ diff --git a/api/core/helper/code_executor/python3/python3_code_provider.py b/api/core/helper/code_executor/python3/python3_code_provider.py index efcb8a9d1e..923724b49d 100644 --- a/api/core/helper/code_executor/python3/python3_code_provider.py +++ b/api/core/helper/code_executor/python3/python3_code_provider.py @@ -13,7 +13,7 @@ class Python3CodeProvider(CodeNodeProvider): def get_default_code(cls) -> str: return dedent( """ - def main(arg1: int, arg2: int) -> dict: + def main(arg1: str, arg2: str) -> dict: return { "result": arg1 + arg2, } diff --git a/api/core/helper/code_executor/python3/python3_transformer.py b/api/core/helper/code_executor/python3/python3_transformer.py index 4a5fa35093..75a5a44d08 100644 --- a/api/core/helper/code_executor/python3/python3_transformer.py +++ b/api/core/helper/code_executor/python3/python3_transformer.py @@ -4,30 +4,6 @@ from core.helper.code_executor.template_transformer import TemplateTransformer class Python3TemplateTransformer(TemplateTransformer): - @classmethod - def get_standard_packages(cls) -> set[str]: - return { - 'base64', - 'binascii', - 'collections', - 'datetime', - 'functools', - 'hashlib', - 'hmac', - 'itertools', - 'json', - 'math', - 'operator', - 'os', - 'random', - 're', - 'string', - 'sys', - 'time', - 'traceback', - 'uuid', - } - @classmethod def get_runner_script(cls) -> str: runner_script = dedent(f""" diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index da7ef469d9..cf66558b65 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -2,9 +2,6 @@ import json import re from abc import ABC, abstractmethod from base64 import b64encode -from typing import Optional - -from core.helper.code_executor.entities import CodeDependency class TemplateTransformer(ABC): @@ -13,12 +10,7 @@ class TemplateTransformer(ABC): _result_tag: str = '<>' @classmethod - def get_standard_packages(cls) -> set[str]: - return set() - - @classmethod - def transform_caller(cls, code: str, inputs: dict, - dependencies: Optional[list[CodeDependency]] = None) -> tuple[str, str, list[CodeDependency]]: + def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]: """ Transform code to python runner :param code: code @@ -28,14 +20,7 @@ class TemplateTransformer(ABC): runner_script = cls.assemble_runner_script(code, inputs) preload_script = cls.get_preload_script() - packages = dependencies or [] - standard_packages = cls.get_standard_packages() - for package in standard_packages: - if package not in packages: - packages.append(CodeDependency(name=package, version='')) - packages = list({dep.name: dep for dep in packages if dep.name}.values()) - - return runner_script, preload_script, packages + return runner_script, preload_script @classmethod def extract_result_str_from_response(cls, response: str) -> str: diff --git a/api/core/helper/position_helper.py b/api/core/helper/position_helper.py index dd1534c791..8cf184ac44 100644 --- a/api/core/helper/position_helper.py +++ b/api/core/helper/position_helper.py @@ -3,6 +3,7 @@ from collections import OrderedDict from collections.abc import Callable from typing import Any +from configs import dify_config from core.tools.utils.yaml_utils import load_yaml_file @@ -19,6 +20,87 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> return {name: index for index, name in enumerate(positions)} +def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]: + """ + Get the mapping for tools from name to index from a YAML file. + :param folder_path: + :param file_name: the YAML file name, default to '_position.yaml' + :return: a dict with name as key and index as value + """ + position_map = get_position_map(folder_path, file_name=file_name) + + return pin_position_map( + position_map, + pin_list=dify_config.POSITION_TOOL_PINS_LIST, + ) + + +def get_provider_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]: + """ + Get the mapping for providers from name to index from a YAML file. + :param folder_path: + :param file_name: the YAML file name, default to '_position.yaml' + :return: a dict with name as key and index as value + """ + position_map = get_position_map(folder_path, file_name=file_name) + return pin_position_map( + position_map, + pin_list=dify_config.POSITION_PROVIDER_PINS_LIST, + ) + + +def pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) -> dict[str, int]: + """ + Pin the items in the pin list to the beginning of the position map. + Overall logic: exclude > include > pin + :param position_map: the position map to be sorted and filtered + :param pin_list: the list of pins to be put at the beginning + :return: the sorted position map + """ + positions = sorted(original_position_map.keys(), key=lambda x: original_position_map[x]) + + # Add pins to position map + position_map = {name: idx for idx, name in enumerate(pin_list)} + + # Add remaining positions to position map + start_idx = len(position_map) + for name in positions: + if name not in position_map: + position_map[name] = start_idx + start_idx += 1 + + return position_map + + +def is_filtered( + include_set: set[str], + exclude_set: set[str], + data: Any, + name_func: Callable[[Any], str], +) -> bool: + """ + Chcek if the object should be filtered out. + Overall logic: exclude > include > pin + :param include_set: the set of names to be included + :param exclude_set: the set of names to be excluded + :param name_func: the function to get the name of the object + :param data: the data to be filtered + :return: True if the object should be filtered out, False otherwise + """ + if not data: + return False + if not include_set and not exclude_set: + return False + + name = name_func(data) + + if name in exclude_set: # exclude_set is prioritized + return True + if include_set and name not in include_set: # filter out only if include_set is not empty + return True + return False + + def sort_by_position_map( position_map: dict[str, int], data: list[Any], diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index 5f7fec5833..ddcd751286 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -58,7 +58,8 @@ class HostingConfiguration: self.moderation_config = self.init_moderation_config(config) - def init_azure_openai(self, app_config: Config) -> HostingProvider: + @staticmethod + def init_azure_openai(app_config: Config) -> HostingProvider: quota_unit = QuotaUnit.TIMES if app_config.get("HOSTED_AZURE_OPENAI_ENABLED"): credentials = { @@ -145,7 +146,8 @@ class HostingConfiguration: quota_unit=quota_unit, ) - def init_anthropic(self, app_config: Config) -> HostingProvider: + @staticmethod + def init_anthropic(app_config: Config) -> HostingProvider: quota_unit = QuotaUnit.TOKENS quotas = [] @@ -180,7 +182,8 @@ class HostingConfiguration: quota_unit=quota_unit, ) - def init_minimax(self, app_config: Config) -> HostingProvider: + @staticmethod + def init_minimax(app_config: Config) -> HostingProvider: quota_unit = QuotaUnit.TOKENS if app_config.get("HOSTED_MINIMAX_ENABLED"): quotas = [FreeHostingQuota()] @@ -197,7 +200,8 @@ class HostingConfiguration: quota_unit=quota_unit, ) - def init_spark(self, app_config: Config) -> HostingProvider: + @staticmethod + def init_spark(app_config: Config) -> HostingProvider: quota_unit = QuotaUnit.TOKENS if app_config.get("HOSTED_SPARK_ENABLED"): quotas = [FreeHostingQuota()] @@ -214,7 +218,8 @@ class HostingConfiguration: quota_unit=quota_unit, ) - def init_zhipuai(self, app_config: Config) -> HostingProvider: + @staticmethod + def init_zhipuai(app_config: Config) -> HostingProvider: quota_unit = QuotaUnit.TOKENS if app_config.get("HOSTED_ZHIPUAI_ENABLED"): quotas = [FreeHostingQuota()] @@ -231,7 +236,8 @@ class HostingConfiguration: quota_unit=quota_unit, ) - def init_moderation_config(self, app_config: Config) -> HostedModerationConfig: + @staticmethod + def init_moderation_config(app_config: Config) -> HostedModerationConfig: if app_config.get("HOSTED_MODERATION_ENABLED") \ and app_config.get("HOSTED_MODERATION_PROVIDERS"): return HostedModerationConfig( diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index b20c6ed187..7a1c5e760b 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -16,9 +16,7 @@ from configs import dify_config from core.errors.error import ProviderTokenNotInitError from core.llm_generator.llm_generator import LLMGenerator from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.model_entities import ModelType, PriceType -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting @@ -255,11 +253,8 @@ class IndexingRunner: tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) - tokens = 0 preview_texts = [] total_segments = 0 - total_price = 0 - currency = 'USD' index_type = doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() all_text_docs = [] @@ -286,54 +281,22 @@ class IndexingRunner: for document in documents: if len(preview_texts) < 5: preview_texts.append(document.page_content) - if indexing_technique == 'high_quality' or embedding_model_instance: - tokens += embedding_model_instance.get_text_embedding_num_tokens( - texts=[self.filter_string(document.page_content)] - ) if doc_form and doc_form == 'qa_model': - model_instance = self.model_manager.get_default_model_instance( - tenant_id=tenant_id, - model_type=ModelType.LLM - ) - - model_type_instance = model_instance.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) if len(preview_texts) > 0: # qa model document response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], doc_language) document_qa_list = self.format_split_text(response) - price_info = model_type_instance.get_price( - model=model_instance.model, - credentials=model_instance.credentials, - price_type=PriceType.INPUT, - tokens=total_segments * 2000, - ) + return { "total_segments": total_segments * 20, - "tokens": total_segments * 2000, - "total_price": '{:f}'.format(price_info.total_amount), - "currency": price_info.currency, "qa_preview": document_qa_list, "preview": preview_texts } - if embedding_model_instance: - embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_instance.model_type_instance) - embedding_price_info = embedding_model_type_instance.get_price( - model=embedding_model_instance.model, - credentials=embedding_model_instance.credentials, - price_type=PriceType.INPUT, - tokens=tokens - ) - total_price = '{:f}'.format(embedding_price_info.total_amount) - currency = embedding_price_info.currency return { "total_segments": total_segments, - "tokens": tokens, - "total_price": total_price, - "currency": currency, "preview": preview_texts } @@ -411,7 +374,8 @@ class IndexingRunner: return text_docs - def filter_string(self, text): + @staticmethod + def filter_string(text): text = re.sub(r'<\|', '<', text) text = re.sub(r'\|>', '>', text) text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text) @@ -419,7 +383,8 @@ class IndexingRunner: text = re.sub('\uFFFE', '', text) return text - def _get_splitter(self, processing_rule: DatasetProcessRule, + @staticmethod + def _get_splitter(processing_rule: DatasetProcessRule, embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: """ Get the NodeParser object according to the processing rule. @@ -611,7 +576,8 @@ class IndexingRunner: return all_documents - def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str: + @staticmethod + def _document_clean(text: str, processing_rule: DatasetProcessRule) -> str: """ Clean the document text according to the processing rules. """ @@ -640,7 +606,8 @@ class IndexingRunner: return text - def format_split_text(self, text): + @staticmethod + def format_split_text(text): regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)" matches = re.findall(regex, text, re.UNICODE) @@ -700,10 +667,12 @@ class IndexingRunner: DatasetDocument.tokens: tokens, DatasetDocument.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at, + DatasetDocument.error: None, } ) - def _process_keyword_index(self, flask_app, dataset_id, document_id, documents): + @staticmethod + def _process_keyword_index(flask_app, dataset_id, document_id, documents): with flask_app.app_context(): dataset = Dataset.query.filter_by(id=dataset_id).first() if not dataset: @@ -714,6 +683,7 @@ class IndexingRunner: document_ids = [document.metadata['doc_id'] for document in documents] db.session.query(DocumentSegment).filter( DocumentSegment.document_id == document_id, + DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id.in_(document_ids), DocumentSegment.status == "indexing" ).update({ @@ -745,6 +715,7 @@ class IndexingRunner: document_ids = [document.metadata['doc_id'] for document in chunk_documents] db.session.query(DocumentSegment).filter( DocumentSegment.document_id == dataset_document.id, + DocumentSegment.dataset_id == dataset.id, DocumentSegment.index_node_id.in_(document_ids), DocumentSegment.status == "indexing" ).update({ @@ -757,13 +728,15 @@ class IndexingRunner: return tokens - def _check_document_paused_status(self, document_id: str): + @staticmethod + def _check_document_paused_status(document_id: str): indexing_cache_key = 'document_{}_is_paused'.format(document_id) result = redis_client.get(indexing_cache_key) if result: raise DocumentIsPausedException() - def _update_document_index_status(self, document_id: str, after_indexing_status: str, + @staticmethod + def _update_document_index_status(document_id: str, after_indexing_status: str, extra_update_params: Optional[dict] = None) -> None: """ Update the document indexing status. @@ -785,14 +758,16 @@ class IndexingRunner: DatasetDocument.query.filter_by(id=document_id).update(update_params) db.session.commit() - def _update_segments_by_document(self, dataset_document_id: str, update_params: dict) -> None: + @staticmethod + def _update_segments_by_document(dataset_document_id: str, update_params: dict) -> None: """ Update the document segment by document id. """ DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params) db.session.commit() - def batch_add_segments(self, segments: list[DocumentSegment], dataset: Dataset): + @staticmethod + def batch_add_segments(segments: list[DocumentSegment], dataset: Dataset): """ Batch add segments index processing """ diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 7c23e14297..7b1a7ada5b 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -44,7 +44,8 @@ class ModelInstance: credentials=self.credentials ) - def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict: + @staticmethod + def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str) -> dict: """ Fetch credentials from provider model bundle :param provider_model_bundle: provider model bundle @@ -63,7 +64,8 @@ class ModelInstance: return credentials - def _get_load_balancing_manager(self, configuration: ProviderConfiguration, + @staticmethod + def _get_load_balancing_manager(configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict) -> Optional["LBModelManager"]: @@ -368,6 +370,15 @@ class ModelManager: return ModelInstance(provider_model_bundle, model) + def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]: + """ + Return first provider and the first model in the provider + :param tenant_id: tenant id + :param model_type: model type + :return: provider name, model name + """ + return self._provider_manager.get_first_provider_first_model(tenant_id, model_type) + def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance: """ Get default model instance @@ -502,13 +513,12 @@ class LBModelManager: config.id ) - res = redis_client.exists(cooldown_cache_key) res = cast(bool, res) return res - @classmethod - def get_config_in_cooldown_and_ttl(cls, tenant_id: str, + @staticmethod + def get_config_in_cooldown_and_ttl(tenant_id: str, provider: str, model_type: ModelType, model: str, diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 0de216bf89..716bb63566 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -151,9 +151,9 @@ class AIModel(ABC): os.path.join(provider_model_type_path, model_schema_yaml) for model_schema_yaml in os.listdir(provider_model_type_path) if not model_schema_yaml.startswith('__') - and not model_schema_yaml.startswith('_') - and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml)) - and model_schema_yaml.endswith('.yaml') + and not model_schema_yaml.startswith('_') + and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml)) + and model_schema_yaml.endswith('.yaml') ] # get _position.yaml file path diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index 19ce401999..81be1a06a7 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -1,6 +1,6 @@ import base64 +import io import json -import mimetypes from collections.abc import Generator from typing import Optional, Union, cast @@ -18,6 +18,7 @@ from anthropic.types import ( ) from anthropic.types.beta.tools import ToolsBetaMessage from httpx import Timeout +from PIL import Image from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta @@ -462,7 +463,8 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): # fetch image data from url try: image_content = requests.get(message_content.data).content - mime_type, _ = mimetypes.guess_type(message_content.data) + with Image.open(io.BytesIO(image_content)) as img: + mime_type = f"image/{img.format.lower()}" base64_data = base64.b64encode(image_content).decode('utf-8') except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") diff --git a/api/core/model_runtime/model_providers/azure_ai_studio/__init__.py b/api/core/model_runtime/model_providers/azure_ai_studio/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/azure_ai_studio/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/azure_ai_studio/_assets/icon_l_en.png new file mode 100644 index 0000000000..4b941654a7 Binary files /dev/null and b/api/core/model_runtime/model_providers/azure_ai_studio/_assets/icon_l_en.png differ diff --git a/api/core/model_runtime/model_providers/azure_ai_studio/_assets/icon_s_en.png b/api/core/model_runtime/model_providers/azure_ai_studio/_assets/icon_s_en.png new file mode 100644 index 0000000000..ca3043dc8d Binary files /dev/null and b/api/core/model_runtime/model_providers/azure_ai_studio/_assets/icon_s_en.png differ diff --git a/api/core/model_runtime/model_providers/azure_ai_studio/azure_ai_studio.py b/api/core/model_runtime/model_providers/azure_ai_studio/azure_ai_studio.py new file mode 100644 index 0000000000..75d21d1ce9 --- /dev/null +++ b/api/core/model_runtime/model_providers/azure_ai_studio/azure_ai_studio.py @@ -0,0 +1,17 @@ +import logging + +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class AzureAIStudioProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + pass diff --git a/api/core/model_runtime/model_providers/azure_ai_studio/azure_ai_studio.yaml b/api/core/model_runtime/model_providers/azure_ai_studio/azure_ai_studio.yaml new file mode 100644 index 0000000000..9e17ba0884 --- /dev/null +++ b/api/core/model_runtime/model_providers/azure_ai_studio/azure_ai_studio.yaml @@ -0,0 +1,65 @@ +provider: azure_ai_studio +label: + zh_Hans: Azure AI Studio + en_US: Azure AI Studio +icon_small: + en_US: icon_s_en.png +icon_large: + en_US: icon_l_en.png +description: + en_US: Azure AI Studio + zh_Hans: Azure AI Studio +background: "#93c5fd" +help: + title: + en_US: How to deploy customized model on Azure AI Studio + zh_Hans: 如何在Azure AI Studio上的私有化部署的模型 + url: + en_US: https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models + zh_Hans: https://learn.microsoft.com/zh-cn/azure/ai-studio/how-to/deploy-models +supported_model_types: + - llm + - rerank +configurate_methods: + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: endpoint + label: + en_US: Azure AI Studio Endpoint + type: text-input + required: true + placeholder: + zh_Hans: 请输入你的Azure AI Studio推理端点 + en_US: 'Enter your API Endpoint, eg: https://example.com' + - variable: api_key + required: true + label: + en_US: API Key + zh_Hans: API Key + type: secret-input + placeholder: + en_US: Enter your Azure AI Studio API Key + zh_Hans: 在此输入您的 Azure AI Studio API Key + show_on: + - variable: __model_type + value: llm + - variable: jwt_token + required: true + label: + en_US: JWT Token + zh_Hans: JWT令牌 + type: secret-input + placeholder: + en_US: Enter your Azure AI Studio JWT Token + zh_Hans: 在此输入您的 Azure AI Studio 推理 API Key + show_on: + - variable: __model_type + value: rerank diff --git a/api/core/model_runtime/model_providers/azure_ai_studio/llm/__init__.py b/api/core/model_runtime/model_providers/azure_ai_studio/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/azure_ai_studio/llm/llm.py b/api/core/model_runtime/model_providers/azure_ai_studio/llm/llm.py new file mode 100644 index 0000000000..42eae6c1e5 --- /dev/null +++ b/api/core/model_runtime/model_providers/azure_ai_studio/llm/llm.py @@ -0,0 +1,334 @@ +import logging +from collections.abc import Generator +from typing import Any, Optional, Union + +from azure.ai.inference import ChatCompletionsClient +from azure.ai.inference.models import StreamingChatCompletionsUpdate +from azure.core.credentials import AzureKeyCredential +from azure.core.exceptions import ( + ClientAuthenticationError, + DecodeError, + DeserializationError, + HttpResponseError, + ResourceExistsError, + ResourceModifiedError, + ResourceNotFoundError, + ResourceNotModifiedError, + SerializationError, + ServiceRequestError, + ServiceResponseError, +) + +from core.model_runtime.callbacks.base_callback import Callback +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + I18nObject, + ModelType, + ParameterRule, + ParameterType, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + +logger = logging.getLogger(__name__) + + +class AzureAIStudioLargeLanguageModel(LargeLanguageModel): + """ + Model class for Azure AI Studio large language model. + """ + + client: Any = None + + from azure.ai.inference.models import StreamingChatCompletionsUpdate + + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + """ + Invoke large language model + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param tools: tools for tool calling + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + + if not self.client: + endpoint = credentials.get("endpoint") + api_key = credentials.get("api_key") + self.client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(api_key)) + + messages = [{"role": msg.role.value, "content": msg.content} for msg in prompt_messages] + + payload = { + "messages": messages, + "max_tokens": model_parameters.get("max_tokens", 4096), + "temperature": model_parameters.get("temperature", 0), + "top_p": model_parameters.get("top_p", 1), + "stream": stream, + } + + if stop: + payload["stop"] = stop + + if tools: + payload["tools"] = [tool.model_dump() for tool in tools] + + try: + response = self.client.complete(**payload) + + if stream: + return self._handle_stream_response(response, model, prompt_messages) + else: + return self._handle_non_stream_response(response, model, prompt_messages, credentials) + except Exception as e: + raise self._transform_invoke_error(e) + + def _handle_stream_response(self, response, model: str, prompt_messages: list[PromptMessage]) -> Generator: + for chunk in response: + if isinstance(chunk, StreamingChatCompletionsUpdate): + if chunk.choices: + delta = chunk.choices[0].delta + if delta.content: + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=delta.content, tool_calls=[]), + ), + ) + + def _handle_non_stream_response( + self, response, model: str, prompt_messages: list[PromptMessage], credentials: dict + ) -> LLMResult: + assistant_text = response.choices[0].message.content + assistant_prompt_message = AssistantPromptMessage(content=assistant_text) + usage = self._calc_response_usage( + model, credentials, response.usage.prompt_tokens, response.usage.completion_tokens + ) + result = LLMResult(model=model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage) + + if hasattr(response, "system_fingerprint"): + result.system_fingerprint = response.system_fingerprint + + return result + + def _invoke_result_generator( + self, + model: str, + result: Generator, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Generator: + """ + Invoke result generator + + :param result: result generator + :return: result generator + """ + callbacks = callbacks or [] + prompt_message = AssistantPromptMessage(content="") + usage = None + system_fingerprint = None + real_model = model + + try: + for chunk in result: + if isinstance(chunk, dict): + content = chunk["choices"][0]["message"]["content"] + usage = chunk["usage"] + chunk = LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=content, tool_calls=[]), + ), + system_fingerprint=chunk.get("system_fingerprint"), + ) + + yield chunk + + self._trigger_new_chunk_callbacks( + chunk=chunk, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + callbacks=callbacks, + ) + + prompt_message.content += chunk.delta.message.content + real_model = chunk.model + if hasattr(chunk.delta, "usage"): + usage = chunk.delta.usage + + if chunk.system_fingerprint: + system_fingerprint = chunk.system_fingerprint + except Exception as e: + raise self._transform_invoke_error(e) + + self._trigger_after_invoke_callbacks( + model=model, + result=LLMResult( + model=real_model, + prompt_messages=prompt_messages, + message=prompt_message, + usage=usage if usage else LLMUsage.empty_usage(), + system_fingerprint=system_fingerprint, + ), + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + callbacks=callbacks, + ) + + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :return: + """ + # Implement token counting logic here + # Might need to use a tokenizer specific to the Azure AI Studio model + return 0 + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + endpoint = credentials.get("endpoint") + api_key = credentials.get("api_key") + client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(api_key)) + client.get_model_info() + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + ServiceRequestError, + ], + InvokeServerUnavailableError: [ + ServiceResponseError, + ], + InvokeAuthorizationError: [ + ClientAuthenticationError, + ], + InvokeBadRequestError: [ + HttpResponseError, + DecodeError, + ResourceExistsError, + ResourceNotFoundError, + ResourceModifiedError, + ResourceNotModifiedError, + SerializationError, + DeserializationError, + ], + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + Used to define customizable model schema + """ + rules = [ + ParameterRule( + name="temperature", + type=ParameterType.FLOAT, + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), + ), + ParameterRule( + name="top_p", + type=ParameterType.FLOAT, + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), + ), + ParameterRule( + name="max_tokens", + type=ParameterType.INT, + use_template="max_tokens", + min=1, + default=512, + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), + ), + ] + + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.LLM, + features=[], + model_properties={}, + parameter_rules=rules, + ) + + return entity diff --git a/api/core/model_runtime/model_providers/azure_ai_studio/rerank/__init__.py b/api/core/model_runtime/model_providers/azure_ai_studio/rerank/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/azure_ai_studio/rerank/rerank.py b/api/core/model_runtime/model_providers/azure_ai_studio/rerank/rerank.py new file mode 100644 index 0000000000..6ed7ab277c --- /dev/null +++ b/api/core/model_runtime/model_providers/azure_ai_studio/rerank/rerank.py @@ -0,0 +1,164 @@ +import json +import logging +import os +import ssl +import urllib.request +from typing import Optional + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + +logger = logging.getLogger(__name__) + + +class AzureRerankModel(RerankModel): + """ + Model class for Azure AI Studio rerank model. + """ + + def _allow_self_signed_https(self, allowed): + # bypass the server certificate verification on client side + if allowed and not os.environ.get("PYTHONHTTPSVERIFY", "") and getattr(ssl, "_create_unverified_context", None): + ssl._create_default_https_context = ssl._create_unverified_context + + def _azure_rerank(self, query_input: str, docs: list[str], endpoint: str, api_key: str): + # self._allow_self_signed_https(True) # Enable if using self-signed certificate + + data = {"inputs": query_input, "docs": docs} + + body = json.dumps(data).encode("utf-8") + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} + + req = urllib.request.Request(endpoint, body, headers) + + try: + with urllib.request.urlopen(req) as response: + result = response.read() + return json.loads(result) + except urllib.error.HTTPError as error: + logger.error(f"The request failed with status code: {error.code}") + logger.error(error.info()) + logger.error(error.read().decode("utf8", "ignore")) + raise + + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id + :return: rerank result + """ + try: + if len(docs) == 0: + return RerankResult(model=model, docs=[]) + + endpoint = credentials.get("endpoint") + api_key = credentials.get("jwt_token") + + if not endpoint or not api_key: + raise ValueError("Azure endpoint and API key must be provided in credentials") + + result = self._azure_rerank(query, docs, endpoint, api_key) + logger.info(f"Azure rerank result: {result}") + + rerank_documents = [] + for idx, (doc, score_dict) in enumerate(zip(docs, result)): + score = score_dict["score"] + rerank_document = RerankDocument(index=idx, text=doc, score=score) + + if score_threshold is None or score >= score_threshold: + rerank_documents.append(rerank_document) + + rerank_documents.sort(key=lambda x: x.score, reverse=True) + + if top_n: + rerank_documents = rerank_documents[:top_n] + + return RerankResult(model=model, docs=rerank_documents) + + except Exception as e: + logger.exception(f"Exception in Azure rerank: {e}") + raise + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8, + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [urllib.error.URLError], + InvokeServerUnavailableError: [urllib.error.HTTPError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError, json.JSONDecodeError], + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.RERANK, + model_properties={}, + parameter_rules=[], + ) + + return entity diff --git a/api/core/model_runtime/model_providers/baichuan/baichuan.yaml b/api/core/model_runtime/model_providers/baichuan/baichuan.yaml index 792126af7f..81e6e36215 100644 --- a/api/core/model_runtime/model_providers/baichuan/baichuan.yaml +++ b/api/core/model_runtime/model_providers/baichuan/baichuan.yaml @@ -27,11 +27,3 @@ provider_credential_schema: placeholder: zh_Hans: 在此输入您的 API Key en_US: Enter your API Key - - variable: secret_key - label: - en_US: Secret Key - type: secret-input - required: false - placeholder: - zh_Hans: 在此输入您的 Secret Key - en_US: Enter your Secret Key diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-53b.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-53b.yaml index 04849500dc..8360dd5faf 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-53b.yaml +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-53b.yaml @@ -43,3 +43,4 @@ parameter_rules: zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。 en_US: Allow the model to perform external search to enhance the generation results. required: false +deprecated: true diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo-192k.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo-192k.yaml index c8156c152b..0ce0265cfe 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo-192k.yaml +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo-192k.yaml @@ -43,3 +43,4 @@ parameter_rules: zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。 en_US: Allow the model to perform external search to enhance the generation results. required: false +deprecated: true diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo.yaml index f91329c77a..ccb4ee8b92 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo.yaml +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo.yaml @@ -4,36 +4,32 @@ label: model_type: llm features: - agent-thought + - multi-tool-call model_properties: mode: chat context_size: 32000 parameter_rules: - name: temperature use_template: temperature + default: 0.3 - name: top_p use_template: top_p + default: 0.85 - name: top_k label: zh_Hans: 取样数量 en_US: Top k type: int + min: 0 + max: 20 + default: 5 help: zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 en_US: Only sample from the top K options for each subsequent token. required: false - name: max_tokens use_template: max_tokens - required: true - default: 8000 - min: 1 - max: 192000 - - name: presence_penalty - use_template: presence_penalty - - name: frequency_penalty - use_template: frequency_penalty - default: 1 - min: 1 - max: 2 + default: 2048 - name: with_search_enhance label: zh_Hans: 搜索增强 diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo-128k.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo-128k.yaml index bf72e82296..c6c6c7e9e9 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo-128k.yaml +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo-128k.yaml @@ -4,36 +4,44 @@ label: model_type: llm features: - agent-thought + - multi-tool-call model_properties: mode: chat context_size: 128000 parameter_rules: - name: temperature use_template: temperature + default: 0.3 - name: top_p use_template: top_p + default: 0.85 - name: top_k label: zh_Hans: 取样数量 en_US: Top k type: int + min: 0 + max: 20 + default: 5 help: zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 en_US: Only sample from the top K options for each subsequent token. required: false - name: max_tokens use_template: max_tokens - required: true - default: 8000 - min: 1 - max: 128000 - - name: presence_penalty - use_template: presence_penalty - - name: frequency_penalty - use_template: frequency_penalty - default: 1 - min: 1 - max: 2 + default: 2048 + - name: res_format + label: + zh_Hans: 回复格式 + en_US: response format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object - name: with_search_enhance label: zh_Hans: 搜索增强 diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo.yaml index 85882519b8..ee8a9ff0d5 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo.yaml +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo.yaml @@ -4,36 +4,44 @@ label: model_type: llm features: - agent-thought + - multi-tool-call model_properties: mode: chat context_size: 32000 parameter_rules: - name: temperature use_template: temperature + default: 0.3 - name: top_p use_template: top_p + default: 0.85 - name: top_k label: zh_Hans: 取样数量 en_US: Top k type: int + min: 0 + max: 20 + default: 5 help: zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 en_US: Only sample from the top K options for each subsequent token. required: false - name: max_tokens use_template: max_tokens - required: true - default: 8000 - min: 1 - max: 32000 - - name: presence_penalty - use_template: presence_penalty - - name: frequency_penalty - use_template: frequency_penalty - default: 1 - min: 1 - max: 2 + default: 2048 + - name: res_format + label: + zh_Hans: 回复格式 + en_US: response format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object - name: with_search_enhance label: zh_Hans: 搜索增强 diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan4.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan4.yaml index f8c6566081..e5e6aeb491 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan4.yaml +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan4.yaml @@ -4,36 +4,44 @@ label: model_type: llm features: - agent-thought + - multi-tool-call model_properties: mode: chat context_size: 32000 parameter_rules: - name: temperature use_template: temperature + default: 0.3 - name: top_p use_template: top_p + default: 0.85 - name: top_k label: zh_Hans: 取样数量 en_US: Top k type: int + min: 0 + max: 20 + default: 5 help: zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 en_US: Only sample from the top K options for each subsequent token. required: false - name: max_tokens use_template: max_tokens - required: true - default: 8000 - min: 1 - max: 32000 - - name: presence_penalty - use_template: presence_penalty - - name: frequency_penalty - use_template: frequency_penalty - default: 1 - min: 1 - max: 2 + default: 2048 + - name: res_format + label: + zh_Hans: 回复格式 + en_US: response format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object - name: with_search_enhance label: zh_Hans: 搜索增强 diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py index d7d8b7c91b..a8fd9dce91 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py @@ -1,11 +1,10 @@ -from collections.abc import Generator -from enum import Enum -from hashlib import md5 -from json import dumps, loads -from typing import Any, Union +import json +from collections.abc import Iterator +from typing import Any, Optional, Union from requests import post +from core.model_runtime.entities.message_entities import PromptMessageTool from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( BadRequestError, InsufficientAccountBalance, @@ -16,203 +15,133 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors impor ) -class BaichuanMessage: - class Role(Enum): - USER = 'user' - ASSISTANT = 'assistant' - # Baichuan does not have system message - _SYSTEM = 'system' - - role: str = Role.USER.value - content: str - usage: dict[str, int] = None - stop_reason: str = '' - - def to_dict(self) -> dict[str, Any]: - return { - 'role': self.role, - 'content': self.content, - } - - def __init__(self, content: str, role: str = 'user') -> None: - self.content = content - self.role = role - class BaichuanModel: api_key: str - secret_key: str - def __init__(self, api_key: str, secret_key: str = '') -> None: + def __init__(self, api_key: str) -> None: self.api_key = api_key - self.secret_key = secret_key - def _model_mapping(self, model: str) -> str: + @property + def _model_mapping(self) -> dict: return { - 'baichuan2-turbo': 'Baichuan2-Turbo', - 'baichuan2-turbo-192k': 'Baichuan2-Turbo-192k', - 'baichuan2-53b': 'Baichuan2-53B', - 'baichuan3-turbo': 'Baichuan3-Turbo', - 'baichuan3-turbo-128k': 'Baichuan3-Turbo-128k', - 'baichuan4': 'Baichuan4', - }[model] + "baichuan2-turbo": "Baichuan2-Turbo", + "baichuan3-turbo": "Baichuan3-Turbo", + "baichuan3-turbo-128k": "Baichuan3-Turbo-128k", + "baichuan4": "Baichuan4", + } - def _handle_chat_generate_response(self, response) -> BaichuanMessage: - resp = response.json() - choices = resp.get('choices', []) - message = BaichuanMessage(content='', role='assistant') - for choice in choices: - message.content += choice['message']['content'] - message.role = choice['message']['role'] - if choice['finish_reason']: - message.stop_reason = choice['finish_reason'] + @property + def request_headers(self) -> dict[str, Any]: + return { + "Content-Type": "application/json", + "Authorization": "Bearer " + self.api_key, + } - if 'usage' in resp: - message.usage = { - 'prompt_tokens': resp['usage']['prompt_tokens'], - 'completion_tokens': resp['usage']['completion_tokens'], - 'total_tokens': resp['usage']['total_tokens'], - } + def _build_parameters( + self, + model: str, + stream: bool, + messages: list[dict], + parameters: dict[str, Any], + tools: Optional[list[PromptMessageTool]] = None, + ) -> dict[str, Any]: + if model in self._model_mapping.keys(): + # the LargeLanguageModel._code_block_mode_wrapper() method will remove the response_format of parameters. + # we need to rename it to res_format to get its value + if parameters.get("res_format") == "json_object": + parameters["response_format"] = {"type": "json_object"} - return message - - def _handle_chat_stream_generate_response(self, response) -> Generator: - for line in response.iter_lines(): - if not line: - continue - line = line.decode('utf-8') - # remove the first `data: ` prefix - if line.startswith('data:'): - line = line[5:].strip() - try: - data = loads(line) - except Exception as e: - if line.strip() == '[DONE]': - return - choices = data.get('choices', []) - # save stop reason temporarily - stop_reason = '' - for choice in choices: - if choice.get('finish_reason'): - stop_reason = choice['finish_reason'] + if tools or parameters.get("with_search_enhance") is True: + parameters["tools"] = [] - if len(choice['delta']['content']) == 0: - continue - yield BaichuanMessage(**choice['delta']) - - # if there is usage, the response is the last one, yield it and return - if 'usage' in data: - message = BaichuanMessage(content='', role='assistant') - message.usage = { - 'prompt_tokens': data['usage']['prompt_tokens'], - 'completion_tokens': data['usage']['completion_tokens'], - 'total_tokens': data['usage']['total_tokens'], - } - message.stop_reason = stop_reason - yield message - - def _build_parameters(self, model: str, stream: bool, messages: list[BaichuanMessage], - parameters: dict[str, Any]) \ - -> dict[str, Any]: - if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b' - or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'): - prompt_messages = [] - for message in messages: - if message.role == BaichuanMessage.Role.USER.value or message.role == BaichuanMessage.Role._SYSTEM.value: - # check if the latest message is a user message - if len(prompt_messages) > 0 and prompt_messages[-1]['role'] == BaichuanMessage.Role.USER.value: - prompt_messages[-1]['content'] += message.content - else: - prompt_messages.append({ - 'content': message.content, - 'role': BaichuanMessage.Role.USER.value, - }) - elif message.role == BaichuanMessage.Role.ASSISTANT.value: - prompt_messages.append({ - 'content': message.content, - 'role': message.role, - }) - # [baichuan] frequency_penalty must be between 1 and 2 - if 'frequency_penalty' in parameters: - if parameters['frequency_penalty'] < 1 or parameters['frequency_penalty'] > 2: - parameters['frequency_penalty'] = 1 + # with_search_enhance is deprecated, use web_search instead + if parameters.get("with_search_enhance") is True: + parameters["tools"].append( + { + "type": "web_search", + "web_search": {"enable": True}, + } + ) + if tools: + for tool in tools: + parameters["tools"].append( + { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters, + }, + } + ) # turbo api accepts flat parameters return { - 'model': self._model_mapping(model), - 'stream': stream, - 'messages': prompt_messages, + "model": self._model_mapping.get(model), + "stream": stream, + "messages": messages, **parameters, } else: raise BadRequestError(f"Unknown model: {model}") - - def _build_headers(self, model: str, data: dict[str, Any]) -> dict[str, Any]: - if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b' - or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'): - # there is no secret key for turbo api - return { - 'Content-Type': 'application/json', - 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) ', - 'Authorization': 'Bearer ' + self.api_key, - } - else: - raise BadRequestError(f"Unknown model: {model}") - - def _calculate_md5(self, input_string): - return md5(input_string.encode('utf-8')).hexdigest() - def generate(self, model: str, stream: bool, messages: list[BaichuanMessage], - parameters: dict[str, Any], timeout: int) \ - -> Union[Generator, BaichuanMessage]: - - if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b' - or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'): - api_base = 'https://api.baichuan-ai.com/v1/chat/completions' + def generate( + self, + model: str, + stream: bool, + messages: list[dict], + parameters: dict[str, Any], + timeout: int, + tools: Optional[list[PromptMessageTool]] = None, + ) -> Union[Iterator, dict]: + + if model in self._model_mapping.keys(): + api_base = "https://api.baichuan-ai.com/v1/chat/completions" else: raise BadRequestError(f"Unknown model: {model}") - - try: - data = self._build_parameters(model, stream, messages, parameters) - headers = self._build_headers(model, data) - except KeyError: - raise InternalServerError(f"Failed to build parameters for model: {model}") + + data = self._build_parameters(model, stream, messages, parameters, tools) try: response = post( url=api_base, - headers=headers, - data=dumps(data), + headers=self.request_headers, + data=json.dumps(data), timeout=timeout, - stream=stream + stream=stream, ) except Exception as e: raise InternalServerError(f"Failed to invoke model: {e}") - + if response.status_code != 200: try: resp = response.json() # try to parse error message - err = resp['error']['code'] - msg = resp['error']['message'] + err = resp["error"]["type"] + msg = resp["error"]["message"] except Exception as e: - raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") + raise InternalServerError( + f"Failed to convert response to json: {e} with text: {response.text}" + ) - if err == 'invalid_api_key': + if err == "invalid_api_key": raise InvalidAPIKeyError(msg) - elif err == 'insufficient_quota': + elif err == "insufficient_quota": raise InsufficientAccountBalance(msg) - elif err == 'invalid_authentication': + elif err == "invalid_authentication": raise InvalidAuthenticationError(msg) - elif 'rate' in err: + elif err == "invalid_request_error": + raise BadRequestError(msg) + elif "rate" in err: raise RateLimitReachedError(msg) - elif 'internal' in err: + elif "internal" in err: raise InternalServerError(msg) - elif err == 'api_key_empty': + elif err == "api_key_empty": raise InvalidAPIKeyError(msg) else: raise InternalServerError(f"Unknown error: {err} with message: {msg}") - + if stream: - return self._handle_chat_stream_generate_response(response) + return response.iter_lines() else: - return self._handle_chat_generate_response(response) \ No newline at end of file + return response.json() diff --git a/api/core/model_runtime/model_providers/baichuan/llm/llm.py b/api/core/model_runtime/model_providers/baichuan/llm/llm.py index edcd3af420..4f44682e9f 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/llm.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/llm.py @@ -1,7 +1,12 @@ -from collections.abc import Generator +import json +from collections.abc import Generator, Iterator from typing import cast -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, +) from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -21,7 +26,7 @@ from core.model_runtime.errors.invoke import ( from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer -from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanMessage, BaichuanModel +from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanModel from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( BadRequestError, InsufficientAccountBalance, @@ -33,19 +38,40 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors impor class BaichuanLarguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, - model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: + return self._generate( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stream=stream, + ) + + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: return self._num_tokens_from_messages(prompt_messages) - def _num_tokens_from_messages(self, messages: list[PromptMessage], ) -> int: + def _num_tokens_from_messages( + self, + messages: list[PromptMessage], + ) -> int: """Calculate num tokens for baichuan model""" def tokens(text: str): @@ -59,10 +85,10 @@ class BaichuanLarguageModel(LargeLanguageModel): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -84,19 +110,18 @@ class BaichuanLarguageModel(LargeLanguageModel): elif isinstance(message, AssistantPromptMessage): message = cast(AssistantPromptMessage, message) message_dict = {"role": "assistant", "content": message.content} + if message.tool_calls: + message_dict["tool_calls"] = [tool_call.dict() for tool_call in + message.tool_calls] elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) - message_dict = {"role": "user", "content": message.content} + message_dict = {"role": "system", "content": message.content} elif isinstance(message, ToolPromptMessage): - # copy from core/model_runtime/model_providers/anthropic/llm/llm.py message = cast(ToolPromptMessage, message) message_dict = { - "role": "user", - "content": [{ - "type": "tool_result", - "tool_use_id": message.tool_call_id, - "content": message.content - }] + "role": "tool", + "content": message.content, + "tool_call_id": message.tool_call_id } else: raise ValueError(f"Unknown message type {type(message)}") @@ -105,102 +130,159 @@ class BaichuanLarguageModel(LargeLanguageModel): def validate_credentials(self, model: str, credentials: dict) -> None: # ping - instance = BaichuanModel( - api_key=credentials['api_key'], - secret_key=credentials.get('secret_key', '') - ) + instance = BaichuanModel(api_key=credentials["api_key"]) try: - instance.generate(model=model, stream=False, messages=[ - BaichuanMessage(content='ping', role='user') - ], parameters={ - 'max_tokens': 1, - }, timeout=60) + instance.generate( + model=model, + stream=False, + messages=[{"content": "ping", "role": "user"}], + parameters={ + "max_tokens": 1, + }, + timeout=60, + ) except Exception as e: raise CredentialsValidateFailedError(f"Invalid API key: {e}") - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - if tools is not None and len(tools) > 0: - raise InvokeBadRequestError("Baichuan model doesn't support tools") + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stream: bool = True, + ) -> LLMResult | Generator: - instance = BaichuanModel( - api_key=credentials['api_key'], - secret_key=credentials.get('secret_key', '') - ) - - # convert prompt messages to baichuan messages - messages = [ - BaichuanMessage( - content=message.content if isinstance(message.content, str) else ''.join([ - content.data for content in message.content - ]), - role=message.role.value - ) for message in prompt_messages - ] + instance = BaichuanModel(api_key=credentials["api_key"]) + messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] # invoke model - response = instance.generate(model=model, stream=stream, messages=messages, parameters=model_parameters, - timeout=60) + response = instance.generate( + model=model, + stream=stream, + messages=messages, + parameters=model_parameters, + timeout=60, + tools=tools, + ) if stream: - return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response) + return self._handle_chat_generate_stream_response( + model, prompt_messages, credentials, response + ) - return self._handle_chat_generate_response(model, prompt_messages, credentials, response) + return self._handle_chat_generate_response( + model, prompt_messages, credentials, response + ) + + def _handle_chat_generate_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: dict, + ) -> LLMResult: + choices = response.get("choices", []) + assistant_message = AssistantPromptMessage(content='', tool_calls=[]) + if choices and choices[0]["finish_reason"] == "tool_calls": + for choice in choices: + for tool_call in choice["message"]["tool_calls"]: + tool = AssistantPromptMessage.ToolCall( + id=tool_call.get("id", ""), + type=tool_call.get("type", ""), + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=tool_call.get("function", {}).get("name", ""), + arguments=tool_call.get("function", {}).get("arguments", "") + ), + ) + assistant_message.tool_calls.append(tool) + else: + for choice in choices: + assistant_message.content += choice["message"]["content"] + assistant_message.role = choice["message"]["role"] + + usage = response.get("usage") + if usage: + # transform usage + prompt_tokens = usage["prompt_tokens"] + completion_tokens = usage["completion_tokens"] + else: + # calculate num tokens + prompt_tokens = self._num_tokens_from_messages(prompt_messages) + completion_tokens = self._num_tokens_from_messages([assistant_message]) + + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) - def _handle_chat_generate_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: BaichuanMessage) -> LLMResult: - # convert baichuan message to llm result - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=response.usage['prompt_tokens'], - completion_tokens=response.usage['completion_tokens']) return LLMResult( model=model, prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=response.content, - tool_calls=[] - ), + message=assistant_message, usage=usage, ) - def _handle_chat_generate_stream_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: Generator[BaichuanMessage, None, None]) -> Generator: - for message in response: - if message.usage: - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=message.usage['prompt_tokens'], - completion_tokens=message.usage['completion_tokens']) + def _handle_chat_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Iterator, + ) -> Generator: + for line in response: + if not line: + continue + line = line.decode("utf-8") + # remove the first `data: ` prefix + if line.startswith("data:"): + line = line[5:].strip() + try: + data = json.loads(line) + except Exception as e: + if line.strip() == "[DONE]": + return + choices = data.get("choices", []) + + stop_reason = "" + for choice in choices: + if choice.get("finish_reason"): + stop_reason = choice["finish_reason"] + + if len(choice["delta"]["content"]) == 0: + continue yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, message=AssistantPromptMessage( - content=message.content, - tool_calls=[] + content=choice["delta"]["content"], tool_calls=[] ), - usage=usage, - finish_reason=message.stop_reason if message.stop_reason else None, + finish_reason=stop_reason, ), ) - else: + + # if there is usage, the response is the last one, yield it and return + if "usage" in data: + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=data["usage"]["prompt_tokens"], + completion_tokens=data["usage"]["completion_tokens"], + ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), - finish_reason=message.stop_reason if message.stop_reason else None, + message=AssistantPromptMessage(content="", tool_calls=[]), + usage=usage, + finish_reason=stop_reason, ), ) @@ -215,21 +297,13 @@ class BaichuanLarguageModel(LargeLanguageModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, InsufficientAccountBalance, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index 335fa493cd..3f7266f600 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -1,8 +1,8 @@ # standard import import base64 +import io import json import logging -import mimetypes from collections.abc import Generator from typing import Optional, Union, cast @@ -17,6 +17,7 @@ from botocore.exceptions import ( ServiceNotInRegionError, UnknownServiceError, ) +from PIL.Image import Image # local import from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta @@ -381,9 +382,8 @@ class BedrockLargeLanguageModel(LargeLanguageModel): try: url = message_content.data image_content = requests.get(url).content - if '?' in url: - url = url.split('?')[0] - mime_type, _ = mimetypes.guess_type(url) + with Image.open(io.BytesIO(image_content)) as img: + mime_type = f"image/{img.format.lower()}" base64_data = base64.b64encode(image_content).decode('utf-8') except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index ebcd0af35b..84241fb6c8 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -1,7 +1,7 @@ import base64 +import io import json import logging -import mimetypes from collections.abc import Generator from typing import Optional, Union, cast @@ -12,6 +12,7 @@ import google.generativeai.client as client import requests from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory from google.generativeai.types.content_types import to_part +from PIL import Image from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -371,7 +372,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel): # fetch image data from url try: image_content = requests.get(message_content.data).content - mime_type, _ = mimetypes.guess_type(message_content.data) + with Image.open(io.BytesIO(image_content)) as img: + mime_type = f"image/{img.format.lower()}" base64_data = base64.b64encode(image_content).decode('utf-8') except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index b1660afafb..e2d17e3257 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -6,7 +6,7 @@ from typing import Optional from pydantic import BaseModel, ConfigDict from core.helper.module_import_helper import load_single_subclass_from_source -from core.helper.position_helper import get_position_map, sort_to_dict_by_position_map +from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity from core.model_runtime.model_providers.__base.model_provider import ModelProvider @@ -234,7 +234,7 @@ class ModelProviderFactory: ] # get _position.yaml file path - position_map = get_position_map(model_providers_path) + position_map = get_provider_position_map(model_providers_path) # traverse all model_provider_dir_paths model_providers: list[ModelProviderExtension] = [] diff --git a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-128k.yaml b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-128k.yaml index 0d2e51c47f..1078e84c59 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-128k.yaml +++ b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-128k.yaml @@ -21,6 +21,18 @@ parameter_rules: default: 1024 min: 1 max: 128000 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.06' output: '0.06' diff --git a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-32k.yaml b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-32k.yaml index 9ff537014a..9c739d0501 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-32k.yaml +++ b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-32k.yaml @@ -21,6 +21,18 @@ parameter_rules: default: 1024 min: 1 max: 32000 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.024' output: '0.024' diff --git a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-8k.yaml b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-8k.yaml index 0f308d3676..187a86999e 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-8k.yaml +++ b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-8k.yaml @@ -21,6 +21,18 @@ parameter_rules: default: 512 min: 1 max: 8192 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.012' output: '0.012' diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini-2024-07-18.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini-2024-07-18.yaml index 6f23e0647d..03e28772e6 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini-2024-07-18.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini-2024-07-18.yaml @@ -37,6 +37,9 @@ parameter_rules: options: - text - json_object + - json_schema + - name: json_schema + use_template: json_schema pricing: input: '0.15' output: '0.60' diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index e5cc884b6d..6279125f46 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -150,9 +150,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): except json.JSONDecodeError as e: raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error') - if (completion_type is LLMMode.CHAT and json_result['object'] == ''): + if (completion_type is LLMMode.CHAT and json_result.get('object','') == ''): json_result['object'] = 'chat.completion' - elif (completion_type is LLMMode.COMPLETION and json_result['object'] == ''): + elif (completion_type is LLMMode.COMPLETION and json_result.get('object','') == ''): json_result['object'] = 'text_completion' if (completion_type is LLMMode.CHAT @@ -428,7 +428,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): if new_tool_call.function.arguments: tool_call.function.arguments += new_tool_call.function.arguments - finish_reason = 'Unknown' + finish_reason = None # The default value of finish_reason is None for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter): chunk = chunk.strip() @@ -437,6 +437,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): if chunk.startswith(':'): continue decoded_chunk = chunk.strip().lstrip('data: ').lstrip() + if decoded_chunk == '[DONE]': # Some provider returns "data: [DONE]" + continue try: chunk_json = json.loads(decoded_chunk) @@ -647,7 +649,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): else: raise ValueError(f"Got unknown type {message}") - if message.name: + if message.name and message_dict.get("role", "") != "tool": message_dict["name"] = message.name return message_dict diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v3.yaml b/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v3.yaml new file mode 100644 index 0000000000..171a379ee2 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v3.yaml @@ -0,0 +1,9 @@ +model: text-embedding-v3 +model_type: text-embedding +model_properties: + context_size: 8192 + max_chunks: 25 +pricing: + input: "0.0007" + unit: "0.001" + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py index e7e1b5c764..97dcb72f7c 100644 --- a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py @@ -137,9 +137,19 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel): input=text, text_type="document", ) - data = response.output["embeddings"][0] - embeddings.append(data["embedding"]) - embedding_used_tokens += response.usage["total_tokens"] + if response.output and "embeddings" in response.output and response.output["embeddings"]: + data = response.output["embeddings"][0] + if "embedding" in data: + embeddings.append(data["embedding"]) + else: + raise ValueError("Embedding data is missing in the response.") + else: + raise ValueError("Response output is missing or does not contain embeddings.") + + if response.usage and "total_tokens" in response.usage: + embedding_used_tokens += response.usage["total_tokens"] + else: + raise ValueError("Response usage is missing or does not contain total tokens.") return [list(map(float, e)) for e in embeddings], embedding_used_tokens diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py index 8901549110..1a7368a2cf 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py @@ -1,4 +1,5 @@ import base64 +import io import json import logging from collections.abc import Generator @@ -18,6 +19,7 @@ from anthropic.types import ( ) from google.cloud import aiplatform from google.oauth2 import service_account +from PIL import Image from vertexai.generative_models import HarmBlockThreshold, HarmCategory from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage @@ -332,7 +334,8 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): # fetch image data from url try: image_content = requests.get(message_content.data).content - mime_type, _ = mimetypes.guess_type(message_content.data) + with Image.open(io.BytesIO(image_content)) as img: + mime_type = f"image/{img.format.lower()}" base64_data = base64.b64encode(image_content).decode('utf-8') except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") diff --git a/api/core/model_runtime/model_providers/volcengine_maas/client.py b/api/core/model_runtime/model_providers/volcengine_maas/client.py index 471cb3c94e..a4d89dabcb 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/client.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/client.py @@ -1,6 +1,25 @@ import re -from collections.abc import Callable, Generator -from typing import cast +from collections.abc import Generator +from typing import Optional, cast + +from volcenginesdkarkruntime import Ark +from volcenginesdkarkruntime.types.chat import ( + ChatCompletion, + ChatCompletionAssistantMessageParam, + ChatCompletionChunk, + ChatCompletionContentPartImageParam, + ChatCompletionContentPartTextParam, + ChatCompletionMessageParam, + ChatCompletionMessageToolCallParam, + ChatCompletionSystemMessageParam, + ChatCompletionToolMessageParam, + ChatCompletionToolParam, + ChatCompletionUserMessageParam, +) +from volcenginesdkarkruntime.types.chat.chat_completion_content_part_image_param import ImageURL +from volcenginesdkarkruntime.types.chat.chat_completion_message_tool_call_param import Function +from volcenginesdkarkruntime.types.create_embedding_response import CreateEmbeddingResponse +from volcenginesdkarkruntime.types.shared_params import FunctionDefinition from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -12,123 +31,195 @@ from core.model_runtime.entities.message_entities import ( ToolPromptMessage, UserPromptMessage, ) -from core.model_runtime.model_providers.volcengine_maas.errors import wrap_error -from core.model_runtime.model_providers.volcengine_maas.volc_sdk import ChatRole, MaasException, MaasService + +DEFAULT_V2_ENDPOINT = "maas-api.ml-platform-cn-beijing.volces.com" +DEFAULT_V3_ENDPOINT = "https://ark.cn-beijing.volces.com/api/v3" -class MaaSClient(MaasService): - def __init__(self, host: str, region: str): +class ArkClientV3: + endpoint_id: Optional[str] = None + ark: Optional[Ark] = None + + def __init__(self, *args, **kwargs): + self.ark = Ark(*args, **kwargs) self.endpoint_id = None - super().__init__(host, region) - - def set_endpoint_id(self, endpoint_id: str): - self.endpoint_id = endpoint_id - - @classmethod - def from_credential(cls, credentials: dict) -> 'MaaSClient': - host = credentials['api_endpoint_host'] - region = credentials['volc_region'] - ak = credentials['volc_access_key_id'] - sk = credentials['volc_secret_access_key'] - endpoint_id = credentials['endpoint_id'] - - client = cls(host, region) - client.set_endpoint_id(endpoint_id) - client.set_ak(ak) - client.set_sk(sk) - return client - - def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extra_model_kwargs) -> Generator | dict: - req = { - 'parameters': params, - 'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages], - **extra_model_kwargs, - } - if not stream: - return super().chat( - self.endpoint_id, - req, - ) - return super().stream_chat( - self.endpoint_id, - req, - ) - - def embeddings(self, texts: list[str]) -> dict: - req = { - 'input': texts - } - return super().embeddings(self.endpoint_id, req) @staticmethod - def convert_prompt_message_to_maas_message(message: PromptMessage) -> dict: + def is_legacy(credentials: dict) -> bool: + # match default v2 endpoint + if ArkClientV3.is_compatible_with_legacy(credentials): + return False + # match default v3 endpoint + if credentials.get("api_endpoint_host") == DEFAULT_V3_ENDPOINT: + return False + # only v3 support api_key + if credentials.get("auth_method") == "api_key": + return False + # these cases are considered as sdk v2 + # - modified default v2 endpoint + # - modified default v3 endpoint and auth without api_key + return True + + @staticmethod + def is_compatible_with_legacy(credentials: dict) -> bool: + endpoint = credentials.get("api_endpoint_host") + return endpoint == DEFAULT_V2_ENDPOINT + + @classmethod + def from_credentials(cls, credentials): + """Initialize the client using the credentials provided.""" + args = { + "base_url": credentials['api_endpoint_host'], + "region": credentials['volc_region'], + } + if credentials.get("auth_method") == "api_key": + args = { + **args, + "api_key": credentials['volc_api_key'], + } + else: + args = { + **args, + "ak": credentials['volc_access_key_id'], + "sk": credentials['volc_secret_access_key'], + } + + if cls.is_compatible_with_legacy(credentials): + args = { + **args, + "base_url": DEFAULT_V3_ENDPOINT + } + + client = ArkClientV3( + **args + ) + client.endpoint_id = credentials['endpoint_id'] + return client + + @staticmethod + def convert_prompt_message(message: PromptMessage) -> ChatCompletionMessageParam: + """Converts a PromptMessage to a ChatCompletionMessageParam""" if isinstance(message, UserPromptMessage): message = cast(UserPromptMessage, message) if isinstance(message.content, str): - message_dict = {"role": ChatRole.USER, - "content": message.content} + content = message.content else: content = [] for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: - raise ValueError( - 'Content object type only support image_url') + content.append(ChatCompletionContentPartTextParam( + text=message_content.text, + type='text', + )) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast( ImagePromptMessageContent, message_content) image_data = re.sub( r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data) - content.append({ - 'type': 'image_url', - 'image_url': { - 'url': '', - 'image_bytes': image_data, - 'detail': message_content.detail, - } - }) - - message_dict = {'role': ChatRole.USER, 'content': content} + content.append(ChatCompletionContentPartImageParam( + image_url=ImageURL( + url=image_data, + detail=message_content.detail.value, + ), + type='image_url', + )) + message_dict = ChatCompletionUserMessageParam( + role='user', + content=content + ) elif isinstance(message, AssistantPromptMessage): message = cast(AssistantPromptMessage, message) - message_dict = {'role': ChatRole.ASSISTANT, - 'content': message.content} - if message.tool_calls: - message_dict['tool_calls'] = [ - { - 'name': call.function.name, - 'arguments': call.function.arguments - } for call in message.tool_calls + message_dict = ChatCompletionAssistantMessageParam( + content=message.content, + role='assistant', + tool_calls=None if not message.tool_calls else [ + ChatCompletionMessageToolCallParam( + id=call.id, + function=Function( + name=call.function.name, + arguments=call.function.arguments + ), + type='function' + ) for call in message.tool_calls ] + ) elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) - message_dict = {'role': ChatRole.SYSTEM, - 'content': message.content} + message_dict = ChatCompletionSystemMessageParam( + content=message.content, + role='system' + ) elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) - message_dict = {'role': ChatRole.FUNCTION, - 'content': message.content, - 'name': message.tool_call_id} + message_dict = ChatCompletionToolMessageParam( + content=message.content, + role='tool', + tool_call_id=message.tool_call_id + ) else: raise ValueError(f"Got unknown PromptMessage type {message}") return message_dict @staticmethod - def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator: - try: - resp = fn() - except MaasException as e: - raise wrap_error(e) + def _convert_tool_prompt(message: PromptMessageTool) -> ChatCompletionToolParam: + return ChatCompletionToolParam( + type='function', + function=FunctionDefinition( + name=message.name, + description=message.description, + parameters=message.parameters, + ) + ) - return resp + def chat(self, messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + frequency_penalty: Optional[float] = None, + max_tokens: Optional[int] = None, + presence_penalty: Optional[float] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + ) -> ChatCompletion: + """Block chat""" + return self.ark.chat.completions.create( + model=self.endpoint_id, + messages=[self.convert_prompt_message(message) for message in messages], + tools=[self._convert_tool_prompt(tool) for tool in tools] if tools else None, + stop=stop, + frequency_penalty=frequency_penalty, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + top_p=top_p, + temperature=temperature, + ) - @staticmethod - def transform_tool_prompt_to_maas_config(tool: PromptMessageTool): - return { - "type": "function", - "function": { - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters, - } - } + def stream_chat(self, messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + frequency_penalty: Optional[float] = None, + max_tokens: Optional[int] = None, + presence_penalty: Optional[float] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + ) -> Generator[ChatCompletionChunk]: + """Stream chat""" + chunks = self.ark.chat.completions.create( + stream=True, + model=self.endpoint_id, + messages=[self.convert_prompt_message(message) for message in messages], + tools=[self._convert_tool_prompt(tool) for tool in tools] if tools else None, + stop=stop, + frequency_penalty=frequency_penalty, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + top_p=top_p, + temperature=temperature, + ) + for chunk in chunks: + if not chunk.choices: + continue + yield chunk + + def embeddings(self, texts: list[str]) -> CreateEmbeddingResponse: + return self.ark.embeddings.create(model=self.endpoint_id, input=texts) diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py new file mode 100644 index 0000000000..1978c11680 --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py @@ -0,0 +1,134 @@ +import re +from collections.abc import Callable, Generator +from typing import cast + +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentType, + PromptMessageTool, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.model_providers.volcengine_maas.legacy.errors import wrap_error +from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasException, MaasService + + +class MaaSClient(MaasService): + def __init__(self, host: str, region: str): + self.endpoint_id = None + super().__init__(host, region) + + def set_endpoint_id(self, endpoint_id: str): + self.endpoint_id = endpoint_id + + @classmethod + def from_credential(cls, credentials: dict) -> 'MaaSClient': + host = credentials['api_endpoint_host'] + region = credentials['volc_region'] + ak = credentials['volc_access_key_id'] + sk = credentials['volc_secret_access_key'] + endpoint_id = credentials['endpoint_id'] + + client = cls(host, region) + client.set_endpoint_id(endpoint_id) + client.set_ak(ak) + client.set_sk(sk) + return client + + def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extra_model_kwargs) -> Generator | dict: + req = { + 'parameters': params, + 'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages], + **extra_model_kwargs, + } + if not stream: + return super().chat( + self.endpoint_id, + req, + ) + return super().stream_chat( + self.endpoint_id, + req, + ) + + def embeddings(self, texts: list[str]) -> dict: + req = { + 'input': texts + } + return super().embeddings(self.endpoint_id, req) + + @staticmethod + def convert_prompt_message_to_maas_message(message: PromptMessage) -> dict: + if isinstance(message, UserPromptMessage): + message = cast(UserPromptMessage, message) + if isinstance(message.content, str): + message_dict = {"role": ChatRole.USER, + "content": message.content} + else: + content = [] + for message_content in message.content: + if message_content.type == PromptMessageContentType.TEXT: + raise ValueError( + 'Content object type only support image_url') + elif message_content.type == PromptMessageContentType.IMAGE: + message_content = cast( + ImagePromptMessageContent, message_content) + image_data = re.sub( + r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data) + content.append({ + 'type': 'image_url', + 'image_url': { + 'url': '', + 'image_bytes': image_data, + 'detail': message_content.detail, + } + }) + + message_dict = {'role': ChatRole.USER, 'content': content} + elif isinstance(message, AssistantPromptMessage): + message = cast(AssistantPromptMessage, message) + message_dict = {'role': ChatRole.ASSISTANT, + 'content': message.content} + if message.tool_calls: + message_dict['tool_calls'] = [ + { + 'name': call.function.name, + 'arguments': call.function.arguments + } for call in message.tool_calls + ] + elif isinstance(message, SystemPromptMessage): + message = cast(SystemPromptMessage, message) + message_dict = {'role': ChatRole.SYSTEM, + 'content': message.content} + elif isinstance(message, ToolPromptMessage): + message = cast(ToolPromptMessage, message) + message_dict = {'role': ChatRole.FUNCTION, + 'content': message.content, + 'name': message.tool_call_id} + else: + raise ValueError(f"Got unknown PromptMessage type {message}") + + return message_dict + + @staticmethod + def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator: + try: + resp = fn() + except MaasException as e: + raise wrap_error(e) + + return resp + + @staticmethod + def transform_tool_prompt_to_maas_config(tool: PromptMessageTool): + return { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters, + } + } diff --git a/api/core/model_runtime/model_providers/volcengine_maas/errors.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py similarity index 97% rename from api/core/model_runtime/model_providers/volcengine_maas/errors.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py index 63397a456e..21ffaf1258 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/errors.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py @@ -1,4 +1,4 @@ -from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException +from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasException class ClientSDKRequestError(MaasException): diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py similarity index 100% rename from api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/__init__.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/__init__.py similarity index 100% rename from api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/__init__.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/__init__.py diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py similarity index 100% rename from api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py similarity index 100% rename from api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/util.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py similarity index 100% rename from api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/util.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/common.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/common.py similarity index 100% rename from api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/common.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/common.py diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/maas.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py similarity index 100% rename from api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/maas.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py index add5822bef..996c66e604 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py @@ -1,8 +1,10 @@ import logging from collections.abc import Generator +from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk + from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -27,19 +29,21 @@ from core.model_runtime.errors.invoke import ( ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient -from core.model_runtime.model_providers.volcengine_maas.errors import ( +from core.model_runtime.model_providers.volcengine_maas.client import ArkClientV3 +from core.model_runtime.model_providers.volcengine_maas.legacy.client import MaaSClient +from core.model_runtime.model_providers.volcengine_maas.legacy.errors import ( AuthErrors, BadRequestErrors, ConnectionErrors, + MaasException, RateLimitErrors, ServerUnavailableErrors, ) from core.model_runtime.model_providers.volcengine_maas.llm.models import ( get_model_config, get_v2_req_params, + get_v3_req_params, ) -from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException logger = logging.getLogger(__name__) @@ -49,13 +53,20 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): model_parameters: dict, tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: - return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + if ArkClientV3.is_legacy(credentials): + return self._generate_v2(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + return self._generate_v3(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate credentials """ - # ping + if ArkClientV3.is_legacy(credentials): + return self._validate_credentials_v2(credentials) + return self._validate_credentials_v3(credentials) + + @staticmethod + def _validate_credentials_v2(credentials: dict) -> None: client = MaaSClient.from_credential(credentials) try: client.chat( @@ -70,18 +81,24 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): except MaasException as e: raise CredentialsValidateFailedError(e.message) + @staticmethod + def _validate_credentials_v3(credentials: dict) -> None: + client = ArkClientV3.from_credentials(credentials) + try: + client.chat(max_tokens=16, temperature=0.7, top_p=0.9, + messages=[UserPromptMessage(content='ping\nAnswer: ')], ) + except Exception as e: + raise CredentialsValidateFailedError(e) + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: list[PromptMessageTool] | None = None) -> int: - if len(prompt_messages) == 0: + if ArkClientV3.is_legacy(credentials): + return self._get_num_tokens_v2(prompt_messages) + return self._get_num_tokens_v3(prompt_messages) + + def _get_num_tokens_v2(self, messages: list[PromptMessage]) -> int: + if len(messages) == 0: return 0 - return self._num_tokens_from_messages(prompt_messages) - - def _num_tokens_from_messages(self, messages: list[PromptMessage]) -> int: - """ - Calculate num tokens. - - :param messages: messages - """ num_tokens = 0 messages_dict = [ MaaSClient.convert_prompt_message_to_maas_message(m) for m in messages] @@ -92,9 +109,22 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): return num_tokens - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ + def _get_num_tokens_v3(self, messages: list[PromptMessage]) -> int: + if len(messages) == 0: + return 0 + num_tokens = 0 + messages_dict = [ + ArkClientV3.convert_prompt_message(m) for m in messages] + for message in messages_dict: + for key, value in message.items(): + num_tokens += self._get_num_tokens_by_gpt2(str(key)) + num_tokens += self._get_num_tokens_by_gpt2(str(value)) + + return num_tokens + + def _generate_v2(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: client = MaaSClient.from_credential(credentials) @@ -106,77 +136,151 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): ] resp = MaaSClient.wrap_exception( lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs)) - if not stream: - return self._handle_chat_response(model, credentials, prompt_messages, resp) - return self._handle_stream_chat_response(model, credentials, prompt_messages, resp) - def _handle_stream_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: Generator) -> Generator: - for index, r in enumerate(resp): - choices = r['choices'] + def _handle_stream_chat_response() -> Generator: + for index, r in enumerate(resp): + choices = r['choices'] + if not choices: + continue + choice = choices[0] + message = choice['message'] + usage = None + if r.get('usage'): + usage = self._calc_response_usage(model=model, credentials=credentials, + prompt_tokens=r['usage']['prompt_tokens'], + completion_tokens=r['usage']['completion_tokens'] + ) + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=AssistantPromptMessage( + content=message['content'] if message['content'] else '', + tool_calls=[] + ), + usage=usage, + finish_reason=choice.get('finish_reason'), + ), + ) + + def _handle_chat_response() -> LLMResult: + choices = resp['choices'] if not choices: - continue + raise ValueError("No choices found") + choice = choices[0] message = choice['message'] - usage = None - if r.get('usage'): - usage = self._calc_usage(model, credentials, r['usage']) - yield LLMResultChunk( + + # parse tool calls + tool_calls = [] + if message['tool_calls']: + for call in message['tool_calls']: + tool_call = AssistantPromptMessage.ToolCall( + id=call['function']['name'], + type=call['type'], + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=call['function']['name'], + arguments=call['function']['arguments'] + ) + ) + tool_calls.append(tool_call) + + usage = resp['usage'] + return LLMResult( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=AssistantPromptMessage( - content=message['content'] if message['content'] else '', - tool_calls=[] - ), - usage=usage, - finish_reason=choice.get('finish_reason'), + message=AssistantPromptMessage( + content=message['content'] if message['content'] else '', + tool_calls=tool_calls, ), + usage=self._calc_response_usage(model=model, credentials=credentials, + prompt_tokens=usage['prompt_tokens'], + completion_tokens=usage['completion_tokens'] + ), ) - def _handle_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: dict) -> LLMResult: - choices = resp['choices'] - if not choices: - return - choice = choices[0] - message = choice['message'] + if not stream: + return _handle_chat_response() + return _handle_stream_chat_response() - # parse tool calls - tool_calls = [] - if message['tool_calls']: - for call in message['tool_calls']: - tool_call = AssistantPromptMessage.ToolCall( - id=call['function']['name'], - type=call['type'], - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=call['function']['name'], - arguments=call['function']['arguments'] - ) + def _generate_v3(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ + -> LLMResult | Generator: + + client = ArkClientV3.from_credentials(credentials) + req_params = get_v3_req_params(credentials, model_parameters, stop) + if tools: + req_params['tools'] = tools + + def _handle_stream_chat_response(chunks: Generator[ChatCompletionChunk]) -> Generator: + for chunk in chunks: + if not chunk.choices: + continue + choice = chunk.choices[0] + + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=choice.index, + message=AssistantPromptMessage( + content=choice.delta.content, + tool_calls=[] + ), + usage=self._calc_response_usage(model=model, credentials=credentials, + prompt_tokens=chunk.usage.prompt_tokens, + completion_tokens=chunk.usage.completion_tokens + ) if chunk.usage else None, + finish_reason=choice.finish_reason, + ), ) - tool_calls.append(tool_call) - return LLMResult( - model=model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=message['content'] if message['content'] else '', - tool_calls=tool_calls, - ), - usage=self._calc_usage(model, credentials, resp['usage']), - ) + def _handle_chat_response(resp: ChatCompletion) -> LLMResult: + choice = resp.choices[0] + message = choice.message + # parse tool calls + tool_calls = [] + if message.tool_calls: + for call in message.tool_calls: + tool_call = AssistantPromptMessage.ToolCall( + id=call.id, + type=call.type, + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=call.function.name, + arguments=call.function.arguments + ) + ) + tool_calls.append(tool_call) - def _calc_usage(self, model: str, credentials: dict, usage: dict) -> LLMUsage: - return self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=usage['prompt_tokens'], - completion_tokens=usage['completion_tokens'] - ) + usage = resp.usage + return LLMResult( + model=model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage( + content=message.content if message.content else "", + tool_calls=tool_calls, + ), + usage=self._calc_response_usage(model=model, credentials=credentials, + prompt_tokens=usage.prompt_tokens, + completion_tokens=usage.completion_tokens + ), + ) + + if not stream: + resp = client.chat(prompt_messages, **req_params) + return _handle_chat_response(resp) + + chunks = client.stream_chat(prompt_messages, **req_params) + return _handle_stream_chat_response(chunks) def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ used to define customizable model schema """ model_config = get_model_config(credentials) - + rules = [ ParameterRule( name='temperature', @@ -212,7 +316,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): use_template='presence_penalty', label=I18nObject( en_US='Presence Penalty', - zh_Hans= '存在惩罚', + zh_Hans='存在惩罚', ), min=-2.0, max=2.0, @@ -222,8 +326,8 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): type=ParameterType.FLOAT, use_template='frequency_penalty', label=I18nObject( - en_US= 'Frequency Penalty', - zh_Hans= '频率惩罚', + en_US='Frequency Penalty', + zh_Hans='频率惩罚', ), min=-2.0, max=2.0, @@ -245,7 +349,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): model_properties = {} model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size model_properties[ModelPropertyKey.MODE] = model_config.properties.mode.value - + entity = AIModelEntity( model=model, label=I18nObject( diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py index c5e53d8955..a882f68a36 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py @@ -5,10 +5,11 @@ from core.model_runtime.entities.model_entities import ModelFeature class ModelProperties(BaseModel): - context_size: int - max_tokens: int + context_size: int + max_tokens: int mode: LLMMode + class ModelConfig(BaseModel): properties: ModelProperties features: list[ModelFeature] @@ -24,23 +25,23 @@ configs: dict[str, ModelConfig] = { features=[ModelFeature.TOOL_CALL] ), 'Doubao-pro-32k': ModelConfig( - properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT), + properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT), features=[ModelFeature.TOOL_CALL] ), 'Doubao-lite-32k': ModelConfig( - properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT), + properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT), features=[ModelFeature.TOOL_CALL] ), 'Doubao-pro-128k': ModelConfig( - properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT), + properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT), features=[ModelFeature.TOOL_CALL] ), 'Doubao-lite-128k': ModelConfig( - properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT), + features=[] ), 'Skylark2-pro-4k': ModelConfig( - properties=ModelProperties(context_size=4096, max_tokens=4000, mode=LLMMode.CHAT), + properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), features=[] ), 'Llama3-8B': ModelConfig( @@ -53,23 +54,23 @@ configs: dict[str, ModelConfig] = { ), 'Moonshot-v1-8k': ModelConfig( properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT), - features=[] + features=[ModelFeature.TOOL_CALL] ), 'Moonshot-v1-32k': ModelConfig( properties=ModelProperties(context_size=32768, max_tokens=16384, mode=LLMMode.CHAT), - features=[] + features=[ModelFeature.TOOL_CALL] ), 'Moonshot-v1-128k': ModelConfig( properties=ModelProperties(context_size=131072, max_tokens=65536, mode=LLMMode.CHAT), - features=[] + features=[ModelFeature.TOOL_CALL] ), 'GLM3-130B': ModelConfig( properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT), - features=[] + features=[ModelFeature.TOOL_CALL] ), 'GLM3-130B-Fin': ModelConfig( properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT), - features=[] + features=[ModelFeature.TOOL_CALL] ), 'Mistral-7B': ModelConfig( properties=ModelProperties(context_size=8192, max_tokens=2048, mode=LLMMode.CHAT), @@ -77,23 +78,24 @@ configs: dict[str, ModelConfig] = { ) } -def get_model_config(credentials: dict)->ModelConfig: + +def get_model_config(credentials: dict) -> ModelConfig: base_model = credentials.get('base_model_name', '') model_configs = configs.get(base_model) if not model_configs: return ModelConfig( - properties=ModelProperties( + properties=ModelProperties( context_size=int(credentials.get('context_size', 0)), max_tokens=int(credentials.get('max_tokens', 0)), - mode= LLMMode.value_of(credentials.get('mode', 'chat')), + mode=LLMMode.value_of(credentials.get('mode', 'chat')), ), features=[] ) return model_configs -def get_v2_req_params(credentials: dict, model_parameters: dict, - stop: list[str] | None=None): +def get_v2_req_params(credentials: dict, model_parameters: dict, + stop: list[str] | None = None): req_params = {} # predefined properties model_configs = get_model_config(credentials) @@ -116,8 +118,36 @@ def get_v2_req_params(credentials: dict, model_parameters: dict, if model_parameters.get('frequency_penalty'): req_params['frequency_penalty'] = model_parameters.get( 'frequency_penalty') - + if stop: req_params['stop'] = stop - return req_params \ No newline at end of file + return req_params + + +def get_v3_req_params(credentials: dict, model_parameters: dict, + stop: list[str] | None = None): + req_params = {} + # predefined properties + model_configs = get_model_config(credentials) + if model_configs: + req_params['max_tokens'] = model_configs.properties.max_tokens + + # model parameters + if model_parameters.get('max_tokens'): + req_params['max_tokens'] = model_parameters.get('max_tokens') + if model_parameters.get('temperature'): + req_params['temperature'] = model_parameters.get('temperature') + if model_parameters.get('top_p'): + req_params['top_p'] = model_parameters.get('top_p') + if model_parameters.get('presence_penalty'): + req_params['presence_penalty'] = model_parameters.get( + 'presence_penalty') + if model_parameters.get('frequency_penalty'): + req_params['frequency_penalty'] = model_parameters.get( + 'frequency_penalty') + + if stop: + req_params['stop'] = stop + + return req_params diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py index 2d8f972b94..74cf26247c 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py @@ -2,26 +2,29 @@ from pydantic import BaseModel class ModelProperties(BaseModel): - context_size: int - max_chunks: int + context_size: int + max_chunks: int + class ModelConfig(BaseModel): properties: ModelProperties + ModelConfigs = { 'Doubao-embedding': ModelConfig( - properties=ModelProperties(context_size=4096, max_chunks=1) + properties=ModelProperties(context_size=4096, max_chunks=32) ), } -def get_model_config(credentials: dict)->ModelConfig: + +def get_model_config(credentials: dict) -> ModelConfig: base_model = credentials.get('base_model_name', '') model_configs = ModelConfigs.get(base_model) if not model_configs: return ModelConfig( - properties=ModelProperties( + properties=ModelProperties( context_size=int(credentials.get('context_size', 0)), max_chunks=int(credentials.get('max_chunks', 0)), ) ) - return model_configs \ No newline at end of file + return model_configs diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py index 8ac632369e..d54aeeb0b1 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py @@ -22,16 +22,17 @@ from core.model_runtime.errors.invoke import ( ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient -from core.model_runtime.model_providers.volcengine_maas.errors import ( +from core.model_runtime.model_providers.volcengine_maas.client import ArkClientV3 +from core.model_runtime.model_providers.volcengine_maas.legacy.client import MaaSClient +from core.model_runtime.model_providers.volcengine_maas.legacy.errors import ( AuthErrors, BadRequestErrors, ConnectionErrors, + MaasException, RateLimitErrors, ServerUnavailableErrors, ) from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import get_model_config -from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): @@ -51,6 +52,14 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ + if ArkClientV3.is_legacy(credentials): + return self._generate_v2(model, credentials, texts, user) + + return self._generate_v3(model, credentials, texts, user) + + def _generate_v2(self, model: str, credentials: dict, + texts: list[str], user: Optional[str] = None) \ + -> TextEmbeddingResult: client = MaaSClient.from_credential(credentials) resp = MaaSClient.wrap_exception(lambda: client.embeddings(texts)) @@ -65,6 +74,23 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): return result + def _generate_v3(self, model: str, credentials: dict, + texts: list[str], user: Optional[str] = None) \ + -> TextEmbeddingResult: + client = ArkClientV3.from_credentials(credentials) + resp = client.embeddings(texts) + + usage = self._calc_response_usage( + model=model, credentials=credentials, tokens=resp.usage.total_tokens) + + result = TextEmbeddingResult( + model=model, + embeddings=[v.embedding for v in resp.data], + usage=usage + ) + + return result + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ Get number of tokens for given prompt messages @@ -88,11 +114,22 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): :param credentials: model credentials :return: """ + if ArkClientV3.is_legacy(credentials): + return self._validate_credentials_v2(model, credentials) + return self._validate_credentials_v3(model, credentials) + + def _validate_credentials_v2(self, model: str, credentials: dict) -> None: try: self._invoke(model=model, credentials=credentials, texts=['ping']) except MaasException as e: raise CredentialsValidateFailedError(e.message) + def _validate_credentials_v3(self, model: str, credentials: dict) -> None: + try: + self._invoke(model=model, credentials=credentials, texts=['ping']) + except Exception as e: + raise CredentialsValidateFailedError(e) + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ @@ -116,9 +153,10 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): generate custom model entities from credentials """ model_config = get_model_config(credentials) - model_properties = {} - model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size - model_properties[ModelPropertyKey.MAX_CHUNKS] = model_config.properties.max_chunks + model_properties = { + ModelPropertyKey.CONTEXT_SIZE: model_config.properties.context_size, + ModelPropertyKey.MAX_CHUNKS: model_config.properties.max_chunks + } entity = AIModelEntity( model=model, label=I18nObject(en_US=model), diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml b/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml index a00c1b7994..13e00da76f 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml +++ b/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml @@ -30,8 +30,28 @@ model_credential_schema: en_US: Enter your Model Name zh_Hans: 输入模型名称 credential_form_schemas: + - variable: auth_method + required: true + label: + en_US: Authentication Method + zh_Hans: 鉴权方式 + type: select + default: aksk + options: + - label: + en_US: API Key + value: api_key + - label: + en_US: Access Key / Secret Access Key + value: aksk + placeholder: + en_US: Enter your Authentication Method + zh_Hans: 选择鉴权方式 - variable: volc_access_key_id required: true + show_on: + - variable: auth_method + value: aksk label: en_US: Access Key zh_Hans: Access Key @@ -41,6 +61,9 @@ model_credential_schema: zh_Hans: 输入您的 Access Key - variable: volc_secret_access_key required: true + show_on: + - variable: auth_method + value: aksk label: en_US: Secret Access Key zh_Hans: Secret Access Key @@ -48,6 +71,17 @@ model_credential_schema: placeholder: en_US: Enter your Secret Access Key zh_Hans: 输入您的 Secret Access Key + - variable: volc_api_key + required: true + show_on: + - variable: auth_method + value: api_key + label: + en_US: API Key + type: secret-input + placeholder: + en_US: Enter your API Key + zh_Hans: 输入您的 API Key - variable: volc_region required: true label: @@ -64,7 +98,7 @@ model_credential_schema: en_US: API Endpoint Host zh_Hans: API Endpoint Host type: text-input - default: maas-api.ml-platform-cn-beijing.volces.com + default: https://ark.cn-beijing.volces.com/api/v3 placeholder: en_US: Enter your API Endpoint Host zh_Hans: 输入 API Endpoint Host diff --git a/api/core/model_runtime/model_providers/xinference/tts/tts.py b/api/core/model_runtime/model_providers/xinference/tts/tts.py index bfa752df8c..8cc99fef7c 100644 --- a/api/core/model_runtime/model_providers/xinference/tts/tts.py +++ b/api/core/model_runtime/model_providers/xinference/tts/tts.py @@ -174,6 +174,11 @@ class XinferenceText2SpeechModel(TTSModel): return voices[language] elif 'all' in voices: return voices['all'] + else: + all_voices = [] + for lang, lang_voices in voices.items(): + all_voices.extend(lang_voices) + return all_voices return self.model_voices['__default']['all'] diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flash.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flash.yaml index 1b1d499ba7..0e3c001f06 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flash.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flash.yaml @@ -38,7 +38,7 @@ parameter_rules: min: 1 max: 8192 pricing: - input: '0.0001' - output: '0.0001' + input: '0' + output: '0' unit: '0.001' currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml index 5bdb442840..b0f95c0a68 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml @@ -37,3 +37,8 @@ parameter_rules: default: 1024 min: 1 max: 8192 +pricing: + input: '0.001' + output: '0.001' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml index 6b5bcc5bcf..271eecf199 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml @@ -37,3 +37,8 @@ parameter_rules: default: 1024 min: 1 max: 8192 +pricing: + input: '0.1' + output: '0.1' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_long.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_long.yaml index 9d92e58f6c..150e07b60a 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_long.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_long.yaml @@ -30,4 +30,9 @@ parameter_rules: use_template: max_tokens default: 1024 min: 1 - max: 4096 + max: 8192 +pricing: + input: '0.001' + output: '0.001' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_plus.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_plus.yaml new file mode 100644 index 0000000000..237a951cd5 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_plus.yaml @@ -0,0 +1,44 @@ +model: glm-4-plus +label: + en_US: glm-4-plus +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call +model_properties: + mode: chat +parameter_rules: + - name: temperature + use_template: temperature + default: 0.95 + min: 0.0 + max: 1.0 + help: + zh_Hans: 采样温度,控制输出的随机性,必须为正数取值范围是:(0.0,1.0],不能等于 0,默认值为 0.95 值越大,会使输出更随机,更具创造性;值越小,输出会更加稳定或确定建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 + en_US: Sampling temperature, controls the randomness of the output, must be a positive number. The value range is (0.0,1.0], which cannot be equal to 0. The default value is 0.95. The larger the value, the more random and creative the output will be; the smaller the value, The output will be more stable or certain. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. + - name: top_p + use_template: top_p + default: 0.7 + help: + zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 + en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. + - name: incremental + label: + zh_Hans: 增量返回 + en_US: Incremental + type: boolean + help: + zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。 + en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return. + required: false + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 8192 +pricing: + input: '0.05' + output: '0.05' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v.yaml index ddea331c8e..c7a4093d7a 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v.yaml @@ -34,4 +34,9 @@ parameter_rules: use_template: max_tokens default: 1024 min: 1 - max: 8192 + max: 1024 +pricing: + input: '0.05' + output: '0.05' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v_plus.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v_plus.yaml new file mode 100644 index 0000000000..a7aee5b4ca --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v_plus.yaml @@ -0,0 +1,42 @@ +model: glm-4v-plus +label: + en_US: glm-4v-plus +model_type: llm +model_properties: + mode: chat +features: + - vision +parameter_rules: + - name: temperature + use_template: temperature + default: 0.95 + min: 0.0 + max: 1.0 + help: + zh_Hans: 采样温度,控制输出的随机性,必须为正数取值范围是:(0.0,1.0],不能等于 0,默认值为 0.95 值越大,会使输出更随机,更具创造性;值越小,输出会更加稳定或确定建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 + en_US: Sampling temperature, controls the randomness of the output, must be a positive number. The value range is (0.0,1.0], which cannot be equal to 0. The default value is 0.95. The larger the value, the more random and creative the output will be; the smaller the value, The output will be more stable or certain. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. + - name: top_p + use_template: top_p + default: 0.7 + help: + zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 + en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. + - name: incremental + label: + zh_Hans: 增量返回 + en_US: Incremental + type: boolean + help: + zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。 + en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return. + required: false + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 1024 +pricing: + input: '0.01' + output: '0.01' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index ff971964a8..b2cdc7ad7a 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -153,7 +153,8 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): :return: full response or stream response chunk generator result """ extra_model_kwargs = {} - if stop: + # request to glm-4v-plus with stop words will always response "finish_reason":"network_error" + if stop and model!= 'glm-4v-plus': extra_model_kwargs['stop'] = stop client = ZhipuAI( @@ -174,7 +175,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): if copy_prompt_message.role in [PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL]: if isinstance(copy_prompt_message.content, list): # check if model is 'glm-4v' - if model != 'glm-4v': + if model not in ('glm-4v', 'glm-4v-plus'): # not support list message continue # get image and @@ -207,7 +208,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): else: new_prompt_messages.append(copy_prompt_message) - if model == 'glm-4v': + if model == 'glm-4v' or model == 'glm-4v-plus': params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters) else: params = { @@ -304,7 +305,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): return params - def _construct_glm_4v_messages(self, prompt_message: Union[str | list[PromptMessageContent]]) -> list[dict]: + def _construct_glm_4v_messages(self, prompt_message: Union[str, list[PromptMessageContent]]) -> list[dict]: if isinstance(prompt_message, str): return [{'type': 'text', 'text': prompt_message}] @@ -444,6 +445,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, + finish_reason=delta.finish_reason ) ) diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 698398e0cb..a0f3ac7f86 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -204,6 +204,7 @@ class LangFuseDataTrace(BaseTraceInstance): node_generation_data = LangfuseGeneration( name="llm", trace_id=trace_id, + model=process_data.get("model_name"), parent_observation_id=node_execution_id, start_time=created_at, end_time=finished_at, @@ -419,3 +420,11 @@ class LangFuseDataTrace(BaseTraceInstance): except Exception as e: logger.debug(f"LangFuse API check failed: {str(e)}") raise ValueError(f"LangFuse API check failed: {str(e)}") + + def get_project_key(self): + try: + projects = self.langfuse_client.client.projects.get() + return projects.data[0].id + except Exception as e: + logger.debug(f"LangFuse get project key failed: {str(e)}") + raise ValueError(f"LangFuse get project key failed: {str(e)}") diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index fde8a06c61..cc242905bd 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -139,8 +139,7 @@ class LangSmithDataTrace(BaseTraceInstance): json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} ) node_total_tokens = execution_metadata.get("total_tokens", 0) - - metadata = json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} + metadata = execution_metadata.copy() metadata.update( { "workflow_run_id": trace_info.workflow_run_id, @@ -156,6 +155,12 @@ class LangSmithDataTrace(BaseTraceInstance): process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} if process_data and process_data.get("model_mode") == "chat": run_type = LangSmithRunType.llm + metadata.update( + { + 'ls_provider': process_data.get('model_provider', ''), + 'ls_model_name': process_data.get('model_name', ''), + } + ) elif node_type == "knowledge-retrieval": run_type = LangSmithRunType.retriever else: diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 068b490ec8..1416d6bd2d 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -38,7 +38,7 @@ provider_config_map = { TracingProviderEnum.LANGFUSE.value: { 'config_class': LangfuseConfig, 'secret_keys': ['public_key', 'secret_key'], - 'other_keys': ['host'], + 'other_keys': ['host', 'project_key'], 'trace_instance': LangFuseDataTrace }, TracingProviderEnum.LANGSMITH.value: { @@ -123,7 +123,6 @@ class OpsTraceManager: for key in other_keys: new_config[key] = decrypt_tracing_config.get(key, "") - return config_class(**new_config).model_dump() @classmethod @@ -252,6 +251,19 @@ class OpsTraceManager: tracing_config = config_type(**tracing_config) return trace_instance(tracing_config).api_check() + @staticmethod + def get_trace_config_project_key(tracing_config: dict, tracing_provider: str): + """ + get trace config is project key + :param tracing_config: tracing config + :param tracing_provider: tracing provider + :return: + """ + config_type, trace_instance = provider_config_map[tracing_provider]['config_class'], \ + provider_config_map[tracing_provider]['trace_instance'] + tracing_config = config_type(**tracing_config) + return trace_instance(tracing_config).get_project_key() + class TraceTask: def __init__( diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 65a4cada88..67eee2c294 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -5,6 +5,7 @@ from typing import Optional from sqlalchemy.exc import IntegrityError +from configs import dify_config from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle from core.entities.provider_entities import ( @@ -18,12 +19,9 @@ from core.entities.provider_entities import ( ) from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType +from core.helper.position_helper import is_filtered from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.entities.provider_entities import ( - CredentialFormSchema, - FormType, - ProviderEntity, -) +from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderEntity from core.model_runtime.model_providers import model_provider_factory from extensions import ext_hosting_provider from extensions.ext_database import db @@ -45,6 +43,7 @@ class ProviderManager: """ ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers. """ + def __init__(self) -> None: self.decoding_rsa_key = None self.decoding_cipher_rsa = None @@ -117,6 +116,16 @@ class ProviderManager: # Construct ProviderConfiguration objects for each provider for provider_entity in provider_entities: + + # handle include, exclude + if is_filtered( + include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET, + exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET, + data=provider_entity, + name_func=lambda x: x.provider, + ): + continue + provider_name = provider_entity.provider provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, []) provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, []) @@ -271,6 +280,24 @@ class ProviderManager: ) ) + def get_first_provider_first_model(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]: + """ + Get names of first model and its provider + + :param tenant_id: workspace id + :param model_type: model type + :return: provider name, model name + """ + provider_configurations = self.get_configurations(tenant_id) + + # get available models from provider_configurations + all_models = provider_configurations.get_models( + model_type=model_type, + only_active=False + ) + + return all_models[0].provider.provider, all_models[0].model + def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \ -> TenantDefaultModel: """ @@ -323,7 +350,8 @@ class ProviderManager: return default_model - def _get_all_providers(self, tenant_id: str) -> dict[str, list[Provider]]: + @staticmethod + def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]: """ Get all provider records of the workspace. @@ -342,7 +370,8 @@ class ProviderManager: return provider_name_to_provider_records_dict - def _get_all_provider_models(self, tenant_id: str) -> dict[str, list[ProviderModel]]: + @staticmethod + def _get_all_provider_models(tenant_id: str) -> dict[str, list[ProviderModel]]: """ Get all provider model records of the workspace. @@ -362,7 +391,8 @@ class ProviderManager: return provider_name_to_provider_model_records_dict - def _get_all_preferred_model_providers(self, tenant_id: str) -> dict[str, TenantPreferredModelProvider]: + @staticmethod + def _get_all_preferred_model_providers(tenant_id: str) -> dict[str, TenantPreferredModelProvider]: """ Get All preferred provider types of the workspace. @@ -381,7 +411,8 @@ class ProviderManager: return provider_name_to_preferred_provider_type_records_dict - def _get_all_provider_model_settings(self, tenant_id: str) -> dict[str, list[ProviderModelSetting]]: + @staticmethod + def _get_all_provider_model_settings(tenant_id: str) -> dict[str, list[ProviderModelSetting]]: """ Get All provider model settings of the workspace. @@ -400,7 +431,8 @@ class ProviderManager: return provider_name_to_provider_model_settings_dict - def _get_all_provider_load_balancing_configs(self, tenant_id: str) -> dict[str, list[LoadBalancingModelConfig]]: + @staticmethod + def _get_all_provider_load_balancing_configs(tenant_id: str) -> dict[str, list[LoadBalancingModelConfig]]: """ Get All provider load balancing configs of the workspace. @@ -431,7 +463,8 @@ class ProviderManager: return provider_name_to_provider_load_balancing_model_configs_dict - def _init_trial_provider_records(self, tenant_id: str, + @staticmethod + def _init_trial_provider_records(tenant_id: str, provider_name_to_provider_records_dict: dict[str, list]) -> dict[str, list]: """ Initialize trial provider records if not exists. @@ -764,7 +797,8 @@ class ProviderManager: credentials=current_using_credentials ) - def _choice_current_using_quota_type(self, quota_configurations: list[QuotaConfiguration]) -> ProviderQuotaType: + @staticmethod + def _choice_current_using_quota_type(quota_configurations: list[QuotaConfiguration]) -> ProviderQuotaType: """ Choice current using quota type. paid quotas > provider free quotas > hosting trial quotas @@ -791,7 +825,8 @@ class ProviderManager: raise ValueError('No quota type available') - def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]: + @staticmethod + def _extract_secret_variables(credential_form_schemas: list[CredentialFormSchema]) -> list[str]: """ Extract secret input form variables. diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 3932e90042..9f45771794 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -146,7 +146,7 @@ class RetrievalService: ) if documents: - if reranking_model and retrival_method == RetrievalMethod.SEMANTIC_SEARCH.value: + if reranking_model and reranking_model.get('reranking_model_name') and reranking_model.get('reranking_provider_name') and retrival_method == RetrievalMethod.SEMANTIC_SEARCH.value: data_post_processor = DataPostProcessor(str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False) @@ -180,7 +180,7 @@ class RetrievalService: top_k=top_k ) if documents: - if reranking_model and retrival_method == RetrievalMethod.FULL_TEXT_SEARCH.value: + if reranking_model and reranking_model.get('reranking_model_name') and reranking_model.get('reranking_provider_name') and retrival_method == RetrievalMethod.FULL_TEXT_SEARCH.value: data_post_processor = DataPostProcessor(str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False) diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index 01ba6fb324..233539756f 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -1,5 +1,7 @@ import json -from typing import Any +import logging +from typing import Any, Optional +from urllib.parse import urlparse import requests from elasticsearch import Elasticsearch @@ -7,16 +9,20 @@ from flask import current_app from pydantic import BaseModel, model_validator from core.rag.datasource.entity.embedding import Embeddings +from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType from core.rag.models.document import Document +from extensions.ext_redis import redis_client from models.dataset import Dataset +logger = logging.getLogger(__name__) + class ElasticSearchConfig(BaseModel): host: str - port: str + port: int username: str password: str @@ -37,12 +43,19 @@ class ElasticSearchVector(BaseVector): def __init__(self, index_name: str, config: ElasticSearchConfig, attributes: list): super().__init__(index_name.lower()) self._client = self._init_client(config) + self._version = self._get_version() + self._check_version() self._attributes = attributes def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch: try: + parsed_url = urlparse(config.host) + if parsed_url.scheme in ['http', 'https']: + hosts = f'{config.host}:{config.port}' + else: + hosts = f'http://{config.host}:{config.port}' client = Elasticsearch( - hosts=f'{config.host}:{config.port}', + hosts=hosts, basic_auth=(config.username, config.password), request_timeout=100000, retry_on_timeout=True, @@ -53,42 +66,27 @@ class ElasticSearchVector(BaseVector): return client + def _get_version(self) -> str: + info = self._client.info() + return info['version']['number'] + + def _check_version(self): + if self._version < '8.0.0': + raise ValueError("Elasticsearch vector database version must be greater than 8.0.0") + def get_type(self) -> str: return 'elasticsearch' def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): uuids = self._get_uuids(documents) - texts = [d.page_content for d in documents] - metadatas = [d.metadata for d in documents] - - if not self._client.indices.exists(index=self._collection_name): - dim = len(embeddings[0]) - mapping = { - "properties": { - "text": { - "type": "text" - }, - "vector": { - "type": "dense_vector", - "index": True, - "dims": dim, - "similarity": "l2_norm" - }, - } - } - self._client.indices.create(index=self._collection_name, mappings=mapping) - - added_ids = [] - for i, text in enumerate(texts): + for i in range(len(documents)): self._client.index(index=self._collection_name, id=uuids[i], document={ - "text": text, - "vector": embeddings[i] if embeddings[i] else None, - "metadata": metadatas[i] if metadatas[i] else {}, + Field.CONTENT_KEY.value: documents[i].page_content, + Field.VECTOR.value: embeddings[i] if embeddings[i] else None, + Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {} }) - added_ids.append(uuids[i]) - self._client.indices.refresh(index=self._collection_name) return uuids @@ -116,28 +114,21 @@ class ElasticSearchVector(BaseVector): self._client.indices.delete(index=self._collection_name) def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: - query_str = { - "query": { - "script_score": { - "query": { - "match_all": {} - }, - "script": { - "source": "cosineSimilarity(params.query_vector, 'vector') + 1.0", - "params": { - "query_vector": query_vector - } - } - } - } + top_k = kwargs.get("top_k", 10) + knn = { + "field": Field.VECTOR.value, + "query_vector": query_vector, + "k": top_k } - results = self._client.search(index=self._collection_name, body=query_str) + results = self._client.search(index=self._collection_name, knn=knn, size=top_k) docs_and_scores = [] for hit in results['hits']['hits']: docs_and_scores.append( - (Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata']), hit['_score'])) + (Document(page_content=hit['_source'][Field.CONTENT_KEY.value], + vector=hit['_source'][Field.VECTOR.value], + metadata=hit['_source'][Field.METADATA_KEY.value]), hit['_score'])) docs = [] for doc, score in docs_and_scores: @@ -146,25 +137,61 @@ class ElasticSearchVector(BaseVector): doc.metadata['score'] = score docs.append(doc) - # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True) - return docs + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: query_str = { "match": { - "text": query + Field.CONTENT_KEY.value: query } } results = self._client.search(index=self._collection_name, query=query_str) docs = [] for hit in results['hits']['hits']: - docs.append(Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata'])) + docs.append(Document( + page_content=hit['_source'][Field.CONTENT_KEY.value], + vector=hit['_source'][Field.VECTOR.value], + metadata=hit['_source'][Field.METADATA_KEY.value], + )) return docs def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): - return self.add_texts(texts, embeddings, **kwargs) + metadatas = [d.metadata for d in texts] + self.create_collection(embeddings, metadatas) + self.add_texts(texts, embeddings, **kwargs) + + def create_collection( + self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + ): + lock_name = f'vector_indexing_lock_{self._collection_name}' + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = f'vector_indexing_{self._collection_name}' + if redis_client.get(collection_exist_cache_key): + logger.info(f"Collection {self._collection_name} already exists.") + return + + if not self._client.indices.exists(index=self._collection_name): + dim = len(embeddings[0]) + mappings = { + "properties": { + Field.CONTENT_KEY.value: {"type": "text"}, + Field.VECTOR.value: { # Make sure the dimension is correct here + "type": "dense_vector", + "dims": dim, + "similarity": "cosine" + }, + Field.METADATA_KEY.value: { + "type": "object", + "properties": { + "doc_id": {"type": "keyword"} # Map doc_id to keyword type + } + } + } + } + self._client.indices.create(index=self._collection_name, mappings=mappings) + + redis_client.set(collection_exist_cache_key, 1, ex=3600) class ElasticSearchVectorFactory(AbstractVectorFactory): diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index 4ae1a3395b..05e75effef 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -122,7 +122,7 @@ class MyScaleVector(BaseVector): def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 5) - score_threshold = kwargs.get("score_threshold", 0.0) + score_threshold = kwargs.get('score_threshold') or 0.0 where_str = f"WHERE dist < {1 - score_threshold}" if \ self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 else "" sql = f""" diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 9535455909..2b12b8a4b2 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -281,20 +281,25 @@ class NotionExtractor(BaseExtractor): for table_header_cell_text in tabel_header_cell: text = table_header_cell_text["text"]["content"] table_header_cell_texts.append(text) - # get table columns text and format + else: + table_header_cell_texts.append('') + # Initialize Markdown table with headers + markdown_table = "| " + " | ".join(table_header_cell_texts) + " |\n" + markdown_table += "| " + " | ".join(['---'] * len(table_header_cell_texts)) + " |\n" + + # Process data to format each row in Markdown table format results = data["results"] for i in range(len(results) - 1): column_texts = [] - tabel_column_cells = data["results"][i + 1]['table_row']['cells'] - for j in range(len(tabel_column_cells)): - if tabel_column_cells[j]: - for table_column_cell_text in tabel_column_cells[j]: + table_column_cells = data["results"][i + 1]['table_row']['cells'] + for j in range(len(table_column_cells)): + if table_column_cells[j]: + for table_column_cell_text in table_column_cells[j]: column_text = table_column_cell_text["text"]["content"] - column_texts.append(f'{table_header_cell_texts[j]}:{column_text}') - - cur_result_text = "\n".join(column_texts) - result_lines_arr.append(cur_result_text) - + column_texts.append(column_text) + # Add row to Markdown table + markdown_table += "| " + " | ".join(column_texts) + " |\n" + result_lines_arr.append(markdown_table) if data["next_cursor"] is None: done = True break diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index c3f0b75cfb..15822867bb 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -170,6 +170,8 @@ class WordExtractor(BaseExtractor): if run.element.xpath('.//a:blip'): for blip in run.element.xpath('.//a:blip'): image_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed") + if not image_id: + continue image_part = paragraph.part.rels[image_id].target_part if image_part in image_map: @@ -256,6 +258,6 @@ class WordExtractor(BaseExtractor): content.append(parsed_paragraph) elif isinstance(element.tag, str) and element.tag.endswith('tbl'): # table table = tables.pop(0) - content.append(self._table_to_markdown(table,image_map)) + content.append(self._table_to_markdown(table, image_map)) return '\n'.join(content) diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index 33e78ce8c5..176d0c1ed6 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -57,7 +57,7 @@ class BaseIndexProcessor(ABC): character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( chunk_size=segmentation["max_tokens"], - chunk_overlap=segmentation.get('chunk_overlap', 0), + chunk_overlap=segmentation.get('chunk_overlap', 0) or 0, fixed_separator=separator, separators=["\n\n", "。", ". ", " ", ""], embedding_model_instance=embedding_model_instance diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index e945364796..c970e3dafa 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -614,8 +614,9 @@ class DatasetRetrieval: top_k: int, score_threshold: float) -> list[Document]: filter_documents = [] for document in all_documents: - if score_threshold and document.metadata['score'] >= score_threshold: + if score_threshold is None or document.metadata['score'] >= score_threshold: filter_documents.append(document) + if not filter_documents: return [] filter_documents = sorted(filter_documents, key=lambda x: x.metadata['score'], reverse=True) diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index b3adcedc76..943f9918a7 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -30,15 +30,14 @@ def _split_text_with_regex( if keep_separator: # The parentheses in the pattern keep the delimiters in the result. _splits = re.split(f"({re.escape(separator)})", text) - splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)] - if len(_splits) % 2 == 0: + splits = [_splits[i - 1] + _splits[i] for i in range(1, len(_splits), 2)] + if len(_splits) % 2 != 0: splits += _splits[-1:] - splits = [_splits[0]] + splits else: splits = re.split(separator, text) else: splits = list(text) - return [s for s in splits if s != ""] + return [s for s in splits if (s != "" and s != '\n')] class TextSplitter(BaseDocumentTransformer, ABC): @@ -109,7 +108,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): else: return text - def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]: + def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int]) -> list[str]: # We now want to combine these smaller pieces into medium size # chunks to send to the LLM. separator_len = self._length_function(separator) @@ -117,8 +116,9 @@ class TextSplitter(BaseDocumentTransformer, ABC): docs = [] current_doc: list[str] = [] total = 0 + index = 0 for d in splits: - _len = self._length_function(d) + _len = lengths[index] if ( total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size @@ -146,6 +146,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): current_doc = current_doc[1:] current_doc.append(d) total += _len + (separator_len if len(current_doc) > 1 else 0) + index += 1 doc = self._join_docs(current_doc, separator) if doc is not None: docs.append(doc) @@ -494,11 +495,10 @@ class RecursiveCharacterTextSplitter(TextSplitter): self._separators = separators or ["\n\n", "\n", " ", ""] def _split_text(self, text: str, separators: list[str]) -> list[str]: - """Split incoming text and return chunks.""" final_chunks = [] - # Get appropriate separator to use separator = separators[-1] new_separators = [] + for i, _s in enumerate(separators): if _s == "": separator = _s @@ -509,25 +509,31 @@ class RecursiveCharacterTextSplitter(TextSplitter): break splits = _split_text_with_regex(text, separator, self._keep_separator) - # Now go merging things, recursively splitting longer texts. _good_splits = [] + _good_splits_lengths = [] # cache the lengths of the splits _separator = "" if self._keep_separator else separator + for s in splits: - if self._length_function(s) < self._chunk_size: + s_len = self._length_function(s) + if s_len < self._chunk_size: _good_splits.append(s) + _good_splits_lengths.append(s_len) else: if _good_splits: - merged_text = self._merge_splits(_good_splits, _separator) + merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths) final_chunks.extend(merged_text) _good_splits = [] + _good_splits_lengths = [] if not new_separators: final_chunks.append(s) else: other_info = self._split_text(s, new_separators) final_chunks.extend(other_info) + if _good_splits: - merged_text = self._merge_splits(_good_splits, _separator) + merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths) final_chunks.extend(merged_text) + return final_chunks def split_text(self, text: str) -> list[str]: diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 2e4433d9f6..e31dec55d2 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -148,7 +148,7 @@ class ToolParameter(BaseModel): form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm") llm_description: Optional[str] = None required: Optional[bool] = False - default: Optional[Union[int, str]] = None + default: Optional[Union[float, int, str]] = None min: Optional[Union[float, int]] = None max: Optional[Union[float, int]] = None options: Optional[list[ToolParameterOption]] = None diff --git a/api/core/tools/provider/_position.yaml b/api/core/tools/provider/_position.yaml index 25d9f403a0..40c3356116 100644 --- a/api/core/tools/provider/_position.yaml +++ b/api/core/tools/provider/_position.yaml @@ -1,5 +1,6 @@ - google - bing +- perplexity - duckduckgo - searchapi - serper @@ -10,6 +11,7 @@ - wikipedia - nominatim - yahoo +- alphavantage - arxiv - pubmed - stablediffusion @@ -30,5 +32,7 @@ - dingtalk - feishu - feishu_base +- feishu_document +- feishu_message - slack - tianditu diff --git a/api/core/tools/provider/builtin/_positions.py b/api/core/tools/provider/builtin/_positions.py index ae806eaff4..062668fc5b 100644 --- a/api/core/tools/provider/builtin/_positions.py +++ b/api/core/tools/provider/builtin/_positions.py @@ -1,6 +1,6 @@ import os.path -from core.helper.position_helper import get_position_map, sort_by_position_map +from core.helper.position_helper import get_tool_position_map, sort_by_position_map from core.tools.entities.api_entities import UserToolProvider @@ -10,11 +10,11 @@ class BuiltinToolProviderSort: @classmethod def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]: if not cls._position: - cls._position = get_position_map(os.path.join(os.path.dirname(__file__), '..')) + cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), '..')) def name_func(provider: UserToolProvider) -> str: return provider.name sorted_providers = sort_by_position_map(cls._position, providers, name_func) - return sorted_providers \ No newline at end of file + return sorted_providers diff --git a/api/core/tools/provider/builtin/alphavantage/_assets/icon.svg b/api/core/tools/provider/builtin/alphavantage/_assets/icon.svg new file mode 100644 index 0000000000..785432943b --- /dev/null +++ b/api/core/tools/provider/builtin/alphavantage/_assets/icon.svg @@ -0,0 +1,7 @@ + + + 形状结合 + + + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/alphavantage/alphavantage.py b/api/core/tools/provider/builtin/alphavantage/alphavantage.py new file mode 100644 index 0000000000..01f2acfb5b --- /dev/null +++ b/api/core/tools/provider/builtin/alphavantage/alphavantage.py @@ -0,0 +1,22 @@ +from typing import Any + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.alphavantage.tools.query_stock import QueryStockTool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class AlphaVantageProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + try: + QueryStockTool().fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ).invoke( + user_id='', + tool_parameters={ + "code": "AAPL", # Apple Inc. + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/alphavantage/alphavantage.yaml b/api/core/tools/provider/builtin/alphavantage/alphavantage.yaml new file mode 100644 index 0000000000..710510cfd8 --- /dev/null +++ b/api/core/tools/provider/builtin/alphavantage/alphavantage.yaml @@ -0,0 +1,31 @@ +identity: + author: zhuhao + name: alphavantage + label: + en_US: AlphaVantage + zh_Hans: AlphaVantage + pt_BR: AlphaVantage + description: + en_US: AlphaVantage is an online platform that provides financial market data and APIs, making it convenient for individual investors and developers to access stock quotes, technical indicators, and stock analysis. + zh_Hans: AlphaVantage是一个在线平台,它提供金融市场数据和API,便于个人投资者和开发者获取股票报价、技术指标和股票分析。 + pt_BR: AlphaVantage is an online platform that provides financial market data and APIs, making it convenient for individual investors and developers to access stock quotes, technical indicators, and stock analysis. + icon: icon.svg + tags: + - finance +credentials_for_provider: + api_key: + type: secret-input + required: true + label: + en_US: AlphaVantage API key + zh_Hans: AlphaVantage API key + pt_BR: AlphaVantage API key + placeholder: + en_US: Please input your AlphaVantage API key + zh_Hans: 请输入你的 AlphaVantage API key + pt_BR: Please input your AlphaVantage API key + help: + en_US: Get your AlphaVantage API key from AlphaVantage + zh_Hans: 从 AlphaVantage 获取您的 AlphaVantage API key + pt_BR: Get your AlphaVantage API key from AlphaVantage + url: https://www.alphavantage.co/support/#api-key diff --git a/api/core/tools/provider/builtin/alphavantage/tools/query_stock.py b/api/core/tools/provider/builtin/alphavantage/tools/query_stock.py new file mode 100644 index 0000000000..5c379b746d --- /dev/null +++ b/api/core/tools/provider/builtin/alphavantage/tools/query_stock.py @@ -0,0 +1,49 @@ +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +ALPHAVANTAGE_API_URL = "https://www.alphavantage.co/query" + + +class QueryStockTool(BuiltinTool): + + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + + stock_code = tool_parameters.get('code', '') + if not stock_code: + return self.create_text_message('Please tell me your stock code') + + if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'): + return self.create_text_message("Alpha Vantage API key is required.") + + params = { + "function": "TIME_SERIES_DAILY", + "symbol": stock_code, + "outputsize": "compact", + "datatype": "json", + "apikey": self.runtime.credentials['api_key'] + } + response = requests.get(url=ALPHAVANTAGE_API_URL, params=params) + response.raise_for_status() + result = self._handle_response(response.json()) + return self.create_json_message(result) + + def _handle_response(self, response: dict[str, Any]) -> dict[str, Any]: + result = response.get('Time Series (Daily)', {}) + if not result: + return {} + stock_result = {} + for k, v in result.items(): + stock_result[k] = {} + stock_result[k]['open'] = v.get('1. open') + stock_result[k]['high'] = v.get('2. high') + stock_result[k]['low'] = v.get('3. low') + stock_result[k]['close'] = v.get('4. close') + stock_result[k]['volume'] = v.get('5. volume') + return stock_result diff --git a/api/core/tools/provider/builtin/alphavantage/tools/query_stock.yaml b/api/core/tools/provider/builtin/alphavantage/tools/query_stock.yaml new file mode 100644 index 0000000000..d89f34e373 --- /dev/null +++ b/api/core/tools/provider/builtin/alphavantage/tools/query_stock.yaml @@ -0,0 +1,27 @@ +identity: + name: query_stock + author: zhuhao + label: + en_US: query_stock + zh_Hans: query_stock + pt_BR: query_stock +description: + human: + en_US: Retrieve information such as daily opening price, daily highest price, daily lowest price, daily closing price, and daily trading volume for a specified stock symbol. + zh_Hans: 获取指定股票代码的每日开盘价、每日最高价、每日最低价、每日收盘价和每日交易量等信息。 + pt_BR: Retrieve information such as daily opening price, daily highest price, daily lowest price, daily closing price, and daily trading volume for a specified stock symbol + llm: Retrieve information such as daily opening price, daily highest price, daily lowest price, daily closing price, and daily trading volume for a specified stock symbol +parameters: + - name: code + type: string + required: true + label: + en_US: stock code + zh_Hans: 股票代码 + pt_BR: stock code + human_description: + en_US: stock code + zh_Hans: 股票代码 + pt_BR: stock code + llm_description: stock code for query from alphavantage + form: llm diff --git a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.yaml b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.yaml index 63a8c99d97..e256748e8f 100644 --- a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.yaml +++ b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.yaml @@ -25,7 +25,7 @@ parameters: pt_BR: Prompt human_description: en_US: Image prompt, you can check the official documentation of DallE 3 - zh_Hans: 图像提示词,您可以查看DallE 3 的官方文档 + zh_Hans: 图像提示词,您可以查看 DallE 3 的官方文档 pt_BR: Imagem prompt, você pode verificar a documentação oficial do DallE 3 llm_description: Image prompt of DallE 3, you should describe the image you want to generate as a list of words as possible as detailed form: llm diff --git a/api/core/tools/provider/builtin/cogview/tools/cogview3.yaml b/api/core/tools/provider/builtin/cogview/tools/cogview3.yaml index ba0b271a1c..1de3f599b6 100644 --- a/api/core/tools/provider/builtin/cogview/tools/cogview3.yaml +++ b/api/core/tools/provider/builtin/cogview/tools/cogview3.yaml @@ -25,7 +25,7 @@ parameters: pt_BR: Prompt human_description: en_US: Image prompt, you can check the official documentation of CogView 3 - zh_Hans: 图像提示词,您可以查看CogView 3 的官方文档 + zh_Hans: 图像提示词,您可以查看 CogView 3 的官方文档 pt_BR: Image prompt, you can check the official documentation of CogView 3 llm_description: Image prompt of CogView 3, you should describe the image you want to generate as a list of words as possible as detailed form: llm diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle2.yaml b/api/core/tools/provider/builtin/dalle/tools/dalle2.yaml index 90c73ecc57..e43e5df8cd 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle2.yaml +++ b/api/core/tools/provider/builtin/dalle/tools/dalle2.yaml @@ -24,7 +24,7 @@ parameters: pt_BR: Prompt human_description: en_US: Image prompt, you can check the official documentation of DallE 2 - zh_Hans: 图像提示词,您可以查看DallE 2 的官方文档 + zh_Hans: 图像提示词,您可以查看 DallE 2 的官方文档 pt_BR: Image prompt, you can check the official documentation of DallE 2 llm_description: Image prompt of DallE 2, you should describe the image you want to generate as a list of words as possible as detailed form: llm diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle3.yaml b/api/core/tools/provider/builtin/dalle/tools/dalle3.yaml index 7ba5c56889..0cea8af761 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle3.yaml +++ b/api/core/tools/provider/builtin/dalle/tools/dalle3.yaml @@ -25,7 +25,7 @@ parameters: pt_BR: Prompt human_description: en_US: Image prompt, you can check the official documentation of DallE 3 - zh_Hans: 图像提示词,您可以查看DallE 3 的官方文档 + zh_Hans: 图像提示词,您可以查看 DallE 3 的官方文档 pt_BR: Image prompt, you can check the official documentation of DallE 3 llm_description: Image prompt of DallE 3, you should describe the image you want to generate as a list of words as possible as detailed form: llm diff --git a/api/core/tools/provider/builtin/feishu_document/_assets/icon.svg b/api/core/tools/provider/builtin/feishu_document/_assets/icon.svg new file mode 100644 index 0000000000..5a0a6416b3 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/_assets/icon.svg @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/api/core/tools/provider/builtin/feishu_document/feishu_document.py b/api/core/tools/provider/builtin/feishu_document/feishu_document.py new file mode 100644 index 0000000000..c4f8f26e2c --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/feishu_document.py @@ -0,0 +1,15 @@ +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class FeishuDocumentProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + app_id = credentials.get('app_id') + app_secret = credentials.get('app_secret') + if not app_id or not app_secret: + raise ToolProviderCredentialValidationError("app_id and app_secret is required") + try: + assert FeishuRequest(app_id, app_secret).tenant_access_token is not None + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/feishu_document/feishu_document.yaml b/api/core/tools/provider/builtin/feishu_document/feishu_document.yaml new file mode 100644 index 0000000000..8eaa6b2704 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/feishu_document.yaml @@ -0,0 +1,34 @@ +identity: + author: Doug Lea + name: feishu_document + label: + en_US: Lark Cloud Document + zh_Hans: 飞书云文档 + description: + en_US: Lark Cloud Document + zh_Hans: 飞书云文档 + icon: icon.svg + tags: + - social + - productivity +credentials_for_provider: + app_id: + type: text-input + required: true + label: + en_US: APP ID + placeholder: + en_US: Please input your feishu app id + zh_Hans: 请输入你的飞书 app id + help: + en_US: Get your app_id and app_secret from Feishu + zh_Hans: 从飞书获取您的 app_id 和 app_secret + url: https://open.feishu.cn + app_secret: + type: secret-input + required: true + label: + en_US: APP Secret + placeholder: + en_US: Please input your app secret + zh_Hans: 请输入你的飞书 app secret diff --git a/api/core/tools/provider/builtin/feishu_document/tools/create_document.py b/api/core/tools/provider/builtin/feishu_document/tools/create_document.py new file mode 100644 index 0000000000..0ff82e621b --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/tools/create_document.py @@ -0,0 +1,19 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class CreateDocumentTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get('app_id') + app_secret = self.runtime.credentials.get('app_secret') + client = FeishuRequest(app_id, app_secret) + + title = tool_parameters.get('title') + content = tool_parameters.get('content') + folder_token = tool_parameters.get('folder_token') + + res = client.create_document(title, content, folder_token) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/create_document.yaml b/api/core/tools/provider/builtin/feishu_document/tools/create_document.yaml new file mode 100644 index 0000000000..ddf2729f0e --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/tools/create_document.yaml @@ -0,0 +1,47 @@ +identity: + name: create_document + author: Doug Lea + label: + en_US: Create Lark document + zh_Hans: 创建飞书文档 +description: + human: + en_US: Create Lark document + zh_Hans: 创建飞书文档,支持创建空文档和带内容的文档,支持 markdown 语法创建。 + llm: A tool for creating Feishu documents. +parameters: + - name: title + type: string + required: false + label: + en_US: Document title + zh_Hans: 文档标题 + human_description: + en_US: Document title, only supports plain text content. + zh_Hans: 文档标题,只支持纯文本内容。 + llm_description: 文档标题,只支持纯文本内容,可以为空。 + form: llm + + - name: content + type: string + required: false + label: + en_US: Document content + zh_Hans: 文档内容 + human_description: + en_US: Document content, supports markdown syntax, can be empty. + zh_Hans: 文档内容,支持 markdown 语法,可以为空。 + llm_description: 文档内容,支持 markdown 语法,可以为空。 + form: llm + + - name: folder_token + type: string + required: false + label: + en_US: folder_token + zh_Hans: 文档所在文件夹的 Token + human_description: + en_US: The token of the folder where the document is located. If it is not passed or is empty, it means the root directory. + zh_Hans: 文档所在文件夹的 Token,不传或传空表示根目录。 + llm_description: 文档所在文件夹的 Token,不传或传空表示根目录。 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.py b/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.py new file mode 100644 index 0000000000..16ef90908b --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.py @@ -0,0 +1,17 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class GetDocumentRawContentTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get('app_id') + app_secret = self.runtime.credentials.get('app_secret') + client = FeishuRequest(app_id, app_secret) + + document_id = tool_parameters.get('document_id') + + res = client.get_document_raw_content(document_id) + return self.create_json_message(res) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.yaml b/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.yaml new file mode 100644 index 0000000000..e5b0937e03 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.yaml @@ -0,0 +1,23 @@ +identity: + name: get_document_raw_content + author: Doug Lea + label: + en_US: Get Document Raw Content + zh_Hans: 获取文档纯文本内容 +description: + human: + en_US: Get document raw content + zh_Hans: 获取文档纯文本内容 + llm: A tool for getting the plain text content of Feishu documents +parameters: + - name: document_id + type: string + required: true + label: + en_US: document_id + zh_Hans: 飞书文档的唯一标识 + human_description: + en_US: Unique ID of Feishu document document_id + zh_Hans: 飞书文档的唯一标识 document_id + llm_description: 飞书文档的唯一标识 document_id + form: llm diff --git a/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.py b/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.py new file mode 100644 index 0000000000..97d17bdb04 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.py @@ -0,0 +1,19 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class ListDocumentBlockTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get('app_id') + app_secret = self.runtime.credentials.get('app_secret') + client = FeishuRequest(app_id, app_secret) + + document_id = tool_parameters.get('document_id') + page_size = tool_parameters.get('page_size', 500) + page_token = tool_parameters.get('page_token', '') + + res = client.list_document_block(document_id, page_token, page_size) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.yaml b/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.yaml new file mode 100644 index 0000000000..d51e5a837c --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.yaml @@ -0,0 +1,48 @@ +identity: + name: list_document_block + author: Doug Lea + label: + en_US: List Document Block + zh_Hans: 获取飞书文档所有块 +description: + human: + en_US: List document block + zh_Hans: 获取飞书文档所有块的富文本内容并分页返回。 + llm: A tool to get all blocks of Feishu documents +parameters: + - name: document_id + type: string + required: true + label: + en_US: document_id + zh_Hans: 飞书文档的唯一标识 + human_description: + en_US: Unique ID of Feishu document document_id + zh_Hans: 飞书文档的唯一标识 document_id + llm_description: 飞书文档的唯一标识 document_id + form: llm + + - name: page_size + type: number + required: false + default: 500 + label: + en_US: page_size + zh_Hans: 分页大小 + human_description: + en_US: Paging size, the default and maximum value is 500. + zh_Hans: 分页大小, 默认值和最大值为 500。 + llm_description: 分页大小, 表示一次请求最多返回多少条数据,默认值和最大值为 500。 + form: llm + + - name: page_token + type: string + required: false + label: + en_US: page_token + zh_Hans: 分页标记 + human_description: + en_US: Pagination tag, used to paginate query results so that more items can be obtained in the next traversal. + zh_Hans: 分页标记,用于分页查询结果,以便下次遍历时获取更多项。 + llm_description: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_document/tools/write_document.py b/api/core/tools/provider/builtin/feishu_document/tools/write_document.py new file mode 100644 index 0000000000..914a44dce6 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/tools/write_document.py @@ -0,0 +1,19 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class CreateDocumentTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get('app_id') + app_secret = self.runtime.credentials.get('app_secret') + client = FeishuRequest(app_id, app_secret) + + document_id = tool_parameters.get('document_id') + content = tool_parameters.get('content') + position = tool_parameters.get('position') + + res = client.write_document(document_id, content, position) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/write_document.yaml b/api/core/tools/provider/builtin/feishu_document/tools/write_document.yaml new file mode 100644 index 0000000000..8ee219d4a7 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/tools/write_document.yaml @@ -0,0 +1,56 @@ +identity: + name: write_document + author: Doug Lea + label: + en_US: Write Document + zh_Hans: 在飞书文档中新增内容 +description: + human: + en_US: Adding new content to Lark documents + zh_Hans: 在飞书文档中新增内容 + llm: A tool for adding new content to Lark documents. +parameters: + - name: document_id + type: string + required: true + label: + en_US: document_id + zh_Hans: 飞书文档的唯一标识 + human_description: + en_US: Unique ID of Feishu document document_id + zh_Hans: 飞书文档的唯一标识 document_id + llm_description: 飞书文档的唯一标识 document_id + form: llm + + - name: content + type: string + required: true + label: + en_US: document content + zh_Hans: 文档内容 + human_description: + en_US: Document content, supports markdown syntax, can be empty. + zh_Hans: 文档内容,支持 markdown 语法,可以为空。 + llm_description: + form: llm + + - name: position + type: select + required: true + default: start + label: + en_US: Choose where to add content + zh_Hans: 选择添加内容的位置 + human_description: + en_US: Please fill in start or end to add content at the beginning or end of the document respectively. + zh_Hans: 请填入 start 或 end, 分别表示在文档开头(start)或结尾(end)添加内容。 + form: llm + options: + - value: start + label: + en_US: start + zh_Hans: 在文档开头添加内容 + - value: end + label: + en_US: end + zh_Hans: 在文档结尾添加内容 diff --git a/api/core/tools/provider/builtin/feishu_message/_assets/icon.svg b/api/core/tools/provider/builtin/feishu_message/_assets/icon.svg new file mode 100644 index 0000000000..222a1571f9 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_message/_assets/icon.svg @@ -0,0 +1,19 @@ + + + + diff --git a/api/core/tools/provider/builtin/feishu_message/feishu_message.py b/api/core/tools/provider/builtin/feishu_message/feishu_message.py new file mode 100644 index 0000000000..6d7fed330c --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_message/feishu_message.py @@ -0,0 +1,15 @@ +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class FeishuMessageProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + app_id = credentials.get('app_id') + app_secret = credentials.get('app_secret') + if not app_id or not app_secret: + raise ToolProviderCredentialValidationError("app_id and app_secret is required") + try: + assert FeishuRequest(app_id, app_secret).tenant_access_token is not None + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/feishu_message/feishu_message.yaml b/api/core/tools/provider/builtin/feishu_message/feishu_message.yaml new file mode 100644 index 0000000000..1bd8953ddd --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_message/feishu_message.yaml @@ -0,0 +1,34 @@ +identity: + author: Doug Lea + name: feishu_message + label: + en_US: Lark Message + zh_Hans: 飞书消息 + description: + en_US: Lark Message + zh_Hans: 飞书消息 + icon: icon.svg + tags: + - social + - productivity +credentials_for_provider: + app_id: + type: text-input + required: true + label: + en_US: APP ID + placeholder: + en_US: Please input your feishu app id + zh_Hans: 请输入你的飞书 app id + help: + en_US: Get your app_id and app_secret from Feishu + zh_Hans: 从飞书获取您的 app_id 和 app_secret + url: https://open.feishu.cn + app_secret: + type: secret-input + required: true + label: + en_US: APP Secret + placeholder: + en_US: Please input your app secret + zh_Hans: 请输入你的飞书 app secret diff --git a/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.py b/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.py new file mode 100644 index 0000000000..74f6866ba3 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.py @@ -0,0 +1,20 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class SendBotMessageTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get('app_id') + app_secret = self.runtime.credentials.get('app_secret') + client = FeishuRequest(app_id, app_secret) + + receive_id_type = tool_parameters.get('receive_id_type') + receive_id = tool_parameters.get('receive_id') + msg_type = tool_parameters.get('msg_type') + content = tool_parameters.get('content') + + res = client.send_bot_message(receive_id_type, receive_id, msg_type, content) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.yaml b/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.yaml new file mode 100644 index 0000000000..6e398b18ab --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.yaml @@ -0,0 +1,91 @@ +identity: + name: send_bot_message + author: Doug Lea + label: + en_US: Send Bot Message + zh_Hans: 发送飞书应用消息 +description: + human: + en_US: Send bot message + zh_Hans: 发送飞书应用消息 + llm: A tool for sending Feishu application messages. +parameters: + - name: receive_id_type + type: select + required: true + options: + - value: open_id + label: + en_US: open id + zh_Hans: open id + - value: union_id + label: + en_US: union id + zh_Hans: union id + - value: user_id + label: + en_US: user id + zh_Hans: user id + - value: email + label: + en_US: email + zh_Hans: email + - value: chat_id + label: + en_US: chat id + zh_Hans: chat id + label: + en_US: User ID Type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID Type + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id、email、chat_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id、email、chat_id。 + form: llm + + - name: receive_id + type: string + required: true + label: + en_US: Receive Id + zh_Hans: 消息接收者的 ID + human_description: + en_US: The ID of the message receiver. The ID type should correspond to the query parameter receive_id_type. + zh_Hans: 消息接收者的 ID,ID 类型应与查询参数 receive_id_type 对应。 + llm_description: 消息接收者的 ID,ID 类型应与查询参数 receive_id_type 对应。 + form: llm + + - name: msg_type + type: string + required: true + options: + - value: text + label: + en_US: text + zh_Hans: 文本 + - value: interactive + label: + en_US: message card + zh_Hans: 消息卡片 + label: + en_US: Message type + zh_Hans: 消息类型 + human_description: + en_US: Message type, optional values are, text (text), interactive (message card). + zh_Hans: 消息类型,可选值有:text(文本)、interactive(消息卡片)。 + llm_description: 消息类型,可选值有:text(文本)、interactive(消息卡片)。 + form: llm + + - name: content + type: string + required: true + label: + en_US: Message content + zh_Hans: 消息内容 + human_description: + en_US: Message content + zh_Hans: | + 消息内容,JSON 结构序列化后的字符串。不同 msg_type 对应不同内容, + 具体格式说明参考:https://open.larkoffice.com/document/server-docs/im-v1/message-content-description/create_json + llm_description: 消息内容,JSON 结构序列化后的字符串。不同 msg_type 对应不同内容。 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.py b/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.py new file mode 100644 index 0000000000..7159f59ffa --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.py @@ -0,0 +1,19 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class SendWebhookMessageTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) ->ToolInvokeMessage: + app_id = self.runtime.credentials.get('app_id') + app_secret = self.runtime.credentials.get('app_secret') + client = FeishuRequest(app_id, app_secret) + + webhook = tool_parameters.get('webhook') + msg_type = tool_parameters.get('msg_type') + content = tool_parameters.get('content') + + res = client.send_webhook_message(webhook, msg_type, content) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.yaml b/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.yaml new file mode 100644 index 0000000000..8b39ce4874 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.yaml @@ -0,0 +1,58 @@ +identity: + name: send_webhook_message + author: Doug Lea + label: + en_US: Send Webhook Message + zh_Hans: 使用自定义机器人发送飞书消息 +description: + human: + en_US: Send webhook message + zh_Hans: 使用自定义机器人发送飞书消息 + llm: A tool for sending Lark messages using a custom robot. +parameters: + - name: webhook + type: string + required: true + label: + en_US: webhook + zh_Hans: webhook 的地址 + human_description: + en_US: The address of the webhook + zh_Hans: webhook 的地址 + llm_description: webhook 的地址 + form: llm + + - name: msg_type + type: string + required: true + options: + - value: text + label: + en_US: text + zh_Hans: 文本 + - value: interactive + label: + en_US: message card + zh_Hans: 消息卡片 + label: + en_US: Message type + zh_Hans: 消息类型 + human_description: + en_US: Message type, optional values are, text (text), interactive (message card). + zh_Hans: 消息类型,可选值有:text(文本)、interactive(消息卡片)。 + llm_description: 消息类型,可选值有:text(文本)、interactive(消息卡片)。 + form: llm + + - name: content + type: string + required: true + label: + en_US: Message content + zh_Hans: 消息内容 + human_description: + en_US: Message content + zh_Hans: | + 消息内容,JSON 结构序列化后的字符串。不同 msg_type 对应不同内容, + 具体格式说明参考:https://open.larkoffice.com/document/server-docs/im-v1/message-content-description/create_json + llm_description: 消息内容,JSON 结构序列化后的字符串。不同 msg_type 对应不同内容。 + form: llm diff --git a/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py b/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py new file mode 100644 index 0000000000..b753be4791 --- /dev/null +++ b/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py @@ -0,0 +1,73 @@ +from novita_client import ( + Txt2ImgV3Embedding, + Txt2ImgV3HiresFix, + Txt2ImgV3LoRA, + Txt2ImgV3Refiner, + V3TaskImage, +) + + +class NovitaAiToolBase: + def _extract_loras(self, loras_str: str): + if not loras_str: + return [] + + loras_ori_list = lora_str.strip().split(';') + result_list = [] + for lora_str in loras_ori_list: + lora_info = lora_str.strip().split(',') + lora = Txt2ImgV3LoRA( + model_name=lora_info[0].strip(), + strength=float(lora_info[1]), + ) + result_list.append(lora) + + return result_list + + def _extract_embeddings(self, embeddings_str: str): + if not embeddings_str: + return [] + + embeddings_ori_list = embeddings_str.strip().split(';') + result_list = [] + for embedding_str in embeddings_ori_list: + embedding = Txt2ImgV3Embedding( + model_name=embedding_str.strip() + ) + result_list.append(embedding) + + return result_list + + def _extract_hires_fix(self, hires_fix_str: str): + hires_fix_info = hires_fix_str.strip().split(',') + if 'upscaler' in hires_fix_info: + hires_fix = Txt2ImgV3HiresFix( + target_width=int(hires_fix_info[0]), + target_height=int(hires_fix_info[1]), + strength=float(hires_fix_info[2]), + upscaler=hires_fix_info[3].strip() + ) + else: + hires_fix = Txt2ImgV3HiresFix( + target_width=int(hires_fix_info[0]), + target_height=int(hires_fix_info[1]), + strength=float(hires_fix_info[2]) + ) + + return hires_fix + + def _extract_refiner(self, switch_at: str): + refiner = Txt2ImgV3Refiner( + switch_at=float(switch_at) + ) + return refiner + + def _is_hit_nsfw_detection(self, image: V3TaskImage, confidence_threshold: float) -> bool: + """ + is hit nsfw + """ + if image.nsfw_detection_result is None: + return False + if image.nsfw_detection_result.valid and image.nsfw_detection_result.confidence >= confidence_threshold: + return True + return False diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py index c9524d6a66..5fef3d2da7 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py @@ -4,19 +4,15 @@ from typing import Any, Union from novita_client import ( NovitaClient, - Txt2ImgV3Embedding, - Txt2ImgV3HiresFix, - Txt2ImgV3LoRA, - Txt2ImgV3Refiner, - V3TaskImage, ) from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.novitaai._novita_tool_base import NovitaAiToolBase from core.tools.tool.builtin_tool import BuiltinTool -class NovitaAiTxt2ImgTool(BuiltinTool): +class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase): def _invoke(self, user_id: str, tool_parameters: dict[str, Any], @@ -73,65 +69,19 @@ class NovitaAiTxt2ImgTool(BuiltinTool): # process loras if 'loras' in res_parameters: - loras_ori_list = res_parameters.get('loras').strip().split(';') - locals_list = [] - for lora_str in loras_ori_list: - lora_info = lora_str.strip().split(',') - lora = Txt2ImgV3LoRA( - model_name=lora_info[0].strip(), - strength=float(lora_info[1]), - ) - locals_list.append(lora) - - res_parameters['loras'] = locals_list + res_parameters['loras'] = self._extract_loras(res_parameters.get('loras')) # process embeddings if 'embeddings' in res_parameters: - embeddings_ori_list = res_parameters.get('embeddings').strip().split(';') - locals_list = [] - for embedding_str in embeddings_ori_list: - embedding = Txt2ImgV3Embedding( - model_name=embedding_str.strip() - ) - locals_list.append(embedding) - - res_parameters['embeddings'] = locals_list + res_parameters['embeddings'] = self._extract_embeddings(res_parameters.get('embeddings')) # process hires_fix if 'hires_fix' in res_parameters: - hires_fix_ori = res_parameters.get('hires_fix') - hires_fix_info = hires_fix_ori.strip().split(',') - if 'upscaler' in hires_fix_info: - hires_fix = Txt2ImgV3HiresFix( - target_width=int(hires_fix_info[0]), - target_height=int(hires_fix_info[1]), - strength=float(hires_fix_info[2]), - upscaler=hires_fix_info[3].strip() - ) - else: - hires_fix = Txt2ImgV3HiresFix( - target_width=int(hires_fix_info[0]), - target_height=int(hires_fix_info[1]), - strength=float(hires_fix_info[2]) - ) + res_parameters['hires_fix'] = self._extract_hires_fix(res_parameters.get('hires_fix')) - res_parameters['hires_fix'] = hires_fix - - if 'refiner_switch_at' in res_parameters: - refiner = Txt2ImgV3Refiner( - switch_at=float(res_parameters.get('refiner_switch_at')) - ) - del res_parameters['refiner_switch_at'] - res_parameters['refiner'] = refiner + # process refiner + if 'refiner_switch_at' in res_parameters: + res_parameters['refiner'] = self._extract_refiner(res_parameters.get('refiner_switch_at')) + del res_parameters['refiner_switch_at'] return res_parameters - - def _is_hit_nsfw_detection(self, image: V3TaskImage, confidence_threshold: float) -> bool: - """ - is hit nsfw - """ - if image.nsfw_detection_result is None: - return False - if image.nsfw_detection_result.valid and image.nsfw_detection_result.confidence >= confidence_threshold: - return True - return False diff --git a/api/core/tools/provider/builtin/onebot/_assets/icon.ico b/api/core/tools/provider/builtin/onebot/_assets/icon.ico new file mode 100644 index 0000000000..1b07e965b9 Binary files /dev/null and b/api/core/tools/provider/builtin/onebot/_assets/icon.ico differ diff --git a/api/core/tools/provider/builtin/onebot/onebot.py b/api/core/tools/provider/builtin/onebot/onebot.py new file mode 100644 index 0000000000..42f321e919 --- /dev/null +++ b/api/core/tools/provider/builtin/onebot/onebot.py @@ -0,0 +1,12 @@ +from typing import Any + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class OneBotProvider(BuiltinToolProviderController): + + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + + if not credentials.get("ob11_http_url"): + raise ToolProviderCredentialValidationError('OneBot HTTP URL is required.') diff --git a/api/core/tools/provider/builtin/onebot/onebot.yaml b/api/core/tools/provider/builtin/onebot/onebot.yaml new file mode 100644 index 0000000000..1922adc4de --- /dev/null +++ b/api/core/tools/provider/builtin/onebot/onebot.yaml @@ -0,0 +1,35 @@ +identity: + author: RockChinQ + name: onebot + label: + en_US: OneBot v11 Protocol + zh_Hans: OneBot v11 协议 + description: + en_US: Unofficial OneBot v11 Protocol Tool + zh_Hans: 非官方 OneBot v11 协议工具 + icon: icon.ico +credentials_for_provider: + ob11_http_url: + type: text-input + required: true + label: + en_US: HTTP URL + zh_Hans: HTTP URL + description: + en_US: Forward HTTP URL of OneBot v11 + zh_Hans: OneBot v11 正向 HTTP URL + help: + en_US: Fill this with the HTTP URL of your OneBot server + zh_Hans: 请在你的 OneBot 协议端开启 正向 HTTP 并填写其 URL + access_token: + type: secret-input + required: false + label: + en_US: Access Token + zh_Hans: 访问令牌 + description: + en_US: Access Token for OneBot v11 Protocol + zh_Hans: OneBot 协议访问令牌 + help: + en_US: Fill this if you set a access token in your OneBot server + zh_Hans: 如果你在 OneBot 服务器中设置了 access token,请填写此项 diff --git a/api/core/tools/provider/builtin/onebot/tools/__init__.py b/api/core/tools/provider/builtin/onebot/tools/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py b/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py new file mode 100644 index 0000000000..2a1a9f86de --- /dev/null +++ b/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py @@ -0,0 +1,64 @@ +from typing import Any, Union + +import requests +from yarl import URL + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class SendGroupMsg(BuiltinTool): + """OneBot v11 Tool: Send Group Message""" + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + + # Get parameters + send_group_id = tool_parameters.get('group_id', '') + + message = tool_parameters.get('message', '') + if not message: + return self.create_json_message( + { + 'error': 'Message is empty.' + } + ) + + auto_escape = tool_parameters.get('auto_escape', False) + + try: + url = URL(self.runtime.credentials['ob11_http_url']) / 'send_group_msg' + + resp = requests.post( + url, + json={ + 'group_id': send_group_id, + 'message': message, + 'auto_escape': auto_escape + }, + headers={ + 'Authorization': 'Bearer ' + self.runtime.credentials['access_token'] + } + ) + + if resp.status_code != 200: + return self.create_json_message( + { + 'error': f'Failed to send group message: {resp.text}' + } + ) + + return self.create_json_message( + { + 'response': resp.json() + } + ) + except Exception as e: + return self.create_json_message( + { + 'error': f'Failed to send group message: {e}' + } + ) diff --git a/api/core/tools/provider/builtin/onebot/tools/send_group_msg.yaml b/api/core/tools/provider/builtin/onebot/tools/send_group_msg.yaml new file mode 100644 index 0000000000..64beaa8545 --- /dev/null +++ b/api/core/tools/provider/builtin/onebot/tools/send_group_msg.yaml @@ -0,0 +1,46 @@ +identity: + name: send_group_msg + author: RockChinQ + label: + en_US: Send Group Message + zh_Hans: 发送群消息 +description: + human: + en_US: Send a message to a group + zh_Hans: 发送消息到群聊 + llm: A tool for sending a message segment to a group +parameters: + - name: group_id + type: number + required: true + label: + en_US: Target Group ID + zh_Hans: 目标群 ID + human_description: + en_US: The group ID of the target group + zh_Hans: 目标群的群 ID + llm_description: The group ID of the target group + form: llm + - name: message + type: string + required: true + label: + en_US: Message + zh_Hans: 消息 + human_description: + en_US: The message to send + zh_Hans: 要发送的消息。支持 CQ码(需要同时设置 auto_escape 为 true) + llm_description: The message to send + form: llm + - name: auto_escape + type: boolean + required: false + default: false + label: + en_US: Auto Escape + zh_Hans: 自动转义 + human_description: + en_US: If true, the message will be treated as a CQ code for parsing, otherwise it will be treated as plain text for direct sending. Since Dify currently does not support passing Object-format message chains, developers can send complex message components through CQ codes. + zh_Hans: 若为 true 则会把 message 视为 CQ 码解析,否则视为 纯文本 直接发送。由于 Dify 目前不支持传入 Object格式 的消息,故开发者可以通过 CQ 码来发送复杂消息组件。 + llm_description: If true, the message will be treated as a CQ code for parsing, otherwise it will be treated as plain text for direct sending. + form: form diff --git a/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py b/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py new file mode 100644 index 0000000000..8ef4d72ab6 --- /dev/null +++ b/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py @@ -0,0 +1,64 @@ +from typing import Any, Union + +import requests +from yarl import URL + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class SendPrivateMsg(BuiltinTool): + """OneBot v11 Tool: Send Private Message""" + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + + # Get parameters + send_user_id = tool_parameters.get('user_id', '') + + message = tool_parameters.get('message', '') + if not message: + return self.create_json_message( + { + 'error': 'Message is empty.' + } + ) + + auto_escape = tool_parameters.get('auto_escape', False) + + try: + url = URL(self.runtime.credentials['ob11_http_url']) / 'send_private_msg' + + resp = requests.post( + url, + json={ + 'user_id': send_user_id, + 'message': message, + 'auto_escape': auto_escape + }, + headers={ + 'Authorization': 'Bearer ' + self.runtime.credentials['access_token'] + } + ) + + if resp.status_code != 200: + return self.create_json_message( + { + 'error': f'Failed to send private message: {resp.text}' + } + ) + + return self.create_json_message( + { + 'response': resp.json() + } + ) + except Exception as e: + return self.create_json_message( + { + 'error': f'Failed to send private message: {e}' + } + ) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/onebot/tools/send_private_msg.yaml b/api/core/tools/provider/builtin/onebot/tools/send_private_msg.yaml new file mode 100644 index 0000000000..8200ce4a83 --- /dev/null +++ b/api/core/tools/provider/builtin/onebot/tools/send_private_msg.yaml @@ -0,0 +1,46 @@ +identity: + name: send_private_msg + author: RockChinQ + label: + en_US: Send Private Message + zh_Hans: 发送私聊消息 +description: + human: + en_US: Send a private message to a user + zh_Hans: 发送私聊消息给用户 + llm: A tool for sending a message segment to a user in private chat +parameters: + - name: user_id + type: number + required: true + label: + en_US: Target User ID + zh_Hans: 目标用户 ID + human_description: + en_US: The user ID of the target user + zh_Hans: 目标用户的用户 ID + llm_description: The user ID of the target user + form: llm + - name: message + type: string + required: true + label: + en_US: Message + zh_Hans: 消息 + human_description: + en_US: The message to send + zh_Hans: 要发送的消息。支持 CQ码(需要同时设置 auto_escape 为 true) + llm_description: The message to send + form: llm + - name: auto_escape + type: boolean + required: false + default: false + label: + en_US: Auto Escape + zh_Hans: 自动转义 + human_description: + en_US: If true, the message will be treated as a CQ code for parsing, otherwise it will be treated as plain text for direct sending. Since Dify currently does not support passing Object-format message chains, developers can send complex message components through CQ codes. + zh_Hans: 若为 true 则会把 message 视为 CQ 码解析,否则视为 纯文本 直接发送。由于 Dify 目前不支持传入 Object格式 的消息,故开发者可以通过 CQ 码来发送复杂消息组件。 + llm_description: If true, the message will be treated as a CQ code for parsing, otherwise it will be treated as plain text for direct sending. + form: form diff --git a/api/core/tools/provider/builtin/perplexity/_assets/icon.svg b/api/core/tools/provider/builtin/perplexity/_assets/icon.svg new file mode 100644 index 0000000000..c2974c142f --- /dev/null +++ b/api/core/tools/provider/builtin/perplexity/_assets/icon.svg @@ -0,0 +1,3 @@ + + + diff --git a/api/core/tools/provider/builtin/perplexity/perplexity.py b/api/core/tools/provider/builtin/perplexity/perplexity.py new file mode 100644 index 0000000000..ff91edf18d --- /dev/null +++ b/api/core/tools/provider/builtin/perplexity/perplexity.py @@ -0,0 +1,46 @@ +from typing import Any + +import requests + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.perplexity.tools.perplexity_search import PERPLEXITY_API_URL +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class PerplexityProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + headers = { + "Authorization": f"Bearer {credentials.get('perplexity_api_key')}", + "Content-Type": "application/json" + } + + payload = { + "model": "llama-3.1-sonar-small-128k-online", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "Hello" + } + ], + "max_tokens": 5, + "temperature": 0.1, + "top_p": 0.9, + "stream": False + } + + try: + response = requests.post(PERPLEXITY_API_URL, json=payload, headers=headers) + response.raise_for_status() + except requests.RequestException as e: + raise ToolProviderCredentialValidationError( + f"Failed to validate Perplexity API key: {str(e)}" + ) + + if response.status_code != 200: + raise ToolProviderCredentialValidationError( + f"Perplexity API key is invalid. Status code: {response.status_code}" + ) diff --git a/api/core/tools/provider/builtin/perplexity/perplexity.yaml b/api/core/tools/provider/builtin/perplexity/perplexity.yaml new file mode 100644 index 0000000000..c0b504f300 --- /dev/null +++ b/api/core/tools/provider/builtin/perplexity/perplexity.yaml @@ -0,0 +1,26 @@ +identity: + author: Dify + name: perplexity + label: + en_US: Perplexity + zh_Hans: Perplexity + description: + en_US: Perplexity.AI + zh_Hans: Perplexity.AI + icon: icon.svg + tags: + - search +credentials_for_provider: + perplexity_api_key: + type: secret-input + required: true + label: + en_US: Perplexity API key + zh_Hans: Perplexity API key + placeholder: + en_US: Please input your Perplexity API key + zh_Hans: 请输入你的 Perplexity API key + help: + en_US: Get your Perplexity API key from Perplexity + zh_Hans: 从 Perplexity 获取您的 Perplexity API key + url: https://www.perplexity.ai/settings/api diff --git a/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.py b/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.py new file mode 100644 index 0000000000..5b1a263f9b --- /dev/null +++ b/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.py @@ -0,0 +1,72 @@ +import json +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +PERPLEXITY_API_URL = "https://api.perplexity.ai/chat/completions" + +class PerplexityAITool(BuiltinTool): + def _parse_response(self, response: dict) -> dict: + """Parse the response from Perplexity AI API""" + if 'choices' in response and len(response['choices']) > 0: + message = response['choices'][0]['message'] + return { + 'content': message.get('content', ''), + 'role': message.get('role', ''), + 'citations': response.get('citations', []) + } + else: + return {'content': 'Unable to get a valid response', 'role': 'assistant', 'citations': []} + + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + headers = { + "Authorization": f"Bearer {self.runtime.credentials['perplexity_api_key']}", + "Content-Type": "application/json" + } + + payload = { + "model": tool_parameters.get('model', 'llama-3.1-sonar-small-128k-online'), + "messages": [ + { + "role": "system", + "content": "Be precise and concise." + }, + { + "role": "user", + "content": tool_parameters['query'] + } + ], + "max_tokens": tool_parameters.get('max_tokens', 4096), + "temperature": tool_parameters.get('temperature', 0.7), + "top_p": tool_parameters.get('top_p', 1), + "top_k": tool_parameters.get('top_k', 5), + "presence_penalty": tool_parameters.get('presence_penalty', 0), + "frequency_penalty": tool_parameters.get('frequency_penalty', 1), + "stream": False + } + + if 'search_recency_filter' in tool_parameters: + payload['search_recency_filter'] = tool_parameters['search_recency_filter'] + if 'return_citations' in tool_parameters: + payload['return_citations'] = tool_parameters['return_citations'] + if 'search_domain_filter' in tool_parameters: + if isinstance(tool_parameters['search_domain_filter'], str): + payload['search_domain_filter'] = [tool_parameters['search_domain_filter']] + elif isinstance(tool_parameters['search_domain_filter'], list): + payload['search_domain_filter'] = tool_parameters['search_domain_filter'] + + + response = requests.post(url=PERPLEXITY_API_URL, json=payload, headers=headers) + response.raise_for_status() + valuable_res = self._parse_response(response.json()) + + return [ + self.create_json_message(valuable_res), + self.create_text_message(json.dumps(valuable_res, ensure_ascii=False, indent=2)) + ] diff --git a/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.yaml b/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.yaml new file mode 100644 index 0000000000..02a645df33 --- /dev/null +++ b/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.yaml @@ -0,0 +1,178 @@ +identity: + name: perplexity + author: Dify + label: + en_US: Perplexity Search +description: + human: + en_US: Search information using Perplexity AI's language models. + llm: This tool is used to search information using Perplexity AI's language models. +parameters: + - name: query + type: string + required: true + label: + en_US: Query + zh_Hans: 查询 + human_description: + en_US: The text query to be processed by the AI model. + zh_Hans: 要由 AI 模型处理的文本查询。 + form: llm + - name: model + type: select + required: false + label: + en_US: Model Name + zh_Hans: 模型名称 + human_description: + en_US: The Perplexity AI model to use for generating the response. + zh_Hans: 用于生成响应的 Perplexity AI 模型。 + form: form + default: "llama-3.1-sonar-small-128k-online" + options: + - value: llama-3.1-sonar-small-128k-online + label: + en_US: llama-3.1-sonar-small-128k-online + zh_Hans: llama-3.1-sonar-small-128k-online + - value: llama-3.1-sonar-large-128k-online + label: + en_US: llama-3.1-sonar-large-128k-online + zh_Hans: llama-3.1-sonar-large-128k-online + - value: llama-3.1-sonar-huge-128k-online + label: + en_US: llama-3.1-sonar-huge-128k-online + zh_Hans: llama-3.1-sonar-huge-128k-online + - name: max_tokens + type: number + required: false + label: + en_US: Max Tokens + zh_Hans: 最大令牌数 + pt_BR: Máximo de Tokens + human_description: + en_US: The maximum number of tokens to generate in the response. + zh_Hans: 在响应中生成的最大令牌数。 + pt_BR: O número máximo de tokens a serem gerados na resposta. + form: form + default: 4096 + min: 1 + max: 4096 + - name: temperature + type: number + required: false + label: + en_US: Temperature + zh_Hans: 温度 + pt_BR: Temperatura + human_description: + en_US: Controls randomness in the output. Lower values make the output more focused and deterministic. + zh_Hans: 控制输出的随机性。较低的值使输出更加集中和确定。 + form: form + default: 0.7 + min: 0 + max: 1 + - name: top_k + type: number + required: false + label: + en_US: Top K + zh_Hans: 取样数量 + human_description: + en_US: The number of top results to consider for response generation. + zh_Hans: 用于生成响应的顶部结果数量。 + form: form + default: 5 + min: 1 + max: 100 + - name: top_p + type: number + required: false + label: + en_US: Top P + zh_Hans: Top P + human_description: + en_US: Controls diversity via nucleus sampling. + zh_Hans: 通过核心采样控制多样性。 + form: form + default: 1 + min: 0.1 + max: 1 + step: 0.1 + - name: presence_penalty + type: number + required: false + label: + en_US: Presence Penalty + zh_Hans: 存在惩罚 + human_description: + en_US: Positive values penalize new tokens based on whether they appear in the text so far. + zh_Hans: 正值会根据新词元是否已经出现在文本中来对其进行惩罚。 + form: form + default: 0 + min: -1.0 + max: 1.0 + step: 0.1 + - name: frequency_penalty + type: number + required: false + label: + en_US: Frequency Penalty + zh_Hans: 频率惩罚 + human_description: + en_US: Positive values penalize new tokens based on their existing frequency in the text so far. + zh_Hans: 正值会根据新词元在文本中已经出现的频率来对其进行惩罚。 + form: form + default: 1 + min: 0.1 + max: 1.0 + step: 0.1 + - name: return_citations + type: boolean + required: false + label: + en_US: Return Citations + zh_Hans: 返回引用 + human_description: + en_US: Whether to return citations in the response. + zh_Hans: 是否在响应中返回引用。 + form: form + default: true + - name: search_domain_filter + type: string + required: false + label: + en_US: Search Domain Filter + zh_Hans: 搜索域过滤器 + human_description: + en_US: Domain to filter the search results. + zh_Hans: 用于过滤搜索结果的域名。 + form: form + default: "" + - name: search_recency_filter + type: select + required: false + label: + en_US: Search Recency Filter + zh_Hans: 搜索时间过滤器 + human_description: + en_US: Filter for search results based on recency. + zh_Hans: 基于时间筛选搜索结果。 + form: form + default: "month" + options: + - value: day + label: + en_US: Day + zh_Hans: 天 + - value: week + label: + en_US: Week + zh_Hans: 周 + - value: month + label: + en_US: Month + zh_Hans: 月 + - value: year + label: + en_US: Year + zh_Hans: 年 diff --git a/api/core/tools/provider/builtin/siliconflow/_assets/icon.svg b/api/core/tools/provider/builtin/siliconflow/_assets/icon.svg new file mode 100644 index 0000000000..ad6b384f7a --- /dev/null +++ b/api/core/tools/provider/builtin/siliconflow/_assets/icon.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/siliconflow/siliconflow.py b/api/core/tools/provider/builtin/siliconflow/siliconflow.py new file mode 100644 index 0000000000..0df78280df --- /dev/null +++ b/api/core/tools/provider/builtin/siliconflow/siliconflow.py @@ -0,0 +1,19 @@ +import requests + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class SiliconflowProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + url = "https://api.siliconflow.cn/v1/models" + headers = { + "accept": "application/json", + "authorization": f"Bearer {credentials.get('siliconFlow_api_key')}", + } + + response = requests.get(url, headers=headers) + if response.status_code != 200: + raise ToolProviderCredentialValidationError( + "SiliconFlow API key is invalid" + ) diff --git a/api/core/tools/provider/builtin/siliconflow/siliconflow.yaml b/api/core/tools/provider/builtin/siliconflow/siliconflow.yaml new file mode 100644 index 0000000000..46be99f262 --- /dev/null +++ b/api/core/tools/provider/builtin/siliconflow/siliconflow.yaml @@ -0,0 +1,21 @@ +identity: + author: hjlarry + name: siliconflow + label: + en_US: SiliconFlow + zh_CN: 硅基流动 + description: + en_US: The image generation API provided by SiliconFlow includes Flux and Stable Diffusion models. + zh_CN: 硅基流动提供的图片生成 API,包含 Flux 和 Stable Diffusion 模型。 + icon: icon.svg + tags: + - image +credentials_for_provider: + siliconFlow_api_key: + type: secret-input + required: true + label: + en_US: SiliconFlow API Key + placeholder: + en_US: Please input your SiliconFlow API key + url: https://cloud.siliconflow.cn/account/ak diff --git a/api/core/tools/provider/builtin/siliconflow/tools/flux.py b/api/core/tools/provider/builtin/siliconflow/tools/flux.py new file mode 100644 index 0000000000..ed9f4be574 --- /dev/null +++ b/api/core/tools/provider/builtin/siliconflow/tools/flux.py @@ -0,0 +1,44 @@ +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +FLUX_URL = ( + "https://api.siliconflow.cn/v1/black-forest-labs/FLUX.1-schnell/text-to-image" +) + + +class FluxTool(BuiltinTool): + + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + + headers = { + "accept": "application/json", + "content-type": "application/json", + "authorization": f"Bearer {self.runtime.credentials['siliconFlow_api_key']}", + } + + payload = { + "prompt": tool_parameters.get("prompt"), + "image_size": tool_parameters.get("image_size", "1024x1024"), + "seed": tool_parameters.get("seed"), + "num_inference_steps": tool_parameters.get("num_inference_steps", 20), + } + + response = requests.post(FLUX_URL, json=payload, headers=headers) + if response.status_code != 200: + return self.create_text_message(f"Got Error Response:{response.text}") + + res = response.json() + result = [self.create_json_message(res)] + for image in res.get("images", []): + result.append( + self.create_image_message( + image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value + ) + ) + return result diff --git a/api/core/tools/provider/builtin/siliconflow/tools/flux.yaml b/api/core/tools/provider/builtin/siliconflow/tools/flux.yaml new file mode 100644 index 0000000000..2a0698700c --- /dev/null +++ b/api/core/tools/provider/builtin/siliconflow/tools/flux.yaml @@ -0,0 +1,73 @@ +identity: + name: flux + author: hjlarry + label: + en_US: Flux + icon: icon.svg +description: + human: + en_US: Generate image via SiliconFlow's flux schnell. + llm: This tool is used to generate image from prompt via SiliconFlow's flux schnell model. +parameters: + - name: prompt + type: string + required: true + label: + en_US: prompt + zh_Hans: 提示词 + human_description: + en_US: The text prompt used to generate the image. + zh_Hans: 用于生成图片的文字提示词 + llm_description: this prompt text will be used to generate image. + form: llm + - name: image_size + type: select + required: true + options: + - value: 1024x1024 + label: + en_US: 1024x1024 + - value: 768x1024 + label: + en_US: 768x1024 + - value: 576x1024 + label: + en_US: 576x1024 + - value: 512x1024 + label: + en_US: 512x1024 + - value: 1024x576 + label: + en_US: 1024x576 + - value: 768x512 + label: + en_US: 768x512 + default: 1024x1024 + label: + en_US: Choose Image Size + zh_Hans: 选择生成的图片大小 + form: form + - name: num_inference_steps + type: number + required: true + default: 20 + min: 1 + max: 100 + label: + en_US: Num Inference Steps + zh_Hans: 生成图片的步数 + form: form + human_description: + en_US: The number of inference steps to perform. More steps produce higher quality but take longer. + zh_Hans: 执行的推理步骤数量。更多的步骤可以产生更高质量的结果,但需要更长的时间。 + - name: seed + type: number + min: 0 + max: 9999999999 + label: + en_US: Seed + zh_Hans: 种子 + human_description: + en_US: The same seed and prompt can produce similar images. + zh_Hans: 相同的种子和提示可以产生相似的图像。 + form: form diff --git a/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py new file mode 100644 index 0000000000..e8134a6565 --- /dev/null +++ b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py @@ -0,0 +1,51 @@ +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +SDURL = { + "sd_3": "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-3-medium/text-to-image", + "sd_xl": "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-xl-base-1.0/text-to-image", +} + + +class StableDiffusionTool(BuiltinTool): + + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + + headers = { + "accept": "application/json", + "content-type": "application/json", + "authorization": f"Bearer {self.runtime.credentials['siliconFlow_api_key']}", + } + + model = tool_parameters.get("model", "sd_3") + url = SDURL.get(model) + + payload = { + "prompt": tool_parameters.get("prompt"), + "negative_prompt": tool_parameters.get("negative_prompt", ""), + "image_size": tool_parameters.get("image_size", "1024x1024"), + "batch_size": tool_parameters.get("batch_size", 1), + "seed": tool_parameters.get("seed"), + "guidance_scale": tool_parameters.get("guidance_scale", 7.5), + "num_inference_steps": tool_parameters.get("num_inference_steps", 20), + } + + response = requests.post(url, json=payload, headers=headers) + if response.status_code != 200: + return self.create_text_message(f"Got Error Response:{response.text}") + + res = response.json() + result = [self.create_json_message(res)] + for image in res.get("images", []): + result.append( + self.create_image_message( + image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value + ) + ) + return result diff --git a/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.yaml b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.yaml new file mode 100644 index 0000000000..dce10adc87 --- /dev/null +++ b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.yaml @@ -0,0 +1,121 @@ +identity: + name: stable_diffusion + author: hjlarry + label: + en_US: Stable Diffusion + icon: icon.svg +description: + human: + en_US: Generate image via SiliconFlow's stable diffusion model. + llm: This tool is used to generate image from prompt via SiliconFlow's stable diffusion model. +parameters: + - name: prompt + type: string + required: true + label: + en_US: prompt + zh_Hans: 提示词 + human_description: + en_US: The text prompt used to generate the image. + zh_Hans: 用于生成图片的文字提示词 + llm_description: this prompt text will be used to generate image. + form: llm + - name: negative_prompt + type: string + label: + en_US: negative prompt + zh_Hans: 负面提示词 + human_description: + en_US: Describe what you don't want included in the image. + zh_Hans: 描述您不希望包含在图片中的内容。 + llm_description: Describe what you don't want included in the image. + form: llm + - name: model + type: select + required: true + options: + - value: sd_3 + label: + en_US: Stable Diffusion 3 + - value: sd_xl + label: + en_US: Stable Diffusion XL + default: sd_3 + label: + en_US: Choose Image Model + zh_Hans: 选择生成图片的模型 + form: form + - name: image_size + type: select + required: true + options: + - value: 1024x1024 + label: + en_US: 1024x1024 + - value: 1024x2048 + label: + en_US: 1024x2048 + - value: 1152x2048 + label: + en_US: 1152x2048 + - value: 1536x1024 + label: + en_US: 1536x1024 + - value: 1536x2048 + label: + en_US: 1536x2048 + - value: 2048x1152 + label: + en_US: 2048x1152 + default: 1024x1024 + label: + en_US: Choose Image Size + zh_Hans: 选择生成图片的大小 + form: form + - name: batch_size + type: number + required: true + default: 1 + min: 1 + max: 4 + label: + en_US: Number Images + zh_Hans: 生成图片的数量 + form: form + - name: guidance_scale + type: number + required: true + default: 7.5 + min: 0 + max: 100 + label: + en_US: Guidance Scale + zh_Hans: 与提示词紧密性 + human_description: + en_US: Classifier Free Guidance. How close you want the model to stick to your prompt when looking for a related image to show you. + zh_Hans: 无分类器引导。您希望模型在寻找相关图片向您展示时,与您的提示保持多紧密的关联度。 + form: form + - name: num_inference_steps + type: number + required: true + default: 20 + min: 1 + max: 100 + label: + en_US: Num Inference Steps + zh_Hans: 生成图片的步数 + human_description: + en_US: The number of inference steps to perform. More steps produce higher quality but take longer. + zh_Hans: 执行的推理步骤数量。更多的步骤可以产生更高质量的结果,但需要更长的时间。 + form: form + - name: seed + type: number + min: 0 + max: 9999999999 + label: + en_US: Seed + zh_Hans: 种子 + human_description: + en_US: The same seed and prompt can produce similar images. + zh_Hans: 相同的种子和提示可以产生相似的图像。 + form: form diff --git a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py index 0c5ebc23ac..4be9207d66 100644 --- a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py +++ b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py @@ -27,7 +27,7 @@ DRAW_TEXT_OPTIONS = { "seed_resize_from_w": -1, # Samplers - # "sampler_name": "DPM++ 2M", + "sampler_name": "DPM++ 2M", # "scheduler": "", # "sampler_index": "Automatic", @@ -178,6 +178,23 @@ class StableDiffusionTool(BuiltinTool): return [d['model_name'] for d in response.json()] except Exception as e: return [] + + def get_sample_methods(self) -> list[str]: + """ + get sample method + """ + try: + base_url = self.runtime.credentials.get('base_url', None) + if not base_url: + return [] + api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'samplers') + response = get(url=api_url, timeout=(2, 10)) + if response.status_code != 200: + return [] + else: + return [d['name'] for d in response.json()] + except Exception as e: + return [] def img2img(self, base_url: str, tool_parameters: dict[str, Any]) \ -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: @@ -339,7 +356,27 @@ class StableDiffusionTool(BuiltinTool): label=I18nObject(en_US=i, zh_Hans=i) ) for i in models]) ) + except: pass - + + sample_methods = self.get_sample_methods() + if len(sample_methods) != 0: + parameters.append( + ToolParameter(name='sampler_name', + label=I18nObject(en_US='Sampling method', zh_Hans='Sampling method'), + human_description=I18nObject( + en_US='Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion', + zh_Hans='Stable Diffusion 的Sampling method,您可以查看 Stable Diffusion 的官方文档', + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + llm_description='Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion', + required=True, + default=sample_methods[0], + options=[ToolParameterOption( + value=i, + label=I18nObject(en_US=i, zh_Hans=i) + ) for i in sample_methods]) + ) return parameters diff --git a/api/core/tools/provider/builtin/stepfun/tools/image.py b/api/core/tools/provider/builtin/stepfun/tools/image.py index 5e544aada6..c571f54675 100644 --- a/api/core/tools/provider/builtin/stepfun/tools/image.py +++ b/api/core/tools/provider/builtin/stepfun/tools/image.py @@ -17,11 +17,8 @@ class StepfunTool(BuiltinTool): """ invoke tools """ - base_url = self.runtime.credentials.get('stepfun_base_url', None) - if not base_url: - base_url = None - else: - base_url = str(URL(base_url) / 'v1') + base_url = self.runtime.credentials.get('stepfun_base_url', 'https://api.stepfun.com') + base_url = str(URL(base_url) / 'v1') client = OpenAI( api_key=self.runtime.credentials['stepfun_api_key'], diff --git a/api/core/tools/provider/builtin/stepfun/tools/image.yaml b/api/core/tools/provider/builtin/stepfun/tools/image.yaml index 1e20b157aa..dcc5bd2db2 100644 --- a/api/core/tools/provider/builtin/stepfun/tools/image.yaml +++ b/api/core/tools/provider/builtin/stepfun/tools/image.yaml @@ -25,7 +25,7 @@ parameters: pt_BR: Prompt human_description: en_US: Image prompt, you can check the official documentation of step-1x - zh_Hans: 图像提示词,您可以查看step-1x 的官方文档 + zh_Hans: 图像提示词,您可以查看 step-1x 的官方文档 pt_BR: Image prompt, you can check the official documentation of step-1x llm_description: Image prompt of step-1x you should describe the image you want to generate as a list of words as possible as detailed form: llm diff --git a/api/core/tools/provider/workflow_tool_provider.py b/api/core/tools/provider/workflow_tool_provider.py index f7911fea1d..f14abac767 100644 --- a/api/core/tools/provider/workflow_tool_provider.py +++ b/api/core/tools/provider/workflow_tool_provider.py @@ -1,6 +1,6 @@ from typing import Optional -from core.app.app_config.entities import VariableEntity +from core.app.app_config.entities import VariableEntity, VariableEntityType from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ( @@ -18,6 +18,13 @@ from models.model import App, AppMode from models.tools import WorkflowToolProvider from models.workflow import Workflow +VARIABLE_TO_PARAMETER_TYPE_MAPPING = { + VariableEntityType.TEXT_INPUT: ToolParameter.ToolParameterType.STRING, + VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING, + VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT, + VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER, +} + class WorkflowToolProviderController(ToolProviderController): provider_id: str @@ -28,7 +35,7 @@ class WorkflowToolProviderController(ToolProviderController): if not app: raise ValueError('app not found') - + controller = WorkflowToolProviderController(**{ 'identity': { 'author': db_provider.user.name if db_provider.user_id and db_provider.user else '', @@ -46,7 +53,7 @@ class WorkflowToolProviderController(ToolProviderController): 'credentials_schema': {}, 'provider_id': db_provider.id or '', }) - + # init tools controller.tools = [controller._get_db_provider_tool(db_provider, app)] @@ -56,7 +63,7 @@ class WorkflowToolProviderController(ToolProviderController): @property def provider_type(self) -> ToolProviderType: return ToolProviderType.WORKFLOW - + def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool: """ get db provider tool @@ -93,23 +100,11 @@ class WorkflowToolProviderController(ToolProviderController): if variable: parameter_type = None options = None - if variable.type in [ - VariableEntity.Type.TEXT_INPUT, - VariableEntity.Type.PARAGRAPH, - ]: - parameter_type = ToolParameter.ToolParameterType.STRING - elif variable.type in [ - VariableEntity.Type.SELECT - ]: - parameter_type = ToolParameter.ToolParameterType.SELECT - elif variable.type in [ - VariableEntity.Type.NUMBER - ]: - parameter_type = ToolParameter.ToolParameterType.NUMBER - else: + if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING: raise ValueError(f'unsupported variable type {variable.type}') - - if variable.type == VariableEntity.Type.SELECT and variable.options: + parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type] + + if variable.type == VariableEntityType.SELECT and variable.options: options = [ ToolParameterOption( value=option, @@ -200,7 +195,7 @@ class WorkflowToolProviderController(ToolProviderController): """ if self.tools is not None: return self.tools - + db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == self.provider_id, @@ -208,11 +203,11 @@ class WorkflowToolProviderController(ToolProviderController): if not db_providers: return [] - + self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)] return self.tools - + def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: """ get tool by name @@ -226,5 +221,5 @@ class WorkflowToolProviderController(ToolProviderController): for tool in self.tools: if tool.identity.name == tool_name: return tool - + return None diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index 69e3dfa061..38f10032e2 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -144,7 +144,7 @@ class ApiTool(Tool): path_params[parameter['name']] = value elif parameter['in'] == 'query': - params[parameter['name']] = value + if value !='': params[parameter['name']] = value elif parameter['in'] == 'cookie': cookies[parameter['name']] = value diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index d7ddb40e6b..4a0188af49 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -10,14 +10,11 @@ from configs import dify_config from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom from core.helper.module_import_helper import load_single_subclass_from_source +from core.helper.position_helper import is_filtered from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ( - ApiProviderAuthType, - ToolInvokeFrom, - ToolParameter, -) +from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter from core.tools.errors import ToolProviderNotFoundError from core.tools.provider.api_tool_provider import ApiToolProviderController from core.tools.provider.builtin._positions import BuiltinToolProviderSort @@ -26,10 +23,7 @@ from core.tools.tool.api_tool import ApiTool from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.tool import Tool from core.tools.tool_label_manager import ToolLabelManager -from core.tools.utils.configuration import ( - ToolConfigurationManager, - ToolParameterConfigurationManager, -) +from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager from core.tools.utils.tool_parameter_converter import ToolParameterConverter from core.workflow.nodes.tool.entities import ToolEntity from extensions.ext_database import db @@ -38,6 +32,7 @@ from services.tools.tools_transform_service import ToolTransformService logger = logging.getLogger(__name__) + class ToolManager: _builtin_provider_lock = Lock() _builtin_providers = {} @@ -107,7 +102,7 @@ class ToolManager: tenant_id: str, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \ - -> Union[BuiltinTool, ApiTool]: + -> Union[BuiltinTool, ApiTool]: """ get the tool runtime @@ -346,7 +341,7 @@ class ToolManager: provider_class = load_single_subclass_from_source( module_name=f'core.tools.provider.builtin.{provider}.{provider}', script_path=path.join(path.dirname(path.realpath(__file__)), - 'provider', 'builtin', provider, f'{provider}.py'), + 'provider', 'builtin', provider, f'{provider}.py'), parent_type=BuiltinToolProviderController) provider: BuiltinToolProviderController = provider_class() cls._builtin_providers[provider.identity.name] = provider @@ -414,6 +409,15 @@ class ToolManager: # append builtin providers for provider in builtin_providers: + # handle include, exclude + if is_filtered( + include_set=dify_config.POSITION_TOOL_INCLUDES_SET, + exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, + data=provider, + name_func=lambda x: x.identity.name + ): + continue + user_provider = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider, db_provider=find_db_builtin_provider(provider.identity.name), @@ -473,7 +477,7 @@ class ToolManager: @classmethod def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[ - ApiToolProviderController, dict[str, Any]]: + ApiToolProviderController, dict[str, Any]]: """ get the api provider @@ -593,4 +597,5 @@ class ToolManager: else: raise ValueError(f"provider type {provider_type} not found") + ToolManager.load_builtin_providers_cache() diff --git a/api/core/tools/utils/feishu_api_utils.py b/api/core/tools/utils/feishu_api_utils.py new file mode 100644 index 0000000000..e6b288868f --- /dev/null +++ b/api/core/tools/utils/feishu_api_utils.py @@ -0,0 +1,143 @@ +import httpx + +from extensions.ext_redis import redis_client + + +class FeishuRequest: + def __init__(self, app_id: str, app_secret: str): + self.app_id = app_id + self.app_secret = app_secret + + @property + def tenant_access_token(self): + feishu_tenant_access_token = f"tools:{self.app_id}:feishu_tenant_access_token" + if redis_client.exists(feishu_tenant_access_token): + return redis_client.get(feishu_tenant_access_token).decode() + res = self.get_tenant_access_token(self.app_id, self.app_secret) + redis_client.setex(feishu_tenant_access_token, res.get("expire"), res.get("tenant_access_token")) + return res.get("tenant_access_token") + + def _send_request(self, url: str, method: str = "post", require_token: bool = True, payload: dict = None, + params: dict = None): + headers = { + "Content-Type": "application/json", + "user-agent": "Dify", + } + if require_token: + headers["tenant-access-token"] = f"{self.tenant_access_token}" + res = httpx.request(method=method, url=url, headers=headers, json=payload, params=params, timeout=30).json() + if res.get("code") != 0: + raise Exception(res) + return res + + def get_tenant_access_token(self, app_id: str, app_secret: str) -> dict: + """ + API url: https://open.feishu.cn/document/server-docs/authentication-management/access-token/tenant_access_token_internal + Example Response: + { + "code": 0, + "msg": "ok", + "tenant_access_token": "t-caecc734c2e3328a62489fe0648c4b98779515d3", + "expire": 7200 + } + """ + url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/access_token/get_tenant_access_token" + payload = { + "app_id": app_id, + "app_secret": app_secret + } + res = self._send_request(url, require_token=False, payload=payload) + return res + + def create_document(self, title: str, content: str, folder_token: str) -> dict: + """ + API url: https://open.larkoffice.com/document/server-docs/docs/docs/docx-v1/document/create + Example Response: + { + "data": { + "title": "title", + "url": "https://svi136aogf123.feishu.cn/docx/VWbvd4fEdoW0WSxaY1McQTz8n7d", + "type": "docx", + "token": "VWbvd4fEdoW0WSxaY1McQTz8n7d" + }, + "log_id": "021721281231575fdbddc0200ff00060a9258ec0000103df61b5d", + "code": 0, + "msg": "创建飞书文档成功,请查看" + } + """ + url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/document/create_document" + payload = { + "title": title, + "content": content, + "folder_token": folder_token, + } + res = self._send_request(url, payload=payload) + return res.get("data") + + def write_document(self, document_id: str, content: str, position: str = "start") -> dict: + url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/document/write_document" + payload = { + "document_id": document_id, + "content": content, + "position": position + } + res = self._send_request(url, payload=payload) + return res.get("data") + + def get_document_raw_content(self, document_id: str) -> dict: + """ + API url: https://open.larkoffice.com/document/server-docs/docs/docs/docx-v1/document/raw_content + Example Response: + { + "code": 0, + "msg": "success", + "data": { + "content": "云文档\n多人实时协同,插入一切元素。不仅是在线文档,更是强大的创作和互动工具\n云文档:专为协作而生\n" + } + } + """ + params = { + "document_id": document_id, + } + url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/document/get_document_raw_content" + res = self._send_request(url, method="get", params=params) + return res.get("data").get("content") + + def list_document_block(self, document_id: str, page_token: str, page_size: int = 500) -> dict: + """ + API url: https://open.larkoffice.com/document/server-docs/docs/docs/docx-v1/document/list + """ + url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/document/list_document_block" + params = { + "document_id": document_id, + "page_size": page_size, + "page_token": page_token, + } + res = self._send_request(url, method="get", params=params) + return res.get("data") + + def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> dict: + """ + API url: https://open.larkoffice.com/document/server-docs/im-v1/message/create + """ + url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/message/send_bot_message" + params = { + "receive_id_type": receive_id_type, + } + payload = { + "receive_id": receive_id, + "msg_type": msg_type, + "content": content, + } + res = self._send_request(url, params=params, payload=payload) + return res.get("data") + + def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dict: + url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/message/send_webhook_message" + payload = { + "webhook": webhook, + "msg_type": msg_type, + "content": content, + } + res = self._send_request(url, require_token=False, payload=payload) + return res diff --git a/api/core/tools/utils/yaml_utils.py b/api/core/tools/utils/yaml_utils.py index 21155a6960..f751c43096 100644 --- a/api/core/tools/utils/yaml_utils.py +++ b/api/core/tools/utils/yaml_utils.py @@ -26,7 +26,6 @@ def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any raise YAMLError(f'Failed to load YAML file {file_path}: {e}') except Exception as e: if ignore_error: - logger.debug(f'Failed to load YAML file {file_path}: {e}') return default_value else: raise e diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 9fe3356faa..8120b2ac78 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -6,20 +6,20 @@ from typing_extensions import deprecated from core.app.segments import Segment, Variable, factory from core.file.file_obj import FileVar -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey VariableValue = Union[str, int, float, dict, list, FileVar] -SYSTEM_VARIABLE_NODE_ID = 'sys' -ENVIRONMENT_VARIABLE_NODE_ID = 'env' -CONVERSATION_VARIABLE_NODE_ID = 'conversation' +SYSTEM_VARIABLE_NODE_ID = "sys" +ENVIRONMENT_VARIABLE_NODE_ID = "env" +CONVERSATION_VARIABLE_NODE_ID = "conversation" class VariablePool: def __init__( self, - system_variables: Mapping[SystemVariable, Any], + system_variables: Mapping[SystemVariableKey, Any], user_inputs: Mapping[str, Any], environment_variables: Sequence[Variable], conversation_variables: Sequence[Variable] | None = None, @@ -68,7 +68,7 @@ class VariablePool: None """ if len(selector) < 2: - raise ValueError('Invalid selector') + raise ValueError("Invalid selector") if value is None: return @@ -95,13 +95,13 @@ class VariablePool: ValueError: If the selector is invalid. """ if len(selector) < 2: - raise ValueError('Invalid selector') + raise ValueError("Invalid selector") hash_key = hash(tuple(selector[1:])) value = self._variable_dictionary[selector[0]].get(hash_key) return value - @deprecated('This method is deprecated, use `get` instead.') + @deprecated("This method is deprecated, use `get` instead.") def get_any(self, selector: Sequence[str], /) -> Any | None: """ Retrieves the value from the variable pool based on the given selector. @@ -116,7 +116,7 @@ class VariablePool: ValueError: If the selector is invalid. """ if len(selector) < 2: - raise ValueError('Invalid selector') + raise ValueError("Invalid selector") hash_key = hash(tuple(selector[1:])) value = self._variable_dictionary[selector[0]].get(hash_key) return value.to_object() if value else None diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py index 4757cf32f8..da65f6b1fb 100644 --- a/api/core/workflow/enums.py +++ b/api/core/workflow/enums.py @@ -1,25 +1,13 @@ from enum import Enum -class SystemVariable(str, Enum): +class SystemVariableKey(str, Enum): """ System Variables. """ - QUERY = 'query' - FILES = 'files' - CONVERSATION_ID = 'conversation_id' - USER_ID = 'user_id' - DIALOGUE_COUNT = 'dialogue_count' - @classmethod - def value_of(cls, value: str): - """ - Get value of given system variable. - - :param value: system variable value - :return: system variable - """ - for system_variable in cls: - if system_variable.value == value: - return system_variable - raise ValueError(f'invalid system variable value {value}') + QUERY = "query" + FILES = "files" + CONVERSATION_ID = "conversation_id" + USER_ID = "user_id" + DIALOGUE_COUNT = "dialogue_count" diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 60678bc2ba..335991ae87 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -11,19 +11,10 @@ from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.code.entities import CodeNodeData from models.workflow import WorkflowNodeExecutionStatus -MAX_NUMBER = dify_config.CODE_MAX_NUMBER -MIN_NUMBER = dify_config.CODE_MIN_NUMBER -MAX_PRECISION = 20 -MAX_DEPTH = 5 -MAX_STRING_LENGTH = dify_config.CODE_MAX_STRING_LENGTH -MAX_STRING_ARRAY_LENGTH = dify_config.CODE_MAX_STRING_ARRAY_LENGTH -MAX_OBJECT_ARRAY_LENGTH = dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH -MAX_NUMBER_ARRAY_LENGTH = dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH - class CodeNode(BaseNode): _node_data_cls = CodeNodeData - node_type = NodeType.CODE + _node_type = NodeType.CODE @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: @@ -48,8 +39,7 @@ class CodeNode(BaseNode): :param variable_pool: variable pool :return: """ - node_data = self.node_data - node_data: CodeNodeData = cast(self._node_data_cls, node_data) + node_data = cast(CodeNodeData, self.node_data) # Get code language code_language = node_data.code_language @@ -68,7 +58,6 @@ class CodeNode(BaseNode): language=code_language, code=code, inputs=variables, - dependencies=node_data.dependencies ) # Transform result @@ -99,8 +88,9 @@ class CodeNode(BaseNode): else: raise ValueError(f"Output variable `{variable}` must be a string") - if len(value) > MAX_STRING_LENGTH: - raise ValueError(f'The length of output variable `{variable}` must be less than {MAX_STRING_LENGTH} characters') + if len(value) > dify_config.CODE_MAX_STRING_LENGTH: + raise ValueError(f'The length of output variable `{variable}` must be' + f' less than {dify_config.CODE_MAX_STRING_LENGTH} characters') return value.replace('\x00', '') @@ -117,13 +107,15 @@ class CodeNode(BaseNode): else: raise ValueError(f"Output variable `{variable}` must be a number") - if value > MAX_NUMBER or value < MIN_NUMBER: - raise ValueError(f'Output variable `{variable}` is out of range, it must be between {MIN_NUMBER} and {MAX_NUMBER}.') + if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER: + raise ValueError(f'Output variable `{variable}` is out of range,' + f' it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}.') if isinstance(value, float): # raise error if precision is too high - if len(str(value).split('.')[1]) > MAX_PRECISION: - raise ValueError(f'Output variable `{variable}` has too high precision, it must be less than {MAX_PRECISION} digits.') + if len(str(value).split('.')[1]) > dify_config.CODE_MAX_PRECISION: + raise ValueError(f'Output variable `{variable}` has too high precision,' + f' it must be less than {dify_config.CODE_MAX_PRECISION} digits.') return value @@ -136,8 +128,8 @@ class CodeNode(BaseNode): :param output_schema: output schema :return: """ - if depth > MAX_DEPTH: - raise ValueError("Depth limit reached, object too deep.") + if depth > dify_config.CODE_MAX_DEPTH: + raise ValueError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.") transformed_result = {} if output_schema is None: @@ -237,9 +229,10 @@ class CodeNode(BaseNode): f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.' ) else: - if len(result[output_name]) > MAX_NUMBER_ARRAY_LENGTH: + if len(result[output_name]) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH: raise ValueError( - f'The length of output variable `{prefix}{dot}{output_name}` must be less than {MAX_NUMBER_ARRAY_LENGTH} elements.' + f'The length of output variable `{prefix}{dot}{output_name}` must be' + f' less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements.' ) transformed_result[output_name] = [ @@ -259,9 +252,10 @@ class CodeNode(BaseNode): f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.' ) else: - if len(result[output_name]) > MAX_STRING_ARRAY_LENGTH: + if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH: raise ValueError( - f'The length of output variable `{prefix}{dot}{output_name}` must be less than {MAX_STRING_ARRAY_LENGTH} elements.' + f'The length of output variable `{prefix}{dot}{output_name}` must be' + f' less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements.' ) transformed_result[output_name] = [ @@ -281,9 +275,10 @@ class CodeNode(BaseNode): f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.' ) else: - if len(result[output_name]) > MAX_OBJECT_ARRAY_LENGTH: + if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH: raise ValueError( - f'The length of output variable `{prefix}{dot}{output_name}` must be less than {MAX_OBJECT_ARRAY_LENGTH} elements.' + f'The length of output variable `{prefix}{dot}{output_name}` must be' + f' less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements.' ) for i, value in enumerate(result[output_name]): diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index 83a5416d57..c0701ecccd 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -3,7 +3,6 @@ from typing import Literal, Optional from pydantic import BaseModel from core.helper.code_executor.code_executor import CodeLanguage -from core.helper.code_executor.entities import CodeDependency from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector @@ -16,8 +15,12 @@ class CodeNodeData(BaseNodeData): type: Literal['string', 'number', 'object', 'array[string]', 'array[number]', 'array[object]'] children: Optional[dict[str, 'Output']] = None + class Dependency(BaseModel): + name: str + version: str + variables: list[VariableSelector] code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT] code: str outputs: dict[str, Output] - dependencies: Optional[list[CodeDependency]] = None \ No newline at end of file + dependencies: Optional[list[Dependency]] = None \ No newline at end of file diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py index 90d644e0e2..c066d469d8 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/core/workflow/nodes/http_request/entities.py @@ -5,10 +5,6 @@ from pydantic import BaseModel, ValidationInfo, field_validator from configs import dify_config from core.workflow.entities.base_node_data_entities import BaseNodeData -MAX_CONNECT_TIMEOUT = dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT -MAX_READ_TIMEOUT = dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT -MAX_WRITE_TIMEOUT = dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT - class HttpRequestNodeAuthorizationConfig(BaseModel): type: Literal[None, 'basic', 'bearer', 'custom'] @@ -41,9 +37,9 @@ class HttpRequestNodeBody(BaseModel): class HttpRequestNodeTimeout(BaseModel): - connect: int = MAX_CONNECT_TIMEOUT - read: int = MAX_READ_TIMEOUT - write: int = MAX_WRITE_TIMEOUT + connect: int = dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT + read: int = dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT + write: int = dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT class HttpRequestNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py index db18bd00b2..d16bff58bd 100644 --- a/api/core/workflow/nodes/http_request/http_executor.py +++ b/api/core/workflow/nodes/http_request/http_executor.py @@ -18,11 +18,6 @@ from core.workflow.nodes.http_request.entities import ( ) from core.workflow.utils.variable_template_parser import VariableTemplateParser -MAX_BINARY_SIZE = dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE -READABLE_MAX_BINARY_SIZE = dify_config.HTTP_REQUEST_NODE_READABLE_MAX_BINARY_SIZE -MAX_TEXT_SIZE = dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE -READABLE_MAX_TEXT_SIZE = dify_config.HTTP_REQUEST_NODE_READABLE_MAX_TEXT_SIZE - class HttpExecutorResponse: headers: dict[str, str] @@ -237,16 +232,14 @@ class HttpExecutor: else: raise ValueError(f'Invalid response type {type(response)}') - if executor_response.is_file: - if executor_response.size > MAX_BINARY_SIZE: - raise ValueError( - f'File size is too large, max size is {READABLE_MAX_BINARY_SIZE}, but current size is {executor_response.readable_size}.' - ) - else: - if executor_response.size > MAX_TEXT_SIZE: - raise ValueError( - f'Text size is too large, max size is {READABLE_MAX_TEXT_SIZE}, but current size is {executor_response.readable_size}.' - ) + threshold_size = dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE if executor_response.is_file \ + else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE + if executor_response.size > threshold_size: + raise ValueError( + f'{"File" if executor_response.is_file else "Text"} size is too large,' + f' max size is {threshold_size / 1024 / 1024:.2f} MB,' + f' but current size is {executor_response.readable_size}.' + ) return executor_response diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index 1facf8a4f4..037a7a1848 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -3,6 +3,7 @@ from mimetypes import guess_extension from os import path from typing import cast +from configs import dify_config from core.app.segments import parser from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.tools.tool_file_manager import ToolFileManager @@ -11,9 +12,6 @@ from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.http_request.entities import ( - MAX_CONNECT_TIMEOUT, - MAX_READ_TIMEOUT, - MAX_WRITE_TIMEOUT, HttpRequestNodeData, HttpRequestNodeTimeout, ) @@ -21,9 +19,9 @@ from core.workflow.nodes.http_request.http_executor import HttpExecutor, HttpExe from models.workflow import WorkflowNodeExecutionStatus HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout( - connect=min(10, MAX_CONNECT_TIMEOUT), - read=min(60, MAX_READ_TIMEOUT), - write=min(20, MAX_WRITE_TIMEOUT), + connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, + read=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, + write=dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, ) @@ -43,9 +41,9 @@ class HttpRequestNode(BaseNode): 'body': {'type': 'none'}, 'timeout': { **HTTP_REQUEST_DEFAULT_TIMEOUT.model_dump(), - 'max_connect_timeout': MAX_CONNECT_TIMEOUT, - 'max_read_timeout': MAX_READ_TIMEOUT, - 'max_write_timeout': MAX_WRITE_TIMEOUT, + 'max_connect_timeout': dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, + 'max_read_timeout': dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, + 'max_write_timeout': dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, }, }, } @@ -92,17 +90,15 @@ class HttpRequestNode(BaseNode): }, ) - def _get_request_timeout(self, node_data: HttpRequestNodeData) -> HttpRequestNodeTimeout: + @staticmethod + def _get_request_timeout(node_data: HttpRequestNodeData) -> HttpRequestNodeTimeout: timeout = node_data.timeout if timeout is None: return HTTP_REQUEST_DEFAULT_TIMEOUT timeout.connect = timeout.connect or HTTP_REQUEST_DEFAULT_TIMEOUT.connect - timeout.connect = min(timeout.connect, MAX_CONNECT_TIMEOUT) timeout.read = timeout.read or HTTP_REQUEST_DEFAULT_TIMEOUT.read - timeout.read = min(timeout.read, MAX_READ_TIMEOUT) timeout.write = timeout.write or HTTP_REQUEST_DEFAULT_TIMEOUT.write - timeout.write = min(timeout.write, MAX_WRITE_TIMEOUT) return timeout @classmethod diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index c20e0d4506..49f61bd597 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -24,7 +24,7 @@ from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptT from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.llm.entities import ( LLMNodeChatModelMessage, @@ -94,7 +94,7 @@ class LLMNode(BaseNode): # fetch prompt messages prompt_messages, stop = self._fetch_prompt_messages( node_data=node_data, - query=variable_pool.get_any(['sys', SystemVariable.QUERY.value]) + query=variable_pool.get_any(['sys', SystemVariableKey.QUERY.value]) if node_data.memory else None, query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None, inputs=inputs, @@ -109,11 +109,13 @@ class LLMNode(BaseNode): 'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving( model_mode=model_config.mode, prompt_messages=prompt_messages - ) + ), + 'model_provider': model_config.provider, + 'model_name': model_config.model, } # handle invoke result - result_text, usage = self._invoke_llm( + result_text, usage, finish_reason = self._invoke_llm( node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, @@ -129,7 +131,8 @@ class LLMNode(BaseNode): outputs = { 'text': result_text, - 'usage': jsonable_encoder(usage) + 'usage': jsonable_encoder(usage), + 'finish_reason': finish_reason } return NodeRunResult( @@ -167,14 +170,14 @@ class LLMNode(BaseNode): ) # handle invoke result - text, usage = self._handle_invoke_result( + text, usage, finish_reason = self._handle_invoke_result( invoke_result=invoke_result ) # deduct quota self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) - return text, usage + return text, usage, finish_reason def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]: """ @@ -186,6 +189,7 @@ class LLMNode(BaseNode): prompt_messages = [] full_text = '' usage = None + finish_reason = None for result in invoke_result: text = result.delta.message.content full_text += text @@ -201,10 +205,13 @@ class LLMNode(BaseNode): if not usage and result.delta.usage: usage = result.delta.usage + if not finish_reason and result.delta.finish_reason: + finish_reason = result.delta.finish_reason + if not usage: usage = LLMUsage.empty_usage() - return full_text, usage + return full_text, usage, finish_reason def _transform_chat_messages(self, messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate @@ -335,7 +342,7 @@ class LLMNode(BaseNode): if not node_data.vision.enabled: return [] - files = variable_pool.get_any(['sys', SystemVariable.FILES.value]) + files = variable_pool.get_any(['sys', SystemVariableKey.FILES.value]) if not files: return [] @@ -500,7 +507,7 @@ class LLMNode(BaseNode): return None # get conversation id - conversation_id = variable_pool.get_any(['sys', SystemVariable.CONVERSATION_ID.value]) + conversation_id = variable_pool.get_any(['sys', SystemVariableKey.CONVERSATION_ID.value]) if conversation_id is None: return None @@ -672,10 +679,10 @@ class LLMNode(BaseNode): variable_mapping['#context#'] = node_data.context.variable_selector if node_data.vision.enabled: - variable_mapping['#files#'] = ['sys', SystemVariable.FILES.value] + variable_mapping['#files#'] = ['sys', SystemVariableKey.FILES.value] if node_data.memory: - variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value] + variable_mapping['#sys.query#'] = ['sys', SystemVariableKey.QUERY.value] if node_data.prompt_config: enable_jinja = False diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 2e1464efce..f4057d50f3 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -63,7 +63,7 @@ class QuestionClassifierNode(LLMNode): ) # handle invoke result - result_text, usage = self._invoke_llm( + result_text, usage, finish_reason = self._invoke_llm( node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, @@ -93,6 +93,7 @@ class QuestionClassifierNode(LLMNode): prompt_messages=prompt_messages ), 'usage': jsonable_encoder(usage), + 'finish_reason': finish_reason } outputs = { 'class_name': category_name diff --git a/api/core/workflow/nodes/start/entities.py b/api/core/workflow/nodes/start/entities.py index 0bd5f203bf..b81ce15bd7 100644 --- a/api/core/workflow/nodes/start/entities.py +++ b/api/core/workflow/nodes/start/entities.py @@ -1,3 +1,7 @@ +from collections.abc import Sequence + +from pydantic import Field + from core.app.app_config.entities import VariableEntity from core.workflow.entities.base_node_data_entities import BaseNodeData @@ -6,4 +10,4 @@ class StartNodeData(BaseNodeData): """ Start Node Data """ - variables: list[VariableEntity] = [] + variables: Sequence[VariableEntity] = Field(default_factory=list) diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 661b403d32..54e66bd671 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,7 +1,7 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool +from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID, VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.start.entities import StartNodeData from models.workflow import WorkflowNodeExecutionStatus @@ -17,16 +17,16 @@ class StartNode(BaseNode): :param variable_pool: variable pool :return: """ - # Get cleaned inputs - cleaned_inputs = dict(variable_pool.user_inputs) + node_inputs = dict(variable_pool.user_inputs) + system_inputs = variable_pool.system_variables - for var in variable_pool.system_variables: - cleaned_inputs['sys.' + var.value] = variable_pool.system_variables[var] + for var in system_inputs: + node_inputs[SYSTEM_VARIABLE_NODE_ID + '.' + var] = system_inputs[var] return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=cleaned_inputs, - outputs=cleaned_inputs + inputs=node_inputs, + outputs=node_inputs ) @classmethod diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 554e3b6074..ccce9ef360 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -2,7 +2,7 @@ from collections.abc import Mapping, Sequence from os import path from typing import Any, cast -from core.app.segments import ArrayAnyVariable, parser +from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter @@ -11,7 +11,7 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.utils.variable_template_parser import VariableTemplateParser @@ -141,8 +141,8 @@ class ToolNode(BaseNode): return result def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]: - variable = variable_pool.get(['sys', SystemVariable.FILES.value]) - assert isinstance(variable, ArrayAnyVariable) + variable = variable_pool.get(['sys', SystemVariableKey.FILES.value]) + assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) return list(variable.value) if variable else [] def _convert_tool_messages(self, messages: list[ToolInvokeMessage]): diff --git a/api/core/workflow/nodes/variable_aggregator/entities.py b/api/core/workflow/nodes/variable_aggregator/entities.py index cea88334b9..e5de38dc0f 100644 --- a/api/core/workflow/nodes/variable_aggregator/entities.py +++ b/api/core/workflow/nodes/variable_aggregator/entities.py @@ -17,7 +17,7 @@ class AdvancedSettings(BaseModel): """ Group. """ - output_type: Literal['string', 'number', 'array', 'object'] + output_type: Literal['string', 'number', 'object', 'array[string]', 'array[number]', 'array[object]'] variables: list[list[str]] group_name: str @@ -30,4 +30,4 @@ class VariableAssignerNodeData(BaseNodeData): type: str = 'variable-assigner' output_type: str variables: list[list[str]] - advanced_settings: Optional[AdvancedSettings] = None \ No newline at end of file + advanced_settings: Optional[AdvancedSettings] = None diff --git a/api/core/workflow/nodes/variable_assigner/__init__.py b/api/core/workflow/nodes/variable_assigner/__init__.py index 552cc367f2..d791d51523 100644 --- a/api/core/workflow/nodes/variable_assigner/__init__.py +++ b/api/core/workflow/nodes/variable_assigner/__init__.py @@ -1,109 +1,8 @@ -from collections.abc import Sequence -from enum import Enum -from typing import Optional, cast +from .node import VariableAssignerNode +from .node_data import VariableAssignerData, WriteMode -from sqlalchemy import select -from sqlalchemy.orm import Session - -from core.app.segments import SegmentType, Variable, factory -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseNode -from extensions.ext_database import db -from models import ConversationVariable, WorkflowNodeExecutionStatus - - -class VariableAssignerNodeError(Exception): - pass - - -class WriteMode(str, Enum): - OVER_WRITE = 'over-write' - APPEND = 'append' - CLEAR = 'clear' - - -class VariableAssignerData(BaseNodeData): - title: str = 'Variable Assigner' - desc: Optional[str] = 'Assign a value to a variable' - assigned_variable_selector: Sequence[str] - write_mode: WriteMode - input_variable_selector: Sequence[str] - - -class VariableAssignerNode(BaseNode): - _node_data_cls: type[BaseNodeData] = VariableAssignerData - _node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER - - def _run(self, variable_pool: VariablePool) -> NodeRunResult: - data = cast(VariableAssignerData, self.node_data) - - # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject - original_variable = variable_pool.get(data.assigned_variable_selector) - if not isinstance(original_variable, Variable): - raise VariableAssignerNodeError('assigned variable not found') - - match data.write_mode: - case WriteMode.OVER_WRITE: - income_value = variable_pool.get(data.input_variable_selector) - if not income_value: - raise VariableAssignerNodeError('input value not found') - updated_variable = original_variable.model_copy(update={'value': income_value.value}) - - case WriteMode.APPEND: - income_value = variable_pool.get(data.input_variable_selector) - if not income_value: - raise VariableAssignerNodeError('input value not found') - updated_value = original_variable.value + [income_value.value] - updated_variable = original_variable.model_copy(update={'value': updated_value}) - - case WriteMode.CLEAR: - income_value = get_zero_value(original_variable.value_type) - updated_variable = original_variable.model_copy(update={'value': income_value.to_object()}) - - case _: - raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}') - - # Over write the variable. - variable_pool.add(data.assigned_variable_selector, updated_variable) - - # Update conversation variable. - # TODO: Find a better way to use the database. - conversation_id = variable_pool.get(['sys', 'conversation_id']) - if not conversation_id: - raise VariableAssignerNodeError('conversation_id not found') - update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable) - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs={ - 'value': income_value.to_object(), - }, - ) - - -def update_conversation_variable(conversation_id: str, variable: Variable): - stmt = select(ConversationVariable).where( - ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id - ) - with Session(db.engine) as session: - row = session.scalar(stmt) - if not row: - raise VariableAssignerNodeError('conversation variable not found in the database') - row.data = variable.model_dump_json() - session.commit() - - -def get_zero_value(t: SegmentType): - match t: - case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER: - return factory.build_segment([]) - case SegmentType.OBJECT: - return factory.build_segment({}) - case SegmentType.STRING: - return factory.build_segment('') - case SegmentType.NUMBER: - return factory.build_segment(0) - case _: - raise VariableAssignerNodeError(f'unsupported variable type: {t}') +__all__ = [ + 'VariableAssignerNode', + 'VariableAssignerData', + 'WriteMode', +] diff --git a/api/core/workflow/nodes/variable_assigner/exc.py b/api/core/workflow/nodes/variable_assigner/exc.py new file mode 100644 index 0000000000..914be22256 --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/exc.py @@ -0,0 +1,2 @@ +class VariableAssignerNodeError(Exception): + pass diff --git a/api/core/workflow/nodes/variable_assigner/node.py b/api/core/workflow/nodes/variable_assigner/node.py new file mode 100644 index 0000000000..8c2adcabb9 --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/node.py @@ -0,0 +1,92 @@ +from typing import cast + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.segments import SegmentType, Variable, factory +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import BaseNode +from extensions.ext_database import db +from models import ConversationVariable, WorkflowNodeExecutionStatus + +from .exc import VariableAssignerNodeError +from .node_data import VariableAssignerData, WriteMode + + +class VariableAssignerNode(BaseNode): + _node_data_cls: type[BaseNodeData] = VariableAssignerData + _node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER + + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + data = cast(VariableAssignerData, self.node_data) + + # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject + original_variable = variable_pool.get(data.assigned_variable_selector) + if not isinstance(original_variable, Variable): + raise VariableAssignerNodeError('assigned variable not found') + + match data.write_mode: + case WriteMode.OVER_WRITE: + income_value = variable_pool.get(data.input_variable_selector) + if not income_value: + raise VariableAssignerNodeError('input value not found') + updated_variable = original_variable.model_copy(update={'value': income_value.value}) + + case WriteMode.APPEND: + income_value = variable_pool.get(data.input_variable_selector) + if not income_value: + raise VariableAssignerNodeError('input value not found') + updated_value = original_variable.value + [income_value.value] + updated_variable = original_variable.model_copy(update={'value': updated_value}) + + case WriteMode.CLEAR: + income_value = get_zero_value(original_variable.value_type) + updated_variable = original_variable.model_copy(update={'value': income_value.to_object()}) + + case _: + raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}') + + # Over write the variable. + variable_pool.add(data.assigned_variable_selector, updated_variable) + + # TODO: Move database operation to the pipeline. + # Update conversation variable. + conversation_id = variable_pool.get(['sys', 'conversation_id']) + if not conversation_id: + raise VariableAssignerNodeError('conversation_id not found') + update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={ + 'value': income_value.to_object(), + }, + ) + + +def update_conversation_variable(conversation_id: str, variable: Variable): + stmt = select(ConversationVariable).where( + ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id + ) + with Session(db.engine) as session: + row = session.scalar(stmt) + if not row: + raise VariableAssignerNodeError('conversation variable not found in the database') + row.data = variable.model_dump_json() + session.commit() + + +def get_zero_value(t: SegmentType): + match t: + case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER: + return factory.build_segment([]) + case SegmentType.OBJECT: + return factory.build_segment({}) + case SegmentType.STRING: + return factory.build_segment('') + case SegmentType.NUMBER: + return factory.build_segment(0) + case _: + raise VariableAssignerNodeError(f'unsupported variable type: {t}') diff --git a/api/core/workflow/nodes/variable_assigner/node_data.py b/api/core/workflow/nodes/variable_assigner/node_data.py new file mode 100644 index 0000000000..b3652b6802 --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/node_data.py @@ -0,0 +1,19 @@ +from collections.abc import Sequence +from enum import Enum +from typing import Optional + +from core.workflow.entities.base_node_data_entities import BaseNodeData + + +class WriteMode(str, Enum): + OVER_WRITE = 'over-write' + APPEND = 'append' + CLEAR = 'clear' + + +class VariableAssignerData(BaseNodeData): + title: str = 'Variable Assigner' + desc: Optional[str] = 'Assign a value to a variable' + assigned_variable_selector: Sequence[str] + write_mode: WriteMode + input_variable_selector: Sequence[str] diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 9cf5c505d1..1edc558676 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -20,11 +20,11 @@ if [[ "${MODE}" == "worker" ]]; then CONCURRENCY_OPTION="-c ${CELERY_WORKER_AMOUNT:-1}" fi - exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION --loglevel INFO \ + exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION --loglevel ${LOG_LEVEL} \ -Q ${CELERY_QUEUES:-dataset,generation,mail,ops_trace,app_deletion} elif [[ "${MODE}" == "beat" ]]; then - exec celery -A app.celery beat --loglevel INFO + exec celery -A app.celery beat --loglevel ${LOG_LEVEL} else if [[ "${DEBUG}" == "true" ]]; then exec flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug diff --git a/api/events/event_handlers/create_site_record_when_app_created.py b/api/events/event_handlers/create_site_record_when_app_created.py index ab07c5d366..1515661b2d 100644 --- a/api/events/event_handlers/create_site_record_when_app_created.py +++ b/api/events/event_handlers/create_site_record_when_app_created.py @@ -17,6 +17,8 @@ def handle(sender, **kwargs): default_language=account.interface_language, customize_token_strategy="not_allow", code=Site.generate_code(16), + created_by=app.created_by, + updated_by=app.updated_by, ) db.session.add(site) diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index 227c6635f0..3b7b0a37f4 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -1,3 +1,4 @@ +import openai import sentry_sdk from sentry_sdk.integrations.celery import CeleryIntegration from sentry_sdk.integrations.flask import FlaskIntegration @@ -9,7 +10,7 @@ def init_app(app): sentry_sdk.init( dsn=app.config.get("SENTRY_DSN"), integrations=[FlaskIntegration(), CeleryIntegration()], - ignore_errors=[HTTPException, ValueError], + ignore_errors=[HTTPException, ValueError, openai.APIStatusError], traces_sample_rate=app.config.get("SENTRY_TRACES_SAMPLE_RATE", 1.0), profiles_sample_rate=app.config.get("SENTRY_PROFILES_SAMPLE_RATE", 1.0), environment=app.config.get("DEPLOY_ENV"), diff --git a/api/extensions/storage/aliyun_storage.py b/api/extensions/storage/aliyun_storage.py index b962cedc55..bee237fc17 100644 --- a/api/extensions/storage/aliyun_storage.py +++ b/api/extensions/storage/aliyun_storage.py @@ -15,6 +15,7 @@ class AliyunStorage(BaseStorage): app_config = self.app.config self.bucket_name = app_config.get("ALIYUN_OSS_BUCKET_NAME") + self.folder = app.config.get("ALIYUN_OSS_PATH") oss_auth_method = aliyun_s3.Auth region = None if app_config.get("ALIYUN_OSS_AUTH_VERSION") == "v4": @@ -30,15 +31,29 @@ class AliyunStorage(BaseStorage): ) def save(self, filename, data): + if not self.folder or self.folder.endswith("/"): + filename = self.folder + filename + else: + filename = self.folder + "/" + filename self.client.put_object(filename, data) def load_once(self, filename: str) -> bytes: + if not self.folder or self.folder.endswith("/"): + filename = self.folder + filename + else: + filename = self.folder + "/" + filename + with closing(self.client.get_object(filename)) as obj: data = obj.read() return data def load_stream(self, filename: str) -> Generator: def generate(filename: str = filename) -> Generator: + if not self.folder or self.folder.endswith("/"): + filename = self.folder + filename + else: + filename = self.folder + "/" + filename + with closing(self.client.get_object(filename)) as obj: while chunk := obj.read(4096): yield chunk @@ -46,10 +61,24 @@ class AliyunStorage(BaseStorage): return generate() def download(self, filename, target_filepath): + if not self.folder or self.folder.endswith("/"): + filename = self.folder + filename + else: + filename = self.folder + "/" + filename + self.client.get_object_to_file(filename, target_filepath) def exists(self, filename): + if not self.folder or self.folder.endswith("/"): + filename = self.folder + filename + else: + filename = self.folder + "/" + filename + return self.client.object_exists(filename) def delete(self, filename): + if not self.folder or self.folder.endswith("/"): + filename = self.folder + filename + else: + filename = self.folder + "/" + filename self.client.delete_object(filename) diff --git a/api/extensions/storage/s3_storage.py b/api/extensions/storage/s3_storage.py index 022ce5b14a..0858be3af6 100644 --- a/api/extensions/storage/s3_storage.py +++ b/api/extensions/storage/s3_storage.py @@ -28,6 +28,19 @@ class S3Storage(BaseStorage): region_name=app_config.get("S3_REGION"), config=Config(s3={"addressing_style": app_config.get("S3_ADDRESS_STYLE")}), ) + # create bucket + try: + self.client.head_bucket(Bucket=self.bucket_name) + except ClientError as e: + # if bucket not exists, create it + if e.response["Error"]["Code"] == "404": + self.client.create_bucket(Bucket=self.bucket_name) + # if bucket is not accessible, pass, maybe the bucket is existing but not accessible + elif e.response["Error"]["Code"] == "403": + pass + else: + # other error, raise exception + raise def save(self, filename, data): self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data) diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index 26ed686783..aa353a3cc1 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -1,5 +1,6 @@ from flask_restful import fields +from fields.workflow_fields import workflow_partial_fields from libs.helper import AppIconUrlField, TimestampField app_detail_kernel_fields = { @@ -39,7 +40,10 @@ model_config_fields = { "completion_prompt_config": fields.Raw(attribute="completion_prompt_config_dict"), "dataset_configs": fields.Raw(attribute="dataset_configs_dict"), "file_upload": fields.Raw(attribute="file_upload_dict"), + "created_by": fields.String, "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, } app_detail_fields = { @@ -52,8 +56,13 @@ app_detail_fields = { "enable_site": fields.Boolean, "enable_api": fields.Boolean, "model_config": fields.Nested(model_config_fields, attribute="app_model_config", allow_null=True), + "workflow": fields.Nested(workflow_partial_fields, allow_null=True), "tracing": fields.Raw, + "use_icon_as_answer_icon": fields.Boolean, + "created_by": fields.String, "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, } prompt_config_fields = { @@ -63,6 +72,10 @@ prompt_config_fields = { model_config_partial_fields = { "model": fields.Raw(attribute="model_dict"), "pre_prompt": fields.String, + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, } tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String} @@ -78,7 +91,12 @@ app_partial_fields = { "icon_background": fields.String, "icon_url": AppIconUrlField, "model_config": fields.Nested(model_config_partial_fields, attribute="app_model_config", allow_null=True), + "workflow": fields.Nested(workflow_partial_fields, allow_null=True), + "use_icon_as_answer_icon": fields.Boolean, + "created_by": fields.String, "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, "tags": fields.List(fields.Nested(tag_fields)), } @@ -124,6 +142,11 @@ site_fields = { "prompt_public": fields.Boolean, "app_base_url": fields.String, "show_workflow_steps": fields.Boolean, + "use_icon_as_answer_icon": fields.Boolean, + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, } app_detail_fields_with_site = { @@ -138,9 +161,14 @@ app_detail_fields_with_site = { "enable_site": fields.Boolean, "enable_api": fields.Boolean, "model_config": fields.Nested(model_config_fields, attribute="app_model_config", allow_null=True), + "workflow": fields.Nested(workflow_partial_fields, allow_null=True), "site": fields.Nested(site_fields), "api_base_url": fields.String, + "use_icon_as_answer_icon": fields.Boolean, + "created_by": fields.String, "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, "deleted_tools": fields.List(fields.String), } @@ -160,4 +188,5 @@ app_site_fields = { "customize_token_strategy": fields.String, "prompt_public": fields.Boolean, "show_workflow_steps": fields.Boolean, + "use_icon_as_answer_icon": fields.Boolean, } diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 1b15fe3880..9207314fc2 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -111,6 +111,7 @@ conversation_fields = { "from_end_user_id": fields.String, "from_end_user_session_id": fields.String(), "from_account_id": fields.String, + "from_account_name": fields.String, "read_at": TimestampField, "created_at": TimestampField, "annotation": fields.Nested(annotation_fields, allow_null=True), @@ -146,10 +147,12 @@ conversation_with_summary_fields = { "from_end_user_id": fields.String, "from_end_user_session_id": fields.String, "from_account_id": fields.String, + "from_account_name": fields.String, "name": fields.String, "summary": fields.String(attribute="summary_or_query"), "read_at": TimestampField, "created_at": TimestampField, + "updated_at": TimestampField, "annotated": fields.Boolean, "model_config": fields.Nested(simple_model_config_fields), "message_count": fields.Integer, diff --git a/api/fields/installed_app_fields.py b/api/fields/installed_app_fields.py index 9afc1b1a4a..e0b3e340f6 100644 --- a/api/fields/installed_app_fields.py +++ b/api/fields/installed_app_fields.py @@ -10,6 +10,7 @@ app_fields = { "icon": fields.String, "icon_background": fields.String, "icon_url": AppIconUrlField, + "use_icon_as_answer_icon": fields.Boolean, } installed_app_fields = { diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 240b8f2eb0..2adef63ada 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -53,3 +53,11 @@ workflow_fields = { "environment_variables": fields.List(EnvironmentVariableField()), "conversation_variables": fields.List(fields.Nested(conversation_variable_fields)), } + +workflow_partial_fields = { + "id": fields.String, + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, +} diff --git a/api/migrations/versions/2024_08_20_0455-2dbe42621d96_rename_workflow__conversation_variables_.py b/api/migrations/versions/2024_08_20_0455-2dbe42621d96_rename_workflow__conversation_variables_.py new file mode 100644 index 0000000000..3dc7fed818 --- /dev/null +++ b/api/migrations/versions/2024_08_20_0455-2dbe42621d96_rename_workflow__conversation_variables_.py @@ -0,0 +1,28 @@ +"""rename workflow__conversation_variables to workflow_conversation_variables + +Revision ID: 2dbe42621d96 +Revises: a6be81136580 +Create Date: 2024-08-20 04:55:38.160010 + +""" +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '2dbe42621d96' +down_revision = 'a6be81136580' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.rename_table('workflow__conversation_variables', 'workflow_conversation_variables') + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.rename_table('workflow_conversation_variables', 'workflow__conversation_variables') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_08_25_0441-d0187d6a88dd_add_created_by_and_updated_by_to_app_.py b/api/migrations/versions/2024_08_25_0441-d0187d6a88dd_add_created_by_and_updated_by_to_app_.py new file mode 100644 index 0000000000..e0066a302c --- /dev/null +++ b/api/migrations/versions/2024_08_25_0441-d0187d6a88dd_add_created_by_and_updated_by_to_app_.py @@ -0,0 +1,52 @@ +"""add created_by and updated_by to app, modelconfig, and site + +Revision ID: d0187d6a88dd +Revises: 2dbe42621d96 +Create Date: 2024-08-25 04:41:18.157397 + +""" + +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = "d0187d6a88dd" +down_revision = "2dbe42621d96" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("app_model_configs", schema=None) as batch_op: + batch_op.add_column(sa.Column("created_by", models.types.StringUUID(), nullable=True)) + batch_op.add_column(sa.Column("updated_by", models.types.StringUUID(), nullable=True)) + + with op.batch_alter_table("apps", schema=None) as batch_op: + batch_op.add_column(sa.Column("created_by", models.types.StringUUID(), nullable=True)) + batch_op.add_column(sa.Column("updated_by", models.types.StringUUID(), nullable=True)) + + with op.batch_alter_table("sites", schema=None) as batch_op: + batch_op.add_column(sa.Column("created_by", models.types.StringUUID(), nullable=True)) + batch_op.add_column(sa.Column("updated_by", models.types.StringUUID(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("sites", schema=None) as batch_op: + batch_op.drop_column("updated_by") + batch_op.drop_column("created_by") + + with op.batch_alter_table("apps", schema=None) as batch_op: + batch_op.drop_column("updated_by") + batch_op.drop_column("created_by") + + with op.batch_alter_table("app_model_configs", schema=None) as batch_op: + batch_op.drop_column("updated_by") + batch_op.drop_column("created_by") + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_09_01_1255-030f4915f36a_add_use_icon_as_answer_icon_fields_for_.py b/api/migrations/versions/2024_09_01_1255-030f4915f36a_add_use_icon_as_answer_icon_fields_for_.py new file mode 100644 index 0000000000..4406d51ed0 --- /dev/null +++ b/api/migrations/versions/2024_09_01_1255-030f4915f36a_add_use_icon_as_answer_icon_fields_for_.py @@ -0,0 +1,45 @@ +"""add use_icon_as_answer_icon fields for app and site + +Revision ID: 030f4915f36a +Revises: d0187d6a88dd +Create Date: 2024-09-01 12:55:45.129687 + +""" + +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = "030f4915f36a" +down_revision = "d0187d6a88dd" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("apps", schema=None) as batch_op: + batch_op.add_column( + sa.Column("use_icon_as_answer_icon", sa.Boolean(), server_default=sa.text("false"), nullable=False) + ) + + with op.batch_alter_table("sites", schema=None) as batch_op: + batch_op.add_column( + sa.Column("use_icon_as_answer_icon", sa.Boolean(), server_default=sa.text("false"), nullable=False) + ) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table("sites", schema=None) as batch_op: + batch_op.drop_column("use_icon_as_answer_icon") + + with op.batch_alter_table("apps", schema=None) as batch_op: + batch_op.drop_column("use_icon_as_answer_icon") + + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index 3c6717c2c4..88c038a10b 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -1,4 +1,5 @@ import base64 +import enum import hashlib import hmac import json @@ -22,6 +23,11 @@ from .model import App, Tag, TagBinding, UploadFile from .types import StringUUID +class DatasetPermissionEnum(str, enum.Enum): + ONLY_ME = 'only_me' + ALL_TEAM = 'all_team_members' + PARTIAL_TEAM = 'partial_members' + class Dataset(db.Model): __tablename__ = 'datasets' __table_args__ = ( diff --git a/api/models/model.py b/api/models/model.py index 7301629771..e81d25fbc9 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -82,8 +82,11 @@ class App(db.Model): is_universal = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) tracing = db.Column(db.Text, nullable=True) max_active_requests = db.Column(db.Integer, nullable=True) + created_by = db.Column(StringUUID, nullable=True) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_by = db.Column(StringUUID, nullable=True) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) @property def desc_or_prompt(self): @@ -221,7 +224,9 @@ class AppModelConfig(db.Model): provider = db.Column(db.String(255), nullable=True) model_id = db.Column(db.String(255), nullable=True) configs = db.Column(db.JSON, nullable=True) + created_by = db.Column(StringUUID, nullable=True) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_by = db.Column(StringUUID, nullable=True) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) opening_statement = db.Column(db.Text) suggested_questions = db.Column(db.Text) @@ -490,7 +495,6 @@ class InstalledApp(db.Model): return tenant - class Conversation(db.Model): __tablename__ = 'conversations' __table_args__ = ( @@ -623,6 +627,15 @@ class Conversation(db.Model): return None + @property + def from_account_name(self): + if self.from_account_id: + account = db.session.query(Account).filter(Account.id == self.from_account_id).first() + if account: + return account.name + + return None + @property def in_debug_mode(self): return self.override_model_configs is not None @@ -1102,12 +1115,15 @@ class Site(db.Model): copyright = db.Column(db.String(255)) privacy_policy = db.Column(db.String(255)) show_workflow_steps = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) + use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) custom_disclaimer = db.Column(db.String(255), nullable=True) customize_domain = db.Column(db.String(255)) customize_token_strategy = db.Column(db.String(255), nullable=False) prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + created_by = db.Column(StringUUID, nullable=True) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_by = db.Column(StringUUID, nullable=True) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) code = db.Column(db.String(255)) diff --git a/api/models/workflow.py b/api/models/workflow.py index 759e07c715..cdd5e1992d 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,5 +1,6 @@ import json from collections.abc import Mapping, Sequence +from datetime import datetime from enum import Enum from typing import Any, Optional, Union @@ -110,19 +111,32 @@ class Workflow(db.Model): db.Index('workflow_version_idx', 'tenant_id', 'app_id', 'version'), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(StringUUID, nullable=False) - app_id = db.Column(StringUUID, nullable=False) - type = db.Column(db.String(255), nullable=False) - version = db.Column(db.String(255), nullable=False) - graph = db.Column(db.Text) - features = db.Column(db.Text) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_by = db.Column(StringUUID) - updated_at = db.Column(db.DateTime) - _environment_variables = db.Column('environment_variables', db.Text, nullable=False, server_default='{}') - _conversation_variables = db.Column('conversation_variables', db.Text, nullable=False, server_default='{}') + id: Mapped[str] = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) + app_id: Mapped[str] = db.Column(StringUUID, nullable=False) + type: Mapped[str] = db.Column(db.String(255), nullable=False) + version: Mapped[str] = db.Column(db.String(255), nullable=False) + graph: Mapped[str] = db.Column(db.Text) + features: Mapped[str] = db.Column(db.Text) + created_by: Mapped[str] = db.Column(StringUUID, nullable=False) + created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_by: Mapped[str] = db.Column(StringUUID) + updated_at: Mapped[datetime] = db.Column(db.DateTime) + _environment_variables: Mapped[str] = db.Column('environment_variables', db.Text, nullable=False, server_default='{}') + _conversation_variables: Mapped[str] = db.Column('conversation_variables', db.Text, nullable=False, server_default='{}') + + def __init__(self, *, tenant_id: str, app_id: str, type: str, version: str, graph: str, + features: str, created_by: str, environment_variables: Sequence[Variable], + conversation_variables: Sequence[Variable]): + self.tenant_id = tenant_id + self.app_id = app_id + self.type = type + self.version = version + self.graph = graph + self.features = features + self.created_by = created_by + self.environment_variables = environment_variables or [] + self.conversation_variables = conversation_variables or [] @property def created_by_account(self): @@ -724,7 +738,7 @@ class WorkflowAppLog(db.Model): class ConversationVariable(db.Model): - __tablename__ = 'workflow__conversation_variables' + __tablename__ = 'workflow_conversation_variables' id: Mapped[str] = db.Column(StringUUID, primary_key=True) conversation_id: Mapped[str] = db.Column(StringUUID, nullable=False, primary_key=True) diff --git a/api/poetry.lock b/api/poetry.lock index 9bfeec30d7..7d26dbdc57 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -2,98 +2,113 @@ [[package]] name = "aiohappyeyeballs" -version = "2.3.4" +version = "2.4.0" description = "Happy Eyeballs for asyncio" optional = false -python-versions = "<4.0,>=3.8" +python-versions = ">=3.8" files = [ - {file = "aiohappyeyeballs-2.3.4-py3-none-any.whl", hash = "sha256:40a16ceffcf1fc9e142fd488123b2e218abc4188cf12ac20c67200e1579baa42"}, - {file = "aiohappyeyeballs-2.3.4.tar.gz", hash = "sha256:7e1ae8399c320a8adec76f6c919ed5ceae6edd4c3672f4d9eae2b27e37c80ff6"}, + {file = "aiohappyeyeballs-2.4.0-py3-none-any.whl", hash = "sha256:7ce92076e249169a13c2f49320d1967425eaf1f407522d707d59cac7628d62bd"}, + {file = "aiohappyeyeballs-2.4.0.tar.gz", hash = "sha256:55a1714f084e63d49639800f95716da97a1f173d46a16dfcfda0016abb93b6b2"}, ] [[package]] name = "aiohttp" -version = "3.10.1" +version = "3.10.5" description = "Async http client/server framework (asyncio)" optional = false python-versions = ">=3.8" files = [ - {file = "aiohttp-3.10.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:47b4c2412960e64d97258f40616efddaebcb34ff664c8a972119ed38fac2a62c"}, - {file = "aiohttp-3.10.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e7dbf637f87dd315fa1f36aaed8afa929ee2c607454fb7791e74c88a0d94da59"}, - {file = "aiohttp-3.10.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c8fb76214b5b739ce59e2236a6489d9dc3483649cfd6f563dbf5d8e40dbdd57d"}, - {file = "aiohttp-3.10.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c577cdcf8f92862363b3d598d971c6a84ed8f0bf824d4cc1ce70c2fb02acb4a"}, - {file = "aiohttp-3.10.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:777e23609899cb230ad2642b4bdf1008890f84968be78de29099a8a86f10b261"}, - {file = "aiohttp-3.10.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b07286a1090483799599a2f72f76ac396993da31f6e08efedb59f40876c144fa"}, - {file = "aiohttp-3.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b9db600a86414a9a653e3c1c7f6a2f6a1894ab8f83d11505247bd1b90ad57157"}, - {file = "aiohttp-3.10.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:01c3f1eb280008e51965a8d160a108c333136f4a39d46f516c64d2aa2e6a53f2"}, - {file = "aiohttp-3.10.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f5dd109a925fee4c9ac3f6a094900461a2712df41745f5d04782ebcbe6479ccb"}, - {file = "aiohttp-3.10.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:8c81ff4afffef9b1186639506d70ea90888218f5ddfff03870e74ec80bb59970"}, - {file = "aiohttp-3.10.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:2a384dfbe8bfebd203b778a30a712886d147c61943675f4719b56725a8bbe803"}, - {file = "aiohttp-3.10.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:b9fb6508893dc31cfcbb8191ef35abd79751db1d6871b3e2caee83959b4d91eb"}, - {file = "aiohttp-3.10.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:88596384c3bec644a96ae46287bb646d6a23fa6014afe3799156aef42669c6bd"}, - {file = "aiohttp-3.10.1-cp310-cp310-win32.whl", hash = "sha256:68164d43c580c2e8bf8e0eb4960142919d304052ccab92be10250a3a33b53268"}, - {file = "aiohttp-3.10.1-cp310-cp310-win_amd64.whl", hash = "sha256:d6bbe2c90c10382ca96df33b56e2060404a4f0f88673e1e84b44c8952517e5f3"}, - {file = "aiohttp-3.10.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f6979b4f20d3e557a867da9d9227de4c156fcdcb348a5848e3e6190fd7feb972"}, - {file = "aiohttp-3.10.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:03c0c380c83f8a8d4416224aafb88d378376d6f4cadebb56b060688251055cd4"}, - {file = "aiohttp-3.10.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1c2b104e81b3c3deba7e6f5bc1a9a0e9161c380530479970766a6655b8b77c7c"}, - {file = "aiohttp-3.10.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b023b68c61ab0cd48bd38416b421464a62c381e32b9dc7b4bdfa2905807452a4"}, - {file = "aiohttp-3.10.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1a07c76a82390506ca0eabf57c0540cf5a60c993c442928fe4928472c4c6e5e6"}, - {file = "aiohttp-3.10.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:41d8dab8c64ded1edf117d2a64f353efa096c52b853ef461aebd49abae979f16"}, - {file = "aiohttp-3.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:615348fab1a9ef7d0960a905e83ad39051ae9cb0d2837da739b5d3a7671e497a"}, - {file = "aiohttp-3.10.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:256ee6044214ee9d66d531bb374f065ee94e60667d6bbeaa25ca111fc3997158"}, - {file = "aiohttp-3.10.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b7d5bb926805022508b7ddeaad957f1fce7a8d77532068d7bdb431056dc630cd"}, - {file = "aiohttp-3.10.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:028faf71b338f069077af6315ad54281612705d68889f5d914318cbc2aab0d50"}, - {file = "aiohttp-3.10.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:5c12310d153b27aa630750be44e79313acc4e864c421eb7d2bc6fa3429c41bf8"}, - {file = "aiohttp-3.10.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:de1a91d5faded9054957ed0a9e01b9d632109341942fc123947ced358c5d9009"}, - {file = "aiohttp-3.10.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:9c186b270979fb1dee3ababe2d12fb243ed7da08b30abc83ebac3a928a4ddb15"}, - {file = "aiohttp-3.10.1-cp311-cp311-win32.whl", hash = "sha256:4a9ce70f5e00380377aac0e568abd075266ff992be2e271765f7b35d228a990c"}, - {file = "aiohttp-3.10.1-cp311-cp311-win_amd64.whl", hash = "sha256:a77c79bac8d908d839d32c212aef2354d2246eb9deb3e2cb01ffa83fb7a6ea5d"}, - {file = "aiohttp-3.10.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:2212296cdb63b092e295c3e4b4b442e7b7eb41e8a30d0f53c16d5962efed395d"}, - {file = "aiohttp-3.10.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:4dcb127ca3eb0a61205818a606393cbb60d93b7afb9accd2fd1e9081cc533144"}, - {file = "aiohttp-3.10.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:cb8b79a65332e1a426ccb6290ce0409e1dc16b4daac1cc5761e059127fa3d134"}, - {file = "aiohttp-3.10.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:68cc24f707ed9cb961f6ee04020ca01de2c89b2811f3cf3361dc7c96a14bfbcc"}, - {file = "aiohttp-3.10.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9cb54f5725b4b37af12edf6c9e834df59258c82c15a244daa521a065fbb11717"}, - {file = "aiohttp-3.10.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:51d03e948e53b3639ce4d438f3d1d8202898ec6655cadcc09ec99229d4adc2a9"}, - {file = "aiohttp-3.10.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:786299d719eb5d868f161aeec56d589396b053925b7e0ce36e983d30d0a3e55c"}, - {file = "aiohttp-3.10.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:abda4009a30d51d3f06f36bc7411a62b3e647fa6cc935ef667e3e3d3a7dd09b1"}, - {file = "aiohttp-3.10.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:67f7639424c313125213954e93a6229d3a1d386855d70c292a12628f600c7150"}, - {file = "aiohttp-3.10.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:8e5a26d7aac4c0d8414a347da162696eea0629fdce939ada6aedf951abb1d745"}, - {file = "aiohttp-3.10.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:120548d89f14b76a041088b582454d89389370632ee12bf39d919cc5c561d1ca"}, - {file = "aiohttp-3.10.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:f5293726943bdcea24715b121d8c4ae12581441d22623b0e6ab12d07ce85f9c4"}, - {file = "aiohttp-3.10.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1f8605e573ed6c44ec689d94544b2c4bb1390aaa723a8b5a2cc0a5a485987a68"}, - {file = "aiohttp-3.10.1-cp312-cp312-win32.whl", hash = "sha256:e7168782621be4448d90169a60c8b37e9b0926b3b79b6097bc180c0a8a119e73"}, - {file = "aiohttp-3.10.1-cp312-cp312-win_amd64.whl", hash = "sha256:8fbf8c0ded367c5c8eaf585f85ca8dd85ff4d5b73fb8fe1e6ac9e1b5e62e11f7"}, - {file = "aiohttp-3.10.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:54b7f4a20d7cc6bfa4438abbde069d417bb7a119f870975f78a2b99890226d55"}, - {file = "aiohttp-3.10.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2fa643ca990323db68911b92f3f7a0ca9ae300ae340d0235de87c523601e58d9"}, - {file = "aiohttp-3.10.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d8311d0d690487359fe2247ec5d2cac9946e70d50dced8c01ce9e72341c21151"}, - {file = "aiohttp-3.10.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:222821c60b8f6a64c5908cb43d69c0ee978a1188f6a8433d4757d39231b42cdb"}, - {file = "aiohttp-3.10.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e7b55d9ede66af7feb6de87ff277e0ccf6d51c7db74cc39337fe3a0e31b5872d"}, - {file = "aiohttp-3.10.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5a95151a5567b3b00368e99e9c5334a919514f60888a6b6d2054fea5e66e527e"}, - {file = "aiohttp-3.10.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e9e9171d2fe6bfd9d3838a6fe63b1e91b55e0bf726c16edf265536e4eafed19"}, - {file = "aiohttp-3.10.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a57e73f9523e980f6101dc9a83adcd7ac0006ea8bf7937ca3870391c7bb4f8ff"}, - {file = "aiohttp-3.10.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:0df51a3d70a2bfbb9c921619f68d6d02591f24f10e9c76de6f3388c89ed01de6"}, - {file = "aiohttp-3.10.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:b0de63ff0307eac3961b4af74382d30220d4813f36b7aaaf57f063a1243b4214"}, - {file = "aiohttp-3.10.1-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:8db9b749f589b5af8e4993623dbda6716b2b7a5fcb0fa2277bf3ce4b278c7059"}, - {file = "aiohttp-3.10.1-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:6b14c19172eb53b63931d3e62a9749d6519f7c121149493e6eefca055fcdb352"}, - {file = "aiohttp-3.10.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:5cd57ad998e3038aa87c38fe85c99ed728001bf5dde8eca121cadee06ee3f637"}, - {file = "aiohttp-3.10.1-cp38-cp38-win32.whl", hash = "sha256:df31641e3f02b77eb3c5fb63c0508bee0fc067cf153da0e002ebbb0db0b6d91a"}, - {file = "aiohttp-3.10.1-cp38-cp38-win_amd64.whl", hash = "sha256:93094eba50bc2ad4c40ff4997ead1fdcd41536116f2e7d6cfec9596a8ecb3615"}, - {file = "aiohttp-3.10.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:440954ddc6b77257e67170d57b1026aa9545275c33312357472504eef7b4cc0b"}, - {file = "aiohttp-3.10.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f9f8beed277488a52ee2b459b23c4135e54d6a819eaba2e120e57311015b58e9"}, - {file = "aiohttp-3.10.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d8a8221a63602008550022aa3a4152ca357e1dde7ab3dd1da7e1925050b56863"}, - {file = "aiohttp-3.10.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a702bd3663b5cbf3916e84bf332400d24cdb18399f0877ca6b313ce6c08bfb43"}, - {file = "aiohttp-3.10.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1988b370536eb14f0ce7f3a4a5b422ab64c4e255b3f5d7752c5f583dc8c967fc"}, - {file = "aiohttp-3.10.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7ccf1f0a304352c891d124ac1a9dea59b14b2abed1704aaa7689fc90ef9c5be1"}, - {file = "aiohttp-3.10.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc3ea6ef2a83edad84bbdb5d96e22f587b67c68922cd7b6f9d8f24865e655bcf"}, - {file = "aiohttp-3.10.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:89b47c125ab07f0831803b88aeb12b04c564d5f07a1c1a225d4eb4d2f26e8b5e"}, - {file = "aiohttp-3.10.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:21778552ef3d44aac3278cc6f6d13a6423504fa5f09f2df34bfe489ed9ded7f5"}, - {file = "aiohttp-3.10.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:bde0693073fd5e542e46ea100aa6c1a5d36282dbdbad85b1c3365d5421490a92"}, - {file = "aiohttp-3.10.1-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:bf66149bb348d8e713f3a8e0b4f5b952094c2948c408e1cfef03b49e86745d60"}, - {file = "aiohttp-3.10.1-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:587237571a85716d6f71f60d103416c9df7d5acb55d96d3d3ced65f39bff9c0c"}, - {file = "aiohttp-3.10.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:bfe33cba6e127d0b5b417623c9aa621f0a69f304742acdca929a9fdab4593693"}, - {file = "aiohttp-3.10.1-cp39-cp39-win32.whl", hash = "sha256:9fbff00646cf8211b330690eb2fd64b23e1ce5b63a342436c1d1d6951d53d8dd"}, - {file = "aiohttp-3.10.1-cp39-cp39-win_amd64.whl", hash = "sha256:5951c328f9ac42d7bce7a6ded535879bc9ae13032818d036749631fa27777905"}, - {file = "aiohttp-3.10.1.tar.gz", hash = "sha256:8b0d058e4e425d3b45e8ec70d49b402f4d6b21041e674798b1f91ba027c73f28"}, + {file = "aiohttp-3.10.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:18a01eba2574fb9edd5f6e5fb25f66e6ce061da5dab5db75e13fe1558142e0a3"}, + {file = "aiohttp-3.10.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:94fac7c6e77ccb1ca91e9eb4cb0ac0270b9fb9b289738654120ba8cebb1189c6"}, + {file = "aiohttp-3.10.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2f1f1c75c395991ce9c94d3e4aa96e5c59c8356a15b1c9231e783865e2772699"}, + {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f7acae3cf1a2a2361ec4c8e787eaaa86a94171d2417aae53c0cca6ca3118ff6"}, + {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:94c4381ffba9cc508b37d2e536b418d5ea9cfdc2848b9a7fea6aebad4ec6aac1"}, + {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c31ad0c0c507894e3eaa843415841995bf8de4d6b2d24c6e33099f4bc9fc0d4f"}, + {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0912b8a8fadeb32ff67a3ed44249448c20148397c1ed905d5dac185b4ca547bb"}, + {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0d93400c18596b7dc4794d48a63fb361b01a0d8eb39f28800dc900c8fbdaca91"}, + {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d00f3c5e0d764a5c9aa5a62d99728c56d455310bcc288a79cab10157b3af426f"}, + {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:d742c36ed44f2798c8d3f4bc511f479b9ceef2b93f348671184139e7d708042c"}, + {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:814375093edae5f1cb31e3407997cf3eacefb9010f96df10d64829362ae2df69"}, + {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:8224f98be68a84b19f48e0bdc14224b5a71339aff3a27df69989fa47d01296f3"}, + {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:d9a487ef090aea982d748b1b0d74fe7c3950b109df967630a20584f9a99c0683"}, + {file = "aiohttp-3.10.5-cp310-cp310-win32.whl", hash = "sha256:d9ef084e3dc690ad50137cc05831c52b6ca428096e6deb3c43e95827f531d5ef"}, + {file = "aiohttp-3.10.5-cp310-cp310-win_amd64.whl", hash = "sha256:66bf9234e08fe561dccd62083bf67400bdbf1c67ba9efdc3dac03650e97c6088"}, + {file = "aiohttp-3.10.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8c6a4e5e40156d72a40241a25cc226051c0a8d816610097a8e8f517aeacd59a2"}, + {file = "aiohttp-3.10.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2c634a3207a5445be65536d38c13791904fda0748b9eabf908d3fe86a52941cf"}, + {file = "aiohttp-3.10.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4aff049b5e629ef9b3e9e617fa6e2dfeda1bf87e01bcfecaf3949af9e210105e"}, + {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1942244f00baaacaa8155eca94dbd9e8cc7017deb69b75ef67c78e89fdad3c77"}, + {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e04a1f2a65ad2f93aa20f9ff9f1b672bf912413e5547f60749fa2ef8a644e061"}, + {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7f2bfc0032a00405d4af2ba27f3c429e851d04fad1e5ceee4080a1c570476697"}, + {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:424ae21498790e12eb759040bbb504e5e280cab64693d14775c54269fd1d2bb7"}, + {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:975218eee0e6d24eb336d0328c768ebc5d617609affaca5dbbd6dd1984f16ed0"}, + {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:4120d7fefa1e2d8fb6f650b11489710091788de554e2b6f8347c7a20ceb003f5"}, + {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:b90078989ef3fc45cf9221d3859acd1108af7560c52397ff4ace8ad7052a132e"}, + {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:ba5a8b74c2a8af7d862399cdedce1533642fa727def0b8c3e3e02fcb52dca1b1"}, + {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:02594361128f780eecc2a29939d9dfc870e17b45178a867bf61a11b2a4367277"}, + {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:8fb4fc029e135859f533025bc82047334e24b0d489e75513144f25408ecaf058"}, + {file = "aiohttp-3.10.5-cp311-cp311-win32.whl", hash = "sha256:e1ca1ef5ba129718a8fc827b0867f6aa4e893c56eb00003b7367f8a733a9b072"}, + {file = "aiohttp-3.10.5-cp311-cp311-win_amd64.whl", hash = "sha256:349ef8a73a7c5665cca65c88ab24abe75447e28aa3bc4c93ea5093474dfdf0ff"}, + {file = "aiohttp-3.10.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:305be5ff2081fa1d283a76113b8df7a14c10d75602a38d9f012935df20731487"}, + {file = "aiohttp-3.10.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:3a1c32a19ee6bbde02f1cb189e13a71b321256cc1d431196a9f824050b160d5a"}, + {file = "aiohttp-3.10.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:61645818edd40cc6f455b851277a21bf420ce347baa0b86eaa41d51ef58ba23d"}, + {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c225286f2b13bab5987425558baa5cbdb2bc925b2998038fa028245ef421e75"}, + {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8ba01ebc6175e1e6b7275c907a3a36be48a2d487549b656aa90c8a910d9f3178"}, + {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8eaf44ccbc4e35762683078b72bf293f476561d8b68ec8a64f98cf32811c323e"}, + {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1c43eb1ab7cbf411b8e387dc169acb31f0ca0d8c09ba63f9eac67829585b44f"}, + {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:de7a5299827253023c55ea549444e058c0eb496931fa05d693b95140a947cb73"}, + {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4790f0e15f00058f7599dab2b206d3049d7ac464dc2e5eae0e93fa18aee9e7bf"}, + {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:44b324a6b8376a23e6ba25d368726ee3bc281e6ab306db80b5819999c737d820"}, + {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:0d277cfb304118079e7044aad0b76685d30ecb86f83a0711fc5fb257ffe832ca"}, + {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:54d9ddea424cd19d3ff6128601a4a4d23d54a421f9b4c0fff740505813739a91"}, + {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:4f1c9866ccf48a6df2b06823e6ae80573529f2af3a0992ec4fe75b1a510df8a6"}, + {file = "aiohttp-3.10.5-cp312-cp312-win32.whl", hash = "sha256:dc4826823121783dccc0871e3f405417ac116055bf184ac04c36f98b75aacd12"}, + {file = "aiohttp-3.10.5-cp312-cp312-win_amd64.whl", hash = "sha256:22c0a23a3b3138a6bf76fc553789cb1a703836da86b0f306b6f0dc1617398abc"}, + {file = "aiohttp-3.10.5-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:7f6b639c36734eaa80a6c152a238242bedcee9b953f23bb887e9102976343092"}, + {file = "aiohttp-3.10.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f29930bc2921cef955ba39a3ff87d2c4398a0394ae217f41cb02d5c26c8b1b77"}, + {file = "aiohttp-3.10.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f489a2c9e6455d87eabf907ac0b7d230a9786be43fbe884ad184ddf9e9c1e385"}, + {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:123dd5b16b75b2962d0fff566effb7a065e33cd4538c1692fb31c3bda2bfb972"}, + {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b98e698dc34966e5976e10bbca6d26d6724e6bdea853c7c10162a3235aba6e16"}, + {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c3b9162bab7e42f21243effc822652dc5bb5e8ff42a4eb62fe7782bcbcdfacf6"}, + {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1923a5c44061bffd5eebeef58cecf68096e35003907d8201a4d0d6f6e387ccaa"}, + {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d55f011da0a843c3d3df2c2cf4e537b8070a419f891c930245f05d329c4b0689"}, + {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:afe16a84498441d05e9189a15900640a2d2b5e76cf4efe8cbb088ab4f112ee57"}, + {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:f8112fb501b1e0567a1251a2fd0747baae60a4ab325a871e975b7bb67e59221f"}, + {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:1e72589da4c90337837fdfe2026ae1952c0f4a6e793adbbfbdd40efed7c63599"}, + {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:4d46c7b4173415d8e583045fbc4daa48b40e31b19ce595b8d92cf639396c15d5"}, + {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:33e6bc4bab477c772a541f76cd91e11ccb6d2efa2b8d7d7883591dfb523e5987"}, + {file = "aiohttp-3.10.5-cp313-cp313-win32.whl", hash = "sha256:c58c6837a2c2a7cf3133983e64173aec11f9c2cd8e87ec2fdc16ce727bcf1a04"}, + {file = "aiohttp-3.10.5-cp313-cp313-win_amd64.whl", hash = "sha256:38172a70005252b6893088c0f5e8a47d173df7cc2b2bd88650957eb84fcf5022"}, + {file = "aiohttp-3.10.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:f6f18898ace4bcd2d41a122916475344a87f1dfdec626ecde9ee802a711bc569"}, + {file = "aiohttp-3.10.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5ede29d91a40ba22ac1b922ef510aab871652f6c88ef60b9dcdf773c6d32ad7a"}, + {file = "aiohttp-3.10.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:673f988370f5954df96cc31fd99c7312a3af0a97f09e407399f61583f30da9bc"}, + {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58718e181c56a3c02d25b09d4115eb02aafe1a732ce5714ab70326d9776457c3"}, + {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4b38b1570242fbab8d86a84128fb5b5234a2f70c2e32f3070143a6d94bc854cf"}, + {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:074d1bff0163e107e97bd48cad9f928fa5a3eb4b9d33366137ffce08a63e37fe"}, + {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd31f176429cecbc1ba499d4aba31aaccfea488f418d60376b911269d3b883c5"}, + {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7384d0b87d4635ec38db9263e6a3f1eb609e2e06087f0aa7f63b76833737b471"}, + {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:8989f46f3d7ef79585e98fa991e6ded55d2f48ae56d2c9fa5e491a6e4effb589"}, + {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:c83f7a107abb89a227d6c454c613e7606c12a42b9a4ca9c5d7dad25d47c776ae"}, + {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:cde98f323d6bf161041e7627a5fd763f9fd829bcfcd089804a5fdce7bb6e1b7d"}, + {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:676f94c5480d8eefd97c0c7e3953315e4d8c2b71f3b49539beb2aa676c58272f"}, + {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:2d21ac12dc943c68135ff858c3a989f2194a709e6e10b4c8977d7fcd67dfd511"}, + {file = "aiohttp-3.10.5-cp38-cp38-win32.whl", hash = "sha256:17e997105bd1a260850272bfb50e2a328e029c941c2708170d9d978d5a30ad9a"}, + {file = "aiohttp-3.10.5-cp38-cp38-win_amd64.whl", hash = "sha256:1c19de68896747a2aa6257ae4cf6ef59d73917a36a35ee9d0a6f48cff0f94db8"}, + {file = "aiohttp-3.10.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7e2fe37ac654032db1f3499fe56e77190282534810e2a8e833141a021faaab0e"}, + {file = "aiohttp-3.10.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f5bf3ead3cb66ab990ee2561373b009db5bc0e857549b6c9ba84b20bc462e172"}, + {file = "aiohttp-3.10.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1b2c16a919d936ca87a3c5f0e43af12a89a3ce7ccbce59a2d6784caba945b68b"}, + {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad146dae5977c4dd435eb31373b3fe9b0b1bf26858c6fc452bf6af394067e10b"}, + {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8c5c6fa16412b35999320f5c9690c0f554392dc222c04e559217e0f9ae244b92"}, + {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:95c4dc6f61d610bc0ee1edc6f29d993f10febfe5b76bb470b486d90bbece6b22"}, + {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da452c2c322e9ce0cfef392e469a26d63d42860f829026a63374fde6b5c5876f"}, + {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:898715cf566ec2869d5cb4d5fb4be408964704c46c96b4be267442d265390f32"}, + {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:391cc3a9c1527e424c6865e087897e766a917f15dddb360174a70467572ac6ce"}, + {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:380f926b51b92d02a34119d072f178d80bbda334d1a7e10fa22d467a66e494db"}, + {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:ce91db90dbf37bb6fa0997f26574107e1b9d5ff939315247b7e615baa8ec313b"}, + {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:9093a81e18c45227eebe4c16124ebf3e0d893830c6aca7cc310bfca8fe59d857"}, + {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:ee40b40aa753d844162dcc80d0fe256b87cba48ca0054f64e68000453caead11"}, + {file = "aiohttp-3.10.5-cp39-cp39-win32.whl", hash = "sha256:03f2645adbe17f274444953bdea69f8327e9d278d961d85657cb0d06864814c1"}, + {file = "aiohttp-3.10.5-cp39-cp39-win_amd64.whl", hash = "sha256:d17920f18e6ee090bdd3d0bfffd769d9f2cb4c8ffde3eb203777a3895c128862"}, + {file = "aiohttp-3.10.5.tar.gz", hash = "sha256:f071854b47d39591ce9a17981c46790acb30518e2f83dfca8db2dfa091178691"}, ] [package.dependencies] @@ -363,13 +378,13 @@ jmespath = ">=0.9.3,<1.0.0" [[package]] name = "aliyun-python-sdk-kms" -version = "2.16.3" +version = "2.16.4" description = "The kms module of Aliyun Python sdk." optional = false python-versions = "*" files = [ - {file = "aliyun-python-sdk-kms-2.16.3.tar.gz", hash = "sha256:c31b7d24e153271a3043e801e7b6b6b3f0db47e95a83c8d10cdab8c11662fc39"}, - {file = "aliyun_python_sdk_kms-2.16.3-py2.py3-none-any.whl", hash = "sha256:8bb8c293be94e0cc9114a5286a503d2ec215eaf8a1fb51de5d6c8bcac209d4a1"}, + {file = "aliyun-python-sdk-kms-2.16.4.tar.gz", hash = "sha256:0d5bb165c07b6a972939753a128507393f48011792ee0ec4f59b6021eabd9752"}, + {file = "aliyun_python_sdk_kms-2.16.4-py2.py3-none-any.whl", hash = "sha256:6d412663ef8c35dc3bb42be6a3ee76a9bc07acdadca6dd26815131062bedf4c5"}, ] [package.dependencies] @@ -536,6 +551,69 @@ files = [ [package.dependencies] cryptography = "*" +[[package]] +name = "azure-ai-inference" +version = "1.0.0b3" +description = "Microsoft Azure Ai Inference Client Library for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "azure-ai-inference-1.0.0b3.tar.gz", hash = "sha256:1e99dc74c3b335a457500311bbbadb348f54dc4c12252a93cb8ab78d6d217ff0"}, + {file = "azure_ai_inference-1.0.0b3-py3-none-any.whl", hash = "sha256:6734ca7334c809a170beb767f1f1455724ab3f006cb60045e42a833c0e764403"}, +] + +[package.dependencies] +azure-core = ">=1.30.0" +isodate = ">=0.6.1" +typing-extensions = ">=4.6.0" + +[[package]] +name = "azure-ai-ml" +version = "1.19.0" +description = "Microsoft Azure Machine Learning Client Library for Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "azure-ai-ml-1.19.0.tar.gz", hash = "sha256:94bb1afbb0497e539ae75455fc4a51b6942b5b68b3a275727ecce6ceb250eff9"}, + {file = "azure_ai_ml-1.19.0-py3-none-any.whl", hash = "sha256:f0385af06efbeae1f83113613e45343508d1288fd2f05857619e7c7d4d4f5302"}, +] + +[package.dependencies] +azure-common = ">=1.1" +azure-core = ">=1.23.0" +azure-mgmt-core = ">=1.3.0" +azure-storage-blob = ">=12.10.0" +azure-storage-file-datalake = ">=12.2.0" +azure-storage-file-share = "*" +colorama = "*" +isodate = "*" +jsonschema = ">=4.0.0" +marshmallow = ">=3.5" +msrest = ">=0.6.18" +opencensus-ext-azure = "*" +opencensus-ext-logging = "*" +pydash = ">=6.0.0" +pyjwt = "*" +pyyaml = ">=5.1.0" +strictyaml = "*" +tqdm = "*" +typing-extensions = "*" + +[package.extras] +designer = ["mldesigner"] +mount = ["azureml-dataprep-rslex (>=2.22.0)"] + +[[package]] +name = "azure-common" +version = "1.1.28" +description = "Microsoft Azure Client Library for Python (Common)" +optional = false +python-versions = "*" +files = [ + {file = "azure-common-1.1.28.zip", hash = "sha256:4ac0cd3214e36b6a1b6a442686722a5d8cc449603aa833f3f0f40bda836704a3"}, + {file = "azure_common-1.1.28-py2.py3-none-any.whl", hash = "sha256:5c12d3dcf4ec20599ca6b0d3e09e86e146353d443e7fcc050c9a19c1f9df20ad"}, +] + [[package]] name = "azure-core" version = "1.30.2" @@ -572,6 +650,20 @@ cryptography = ">=2.5" msal = ">=1.24.0" msal-extensions = ">=0.3.0" +[[package]] +name = "azure-mgmt-core" +version = "1.4.0" +description = "Microsoft Azure Management Core Library for Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "azure-mgmt-core-1.4.0.zip", hash = "sha256:d195208340094f98e5a6661b781cde6f6a051e79ce317caabd8ff97030a9b3ae"}, + {file = "azure_mgmt_core-1.4.0-py3-none-any.whl", hash = "sha256:81071675f186a585555ef01816f2774d49c1c9024cb76e5720c3c0f6b337bb7d"}, +] + +[package.dependencies] +azure-core = ">=1.26.2,<2.0.0" + [[package]] name = "azure-storage-blob" version = "12.13.0" @@ -588,6 +680,42 @@ azure-core = ">=1.23.1,<2.0.0" cryptography = ">=2.1.4" msrest = ">=0.6.21" +[[package]] +name = "azure-storage-file-datalake" +version = "12.8.0" +description = "Microsoft Azure File DataLake Storage Client Library for Python" +optional = false +python-versions = ">=3.6" +files = [ + {file = "azure-storage-file-datalake-12.8.0.zip", hash = "sha256:12e6306e5efb5ca28e0ccd9fa79a2c61acd589866d6109fe5601b18509da92f4"}, + {file = "azure_storage_file_datalake-12.8.0-py3-none-any.whl", hash = "sha256:b6cf5733fe794bf3c866efbe3ce1941409e35b6b125028ac558b436bf90f2de7"}, +] + +[package.dependencies] +azure-core = ">=1.23.1,<2.0.0" +azure-storage-blob = ">=12.13.0,<13.0.0" +msrest = ">=0.6.21" + +[[package]] +name = "azure-storage-file-share" +version = "12.17.0" +description = "Microsoft Azure Azure File Share Storage Client Library for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "azure-storage-file-share-12.17.0.tar.gz", hash = "sha256:f7b2c6cfc1b7cb80097a53b1ed2efa9e545b49a291430d369cdb49fafbc841d6"}, + {file = "azure_storage_file_share-12.17.0-py3-none-any.whl", hash = "sha256:c4652759a9d529bf08881bb53275bf38774bb643746b849d27c47118f9cf923d"}, +] + +[package.dependencies] +azure-core = ">=1.28.0" +cryptography = ">=2.1.4" +isodate = ">=0.6.1" +typing-extensions = ">=4.6.0" + +[package.extras] +aio = ["azure-core[aio] (>=1.28.0)"] + [[package]] name = "backoff" version = "2.2.1" @@ -700,13 +828,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.34.155" +version = "1.34.162" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">=3.8" files = [ - {file = "botocore-1.34.155-py3-none-any.whl", hash = "sha256:f2696c11bb0cad627d42512937befd2e3f966aedd15de00d90ee13cf7a16b328"}, - {file = "botocore-1.34.155.tar.gz", hash = "sha256:3aa88abfef23909f68d3e6679a3d4b4bb3c6288a6cfbf9e253aa68dac8edad64"}, + {file = "botocore-1.34.162-py3-none-any.whl", hash = "sha256:2d918b02db88d27a75b48275e6fb2506e9adaaddbec1ffa6a8a0898b34e769be"}, + {file = "botocore-1.34.162.tar.gz", hash = "sha256:adc23be4fb99ad31961236342b7cbf3c0bfc62532cd02852196032e8c0d682f3"}, ] [package.dependencies] @@ -1370,77 +1498,77 @@ testing = ["pytest (>=7.2.1)", "pytest-cov (>=4.0.0)", "tox (>=4.4.3)"] [[package]] name = "clickhouse-connect" -version = "0.7.18" +version = "0.7.19" description = "ClickHouse Database Core Driver for Python, Pandas, and Superset" optional = false python-versions = "~=3.8" files = [ - {file = "clickhouse-connect-0.7.18.tar.gz", hash = "sha256:516aba1fdcf58973b0d0d90168a60c49f6892b6db1183b932f80ae057994eadb"}, - {file = "clickhouse_connect-0.7.18-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:43e712b8fada717160153022314473826adffde00e8cbe8068e0aa1c187c2395"}, - {file = "clickhouse_connect-0.7.18-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0a21244d24c9b2a7d1ea2cf23f254884113e0f6d9950340369ce154d7d377165"}, - {file = "clickhouse_connect-0.7.18-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:347b19f3674b57906dea94dd0e8b72aaedc822131cc2a2383526b19933ed7a33"}, - {file = "clickhouse_connect-0.7.18-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:23c5aa1b144491211f662ed26f279845fb367c37d49b681b783ca4f8c51c7891"}, - {file = "clickhouse_connect-0.7.18-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e99b4271ed08cc59162a6025086f1786ded5b8a29f4c38e2d3b2a58af04f85f5"}, - {file = "clickhouse_connect-0.7.18-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:27d76d1dbe988350567dab7fbcc0a54cdd25abedc5585326c753974349818694"}, - {file = "clickhouse_connect-0.7.18-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:d2cd40b4e07df277192ab6bcb187b3f61e0074ad0e256908bf443b3080be4a6c"}, - {file = "clickhouse_connect-0.7.18-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8f4ae2c4fb66b2b49f2e7f893fe730712a61a068e79f7272e60d4dd7d64df260"}, - {file = "clickhouse_connect-0.7.18-cp310-cp310-win32.whl", hash = "sha256:ed871195b25a4e1acfd37f59527ceb872096f0cd65d76af8c91f581c033b1cc0"}, - {file = "clickhouse_connect-0.7.18-cp310-cp310-win_amd64.whl", hash = "sha256:0c4989012e434b9c167bddf9298ca6eb076593e48a2cab7347cd70a446a7b5d3"}, - {file = "clickhouse_connect-0.7.18-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:52cfcd77fc63561e7b51940e32900c13731513d703d7fc54a3a6eb1fa4f7be4e"}, - {file = "clickhouse_connect-0.7.18-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:71d7bb9a24b0eacf8963044d6a1dd9e86dfcdd30afe1bd4a581c00910c83895a"}, - {file = "clickhouse_connect-0.7.18-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:395cfe09d1d39be4206fc1da96fe316f270077791f9758fcac44fd2765446dba"}, - {file = "clickhouse_connect-0.7.18-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac55b2b2eb068b02cbb1afbfc8b2255734e28a646d633c43a023a9b95e08023b"}, - {file = "clickhouse_connect-0.7.18-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4d59bb1df3814acb321f0fe87a4a6eea658463d5e59f6dc8ae10072df1205591"}, - {file = "clickhouse_connect-0.7.18-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:da5ea738641a7ad0ab7a8e1d8d6234639ea1e61c6eac970bbc6b94547d2c2fa7"}, - {file = "clickhouse_connect-0.7.18-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:72eb32a75026401777e34209694ffe64db0ce610475436647ed45589b4ab4efe"}, - {file = "clickhouse_connect-0.7.18-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:43bdd638b1ff27649d0ed9ed5000a8b8d754be891a8d279b27c72c03e3d12dcb"}, - {file = "clickhouse_connect-0.7.18-cp311-cp311-win32.whl", hash = "sha256:f45bdcba1dc84a1f60a8d827310f615ecbc322518c2d36bba7bf878631007152"}, - {file = "clickhouse_connect-0.7.18-cp311-cp311-win_amd64.whl", hash = "sha256:6df629ab4b646a49a74e791e14a1b6a73ccbe6c4ee25f864522588d376b66279"}, - {file = "clickhouse_connect-0.7.18-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:32a35e1e63e4ae708432cbe29c8d116518d2d7b9ecb575b912444c3078b20e20"}, - {file = "clickhouse_connect-0.7.18-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:357529b8c08305ab895cdc898b60a3dc9b36637dfa4dbfedfc1d00548fc88edc"}, - {file = "clickhouse_connect-0.7.18-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2aa124d2bb65e29443779723e52398e8724e4bf56db94c9a93fd8208b9d6e2bf"}, - {file = "clickhouse_connect-0.7.18-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e3646254607e38294e20bf2e20b780b1c3141fb246366a1ad2021531f2c9c1b"}, - {file = "clickhouse_connect-0.7.18-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:433e50309af9d46d1b52e5b93ea105332565558be35296c7555c9c2753687586"}, - {file = "clickhouse_connect-0.7.18-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:251e67753909f76f8b136cad734501e0daf5977ed62747e18baa2b187f41c92c"}, - {file = "clickhouse_connect-0.7.18-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:a9980916495da3ed057e56ce2c922fc23de614ea5d74ed470b8450b58902ccee"}, - {file = "clickhouse_connect-0.7.18-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:555e00660c04a524ea00409f783265ccd0d0192552eb9d4dc10d2aeaf2fa6575"}, - {file = "clickhouse_connect-0.7.18-cp312-cp312-win32.whl", hash = "sha256:f4770c100f0608511f7e572b63a6b222fb780fc67341c11746d361c2b03d36d3"}, - {file = "clickhouse_connect-0.7.18-cp312-cp312-win_amd64.whl", hash = "sha256:fd44a7885d992410668d083ba38d6a268a1567f49709300b4ff84eb6aef63b70"}, - {file = "clickhouse_connect-0.7.18-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9ac122dcabe1a9d3c14d331fade70a0adc78cf4006c8b91ee721942cdaa1190e"}, - {file = "clickhouse_connect-0.7.18-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1e89db8e8cc9187f2e9cd6aa32062f67b3b4de7b21b8703f103e89d659eda736"}, - {file = "clickhouse_connect-0.7.18-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c34bb25e5ab9a97a4154d43fdcd16751c9aa4a6e6f959016e4c5fe5b692728ed"}, - {file = "clickhouse_connect-0.7.18-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:929441a6689a78c63c6a05ee7eb39a183601d93714835ebd537c0572101f7ab1"}, - {file = "clickhouse_connect-0.7.18-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e8852df54b04361e57775d8ae571cd87e6983f7ed968890c62bbba6a2f2c88fd"}, - {file = "clickhouse_connect-0.7.18-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:56333eb772591162627455e2c21c8541ed628a9c6e7c115193ad00f24fc59440"}, - {file = "clickhouse_connect-0.7.18-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:ac6633d2996100552d2ae47ac5e4eb551e11f69d05637ea84f1e13ac0f2bc21a"}, - {file = "clickhouse_connect-0.7.18-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:265085ab548fb49981fe2aef9f46652ee24d5583bf12e652abb13ee2d7e77581"}, - {file = "clickhouse_connect-0.7.18-cp38-cp38-win32.whl", hash = "sha256:5ee6c1f74df5fb19b341c389cfed7535fb627cbb9cb1a9bdcbda85045b86cd49"}, - {file = "clickhouse_connect-0.7.18-cp38-cp38-win_amd64.whl", hash = "sha256:c7a28f810775ce68577181e752ecd2dc8caae77f288b6b9f6a7ce4d36657d4fb"}, - {file = "clickhouse_connect-0.7.18-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:67f9a3953693b609ab068071be5ac9521193f728b29057e913b386582f84b0c2"}, - {file = "clickhouse_connect-0.7.18-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:77e202b8606096769bf45e68b46e6bb8c78c2c451c29cb9b3a7bf505b4060d44"}, - {file = "clickhouse_connect-0.7.18-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8abcbd17f243ca8399a06fb08970d68e73d1ad671f84bb38518449248093f655"}, - {file = "clickhouse_connect-0.7.18-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:192605c2a9412e4c7d4baab85e432a58a0a5520615f05bc14f13c2836cfc6eeb"}, - {file = "clickhouse_connect-0.7.18-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c17108b190ab34645ee1981440ae129ecd7ca0cb6a93b4e5ce3ffc383355243f"}, - {file = "clickhouse_connect-0.7.18-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ac1be43360a6e602784eb60547a03a6c2c574744cb8982ec15aac0e0e57709bd"}, - {file = "clickhouse_connect-0.7.18-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:cf403781d4ffd5a47aa7eff591940df182de4d9c423cfdc7eb6ade1a1b100e22"}, - {file = "clickhouse_connect-0.7.18-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:937c6481ec083e2a0bcf178ea363b72d437ab0c8fcbe65143db64b12c1e077c0"}, - {file = "clickhouse_connect-0.7.18-cp39-cp39-win32.whl", hash = "sha256:77635fea4b3fc4b1568a32674f04d35f4e648e3180528a9bb776e46e76090e4a"}, - {file = "clickhouse_connect-0.7.18-cp39-cp39-win_amd64.whl", hash = "sha256:5ef60eb76be54b6d6bd8f189b076939e2cca16b50b92b763e7a9c7a62b488045"}, - {file = "clickhouse_connect-0.7.18-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:7bf76743d7b92b6cac6b4ef2e7a4c2d030ecf2fd542fcfccb374b2432b8d1027"}, - {file = "clickhouse_connect-0.7.18-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:65b344f174d63096eec098137b5d9c3bb545d67dd174966246c4aa80f9c0bc1e"}, - {file = "clickhouse_connect-0.7.18-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:24dcc19338cd540e6a3e32e8a7c72c5fc4930c0dd5a760f76af9d384b3e57ddc"}, - {file = "clickhouse_connect-0.7.18-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:31f5e42d5fd4eaab616926bae344c17202950d9d9c04716d46bccce6b31dbb73"}, - {file = "clickhouse_connect-0.7.18-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:a890421403c7a59ef85e3afc4ff0d641c5553c52fbb9d6ce30c0a0554649fac6"}, - {file = "clickhouse_connect-0.7.18-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:d61de71d2b82446dd66ade1b925270366c36a2b11779d5d1bcf71b1bfdd161e6"}, - {file = "clickhouse_connect-0.7.18-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e81c4f2172e8d6f3dc4dd64ff2dc426920c0caeed969b4ec5bdd0b2fad1533e4"}, - {file = "clickhouse_connect-0.7.18-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:092cb8e8acdcccce01d239760405fbd8c266052def49b13ad0a96814f5e521ca"}, - {file = "clickhouse_connect-0.7.18-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a1ae8b1bab7f06815abf9d833a66849faa2b9dfadcc5728fd14c494e2879afa8"}, - {file = "clickhouse_connect-0.7.18-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e08ebec4db83109024c97ca2d25740bf57915160d7676edd5c4390777c3e3ec0"}, - {file = "clickhouse_connect-0.7.18-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:e5e42ec23b59597b512b994fec68ac1c2fa6def8594848cc3ae2459cf5e9d76a"}, - {file = "clickhouse_connect-0.7.18-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1aad4543a1ae4d40dc815ef85031a1809fe101687380d516383b168a7407ab2"}, - {file = "clickhouse_connect-0.7.18-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46cb4c604bd696535b1e091efb8047b833ff4220d31dbd95558c3587fda533a7"}, - {file = "clickhouse_connect-0.7.18-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:05e1ef335b81bf6b5908767c3b55e842f1f8463742992653551796eeb8f2d7d6"}, - {file = "clickhouse_connect-0.7.18-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:094e089de4a50a170f5fd1c0ebb2ea357e055266220bb11dfd7ddf2d4e9c9123"}, + {file = "clickhouse-connect-0.7.19.tar.gz", hash = "sha256:ce8f21f035781c5ef6ff57dc162e8150779c009b59f14030ba61f8c9c10c06d0"}, + {file = "clickhouse_connect-0.7.19-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6ac74eb9e8d6331bae0303d0fc6bdc2125aa4c421ef646348b588760b38c29e9"}, + {file = "clickhouse_connect-0.7.19-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:300f3dea7dd48b2798533ed2486e4b0c3bb03c8d9df9aed3fac44161b92a30f9"}, + {file = "clickhouse_connect-0.7.19-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c72629f519105e21600680c791459d729889a290440bbdc61e43cd5eb61d928"}, + {file = "clickhouse_connect-0.7.19-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ece0fb202cd9267b3872210e8e0974e4c33c8f91ca9f1c4d92edea997189c72"}, + {file = "clickhouse_connect-0.7.19-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a6e5adf0359043d4d21c9a668cc1b6323a1159b3e1a77aea6f82ce528b5e4c5b"}, + {file = "clickhouse_connect-0.7.19-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:63432180179e90f6f3c18861216f902d1693979e3c26a7f9ef9912c92ce00d14"}, + {file = "clickhouse_connect-0.7.19-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:754b9c58b032835caaa9177b69059dc88307485d2cf6d0d545b3dedb13cb512a"}, + {file = "clickhouse_connect-0.7.19-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:24e2694e89d12bba405a14b84c36318620dc50f90adbc93182418742d8f6d73f"}, + {file = "clickhouse_connect-0.7.19-cp310-cp310-win32.whl", hash = "sha256:52929826b39b5b0f90f423b7a035930b8894b508768e620a5086248bcbad3707"}, + {file = "clickhouse_connect-0.7.19-cp310-cp310-win_amd64.whl", hash = "sha256:5c301284c87d132963388b6e8e4a690c0776d25acc8657366eccab485e53738f"}, + {file = "clickhouse_connect-0.7.19-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ee47af8926a7ec3a970e0ebf29a82cbbe3b1b7eae43336a81b3a0ca18091de5f"}, + {file = "clickhouse_connect-0.7.19-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ce429233b2d21a8a149c8cd836a2555393cbcf23d61233520db332942ffb8964"}, + {file = "clickhouse_connect-0.7.19-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:617c04f5c46eed3344a7861cd96fb05293e70d3b40d21541b1e459e7574efa96"}, + {file = "clickhouse_connect-0.7.19-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f08e33b8cc2dc1873edc5ee4088d4fc3c0dbb69b00e057547bcdc7e9680b43e5"}, + {file = "clickhouse_connect-0.7.19-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:921886b887f762e5cc3eef57ef784d419a3f66df85fd86fa2e7fbbf464c4c54a"}, + {file = "clickhouse_connect-0.7.19-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6ad0cf8552a9e985cfa6524b674ae7c8f5ba51df5bd3ecddbd86c82cdbef41a7"}, + {file = "clickhouse_connect-0.7.19-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:70f838ef0861cdf0e2e198171a1f3fd2ee05cf58e93495eeb9b17dfafb278186"}, + {file = "clickhouse_connect-0.7.19-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c5f0d207cb0dcc1adb28ced63f872d080924b7562b263a9d54d4693b670eb066"}, + {file = "clickhouse_connect-0.7.19-cp311-cp311-win32.whl", hash = "sha256:8c96c4c242b98fcf8005e678a26dbd4361748721b6fa158c1fe84ad15c7edbbe"}, + {file = "clickhouse_connect-0.7.19-cp311-cp311-win_amd64.whl", hash = "sha256:bda092bab224875ed7c7683707d63f8a2322df654c4716e6611893a18d83e908"}, + {file = "clickhouse_connect-0.7.19-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8f170d08166438d29f0dcfc8a91b672c783dc751945559e65eefff55096f9274"}, + {file = "clickhouse_connect-0.7.19-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:26b80cb8f66bde9149a9a2180e2cc4895c1b7d34f9dceba81630a9b9a9ae66b2"}, + {file = "clickhouse_connect-0.7.19-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9ba80e3598acf916c4d1b2515671f65d9efee612a783c17c56a5a646f4db59b9"}, + {file = "clickhouse_connect-0.7.19-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d38c30bd847af0ce7ff738152478f913854db356af4d5824096394d0eab873d"}, + {file = "clickhouse_connect-0.7.19-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d41d4b159071c0e4f607563932d4fa5c2a8fc27d3ba1200d0929b361e5191864"}, + {file = "clickhouse_connect-0.7.19-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3682c2426f5dbda574611210e3c7c951b9557293a49eb60a7438552435873889"}, + {file = "clickhouse_connect-0.7.19-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:6d492064dca278eb61be3a2d70a5f082e2ebc8ceebd4f33752ae234116192020"}, + {file = "clickhouse_connect-0.7.19-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:62612da163b934c1ff35df6155a47cf17ac0e2d2f9f0f8f913641e5c02cdf39f"}, + {file = "clickhouse_connect-0.7.19-cp312-cp312-win32.whl", hash = "sha256:196e48c977affc045794ec7281b4d711e169def00535ecab5f9fdeb8c177f149"}, + {file = "clickhouse_connect-0.7.19-cp312-cp312-win_amd64.whl", hash = "sha256:b771ca6a473d65103dcae82810d3a62475c5372fc38d8f211513c72b954fb020"}, + {file = "clickhouse_connect-0.7.19-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:85a016eebff440b76b90a4725bb1804ddc59e42bba77d21c2a2ec4ac1df9e28d"}, + {file = "clickhouse_connect-0.7.19-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f059d3e39be1bafbf3cf0e12ed19b3cbf30b468a4840ab85166fd023ce8c3a17"}, + {file = "clickhouse_connect-0.7.19-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39ed54ba0998fd6899fcc967af2b452da28bd06de22e7ebf01f15acbfd547eac"}, + {file = "clickhouse_connect-0.7.19-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e4b4d786572cb695a087a71cfdc53999f76b7f420f2580c9cffa8cc51442058"}, + {file = "clickhouse_connect-0.7.19-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3710ca989ceae03d5ae56a436b4fe246094dbc17a2946ff318cb460f31b69450"}, + {file = "clickhouse_connect-0.7.19-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:d104f25a054cb663495a51ccb26ea11bcdc53e9b54c6d47a914ee6fba7523e62"}, + {file = "clickhouse_connect-0.7.19-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:ee23b80ee4c5b05861582dd4cd11f0ca0d215a899e9ba299a6ec6e9196943b1b"}, + {file = "clickhouse_connect-0.7.19-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:942ec21211d369068ab0ac082312d4df53c638bfc41545d02c41a9055e212df8"}, + {file = "clickhouse_connect-0.7.19-cp38-cp38-win32.whl", hash = "sha256:cb8f0a59d1521a6b30afece7c000f6da2cd9f22092e90981aa83342032e5df99"}, + {file = "clickhouse_connect-0.7.19-cp38-cp38-win_amd64.whl", hash = "sha256:98d5779dba942459d5dc6aa083e3a8a83e1cf6191eaa883832118ad7a7e69c87"}, + {file = "clickhouse_connect-0.7.19-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9f57aaa32d90f3bd18aa243342b3e75f062dc56a7f988012a22f65fb7946e81d"}, + {file = "clickhouse_connect-0.7.19-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5fb25143e4446d3a73fdc1b7d976a0805f763c37bf8f9b2d612a74f65d647830"}, + {file = "clickhouse_connect-0.7.19-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b4e19c9952b7b9fe24a99cca0b36a37e17e2a0e59b14457a2ce8868aa32e30e"}, + {file = "clickhouse_connect-0.7.19-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9876509aa25804f1377cb1b54dd55c1f5f37a9fbc42fa0c4ac8ac51b38db5926"}, + {file = "clickhouse_connect-0.7.19-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:04cfb1dae8fb93117211cfe4e04412b075e47580391f9eee9a77032d8e7d46f4"}, + {file = "clickhouse_connect-0.7.19-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b04f7c57f61b5dfdbf49d4b5e4fa5e91ce86bee09bb389b641268afa8f511ab4"}, + {file = "clickhouse_connect-0.7.19-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:e5b563f32dcc9cb6ff1f6ed238e83c3e80eb15814b1ea130817c004c241a3c2e"}, + {file = "clickhouse_connect-0.7.19-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:6018675a231130bd03a7b39a3e875e683286d98115085bfa3ac0918f555f4bfe"}, + {file = "clickhouse_connect-0.7.19-cp39-cp39-win32.whl", hash = "sha256:5cb67ae3309396033b825626d60fe2cd789c1d2a183faabef8ffdbbef153d7fb"}, + {file = "clickhouse_connect-0.7.19-cp39-cp39-win_amd64.whl", hash = "sha256:fd225af60478c068cde0952e8df8f731f24c828b75cc1a2e61c21057ff546ecd"}, + {file = "clickhouse_connect-0.7.19-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:6f31898e0281f820e35710b5c4ad1d40a6c01ffae5278afaef4a16877ac8cbfb"}, + {file = "clickhouse_connect-0.7.19-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:51c911b0b8281ab4a909320f41dd9c0662796bec157c8f2704de702c552104db"}, + {file = "clickhouse_connect-0.7.19-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1088da11789c519f9bb8927a14b16892e3c65e2893abe2680eae68bf6c63835"}, + {file = "clickhouse_connect-0.7.19-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:03953942cc073078b40619a735ebeaed9bf98efc71c6f43ce92a38540b1308ce"}, + {file = "clickhouse_connect-0.7.19-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:4ac0602fa305d097a0cd40cebbe10a808f6478c9f303d57a48a3a0ad09659544"}, + {file = "clickhouse_connect-0.7.19-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:4fdefe9eb2d38063835f8f1f326d666c3f61de9d6c3a1607202012c386ca7631"}, + {file = "clickhouse_connect-0.7.19-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ff6469822fe8d83f272ffbb3fb99dcc614e20b1d5cddd559505029052eff36e7"}, + {file = "clickhouse_connect-0.7.19-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46298e23f7e7829f0aa880a99837a82390c1371a643b21f8feb77702707b9eaa"}, + {file = "clickhouse_connect-0.7.19-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c6409390b13e09c19435ff65e2ebfcf01f9b2382e4b946191979a5d54ef8625c"}, + {file = "clickhouse_connect-0.7.19-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:cd7e7097b30b70eb695b7b3b6c79ba943548c053cc465fa74efa67a2354f6acd"}, + {file = "clickhouse_connect-0.7.19-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:15e080aead66e43c1f214b3e76ab26e3f342a4a4f50e3bbc3118bdd013d12e5f"}, + {file = "clickhouse_connect-0.7.19-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:194d2a32ba1b370cb5ac375dd4153871bb0394ff040344d8f449cb36ea951a96"}, + {file = "clickhouse_connect-0.7.19-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ac93aafd6a542fdcad4a2b6778575eab6dbdbf8806e86d92e1c1aa00d91cfee"}, + {file = "clickhouse_connect-0.7.19-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b208dd3e29db7154b02652c26157a1903bea03d27867ca5b749edc2285c62161"}, + {file = "clickhouse_connect-0.7.19-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:9724fdf3563b2335791443cb9e2114be7f77c20c8c4bbfb3571a3020606f0773"}, ] [package.dependencies] @@ -1460,115 +1588,115 @@ tzlocal = ["tzlocal (>=4.0)"] [[package]] name = "clickhouse-driver" -version = "0.2.8" +version = "0.2.9" description = "Python driver with native interface for ClickHouse" optional = false python-versions = "<4,>=3.7" files = [ - {file = "clickhouse-driver-0.2.8.tar.gz", hash = "sha256:844b3080e558acbacd42ee569ec83ca7aaa3728f7077b9314c8d09aaa393d752"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3a3a708e020ed2df59e424631f1822ffef4353912fcee143f3b7fc34e866621d"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d258d3c3ac0f03527e295eeaf3cebb0a976bc643f6817ccd1d0d71ce970641b4"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f63fb64a55dea29ed6a7d1d6805ebc95c37108c8a36677bc045d904ad600828"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1b16d5dbd53fe32a99d3c4ab6c478c8aa9ae02aec5a2bd2f24180b0b4c03e1a5"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ad2e1850ce91301ae203bc555fb83272dfebb09ad4df99db38c608d45fc22fa4"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ae9239f61a18050164185ec0a3e92469d084377a66ae033cc6b4efa15922867"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8f222f2577bf304e86eec73dbca9c19d7daa6abcafc0bef68bbf31dd461890b"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:451ac3de1191531d030751b05f122219b93b3c509e781fad81c2c91f0e9256b6"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:5a2c4fea88e91f1d5217b760ffea84631e647d8db2265b821cbe7b0e015c7807"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:19825a3044c48ab65dc6659eb9763e2f0821887bdd9ee14a2f9ae8c539281ebf"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:ae13044a10015225297868658a6f1843c2e34b9fcaa6268880e25c4fca9f3c4d"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:548a77efb86012800e76db6d45b3dcffea9a1a26fa3d5fd42021298f0b9a6f16"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-win32.whl", hash = "sha256:ebe4328eaaf937365114b5bab5626600ee57e57d4d099ba2ddbae48c2493f73d"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-win_amd64.whl", hash = "sha256:7beaeb4d7e6c3aba7e02375eeca85b20cc8e54dc31fcdb25d3c4308f2cd9465f"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8e06ef6bb701c8e42a9c686d77ad30805cf431bb79fa8fe0f4d3dee819e9a12c"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4afbcfa557419ed1783ecde3abbee1134e09b26c3ab0ada5b2118ae587357c2b"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85f628b4bf6db0fe8fe13da8576a9b95c23b463dff59f4c7aa58cedf529d7d97"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:036f4b3283796ca51610385c7b24bdac1bb873f8a2e97a179f66544594aa9840"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2c8916d3d324ce8fd31f8dedd293dc2c29204b94785a5398d1ec1e7ea4e16a26"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:30bee7cddd85c04ec49c753b53580364d907cc05c44daafe31b924a352e5e525"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:03c8a844f6b128348d099dc5d75fad70f4e85802d1649c1b835916ac94ae750a"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:33965329393fd7740b445758787ddacdf70f35fa3411f98a1a86918fff679a46"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:8cf85a7ebb0a56182c5b659602e20bae6b36c48a0edf518a6e6f56042d3fcee0"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:c10fd1f921ff82638cb9513b9b4acfb575b421c44ef6bf6cf57ee3c487b9d538"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:0a30d49bb6c34e3f5fe42e43dd6a7da0523ddfd05834ef02bd70b9363ea7de7e"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ea32c377a347b0801fc7f2b242f2ec7d78df58047097352672d0de5fbfa9e390"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-win32.whl", hash = "sha256:2a85529d1c0c3f2eedf7a4f736d0efc6e6c8032ac90ca5a63f7a067db58384fe"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-win_amd64.whl", hash = "sha256:1f438f83a7473ce7fe9c16cda8750e2fdda1b09fb87f0ec6b87a2b89acb13f24"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9b71bbef6ee08252cee0593329c8ca8e623547627807d38195331f476eaf8136"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f30b3dd388f28eb4052851effe671354db55aea87de748aaf607e7048f72413e"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e3bb27ce7ca61089c04dc04dbf207c9165d62a85eb9c99d1451fd686b6b773f9"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:59c04ec0b45602b6a63e0779ca7c3d3614be4710ec5ac7214da1b157d43527c5"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a61b14244993c7e0f312983455b7851576a85ab5a9fcc6374e75d2680a985e76"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c99a1b0b7759ccd1bf44c65210543c228ba704e3153014fd3aabfe56a227b1a5"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9f14d860088ab2c7eeb3782c9490ad3f6bf6b1e9235e9db9c3b0079cd4751ffa"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:303887a14a71faddcdee150bc8cde498c25c446b0a72ae586bd67d0c366dbff5"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:359814e4f989c138bfb83e3c81f8f88c8449721dcf32cb8cc25fdb86f4b53c99"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:42de61b4cf9053698b14dbe29e1e3d78cb0a7aaef874fd854df390de5c9cc1f1"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:3bf3089f220480e5a69cbec79f3b65c23afb5c2836e7285234140e5f237f2768"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:41daa4ae5ada22f10c758b0b3b477a51f5df56eef8569cff8e2275de6d9b1b96"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-win32.whl", hash = "sha256:03ea71c7167c6c38c3ba2bbed43615ce0c41ebf3bfa28d96ffcd93cd1cdd07d8"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-win_amd64.whl", hash = "sha256:76985286e10adb2115da116ae25647319bc485ad9e327cbc27296ccf0b052180"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:271529124914c439a5bbcf8a90e3101311d60c1813e03c0467e01fbabef489ee"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f8f499746bc027c6d05de09efa7b2e4f2241f66c1ac2d6b7748f90709b00e10"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f29f256520bb718c532e7fcd85250d4001f49acbaa9e6896bdf4a70d5557e2ef"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:104d062bdf7eab74e92efcbf72088b3241365242b4f119b3fe91057c4d80825c"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ee34ed08592a6eff5e176f42897c6ab4dfd8c07df16e9f392e18f1f2ee3fe3ca"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f5be9a8d89de881d5ea9d46f9d293caa72dbc7f40b105374cafd88f52b2099ea"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:c57efc768fa87e83d6778e7bbd180dd1ff5d647044983ec7d238a8577bd25fa5"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:e1a003475f2d54e9fea8de86b57bc26b409c9efea3d298409ab831f194d62c3b"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:fba71cf41934a23156290a70ef794a5dadc642b21cc25eb13e1f99f2512c8594"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:7289b0e9d1019fed418c577963edd66770222554d1da0c491ca436593667256e"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:16e810cc9be18fdada545b9a521054214dd607bb7aa2f280ca488da23a077e48"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-win32.whl", hash = "sha256:ed4a6590015f18f414250149255dc2ae81ae956b6e670b290d52c2ecb61ed517"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-win_amd64.whl", hash = "sha256:9d454f16ccf1b2185cc630f6fb2160b1abde27759c4e94c42e30b9ea911d58f0"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2e487d49c24448873a6802c34aa21858b9e3fb4a2605268a980a5c02b54a6bae"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e877de75b97ddb11a027a7499171ea0aa9cad569b18fce53c9d508353000cfae"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c60dcefddf6e2c65c92b7e6096c222ff6ed73b01b6c5712f9ce8a23f2ec80f1a"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:422cbbabfad3f9b533d9f517f6f4e174111a613cba878402f7ef632b0eadec3a"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ff8a8e25ff6051ff3d0528dbe36305b0140075d2fa49432149ee2a7841f23ed"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:19c7a5960d4f7f9a8f9a560ae05020ff5afe874b565cce06510586a0096bb626"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f5b3333257b46f307b713ba507e4bf11b7531ba3765a4150924532298d645ffd"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:bbc2252a697c674e1b8b6123cf205d2b15979eddf74e7ada0e62a0ecc81a75c3"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:af7f1a9a99dafb0f2a91d1a2d4a3e37f86076147d59abbe69b28d39308fe20fb"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:580c34cc505c492a8abeacbd863ce46158643bece914d8fe2fadea0e94c4e0c1"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:5b905eaa6fd3b453299f946a2c8f4a6392f379597e51e46297c6a37699226cda"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6e2b5891c52841aedf803b8054085eb8a611ad4bf57916787a1a9aabf618fb77"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-win32.whl", hash = "sha256:b58a5612db8b3577dc2ae6fda4c783d61c2376396bb364545530aa6a767f166d"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-win_amd64.whl", hash = "sha256:96b0424bb5dd698c10b899091562a78f4933a9a039409f310fb74db405d73854"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:22cbed52daa584ca9a93efd772ee5c8c1f68ceaaeb21673985004ec2fd411c49"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e36156fe8a355fc830cc0ea1267c804c631c9dbd9b6accdca868a426213e5929"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c1341325f4180e1318d0d2cf0b268008ea250715c6f30a5ccce586860c000b5"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cb52161276f7d77d4af09f1aab97a16edf86014a89e3d9923f0a6b8fdaa12438"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0d1ccd47040c0a8753684a20a0f83b8a0820386889fdf460a3248e0eed142032"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fcda48e938d011e5f4dcebf965e6ec19e020e8efa207b98eeb99c12fa873236d"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2252ab3f8b3bbd705e1d7dc80395c7bea14f5ae51a268fc7be5328da77c0e200"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:e1b9ef3fa0cc6c9de77daa74a2f183186d0b5556c4f6870fc966a41fde6cae2b"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:d0afa3c68fed6b5e6f23eb3f053d3aba86d09dbbc7706a0120ab5595d5c37003"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:102027bb687ff7a978f7110348f39f0dce450ab334787edbc64b8a9927238e32"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:9fc1ae52a171ded7d9f1f971b9b5bb0ce4d0490a54e102f3717cea51011d0308"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5a62c691be83b1da72ff3455790b50b0f894b7932ac962a8133f3f9c04c943b3"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-win32.whl", hash = "sha256:8b5068cef07cfba5be25a9a461c010ce7a0fe2de5b0b0262c6030684f43fa7f5"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-win_amd64.whl", hash = "sha256:cd71965d00b0f3ba992652d577b1d46b87100a67b3e0dc5c191c88092e484c81"}, - {file = "clickhouse_driver-0.2.8-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:4db0812c43f67e7b1805c05e2bc08f7d670ddfd8d8c671c9b47cdb52f4f74129"}, - {file = "clickhouse_driver-0.2.8-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56622ffefe94a82d9a30747e3486819104d1310d7a94f0e37da461d7112e9864"}, - {file = "clickhouse_driver-0.2.8-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c47c8ed61b2f35bb29d991f66d6e03d5cc786def56533480331b2a584854dd5"}, - {file = "clickhouse_driver-0.2.8-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dec001a1a49b993522dd134d2fca161352f139d42edcda0e983b8ea8f5023cda"}, - {file = "clickhouse_driver-0.2.8-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:c03bd486540a6c03aa5a164b7ec6c50980df9642ab1ce22cb70327e4090bdc60"}, - {file = "clickhouse_driver-0.2.8-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:c059c3da454f0cc0a6f056b542a0c1784cd0398613d25326b11fd1c6f9f7e8d2"}, - {file = "clickhouse_driver-0.2.8-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc7f9677c637b710046ec6c6c0cab25b4c4ff21620e44f462041d7455e9e8d13"}, - {file = "clickhouse_driver-0.2.8-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba3f6b8fdd7a2e6a831ebbcaaf346f7c8c5eb5085a350c9d4d1ce7053a050b70"}, - {file = "clickhouse_driver-0.2.8-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:20c2db3ae29950c80837d270b5ab63c74597afce226b474930060cac7969287b"}, - {file = "clickhouse_driver-0.2.8-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:b7767019a301dad314e7b515046535a45eda84bd9c29590bc3e99b1c334f69e7"}, - {file = "clickhouse_driver-0.2.8-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:ba8b8b80fa8850546aa40acc952835b1f149af17182cdf3db4f2133b2a241fe8"}, - {file = "clickhouse_driver-0.2.8-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:924f11e87e3dcbbc1c9e8158af9917f182cd5e96d37385485d6268f59b564142"}, - {file = "clickhouse_driver-0.2.8-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4c39e1477ad310a4d276db17c1e1cf6fb059c29eb8d21351afefd5a22de381c6"}, - {file = "clickhouse_driver-0.2.8-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e950b9a63af5fa233e3da0e57a7ebd85d4b319e65eef5f9daac84532836f4123"}, - {file = "clickhouse_driver-0.2.8-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:0698dc57373b2f42f3a95bd419d9fa07f2d02150f13a0db2909a2651208262b9"}, - {file = "clickhouse_driver-0.2.8-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:e0694ca2fb459c23e44036d975fe89544a7c9918618b5d8bda9a8aa2d24e5c37"}, - {file = "clickhouse_driver-0.2.8-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:62620348aeae5a905ccb8f7e6bff8d76aae9a95d81aa8c8f6fce0f2af7e104b8"}, - {file = "clickhouse_driver-0.2.8-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66276fd5092cccdd6f3123df4357a068fb1972b7e2622fab6f235948c50b6eed"}, - {file = "clickhouse_driver-0.2.8-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f86fe87327662b597824d0d7505cc600b0919473b22bbbd178a1a4d4e29283e1"}, - {file = "clickhouse_driver-0.2.8-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:54b9c6ff0aaabdcf7e80a6d9432459611b3413d6a66bec41cbcdad7212721cc7"}, + {file = "clickhouse-driver-0.2.9.tar.gz", hash = "sha256:050ea4870ead993910b39e7fae965dc1c347b2e8191dcd977cd4b385f9e19f87"}, + {file = "clickhouse_driver-0.2.9-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6ce04e9d0d0f39561f312d1ac1a8147bc9206e4267e1a23e20e0423ebac95534"}, + {file = "clickhouse_driver-0.2.9-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:7ae5c8931bf290b9d85582e7955b9aad7f19ff9954e48caa4f9a180ea4d01078"}, + {file = "clickhouse_driver-0.2.9-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e51792f3bd12c32cb15a907f12de3c9d264843f0bb33dce400e3966c9f09a3f"}, + {file = "clickhouse_driver-0.2.9-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:42fc546c31e4a04c97b749769335a679c9044dc693fa7a93e38c97fd6727173d"}, + {file = "clickhouse_driver-0.2.9-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6a383a403d185185c64e49edd6a19b2ec973c5adcb8ebff7ed2fc539a2cc65a5"}, + {file = "clickhouse_driver-0.2.9-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f05321a97e816afc75b3e4f9eda989848fecf14ecf1a91d0f22c04258123d1f7"}, + {file = "clickhouse_driver-0.2.9-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:be47e793846aac28442b6b1c6554e0731b848a5a7759a54aa2489997354efe4a"}, + {file = "clickhouse_driver-0.2.9-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:780e42a215d1ae2f6d695d74dd6f087781fb2fa51c508b58f79e68c24c5364e0"}, + {file = "clickhouse_driver-0.2.9-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:9e28f1fe850675e173db586e9f1ac790e8f7edd507a4227cd54cd7445f8e75b6"}, + {file = "clickhouse_driver-0.2.9-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:125aae7f1308d3083dadbb3c78f828ae492e060f13e4007a0cf53a8169ed7b39"}, + {file = "clickhouse_driver-0.2.9-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:2f3c4fbb61e75c62a1ab93a1070d362de4cb5682f82833b2c12deccb3bae888d"}, + {file = "clickhouse_driver-0.2.9-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0dc03196a84e32d23b88b665be69afae98f57426f5fdf203e16715b756757961"}, + {file = "clickhouse_driver-0.2.9-cp310-cp310-win32.whl", hash = "sha256:25695d78a1d7ad6e221e800612eac08559f6182bf6dee0a220d08de7b612d993"}, + {file = "clickhouse_driver-0.2.9-cp310-cp310-win_amd64.whl", hash = "sha256:367acac95398d721a0a2a6cf87e93638c5588b79498a9848676ce7f182540a6c"}, + {file = "clickhouse_driver-0.2.9-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5a7353a7a08eee3aa0001d8a5d771cb1f37e2acae1b48178002431f23892121a"}, + {file = "clickhouse_driver-0.2.9-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6af1c6cbc3481205503ab72a34aa76d6519249c904aa3f7a84b31e7b435555be"}, + {file = "clickhouse_driver-0.2.9-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48033803abd1100bfff6b9a1769d831b672cd3cda5147e0323b956fd1416d38d"}, + {file = "clickhouse_driver-0.2.9-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1f202a58a540c85e47c31dabc8f84b6fe79dca5315c866450a538d58d6fa0571"}, + {file = "clickhouse_driver-0.2.9-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e4df50fd84bfa4aa1eb7b52d48136066bfb64fabb7ceb62d4c318b45a296200b"}, + {file = "clickhouse_driver-0.2.9-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:433a650571a0d7766eb6f402e8f5930222997686c2ee01ded22f1d8fd46af9d4"}, + {file = "clickhouse_driver-0.2.9-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:232ee260475611cbf7adb554b81db6b5790b36e634fe2164f4ffcd2ca3e63a71"}, + {file = "clickhouse_driver-0.2.9-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:09049f7e71f15c9c9a03f597f77fc1f7b61ababd155c06c0d9e64d1453d945d7"}, + {file = "clickhouse_driver-0.2.9-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:424153d1d5f5a807f596a48cc88119f9fb3213ca7e38f57b8d15dcc964dd91f7"}, + {file = "clickhouse_driver-0.2.9-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:4f078fd1cf19c4ca63b8d1e0803df665310c8d5b644c5b02bf2465e8d6ef8f55"}, + {file = "clickhouse_driver-0.2.9-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:f138d939e26e767537f891170b69a55a88038919f5c10d8865b67b8777fe4848"}, + {file = "clickhouse_driver-0.2.9-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9aafabc7e32942f85dcb46f007f447ab69024831575df97cae28c6ed127654d1"}, + {file = "clickhouse_driver-0.2.9-cp311-cp311-win32.whl", hash = "sha256:935e16ebf1a1998d8493979d858821a755503c9b8af572d9c450173d4b88868c"}, + {file = "clickhouse_driver-0.2.9-cp311-cp311-win_amd64.whl", hash = "sha256:306b3102cba278b5dfec6f5f7dc8b78416c403901510475c74913345b56c9e42"}, + {file = "clickhouse_driver-0.2.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:fcb2fd00e58650ae206a6d5dbc83117240e622471aa5124733fbf2805eb8bda0"}, + {file = "clickhouse_driver-0.2.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b7a3e6b0a1eb218e3d870a94c76daaf65da46dca8f6888ea6542f94905c24d88"}, + {file = "clickhouse_driver-0.2.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a8d8e2888a857d8db3d98765a5ad23ab561241feaef68bbffc5a0bd9c142342"}, + {file = "clickhouse_driver-0.2.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:85d50c011467f5ff6772c4059345968b854b72e07a0219030b7c3f68419eb7f7"}, + {file = "clickhouse_driver-0.2.9-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:93b395c1370629ccce8fb3e14cd5be2646d227bd32018c21f753c543e9a7e96b"}, + {file = "clickhouse_driver-0.2.9-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6dbcee870c60d9835e5dce1456ab6b9d807e6669246357f4b321ef747b90fa43"}, + {file = "clickhouse_driver-0.2.9-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fffa5a5f317b1ec92e406a30a008929054cf3164d2324a3c465d0a0330273bf8"}, + {file = "clickhouse_driver-0.2.9-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:476702740a279744badbd177ae1c4a2d089ec128bd676861219d1f92078e4530"}, + {file = "clickhouse_driver-0.2.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:5cd6d95fab5ff80e9dc9baedc9a926f62f74072d42d5804388d63b63bec0bb63"}, + {file = "clickhouse_driver-0.2.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:05027d32d7cf3e46cb8d04f8c984745ae01bd1bc7b3579f9dadf9b3cca735697"}, + {file = "clickhouse_driver-0.2.9-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:3d11831842250b4c1b26503a6e9c511fc03db096608b7c6af743818c421a3032"}, + {file = "clickhouse_driver-0.2.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:81b4b671b785ebb0b8aeabf2432e47072413d81db959eb8cfd8b6ab58c5799c6"}, + {file = "clickhouse_driver-0.2.9-cp312-cp312-win32.whl", hash = "sha256:e893bd4e014877174a59e032b0e99809c95ec61328a0e6bd9352c74a2f6111a8"}, + {file = "clickhouse_driver-0.2.9-cp312-cp312-win_amd64.whl", hash = "sha256:de6624e28eeffd01668803d28ae89e3d4e359b1bff8b60e4933e1cb3c6f86f18"}, + {file = "clickhouse_driver-0.2.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:909205324089a9ee59bee7ecbfa94595435118cca310fd62efdf13f225aa2965"}, + {file = "clickhouse_driver-0.2.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03f31d6e47dc2b0f367f598f5629147ed056d7216c1788e25190fcfbfa02e749"}, + {file = "clickhouse_driver-0.2.9-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ed84179914b2b7bb434c2322a6e7fd83daa681c97a050450511b66d917a129bb"}, + {file = "clickhouse_driver-0.2.9-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:67d1bf63efb4ba14ae6c6da99622e4a549e68fc3ee14d859bf611d8e6a61b3fa"}, + {file = "clickhouse_driver-0.2.9-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9eed23ea41dd582d76f7a2ec7e09cbe5e9fec008f11a4799fa35ce44a3ebd283"}, + {file = "clickhouse_driver-0.2.9-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a654291132766efa2703058317749d7c69b69f02d89bac75703eaf7f775e20da"}, + {file = "clickhouse_driver-0.2.9-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:1c26c5ef16d0ef3cabc5bc03e827e01b0a4afb5b4eaf8850b7cf740cee04a1d4"}, + {file = "clickhouse_driver-0.2.9-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:b57e83d7986d3cbda6096974a9510eb53cb33ad9072288c87c820ba5eee3370e"}, + {file = "clickhouse_driver-0.2.9-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:153cc03b36f22cbde55aa6a5bbe99072a025567a54c48b262eb0da15d8cd7c83"}, + {file = "clickhouse_driver-0.2.9-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:83a857d99192936091f495826ae97497cd1873af213b1e069d56369fb182ab8e"}, + {file = "clickhouse_driver-0.2.9-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:bb05a9bb22cbe9ad187ad268f86adf7e60df6083331fe59c01571b7b725212dd"}, + {file = "clickhouse_driver-0.2.9-cp37-cp37m-win32.whl", hash = "sha256:3e282c5c25e32d96ed151e5460d2bf4ecb805ea64449197dd918e84e768016df"}, + {file = "clickhouse_driver-0.2.9-cp37-cp37m-win_amd64.whl", hash = "sha256:c46dccfb04a9afd61a1b0e60bfefceff917f76da2c863f9b36b39248496d5c77"}, + {file = "clickhouse_driver-0.2.9-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:612ca9028c718f362c97f552e63d313cf1a70a616ef8532ddb0effdaf12ebef9"}, + {file = "clickhouse_driver-0.2.9-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:471b884d318e012f68d858476052742048918854f7dfe87d78e819f87a848ffb"}, + {file = "clickhouse_driver-0.2.9-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58ee63c35e99da887eb035c8d6d9e64fd298a0efc1460395297dd5cc281a6912"}, + {file = "clickhouse_driver-0.2.9-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0819bb63d2c5025a1fb9589f57ef82602687cef11081d6dfa6f2ce44606a1772"}, + {file = "clickhouse_driver-0.2.9-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f6680ee18870bca1fbab1736c8203a965efaec119ab4c37821ad99add248ee08"}, + {file = "clickhouse_driver-0.2.9-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:713c498741b54debd3a10a5529e70b6ed85ca33c3e8629e24ae5cd8160b5a5f2"}, + {file = "clickhouse_driver-0.2.9-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:730837b8f63941065c9c955c44286aef0987fb084ffb3f55bf1e4fe07df62269"}, + {file = "clickhouse_driver-0.2.9-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:9f4e38b2ea09214c8e7848a19391009a18c56a3640e1ba1a606b9e57aeb63404"}, + {file = "clickhouse_driver-0.2.9-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:457f1d6639e0345b717ae603c79bd087a35361ce68c1c308d154b80b841e5e7d"}, + {file = "clickhouse_driver-0.2.9-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:49a55aeb8ea625a87965a96e361bbb1ad67d0931bfb2a575f899c1064e70c2da"}, + {file = "clickhouse_driver-0.2.9-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:9230058d8c9b1a04079afae4650fb67745f0f1c39db335728f64d48bd2c19246"}, + {file = "clickhouse_driver-0.2.9-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8798258bd556542dd9c6b8ebe62f9c5110c9dcdf97c57fb077e7b8b6d6da0826"}, + {file = "clickhouse_driver-0.2.9-cp38-cp38-win32.whl", hash = "sha256:ce8e3f4be46bcc63555863f70ab0035202b082b37e6f16876ef50e7bc4b47056"}, + {file = "clickhouse_driver-0.2.9-cp38-cp38-win_amd64.whl", hash = "sha256:2d982959ff628255808d895a67493f2dab0c3a9bfc65eeda0f00c8ae9962a1b3"}, + {file = "clickhouse_driver-0.2.9-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a46b227fab4420566ed24ee70d90076226d16fcf09c6ad4d428717efcf536446"}, + {file = "clickhouse_driver-0.2.9-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7eaa2ce5ea08cf5fddebb8c274c450e102f329f9e6966b6cd85aa671c48e5552"}, + {file = "clickhouse_driver-0.2.9-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f97f0083194d6e23b5ef6156ed0d5388c37847b298118199d7937ba26412a9e2"}, + {file = "clickhouse_driver-0.2.9-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a6cab5cdbb0f8ee51d879d977b78f07068b585225ac656f3c081896c362e8f83"}, + {file = "clickhouse_driver-0.2.9-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cdb1b011a53ee71539e9dc655f268b111bac484db300da92829ed59e910a8fd0"}, + {file = "clickhouse_driver-0.2.9-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7bf51bb761b281d20910b4b689c699ef98027845467daa5bb5dfdb53bd6ee404"}, + {file = "clickhouse_driver-0.2.9-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8ea462e3cebb121ff55002e9c8a9a0a3fd9b5bbbf688b4960f0a83c0172fb31"}, + {file = "clickhouse_driver-0.2.9-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:70bee21c245226ad0d637bf470472e2d487b86911b6d673a862127b934336ff4"}, + {file = "clickhouse_driver-0.2.9-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:253a3c223b944d691bf0abbd599f592ea3b36f0a71d2526833b1718f37eca5c2"}, + {file = "clickhouse_driver-0.2.9-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:a6549b53fc5c403dc556cb39b2ae94d73f9b113daa00438a660bb1dd5380ae4d"}, + {file = "clickhouse_driver-0.2.9-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:1c685cd4abe61af1c26279ff04b9f567eb4d6c1ec7fb265af7481b1f153043aa"}, + {file = "clickhouse_driver-0.2.9-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:7e25144219577491929d032a6c3ddd63c6cd7fa764af829a5637f798190d9b26"}, + {file = "clickhouse_driver-0.2.9-cp39-cp39-win32.whl", hash = "sha256:0b9925610d25405a8e6d83ff4f54fc2456a121adb0155999972f5edd6ba3efc8"}, + {file = "clickhouse_driver-0.2.9-cp39-cp39-win_amd64.whl", hash = "sha256:b243de483cfa02716053b0148d73558f4694f3c27b97fc1eaa97d7079563a14d"}, + {file = "clickhouse_driver-0.2.9-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:45a3d5b1d06750fd6a18c29b871494a2635670099ec7693e756a5885a4a70dbf"}, + {file = "clickhouse_driver-0.2.9-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8415ffebd6ca9eef3024763abc450f8659f1716d015bd563c537d01c7fbc3569"}, + {file = "clickhouse_driver-0.2.9-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ace48db993aa4bd31c42de0fa8d38c94ad47405916d6b61f7a7168a48fb52ac1"}, + {file = "clickhouse_driver-0.2.9-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b07123334fe143bfe6fa4e3d4b732d647d5fd2cfb9ec7f2f76104b46fe9d20c6"}, + {file = "clickhouse_driver-0.2.9-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:e2af3efa73d296420ce6362789f5b1febf75d4aa159a479393f01549115509d5"}, + {file = "clickhouse_driver-0.2.9-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:baf57eede88d07a1eb04352d26fc58a4d97991ca3d8840f7c5d48691dec9f251"}, + {file = "clickhouse_driver-0.2.9-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:275d0ccdab9c3571bdb3e9acfab4497930aa584ff2766b035bb2f854deaf8b82"}, + {file = "clickhouse_driver-0.2.9-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:293da77bfcac3168fb35b27c242f97c1a05502435c0686ecbb8e2e4abcb3de26"}, + {file = "clickhouse_driver-0.2.9-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8d6c2e5830705e4eeef33070ca4d5a24dfa221f28f2f540e5e6842c26e70b10b"}, + {file = "clickhouse_driver-0.2.9-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:11934bd78d97dd7e1a23a6222b5edd1e1b4d34e1ead5c846dc2b5c56fdc35ff5"}, + {file = "clickhouse_driver-0.2.9-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:b802b6f0fbdcc3ab81b87f09b694dde91ab049f44d1d2c08c3dc8ea9a5950cfa"}, + {file = "clickhouse_driver-0.2.9-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7af871c5315eb829ecf4533c790461ea8f73b3bfd5f533b0467e479fdf6ddcfd"}, + {file = "clickhouse_driver-0.2.9-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d577dd4867b9e26cf60590e1f500990c8701a6e3cfbb9e644f4d0c0fb607028"}, + {file = "clickhouse_driver-0.2.9-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2ed3dea2d1eca85fef5b8564ddd76dedb15a610c77d55d555b49d9f7c896b64b"}, + {file = "clickhouse_driver-0.2.9-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:91ec96f2c48e5bdeac9eea43a9bc9cc19acb2d2c59df0a13d5520dfc32457605"}, + {file = "clickhouse_driver-0.2.9-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:7667ab423452754f36ba8fb41e006a46baace9c94e2aca2a745689b9f2753dfb"}, + {file = "clickhouse_driver-0.2.9-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:653583b1f3b088d106f180d6f02c90917ecd669ec956b62903a05df4a7f44863"}, + {file = "clickhouse_driver-0.2.9-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ef3dd0cbdf2f0171caab90389af0ede068ec802bf46c6a77f14e6edc86671bc"}, + {file = "clickhouse_driver-0.2.9-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:11b1833ee8ff8d5df39a34a895e060b57bd81e05ea68822bc60476daff4ce1c8"}, + {file = "clickhouse_driver-0.2.9-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:8a3195639e6393b9d4aafe736036881ff86b6be5855d4bf7d9f5c31637181ec3"}, ] [package.dependencies] @@ -1943,13 +2071,13 @@ typing-inspect = ">=0.4.0,<1" [[package]] name = "db-dtypes" -version = "1.2.0" +version = "1.3.0" description = "Pandas Data Types for SQL systems (BigQuery, Spanner)" optional = false python-versions = ">=3.7" files = [ - {file = "db-dtypes-1.2.0.tar.gz", hash = "sha256:3531bb1fb8b5fbab33121fe243ccc2ade16ab2524f4c113b05cc702a1908e6ea"}, - {file = "db_dtypes-1.2.0-py2.py3-none-any.whl", hash = "sha256:6320bddd31d096447ef749224d64aab00972ed20e4392d86f7d8b81ad79f7ff0"}, + {file = "db_dtypes-1.3.0-py2.py3-none-any.whl", hash = "sha256:7e65c59f849ccbe6f7bc4d0253edcc212a7907662906921caba3e4aadd0bc277"}, + {file = "db_dtypes-1.3.0.tar.gz", hash = "sha256:7bcbc8858b07474dc85b77bb2f3ae488978d1336f5ea73b58c39d9118bc3e91b"}, ] [package.dependencies] @@ -2083,21 +2211,21 @@ files = [ [[package]] name = "duckduckgo-search" -version = "6.2.6" +version = "6.2.10" description = "Search for words, documents, images, news, maps and text translation using the DuckDuckGo.com search engine." optional = false python-versions = ">=3.8" files = [ - {file = "duckduckgo_search-6.2.6-py3-none-any.whl", hash = "sha256:c8171bcd6ff4d051f78c70ea23bd34c0d8e779d72973829d3a6b40ccc05cd7c2"}, - {file = "duckduckgo_search-6.2.6.tar.gz", hash = "sha256:96529ecfbd55afa28705b38413003cb3cfc620e55762d33184887545de27dc96"}, + {file = "duckduckgo_search-6.2.10-py3-none-any.whl", hash = "sha256:266c1528dcbc90931b7c800a2c1041a0cb447c83c485414d77a7e443be717ed6"}, + {file = "duckduckgo_search-6.2.10.tar.gz", hash = "sha256:53057368480ca496fc4e331a34648124711580cf43fbb65336eaa6fd2ee37cec"}, ] [package.dependencies] click = ">=8.1.7" -primp = ">=0.5.5" +primp = ">=0.6.1" [package.extras] -dev = ["mypy (>=1.11.0)", "pytest (>=8.3.1)", "pytest-asyncio (>=0.23.8)", "ruff (>=0.5.5)"] +dev = ["mypy (>=1.11.1)", "pytest (>=8.3.1)", "pytest-asyncio (>=0.23.8)", "ruff (>=0.6.1)"] lxml = ["lxml (>=5.2.2)"] [[package]] @@ -2203,18 +2331,18 @@ test = ["pytest (>=6)"] [[package]] name = "fastapi" -version = "0.112.0" +version = "0.112.1" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" optional = false python-versions = ">=3.8" files = [ - {file = "fastapi-0.112.0-py3-none-any.whl", hash = "sha256:3487ded9778006a45834b8c816ec4a48d522e2631ca9e75ec5a774f1b052f821"}, - {file = "fastapi-0.112.0.tar.gz", hash = "sha256:d262bc56b7d101d1f4e8fc0ad2ac75bb9935fec504d2b7117686cec50710cf05"}, + {file = "fastapi-0.112.1-py3-none-any.whl", hash = "sha256:bcbd45817fc2a1cd5da09af66815b84ec0d3d634eb173d1ab468ae3103e183e4"}, + {file = "fastapi-0.112.1.tar.gz", hash = "sha256:b2537146f8c23389a7faa8b03d0bd38d4986e6983874557d95eed2acc46448ef"}, ] [package.dependencies] pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0" -starlette = ">=0.37.2,<0.38.0" +starlette = ">=0.37.2,<0.39.0" typing-extensions = ">=4.8.0" [package.extras] @@ -3174,13 +3302,13 @@ dev = ["Pillow", "absl-py", "black", "ipython", "nose2", "pandas", "pytype", "py [[package]] name = "google-resumable-media" -version = "2.7.1" +version = "2.7.2" description = "Utilities for Google Media Downloads and Resumable Uploads" optional = false python-versions = ">=3.7" files = [ - {file = "google-resumable-media-2.7.1.tar.gz", hash = "sha256:eae451a7b2e2cdbaaa0fd2eb00cc8a1ee5e95e16b55597359cbc3d27d7d90e33"}, - {file = "google_resumable_media-2.7.1-py2.py3-none-any.whl", hash = "sha256:103ebc4ba331ab1bfdac0250f8033627a2cd7cde09e7ccff9181e31ba4315b2c"}, + {file = "google_resumable_media-2.7.2-py2.py3-none-any.whl", hash = "sha256:3ce7551e9fe6d99e9a126101d2536612bb73486721951e9562fee0f90c6ababa"}, + {file = "google_resumable_media-2.7.2.tar.gz", hash = "sha256:5280aed4629f2b60b847b0d42f9857fd4935c11af266744df33d8074cae92fe0"}, ] [package.dependencies] @@ -3811,18 +3939,22 @@ test = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "p [[package]] name = "importlib-resources" -version = "6.4.0" +version = "6.4.4" description = "Read resources from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_resources-6.4.0-py3-none-any.whl", hash = "sha256:50d10f043df931902d4194ea07ec57960f66a80449ff867bfe782b4c486ba78c"}, - {file = "importlib_resources-6.4.0.tar.gz", hash = "sha256:cdb2b453b8046ca4e3798eb1d84f3cce1446a0e8e7b5ef4efb600f19fc398145"}, + {file = "importlib_resources-6.4.4-py3-none-any.whl", hash = "sha256:dda242603d1c9cd836c3368b1174ed74cb4049ecd209e7a1a0104620c18c5c11"}, + {file = "importlib_resources-6.4.4.tar.gz", hash = "sha256:20600c8b7361938dc0bb2d5ec0297802e575df486f5a544fa414da65e13721f7"}, ] [package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] -testing = ["jaraco.test (>=5.4)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)", "zipp (>=3.17)"] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] +cover = ["pytest-cov"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +enabler = ["pytest-enabler (>=2.2)"] +test = ["jaraco.test (>=5.4)", "pytest (>=6,!=8.1.*)", "zipp (>=3.17)"] +type = ["pytest-mypy"] [[package]] name = "iniconfig" @@ -3933,6 +4065,41 @@ files = [ [package.dependencies] ply = "*" +[[package]] +name = "jsonschema" +version = "4.23.0" +description = "An implementation of JSON Schema validation for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "jsonschema-4.23.0-py3-none-any.whl", hash = "sha256:fbadb6f8b144a8f8cf9f0b89ba94501d143e50411a1278633f56a7acf7fd5566"}, + {file = "jsonschema-4.23.0.tar.gz", hash = "sha256:d71497fef26351a33265337fa77ffeb82423f3ea21283cd9467bb03999266bc4"}, +] + +[package.dependencies] +attrs = ">=22.2.0" +jsonschema-specifications = ">=2023.03.6" +referencing = ">=0.28.4" +rpds-py = ">=0.7.1" + +[package.extras] +format = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3987", "uri-template", "webcolors (>=1.11)"] +format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "uri-template", "webcolors (>=24.6.0)"] + +[[package]] +name = "jsonschema-specifications" +version = "2023.12.1" +description = "The JSON Schema meta-schemas and vocabularies, exposed as a Registry" +optional = false +python-versions = ">=3.8" +files = [ + {file = "jsonschema_specifications-2023.12.1-py3-none-any.whl", hash = "sha256:87e4fdf3a94858b8a2ba2778d9ba57d8a9cafca7c7489c46ba0d30a8bc6a9c3c"}, + {file = "jsonschema_specifications-2023.12.1.tar.gz", hash = "sha256:48a76787b3e70f5ed53f1160d2b81f586e4ca6d1548c5de7085d1682674764cc"}, +] + +[package.dependencies] +referencing = ">=0.31.0" + [[package]] name = "kaleido" version = "0.2.1" @@ -4135,13 +4302,13 @@ six = "*" [[package]] name = "langfuse" -version = "2.42.1" +version = "2.44.0" description = "A client library for accessing langfuse" optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langfuse-2.42.1-py3-none-any.whl", hash = "sha256:8895d9645aea91815db51565f90e110a76d5e157a7b12eaf1cd6959e7aaa2263"}, - {file = "langfuse-2.42.1.tar.gz", hash = "sha256:f89faf1c14308d488c90f8b7d0368fff3d259f80ffe34d169b9cfc3f0dbfab82"}, + {file = "langfuse-2.44.0-py3-none-any.whl", hash = "sha256:adb73400a6ad6d597cc95c31381c82f81face3d5fb69391181f224a26f7e8562"}, + {file = "langfuse-2.44.0.tar.gz", hash = "sha256:dfa5378ff7022ae9fe5b8b842c0365347c98f9ef2b772dcee6a93a45442de28c"}, ] [package.dependencies] @@ -4160,16 +4327,17 @@ openai = ["openai (>=0.27.8)"] [[package]] name = "langsmith" -version = "0.1.98" +version = "0.1.101" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langsmith-0.1.98-py3-none-any.whl", hash = "sha256:f79e8a128652bbcee4606d10acb6236973b5cd7dde76e3741186d3b97b5698e9"}, - {file = "langsmith-0.1.98.tar.gz", hash = "sha256:e07678219a0502e8f26d35294e72127a39d25e32fafd091af5a7bb661e9a6bd1"}, + {file = "langsmith-0.1.101-py3-none-any.whl", hash = "sha256:572e2c90709cda1ad837ac86cedda7295f69933f2124c658a92a35fb890477cc"}, + {file = "langsmith-0.1.101.tar.gz", hash = "sha256:caf4d95f314bb6cd3c4e0632eed821fd5cd5d0f18cb824772fce6d7a9113895b"}, ] [package.dependencies] +httpx = ">=0.23.0,<1" orjson = ">=3.9.14,<4.0.0" pydantic = [ {version = ">=1,<3", markers = "python_full_version < \"3.12.4\""}, @@ -4209,153 +4377,149 @@ files = [ [[package]] name = "lxml" -version = "5.2.2" +version = "5.3.0" description = "Powerful and Pythonic XML processing library combining libxml2/libxslt with the ElementTree API." optional = false python-versions = ">=3.6" files = [ - {file = "lxml-5.2.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:364d03207f3e603922d0d3932ef363d55bbf48e3647395765f9bfcbdf6d23632"}, - {file = "lxml-5.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:50127c186f191b8917ea2fb8b206fbebe87fd414a6084d15568c27d0a21d60db"}, - {file = "lxml-5.2.2-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:74e4f025ef3db1c6da4460dd27c118d8cd136d0391da4e387a15e48e5c975147"}, - {file = "lxml-5.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:981a06a3076997adf7c743dcd0d7a0415582661e2517c7d961493572e909aa1d"}, - {file = "lxml-5.2.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aef5474d913d3b05e613906ba4090433c515e13ea49c837aca18bde190853dff"}, - {file = "lxml-5.2.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1e275ea572389e41e8b039ac076a46cb87ee6b8542df3fff26f5baab43713bca"}, - {file = "lxml-5.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5b65529bb2f21ac7861a0e94fdbf5dc0daab41497d18223b46ee8515e5ad297"}, - {file = "lxml-5.2.2-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:bcc98f911f10278d1daf14b87d65325851a1d29153caaf146877ec37031d5f36"}, - {file = "lxml-5.2.2-cp310-cp310-manylinux_2_28_ppc64le.whl", hash = "sha256:b47633251727c8fe279f34025844b3b3a3e40cd1b198356d003aa146258d13a2"}, - {file = "lxml-5.2.2-cp310-cp310-manylinux_2_28_s390x.whl", hash = "sha256:fbc9d316552f9ef7bba39f4edfad4a734d3d6f93341232a9dddadec4f15d425f"}, - {file = "lxml-5.2.2-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:13e69be35391ce72712184f69000cda04fc89689429179bc4c0ae5f0b7a8c21b"}, - {file = "lxml-5.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:3b6a30a9ab040b3f545b697cb3adbf3696c05a3a68aad172e3fd7ca73ab3c835"}, - {file = "lxml-5.2.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:a233bb68625a85126ac9f1fc66d24337d6e8a0f9207b688eec2e7c880f012ec0"}, - {file = "lxml-5.2.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:dfa7c241073d8f2b8e8dbc7803c434f57dbb83ae2a3d7892dd068d99e96efe2c"}, - {file = "lxml-5.2.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:1a7aca7964ac4bb07680d5c9d63b9d7028cace3e2d43175cb50bba8c5ad33316"}, - {file = "lxml-5.2.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ae4073a60ab98529ab8a72ebf429f2a8cc612619a8c04e08bed27450d52103c0"}, - {file = "lxml-5.2.2-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:ffb2be176fed4457e445fe540617f0252a72a8bc56208fd65a690fdb1f57660b"}, - {file = "lxml-5.2.2-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:e290d79a4107d7d794634ce3e985b9ae4f920380a813717adf61804904dc4393"}, - {file = "lxml-5.2.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:96e85aa09274955bb6bd483eaf5b12abadade01010478154b0ec70284c1b1526"}, - {file = "lxml-5.2.2-cp310-cp310-win32.whl", hash = "sha256:f956196ef61369f1685d14dad80611488d8dc1ef00be57c0c5a03064005b0f30"}, - {file = "lxml-5.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:875a3f90d7eb5c5d77e529080d95140eacb3c6d13ad5b616ee8095447b1d22e7"}, - {file = "lxml-5.2.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:45f9494613160d0405682f9eee781c7e6d1bf45f819654eb249f8f46a2c22545"}, - {file = "lxml-5.2.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b0b3f2df149efb242cee2ffdeb6674b7f30d23c9a7af26595099afaf46ef4e88"}, - {file = "lxml-5.2.2-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d28cb356f119a437cc58a13f8135ab8a4c8ece18159eb9194b0d269ec4e28083"}, - {file = "lxml-5.2.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:657a972f46bbefdbba2d4f14413c0d079f9ae243bd68193cb5061b9732fa54c1"}, - {file = "lxml-5.2.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b74b9ea10063efb77a965a8d5f4182806fbf59ed068b3c3fd6f30d2ac7bee734"}, - {file = "lxml-5.2.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:07542787f86112d46d07d4f3c4e7c760282011b354d012dc4141cc12a68cef5f"}, - {file = "lxml-5.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:303f540ad2dddd35b92415b74b900c749ec2010e703ab3bfd6660979d01fd4ed"}, - {file = "lxml-5.2.2-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:2eb2227ce1ff998faf0cd7fe85bbf086aa41dfc5af3b1d80867ecfe75fb68df3"}, - {file = "lxml-5.2.2-cp311-cp311-manylinux_2_28_ppc64le.whl", hash = "sha256:1d8a701774dfc42a2f0b8ccdfe7dbc140500d1049e0632a611985d943fcf12df"}, - {file = "lxml-5.2.2-cp311-cp311-manylinux_2_28_s390x.whl", hash = "sha256:56793b7a1a091a7c286b5f4aa1fe4ae5d1446fe742d00cdf2ffb1077865db10d"}, - {file = "lxml-5.2.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:eb00b549b13bd6d884c863554566095bf6fa9c3cecb2e7b399c4bc7904cb33b5"}, - {file = "lxml-5.2.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1a2569a1f15ae6c8c64108a2cd2b4a858fc1e13d25846be0666fc144715e32ab"}, - {file = "lxml-5.2.2-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:8cf85a6e40ff1f37fe0f25719aadf443686b1ac7652593dc53c7ef9b8492b115"}, - {file = "lxml-5.2.2-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:d237ba6664b8e60fd90b8549a149a74fcc675272e0e95539a00522e4ca688b04"}, - {file = "lxml-5.2.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0b3f5016e00ae7630a4b83d0868fca1e3d494c78a75b1c7252606a3a1c5fc2ad"}, - {file = "lxml-5.2.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:23441e2b5339bc54dc949e9e675fa35efe858108404ef9aa92f0456929ef6fe8"}, - {file = "lxml-5.2.2-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:2fb0ba3e8566548d6c8e7dd82a8229ff47bd8fb8c2da237607ac8e5a1b8312e5"}, - {file = "lxml-5.2.2-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:79d1fb9252e7e2cfe4de6e9a6610c7cbb99b9708e2c3e29057f487de5a9eaefa"}, - {file = "lxml-5.2.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:6dcc3d17eac1df7859ae01202e9bb11ffa8c98949dcbeb1069c8b9a75917e01b"}, - {file = "lxml-5.2.2-cp311-cp311-win32.whl", hash = "sha256:4c30a2f83677876465f44c018830f608fa3c6a8a466eb223535035fbc16f3438"}, - {file = "lxml-5.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:49095a38eb333aaf44c06052fd2ec3b8f23e19747ca7ec6f6c954ffea6dbf7be"}, - {file = "lxml-5.2.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:7429e7faa1a60cad26ae4227f4dd0459efde239e494c7312624ce228e04f6391"}, - {file = "lxml-5.2.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:50ccb5d355961c0f12f6cf24b7187dbabd5433f29e15147a67995474f27d1776"}, - {file = "lxml-5.2.2-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dc911208b18842a3a57266d8e51fc3cfaccee90a5351b92079beed912a7914c2"}, - {file = "lxml-5.2.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:33ce9e786753743159799fdf8e92a5da351158c4bfb6f2db0bf31e7892a1feb5"}, - {file = "lxml-5.2.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ec87c44f619380878bd49ca109669c9f221d9ae6883a5bcb3616785fa8f94c97"}, - {file = "lxml-5.2.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:08ea0f606808354eb8f2dfaac095963cb25d9d28e27edcc375d7b30ab01abbf6"}, - {file = "lxml-5.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75a9632f1d4f698b2e6e2e1ada40e71f369b15d69baddb8968dcc8e683839b18"}, - {file = "lxml-5.2.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:74da9f97daec6928567b48c90ea2c82a106b2d500f397eeb8941e47d30b1ca85"}, - {file = "lxml-5.2.2-cp312-cp312-manylinux_2_28_ppc64le.whl", hash = "sha256:0969e92af09c5687d769731e3f39ed62427cc72176cebb54b7a9d52cc4fa3b73"}, - {file = "lxml-5.2.2-cp312-cp312-manylinux_2_28_s390x.whl", hash = "sha256:9164361769b6ca7769079f4d426a41df6164879f7f3568be9086e15baca61466"}, - {file = "lxml-5.2.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:d26a618ae1766279f2660aca0081b2220aca6bd1aa06b2cf73f07383faf48927"}, - {file = "lxml-5.2.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab67ed772c584b7ef2379797bf14b82df9aa5f7438c5b9a09624dd834c1c1aaf"}, - {file = "lxml-5.2.2-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:3d1e35572a56941b32c239774d7e9ad724074d37f90c7a7d499ab98761bd80cf"}, - {file = "lxml-5.2.2-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:8268cbcd48c5375f46e000adb1390572c98879eb4f77910c6053d25cc3ac2c67"}, - {file = "lxml-5.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e282aedd63c639c07c3857097fc0e236f984ceb4089a8b284da1c526491e3f3d"}, - {file = "lxml-5.2.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6dfdc2bfe69e9adf0df4915949c22a25b39d175d599bf98e7ddf620a13678585"}, - {file = "lxml-5.2.2-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:4aefd911793b5d2d7a921233a54c90329bf3d4a6817dc465f12ffdfe4fc7b8fe"}, - {file = "lxml-5.2.2-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:8b8df03a9e995b6211dafa63b32f9d405881518ff1ddd775db4e7b98fb545e1c"}, - {file = "lxml-5.2.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:f11ae142f3a322d44513de1018b50f474f8f736bc3cd91d969f464b5bfef8836"}, - {file = "lxml-5.2.2-cp312-cp312-win32.whl", hash = "sha256:16a8326e51fcdffc886294c1e70b11ddccec836516a343f9ed0f82aac043c24a"}, - {file = "lxml-5.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:bbc4b80af581e18568ff07f6395c02114d05f4865c2812a1f02f2eaecf0bfd48"}, - {file = "lxml-5.2.2-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:e3d9d13603410b72787579769469af730c38f2f25505573a5888a94b62b920f8"}, - {file = "lxml-5.2.2-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:38b67afb0a06b8575948641c1d6d68e41b83a3abeae2ca9eed2ac59892b36706"}, - {file = "lxml-5.2.2-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c689d0d5381f56de7bd6966a4541bff6e08bf8d3871bbd89a0c6ab18aa699573"}, - {file = "lxml-5.2.2-cp36-cp36m-manylinux_2_28_x86_64.whl", hash = "sha256:cf2a978c795b54c539f47964ec05e35c05bd045db5ca1e8366988c7f2fe6b3ce"}, - {file = "lxml-5.2.2-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:739e36ef7412b2bd940f75b278749106e6d025e40027c0b94a17ef7968d55d56"}, - {file = "lxml-5.2.2-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:d8bbcd21769594dbba9c37d3c819e2d5847656ca99c747ddb31ac1701d0c0ed9"}, - {file = "lxml-5.2.2-cp36-cp36m-musllinux_1_2_x86_64.whl", hash = "sha256:2304d3c93f2258ccf2cf7a6ba8c761d76ef84948d87bf9664e14d203da2cd264"}, - {file = "lxml-5.2.2-cp36-cp36m-win32.whl", hash = "sha256:02437fb7308386867c8b7b0e5bc4cd4b04548b1c5d089ffb8e7b31009b961dc3"}, - {file = "lxml-5.2.2-cp36-cp36m-win_amd64.whl", hash = "sha256:edcfa83e03370032a489430215c1e7783128808fd3e2e0a3225deee278585196"}, - {file = "lxml-5.2.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:28bf95177400066596cdbcfc933312493799382879da504633d16cf60bba735b"}, - {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3a745cc98d504d5bd2c19b10c79c61c7c3df9222629f1b6210c0368177589fb8"}, - {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b590b39ef90c6b22ec0be925b211298e810b4856909c8ca60d27ffbca6c12e6"}, - {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b336b0416828022bfd5a2e3083e7f5ba54b96242159f83c7e3eebaec752f1716"}, - {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:c2faf60c583af0d135e853c86ac2735ce178f0e338a3c7f9ae8f622fd2eb788c"}, - {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:4bc6cb140a7a0ad1f7bc37e018d0ed690b7b6520ade518285dc3171f7a117905"}, - {file = "lxml-5.2.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:7ff762670cada8e05b32bf1e4dc50b140790909caa8303cfddc4d702b71ea184"}, - {file = "lxml-5.2.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:57f0a0bbc9868e10ebe874e9f129d2917750adf008fe7b9c1598c0fbbfdde6a6"}, - {file = "lxml-5.2.2-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:a6d2092797b388342c1bc932077ad232f914351932353e2e8706851c870bca1f"}, - {file = "lxml-5.2.2-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:60499fe961b21264e17a471ec296dcbf4365fbea611bf9e303ab69db7159ce61"}, - {file = "lxml-5.2.2-cp37-cp37m-win32.whl", hash = "sha256:d9b342c76003c6b9336a80efcc766748a333573abf9350f4094ee46b006ec18f"}, - {file = "lxml-5.2.2-cp37-cp37m-win_amd64.whl", hash = "sha256:b16db2770517b8799c79aa80f4053cd6f8b716f21f8aca962725a9565ce3ee40"}, - {file = "lxml-5.2.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7ed07b3062b055d7a7f9d6557a251cc655eed0b3152b76de619516621c56f5d3"}, - {file = "lxml-5.2.2-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f60fdd125d85bf9c279ffb8e94c78c51b3b6a37711464e1f5f31078b45002421"}, - {file = "lxml-5.2.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8a7e24cb69ee5f32e003f50e016d5fde438010c1022c96738b04fc2423e61706"}, - {file = "lxml-5.2.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:23cfafd56887eaed93d07bc4547abd5e09d837a002b791e9767765492a75883f"}, - {file = "lxml-5.2.2-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:19b4e485cd07b7d83e3fe3b72132e7df70bfac22b14fe4bf7a23822c3a35bff5"}, - {file = "lxml-5.2.2-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:7ce7ad8abebe737ad6143d9d3bf94b88b93365ea30a5b81f6877ec9c0dee0a48"}, - {file = "lxml-5.2.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:e49b052b768bb74f58c7dda4e0bdf7b79d43a9204ca584ffe1fb48a6f3c84c66"}, - {file = "lxml-5.2.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d14a0d029a4e176795cef99c056d58067c06195e0c7e2dbb293bf95c08f772a3"}, - {file = "lxml-5.2.2-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:be49ad33819d7dcc28a309b86d4ed98e1a65f3075c6acd3cd4fe32103235222b"}, - {file = "lxml-5.2.2-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:a6d17e0370d2516d5bb9062c7b4cb731cff921fc875644c3d751ad857ba9c5b1"}, - {file = "lxml-5.2.2-cp38-cp38-win32.whl", hash = "sha256:5b8c041b6265e08eac8a724b74b655404070b636a8dd6d7a13c3adc07882ef30"}, - {file = "lxml-5.2.2-cp38-cp38-win_amd64.whl", hash = "sha256:f61efaf4bed1cc0860e567d2ecb2363974d414f7f1f124b1df368bbf183453a6"}, - {file = "lxml-5.2.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:fb91819461b1b56d06fa4bcf86617fac795f6a99d12239fb0c68dbeba41a0a30"}, - {file = "lxml-5.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d4ed0c7cbecde7194cd3228c044e86bf73e30a23505af852857c09c24e77ec5d"}, - {file = "lxml-5.2.2-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54401c77a63cc7d6dc4b4e173bb484f28a5607f3df71484709fe037c92d4f0ed"}, - {file = "lxml-5.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:625e3ef310e7fa3a761d48ca7ea1f9d8718a32b1542e727d584d82f4453d5eeb"}, - {file = "lxml-5.2.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:519895c99c815a1a24a926d5b60627ce5ea48e9f639a5cd328bda0515ea0f10c"}, - {file = "lxml-5.2.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c7079d5eb1c1315a858bbf180000757db8ad904a89476653232db835c3114001"}, - {file = "lxml-5.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:343ab62e9ca78094f2306aefed67dcfad61c4683f87eee48ff2fd74902447726"}, - {file = "lxml-5.2.2-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:cd9e78285da6c9ba2d5c769628f43ef66d96ac3085e59b10ad4f3707980710d3"}, - {file = "lxml-5.2.2-cp39-cp39-manylinux_2_28_ppc64le.whl", hash = "sha256:546cf886f6242dff9ec206331209db9c8e1643ae642dea5fdbecae2453cb50fd"}, - {file = "lxml-5.2.2-cp39-cp39-manylinux_2_28_s390x.whl", hash = "sha256:02f6a8eb6512fdc2fd4ca10a49c341c4e109aa6e9448cc4859af5b949622715a"}, - {file = "lxml-5.2.2-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:339ee4a4704bc724757cd5dd9dc8cf4d00980f5d3e6e06d5847c1b594ace68ab"}, - {file = "lxml-5.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0a028b61a2e357ace98b1615fc03f76eb517cc028993964fe08ad514b1e8892d"}, - {file = "lxml-5.2.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:f90e552ecbad426eab352e7b2933091f2be77115bb16f09f78404861c8322981"}, - {file = "lxml-5.2.2-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:d83e2d94b69bf31ead2fa45f0acdef0757fa0458a129734f59f67f3d2eb7ef32"}, - {file = "lxml-5.2.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a02d3c48f9bb1e10c7788d92c0c7db6f2002d024ab6e74d6f45ae33e3d0288a3"}, - {file = "lxml-5.2.2-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:6d68ce8e7b2075390e8ac1e1d3a99e8b6372c694bbe612632606d1d546794207"}, - {file = "lxml-5.2.2-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:453d037e09a5176d92ec0fd282e934ed26d806331a8b70ab431a81e2fbabf56d"}, - {file = "lxml-5.2.2-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:3b019d4ee84b683342af793b56bb35034bd749e4cbdd3d33f7d1107790f8c472"}, - {file = "lxml-5.2.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:cb3942960f0beb9f46e2a71a3aca220d1ca32feb5a398656be934320804c0df9"}, - {file = "lxml-5.2.2-cp39-cp39-win32.whl", hash = "sha256:ac6540c9fff6e3813d29d0403ee7a81897f1d8ecc09a8ff84d2eea70ede1cdbf"}, - {file = "lxml-5.2.2-cp39-cp39-win_amd64.whl", hash = "sha256:610b5c77428a50269f38a534057444c249976433f40f53e3b47e68349cca1425"}, - {file = "lxml-5.2.2-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:b537bd04d7ccd7c6350cdaaaad911f6312cbd61e6e6045542f781c7f8b2e99d2"}, - {file = "lxml-5.2.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4820c02195d6dfb7b8508ff276752f6b2ff8b64ae5d13ebe02e7667e035000b9"}, - {file = "lxml-5.2.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f2a09f6184f17a80897172863a655467da2b11151ec98ba8d7af89f17bf63dae"}, - {file = "lxml-5.2.2-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:76acba4c66c47d27c8365e7c10b3d8016a7da83d3191d053a58382311a8bf4e1"}, - {file = "lxml-5.2.2-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:b128092c927eaf485928cec0c28f6b8bead277e28acf56800e972aa2c2abd7a2"}, - {file = "lxml-5.2.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:ae791f6bd43305aade8c0e22f816b34f3b72b6c820477aab4d18473a37e8090b"}, - {file = "lxml-5.2.2-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a2f6a1bc2460e643785a2cde17293bd7a8f990884b822f7bca47bee0a82fc66b"}, - {file = "lxml-5.2.2-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e8d351ff44c1638cb6e980623d517abd9f580d2e53bfcd18d8941c052a5a009"}, - {file = "lxml-5.2.2-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bec4bd9133420c5c52d562469c754f27c5c9e36ee06abc169612c959bd7dbb07"}, - {file = "lxml-5.2.2-pp37-pypy37_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:55ce6b6d803890bd3cc89975fca9de1dff39729b43b73cb15ddd933b8bc20484"}, - {file = "lxml-5.2.2-pp37-pypy37_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:8ab6a358d1286498d80fe67bd3d69fcbc7d1359b45b41e74c4a26964ca99c3f8"}, - {file = "lxml-5.2.2-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:06668e39e1f3c065349c51ac27ae430719d7806c026fec462e5693b08b95696b"}, - {file = "lxml-5.2.2-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:9cd5323344d8ebb9fb5e96da5de5ad4ebab993bbf51674259dbe9d7a18049525"}, - {file = "lxml-5.2.2-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89feb82ca055af0fe797a2323ec9043b26bc371365847dbe83c7fd2e2f181c34"}, - {file = "lxml-5.2.2-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e481bba1e11ba585fb06db666bfc23dbe181dbafc7b25776156120bf12e0d5a6"}, - {file = "lxml-5.2.2-pp38-pypy38_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:9d6c6ea6a11ca0ff9cd0390b885984ed31157c168565702959c25e2191674a14"}, - {file = "lxml-5.2.2-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:3d98de734abee23e61f6b8c2e08a88453ada7d6486dc7cdc82922a03968928db"}, - {file = "lxml-5.2.2-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:69ab77a1373f1e7563e0fb5a29a8440367dec051da6c7405333699d07444f511"}, - {file = "lxml-5.2.2-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:34e17913c431f5ae01d8658dbf792fdc457073dcdfbb31dc0cc6ab256e664a8d"}, - {file = "lxml-5.2.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05f8757b03208c3f50097761be2dea0aba02e94f0dc7023ed73a7bb14ff11eb0"}, - {file = "lxml-5.2.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6a520b4f9974b0a0a6ed73c2154de57cdfd0c8800f4f15ab2b73238ffed0b36e"}, - {file = "lxml-5.2.2-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:5e097646944b66207023bc3c634827de858aebc226d5d4d6d16f0b77566ea182"}, - {file = "lxml-5.2.2-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:b5e4ef22ff25bfd4ede5f8fb30f7b24446345f3e79d9b7455aef2836437bc38a"}, - {file = "lxml-5.2.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:ff69a9a0b4b17d78170c73abe2ab12084bdf1691550c5629ad1fe7849433f324"}, - {file = "lxml-5.2.2.tar.gz", hash = "sha256:bb2dc4898180bea79863d5487e5f9c7c34297414bad54bcd0f0852aee9cfdb87"}, + {file = "lxml-5.3.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:dd36439be765e2dde7660212b5275641edbc813e7b24668831a5c8ac91180656"}, + {file = "lxml-5.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ae5fe5c4b525aa82b8076c1a59d642c17b6e8739ecf852522c6321852178119d"}, + {file = "lxml-5.3.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:501d0d7e26b4d261fca8132854d845e4988097611ba2531408ec91cf3fd9d20a"}, + {file = "lxml-5.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb66442c2546446944437df74379e9cf9e9db353e61301d1a0e26482f43f0dd8"}, + {file = "lxml-5.3.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9e41506fec7a7f9405b14aa2d5c8abbb4dbbd09d88f9496958b6d00cb4d45330"}, + {file = "lxml-5.3.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f7d4a670107d75dfe5ad080bed6c341d18c4442f9378c9f58e5851e86eb79965"}, + {file = "lxml-5.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41ce1f1e2c7755abfc7e759dc34d7d05fd221723ff822947132dc934d122fe22"}, + {file = "lxml-5.3.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:44264ecae91b30e5633013fb66f6ddd05c006d3e0e884f75ce0b4755b3e3847b"}, + {file = "lxml-5.3.0-cp310-cp310-manylinux_2_28_ppc64le.whl", hash = "sha256:3c174dc350d3ec52deb77f2faf05c439331d6ed5e702fc247ccb4e6b62d884b7"}, + {file = "lxml-5.3.0-cp310-cp310-manylinux_2_28_s390x.whl", hash = "sha256:2dfab5fa6a28a0b60a20638dc48e6343c02ea9933e3279ccb132f555a62323d8"}, + {file = "lxml-5.3.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:b1c8c20847b9f34e98080da785bb2336ea982e7f913eed5809e5a3c872900f32"}, + {file = "lxml-5.3.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:2c86bf781b12ba417f64f3422cfc302523ac9cd1d8ae8c0f92a1c66e56ef2e86"}, + {file = "lxml-5.3.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:c162b216070f280fa7da844531169be0baf9ccb17263cf5a8bf876fcd3117fa5"}, + {file = "lxml-5.3.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:36aef61a1678cb778097b4a6eeae96a69875d51d1e8f4d4b491ab3cfb54b5a03"}, + {file = "lxml-5.3.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f65e5120863c2b266dbcc927b306c5b78e502c71edf3295dfcb9501ec96e5fc7"}, + {file = "lxml-5.3.0-cp310-cp310-win32.whl", hash = "sha256:ef0c1fe22171dd7c7c27147f2e9c3e86f8bdf473fed75f16b0c2e84a5030ce80"}, + {file = "lxml-5.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:052d99051e77a4f3e8482c65014cf6372e61b0a6f4fe9edb98503bb5364cfee3"}, + {file = "lxml-5.3.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:74bcb423462233bc5d6066e4e98b0264e7c1bed7541fff2f4e34fe6b21563c8b"}, + {file = "lxml-5.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a3d819eb6f9b8677f57f9664265d0a10dd6551d227afb4af2b9cd7bdc2ccbf18"}, + {file = "lxml-5.3.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5b8f5db71b28b8c404956ddf79575ea77aa8b1538e8b2ef9ec877945b3f46442"}, + {file = "lxml-5.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c3406b63232fc7e9b8783ab0b765d7c59e7c59ff96759d8ef9632fca27c7ee4"}, + {file = "lxml-5.3.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2ecdd78ab768f844c7a1d4a03595038c166b609f6395e25af9b0f3f26ae1230f"}, + {file = "lxml-5.3.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:168f2dfcfdedf611eb285efac1516c8454c8c99caf271dccda8943576b67552e"}, + {file = "lxml-5.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa617107a410245b8660028a7483b68e7914304a6d4882b5ff3d2d3eb5948d8c"}, + {file = "lxml-5.3.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:69959bd3167b993e6e710b99051265654133a98f20cec1d9b493b931942e9c16"}, + {file = "lxml-5.3.0-cp311-cp311-manylinux_2_28_ppc64le.whl", hash = "sha256:bd96517ef76c8654446fc3db9242d019a1bb5fe8b751ba414765d59f99210b79"}, + {file = "lxml-5.3.0-cp311-cp311-manylinux_2_28_s390x.whl", hash = "sha256:ab6dd83b970dc97c2d10bc71aa925b84788c7c05de30241b9e96f9b6d9ea3080"}, + {file = "lxml-5.3.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:eec1bb8cdbba2925bedc887bc0609a80e599c75b12d87ae42ac23fd199445654"}, + {file = "lxml-5.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:6a7095eeec6f89111d03dabfe5883a1fd54da319c94e0fb104ee8f23616b572d"}, + {file = "lxml-5.3.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:6f651ebd0b21ec65dfca93aa629610a0dbc13dbc13554f19b0113da2e61a4763"}, + {file = "lxml-5.3.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:f422a209d2455c56849442ae42f25dbaaba1c6c3f501d58761c619c7836642ec"}, + {file = "lxml-5.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:62f7fdb0d1ed2065451f086519865b4c90aa19aed51081979ecd05a21eb4d1be"}, + {file = "lxml-5.3.0-cp311-cp311-win32.whl", hash = "sha256:c6379f35350b655fd817cd0d6cbeef7f265f3ae5fedb1caae2eb442bbeae9ab9"}, + {file = "lxml-5.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:9c52100e2c2dbb0649b90467935c4b0de5528833c76a35ea1a2691ec9f1ee7a1"}, + {file = "lxml-5.3.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:e99f5507401436fdcc85036a2e7dc2e28d962550afe1cbfc07c40e454256a859"}, + {file = "lxml-5.3.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:384aacddf2e5813a36495233b64cb96b1949da72bef933918ba5c84e06af8f0e"}, + {file = "lxml-5.3.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:874a216bf6afaf97c263b56371434e47e2c652d215788396f60477540298218f"}, + {file = "lxml-5.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:65ab5685d56914b9a2a34d67dd5488b83213d680b0c5d10b47f81da5a16b0b0e"}, + {file = "lxml-5.3.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aac0bbd3e8dd2d9c45ceb82249e8bdd3ac99131a32b4d35c8af3cc9db1657179"}, + {file = "lxml-5.3.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b369d3db3c22ed14c75ccd5af429086f166a19627e84a8fdade3f8f31426e52a"}, + {file = "lxml-5.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c24037349665434f375645fa9d1f5304800cec574d0310f618490c871fd902b3"}, + {file = "lxml-5.3.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:62d172f358f33a26d6b41b28c170c63886742f5b6772a42b59b4f0fa10526cb1"}, + {file = "lxml-5.3.0-cp312-cp312-manylinux_2_28_ppc64le.whl", hash = "sha256:c1f794c02903c2824fccce5b20c339a1a14b114e83b306ff11b597c5f71a1c8d"}, + {file = "lxml-5.3.0-cp312-cp312-manylinux_2_28_s390x.whl", hash = "sha256:5d6a6972b93c426ace71e0be9a6f4b2cfae9b1baed2eed2006076a746692288c"}, + {file = "lxml-5.3.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:3879cc6ce938ff4eb4900d901ed63555c778731a96365e53fadb36437a131a99"}, + {file = "lxml-5.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:74068c601baff6ff021c70f0935b0c7bc528baa8ea210c202e03757c68c5a4ff"}, + {file = "lxml-5.3.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:ecd4ad8453ac17bc7ba3868371bffb46f628161ad0eefbd0a855d2c8c32dd81a"}, + {file = "lxml-5.3.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:7e2f58095acc211eb9d8b5771bf04df9ff37d6b87618d1cbf85f92399c98dae8"}, + {file = "lxml-5.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e63601ad5cd8f860aa99d109889b5ac34de571c7ee902d6812d5d9ddcc77fa7d"}, + {file = "lxml-5.3.0-cp312-cp312-win32.whl", hash = "sha256:17e8d968d04a37c50ad9c456a286b525d78c4a1c15dd53aa46c1d8e06bf6fa30"}, + {file = "lxml-5.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:c1a69e58a6bb2de65902051d57fde951febad631a20a64572677a1052690482f"}, + {file = "lxml-5.3.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8c72e9563347c7395910de6a3100a4840a75a6f60e05af5e58566868d5eb2d6a"}, + {file = "lxml-5.3.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e92ce66cd919d18d14b3856906a61d3f6b6a8500e0794142338da644260595cd"}, + {file = "lxml-5.3.0-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1d04f064bebdfef9240478f7a779e8c5dc32b8b7b0b2fc6a62e39b928d428e51"}, + {file = "lxml-5.3.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c2fb570d7823c2bbaf8b419ba6e5662137f8166e364a8b2b91051a1fb40ab8b"}, + {file = "lxml-5.3.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0c120f43553ec759f8de1fee2f4794452b0946773299d44c36bfe18e83caf002"}, + {file = "lxml-5.3.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:562e7494778a69086f0312ec9689f6b6ac1c6b65670ed7d0267e49f57ffa08c4"}, + {file = "lxml-5.3.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:423b121f7e6fa514ba0c7918e56955a1d4470ed35faa03e3d9f0e3baa4c7e492"}, + {file = "lxml-5.3.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:c00f323cc00576df6165cc9d21a4c21285fa6b9989c5c39830c3903dc4303ef3"}, + {file = "lxml-5.3.0-cp313-cp313-manylinux_2_28_ppc64le.whl", hash = "sha256:1fdc9fae8dd4c763e8a31e7630afef517eab9f5d5d31a278df087f307bf601f4"}, + {file = "lxml-5.3.0-cp313-cp313-manylinux_2_28_s390x.whl", hash = "sha256:658f2aa69d31e09699705949b5fc4719cbecbd4a97f9656a232e7d6c7be1a367"}, + {file = "lxml-5.3.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:1473427aff3d66a3fa2199004c3e601e6c4500ab86696edffdbc84954c72d832"}, + {file = "lxml-5.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a87de7dd873bf9a792bf1e58b1c3887b9264036629a5bf2d2e6579fe8e73edff"}, + {file = "lxml-5.3.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:0d7b36afa46c97875303a94e8f3ad932bf78bace9e18e603f2085b652422edcd"}, + {file = "lxml-5.3.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:cf120cce539453ae086eacc0130a324e7026113510efa83ab42ef3fcfccac7fb"}, + {file = "lxml-5.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:df5c7333167b9674aa8ae1d4008fa4bc17a313cc490b2cca27838bbdcc6bb15b"}, + {file = "lxml-5.3.0-cp313-cp313-win32.whl", hash = "sha256:c802e1c2ed9f0c06a65bc4ed0189d000ada8049312cfeab6ca635e39c9608957"}, + {file = "lxml-5.3.0-cp313-cp313-win_amd64.whl", hash = "sha256:406246b96d552e0503e17a1006fd27edac678b3fcc9f1be71a2f94b4ff61528d"}, + {file = "lxml-5.3.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:8f0de2d390af441fe8b2c12626d103540b5d850d585b18fcada58d972b74a74e"}, + {file = "lxml-5.3.0-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1afe0a8c353746e610bd9031a630a95bcfb1a720684c3f2b36c4710a0a96528f"}, + {file = "lxml-5.3.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56b9861a71575f5795bde89256e7467ece3d339c9b43141dbdd54544566b3b94"}, + {file = "lxml-5.3.0-cp36-cp36m-manylinux_2_28_x86_64.whl", hash = "sha256:9fb81d2824dff4f2e297a276297e9031f46d2682cafc484f49de182aa5e5df99"}, + {file = "lxml-5.3.0-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:2c226a06ecb8cdef28845ae976da407917542c5e6e75dcac7cc33eb04aaeb237"}, + {file = "lxml-5.3.0-cp36-cp36m-musllinux_1_2_x86_64.whl", hash = "sha256:7d3d1ca42870cdb6d0d29939630dbe48fa511c203724820fc0fd507b2fb46577"}, + {file = "lxml-5.3.0-cp36-cp36m-win32.whl", hash = "sha256:094cb601ba9f55296774c2d57ad68730daa0b13dc260e1f941b4d13678239e70"}, + {file = "lxml-5.3.0-cp36-cp36m-win_amd64.whl", hash = "sha256:eafa2c8658f4e560b098fe9fc54539f86528651f61849b22111a9b107d18910c"}, + {file = "lxml-5.3.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:cb83f8a875b3d9b458cada4f880fa498646874ba4011dc974e071a0a84a1b033"}, + {file = "lxml-5.3.0-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:25f1b69d41656b05885aa185f5fdf822cb01a586d1b32739633679699f220391"}, + {file = "lxml-5.3.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23e0553b8055600b3bf4a00b255ec5c92e1e4aebf8c2c09334f8368e8bd174d6"}, + {file = "lxml-5.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ada35dd21dc6c039259596b358caab6b13f4db4d4a7f8665764d616daf9cc1d"}, + {file = "lxml-5.3.0-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:81b4e48da4c69313192d8c8d4311e5d818b8be1afe68ee20f6385d0e96fc9512"}, + {file = "lxml-5.3.0-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:2bc9fd5ca4729af796f9f59cd8ff160fe06a474da40aca03fcc79655ddee1a8b"}, + {file = "lxml-5.3.0-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:07da23d7ee08577760f0a71d67a861019103e4812c87e2fab26b039054594cc5"}, + {file = "lxml-5.3.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:ea2e2f6f801696ad7de8aec061044d6c8c0dd4037608c7cab38a9a4d316bfb11"}, + {file = "lxml-5.3.0-cp37-cp37m-win32.whl", hash = "sha256:5c54afdcbb0182d06836cc3d1be921e540be3ebdf8b8a51ee3ef987537455f84"}, + {file = "lxml-5.3.0-cp37-cp37m-win_amd64.whl", hash = "sha256:f2901429da1e645ce548bf9171784c0f74f0718c3f6150ce166be39e4dd66c3e"}, + {file = "lxml-5.3.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c56a1d43b2f9ee4786e4658c7903f05da35b923fb53c11025712562d5cc02753"}, + {file = "lxml-5.3.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ee8c39582d2652dcd516d1b879451500f8db3fe3607ce45d7c5957ab2596040"}, + {file = "lxml-5.3.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0fdf3a3059611f7585a78ee10399a15566356116a4288380921a4b598d807a22"}, + {file = "lxml-5.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:146173654d79eb1fc97498b4280c1d3e1e5d58c398fa530905c9ea50ea849b22"}, + {file = "lxml-5.3.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0a7056921edbdd7560746f4221dca89bb7a3fe457d3d74267995253f46343f15"}, + {file = "lxml-5.3.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:9e4b47ac0f5e749cfc618efdf4726269441014ae1d5583e047b452a32e221920"}, + {file = "lxml-5.3.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:f914c03e6a31deb632e2daa881fe198461f4d06e57ac3d0e05bbcab8eae01945"}, + {file = "lxml-5.3.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:213261f168c5e1d9b7535a67e68b1f59f92398dd17a56d934550837143f79c42"}, + {file = "lxml-5.3.0-cp38-cp38-win32.whl", hash = "sha256:218c1b2e17a710e363855594230f44060e2025b05c80d1f0661258142b2add2e"}, + {file = "lxml-5.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:315f9542011b2c4e1d280e4a20ddcca1761993dda3afc7a73b01235f8641e903"}, + {file = "lxml-5.3.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:1ffc23010330c2ab67fac02781df60998ca8fe759e8efde6f8b756a20599c5de"}, + {file = "lxml-5.3.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2b3778cb38212f52fac9fe913017deea2fdf4eb1a4f8e4cfc6b009a13a6d3fcc"}, + {file = "lxml-5.3.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4b0c7a688944891086ba192e21c5229dea54382f4836a209ff8d0a660fac06be"}, + {file = "lxml-5.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:747a3d3e98e24597981ca0be0fd922aebd471fa99d0043a3842d00cdcad7ad6a"}, + {file = "lxml-5.3.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:86a6b24b19eaebc448dc56b87c4865527855145d851f9fc3891673ff97950540"}, + {file = "lxml-5.3.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b11a5d918a6216e521c715b02749240fb07ae5a1fefd4b7bf12f833bc8b4fe70"}, + {file = "lxml-5.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:68b87753c784d6acb8a25b05cb526c3406913c9d988d51f80adecc2b0775d6aa"}, + {file = "lxml-5.3.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:109fa6fede314cc50eed29e6e56c540075e63d922455346f11e4d7a036d2b8cf"}, + {file = "lxml-5.3.0-cp39-cp39-manylinux_2_28_ppc64le.whl", hash = "sha256:02ced472497b8362c8e902ade23e3300479f4f43e45f4105c85ef43b8db85229"}, + {file = "lxml-5.3.0-cp39-cp39-manylinux_2_28_s390x.whl", hash = "sha256:6b038cc86b285e4f9fea2ba5ee76e89f21ed1ea898e287dc277a25884f3a7dfe"}, + {file = "lxml-5.3.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:7437237c6a66b7ca341e868cda48be24b8701862757426852c9b3186de1da8a2"}, + {file = "lxml-5.3.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:7f41026c1d64043a36fda21d64c5026762d53a77043e73e94b71f0521939cc71"}, + {file = "lxml-5.3.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:482c2f67761868f0108b1743098640fbb2a28a8e15bf3f47ada9fa59d9fe08c3"}, + {file = "lxml-5.3.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:1483fd3358963cc5c1c9b122c80606a3a79ee0875bcac0204149fa09d6ff2727"}, + {file = "lxml-5.3.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:2dec2d1130a9cda5b904696cec33b2cfb451304ba9081eeda7f90f724097300a"}, + {file = "lxml-5.3.0-cp39-cp39-win32.whl", hash = "sha256:a0eabd0a81625049c5df745209dc7fcef6e2aea7793e5f003ba363610aa0a3ff"}, + {file = "lxml-5.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:89e043f1d9d341c52bf2af6d02e6adde62e0a46e6755d5eb60dc6e4f0b8aeca2"}, + {file = "lxml-5.3.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:7b1cd427cb0d5f7393c31b7496419da594fe600e6fdc4b105a54f82405e6626c"}, + {file = "lxml-5.3.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:51806cfe0279e06ed8500ce19479d757db42a30fd509940b1701be9c86a5ff9a"}, + {file = "lxml-5.3.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ee70d08fd60c9565ba8190f41a46a54096afa0eeb8f76bd66f2c25d3b1b83005"}, + {file = "lxml-5.3.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:8dc2c0395bea8254d8daebc76dcf8eb3a95ec2a46fa6fae5eaccee366bfe02ce"}, + {file = "lxml-5.3.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:6ba0d3dcac281aad8a0e5b14c7ed6f9fa89c8612b47939fc94f80b16e2e9bc83"}, + {file = "lxml-5.3.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:6e91cf736959057f7aac7adfc83481e03615a8e8dd5758aa1d95ea69e8931dba"}, + {file = "lxml-5.3.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:94d6c3782907b5e40e21cadf94b13b0842ac421192f26b84c45f13f3c9d5dc27"}, + {file = "lxml-5.3.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c300306673aa0f3ed5ed9372b21867690a17dba38c68c44b287437c362ce486b"}, + {file = "lxml-5.3.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78d9b952e07aed35fe2e1a7ad26e929595412db48535921c5013edc8aa4a35ce"}, + {file = "lxml-5.3.0-pp37-pypy37_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:01220dca0d066d1349bd6a1726856a78f7929f3878f7e2ee83c296c69495309e"}, + {file = "lxml-5.3.0-pp37-pypy37_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:2d9b8d9177afaef80c53c0a9e30fa252ff3036fb1c6494d427c066a4ce6a282f"}, + {file = "lxml-5.3.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:20094fc3f21ea0a8669dc4c61ed7fa8263bd37d97d93b90f28fc613371e7a875"}, + {file = "lxml-5.3.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:ace2c2326a319a0bb8a8b0e5b570c764962e95818de9f259ce814ee666603f19"}, + {file = "lxml-5.3.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:92e67a0be1639c251d21e35fe74df6bcc40cba445c2cda7c4a967656733249e2"}, + {file = "lxml-5.3.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd5350b55f9fecddc51385463a4f67a5da829bc741e38cf689f38ec9023f54ab"}, + {file = "lxml-5.3.0-pp38-pypy38_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:4c1fefd7e3d00921c44dc9ca80a775af49698bbfd92ea84498e56acffd4c5469"}, + {file = "lxml-5.3.0-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:71a8dd38fbd2f2319136d4ae855a7078c69c9a38ae06e0c17c73fd70fc6caad8"}, + {file = "lxml-5.3.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:97acf1e1fd66ab53dacd2c35b319d7e548380c2e9e8c54525c6e76d21b1ae3b1"}, + {file = "lxml-5.3.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:68934b242c51eb02907c5b81d138cb977b2129a0a75a8f8b60b01cb8586c7b21"}, + {file = "lxml-5.3.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b710bc2b8292966b23a6a0121f7a6c51d45d2347edcc75f016ac123b8054d3f2"}, + {file = "lxml-5.3.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18feb4b93302091b1541221196a2155aa296c363fd233814fa11e181adebc52f"}, + {file = "lxml-5.3.0-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:3eb44520c4724c2e1a57c0af33a379eee41792595023f367ba3952a2d96c2aab"}, + {file = "lxml-5.3.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:609251a0ca4770e5a8768ff902aa02bf636339c5a93f9349b48eb1f606f7f3e9"}, + {file = "lxml-5.3.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:516f491c834eb320d6c843156440fe7fc0d50b33e44387fcec5b02f0bc118a4c"}, + {file = "lxml-5.3.0.tar.gz", hash = "sha256:4e109ca30d1edec1ac60cdbe341905dc3b8f55b16855e03a54aaf59e51ec8c6f"}, ] [package.extras] @@ -4363,7 +4527,7 @@ cssselect = ["cssselect (>=0.7)"] html-clean = ["lxml-html-clean"] html5 = ["html5lib"] htmlsoup = ["BeautifulSoup4"] -source = ["Cython (>=3.0.10)"] +source = ["Cython (>=3.0.11)"] [[package]] name = "lz4" @@ -4561,13 +4725,13 @@ files = [ [[package]] name = "marshmallow" -version = "3.21.3" +version = "3.22.0" description = "A lightweight library for converting complex datatypes to and from native Python datatypes." optional = false python-versions = ">=3.8" files = [ - {file = "marshmallow-3.21.3-py3-none-any.whl", hash = "sha256:86ce7fb914aa865001a4b2092c4c2872d13bc347f3d42673272cabfdbad386f1"}, - {file = "marshmallow-3.21.3.tar.gz", hash = "sha256:4f57c5e050a54d66361e826f94fba213eb10b67b2fdb02c3e0343ce207ba1662"}, + {file = "marshmallow-3.22.0-py3-none-any.whl", hash = "sha256:71a2dce49ef901c3f97ed296ae5051135fd3febd2bf43afe0ae9a82143a494d9"}, + {file = "marshmallow-3.22.0.tar.gz", hash = "sha256:4972f529104a220bb8637d595aa4c9762afbe7f7a77d82dc58c1615d70c5823e"}, ] [package.dependencies] @@ -4575,7 +4739,7 @@ packaging = ">=17.0" [package.extras] dev = ["marshmallow[tests]", "pre-commit (>=3.5,<4.0)", "tox"] -docs = ["alabaster (==0.7.16)", "autodocsumm (==0.2.12)", "sphinx (==7.3.7)", "sphinx-issues (==4.1.0)", "sphinx-version-warning (==1.1.2)"] +docs = ["alabaster (==1.0.0)", "autodocsumm (==0.2.13)", "sphinx (==8.0.2)", "sphinx-issues (==4.1.0)", "sphinx-version-warning (==1.1.2)"] tests = ["pytest", "pytz", "simplejson"] [[package]] @@ -4639,15 +4803,15 @@ files = [ [[package]] name = "milvus-lite" -version = "2.4.8" +version = "2.4.9" description = "A lightweight version of Milvus wrapped with Python." optional = false python-versions = ">=3.7" files = [ - {file = "milvus_lite-2.4.8-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:b7e90b34b214884cd44cdc112ab243d4cb197b775498355e2437b6cafea025fe"}, - {file = "milvus_lite-2.4.8-py3-none-macosx_11_0_arm64.whl", hash = "sha256:519dfc62709d8f642d98a1c5b1dcde7080d107e6e312d677fef5a3412a40ac08"}, - {file = "milvus_lite-2.4.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:b21f36d24cbb0e920b4faad607019bb28c1b2c88b4d04680ac8c7697a4ae8a4d"}, - {file = "milvus_lite-2.4.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:08332a2b9abfe7c4e1d7926068937e46f8fb81f2707928b7bc02c9dc99cebe41"}, + {file = "milvus_lite-2.4.9-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:d3e617b3d68c09ad656d54bc3d8cc4ef6ef56c54015e1563d4fe4bcec6b7c90a"}, + {file = "milvus_lite-2.4.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:6e7029282d6829b277ebb92f64e2370be72b938e34770e1eb649346bda5d1d7f"}, + {file = "milvus_lite-2.4.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9b8e991e4e433596f6a399a165c1a506f823ec9133332e03d7f8a114bff4550d"}, + {file = "milvus_lite-2.4.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:7f53e674602101cfbcf0a4a59d19eaa139dfd5580639f3040ad73d901f24fc0b"}, ] [package.dependencies] @@ -5019,13 +5183,13 @@ twitter = ["twython"] [[package]] name = "novita-client" -version = "0.5.6" +version = "0.5.7" description = "novita SDK for Python" optional = false python-versions = ">=3.6" files = [ - {file = "novita_client-0.5.6-py3-none-any.whl", hash = "sha256:9fa6cfd12f13a75c7da42b27f811a560b0320da24cf256480f517bde479bc57c"}, - {file = "novita_client-0.5.6.tar.gz", hash = "sha256:2e4d956903d5da39d43127a41dcb020ae40322d2a6196413071b94b3d6988b98"}, + {file = "novita_client-0.5.7-py3-none-any.whl", hash = "sha256:844a4c09c98328c8d4f72e1d3f63f76285c2963dcc37ccb2de41cbfdbe7fa51d"}, + {file = "novita_client-0.5.7.tar.gz", hash = "sha256:65baf748757aafd8ab080a64f9ab069a40c0810fc1fa9be9c26596988a0aa4b4"}, ] [package.dependencies] @@ -5198,42 +5362,42 @@ tests = ["pytest", "pytest-cov"] [[package]] name = "onnxruntime" -version = "1.18.1" +version = "1.19.0" description = "ONNX Runtime is a runtime accelerator for Machine Learning models" optional = false python-versions = "*" files = [ - {file = "onnxruntime-1.18.1-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:29ef7683312393d4ba04252f1b287d964bd67d5e6048b94d2da3643986c74d80"}, - {file = "onnxruntime-1.18.1-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fc706eb1df06ddf55776e15a30519fb15dda7697f987a2bbda4962845e3cec05"}, - {file = "onnxruntime-1.18.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b7de69f5ced2a263531923fa68bbec52a56e793b802fcd81a03487b5e292bc3a"}, - {file = "onnxruntime-1.18.1-cp310-cp310-win32.whl", hash = "sha256:221e5b16173926e6c7de2cd437764492aa12b6811f45abd37024e7cf2ae5d7e3"}, - {file = "onnxruntime-1.18.1-cp310-cp310-win_amd64.whl", hash = "sha256:75211b619275199c861ee94d317243b8a0fcde6032e5a80e1aa9ded8ab4c6060"}, - {file = "onnxruntime-1.18.1-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:f26582882f2dc581b809cfa41a125ba71ad9e715738ec6402418df356969774a"}, - {file = "onnxruntime-1.18.1-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ef36f3a8b768506d02be349ac303fd95d92813ba3ba70304d40c3cd5c25d6a4c"}, - {file = "onnxruntime-1.18.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:170e711393e0618efa8ed27b59b9de0ee2383bd2a1f93622a97006a5ad48e434"}, - {file = "onnxruntime-1.18.1-cp311-cp311-win32.whl", hash = "sha256:9b6a33419b6949ea34e0dc009bc4470e550155b6da644571ecace4b198b0d88f"}, - {file = "onnxruntime-1.18.1-cp311-cp311-win_amd64.whl", hash = "sha256:5c1380a9f1b7788da742c759b6a02ba771fe1ce620519b2b07309decbd1a2fe1"}, - {file = "onnxruntime-1.18.1-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:31bd57a55e3f983b598675dfc7e5d6f0877b70ec9864b3cc3c3e1923d0a01919"}, - {file = "onnxruntime-1.18.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b9e03c4ba9f734500691a4d7d5b381cd71ee2f3ce80a1154ac8f7aed99d1ecaa"}, - {file = "onnxruntime-1.18.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:781aa9873640f5df24524f96f6070b8c550c66cb6af35710fd9f92a20b4bfbf6"}, - {file = "onnxruntime-1.18.1-cp312-cp312-win32.whl", hash = "sha256:3a2d9ab6254ca62adbb448222e630dc6883210f718065063518c8f93a32432be"}, - {file = "onnxruntime-1.18.1-cp312-cp312-win_amd64.whl", hash = "sha256:ad93c560b1c38c27c0275ffd15cd7f45b3ad3fc96653c09ce2931179982ff204"}, - {file = "onnxruntime-1.18.1-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:3b55dc9d3c67626388958a3eb7ad87eb7c70f75cb0f7ff4908d27b8b42f2475c"}, - {file = "onnxruntime-1.18.1-cp38-cp38-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f80dbcfb6763cc0177a31168b29b4bd7662545b99a19e211de8c734b657e0669"}, - {file = "onnxruntime-1.18.1-cp38-cp38-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f1ff2c61a16d6c8631796c54139bafea41ee7736077a0fc64ee8ae59432f5c58"}, - {file = "onnxruntime-1.18.1-cp38-cp38-win32.whl", hash = "sha256:219855bd272fe0c667b850bf1a1a5a02499269a70d59c48e6f27f9c8bcb25d02"}, - {file = "onnxruntime-1.18.1-cp38-cp38-win_amd64.whl", hash = "sha256:afdf16aa607eb9a2c60d5ca2d5abf9f448e90c345b6b94c3ed14f4fb7e6a2d07"}, - {file = "onnxruntime-1.18.1-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:128df253ade673e60cea0955ec9d0e89617443a6d9ce47c2d79eb3f72a3be3de"}, - {file = "onnxruntime-1.18.1-cp39-cp39-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9839491e77e5c5a175cab3621e184d5a88925ee297ff4c311b68897197f4cde9"}, - {file = "onnxruntime-1.18.1-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ad3187c1faff3ac15f7f0e7373ef4788c582cafa655a80fdbb33eaec88976c66"}, - {file = "onnxruntime-1.18.1-cp39-cp39-win32.whl", hash = "sha256:34657c78aa4e0b5145f9188b550ded3af626651b15017bf43d280d7e23dbf195"}, - {file = "onnxruntime-1.18.1-cp39-cp39-win_amd64.whl", hash = "sha256:9c14fd97c3ddfa97da5feef595e2c73f14c2d0ec1d4ecbea99c8d96603c89589"}, + {file = "onnxruntime-1.19.0-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:6ce22a98dfec7b646ae305f52d0ce14a189a758b02ea501860ca719f4b0ae04b"}, + {file = "onnxruntime-1.19.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:19019c72873f26927aa322c54cf2bf7312b23451b27451f39b88f57016c94f8b"}, + {file = "onnxruntime-1.19.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8eaa16df99171dc636e30108d15597aed8c4c2dd9dbfdd07cc464d57d73fb275"}, + {file = "onnxruntime-1.19.0-cp310-cp310-win32.whl", hash = "sha256:0eb0f8dbe596fd0f4737fe511fdbb17603853a7d204c5b2ca38d3c7808fc556b"}, + {file = "onnxruntime-1.19.0-cp310-cp310-win_amd64.whl", hash = "sha256:616092d54ba8023b7bc0a5f6d900a07a37cc1cfcc631873c15f8c1d6e9e184d4"}, + {file = "onnxruntime-1.19.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:a2b53b3c287cd933e5eb597273926e899082d8c84ab96e1b34035764a1627e17"}, + {file = "onnxruntime-1.19.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e94984663963e74fbb468bde9ec6f19dcf890b594b35e249c4dc8789d08993c5"}, + {file = "onnxruntime-1.19.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6f379d1f050cfb55ce015d53727b78ee362febc065c38eed81512b22b757da73"}, + {file = "onnxruntime-1.19.0-cp311-cp311-win32.whl", hash = "sha256:4ccb48faea02503275ae7e79e351434fc43c294c4cb5c4d8bcb7479061396614"}, + {file = "onnxruntime-1.19.0-cp311-cp311-win_amd64.whl", hash = "sha256:9cdc8d311289a84e77722de68bd22b8adfb94eea26f4be6f9e017350faac8b18"}, + {file = "onnxruntime-1.19.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:1b59eaec1be9a8613c5fdeaafe67f73a062edce3ac03bbbdc9e2d98b58a30617"}, + {file = "onnxruntime-1.19.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:be4144d014a4b25184e63ce7a463a2e7796e2f3df931fccc6a6aefa6f1365dc5"}, + {file = "onnxruntime-1.19.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:10d7e7d4ca7021ce7f29a66dbc6071addf2de5839135339bd855c6d9c2bba371"}, + {file = "onnxruntime-1.19.0-cp312-cp312-win32.whl", hash = "sha256:87f2c58b577a1fb31dc5d92b647ecc588fd5f1ea0c3ad4526f5f80a113357c8d"}, + {file = "onnxruntime-1.19.0-cp312-cp312-win_amd64.whl", hash = "sha256:8a1f50d49676d7b69566536ff039d9e4e95fc482a55673719f46528218ecbb94"}, + {file = "onnxruntime-1.19.0-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:71423c8c4b2d7a58956271534302ec72721c62a41efd0c4896343249b8399ab0"}, + {file = "onnxruntime-1.19.0-cp38-cp38-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9d63630d45e9498f96e75bbeb7fd4a56acb10155de0de4d0e18d1b6cbb0b358a"}, + {file = "onnxruntime-1.19.0-cp38-cp38-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f3bfd15db1e8794d379a86c1a9116889f47f2cca40cc82208fc4f7e8c38e8522"}, + {file = "onnxruntime-1.19.0-cp38-cp38-win32.whl", hash = "sha256:3b098003b6b4cb37cc84942e5f1fe27f945dd857cbd2829c824c26b0ba4a247e"}, + {file = "onnxruntime-1.19.0-cp38-cp38-win_amd64.whl", hash = "sha256:cea067a6541d6787d903ee6843401c5b1332a266585160d9700f9f0939443886"}, + {file = "onnxruntime-1.19.0-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:c4fcff12dc5ca963c5f76b9822bb404578fa4a98c281e8c666b429192799a099"}, + {file = "onnxruntime-1.19.0-cp39-cp39-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f6dcad8a4db908fbe70b98c79cea1c8b6ac3316adf4ce93453136e33a524ac59"}, + {file = "onnxruntime-1.19.0-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4bc449907c6e8d99eee5ae5cc9c8fdef273d801dcd195393d3f9ab8ad3f49522"}, + {file = "onnxruntime-1.19.0-cp39-cp39-win32.whl", hash = "sha256:947febd48405afcf526e45ccff97ff23b15e530434705f734870d22ae7fcf236"}, + {file = "onnxruntime-1.19.0-cp39-cp39-win_amd64.whl", hash = "sha256:f60be47eff5ee77fd28a466b0fd41d7debc42a32179d1ddb21e05d6067d7b48b"}, ] [package.dependencies] coloredlogs = "*" flatbuffers = "*" -numpy = ">=1.21.6,<2.0" +numpy = ">=1.21.6" packaging = "*" protobuf = "*" sympy = "*" @@ -5261,6 +5425,65 @@ typing-extensions = ">=4.7,<5" [package.extras] datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] +[[package]] +name = "opencensus" +version = "0.11.4" +description = "A stats collection and distributed tracing framework" +optional = false +python-versions = "*" +files = [ + {file = "opencensus-0.11.4-py2.py3-none-any.whl", hash = "sha256:a18487ce68bc19900336e0ff4655c5a116daf10c1b3685ece8d971bddad6a864"}, + {file = "opencensus-0.11.4.tar.gz", hash = "sha256:cbef87d8b8773064ab60e5c2a1ced58bbaa38a6d052c41aec224958ce544eff2"}, +] + +[package.dependencies] +google-api-core = {version = ">=1.0.0,<3.0.0", markers = "python_version >= \"3.6\""} +opencensus-context = ">=0.1.3" +six = ">=1.16,<2.0" + +[[package]] +name = "opencensus-context" +version = "0.1.3" +description = "OpenCensus Runtime Context" +optional = false +python-versions = "*" +files = [ + {file = "opencensus-context-0.1.3.tar.gz", hash = "sha256:a03108c3c10d8c80bb5ddf5c8a1f033161fa61972a9917f9b9b3a18517f0088c"}, + {file = "opencensus_context-0.1.3-py2.py3-none-any.whl", hash = "sha256:073bb0590007af276853009fac7e4bab1d523c3f03baf4cb4511ca38967c6039"}, +] + +[[package]] +name = "opencensus-ext-azure" +version = "1.1.13" +description = "OpenCensus Azure Monitor Exporter" +optional = false +python-versions = "*" +files = [ + {file = "opencensus-ext-azure-1.1.13.tar.gz", hash = "sha256:aec30472177005379ba56a702a097d618c5f57558e1bb6676ec75f948130692a"}, + {file = "opencensus_ext_azure-1.1.13-py2.py3-none-any.whl", hash = "sha256:06001fac6f8588ba00726a3a7c6c7f2fc88bc8ad12a65afdca657923085393dd"}, +] + +[package.dependencies] +azure-core = ">=1.12.0,<2.0.0" +azure-identity = ">=1.5.0,<2.0.0" +opencensus = ">=0.11.4,<1.0.0" +psutil = ">=5.6.3" +requests = ">=2.19.0" + +[[package]] +name = "opencensus-ext-logging" +version = "0.1.1" +description = "OpenCensus logging Integration" +optional = false +python-versions = "*" +files = [ + {file = "opencensus-ext-logging-0.1.1.tar.gz", hash = "sha256:c203b70f034151dada529f543af330ba17aaffec27d8a5267d03c713eb1de334"}, + {file = "opencensus_ext_logging-0.1.1-py2.py3-none-any.whl", hash = "sha256:cfdaf5da5d8b195ff3d1af87a4066a6621a28046173f6be4b0b6caec4a3ca89f"}, +] + +[package.dependencies] +opencensus = ">=0.8.0,<1.0.0" + [[package]] name = "openpyxl" version = "3.1.5" @@ -5507,62 +5730,68 @@ cryptography = ">=3.2.1" [[package]] name = "orjson" -version = "3.10.6" +version = "3.10.7" description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy" optional = false python-versions = ">=3.8" files = [ - {file = "orjson-3.10.6-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:fb0ee33124db6eaa517d00890fc1a55c3bfe1cf78ba4a8899d71a06f2d6ff5c7"}, - {file = "orjson-3.10.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c1c4b53b24a4c06547ce43e5fee6ec4e0d8fe2d597f4647fc033fd205707365"}, - {file = "orjson-3.10.6-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eadc8fd310edb4bdbd333374f2c8fec6794bbbae99b592f448d8214a5e4050c0"}, - {file = "orjson-3.10.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:61272a5aec2b2661f4fa2b37c907ce9701e821b2c1285d5c3ab0207ebd358d38"}, - {file = "orjson-3.10.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:57985ee7e91d6214c837936dc1608f40f330a6b88bb13f5a57ce5257807da143"}, - {file = "orjson-3.10.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:633a3b31d9d7c9f02d49c4ab4d0a86065c4a6f6adc297d63d272e043472acab5"}, - {file = "orjson-3.10.6-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:1c680b269d33ec444afe2bdc647c9eb73166fa47a16d9a75ee56a374f4a45f43"}, - {file = "orjson-3.10.6-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f759503a97a6ace19e55461395ab0d618b5a117e8d0fbb20e70cfd68a47327f2"}, - {file = "orjson-3.10.6-cp310-none-win32.whl", hash = "sha256:95a0cce17f969fb5391762e5719575217bd10ac5a189d1979442ee54456393f3"}, - {file = "orjson-3.10.6-cp310-none-win_amd64.whl", hash = "sha256:df25d9271270ba2133cc88ee83c318372bdc0f2cd6f32e7a450809a111efc45c"}, - {file = "orjson-3.10.6-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:b1ec490e10d2a77c345def52599311849fc063ae0e67cf4f84528073152bb2ba"}, - {file = "orjson-3.10.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:55d43d3feb8f19d07e9f01e5b9be4f28801cf7c60d0fa0d279951b18fae1932b"}, - {file = "orjson-3.10.6-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ac3045267e98fe749408eee1593a142e02357c5c99be0802185ef2170086a863"}, - {file = "orjson-3.10.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c27bc6a28ae95923350ab382c57113abd38f3928af3c80be6f2ba7eb8d8db0b0"}, - {file = "orjson-3.10.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d27456491ca79532d11e507cadca37fb8c9324a3976294f68fb1eff2dc6ced5a"}, - {file = "orjson-3.10.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:05ac3d3916023745aa3b3b388e91b9166be1ca02b7c7e41045da6d12985685f0"}, - {file = "orjson-3.10.6-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1335d4ef59ab85cab66fe73fd7a4e881c298ee7f63ede918b7faa1b27cbe5212"}, - {file = "orjson-3.10.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4bbc6d0af24c1575edc79994c20e1b29e6fb3c6a570371306db0993ecf144dc5"}, - {file = "orjson-3.10.6-cp311-none-win32.whl", hash = "sha256:450e39ab1f7694465060a0550b3f6d328d20297bf2e06aa947b97c21e5241fbd"}, - {file = "orjson-3.10.6-cp311-none-win_amd64.whl", hash = "sha256:227df19441372610b20e05bdb906e1742ec2ad7a66ac8350dcfd29a63014a83b"}, - {file = "orjson-3.10.6-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:ea2977b21f8d5d9b758bb3f344a75e55ca78e3ff85595d248eee813ae23ecdfb"}, - {file = "orjson-3.10.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b6f3d167d13a16ed263b52dbfedff52c962bfd3d270b46b7518365bcc2121eed"}, - {file = "orjson-3.10.6-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f710f346e4c44a4e8bdf23daa974faede58f83334289df80bc9cd12fe82573c7"}, - {file = "orjson-3.10.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7275664f84e027dcb1ad5200b8b18373e9c669b2a9ec33d410c40f5ccf4b257e"}, - {file = "orjson-3.10.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0943e4c701196b23c240b3d10ed8ecd674f03089198cf503105b474a4f77f21f"}, - {file = "orjson-3.10.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:446dee5a491b5bc7d8f825d80d9637e7af43f86a331207b9c9610e2f93fee22a"}, - {file = "orjson-3.10.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:64c81456d2a050d380786413786b057983892db105516639cb5d3ee3c7fd5148"}, - {file = "orjson-3.10.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:960db0e31c4e52fa0fc3ecbaea5b2d3b58f379e32a95ae6b0ebeaa25b93dfd34"}, - {file = "orjson-3.10.6-cp312-none-win32.whl", hash = "sha256:a6ea7afb5b30b2317e0bee03c8d34c8181bc5a36f2afd4d0952f378972c4efd5"}, - {file = "orjson-3.10.6-cp312-none-win_amd64.whl", hash = "sha256:874ce88264b7e655dde4aeaacdc8fd772a7962faadfb41abe63e2a4861abc3dc"}, - {file = "orjson-3.10.6-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:66680eae4c4e7fc193d91cfc1353ad6d01b4801ae9b5314f17e11ba55e934183"}, - {file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:caff75b425db5ef8e8f23af93c80f072f97b4fb3afd4af44482905c9f588da28"}, - {file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3722fddb821b6036fd2a3c814f6bd9b57a89dc6337b9924ecd614ebce3271394"}, - {file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c2c116072a8533f2fec435fde4d134610f806bdac20188c7bd2081f3e9e0133f"}, - {file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6eeb13218c8cf34c61912e9df2de2853f1d009de0e46ea09ccdf3d757896af0a"}, - {file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:965a916373382674e323c957d560b953d81d7a8603fbeee26f7b8248638bd48b"}, - {file = "orjson-3.10.6-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:03c95484d53ed8e479cade8628c9cea00fd9d67f5554764a1110e0d5aa2de96e"}, - {file = "orjson-3.10.6-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:e060748a04cccf1e0a6f2358dffea9c080b849a4a68c28b1b907f272b5127e9b"}, - {file = "orjson-3.10.6-cp38-none-win32.whl", hash = "sha256:738dbe3ef909c4b019d69afc19caf6b5ed0e2f1c786b5d6215fbb7539246e4c6"}, - {file = "orjson-3.10.6-cp38-none-win_amd64.whl", hash = "sha256:d40f839dddf6a7d77114fe6b8a70218556408c71d4d6e29413bb5f150a692ff7"}, - {file = "orjson-3.10.6-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:697a35a083c4f834807a6232b3e62c8b280f7a44ad0b759fd4dce748951e70db"}, - {file = "orjson-3.10.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd502f96bf5ea9a61cbc0b2b5900d0dd68aa0da197179042bdd2be67e51a1e4b"}, - {file = "orjson-3.10.6-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f215789fb1667cdc874c1b8af6a84dc939fd802bf293a8334fce185c79cd359b"}, - {file = "orjson-3.10.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a2debd8ddce948a8c0938c8c93ade191d2f4ba4649a54302a7da905a81f00b56"}, - {file = "orjson-3.10.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5410111d7b6681d4b0d65e0f58a13be588d01b473822483f77f513c7f93bd3b2"}, - {file = "orjson-3.10.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb1f28a137337fdc18384079fa5726810681055b32b92253fa15ae5656e1dddb"}, - {file = "orjson-3.10.6-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:bf2fbbce5fe7cd1aa177ea3eab2b8e6a6bc6e8592e4279ed3db2d62e57c0e1b2"}, - {file = "orjson-3.10.6-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:79b9b9e33bd4c517445a62b90ca0cc279b0f1f3970655c3df9e608bc3f91741a"}, - {file = "orjson-3.10.6-cp39-none-win32.whl", hash = "sha256:30b0a09a2014e621b1adf66a4f705f0809358350a757508ee80209b2d8dae219"}, - {file = "orjson-3.10.6-cp39-none-win_amd64.whl", hash = "sha256:49e3bc615652617d463069f91b867a4458114c5b104e13b7ae6872e5f79d0844"}, - {file = "orjson-3.10.6.tar.gz", hash = "sha256:e54b63d0a7c6c54a5f5f726bc93a2078111ef060fec4ecbf34c5db800ca3b3a7"}, + {file = "orjson-3.10.7-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:74f4544f5a6405b90da8ea724d15ac9c36da4d72a738c64685003337401f5c12"}, + {file = "orjson-3.10.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:34a566f22c28222b08875b18b0dfbf8a947e69df21a9ed5c51a6bf91cfb944ac"}, + {file = "orjson-3.10.7-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bf6ba8ebc8ef5792e2337fb0419f8009729335bb400ece005606336b7fd7bab7"}, + {file = "orjson-3.10.7-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ac7cf6222b29fbda9e3a472b41e6a5538b48f2c8f99261eecd60aafbdb60690c"}, + {file = "orjson-3.10.7-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:de817e2f5fc75a9e7dd350c4b0f54617b280e26d1631811a43e7e968fa71e3e9"}, + {file = "orjson-3.10.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:348bdd16b32556cf8d7257b17cf2bdb7ab7976af4af41ebe79f9796c218f7e91"}, + {file = "orjson-3.10.7-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:479fd0844ddc3ca77e0fd99644c7fe2de8e8be1efcd57705b5c92e5186e8a250"}, + {file = "orjson-3.10.7-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:fdf5197a21dd660cf19dfd2a3ce79574588f8f5e2dbf21bda9ee2d2b46924d84"}, + {file = "orjson-3.10.7-cp310-none-win32.whl", hash = "sha256:d374d36726746c81a49f3ff8daa2898dccab6596864ebe43d50733275c629175"}, + {file = "orjson-3.10.7-cp310-none-win_amd64.whl", hash = "sha256:cb61938aec8b0ffb6eef484d480188a1777e67b05d58e41b435c74b9d84e0b9c"}, + {file = "orjson-3.10.7-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:7db8539039698ddfb9a524b4dd19508256107568cdad24f3682d5773e60504a2"}, + {file = "orjson-3.10.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:480f455222cb7a1dea35c57a67578848537d2602b46c464472c995297117fa09"}, + {file = "orjson-3.10.7-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8a9c9b168b3a19e37fe2778c0003359f07822c90fdff8f98d9d2a91b3144d8e0"}, + {file = "orjson-3.10.7-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8de062de550f63185e4c1c54151bdddfc5625e37daf0aa1e75d2a1293e3b7d9a"}, + {file = "orjson-3.10.7-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6b0dd04483499d1de9c8f6203f8975caf17a6000b9c0c54630cef02e44ee624e"}, + {file = "orjson-3.10.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b58d3795dafa334fc8fd46f7c5dc013e6ad06fd5b9a4cc98cb1456e7d3558bd6"}, + {file = "orjson-3.10.7-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:33cfb96c24034a878d83d1a9415799a73dc77480e6c40417e5dda0710d559ee6"}, + {file = "orjson-3.10.7-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:e724cebe1fadc2b23c6f7415bad5ee6239e00a69f30ee423f319c6af70e2a5c0"}, + {file = "orjson-3.10.7-cp311-none-win32.whl", hash = "sha256:82763b46053727a7168d29c772ed5c870fdae2f61aa8a25994c7984a19b1021f"}, + {file = "orjson-3.10.7-cp311-none-win_amd64.whl", hash = "sha256:eb8d384a24778abf29afb8e41d68fdd9a156cf6e5390c04cc07bbc24b89e98b5"}, + {file = "orjson-3.10.7-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:44a96f2d4c3af51bfac6bc4ef7b182aa33f2f054fd7f34cc0ee9a320d051d41f"}, + {file = "orjson-3.10.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76ac14cd57df0572453543f8f2575e2d01ae9e790c21f57627803f5e79b0d3c3"}, + {file = "orjson-3.10.7-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bdbb61dcc365dd9be94e8f7df91975edc9364d6a78c8f7adb69c1cdff318ec93"}, + {file = "orjson-3.10.7-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b48b3db6bb6e0a08fa8c83b47bc169623f801e5cc4f24442ab2b6617da3b5313"}, + {file = "orjson-3.10.7-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:23820a1563a1d386414fef15c249040042b8e5d07b40ab3fe3efbfbbcbcb8864"}, + {file = "orjson-3.10.7-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0c6a008e91d10a2564edbb6ee5069a9e66df3fbe11c9a005cb411f441fd2c09"}, + {file = "orjson-3.10.7-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d352ee8ac1926d6193f602cbe36b1643bbd1bbcb25e3c1a657a4390f3000c9a5"}, + {file = "orjson-3.10.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:d2d9f990623f15c0ae7ac608103c33dfe1486d2ed974ac3f40b693bad1a22a7b"}, + {file = "orjson-3.10.7-cp312-none-win32.whl", hash = "sha256:7c4c17f8157bd520cdb7195f75ddbd31671997cbe10aee559c2d613592e7d7eb"}, + {file = "orjson-3.10.7-cp312-none-win_amd64.whl", hash = "sha256:1d9c0e733e02ada3ed6098a10a8ee0052dd55774de3d9110d29868d24b17faa1"}, + {file = "orjson-3.10.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:77d325ed866876c0fa6492598ec01fe30e803272a6e8b10e992288b009cbe149"}, + {file = "orjson-3.10.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ea2c232deedcb605e853ae1db2cc94f7390ac776743b699b50b071b02bea6fe"}, + {file = "orjson-3.10.7-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:3dcfbede6737fdbef3ce9c37af3fb6142e8e1ebc10336daa05872bfb1d87839c"}, + {file = "orjson-3.10.7-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:11748c135f281203f4ee695b7f80bb1358a82a63905f9f0b794769483ea854ad"}, + {file = "orjson-3.10.7-cp313-none-win32.whl", hash = "sha256:a7e19150d215c7a13f39eb787d84db274298d3f83d85463e61d277bbd7f401d2"}, + {file = "orjson-3.10.7-cp313-none-win_amd64.whl", hash = "sha256:eef44224729e9525d5261cc8d28d6b11cafc90e6bd0be2157bde69a52ec83024"}, + {file = "orjson-3.10.7-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:6ea2b2258eff652c82652d5e0f02bd5e0463a6a52abb78e49ac288827aaa1469"}, + {file = "orjson-3.10.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:430ee4d85841e1483d487e7b81401785a5dfd69db5de01314538f31f8fbf7ee1"}, + {file = "orjson-3.10.7-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4b6146e439af4c2472c56f8540d799a67a81226e11992008cb47e1267a9b3225"}, + {file = "orjson-3.10.7-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:084e537806b458911137f76097e53ce7bf5806dda33ddf6aaa66a028f8d43a23"}, + {file = "orjson-3.10.7-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4829cf2195838e3f93b70fd3b4292156fc5e097aac3739859ac0dcc722b27ac0"}, + {file = "orjson-3.10.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1193b2416cbad1a769f868b1749535d5da47626ac29445803dae7cc64b3f5c98"}, + {file = "orjson-3.10.7-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:4e6c3da13e5a57e4b3dca2de059f243ebec705857522f188f0180ae88badd354"}, + {file = "orjson-3.10.7-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:c31008598424dfbe52ce8c5b47e0752dca918a4fdc4a2a32004efd9fab41d866"}, + {file = "orjson-3.10.7-cp38-none-win32.whl", hash = "sha256:7122a99831f9e7fe977dc45784d3b2edc821c172d545e6420c375e5a935f5a1c"}, + {file = "orjson-3.10.7-cp38-none-win_amd64.whl", hash = "sha256:a763bc0e58504cc803739e7df040685816145a6f3c8a589787084b54ebc9f16e"}, + {file = "orjson-3.10.7-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:e76be12658a6fa376fcd331b1ea4e58f5a06fd0220653450f0d415b8fd0fbe20"}, + {file = "orjson-3.10.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ed350d6978d28b92939bfeb1a0570c523f6170efc3f0a0ef1f1df287cd4f4960"}, + {file = "orjson-3.10.7-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:144888c76f8520e39bfa121b31fd637e18d4cc2f115727865fdf9fa325b10412"}, + {file = "orjson-3.10.7-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:09b2d92fd95ad2402188cf51573acde57eb269eddabaa60f69ea0d733e789fe9"}, + {file = "orjson-3.10.7-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5b24a579123fa884f3a3caadaed7b75eb5715ee2b17ab5c66ac97d29b18fe57f"}, + {file = "orjson-3.10.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e72591bcfe7512353bd609875ab38050efe3d55e18934e2f18950c108334b4ff"}, + {file = "orjson-3.10.7-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:f4db56635b58cd1a200b0a23744ff44206ee6aa428185e2b6c4a65b3197abdcd"}, + {file = "orjson-3.10.7-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:0fa5886854673222618638c6df7718ea7fe2f3f2384c452c9ccedc70b4a510a5"}, + {file = "orjson-3.10.7-cp39-none-win32.whl", hash = "sha256:8272527d08450ab16eb405f47e0f4ef0e5ff5981c3d82afe0efd25dcbef2bcd2"}, + {file = "orjson-3.10.7-cp39-none-win_amd64.whl", hash = "sha256:974683d4618c0c7dbf4f69c95a979734bf183d0658611760017f6e70a145af58"}, + {file = "orjson-3.10.7.tar.gz", hash = "sha256:75ef0640403f945f3a1f9f6400686560dbfb0fb5b16589ad62cd477043c4eee3"}, ] [[package]] @@ -5907,13 +6136,13 @@ tests = ["pytest (>=5.4.1)", "pytest-cov (>=2.8.1)", "pytest-mypy (>=0.8.0)", "p [[package]] name = "posthog" -version = "3.5.0" +version = "3.5.2" description = "Integrate PostHog into any python application." optional = false python-versions = "*" files = [ - {file = "posthog-3.5.0-py2.py3-none-any.whl", hash = "sha256:3c672be7ba6f95d555ea207d4486c171d06657eb34b3ce25eb043bfe7b6b5b76"}, - {file = "posthog-3.5.0.tar.gz", hash = "sha256:8f7e3b2c6e8714d0c0c542a2109b83a7549f63b7113a133ab2763a89245ef2ef"}, + {file = "posthog-3.5.2-py2.py3-none-any.whl", hash = "sha256:605b3d92369971cc99290b1fcc8534cbddac3726ef7972caa993454a5ecfb644"}, + {file = "posthog-3.5.2.tar.gz", hash = "sha256:a383a80c1f47e0243f5ce359e81e06e2e7b37eb39d1d6f8d01c3e64ed29df2ee"}, ] [package.dependencies] @@ -5930,23 +6159,23 @@ test = ["coverage", "flake8", "freezegun (==0.3.15)", "mock (>=2.0.0)", "pylint" [[package]] name = "primp" -version = "0.5.5" +version = "0.6.1" description = "HTTP client that can impersonate web browsers, mimicking their headers and `TLS/JA3/JA4/HTTP2` fingerprints" optional = false python-versions = ">=3.8" files = [ - {file = "primp-0.5.5-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:cff9792e8422424528c23574b5364882d68134ee2743f4a2ae6a765746fb3028"}, - {file = "primp-0.5.5-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:78e13fc5d4d90d44a005dbd5dda116981828c803c86cf85816b3bb5363b045c8"}, - {file = "primp-0.5.5-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3714abfda79d3f5c90a5363db58994afbdbacc4b94fe14e9e5f8ab97e7b82577"}, - {file = "primp-0.5.5-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e54765900ee40eceb6bde43676d7e0b2e16ca1f77c0753981fe5e40afc0c2010"}, - {file = "primp-0.5.5-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:66c7eecc5a55225c42cfb99af857df04f994f3dd0d327c016d3af5414c1a2242"}, - {file = "primp-0.5.5-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:df262271cc1a41f4bf80d68396e967a27d7d3d3de355a3d016f953130e7a20be"}, - {file = "primp-0.5.5-cp38-abi3-win_amd64.whl", hash = "sha256:8b424118d6bab6f9d4980d0f35d5ccc1213ab9f1042497c6ee11730f2f94a876"}, - {file = "primp-0.5.5.tar.gz", hash = "sha256:8623e8a25fd686785296b12175f4173250a08db1de9ee4063282e262b94bf3f2"}, + {file = "primp-0.6.1-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:60cfe95e0bdf154b0f9036d38acaddc9aef02d6723ed125839b01449672d3946"}, + {file = "primp-0.6.1-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:e1e92433ecf32639f9e800bc3a5d58b03792bdec99421b7fb06500e2fae63c85"}, + {file = "primp-0.6.1-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e02353f13f07fb5a6f91df9e2f4d8ec9f41312de95088744dce1c9729a3865d"}, + {file = "primp-0.6.1-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:c5a2ccfdf488b17be225a529a31e2b22724b2e22fba8e1ae168a222f857c2dc0"}, + {file = "primp-0.6.1-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:f335c2ace907800a23bbb7bc6e15acc7fff659b86a2d5858817f6ed79cea07cf"}, + {file = "primp-0.6.1-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5dc15bd9d47ded7bc356fcb5d8321972dcbeba18e7d3b7250e12bb7365447b2b"}, + {file = "primp-0.6.1-cp38-abi3-win_amd64.whl", hash = "sha256:eebf0412ebba4089547b16b97b765d83f69f1433d811bb02b02cdcdbca20f672"}, + {file = "primp-0.6.1.tar.gz", hash = "sha256:64b3c12e3d463a887518811c46f3ec37cca02e6af1ddf1287e548342de436301"}, ] [package.extras] -dev = ["pytest (>=8.1.1)"] +dev = ["certifi", "pytest (>=8.1.1)"] [[package]] name = "prompt-toolkit" @@ -5999,6 +6228,35 @@ files = [ {file = "protobuf-4.25.4.tar.gz", hash = "sha256:0dc4a62cc4052a036ee2204d26fe4d835c62827c855c8a03f29fe6da146b380d"}, ] +[[package]] +name = "psutil" +version = "6.0.0" +description = "Cross-platform lib for process and system monitoring in Python." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +files = [ + {file = "psutil-6.0.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:a021da3e881cd935e64a3d0a20983bda0bb4cf80e4f74fa9bfcb1bc5785360c6"}, + {file = "psutil-6.0.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:1287c2b95f1c0a364d23bc6f2ea2365a8d4d9b726a3be7294296ff7ba97c17f0"}, + {file = "psutil-6.0.0-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:a9a3dbfb4de4f18174528d87cc352d1f788b7496991cca33c6996f40c9e3c92c"}, + {file = "psutil-6.0.0-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:6ec7588fb3ddaec7344a825afe298db83fe01bfaaab39155fa84cf1c0d6b13c3"}, + {file = "psutil-6.0.0-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:1e7c870afcb7d91fdea2b37c24aeb08f98b6d67257a5cb0a8bc3ac68d0f1a68c"}, + {file = "psutil-6.0.0-cp27-none-win32.whl", hash = "sha256:02b69001f44cc73c1c5279d02b30a817e339ceb258ad75997325e0e6169d8b35"}, + {file = "psutil-6.0.0-cp27-none-win_amd64.whl", hash = "sha256:21f1fb635deccd510f69f485b87433460a603919b45e2a324ad65b0cc74f8fb1"}, + {file = "psutil-6.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c588a7e9b1173b6e866756dde596fd4cad94f9399daf99ad8c3258b3cb2b47a0"}, + {file = "psutil-6.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ed2440ada7ef7d0d608f20ad89a04ec47d2d3ab7190896cd62ca5fc4fe08bf0"}, + {file = "psutil-6.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd9a97c8e94059b0ef54a7d4baf13b405011176c3b6ff257c247cae0d560ecd"}, + {file = "psutil-6.0.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2e8d0054fc88153ca0544f5c4d554d42e33df2e009c4ff42284ac9ebdef4132"}, + {file = "psutil-6.0.0-cp36-cp36m-win32.whl", hash = "sha256:fc8c9510cde0146432bbdb433322861ee8c3efbf8589865c8bf8d21cb30c4d14"}, + {file = "psutil-6.0.0-cp36-cp36m-win_amd64.whl", hash = "sha256:34859b8d8f423b86e4385ff3665d3f4d94be3cdf48221fbe476e883514fdb71c"}, + {file = "psutil-6.0.0-cp37-abi3-win32.whl", hash = "sha256:a495580d6bae27291324fe60cea0b5a7c23fa36a7cd35035a16d93bdcf076b9d"}, + {file = "psutil-6.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:33ea5e1c975250a720b3a6609c490db40dae5d83a4eb315170c4fe0d8b1f34b3"}, + {file = "psutil-6.0.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:ffe7fc9b6b36beadc8c322f84e1caff51e8703b88eee1da46d1e3a6ae11b4fd0"}, + {file = "psutil-6.0.0.tar.gz", hash = "sha256:8faae4f310b6d969fa26ca0545338b21f73c6b15db7c4a8d934a5482faa818f2"}, +] + +[package.extras] +test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] + [[package]] name = "psycopg2-binary" version = "2.9.9" @@ -6363,13 +6621,13 @@ semver = ["semver (>=3.0.2)"] [[package]] name = "pydantic-settings" -version = "2.3.4" +version = "2.4.0" description = "Settings management using Pydantic" optional = false python-versions = ">=3.8" files = [ - {file = "pydantic_settings-2.3.4-py3-none-any.whl", hash = "sha256:11ad8bacb68a045f00e4f862c7a718c8a9ec766aa8fd4c32e39a0594b207b53a"}, - {file = "pydantic_settings-2.3.4.tar.gz", hash = "sha256:c5802e3d62b78e82522319bbc9b8f8ffb28ad1c988a99311d04f2a6051fca0a7"}, + {file = "pydantic_settings-2.4.0-py3-none-any.whl", hash = "sha256:bb6849dc067f1687574c12a639e231f3a6feeed0a12d710c1382045c5db1c315"}, + {file = "pydantic_settings-2.4.0.tar.gz", hash = "sha256:ed81c3a0f46392b4d7c0a565c05884e6e54b3456e6f0fe4d8814981172dc9a88"}, ] [package.dependencies] @@ -6377,9 +6635,27 @@ pydantic = ">=2.7.0" python-dotenv = ">=0.21.0" [package.extras] +azure-key-vault = ["azure-identity (>=1.16.0)", "azure-keyvault-secrets (>=4.8.0)"] toml = ["tomli (>=2.0.1)"] yaml = ["pyyaml (>=6.0.1)"] +[[package]] +name = "pydash" +version = "8.0.3" +description = "The kitchen sink of Python utility libraries for doing \"stuff\" in a functional way. Based on the Lo-Dash Javascript library." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pydash-8.0.3-py3-none-any.whl", hash = "sha256:c16871476822ee6b59b87e206dd27888240eff50a7b4cd72a4b80b43b6b994d7"}, + {file = "pydash-8.0.3.tar.gz", hash = "sha256:1b27cd3da05b72f0e5ff786c523afd82af796936462e631ffd1b228d91f8b9aa"}, +] + +[package.dependencies] +typing-extensions = ">3.10,<4.6.0 || >4.6.0" + +[package.extras] +dev = ["build", "coverage", "furo", "invoke", "mypy", "pytest", "pytest-cov", "pytest-mypy-testing", "ruff", "sphinx", "sphinx-autodoc-typehints", "tox", "twine", "wheel"] + [[package]] name = "pygments" version = "2.18.0" @@ -6416,13 +6692,13 @@ tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] [[package]] name = "pymilvus" -version = "2.4.4" +version = "2.4.5" description = "Python Sdk for Milvus" optional = false python-versions = ">=3.8" files = [ - {file = "pymilvus-2.4.4-py3-none-any.whl", hash = "sha256:073b76bc36f6f4e70f0f0a0023a53324f0ba8ef9a60883f87cd30a44b6c6f2b5"}, - {file = "pymilvus-2.4.4.tar.gz", hash = "sha256:50c53eb103e034fbffe936fe942751ea3dbd2452e18cf79acc52360ed4987fb7"}, + {file = "pymilvus-2.4.5-py3-none-any.whl", hash = "sha256:dc4f2d1eac8db9cf3951de39566a1a244695760bb94d8310fbfc73d6d62bb267"}, + {file = "pymilvus-2.4.5.tar.gz", hash = "sha256:1a497fe9b41d6bf62b1d5e1c412960922dde1598576fcbb8818040c8af11149f"}, ] [package.dependencies] @@ -6431,7 +6707,7 @@ grpcio = ">=1.49.1,<=1.63.0" milvus-lite = {version = ">=2.4.0,<2.5.0", markers = "sys_platform != \"win32\""} pandas = ">=1.2.4" protobuf = ">=3.20.0" -setuptools = ">=67" +setuptools = ">69" ujson = ">=2.0.0" [package.extras] @@ -6545,13 +6821,13 @@ files = [ [[package]] name = "pytest" -version = "8.1.2" +version = "8.3.2" description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.8" files = [ - {file = "pytest-8.1.2-py3-none-any.whl", hash = "sha256:6c06dc309ff46a05721e6fd48e492a775ed8165d2ecdf57f156a80c7e95bb142"}, - {file = "pytest-8.1.2.tar.gz", hash = "sha256:f3c45d1d5eed96b01a2aea70dee6a4a366d51d38f9957768083e4fecfc77f3ef"}, + {file = "pytest-8.3.2-py3-none-any.whl", hash = "sha256:4ba08f9ae7dcf84ded419494d229b48d0903ea6407b030eaec46df5e6a73bba5"}, + {file = "pytest-8.3.2.tar.gz", hash = "sha256:c132345d12ce551242c87269de812483f5bcc87cdbb4722e48487ba194f9fdce"}, ] [package.dependencies] @@ -6559,11 +6835,11 @@ colorama = {version = "*", markers = "sys_platform == \"win32\""} exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} iniconfig = "*" packaging = "*" -pluggy = ">=1.4,<2.0" +pluggy = ">=1.5,<2" tomli = {version = ">=1", markers = "python_version < \"3.11\""} [package.extras] -testing = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] [[package]] name = "pytest-benchmark" @@ -7147,6 +7423,21 @@ hiredis = {version = ">1.0.0", optional = true, markers = "extra == \"hiredis\"" hiredis = ["hiredis (>1.0.0)"] ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"] +[[package]] +name = "referencing" +version = "0.35.1" +description = "JSON Referencing + Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "referencing-0.35.1-py3-none-any.whl", hash = "sha256:eda6d3234d62814d1c64e305c1331c9a3a6132da475ab6382eaa997b21ee75de"}, + {file = "referencing-0.35.1.tar.gz", hash = "sha256:25b42124a6c8b632a425174f24087783efb348a6f1e0008e63cd4466fedf703c"}, +] + +[package.dependencies] +attrs = ">=22.2.0" +rpds-py = ">=0.7.0" + [[package]] name = "regex" version = "2024.7.24" @@ -7354,6 +7645,118 @@ pygments = ">=2.13.0,<3.0.0" [package.extras] jupyter = ["ipywidgets (>=7.5.1,<9)"] +[[package]] +name = "rpds-py" +version = "0.20.0" +description = "Python bindings to Rust's persistent data structures (rpds)" +optional = false +python-versions = ">=3.8" +files = [ + {file = "rpds_py-0.20.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3ad0fda1635f8439cde85c700f964b23ed5fc2d28016b32b9ee5fe30da5c84e2"}, + {file = "rpds_py-0.20.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9bb4a0d90fdb03437c109a17eade42dfbf6190408f29b2744114d11586611d6f"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6377e647bbfd0a0b159fe557f2c6c602c159fc752fa316572f012fc0bf67150"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb851b7df9dda52dc1415ebee12362047ce771fc36914586b2e9fcbd7d293b3e"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1e0f80b739e5a8f54837be5d5c924483996b603d5502bfff79bf33da06164ee2"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5a8c94dad2e45324fc74dce25e1645d4d14df9a4e54a30fa0ae8bad9a63928e3"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8e604fe73ba048c06085beaf51147eaec7df856824bfe7b98657cf436623daf"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:df3de6b7726b52966edf29663e57306b23ef775faf0ac01a3e9f4012a24a4140"}, + {file = "rpds_py-0.20.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:cf258ede5bc22a45c8e726b29835b9303c285ab46fc7c3a4cc770736b5304c9f"}, + {file = "rpds_py-0.20.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:55fea87029cded5df854ca7e192ec7bdb7ecd1d9a3f63d5c4eb09148acf4a7ce"}, + {file = "rpds_py-0.20.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ae94bd0b2f02c28e199e9bc51485d0c5601f58780636185660f86bf80c89af94"}, + {file = "rpds_py-0.20.0-cp310-none-win32.whl", hash = "sha256:28527c685f237c05445efec62426d285e47a58fb05ba0090a4340b73ecda6dee"}, + {file = "rpds_py-0.20.0-cp310-none-win_amd64.whl", hash = "sha256:238a2d5b1cad28cdc6ed15faf93a998336eb041c4e440dd7f902528b8891b399"}, + {file = "rpds_py-0.20.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:ac2f4f7a98934c2ed6505aead07b979e6f999389f16b714448fb39bbaa86a489"}, + {file = "rpds_py-0.20.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:220002c1b846db9afd83371d08d239fdc865e8f8c5795bbaec20916a76db3318"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8d7919548df3f25374a1f5d01fbcd38dacab338ef5f33e044744b5c36729c8db"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:758406267907b3781beee0f0edfe4a179fbd97c0be2e9b1154d7f0a1279cf8e5"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3d61339e9f84a3f0767b1995adfb171a0d00a1185192718a17af6e124728e0f5"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1259c7b3705ac0a0bd38197565a5d603218591d3f6cee6e614e380b6ba61c6f6"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c1dc0f53856b9cc9a0ccca0a7cc61d3d20a7088201c0937f3f4048c1718a209"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7e60cb630f674a31f0368ed32b2a6b4331b8350d67de53c0359992444b116dd3"}, + {file = "rpds_py-0.20.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:dbe982f38565bb50cb7fb061ebf762c2f254ca3d8c20d4006878766e84266272"}, + {file = "rpds_py-0.20.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:514b3293b64187172bc77c8fb0cdae26981618021053b30d8371c3a902d4d5ad"}, + {file = "rpds_py-0.20.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d0a26ffe9d4dd35e4dfdd1e71f46401cff0181c75ac174711ccff0459135fa58"}, + {file = "rpds_py-0.20.0-cp311-none-win32.whl", hash = "sha256:89c19a494bf3ad08c1da49445cc5d13d8fefc265f48ee7e7556839acdacf69d0"}, + {file = "rpds_py-0.20.0-cp311-none-win_amd64.whl", hash = "sha256:c638144ce971df84650d3ed0096e2ae7af8e62ecbbb7b201c8935c370df00a2c"}, + {file = "rpds_py-0.20.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:a84ab91cbe7aab97f7446652d0ed37d35b68a465aeef8fc41932a9d7eee2c1a6"}, + {file = "rpds_py-0.20.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:56e27147a5a4c2c21633ff8475d185734c0e4befd1c989b5b95a5d0db699b21b"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2580b0c34583b85efec8c5c5ec9edf2dfe817330cc882ee972ae650e7b5ef739"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b80d4a7900cf6b66bb9cee5c352b2d708e29e5a37fe9bf784fa97fc11504bf6c"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:50eccbf054e62a7b2209b28dc7a22d6254860209d6753e6b78cfaeb0075d7bee"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:49a8063ea4296b3a7e81a5dfb8f7b2d73f0b1c20c2af401fb0cdf22e14711a96"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ea438162a9fcbee3ecf36c23e6c68237479f89f962f82dae83dc15feeceb37e4"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:18d7585c463087bddcfa74c2ba267339f14f2515158ac4db30b1f9cbdb62c8ef"}, + {file = "rpds_py-0.20.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d4c7d1a051eeb39f5c9547e82ea27cbcc28338482242e3e0b7768033cb083821"}, + {file = "rpds_py-0.20.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:e4df1e3b3bec320790f699890d41c59d250f6beda159ea3c44c3f5bac1976940"}, + {file = "rpds_py-0.20.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2cf126d33a91ee6eedc7f3197b53e87a2acdac63602c0f03a02dd69e4b138174"}, + {file = "rpds_py-0.20.0-cp312-none-win32.whl", hash = "sha256:8bc7690f7caee50b04a79bf017a8d020c1f48c2a1077ffe172abec59870f1139"}, + {file = "rpds_py-0.20.0-cp312-none-win_amd64.whl", hash = "sha256:0e13e6952ef264c40587d510ad676a988df19adea20444c2b295e536457bc585"}, + {file = "rpds_py-0.20.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:aa9a0521aeca7d4941499a73ad7d4f8ffa3d1affc50b9ea11d992cd7eff18a29"}, + {file = "rpds_py-0.20.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4a1f1d51eccb7e6c32ae89243cb352389228ea62f89cd80823ea7dd1b98e0b91"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8a86a9b96070674fc88b6f9f71a97d2c1d3e5165574615d1f9168ecba4cecb24"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6c8ef2ebf76df43f5750b46851ed1cdf8f109d7787ca40035fe19fbdc1acc5a7"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b74b25f024b421d5859d156750ea9a65651793d51b76a2e9238c05c9d5f203a9"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:57eb94a8c16ab08fef6404301c38318e2c5a32216bf5de453e2714c964c125c8"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1940dae14e715e2e02dfd5b0f64a52e8374a517a1e531ad9412319dc3ac7879"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d20277fd62e1b992a50c43f13fbe13277a31f8c9f70d59759c88f644d66c619f"}, + {file = "rpds_py-0.20.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:06db23d43f26478303e954c34c75182356ca9aa7797d22c5345b16871ab9c45c"}, + {file = "rpds_py-0.20.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b2a5db5397d82fa847e4c624b0c98fe59d2d9b7cf0ce6de09e4d2e80f8f5b3f2"}, + {file = "rpds_py-0.20.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5a35df9f5548fd79cb2f52d27182108c3e6641a4feb0f39067911bf2adaa3e57"}, + {file = "rpds_py-0.20.0-cp313-none-win32.whl", hash = "sha256:fd2d84f40633bc475ef2d5490b9c19543fbf18596dcb1b291e3a12ea5d722f7a"}, + {file = "rpds_py-0.20.0-cp313-none-win_amd64.whl", hash = "sha256:9bc2d153989e3216b0559251b0c260cfd168ec78b1fac33dd485750a228db5a2"}, + {file = "rpds_py-0.20.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:f2fbf7db2012d4876fb0d66b5b9ba6591197b0f165db8d99371d976546472a24"}, + {file = "rpds_py-0.20.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1e5f3cd7397c8f86c8cc72d5a791071431c108edd79872cdd96e00abd8497d29"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce9845054c13696f7af7f2b353e6b4f676dab1b4b215d7fe5e05c6f8bb06f965"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c3e130fd0ec56cb76eb49ef52faead8ff09d13f4527e9b0c400307ff72b408e1"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4b16aa0107ecb512b568244ef461f27697164d9a68d8b35090e9b0c1c8b27752"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:aa7f429242aae2947246587d2964fad750b79e8c233a2367f71b554e9447949c"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af0fc424a5842a11e28956e69395fbbeab2c97c42253169d87e90aac2886d751"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b8c00a3b1e70c1d3891f0db1b05292747f0dbcfb49c43f9244d04c70fbc40eb8"}, + {file = "rpds_py-0.20.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:40ce74fc86ee4645d0a225498d091d8bc61f39b709ebef8204cb8b5a464d3c0e"}, + {file = "rpds_py-0.20.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:4fe84294c7019456e56d93e8ababdad5a329cd25975be749c3f5f558abb48253"}, + {file = "rpds_py-0.20.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:338ca4539aad4ce70a656e5187a3a31c5204f261aef9f6ab50e50bcdffaf050a"}, + {file = "rpds_py-0.20.0-cp38-none-win32.whl", hash = "sha256:54b43a2b07db18314669092bb2de584524d1ef414588780261e31e85846c26a5"}, + {file = "rpds_py-0.20.0-cp38-none-win_amd64.whl", hash = "sha256:a1862d2d7ce1674cffa6d186d53ca95c6e17ed2b06b3f4c476173565c862d232"}, + {file = "rpds_py-0.20.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:3fde368e9140312b6e8b6c09fb9f8c8c2f00999d1823403ae90cc00480221b22"}, + {file = "rpds_py-0.20.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9824fb430c9cf9af743cf7aaf6707bf14323fb51ee74425c380f4c846ea70789"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:11ef6ce74616342888b69878d45e9f779b95d4bd48b382a229fe624a409b72c5"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c52d3f2f82b763a24ef52f5d24358553e8403ce05f893b5347098014f2d9eff2"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9d35cef91e59ebbeaa45214861874bc6f19eb35de96db73e467a8358d701a96c"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d72278a30111e5b5525c1dd96120d9e958464316f55adb030433ea905866f4de"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4c29cbbba378759ac5786730d1c3cb4ec6f8ababf5c42a9ce303dc4b3d08cda"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6632f2d04f15d1bd6fe0eedd3b86d9061b836ddca4c03d5cf5c7e9e6b7c14580"}, + {file = "rpds_py-0.20.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:d0b67d87bb45ed1cd020e8fbf2307d449b68abc45402fe1a4ac9e46c3c8b192b"}, + {file = "rpds_py-0.20.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:ec31a99ca63bf3cd7f1a5ac9fe95c5e2d060d3c768a09bc1d16e235840861420"}, + {file = "rpds_py-0.20.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:22e6c9976e38f4d8c4a63bd8a8edac5307dffd3ee7e6026d97f3cc3a2dc02a0b"}, + {file = "rpds_py-0.20.0-cp39-none-win32.whl", hash = "sha256:569b3ea770c2717b730b61998b6c54996adee3cef69fc28d444f3e7920313cf7"}, + {file = "rpds_py-0.20.0-cp39-none-win_amd64.whl", hash = "sha256:e6900ecdd50ce0facf703f7a00df12374b74bbc8ad9fe0f6559947fb20f82364"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:617c7357272c67696fd052811e352ac54ed1d9b49ab370261a80d3b6ce385045"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:9426133526f69fcaba6e42146b4e12d6bc6c839b8b555097020e2b78ce908dcc"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:deb62214c42a261cb3eb04d474f7155279c1a8a8c30ac89b7dcb1721d92c3c02"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fcaeb7b57f1a1e071ebd748984359fef83ecb026325b9d4ca847c95bc7311c92"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d454b8749b4bd70dd0a79f428731ee263fa6995f83ccb8bada706e8d1d3ff89d"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d807dc2051abe041b6649681dce568f8e10668e3c1c6543ebae58f2d7e617855"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3c20f0ddeb6e29126d45f89206b8291352b8c5b44384e78a6499d68b52ae511"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b7f19250ceef892adf27f0399b9e5afad019288e9be756d6919cb58892129f51"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:4f1ed4749a08379555cebf4650453f14452eaa9c43d0a95c49db50c18b7da075"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:dcedf0b42bcb4cfff4101d7771a10532415a6106062f005ab97d1d0ab5681c60"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:39ed0d010457a78f54090fafb5d108501b5aa5604cc22408fc1c0c77eac14344"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:bb273176be34a746bdac0b0d7e4e2c467323d13640b736c4c477881a3220a989"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f918a1a130a6dfe1d7fe0f105064141342e7dd1611f2e6a21cd2f5c8cb1cfb3e"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:f60012a73aa396be721558caa3a6fd49b3dd0033d1675c6d59c4502e870fcf0c"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d2b1ad682a3dfda2a4e8ad8572f3100f95fad98cb99faf37ff0ddfe9cbf9d03"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:614fdafe9f5f19c63ea02817fa4861c606a59a604a77c8cdef5aa01d28b97921"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fa518bcd7600c584bf42e6617ee8132869e877db2f76bcdc281ec6a4113a53ab"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f0475242f447cc6cb8a9dd486d68b2ef7fbee84427124c232bff5f63b1fe11e5"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f90a4cd061914a60bd51c68bcb4357086991bd0bb93d8aa66a6da7701370708f"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:def7400461c3a3f26e49078302e1c1b38f6752342c77e3cf72ce91ca69fb1bc1"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:65794e4048ee837494aea3c21a28ad5fc080994dfba5b036cf84de37f7ad5074"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:faefcc78f53a88f3076b7f8be0a8f8d35133a3ecf7f3770895c25f8813460f08"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:5b4f105deeffa28bbcdff6c49b34e74903139afa690e35d2d9e3c2c2fba18cec"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:fdfc3a892927458d98f3d55428ae46b921d1f7543b89382fdb483f5640daaec8"}, + {file = "rpds_py-0.20.0.tar.gz", hash = "sha256:d72a210824facfdaf8768cf2d7ca25a042c30320b3020de2fa04640920d4e121"}, +] + [[package]] name = "rsa" version = "4.9" @@ -7591,36 +7994,44 @@ tests = ["black (>=24.3.0)", "matplotlib (>=3.3.4)", "mypy (>=1.9)", "numpydoc ( [[package]] name = "scipy" -version = "1.14.0" +version = "1.14.1" description = "Fundamental algorithms for scientific computing in Python" optional = false python-versions = ">=3.10" files = [ - {file = "scipy-1.14.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7e911933d54ead4d557c02402710c2396529540b81dd554fc1ba270eb7308484"}, - {file = "scipy-1.14.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:687af0a35462402dd851726295c1a5ae5f987bd6e9026f52e9505994e2f84ef6"}, - {file = "scipy-1.14.0-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:07e179dc0205a50721022344fb85074f772eadbda1e1b3eecdc483f8033709b7"}, - {file = "scipy-1.14.0-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:6a9c9a9b226d9a21e0a208bdb024c3982932e43811b62d202aaf1bb59af264b1"}, - {file = "scipy-1.14.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:076c27284c768b84a45dcf2e914d4000aac537da74236a0d45d82c6fa4b7b3c0"}, - {file = "scipy-1.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42470ea0195336df319741e230626b6225a740fd9dce9642ca13e98f667047c0"}, - {file = "scipy-1.14.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:176c6f0d0470a32f1b2efaf40c3d37a24876cebf447498a4cefb947a79c21e9d"}, - {file = "scipy-1.14.0-cp310-cp310-win_amd64.whl", hash = "sha256:ad36af9626d27a4326c8e884917b7ec321d8a1841cd6dacc67d2a9e90c2f0359"}, - {file = "scipy-1.14.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6d056a8709ccda6cf36cdd2eac597d13bc03dba38360f418560a93050c76a16e"}, - {file = "scipy-1.14.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:f0a50da861a7ec4573b7c716b2ebdcdf142b66b756a0d392c236ae568b3a93fb"}, - {file = "scipy-1.14.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:94c164a9e2498e68308e6e148646e486d979f7fcdb8b4cf34b5441894bdb9caf"}, - {file = "scipy-1.14.0-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:a7d46c3e0aea5c064e734c3eac5cf9eb1f8c4ceee756262f2c7327c4c2691c86"}, - {file = "scipy-1.14.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9eee2989868e274aae26125345584254d97c56194c072ed96cb433f32f692ed8"}, - {file = "scipy-1.14.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e3154691b9f7ed73778d746da2df67a19d046a6c8087c8b385bc4cdb2cfca74"}, - {file = "scipy-1.14.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c40003d880f39c11c1edbae8144e3813904b10514cd3d3d00c277ae996488cdb"}, - {file = "scipy-1.14.0-cp311-cp311-win_amd64.whl", hash = "sha256:5b083c8940028bb7e0b4172acafda6df762da1927b9091f9611b0bcd8676f2bc"}, - {file = "scipy-1.14.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:bff2438ea1330e06e53c424893ec0072640dac00f29c6a43a575cbae4c99b2b9"}, - {file = "scipy-1.14.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:bbc0471b5f22c11c389075d091d3885693fd3f5e9a54ce051b46308bc787e5d4"}, - {file = "scipy-1.14.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:64b2ff514a98cf2bb734a9f90d32dc89dc6ad4a4a36a312cd0d6327170339eb0"}, - {file = "scipy-1.14.0-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:7d3da42fbbbb860211a811782504f38ae7aaec9de8764a9bef6b262de7a2b50f"}, - {file = "scipy-1.14.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d91db2c41dd6c20646af280355d41dfa1ec7eead235642178bd57635a3f82209"}, - {file = "scipy-1.14.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a01cc03bcdc777c9da3cfdcc74b5a75caffb48a6c39c8450a9a05f82c4250a14"}, - {file = "scipy-1.14.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:65df4da3c12a2bb9ad52b86b4dcf46813e869afb006e58be0f516bc370165159"}, - {file = "scipy-1.14.0-cp312-cp312-win_amd64.whl", hash = "sha256:4c4161597c75043f7154238ef419c29a64ac4a7c889d588ea77690ac4d0d9b20"}, - {file = "scipy-1.14.0.tar.gz", hash = "sha256:b5923f48cb840380f9854339176ef21763118a7300a88203ccd0bdd26e58527b"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:b28d2ca4add7ac16ae8bb6632a3c86e4b9e4d52d3e34267f6e1b0c1f8d87e389"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d0d2821003174de06b69e58cef2316a6622b60ee613121199cb2852a873f8cf3"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:8bddf15838ba768bb5f5083c1ea012d64c9a444e16192762bd858f1e126196d0"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:97c5dddd5932bd2a1a31c927ba5e1463a53b87ca96b5c9bdf5dfd6096e27efc3"}, + {file = "scipy-1.14.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2ff0a7e01e422c15739ecd64432743cf7aae2b03f3084288f399affcefe5222d"}, + {file = "scipy-1.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e32dced201274bf96899e6491d9ba3e9a5f6b336708656466ad0522d8528f69"}, + {file = "scipy-1.14.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8426251ad1e4ad903a4514712d2fa8fdd5382c978010d1c6f5f37ef286a713ad"}, + {file = "scipy-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:a49f6ed96f83966f576b33a44257d869756df6cf1ef4934f59dd58b25e0327e5"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:2da0469a4ef0ecd3693761acbdc20f2fdeafb69e6819cc081308cc978153c675"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:c0ee987efa6737242745f347835da2cc5bb9f1b42996a4d97d5c7ff7928cb6f2"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:3a1b111fac6baec1c1d92f27e76511c9e7218f1695d61b59e05e0fe04dc59617"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:8475230e55549ab3f207bff11ebfc91c805dc3463ef62eda3ccf593254524ce8"}, + {file = "scipy-1.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:278266012eb69f4a720827bdd2dc54b2271c97d84255b2faaa8f161a158c3b37"}, + {file = "scipy-1.14.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fef8c87f8abfb884dac04e97824b61299880c43f4ce675dd2cbeadd3c9b466d2"}, + {file = "scipy-1.14.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b05d43735bb2f07d689f56f7b474788a13ed8adc484a85aa65c0fd931cf9ccd2"}, + {file = "scipy-1.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:716e389b694c4bb564b4fc0c51bc84d381735e0d39d3f26ec1af2556ec6aad94"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc"}, + {file = "scipy-1.14.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310"}, + {file = "scipy-1.14.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066"}, + {file = "scipy-1.14.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1"}, + {file = "scipy-1.14.1-cp312-cp312-win_amd64.whl", hash = "sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1729560c906963fc8389f6aac023739ff3983e727b1a4d87696b7bf108316a79"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:4079b90df244709e675cdc8b93bfd8a395d59af40b72e339c2287c91860deb8e"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:e0cf28db0f24a38b2a0ca33a85a54852586e43cf6fd876365c86e0657cfe7d73"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e"}, + {file = "scipy-1.14.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b99722ea48b7ea25e8e015e8341ae74624f72e5f21fc2abd45f3a93266de4c5d"}, + {file = "scipy-1.14.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5149e3fd2d686e42144a093b206aef01932a0059c2a33ddfa67f5f035bdfe13e"}, + {file = "scipy-1.14.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e4f5a7c49323533f9103d4dacf4e4f07078f360743dec7f7596949149efeec06"}, + {file = "scipy-1.14.1-cp313-cp313-win_amd64.whl", hash = "sha256:baff393942b550823bfce952bb62270ee17504d02a1801d7fd0719534dfb9c84"}, + {file = "scipy-1.14.1.tar.gz", hash = "sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417"}, ] [package.dependencies] @@ -7628,8 +8039,8 @@ numpy = ">=1.23.5,<2.3" [package.extras] dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy (==1.10.0)", "pycodestyle", "pydevtool", "rich-click", "ruff (>=0.0.292)", "types-psutil", "typing_extensions"] -doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.13.1)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0)", "sphinx-design (>=0.4.0)"] -test = ["Cython", "array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] +doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.13.1)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0,<=7.3.7)", "sphinx-design (>=0.4.0)"] +test = ["Cython", "array-api-strict (>=2.0)", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] [[package]] name = "sentry-sdk" @@ -7683,19 +8094,19 @@ tornado = ["tornado (>=5)"] [[package]] name = "setuptools" -version = "72.1.0" +version = "73.0.1" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-72.1.0-py3-none-any.whl", hash = "sha256:5a03e1860cf56bb6ef48ce186b0e557fdba433237481a9a625176c2831be15d1"}, - {file = "setuptools-72.1.0.tar.gz", hash = "sha256:8d243eff56d095e5817f796ede6ae32941278f542e0f941867cc05ae52b162ec"}, + {file = "setuptools-73.0.1-py3-none-any.whl", hash = "sha256:b208925fcb9f7af924ed2dc04708ea89791e24bde0d3020b27df0e116088b34e"}, + {file = "setuptools-73.0.1.tar.gz", hash = "sha256:d59a3e788ab7e012ab2c4baed1b376da6366883ee20d7a5fc426816e3d7b1193"}, ] [package.extras] -core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.text (>=3.7)", "more-itertools (>=8.8)", "ordered-set (>=3.1.1)", "packaging (>=24)", "platformdirs (>=2.6.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] -doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] -test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "mypy (==1.11.*)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (<0.4)", "pytest-ruff (>=0.2.1)", "pytest-ruff (>=0.3.2)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] +core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.text (>=3.7)", "more-itertools (>=8.8)", "packaging (>=24)", "platformdirs (>=2.6.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"] +test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "mypy (==1.11.*)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (<0.4)", "pytest-ruff (>=0.2.1)", "pytest-ruff (>=0.3.2)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] [[package]] name = "sgmllib3k" @@ -7709,47 +8120,53 @@ files = [ [[package]] name = "shapely" -version = "2.0.5" +version = "2.0.6" description = "Manipulation and analysis of geometric objects" optional = false python-versions = ">=3.7" files = [ - {file = "shapely-2.0.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:89d34787c44f77a7d37d55ae821f3a784fa33592b9d217a45053a93ade899375"}, - {file = "shapely-2.0.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:798090b426142df2c5258779c1d8d5734ec6942f778dab6c6c30cfe7f3bf64ff"}, - {file = "shapely-2.0.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45211276900c4790d6bfc6105cbf1030742da67594ea4161a9ce6812a6721e68"}, - {file = "shapely-2.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e119444bc27ca33e786772b81760f2028d930ac55dafe9bc50ef538b794a8e1"}, - {file = "shapely-2.0.5-cp310-cp310-win32.whl", hash = "sha256:9a4492a2b2ccbeaebf181e7310d2dfff4fdd505aef59d6cb0f217607cb042fb3"}, - {file = "shapely-2.0.5-cp310-cp310-win_amd64.whl", hash = "sha256:1e5cb5ee72f1bc7ace737c9ecd30dc174a5295fae412972d3879bac2e82c8fae"}, - {file = "shapely-2.0.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5bbfb048a74cf273db9091ff3155d373020852805a37dfc846ab71dde4be93ec"}, - {file = "shapely-2.0.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:93be600cbe2fbaa86c8eb70656369f2f7104cd231f0d6585c7d0aa555d6878b8"}, - {file = "shapely-2.0.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f8e71bb9a46814019f6644c4e2560a09d44b80100e46e371578f35eaaa9da1c"}, - {file = "shapely-2.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5251c28a29012e92de01d2e84f11637eb1d48184ee8f22e2df6c8c578d26760"}, - {file = "shapely-2.0.5-cp311-cp311-win32.whl", hash = "sha256:35110e80070d664781ec7955c7de557456b25727a0257b354830abb759bf8311"}, - {file = "shapely-2.0.5-cp311-cp311-win_amd64.whl", hash = "sha256:6c6b78c0007a34ce7144f98b7418800e0a6a5d9a762f2244b00ea560525290c9"}, - {file = "shapely-2.0.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:03bd7b5fa5deb44795cc0a503999d10ae9d8a22df54ae8d4a4cd2e8a93466195"}, - {file = "shapely-2.0.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ff9521991ed9e201c2e923da014e766c1aa04771bc93e6fe97c27dcf0d40ace"}, - {file = "shapely-2.0.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b65365cfbf657604e50d15161ffcc68de5cdb22a601bbf7823540ab4918a98d"}, - {file = "shapely-2.0.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21f64e647a025b61b19585d2247137b3a38a35314ea68c66aaf507a1c03ef6fe"}, - {file = "shapely-2.0.5-cp312-cp312-win32.whl", hash = "sha256:3ac7dc1350700c139c956b03d9c3df49a5b34aaf91d024d1510a09717ea39199"}, - {file = "shapely-2.0.5-cp312-cp312-win_amd64.whl", hash = "sha256:30e8737983c9d954cd17feb49eb169f02f1da49e24e5171122cf2c2b62d65c95"}, - {file = "shapely-2.0.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:ff7731fea5face9ec08a861ed351734a79475631b7540ceb0b66fb9732a5f529"}, - {file = "shapely-2.0.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ff9e520af0c5a578e174bca3c18713cd47a6c6a15b6cf1f50ac17dc8bb8db6a2"}, - {file = "shapely-2.0.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49b299b91557b04acb75e9732645428470825061f871a2edc36b9417d66c1fc5"}, - {file = "shapely-2.0.5-cp37-cp37m-win32.whl", hash = "sha256:b5870633f8e684bf6d1ae4df527ddcb6f3895f7b12bced5c13266ac04f47d231"}, - {file = "shapely-2.0.5-cp37-cp37m-win_amd64.whl", hash = "sha256:401cb794c5067598f50518e5a997e270cd7642c4992645479b915c503866abed"}, - {file = "shapely-2.0.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e91ee179af539100eb520281ba5394919067c6b51824e6ab132ad4b3b3e76dd0"}, - {file = "shapely-2.0.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8af6f7260f809c0862741ad08b1b89cb60c130ae30efab62320bbf4ee9cc71fa"}, - {file = "shapely-2.0.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f5456dd522800306ba3faef77c5ba847ec30a0bd73ab087a25e0acdd4db2514f"}, - {file = "shapely-2.0.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b714a840402cde66fd7b663bb08cacb7211fa4412ea2a209688f671e0d0631fd"}, - {file = "shapely-2.0.5-cp38-cp38-win32.whl", hash = "sha256:7e8cf5c252fac1ea51b3162be2ec3faddedc82c256a1160fc0e8ddbec81b06d2"}, - {file = "shapely-2.0.5-cp38-cp38-win_amd64.whl", hash = "sha256:4461509afdb15051e73ab178fae79974387f39c47ab635a7330d7fee02c68a3f"}, - {file = "shapely-2.0.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7545a39c55cad1562be302d74c74586f79e07b592df8ada56b79a209731c0219"}, - {file = "shapely-2.0.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4c83a36f12ec8dee2066946d98d4d841ab6512a6ed7eb742e026a64854019b5f"}, - {file = "shapely-2.0.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89e640c2cd37378480caf2eeda9a51be64201f01f786d127e78eaeff091ec897"}, - {file = "shapely-2.0.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06efe39beafde3a18a21dde169d32f315c57da962826a6d7d22630025200c5e6"}, - {file = "shapely-2.0.5-cp39-cp39-win32.whl", hash = "sha256:8203a8b2d44dcb366becbc8c3d553670320e4acf0616c39e218c9561dd738d92"}, - {file = "shapely-2.0.5-cp39-cp39-win_amd64.whl", hash = "sha256:7fed9dbfbcfec2682d9a047b9699db8dcc890dfca857ecba872c42185fc9e64e"}, - {file = "shapely-2.0.5.tar.gz", hash = "sha256:bff2366bc786bfa6cb353d6b47d0443c570c32776612e527ee47b6df63fcfe32"}, + {file = "shapely-2.0.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:29a34e068da2d321e926b5073539fd2a1d4429a2c656bd63f0bd4c8f5b236d0b"}, + {file = "shapely-2.0.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e1c84c3f53144febf6af909d6b581bc05e8785d57e27f35ebaa5c1ab9baba13b"}, + {file = "shapely-2.0.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2ad2fae12dca8d2b727fa12b007e46fbc522148a584f5d6546c539f3464dccde"}, + {file = "shapely-2.0.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3304883bd82d44be1b27a9d17f1167fda8c7f5a02a897958d86c59ec69b705e"}, + {file = "shapely-2.0.6-cp310-cp310-win32.whl", hash = "sha256:3ec3a0eab496b5e04633a39fa3d5eb5454628228201fb24903d38174ee34565e"}, + {file = "shapely-2.0.6-cp310-cp310-win_amd64.whl", hash = "sha256:28f87cdf5308a514763a5c38de295544cb27429cfa655d50ed8431a4796090c4"}, + {file = "shapely-2.0.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5aeb0f51a9db176da9a30cb2f4329b6fbd1e26d359012bb0ac3d3c7781667a9e"}, + {file = "shapely-2.0.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9a7a78b0d51257a367ee115f4d41ca4d46edbd0dd280f697a8092dd3989867b2"}, + {file = "shapely-2.0.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f32c23d2f43d54029f986479f7c1f6e09c6b3a19353a3833c2ffb226fb63a855"}, + {file = "shapely-2.0.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3dc9fb0eb56498912025f5eb352b5126f04801ed0e8bdbd867d21bdbfd7cbd0"}, + {file = "shapely-2.0.6-cp311-cp311-win32.whl", hash = "sha256:d93b7e0e71c9f095e09454bf18dad5ea716fb6ced5df3cb044564a00723f339d"}, + {file = "shapely-2.0.6-cp311-cp311-win_amd64.whl", hash = "sha256:c02eb6bf4cfb9fe6568502e85bb2647921ee49171bcd2d4116c7b3109724ef9b"}, + {file = "shapely-2.0.6-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cec9193519940e9d1b86a3b4f5af9eb6910197d24af02f247afbfb47bcb3fab0"}, + {file = "shapely-2.0.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:83b94a44ab04a90e88be69e7ddcc6f332da7c0a0ebb1156e1c4f568bbec983c3"}, + {file = "shapely-2.0.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:537c4b2716d22c92036d00b34aac9d3775e3691f80c7aa517c2c290351f42cd8"}, + {file = "shapely-2.0.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98fea108334be345c283ce74bf064fa00cfdd718048a8af7343c59eb40f59726"}, + {file = "shapely-2.0.6-cp312-cp312-win32.whl", hash = "sha256:42fd4cd4834747e4990227e4cbafb02242c0cffe9ce7ef9971f53ac52d80d55f"}, + {file = "shapely-2.0.6-cp312-cp312-win_amd64.whl", hash = "sha256:665990c84aece05efb68a21b3523a6b2057e84a1afbef426ad287f0796ef8a48"}, + {file = "shapely-2.0.6-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:42805ef90783ce689a4dde2b6b2f261e2c52609226a0438d882e3ced40bb3013"}, + {file = "shapely-2.0.6-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:6d2cb146191a47bd0cee8ff5f90b47547b82b6345c0d02dd8b25b88b68af62d7"}, + {file = "shapely-2.0.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e3fdef0a1794a8fe70dc1f514440aa34426cc0ae98d9a1027fb299d45741c381"}, + {file = "shapely-2.0.6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c665a0301c645615a107ff7f52adafa2153beab51daf34587170d85e8ba6805"}, + {file = "shapely-2.0.6-cp313-cp313-win32.whl", hash = "sha256:0334bd51828f68cd54b87d80b3e7cee93f249d82ae55a0faf3ea21c9be7b323a"}, + {file = "shapely-2.0.6-cp313-cp313-win_amd64.whl", hash = "sha256:d37d070da9e0e0f0a530a621e17c0b8c3c9d04105655132a87cfff8bd77cc4c2"}, + {file = "shapely-2.0.6-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:fa7468e4f5b92049c0f36d63c3e309f85f2775752e076378e36c6387245c5462"}, + {file = "shapely-2.0.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ed5867e598a9e8ac3291da6cc9baa62ca25706eea186117034e8ec0ea4355653"}, + {file = "shapely-2.0.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:81d9dfe155f371f78c8d895a7b7f323bb241fb148d848a2bf2244f79213123fe"}, + {file = "shapely-2.0.6-cp37-cp37m-win32.whl", hash = "sha256:fbb7bf02a7542dba55129062570211cfb0defa05386409b3e306c39612e7fbcc"}, + {file = "shapely-2.0.6-cp37-cp37m-win_amd64.whl", hash = "sha256:837d395fac58aa01aa544495b97940995211e3e25f9aaf87bc3ba5b3a8cd1ac7"}, + {file = "shapely-2.0.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c6d88ade96bf02f6bfd667ddd3626913098e243e419a0325ebef2bbd481d1eb6"}, + {file = "shapely-2.0.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8b3b818c4407eaa0b4cb376fd2305e20ff6df757bf1356651589eadc14aab41b"}, + {file = "shapely-2.0.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1bbc783529a21f2bd50c79cef90761f72d41c45622b3e57acf78d984c50a5d13"}, + {file = "shapely-2.0.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2423f6c0903ebe5df6d32e0066b3d94029aab18425ad4b07bf98c3972a6e25a1"}, + {file = "shapely-2.0.6-cp38-cp38-win32.whl", hash = "sha256:2de00c3bfa80d6750832bde1d9487e302a6dd21d90cb2f210515cefdb616e5f5"}, + {file = "shapely-2.0.6-cp38-cp38-win_amd64.whl", hash = "sha256:3a82d58a1134d5e975f19268710e53bddd9c473743356c90d97ce04b73e101ee"}, + {file = "shapely-2.0.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:392f66f458a0a2c706254f473290418236e52aa4c9b476a072539d63a2460595"}, + {file = "shapely-2.0.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:eba5bae271d523c938274c61658ebc34de6c4b33fdf43ef7e938b5776388c1be"}, + {file = "shapely-2.0.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7060566bc4888b0c8ed14b5d57df8a0ead5c28f9b69fb6bed4476df31c51b0af"}, + {file = "shapely-2.0.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b02154b3e9d076a29a8513dffcb80f047a5ea63c897c0cd3d3679f29363cf7e5"}, + {file = "shapely-2.0.6-cp39-cp39-win32.whl", hash = "sha256:44246d30124a4f1a638a7d5419149959532b99dfa25b54393512e6acc9c211ac"}, + {file = "shapely-2.0.6-cp39-cp39-win_amd64.whl", hash = "sha256:2b542d7f1dbb89192d3512c52b679c822ba916f93479fa5d4fc2fe4fa0b3c9e8"}, + {file = "shapely-2.0.6.tar.gz", hash = "sha256:997f6159b1484059ec239cacaa53467fd8b5564dabe186cd84ac2944663b0bf6"}, ] [package.dependencies] @@ -7822,13 +8239,13 @@ files = [ [[package]] name = "soupsieve" -version = "2.5" +version = "2.6" description = "A modern CSS selector implementation for Beautiful Soup." optional = false python-versions = ">=3.8" files = [ - {file = "soupsieve-2.5-py3-none-any.whl", hash = "sha256:eaa337ff55a1579b6549dc679565eac1e3d000563bcb1c8ab0d0fefbc0c2cdc7"}, - {file = "soupsieve-2.5.tar.gz", hash = "sha256:5663d5a7b3bfaeee0bc4372e7fc48f9cff4940b3eec54a6451cc5299f1097690"}, + {file = "soupsieve-2.6-py3-none-any.whl", hash = "sha256:e72c4ff06e4fb6e4b5a9f0f55fe6e81514581fca1515028625d0f299c602ccc9"}, + {file = "soupsieve-2.6.tar.gz", hash = "sha256:e2e68417777af359ec65daac1057404a3c8a5455bb8abc36f1a9866ab1a51abb"}, ] [[package]] @@ -7935,13 +8352,13 @@ doc = ["sphinx"] [[package]] name = "starlette" -version = "0.37.2" +version = "0.38.2" description = "The little ASGI library that shines." optional = false python-versions = ">=3.8" files = [ - {file = "starlette-0.37.2-py3-none-any.whl", hash = "sha256:6fe59f29268538e5d0d182f2791a479a0c64638e6935d1c6989e63fb2699c6ee"}, - {file = "starlette-0.37.2.tar.gz", hash = "sha256:9af890290133b79fc3db55474ade20f6220a364a0402e0b556e7cd5e1e093823"}, + {file = "starlette-0.38.2-py3-none-any.whl", hash = "sha256:4ec6a59df6bbafdab5f567754481657f7ed90dc9d69b0c9ff017907dd54faeff"}, + {file = "starlette-0.38.2.tar.gz", hash = "sha256:c7c0441065252160993a1a37cf2a73bb64d271b17303e0b0c1eb7191cfb12d75"}, ] [package.dependencies] @@ -7950,15 +8367,29 @@ anyio = ">=3.4.0,<5" [package.extras] full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7)", "pyyaml"] +[[package]] +name = "strictyaml" +version = "1.7.3" +description = "Strict, typed YAML parser" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "strictyaml-1.7.3-py3-none-any.whl", hash = "sha256:fb5c8a4edb43bebb765959e420f9b3978d7f1af88c80606c03fb420888f5d1c7"}, + {file = "strictyaml-1.7.3.tar.gz", hash = "sha256:22f854a5fcab42b5ddba8030a0e4be51ca89af0267961c8d6cfa86395586c407"}, +] + +[package.dependencies] +python-dateutil = ">=2.6.0" + [[package]] name = "sympy" -version = "1.13.1" +version = "1.13.2" description = "Computer algebra system (CAS) in Python" optional = false python-versions = ">=3.8" files = [ - {file = "sympy-1.13.1-py3-none-any.whl", hash = "sha256:db36cdc64bf61b9b24578b6f7bab1ecdd2452cf008f34faa33776680c26d66f8"}, - {file = "sympy-1.13.1.tar.gz", hash = "sha256:9cebf7e04ff162015ce31c9c6c9144daa34a93bd082f54fd8f12deca4f47515f"}, + {file = "sympy-1.13.2-py3-none-any.whl", hash = "sha256:c51d75517712f1aed280d4ce58506a4a88d635d6b5dd48b39102a7ae1f3fcfe9"}, + {file = "sympy-1.13.2.tar.gz", hash = "sha256:401449d84d07be9d0c7a46a64bd54fe097667d5e7181bfe67ec777be9e01cb13"}, ] [package.dependencies] @@ -8013,13 +8444,13 @@ test = ["pytest", "tornado (>=4.5)", "typeguard"] [[package]] name = "tencentcloud-sdk-python-common" -version = "3.0.1206" +version = "3.0.1216" description = "Tencent Cloud Common SDK for Python" optional = false python-versions = "*" files = [ - {file = "tencentcloud-sdk-python-common-3.0.1206.tar.gz", hash = "sha256:e32745e6d46b94b2c2c33cd68c7e70bff3d63e8e5e5d314bb0b41616521c90f2"}, - {file = "tencentcloud_sdk_python_common-3.0.1206-py2.py3-none-any.whl", hash = "sha256:2100697933d62135b093bae43eee0f8862b45ca0597da72779e304c9b392ac96"}, + {file = "tencentcloud-sdk-python-common-3.0.1216.tar.gz", hash = "sha256:7ad83b100574068fe25439fe47fd27253ff1c730348a309567a7ff88eda63cf8"}, + {file = "tencentcloud_sdk_python_common-3.0.1216-py2.py3-none-any.whl", hash = "sha256:5e1cf9b685923d567d379f96a7008084006ad68793cfa0a0524e65dc59fd09d7"}, ] [package.dependencies] @@ -8027,17 +8458,17 @@ requests = ">=2.16.0" [[package]] name = "tencentcloud-sdk-python-hunyuan" -version = "3.0.1206" +version = "3.0.1216" description = "Tencent Cloud Hunyuan SDK for Python" optional = false python-versions = "*" files = [ - {file = "tencentcloud-sdk-python-hunyuan-3.0.1206.tar.gz", hash = "sha256:2c37f2f50e54d23905d91d7a511a217317d944c701127daae548b7275cc32968"}, - {file = "tencentcloud_sdk_python_hunyuan-3.0.1206-py2.py3-none-any.whl", hash = "sha256:c650315bb5863f28d410fa1062122550d8015600947d04d95e2bff55d0590acc"}, + {file = "tencentcloud-sdk-python-hunyuan-3.0.1216.tar.gz", hash = "sha256:b295d67f97dba52ed358a1d9e061f94b1a4a87e45714efbf0987edab12642206"}, + {file = "tencentcloud_sdk_python_hunyuan-3.0.1216-py2.py3-none-any.whl", hash = "sha256:62d925b41424017929b532389061a076dca72dde455e85ec089947645010e691"}, ] [package.dependencies] -tencentcloud-sdk-python-common = "3.0.1206" +tencentcloud-sdk-python-common = "3.0.1216" [[package]] name = "threadpoolctl" @@ -8406,13 +8837,13 @@ requests = ">=2.0.0" [[package]] name = "typer" -version = "0.12.3" +version = "0.12.4" description = "Typer, build great CLIs. Easy to code. Based on Python type hints." optional = false python-versions = ">=3.7" files = [ - {file = "typer-0.12.3-py3-none-any.whl", hash = "sha256:070d7ca53f785acbccba8e7d28b08dcd88f79f1fbda035ade0aecec71ca5c914"}, - {file = "typer-0.12.3.tar.gz", hash = "sha256:49e73131481d804288ef62598d97a1ceef3058905aa536a1134f90891ba35482"}, + {file = "typer-0.12.4-py3-none-any.whl", hash = "sha256:819aa03699f438397e876aa12b0d63766864ecba1b579092cc9fe35d886e34b6"}, + {file = "typer-0.12.4.tar.gz", hash = "sha256:c9c1613ed6a166162705b3347b8d10b661ccc5d95692654d0fb628118f2c34e6"}, ] [package.dependencies] @@ -8689,13 +9120,13 @@ zstd = ["zstandard (>=0.18.0)"] [[package]] name = "uvicorn" -version = "0.30.5" +version = "0.30.6" description = "The lightning-fast ASGI server." optional = false python-versions = ">=3.8" files = [ - {file = "uvicorn-0.30.5-py3-none-any.whl", hash = "sha256:b2d86de274726e9878188fa07576c9ceeff90a839e2b6e25c917fe05f5a6c835"}, - {file = "uvicorn-0.30.5.tar.gz", hash = "sha256:ac6fdbd4425c5fd17a9fe39daf4d4d075da6fdc80f653e5894cdc2fd98752bee"}, + {file = "uvicorn-0.30.6-py3-none-any.whl", hash = "sha256:65fd46fe3fda5bdc1b03b94eb634923ff18cd35b2f084813ea79d1f103f711b5"}, + {file = "uvicorn-0.30.6.tar.gz", hash = "sha256:4b15decdda1e72be08209e860a1e10e92439ad5b97cf44cc945fcbee66fc5788"}, ] [package.dependencies] @@ -8715,42 +9146,42 @@ standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", [[package]] name = "uvloop" -version = "0.19.0" +version = "0.20.0" description = "Fast implementation of asyncio event loop on top of libuv" optional = false python-versions = ">=3.8.0" files = [ - {file = "uvloop-0.19.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:de4313d7f575474c8f5a12e163f6d89c0a878bc49219641d49e6f1444369a90e"}, - {file = "uvloop-0.19.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5588bd21cf1fcf06bded085f37e43ce0e00424197e7c10e77afd4bbefffef428"}, - {file = "uvloop-0.19.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b1fd71c3843327f3bbc3237bedcdb6504fd50368ab3e04d0410e52ec293f5b8"}, - {file = "uvloop-0.19.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5a05128d315e2912791de6088c34136bfcdd0c7cbc1cf85fd6fd1bb321b7c849"}, - {file = "uvloop-0.19.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:cd81bdc2b8219cb4b2556eea39d2e36bfa375a2dd021404f90a62e44efaaf957"}, - {file = "uvloop-0.19.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5f17766fb6da94135526273080f3455a112f82570b2ee5daa64d682387fe0dcd"}, - {file = "uvloop-0.19.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:4ce6b0af8f2729a02a5d1575feacb2a94fc7b2e983868b009d51c9a9d2149bef"}, - {file = "uvloop-0.19.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:31e672bb38b45abc4f26e273be83b72a0d28d074d5b370fc4dcf4c4eb15417d2"}, - {file = "uvloop-0.19.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:570fc0ed613883d8d30ee40397b79207eedd2624891692471808a95069a007c1"}, - {file = "uvloop-0.19.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5138821e40b0c3e6c9478643b4660bd44372ae1e16a322b8fc07478f92684e24"}, - {file = "uvloop-0.19.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:91ab01c6cd00e39cde50173ba4ec68a1e578fee9279ba64f5221810a9e786533"}, - {file = "uvloop-0.19.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:47bf3e9312f63684efe283f7342afb414eea4d3011542155c7e625cd799c3b12"}, - {file = "uvloop-0.19.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:da8435a3bd498419ee8c13c34b89b5005130a476bda1d6ca8cfdde3de35cd650"}, - {file = "uvloop-0.19.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:02506dc23a5d90e04d4f65c7791e65cf44bd91b37f24cfc3ef6cf2aff05dc7ec"}, - {file = "uvloop-0.19.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2693049be9d36fef81741fddb3f441673ba12a34a704e7b4361efb75cf30befc"}, - {file = "uvloop-0.19.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7010271303961c6f0fe37731004335401eb9075a12680738731e9c92ddd96ad6"}, - {file = "uvloop-0.19.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:5daa304d2161d2918fa9a17d5635099a2f78ae5b5960e742b2fcfbb7aefaa593"}, - {file = "uvloop-0.19.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:7207272c9520203fea9b93843bb775d03e1cf88a80a936ce760f60bb5add92f3"}, - {file = "uvloop-0.19.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:78ab247f0b5671cc887c31d33f9b3abfb88d2614b84e4303f1a63b46c046c8bd"}, - {file = "uvloop-0.19.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:472d61143059c84947aa8bb74eabbace30d577a03a1805b77933d6bd13ddebbd"}, - {file = "uvloop-0.19.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45bf4c24c19fb8a50902ae37c5de50da81de4922af65baf760f7c0c42e1088be"}, - {file = "uvloop-0.19.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:271718e26b3e17906b28b67314c45d19106112067205119dddbd834c2b7ce797"}, - {file = "uvloop-0.19.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:34175c9fd2a4bc3adc1380e1261f60306344e3407c20a4d684fd5f3be010fa3d"}, - {file = "uvloop-0.19.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e27f100e1ff17f6feeb1f33968bc185bf8ce41ca557deee9d9bbbffeb72030b7"}, - {file = "uvloop-0.19.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:13dfdf492af0aa0a0edf66807d2b465607d11c4fa48f4a1fd41cbea5b18e8e8b"}, - {file = "uvloop-0.19.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6e3d4e85ac060e2342ff85e90d0c04157acb210b9ce508e784a944f852a40e67"}, - {file = "uvloop-0.19.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8ca4956c9ab567d87d59d49fa3704cf29e37109ad348f2d5223c9bf761a332e7"}, - {file = "uvloop-0.19.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f467a5fd23b4fc43ed86342641f3936a68ded707f4627622fa3f82a120e18256"}, - {file = "uvloop-0.19.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:492e2c32c2af3f971473bc22f086513cedfc66a130756145a931a90c3958cb17"}, - {file = "uvloop-0.19.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2df95fca285a9f5bfe730e51945ffe2fa71ccbfdde3b0da5772b4ee4f2e770d5"}, - {file = "uvloop-0.19.0.tar.gz", hash = "sha256:0246f4fd1bf2bf702e06b0d45ee91677ee5c31242f39aab4ea6fe0c51aedd0fd"}, + {file = "uvloop-0.20.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:9ebafa0b96c62881d5cafa02d9da2e44c23f9f0cd829f3a32a6aff771449c996"}, + {file = "uvloop-0.20.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:35968fc697b0527a06e134999eef859b4034b37aebca537daeb598b9d45a137b"}, + {file = "uvloop-0.20.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b16696f10e59d7580979b420eedf6650010a4a9c3bd8113f24a103dfdb770b10"}, + {file = "uvloop-0.20.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9b04d96188d365151d1af41fa2d23257b674e7ead68cfd61c725a422764062ae"}, + {file = "uvloop-0.20.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:94707205efbe809dfa3a0d09c08bef1352f5d3d6612a506f10a319933757c006"}, + {file = "uvloop-0.20.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:89e8d33bb88d7263f74dc57d69f0063e06b5a5ce50bb9a6b32f5fcbe655f9e73"}, + {file = "uvloop-0.20.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e50289c101495e0d1bb0bfcb4a60adde56e32f4449a67216a1ab2750aa84f037"}, + {file = "uvloop-0.20.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e237f9c1e8a00e7d9ddaa288e535dc337a39bcbf679f290aee9d26df9e72bce9"}, + {file = "uvloop-0.20.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:746242cd703dc2b37f9d8b9f173749c15e9a918ddb021575a0205ec29a38d31e"}, + {file = "uvloop-0.20.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:82edbfd3df39fb3d108fc079ebc461330f7c2e33dbd002d146bf7c445ba6e756"}, + {file = "uvloop-0.20.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:80dc1b139516be2077b3e57ce1cb65bfed09149e1d175e0478e7a987863b68f0"}, + {file = "uvloop-0.20.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4f44af67bf39af25db4c1ac27e82e9665717f9c26af2369c404be865c8818dcf"}, + {file = "uvloop-0.20.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:4b75f2950ddb6feed85336412b9a0c310a2edbcf4cf931aa5cfe29034829676d"}, + {file = "uvloop-0.20.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:77fbc69c287596880ecec2d4c7a62346bef08b6209749bf6ce8c22bbaca0239e"}, + {file = "uvloop-0.20.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6462c95f48e2d8d4c993a2950cd3d31ab061864d1c226bbf0ee2f1a8f36674b9"}, + {file = "uvloop-0.20.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:649c33034979273fa71aa25d0fe120ad1777c551d8c4cd2c0c9851d88fcb13ab"}, + {file = "uvloop-0.20.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3a609780e942d43a275a617c0839d85f95c334bad29c4c0918252085113285b5"}, + {file = "uvloop-0.20.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:aea15c78e0d9ad6555ed201344ae36db5c63d428818b4b2a42842b3870127c00"}, + {file = "uvloop-0.20.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:f0e94b221295b5e69de57a1bd4aeb0b3a29f61be6e1b478bb8a69a73377db7ba"}, + {file = "uvloop-0.20.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:fee6044b64c965c425b65a4e17719953b96e065c5b7e09b599ff332bb2744bdf"}, + {file = "uvloop-0.20.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:265a99a2ff41a0fd56c19c3838b29bf54d1d177964c300dad388b27e84fd7847"}, + {file = "uvloop-0.20.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b10c2956efcecb981bf9cfb8184d27d5d64b9033f917115a960b83f11bfa0d6b"}, + {file = "uvloop-0.20.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:e7d61fe8e8d9335fac1bf8d5d82820b4808dd7a43020c149b63a1ada953d48a6"}, + {file = "uvloop-0.20.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:2beee18efd33fa6fdb0976e18475a4042cd31c7433c866e8a09ab604c7c22ff2"}, + {file = "uvloop-0.20.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:d8c36fdf3e02cec92aed2d44f63565ad1522a499c654f07935c8f9d04db69e95"}, + {file = "uvloop-0.20.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a0fac7be202596c7126146660725157d4813aa29a4cc990fe51346f75ff8fde7"}, + {file = "uvloop-0.20.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9d0fba61846f294bce41eb44d60d58136090ea2b5b99efd21cbdf4e21927c56a"}, + {file = "uvloop-0.20.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95720bae002ac357202e0d866128eb1ac82545bcf0b549b9abe91b5178d9b541"}, + {file = "uvloop-0.20.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:36c530d8fa03bfa7085af54a48f2ca16ab74df3ec7108a46ba82fd8b411a2315"}, + {file = "uvloop-0.20.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e97152983442b499d7a71e44f29baa75b3b02e65d9c44ba53b10338e98dedb66"}, + {file = "uvloop-0.20.0.tar.gz", hash = "sha256:4603ca714a754fc8d9b197e325db25b2ea045385e8a3ad05d3463de725fdf469"}, ] [package.extras] @@ -8830,88 +9261,122 @@ files = [ {file = "vine-5.1.0.tar.gz", hash = "sha256:8b62e981d35c41049211cf62a0a1242d8c1ee9bd15bb196ce38aefd6799e61e0"}, ] +[[package]] +name = "volcengine-python-sdk" +version = "1.0.98" +description = "Volcengine SDK for Python" +optional = false +python-versions = "*" +files = [ + {file = "volcengine-python-sdk-1.0.98.tar.gz", hash = "sha256:1515e8d46cdcda387f9b45abbcaf0b04b982f7be68068de83f1e388281441784"}, +] + +[package.dependencies] +anyio = {version = ">=3.5.0,<5", optional = true, markers = "extra == \"ark\""} +certifi = ">=2017.4.17" +httpx = {version = ">=0.23.0,<1", optional = true, markers = "extra == \"ark\""} +pydantic = {version = ">=1.9.0,<3", optional = true, markers = "extra == \"ark\""} +python-dateutil = ">=2.1" +six = ">=1.10" +urllib3 = ">=1.23" + +[package.extras] +ark = ["anyio (>=3.5.0,<5)", "cached-property", "httpx (>=0.23.0,<1)", "pydantic (>=1.9.0,<3)"] + [[package]] name = "watchfiles" -version = "0.22.0" +version = "0.23.0" description = "Simple, modern and high performance file watching and code reload in python." optional = false python-versions = ">=3.8" files = [ - {file = "watchfiles-0.22.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:da1e0a8caebf17976e2ffd00fa15f258e14749db5e014660f53114b676e68538"}, - {file = "watchfiles-0.22.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:61af9efa0733dc4ca462347becb82e8ef4945aba5135b1638bfc20fad64d4f0e"}, - {file = "watchfiles-0.22.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d9188979a58a096b6f8090e816ccc3f255f137a009dd4bbec628e27696d67c1"}, - {file = "watchfiles-0.22.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2bdadf6b90c099ca079d468f976fd50062905d61fae183f769637cb0f68ba59a"}, - {file = "watchfiles-0.22.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:067dea90c43bf837d41e72e546196e674f68c23702d3ef80e4e816937b0a3ffd"}, - {file = "watchfiles-0.22.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bbf8a20266136507abf88b0df2328e6a9a7c7309e8daff124dda3803306a9fdb"}, - {file = "watchfiles-0.22.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1235c11510ea557fe21be5d0e354bae2c655a8ee6519c94617fe63e05bca4171"}, - {file = "watchfiles-0.22.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2444dc7cb9d8cc5ab88ebe792a8d75709d96eeef47f4c8fccb6df7c7bc5be71"}, - {file = "watchfiles-0.22.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c5af2347d17ab0bd59366db8752d9e037982e259cacb2ba06f2c41c08af02c39"}, - {file = "watchfiles-0.22.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9624a68b96c878c10437199d9a8b7d7e542feddda8d5ecff58fdc8e67b460848"}, - {file = "watchfiles-0.22.0-cp310-none-win32.whl", hash = "sha256:4b9f2a128a32a2c273d63eb1fdbf49ad64852fc38d15b34eaa3f7ca2f0d2b797"}, - {file = "watchfiles-0.22.0-cp310-none-win_amd64.whl", hash = "sha256:2627a91e8110b8de2406d8b2474427c86f5a62bf7d9ab3654f541f319ef22bcb"}, - {file = "watchfiles-0.22.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:8c39987a1397a877217be1ac0fb1d8b9f662c6077b90ff3de2c05f235e6a8f96"}, - {file = "watchfiles-0.22.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a927b3034d0672f62fb2ef7ea3c9fc76d063c4b15ea852d1db2dc75fe2c09696"}, - {file = "watchfiles-0.22.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:052d668a167e9fc345c24203b104c313c86654dd6c0feb4b8a6dfc2462239249"}, - {file = "watchfiles-0.22.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5e45fb0d70dda1623a7045bd00c9e036e6f1f6a85e4ef2c8ae602b1dfadf7550"}, - {file = "watchfiles-0.22.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c49b76a78c156979759d759339fb62eb0549515acfe4fd18bb151cc07366629c"}, - {file = "watchfiles-0.22.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c4a65474fd2b4c63e2c18ac67a0c6c66b82f4e73e2e4d940f837ed3d2fd9d4da"}, - {file = "watchfiles-0.22.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1cc0cba54f47c660d9fa3218158b8963c517ed23bd9f45fe463f08262a4adae1"}, - {file = "watchfiles-0.22.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:94ebe84a035993bb7668f58a0ebf998174fb723a39e4ef9fce95baabb42b787f"}, - {file = "watchfiles-0.22.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e0f0a874231e2839abbf473256efffe577d6ee2e3bfa5b540479e892e47c172d"}, - {file = "watchfiles-0.22.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:213792c2cd3150b903e6e7884d40660e0bcec4465e00563a5fc03f30ea9c166c"}, - {file = "watchfiles-0.22.0-cp311-none-win32.whl", hash = "sha256:b44b70850f0073b5fcc0b31ede8b4e736860d70e2dbf55701e05d3227a154a67"}, - {file = "watchfiles-0.22.0-cp311-none-win_amd64.whl", hash = "sha256:00f39592cdd124b4ec5ed0b1edfae091567c72c7da1487ae645426d1b0ffcad1"}, - {file = "watchfiles-0.22.0-cp311-none-win_arm64.whl", hash = "sha256:3218a6f908f6a276941422b035b511b6d0d8328edd89a53ae8c65be139073f84"}, - {file = "watchfiles-0.22.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:c7b978c384e29d6c7372209cbf421d82286a807bbcdeb315427687f8371c340a"}, - {file = "watchfiles-0.22.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bd4c06100bce70a20c4b81e599e5886cf504c9532951df65ad1133e508bf20be"}, - {file = "watchfiles-0.22.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:425440e55cd735386ec7925f64d5dde392e69979d4c8459f6bb4e920210407f2"}, - {file = "watchfiles-0.22.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:68fe0c4d22332d7ce53ad094622b27e67440dacefbaedd29e0794d26e247280c"}, - {file = "watchfiles-0.22.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a8a31bfd98f846c3c284ba694c6365620b637debdd36e46e1859c897123aa232"}, - {file = "watchfiles-0.22.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dc2e8fe41f3cac0660197d95216c42910c2b7e9c70d48e6d84e22f577d106fc1"}, - {file = "watchfiles-0.22.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55b7cc10261c2786c41d9207193a85c1db1b725cf87936df40972aab466179b6"}, - {file = "watchfiles-0.22.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:28585744c931576e535860eaf3f2c0ec7deb68e3b9c5a85ca566d69d36d8dd27"}, - {file = "watchfiles-0.22.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:00095dd368f73f8f1c3a7982a9801190cc88a2f3582dd395b289294f8975172b"}, - {file = "watchfiles-0.22.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:52fc9b0dbf54d43301a19b236b4a4614e610605f95e8c3f0f65c3a456ffd7d35"}, - {file = "watchfiles-0.22.0-cp312-none-win32.whl", hash = "sha256:581f0a051ba7bafd03e17127735d92f4d286af941dacf94bcf823b101366249e"}, - {file = "watchfiles-0.22.0-cp312-none-win_amd64.whl", hash = "sha256:aec83c3ba24c723eac14225194b862af176d52292d271c98820199110e31141e"}, - {file = "watchfiles-0.22.0-cp312-none-win_arm64.whl", hash = "sha256:c668228833c5619f6618699a2c12be057711b0ea6396aeaece4ded94184304ea"}, - {file = "watchfiles-0.22.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:d47e9ef1a94cc7a536039e46738e17cce058ac1593b2eccdede8bf72e45f372a"}, - {file = "watchfiles-0.22.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:28f393c1194b6eaadcdd8f941307fc9bbd7eb567995232c830f6aef38e8a6e88"}, - {file = "watchfiles-0.22.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dd64f3a4db121bc161644c9e10a9acdb836853155a108c2446db2f5ae1778c3d"}, - {file = "watchfiles-0.22.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2abeb79209630da981f8ebca30a2c84b4c3516a214451bfc5f106723c5f45843"}, - {file = "watchfiles-0.22.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4cc382083afba7918e32d5ef12321421ef43d685b9a67cc452a6e6e18920890e"}, - {file = "watchfiles-0.22.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d048ad5d25b363ba1d19f92dcf29023988524bee6f9d952130b316c5802069cb"}, - {file = "watchfiles-0.22.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:103622865599f8082f03af4214eaff90e2426edff5e8522c8f9e93dc17caee13"}, - {file = "watchfiles-0.22.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d3e1f3cf81f1f823e7874ae563457828e940d75573c8fbf0ee66818c8b6a9099"}, - {file = "watchfiles-0.22.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8597b6f9dc410bdafc8bb362dac1cbc9b4684a8310e16b1ff5eee8725d13dcd6"}, - {file = "watchfiles-0.22.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:0b04a2cbc30e110303baa6d3ddce8ca3664bc3403be0f0ad513d1843a41c97d1"}, - {file = "watchfiles-0.22.0-cp38-none-win32.whl", hash = "sha256:b610fb5e27825b570554d01cec427b6620ce9bd21ff8ab775fc3a32f28bba63e"}, - {file = "watchfiles-0.22.0-cp38-none-win_amd64.whl", hash = "sha256:fe82d13461418ca5e5a808a9e40f79c1879351fcaeddbede094028e74d836e86"}, - {file = "watchfiles-0.22.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:3973145235a38f73c61474d56ad6199124e7488822f3a4fc97c72009751ae3b0"}, - {file = "watchfiles-0.22.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:280a4afbc607cdfc9571b9904b03a478fc9f08bbeec382d648181c695648202f"}, - {file = "watchfiles-0.22.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a0d883351a34c01bd53cfa75cd0292e3f7e268bacf2f9e33af4ecede7e21d1d"}, - {file = "watchfiles-0.22.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9165bcab15f2b6d90eedc5c20a7f8a03156b3773e5fb06a790b54ccecdb73385"}, - {file = "watchfiles-0.22.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dc1b9b56f051209be458b87edb6856a449ad3f803315d87b2da4c93b43a6fe72"}, - {file = "watchfiles-0.22.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8dc1fc25a1dedf2dd952909c8e5cb210791e5f2d9bc5e0e8ebc28dd42fed7562"}, - {file = "watchfiles-0.22.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dc92d2d2706d2b862ce0568b24987eba51e17e14b79a1abcd2edc39e48e743c8"}, - {file = "watchfiles-0.22.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97b94e14b88409c58cdf4a8eaf0e67dfd3ece7e9ce7140ea6ff48b0407a593ec"}, - {file = "watchfiles-0.22.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:96eec15e5ea7c0b6eb5bfffe990fc7c6bd833acf7e26704eb18387fb2f5fd087"}, - {file = "watchfiles-0.22.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:28324d6b28bcb8d7c1041648d7b63be07a16db5510bea923fc80b91a2a6cbed6"}, - {file = "watchfiles-0.22.0-cp39-none-win32.whl", hash = "sha256:8c3e3675e6e39dc59b8fe5c914a19d30029e36e9f99468dddffd432d8a7b1c93"}, - {file = "watchfiles-0.22.0-cp39-none-win_amd64.whl", hash = "sha256:25c817ff2a86bc3de3ed2df1703e3d24ce03479b27bb4527c57e722f8554d971"}, - {file = "watchfiles-0.22.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b810a2c7878cbdecca12feae2c2ae8af59bea016a78bc353c184fa1e09f76b68"}, - {file = "watchfiles-0.22.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:f7e1f9c5d1160d03b93fc4b68a0aeb82fe25563e12fbcdc8507f8434ab6f823c"}, - {file = "watchfiles-0.22.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:030bc4e68d14bcad2294ff68c1ed87215fbd9a10d9dea74e7cfe8a17869785ab"}, - {file = "watchfiles-0.22.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ace7d060432acde5532e26863e897ee684780337afb775107c0a90ae8dbccfd2"}, - {file = "watchfiles-0.22.0-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5834e1f8b71476a26df97d121c0c0ed3549d869124ed2433e02491553cb468c2"}, - {file = "watchfiles-0.22.0-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:0bc3b2f93a140df6806c8467c7f51ed5e55a931b031b5c2d7ff6132292e803d6"}, - {file = "watchfiles-0.22.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8fdebb655bb1ba0122402352b0a4254812717a017d2dc49372a1d47e24073795"}, - {file = "watchfiles-0.22.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c8e0aa0e8cc2a43561e0184c0513e291ca891db13a269d8d47cb9841ced7c71"}, - {file = "watchfiles-0.22.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:2f350cbaa4bb812314af5dab0eb8d538481e2e2279472890864547f3fe2281ed"}, - {file = "watchfiles-0.22.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:7a74436c415843af2a769b36bf043b6ccbc0f8d784814ba3d42fc961cdb0a9dc"}, - {file = "watchfiles-0.22.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:00ad0bcd399503a84cc688590cdffbe7a991691314dde5b57b3ed50a41319a31"}, - {file = "watchfiles-0.22.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:72a44e9481afc7a5ee3291b09c419abab93b7e9c306c9ef9108cb76728ca58d2"}, - {file = "watchfiles-0.22.0.tar.gz", hash = "sha256:988e981aaab4f3955209e7e28c7794acdb690be1efa7f16f8ea5aba7ffdadacb"}, + {file = "watchfiles-0.23.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:bee8ce357a05c20db04f46c22be2d1a2c6a8ed365b325d08af94358e0688eeb4"}, + {file = "watchfiles-0.23.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4ccd3011cc7ee2f789af9ebe04745436371d36afe610028921cab9f24bb2987b"}, + {file = "watchfiles-0.23.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb02d41c33be667e6135e6686f1bb76104c88a312a18faa0ef0262b5bf7f1a0f"}, + {file = "watchfiles-0.23.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7cf12ac34c444362f3261fb3ff548f0037ddd4c5bb85f66c4be30d2936beb3c5"}, + {file = "watchfiles-0.23.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a0b2c25040a3c0ce0e66c7779cc045fdfbbb8d59e5aabfe033000b42fe44b53e"}, + {file = "watchfiles-0.23.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ecf2be4b9eece4f3da8ba5f244b9e51932ebc441c0867bd6af46a3d97eb068d6"}, + {file = "watchfiles-0.23.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:40cb8fa00028908211eb9f8d47744dca21a4be6766672e1ff3280bee320436f1"}, + {file = "watchfiles-0.23.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f48c917ffd36ff9a5212614c2d0d585fa8b064ca7e66206fb5c095015bc8207"}, + {file = "watchfiles-0.23.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9d183e3888ada88185ab17064079c0db8c17e32023f5c278d7bf8014713b1b5b"}, + {file = "watchfiles-0.23.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9837edf328b2805346f91209b7e660f65fb0e9ca18b7459d075d58db082bf981"}, + {file = "watchfiles-0.23.0-cp310-none-win32.whl", hash = "sha256:296e0b29ab0276ca59d82d2da22cbbdb39a23eed94cca69aed274595fb3dfe42"}, + {file = "watchfiles-0.23.0-cp310-none-win_amd64.whl", hash = "sha256:4ea756e425ab2dfc8ef2a0cb87af8aa7ef7dfc6fc46c6f89bcf382121d4fff75"}, + {file = "watchfiles-0.23.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:e397b64f7aaf26915bf2ad0f1190f75c855d11eb111cc00f12f97430153c2eab"}, + {file = "watchfiles-0.23.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b4ac73b02ca1824ec0a7351588241fd3953748d3774694aa7ddb5e8e46aef3e3"}, + {file = "watchfiles-0.23.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:130a896d53b48a1cecccfa903f37a1d87dbb74295305f865a3e816452f6e49e4"}, + {file = "watchfiles-0.23.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c5e7803a65eb2d563c73230e9d693c6539e3c975ccfe62526cadde69f3fda0cf"}, + {file = "watchfiles-0.23.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d1aa4cc85202956d1a65c88d18c7b687b8319dbe6b1aec8969784ef7a10e7d1a"}, + {file = "watchfiles-0.23.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:87f889f6e58849ddb7c5d2cb19e2e074917ed1c6e3ceca50405775166492cca8"}, + {file = "watchfiles-0.23.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:37fd826dac84c6441615aa3f04077adcc5cac7194a021c9f0d69af20fb9fa788"}, + {file = "watchfiles-0.23.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ee7db6e36e7a2c15923072e41ea24d9a0cf39658cb0637ecc9307b09d28827e1"}, + {file = "watchfiles-0.23.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:2368c5371c17fdcb5a2ea71c5c9d49f9b128821bfee69503cc38eae00feb3220"}, + {file = "watchfiles-0.23.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:857af85d445b9ba9178db95658c219dbd77b71b8264e66836a6eba4fbf49c320"}, + {file = "watchfiles-0.23.0-cp311-none-win32.whl", hash = "sha256:1d636c8aeb28cdd04a4aa89030c4b48f8b2954d8483e5f989774fa441c0ed57b"}, + {file = "watchfiles-0.23.0-cp311-none-win_amd64.whl", hash = "sha256:46f1d8069a95885ca529645cdbb05aea5837d799965676e1b2b1f95a4206313e"}, + {file = "watchfiles-0.23.0-cp311-none-win_arm64.whl", hash = "sha256:e495ed2a7943503766c5d1ff05ae9212dc2ce1c0e30a80d4f0d84889298fa304"}, + {file = "watchfiles-0.23.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:1db691bad0243aed27c8354b12d60e8e266b75216ae99d33e927ff5238d270b5"}, + {file = "watchfiles-0.23.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:62d2b18cb1edaba311fbbfe83fb5e53a858ba37cacb01e69bc20553bb70911b8"}, + {file = "watchfiles-0.23.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e087e8fdf1270d000913c12e6eca44edd02aad3559b3e6b8ef00f0ce76e0636f"}, + {file = "watchfiles-0.23.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:dd41d5c72417b87c00b1b635738f3c283e737d75c5fa5c3e1c60cd03eac3af77"}, + {file = "watchfiles-0.23.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1e5f3ca0ff47940ce0a389457b35d6df601c317c1e1a9615981c474452f98de1"}, + {file = "watchfiles-0.23.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6991e3a78f642368b8b1b669327eb6751439f9f7eaaa625fae67dd6070ecfa0b"}, + {file = "watchfiles-0.23.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7f7252f52a09f8fa5435dc82b6af79483118ce6bd51eb74e6269f05ee22a7b9f"}, + {file = "watchfiles-0.23.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e01bcb8d767c58865207a6c2f2792ad763a0fe1119fb0a430f444f5b02a5ea0"}, + {file = "watchfiles-0.23.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:8e56fbcdd27fce061854ddec99e015dd779cae186eb36b14471fc9ae713b118c"}, + {file = "watchfiles-0.23.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:bd3e2d64500a6cad28bcd710ee6269fbeb2e5320525acd0cfab5f269ade68581"}, + {file = "watchfiles-0.23.0-cp312-none-win32.whl", hash = "sha256:eb99c954291b2fad0eff98b490aa641e128fbc4a03b11c8a0086de8b7077fb75"}, + {file = "watchfiles-0.23.0-cp312-none-win_amd64.whl", hash = "sha256:dccc858372a56080332ea89b78cfb18efb945da858fabeb67f5a44fa0bcb4ebb"}, + {file = "watchfiles-0.23.0-cp312-none-win_arm64.whl", hash = "sha256:6c21a5467f35c61eafb4e394303720893066897fca937bade5b4f5877d350ff8"}, + {file = "watchfiles-0.23.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:ba31c32f6b4dceeb2be04f717811565159617e28d61a60bb616b6442027fd4b9"}, + {file = "watchfiles-0.23.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:85042ab91814fca99cec4678fc063fb46df4cbb57b4835a1cc2cb7a51e10250e"}, + {file = "watchfiles-0.23.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:24655e8c1c9c114005c3868a3d432c8aa595a786b8493500071e6a52f3d09217"}, + {file = "watchfiles-0.23.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6b1a950ab299a4a78fd6369a97b8763732bfb154fdb433356ec55a5bce9515c1"}, + {file = "watchfiles-0.23.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8d3c5cd327dd6ce0edfc94374fb5883d254fe78a5e9d9dfc237a1897dc73cd1"}, + {file = "watchfiles-0.23.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9ff785af8bacdf0be863ec0c428e3288b817e82f3d0c1d652cd9c6d509020dd0"}, + {file = "watchfiles-0.23.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:02b7ba9d4557149410747353e7325010d48edcfe9d609a85cb450f17fd50dc3d"}, + {file = "watchfiles-0.23.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:48a1b05c0afb2cd2f48c1ed2ae5487b116e34b93b13074ed3c22ad5c743109f0"}, + {file = "watchfiles-0.23.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:109a61763e7318d9f821b878589e71229f97366fa6a5c7720687d367f3ab9eef"}, + {file = "watchfiles-0.23.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:9f8e6bb5ac007d4a4027b25f09827ed78cbbd5b9700fd6c54429278dacce05d1"}, + {file = "watchfiles-0.23.0-cp313-none-win32.whl", hash = "sha256:f46c6f0aec8d02a52d97a583782d9af38c19a29900747eb048af358a9c1d8e5b"}, + {file = "watchfiles-0.23.0-cp313-none-win_amd64.whl", hash = "sha256:f449afbb971df5c6faeb0a27bca0427d7b600dd8f4a068492faec18023f0dcff"}, + {file = "watchfiles-0.23.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:2dddc2487d33e92f8b6222b5fb74ae2cfde5e8e6c44e0248d24ec23befdc5366"}, + {file = "watchfiles-0.23.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e75695cc952e825fa3e0684a7f4a302f9128721f13eedd8dbd3af2ba450932b8"}, + {file = "watchfiles-0.23.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2537ef60596511df79b91613a5bb499b63f46f01a11a81b0a2b0dedf645d0a9c"}, + {file = "watchfiles-0.23.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:20b423b58f5fdde704a226b598a2d78165fe29eb5621358fe57ea63f16f165c4"}, + {file = "watchfiles-0.23.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b98732ec893975455708d6fc9a6daab527fc8bbe65be354a3861f8c450a632a4"}, + {file = "watchfiles-0.23.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ee1f5fcbf5bc33acc0be9dd31130bcba35d6d2302e4eceafafd7d9018c7755ab"}, + {file = "watchfiles-0.23.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a8f195338a5a7b50a058522b39517c50238358d9ad8284fd92943643144c0c03"}, + {file = "watchfiles-0.23.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:524fcb8d59b0dbee2c9b32207084b67b2420f6431ed02c18bd191e6c575f5c48"}, + {file = "watchfiles-0.23.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:0eff099a4df36afaa0eea7a913aa64dcf2cbd4e7a4f319a73012210af4d23810"}, + {file = "watchfiles-0.23.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:a8323daae27ea290ba3350c70c836c0d2b0fb47897fa3b0ca6a5375b952b90d3"}, + {file = "watchfiles-0.23.0-cp38-none-win32.whl", hash = "sha256:aafea64a3ae698695975251f4254df2225e2624185a69534e7fe70581066bc1b"}, + {file = "watchfiles-0.23.0-cp38-none-win_amd64.whl", hash = "sha256:c846884b2e690ba62a51048a097acb6b5cd263d8bd91062cd6137e2880578472"}, + {file = "watchfiles-0.23.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:a753993635eccf1ecb185dedcc69d220dab41804272f45e4aef0a67e790c3eb3"}, + {file = "watchfiles-0.23.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6bb91fa4d0b392f0f7e27c40981e46dda9eb0fbc84162c7fb478fe115944f491"}, + {file = "watchfiles-0.23.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1f67312efa3902a8e8496bfa9824d3bec096ff83c4669ea555c6bdd213aa516"}, + {file = "watchfiles-0.23.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7ca6b71dcc50d320c88fb2d88ecd63924934a8abc1673683a242a7ca7d39e781"}, + {file = "watchfiles-0.23.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2aec5c29915caf08771d2507da3ac08e8de24a50f746eb1ed295584ba1820330"}, + {file = "watchfiles-0.23.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1733b9bc2c8098c6bdb0ff7a3d7cb211753fecb7bd99bdd6df995621ee1a574b"}, + {file = "watchfiles-0.23.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:02ff5d7bd066c6a7673b17c8879cd8ee903078d184802a7ee851449c43521bdd"}, + {file = "watchfiles-0.23.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18e2de19801b0eaa4c5292a223effb7cfb43904cb742c5317a0ac686ed604765"}, + {file = "watchfiles-0.23.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:8ada449e22198c31fb013ae7e9add887e8d2bd2335401abd3cbc55f8c5083647"}, + {file = "watchfiles-0.23.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3af1b05361e1cc497bf1be654a664750ae61f5739e4bb094a2be86ec8c6db9b6"}, + {file = "watchfiles-0.23.0-cp39-none-win32.whl", hash = "sha256:486bda18be5d25ab5d932699ceed918f68eb91f45d018b0343e3502e52866e5e"}, + {file = "watchfiles-0.23.0-cp39-none-win_amd64.whl", hash = "sha256:d2d42254b189a346249424fb9bb39182a19289a2409051ee432fb2926bad966a"}, + {file = "watchfiles-0.23.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:6a9265cf87a5b70147bfb2fec14770ed5b11a5bb83353f0eee1c25a81af5abfe"}, + {file = "watchfiles-0.23.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:9f02a259fcbbb5fcfe7a0805b1097ead5ba7a043e318eef1db59f93067f0b49b"}, + {file = "watchfiles-0.23.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1ebaebb53b34690da0936c256c1cdb0914f24fb0e03da76d185806df9328abed"}, + {file = "watchfiles-0.23.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd257f98cff9c6cb39eee1a83c7c3183970d8a8d23e8cf4f47d9a21329285cee"}, + {file = "watchfiles-0.23.0-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:aba037c1310dd108411d27b3d5815998ef0e83573e47d4219f45753c710f969f"}, + {file = "watchfiles-0.23.0-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:a96ac14e184aa86dc43b8a22bb53854760a58b2966c2b41580de938e9bf26ed0"}, + {file = "watchfiles-0.23.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:11698bb2ea5e991d10f1f4f83a39a02f91e44e4bd05f01b5c1ec04c9342bf63c"}, + {file = "watchfiles-0.23.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:efadd40fca3a04063d40c4448c9303ce24dd6151dc162cfae4a2a060232ebdcb"}, + {file = "watchfiles-0.23.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:556347b0abb4224c5ec688fc58214162e92a500323f50182f994f3ad33385dcb"}, + {file = "watchfiles-0.23.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:1cf7f486169986c4b9d34087f08ce56a35126600b6fef3028f19ca16d5889071"}, + {file = "watchfiles-0.23.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f18de0f82c62c4197bea5ecf4389288ac755896aac734bd2cc44004c56e4ac47"}, + {file = "watchfiles-0.23.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:532e1f2c491274d1333a814e4c5c2e8b92345d41b12dc806cf07aaff786beb66"}, + {file = "watchfiles-0.23.0.tar.gz", hash = "sha256:9338ade39ff24f8086bb005d16c29f8e9f19e55b18dcb04dfa26fcbc09da497b"}, ] [package.dependencies] @@ -8977,94 +9442,108 @@ test = ["websockets"] [[package]] name = "websockets" -version = "12.0" +version = "13.0" description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" optional = false python-versions = ">=3.8" files = [ - {file = "websockets-12.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d554236b2a2006e0ce16315c16eaa0d628dab009c33b63ea03f41c6107958374"}, - {file = "websockets-12.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2d225bb6886591b1746b17c0573e29804619c8f755b5598d875bb4235ea639be"}, - {file = "websockets-12.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:eb809e816916a3b210bed3c82fb88eaf16e8afcf9c115ebb2bacede1797d2547"}, - {file = "websockets-12.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c588f6abc13f78a67044c6b1273a99e1cf31038ad51815b3b016ce699f0d75c2"}, - {file = "websockets-12.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5aa9348186d79a5f232115ed3fa9020eab66d6c3437d72f9d2c8ac0c6858c558"}, - {file = "websockets-12.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6350b14a40c95ddd53e775dbdbbbc59b124a5c8ecd6fbb09c2e52029f7a9f480"}, - {file = "websockets-12.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:70ec754cc2a769bcd218ed8d7209055667b30860ffecb8633a834dde27d6307c"}, - {file = "websockets-12.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6e96f5ed1b83a8ddb07909b45bd94833b0710f738115751cdaa9da1fb0cb66e8"}, - {file = "websockets-12.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4d87be612cbef86f994178d5186add3d94e9f31cc3cb499a0482b866ec477603"}, - {file = "websockets-12.0-cp310-cp310-win32.whl", hash = "sha256:befe90632d66caaf72e8b2ed4d7f02b348913813c8b0a32fae1cc5fe3730902f"}, - {file = "websockets-12.0-cp310-cp310-win_amd64.whl", hash = "sha256:363f57ca8bc8576195d0540c648aa58ac18cf85b76ad5202b9f976918f4219cf"}, - {file = "websockets-12.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5d873c7de42dea355d73f170be0f23788cf3fa9f7bed718fd2830eefedce01b4"}, - {file = "websockets-12.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3f61726cae9f65b872502ff3c1496abc93ffbe31b278455c418492016e2afc8f"}, - {file = "websockets-12.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ed2fcf7a07334c77fc8a230755c2209223a7cc44fc27597729b8ef5425aa61a3"}, - {file = "websockets-12.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e332c210b14b57904869ca9f9bf4ca32f5427a03eeb625da9b616c85a3a506c"}, - {file = "websockets-12.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5693ef74233122f8ebab026817b1b37fe25c411ecfca084b29bc7d6efc548f45"}, - {file = "websockets-12.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e9e7db18b4539a29cc5ad8c8b252738a30e2b13f033c2d6e9d0549b45841c04"}, - {file = "websockets-12.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6e2df67b8014767d0f785baa98393725739287684b9f8d8a1001eb2839031447"}, - {file = "websockets-12.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:bea88d71630c5900690fcb03161ab18f8f244805c59e2e0dc4ffadae0a7ee0ca"}, - {file = "websockets-12.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:dff6cdf35e31d1315790149fee351f9e52978130cef6c87c4b6c9b3baf78bc53"}, - {file = "websockets-12.0-cp311-cp311-win32.whl", hash = "sha256:3e3aa8c468af01d70332a382350ee95f6986db479ce7af14d5e81ec52aa2b402"}, - {file = "websockets-12.0-cp311-cp311-win_amd64.whl", hash = "sha256:25eb766c8ad27da0f79420b2af4b85d29914ba0edf69f547cc4f06ca6f1d403b"}, - {file = "websockets-12.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0e6e2711d5a8e6e482cacb927a49a3d432345dfe7dea8ace7b5790df5932e4df"}, - {file = "websockets-12.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:dbcf72a37f0b3316e993e13ecf32f10c0e1259c28ffd0a85cee26e8549595fbc"}, - {file = "websockets-12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:12743ab88ab2af1d17dd4acb4645677cb7063ef4db93abffbf164218a5d54c6b"}, - {file = "websockets-12.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b645f491f3c48d3f8a00d1fce07445fab7347fec54a3e65f0725d730d5b99cb"}, - {file = "websockets-12.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9893d1aa45a7f8b3bc4510f6ccf8db8c3b62120917af15e3de247f0780294b92"}, - {file = "websockets-12.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f38a7b376117ef7aff996e737583172bdf535932c9ca021746573bce40165ed"}, - {file = "websockets-12.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:f764ba54e33daf20e167915edc443b6f88956f37fb606449b4a5b10ba42235a5"}, - {file = "websockets-12.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:1e4b3f8ea6a9cfa8be8484c9221ec0257508e3a1ec43c36acdefb2a9c3b00aa2"}, - {file = "websockets-12.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:9fdf06fd06c32205a07e47328ab49c40fc1407cdec801d698a7c41167ea45113"}, - {file = "websockets-12.0-cp312-cp312-win32.whl", hash = "sha256:baa386875b70cbd81798fa9f71be689c1bf484f65fd6fb08d051a0ee4e79924d"}, - {file = "websockets-12.0-cp312-cp312-win_amd64.whl", hash = "sha256:ae0a5da8f35a5be197f328d4727dbcfafa53d1824fac3d96cdd3a642fe09394f"}, - {file = "websockets-12.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5f6ffe2c6598f7f7207eef9a1228b6f5c818f9f4d53ee920aacd35cec8110438"}, - {file = "websockets-12.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9edf3fc590cc2ec20dc9d7a45108b5bbaf21c0d89f9fd3fd1685e223771dc0b2"}, - {file = "websockets-12.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8572132c7be52632201a35f5e08348137f658e5ffd21f51f94572ca6c05ea81d"}, - {file = "websockets-12.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:604428d1b87edbf02b233e2c207d7d528460fa978f9e391bd8aaf9c8311de137"}, - {file = "websockets-12.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1a9d160fd080c6285e202327aba140fc9a0d910b09e423afff4ae5cbbf1c7205"}, - {file = "websockets-12.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87b4aafed34653e465eb77b7c93ef058516cb5acf3eb21e42f33928616172def"}, - {file = "websockets-12.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b2ee7288b85959797970114deae81ab41b731f19ebcd3bd499ae9ca0e3f1d2c8"}, - {file = "websockets-12.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:7fa3d25e81bfe6a89718e9791128398a50dec6d57faf23770787ff441d851967"}, - {file = "websockets-12.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:a571f035a47212288e3b3519944f6bf4ac7bc7553243e41eac50dd48552b6df7"}, - {file = "websockets-12.0-cp38-cp38-win32.whl", hash = "sha256:3c6cc1360c10c17463aadd29dd3af332d4a1adaa8796f6b0e9f9df1fdb0bad62"}, - {file = "websockets-12.0-cp38-cp38-win_amd64.whl", hash = "sha256:1bf386089178ea69d720f8db6199a0504a406209a0fc23e603b27b300fdd6892"}, - {file = "websockets-12.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:ab3d732ad50a4fbd04a4490ef08acd0517b6ae6b77eb967251f4c263011a990d"}, - {file = "websockets-12.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a1d9697f3337a89691e3bd8dc56dea45a6f6d975f92e7d5f773bc715c15dde28"}, - {file = "websockets-12.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1df2fbd2c8a98d38a66f5238484405b8d1d16f929bb7a33ed73e4801222a6f53"}, - {file = "websockets-12.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23509452b3bc38e3a057382c2e941d5ac2e01e251acce7adc74011d7d8de434c"}, - {file = "websockets-12.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2e5fc14ec6ea568200ea4ef46545073da81900a2b67b3e666f04adf53ad452ec"}, - {file = "websockets-12.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46e71dbbd12850224243f5d2aeec90f0aaa0f2dde5aeeb8fc8df21e04d99eff9"}, - {file = "websockets-12.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b81f90dcc6c85a9b7f29873beb56c94c85d6f0dac2ea8b60d995bd18bf3e2aae"}, - {file = "websockets-12.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:a02413bc474feda2849c59ed2dfb2cddb4cd3d2f03a2fedec51d6e959d9b608b"}, - {file = "websockets-12.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:bbe6013f9f791944ed31ca08b077e26249309639313fff132bfbf3ba105673b9"}, - {file = "websockets-12.0-cp39-cp39-win32.whl", hash = "sha256:cbe83a6bbdf207ff0541de01e11904827540aa069293696dd528a6640bd6a5f6"}, - {file = "websockets-12.0-cp39-cp39-win_amd64.whl", hash = "sha256:fc4e7fa5414512b481a2483775a8e8be7803a35b30ca805afa4998a84f9fd9e8"}, - {file = "websockets-12.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:248d8e2446e13c1d4326e0a6a4e9629cb13a11195051a73acf414812700badbd"}, - {file = "websockets-12.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f44069528d45a933997a6fef143030d8ca8042f0dfaad753e2906398290e2870"}, - {file = "websockets-12.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c4e37d36f0d19f0a4413d3e18c0d03d0c268ada2061868c1e6f5ab1a6d575077"}, - {file = "websockets-12.0-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d829f975fc2e527a3ef2f9c8f25e553eb7bc779c6665e8e1d52aa22800bb38b"}, - {file = "websockets-12.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:2c71bd45a777433dd9113847af751aae36e448bc6b8c361a566cb043eda6ec30"}, - {file = "websockets-12.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:0bee75f400895aef54157b36ed6d3b308fcab62e5260703add87f44cee9c82a6"}, - {file = "websockets-12.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:423fc1ed29f7512fceb727e2d2aecb952c46aa34895e9ed96071821309951123"}, - {file = "websockets-12.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:27a5e9964ef509016759f2ef3f2c1e13f403725a5e6a1775555994966a66e931"}, - {file = "websockets-12.0-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3181df4583c4d3994d31fb235dc681d2aaad744fbdbf94c4802485ececdecf2"}, - {file = "websockets-12.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:b067cb952ce8bf40115f6c19f478dc71c5e719b7fbaa511359795dfd9d1a6468"}, - {file = "websockets-12.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:00700340c6c7ab788f176d118775202aadea7602c5cc6be6ae127761c16d6b0b"}, - {file = "websockets-12.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e469d01137942849cff40517c97a30a93ae79917752b34029f0ec72df6b46399"}, - {file = "websockets-12.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ffefa1374cd508d633646d51a8e9277763a9b78ae71324183693959cf94635a7"}, - {file = "websockets-12.0-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba0cab91b3956dfa9f512147860783a1829a8d905ee218a9837c18f683239611"}, - {file = "websockets-12.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:2cb388a5bfb56df4d9a406783b7f9dbefb888c09b71629351cc6b036e9259370"}, - {file = "websockets-12.0-py3-none-any.whl", hash = "sha256:dc284bbc8d7c78a6c69e0c7325ab46ee5e40bb4d50e494d8131a07ef47500e9e"}, - {file = "websockets-12.0.tar.gz", hash = "sha256:81df9cbcbb6c260de1e007e58c011bfebe2dafc8435107b0537f393dd38c8b1b"}, + {file = "websockets-13.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ad4fa707ff9e2ffee019e946257b5300a45137a58f41fbd9a4db8e684ab61528"}, + {file = "websockets-13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6fd757f313c13c34dae9f126d3ba4cf97175859c719e57c6a614b781c86b617e"}, + {file = "websockets-13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cbac2eb7ce0fac755fb983c9247c4a60c4019bcde4c0e4d167aeb17520cc7ef1"}, + {file = "websockets-13.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d4b83cf7354cbbc058e97b3e545dceb75b8d9cf17fd5a19db419c319ddbaaf7a"}, + {file = "websockets-13.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9202c0010c78fad1041e1c5285232b6508d3633f92825687549540a70e9e5901"}, + {file = "websockets-13.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e6566e79c8c7cbea75ec450f6e1828945fc5c9a4769ceb1c7b6e22470539712"}, + {file = "websockets-13.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e7fcad070dcd9ad37a09d89a4cbc2a5e3e45080b88977c0da87b3090f9f55ead"}, + {file = "websockets-13.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:0a8f7d65358a25172db00c69bcc7df834155ee24229f560d035758fd6613111a"}, + {file = "websockets-13.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:63b702fb31e3f058f946ccdfa551f4d57a06f7729c369e8815eb18643099db37"}, + {file = "websockets-13.0-cp310-cp310-win32.whl", hash = "sha256:3a20cf14ba7b482c4a1924b5e061729afb89c890ca9ed44ac4127c6c5986e424"}, + {file = "websockets-13.0-cp310-cp310-win_amd64.whl", hash = "sha256:587245f0704d0bb675f919898d7473e8827a6d578e5a122a21756ca44b811ec8"}, + {file = "websockets-13.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:06df8306c241c235075d2ae77367038e701e53bc8c1bb4f6644f4f53aa6dedd0"}, + {file = "websockets-13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:85a1f92a02f0b8c1bf02699731a70a8a74402bb3f82bee36e7768b19a8ed9709"}, + {file = "websockets-13.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9ed02c604349068d46d87ef4c2012c112c791f2bec08671903a6bb2bd9c06784"}, + {file = "websockets-13.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b89849171b590107f6724a7b0790736daead40926ddf47eadf998b4ff51d6414"}, + {file = "websockets-13.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:939a16849d71203628157a5e4a495da63967c744e1e32018e9b9e2689aca64d4"}, + {file = "websockets-13.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad818cdac37c0ad4c58e51cb4964eae4f18b43c4a83cb37170b0d90c31bd80cf"}, + {file = "websockets-13.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:cbfe82a07596a044de78bb7a62519e71690c5812c26c5f1d4b877e64e4f46309"}, + {file = "websockets-13.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e07e76c49f39c5b45cbd7362b94f001ae209a3ea4905ae9a09cfd53b3c76373d"}, + {file = "websockets-13.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:372f46a0096cfda23c88f7e42349a33f8375e10912f712e6b496d3a9a557290f"}, + {file = "websockets-13.0-cp311-cp311-win32.whl", hash = "sha256:376a43a4fd96725f13450d3d2e98f4f36c3525c562ab53d9a98dd2950dca9a8a"}, + {file = "websockets-13.0-cp311-cp311-win_amd64.whl", hash = "sha256:2be1382a4daa61e2f3e2be3b3c86932a8db9d1f85297feb6e9df22f391f94452"}, + {file = "websockets-13.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:b5407c34776b9b77bd89a5f95eb0a34aaf91889e3f911c63f13035220eb50107"}, + {file = "websockets-13.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:4782ec789f059f888c1e8fdf94383d0e64b531cffebbf26dd55afd53ab487ca4"}, + {file = "websockets-13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c8feb8e19ef65c9994e652c5b0324abd657bedd0abeb946fb4f5163012c1e730"}, + {file = "websockets-13.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d3f3d2e20c442b58dbac593cb1e02bc02d149a86056cc4126d977ad902472e3b"}, + {file = "websockets-13.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e39d393e0ab5b8bd01717cc26f2922026050188947ff54fe6a49dc489f7750b7"}, + {file = "websockets-13.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f661a4205741bdc88ac9c2b2ec003c72cee97e4acd156eb733662ff004ba429"}, + {file = "websockets-13.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:384129ad0490e06bab2b98c1da9b488acb35bb11e2464c728376c6f55f0d45f3"}, + {file = "websockets-13.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:df5c0eff91f61b8205a6c9f7b255ff390cdb77b61c7b41f79ca10afcbb22b6cb"}, + {file = "websockets-13.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:02cc9bb1a887dac0e08bf657c5d00aa3fac0d03215d35a599130c2034ae6663a"}, + {file = "websockets-13.0-cp312-cp312-win32.whl", hash = "sha256:d9726d2c9bd6aed8cb994d89b3910ca0079406edce3670886ec828a73e7bdd53"}, + {file = "websockets-13.0-cp312-cp312-win_amd64.whl", hash = "sha256:fa0839f35322f7b038d8adcf679e2698c3a483688cc92e3bd15ee4fb06669e9a"}, + {file = "websockets-13.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:da7e501e59857e8e3e9d10586139dc196b80445a591451ca9998aafba1af5278"}, + {file = "websockets-13.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a00e1e587c655749afb5b135d8d3edcfe84ec6db864201e40a882e64168610b3"}, + {file = "websockets-13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a7fbf2a8fe7556a8f4e68cb3e736884af7bf93653e79f6219f17ebb75e97d8f0"}, + {file = "websockets-13.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ea9c9c7443a97ea4d84d3e4d42d0e8c4235834edae652993abcd2aff94affd7"}, + {file = "websockets-13.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:35c2221b539b360203f3f9ad168e527bf16d903e385068ae842c186efb13d0ea"}, + {file = "websockets-13.0-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:358d37c5c431dd050ffb06b4b075505aae3f4f795d7fff9794e5ed96ce99b998"}, + {file = "websockets-13.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:038e7a0f1bfafc7bf52915ab3506b7a03d1e06381e9f60440c856e8918138151"}, + {file = "websockets-13.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:fd038bc9e2c134847f1e0ce3191797fad110756e690c2fdd9702ed34e7a43abb"}, + {file = "websockets-13.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:93b8c2008f372379fb6e5d2b3f7c9ec32f7b80316543fd3a5ace6610c5cde1b0"}, + {file = "websockets-13.0-cp313-cp313-win32.whl", hash = "sha256:851fd0afb3bc0b73f7c5b5858975d42769a5fdde5314f4ef2c106aec63100687"}, + {file = "websockets-13.0-cp313-cp313-win_amd64.whl", hash = "sha256:7d14901fdcf212804970c30ab9ee8f3f0212e620c7ea93079d6534863444fb4e"}, + {file = "websockets-13.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ae7a519a56a714f64c3445cabde9fc2fc927e7eae44f413eae187cddd9e54178"}, + {file = "websockets-13.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5575031472ca87302aeb2ce2c2349f4c6ea978c86a9d1289bc5d16058ad4c10a"}, + {file = "websockets-13.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:9895df6cd0bfe79d09bcd1dbdc03862846f26fbd93797153de954306620c1d00"}, + {file = "websockets-13.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a4de299c947a54fca9ce1c5fd4a08eb92ffce91961becb13bd9195f7c6e71b47"}, + {file = "websockets-13.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:05c25f7b849702950b6fd0e233989bb73a0d2bc83faa3b7233313ca395205f6d"}, + {file = "websockets-13.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ede95125a30602b1691a4b1da88946bf27dae283cf30f22cd2cb8ca4b2e0d119"}, + {file = "websockets-13.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:addf0a16e4983280efed272d8cb3b2e05f0051755372461e7d966b80a6554e16"}, + {file = "websockets-13.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:06b3186e97bf9a33921fa60734d5ed90f2a9b407cce8d23c7333a0984049ef61"}, + {file = "websockets-13.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:eae368cac85adc4c7dc3b0d5f84ffcca609d658db6447387300478e44db70796"}, + {file = "websockets-13.0-cp38-cp38-win32.whl", hash = "sha256:337837ac788d955728b1ab01876d72b73da59819a3388e1c5e8e05c3999f1afa"}, + {file = "websockets-13.0-cp38-cp38-win_amd64.whl", hash = "sha256:f66e00e42f25ca7e91076366303e11c82572ca87cc5aae51e6e9c094f315ab41"}, + {file = "websockets-13.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:94c1c02721139fe9940b38d28fb15b4b782981d800d5f40f9966264fbf23dcc8"}, + {file = "websockets-13.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bd4ba86513430513e2aa25a441bb538f6f83734dc368a2c5d18afdd39097aa33"}, + {file = "websockets-13.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a1ab8f0e0cadc5be5f3f9fa11a663957fecbf483d434762c8dfb8aa44948944a"}, + {file = "websockets-13.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3670def5d3dfd5af6f6e2b3b243ea8f1f72d8da1ef927322f0703f85c90d9603"}, + {file = "websockets-13.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6058b6be92743358885ad6dcdecb378fde4a4c74d4dd16a089d07580c75a0e80"}, + {file = "websockets-13.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:516062a0a8ef5ecbfa4acbaec14b199fc070577834f9fe3d40800a99f92523ca"}, + {file = "websockets-13.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:da7e918d82e7bdfc6f66d31febe1b2e28a1ca3387315f918de26f5e367f61572"}, + {file = "websockets-13.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:9cc7f35dcb49a4e32db82a849fcc0714c4d4acc9d2273aded2d61f87d7f660b7"}, + {file = "websockets-13.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:f5737c53eb2c8ed8f64b50d3dafd3c1dae739f78aa495a288421ac1b3de82717"}, + {file = "websockets-13.0-cp39-cp39-win32.whl", hash = "sha256:265e1f0d3f788ce8ef99dca591a1aec5263b26083ca0934467ad9a1d1181067c"}, + {file = "websockets-13.0-cp39-cp39-win_amd64.whl", hash = "sha256:4d70c89e3d3b347a7c4d3c33f8d323f0584c9ceb69b82c2ef8a174ca84ea3d4a"}, + {file = "websockets-13.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:602cbd010d8c21c8475f1798b705bb18567eb189c533ab5ef568bc3033fdf417"}, + {file = "websockets-13.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:bf8eb5dca4f484a60f5327b044e842e0d7f7cdbf02ea6dc4a4f811259f1f1f0b"}, + {file = "websockets-13.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89d795c1802d99a643bf689b277e8604c14b5af1bc0a31dade2cd7a678087212"}, + {file = "websockets-13.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:788bc841d250beccff67a20a5a53a15657a60111ef9c0c0a97fbdd614fae0fe2"}, + {file = "websockets-13.0-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7334752052532c156d28b8eaf3558137e115c7871ea82adff69b6d94a7bee273"}, + {file = "websockets-13.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:e7a1963302947332c3039e3f66209ec73b1626f8a0191649e0713c391e9f5b0d"}, + {file = "websockets-13.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:2e1cf4e1eb84b4fd74a47688e8b0940c89a04ad9f6937afa43d468e71128cd68"}, + {file = "websockets-13.0-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:c026ee729c4ce55708a14b839ba35086dfae265fc12813b62d34ce33f4980c1c"}, + {file = "websockets-13.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f5f9d23fbbf96eefde836d9692670bfc89e2d159f456d499c5efcf6a6281c1af"}, + {file = "websockets-13.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ad684cb7efce227d756bae3e8484f2e56aa128398753b54245efdfbd1108f2c"}, + {file = "websockets-13.0-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1e10b3fbed7be4a59831d3a939900e50fcd34d93716e433d4193a4d0d1d335d"}, + {file = "websockets-13.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:d42a818e634f789350cd8fb413a3f5eec1cf0400a53d02062534c41519f5125c"}, + {file = "websockets-13.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:e5ba5e9b332267d0f2c33ede390061850f1ac3ee6cd1bdcf4c5ea33ead971966"}, + {file = "websockets-13.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:f9af457ed593e35f467140d8b61d425495b127744a9d65d45a366f8678449a23"}, + {file = "websockets-13.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bcea3eb58c09c3a31cc83b45c06d5907f02ddaf10920aaa6443975310f699b95"}, + {file = "websockets-13.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c210d1460dc8d326ffdef9703c2f83269b7539a1690ad11ae04162bc1878d33d"}, + {file = "websockets-13.0-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b32f38bc81170fd56d0482d505b556e52bf9078b36819a8ba52624bd6667e39e"}, + {file = "websockets-13.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:81a11a1ddd5320429db47c04d35119c3e674d215173d87aaeb06ae80f6e9031f"}, + {file = "websockets-13.0-py3-none-any.whl", hash = "sha256:dbbac01e80aee253d44c4f098ab3cc17c822518519e869b284cfbb8cd16cc9de"}, + {file = "websockets-13.0.tar.gz", hash = "sha256:b7bf950234a482b7461afdb2ec99eee3548ec4d53f418c7990bb79c620476602"}, ] [[package]] name = "werkzeug" -version = "3.0.3" +version = "3.0.4" description = "The comprehensive WSGI web application library." optional = false python-versions = ">=3.8" files = [ - {file = "werkzeug-3.0.3-py3-none-any.whl", hash = "sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8"}, - {file = "werkzeug-3.0.3.tar.gz", hash = "sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18"}, + {file = "werkzeug-3.0.4-py3-none-any.whl", hash = "sha256:02c9eb92b7d6c06f31a782811505d2157837cea66aaede3e217c7c27c039476c"}, + {file = "werkzeug-3.0.4.tar.gz", hash = "sha256:34f2371506b250df4d4f84bfe7b0921e4762525762bbd936614909fe25cd7306"}, ] [package.dependencies] @@ -9387,13 +9866,13 @@ requests = "*" [[package]] name = "zipp" -version = "3.19.2" +version = "3.20.0" description = "Backport of pathlib-compatible object wrapper for zip files" optional = false python-versions = ">=3.8" files = [ - {file = "zipp-3.19.2-py3-none-any.whl", hash = "sha256:f091755f667055f2d02b32c53771a7a6c8b47e1fdbc4b72a8b9072b3eef8015c"}, - {file = "zipp-3.19.2.tar.gz", hash = "sha256:bf1dcf6450f873a13e952a29504887c89e6de7506209e5b1bcc3460135d4de19"}, + {file = "zipp-3.20.0-py3-none-any.whl", hash = "sha256:58da6168be89f0be59beb194da1250516fdaa062ccebd30127ac65d30045e10d"}, + {file = "zipp-3.20.0.tar.gz", hash = "sha256:0145e43d89664cfe1a2e533adc75adafed82fe2da404b4bbb6b026c0157bdb31"}, ] [package.extras] @@ -9584,4 +10063,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "a74c7b6a72145d5074aa84581df6e543ea422810caf0ba1561cd2d35497243ca" +content-hash = "e4c00268514d26bd07c6b72925e0e3b4558ec972895d252e60e9571e3ac38895" diff --git a/api/pyproject.toml b/api/pyproject.toml index 82ccd0b202..6a46ab95f3 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -6,8 +6,6 @@ requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [tool.ruff] -exclude = [ -] line-length = 120 [tool.ruff.lint] @@ -71,13 +69,8 @@ ignore = [ [tool.ruff.format] exclude = [ "core/**/*.py", - "controllers/**/*.py", "models/**/*.py", "migrations/**/*", - "services/**/*.py", - "tasks/**/*.py", - "tests/**/*.py", - "configs/**/*.py", ] [tool.pytest_env] @@ -153,7 +146,7 @@ langfuse = "^2.36.1" langsmith = "^0.1.77" mailchimp-transactional = "~1.0.50" markdown = "~3.5.1" -novita-client = "^0.5.6" +novita-client = "^0.5.7" numpy = "~1.26.4" openai = "~1.29.0" openpyxl = "~3.1.5" @@ -162,7 +155,7 @@ pandas = { version = "~2.2.2", extras = ["performance", "excel"] } psycopg2-binary = "~2.9.6" pycryptodome = "3.19.1" pydantic = "~2.8.2" -pydantic-settings = "~2.3.4" +pydantic-settings = "~2.4.0" pydantic_extra_types = "~2.9.0" pyjwt = "~2.8.0" pypdfium2 = "~4.17.0" @@ -193,6 +186,9 @@ zhipuai = "1.0.7" # Related transparent dependencies with pinned verion # required by main implementations ############################################################ +azure-ai-ml = "^1.19.0" +azure-ai-inference = "^1.0.0b3" +volcengine-python-sdk = {extras = ["ark"], version = "^1.0.98"} [tool.poetry.group.indriect.dependencies] kaleido = "0.2.1" rank-bm25 = "~0.2.2" @@ -216,7 +212,7 @@ twilio = "~9.0.4" vanna = { version = "0.5.5", extras = ["postgres", "mysql", "clickhouse", "duckdb"] } wikipedia = "1.4.0" yfinance = "~0.2.40" - +nltk = "3.8.1" ############################################################ # VDB dependencies required by vector store clients ############################################################ @@ -245,7 +241,7 @@ optional = true [tool.poetry.group.dev.dependencies] coverage = "~7.2.4" -pytest = "~8.1.1" +pytest = "~8.3.2" pytest-benchmark = "~4.0.0" pytest-env = "~1.1.3" pytest-mock = "~3.14.0" diff --git a/api/services/__init__.py b/api/services/__init__.py index 6891436314..5163862cc1 100644 --- a/api/services/__init__.py +++ b/api/services/__init__.py @@ -1,3 +1,3 @@ from . import errors -__all__ = ['errors'] +__all__ = ["errors"] diff --git a/api/services/account_service.py b/api/services/account_service.py index d73cec2697..e1b70fc9ed 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -39,12 +39,7 @@ from tasks.mail_reset_password_task import send_reset_password_mail_task class AccountService: - - reset_password_rate_limiter = RateLimiter( - prefix="reset_password_rate_limit", - max_attempts=5, - time_window=60 * 60 - ) + reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=5, time_window=60 * 60) @staticmethod def load_user(user_id: str) -> None | Account: @@ -55,12 +50,15 @@ class AccountService: if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]: raise Unauthorized("Account is banned or closed.") - current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first() + current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by( + account_id=account.id, current=True + ).first() if current_tenant: account.current_tenant_id = current_tenant.tenant_id else: - available_ta = TenantAccountJoin.query.filter_by(account_id=account.id) \ - .order_by(TenantAccountJoin.id.asc()).first() + available_ta = ( + TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first() + ) if not available_ta: return None @@ -74,14 +72,13 @@ class AccountService: return account - @staticmethod def get_account_jwt_token(account, *, exp: timedelta = timedelta(days=30)): payload = { "user_id": account.id, "exp": datetime.now(timezone.utc).replace(tzinfo=None) + exp, "iss": dify_config.EDITION, - "sub": 'Console API Passport', + "sub": "Console API Passport", } token = PassportService().issue(payload) @@ -93,10 +90,10 @@ class AccountService: account = Account.query.filter_by(email=email).first() if not account: - raise AccountLoginError('Invalid email or password.') + raise AccountLoginError("Invalid email or password.") if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: - raise AccountLoginError('Account is banned or closed.') + raise AccountLoginError("Account is banned or closed.") if account.status == AccountStatus.PENDING.value: account.status = AccountStatus.ACTIVE.value @@ -104,7 +101,7 @@ class AccountService: db.session.commit() if account.password is None or not compare_password(password, account.password, account.password_salt): - raise AccountLoginError('Invalid email or password.') + raise AccountLoginError("Invalid email or password.") return account @staticmethod @@ -129,11 +126,9 @@ class AccountService: return account @staticmethod - def create_account(email: str, - name: str, - interface_language: str, - password: Optional[str] = None, - interface_theme: str = 'light') -> Account: + def create_account( + email: str, name: str, interface_language: str, password: Optional[str] = None, interface_theme: str = "light" + ) -> Account: """create account""" account = Account() account.email = email @@ -155,7 +150,7 @@ class AccountService: account.interface_theme = interface_theme # Set timezone based on language - account.timezone = language_timezone_mapping.get(interface_language, 'UTC') + account.timezone = language_timezone_mapping.get(interface_language, "UTC") db.session.add(account) db.session.commit() @@ -166,8 +161,9 @@ class AccountService: """Link account integrate""" try: # Query whether there is an existing binding record for the same provider - account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by(account_id=account.id, - provider=provider).first() + account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by( + account_id=account.id, provider=provider + ).first() if account_integrate: # If it exists, update the record @@ -176,15 +172,16 @@ class AccountService: account_integrate.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) else: # If it does not exist, create a new record - account_integrate = AccountIntegrate(account_id=account.id, provider=provider, open_id=open_id, - encrypted_token="") + account_integrate = AccountIntegrate( + account_id=account.id, provider=provider, open_id=open_id, encrypted_token="" + ) db.session.add(account_integrate) db.session.commit() - logging.info(f'Account {account.id} linked {provider} account {open_id}.') + logging.info(f"Account {account.id} linked {provider} account {open_id}.") except Exception as e: - logging.exception(f'Failed to link {provider} account {open_id} to Account {account.id}') - raise LinkAccountIntegrateError('Failed to link account.') from e + logging.exception(f"Failed to link {provider} account {open_id} to Account {account.id}") + raise LinkAccountIntegrateError("Failed to link account.") from e @staticmethod def close_account(account: Account) -> None: @@ -218,7 +215,7 @@ class AccountService: AccountService.update_last_login(account, ip_address=ip_address) exp = timedelta(days=30) token = AccountService.get_account_jwt_token(account, exp=exp) - redis_client.set(_get_login_cache_key(account_id=account.id, token=token), '1', ex=int(exp.total_seconds())) + redis_client.set(_get_login_cache_key(account_id=account.id, token=token), "1", ex=int(exp.total_seconds())) return token @staticmethod @@ -236,22 +233,18 @@ class AccountService: if cls.reset_password_rate_limiter.is_rate_limited(account.email): raise RateLimitExceededError(f"Rate limit exceeded for email: {account.email}. Please try again later.") - token = TokenManager.generate_token(account, 'reset_password') - send_reset_password_mail_task.delay( - language=account.interface_language, - to=account.email, - token=token - ) + token = TokenManager.generate_token(account, "reset_password") + send_reset_password_mail_task.delay(language=account.interface_language, to=account.email, token=token) cls.reset_password_rate_limiter.increment_rate_limit(account.email) return token @classmethod def revoke_reset_password_token(cls, token: str): - TokenManager.revoke_token(token, 'reset_password') + TokenManager.revoke_token(token, "reset_password") @classmethod def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]: - return TokenManager.get_token_data(token, 'reset_password') + return TokenManager.get_token_data(token, "reset_password") def _get_login_cache_key(*, account_id: str, token: str): @@ -259,7 +252,6 @@ def _get_login_cache_key(*, account_id: str, token: str): class TenantService: - @staticmethod def create_tenant(name: str) -> Tenant: """Create tenant""" @@ -273,33 +265,33 @@ class TenantService: return tenant @staticmethod - def create_owner_tenant_if_not_exist(account: Account): + def create_owner_tenant_if_not_exist(account: Account, name: Optional[str] = None): """Create owner tenant if not exist""" - available_ta = TenantAccountJoin.query.filter_by(account_id=account.id) \ - .order_by(TenantAccountJoin.id.asc()).first() + available_ta = ( + TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first() + ) if available_ta: return - tenant = TenantService.create_tenant(f"{account.name}'s Workspace") - TenantService.create_tenant_member(tenant, account, role='owner') + if name: + tenant = TenantService.create_tenant(name) + else: + tenant = TenantService.create_tenant(f"{account.name}'s Workspace") + TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant db.session.commit() tenant_was_created.send(tenant) @staticmethod - def create_tenant_member(tenant: Tenant, account: Account, role: str = 'normal') -> TenantAccountJoin: + def create_tenant_member(tenant: Tenant, account: Account, role: str = "normal") -> TenantAccountJoin: """Create tenant member""" if role == TenantAccountJoinRole.OWNER.value: if TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER]): - logging.error(f'Tenant {tenant.id} has already an owner.') - raise Exception('Tenant already has an owner.') + logging.error(f"Tenant {tenant.id} has already an owner.") + raise Exception("Tenant already has an owner.") - ta = TenantAccountJoin( - tenant_id=tenant.id, - account_id=account.id, - role=role - ) + ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=role) db.session.add(ta) db.session.commit() return ta @@ -307,9 +299,12 @@ class TenantService: @staticmethod def get_join_tenants(account: Account) -> list[Tenant]: """Get account join tenants""" - return db.session.query(Tenant).join( - TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id - ).filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL).all() + return ( + db.session.query(Tenant) + .join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id) + .filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL) + .all() + ) @staticmethod def get_current_tenant_by_account(account: Account): @@ -333,16 +328,23 @@ class TenantService: if tenant_id is None: raise ValueError("Tenant ID must be provided.") - tenant_account_join = db.session.query(TenantAccountJoin).join(Tenant, TenantAccountJoin.tenant_id == Tenant.id).filter( - TenantAccountJoin.account_id == account.id, - TenantAccountJoin.tenant_id == tenant_id, - Tenant.status == TenantStatus.NORMAL, - ).first() + tenant_account_join = ( + db.session.query(TenantAccountJoin) + .join(Tenant, TenantAccountJoin.tenant_id == Tenant.id) + .filter( + TenantAccountJoin.account_id == account.id, + TenantAccountJoin.tenant_id == tenant_id, + Tenant.status == TenantStatus.NORMAL, + ) + .first() + ) if not tenant_account_join: raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.") else: - TenantAccountJoin.query.filter(TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id).update({'current': False}) + TenantAccountJoin.query.filter( + TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id + ).update({"current": False}) tenant_account_join.current = True # Set the current tenant for the account account.current_tenant_id = tenant_account_join.tenant_id @@ -354,9 +356,7 @@ class TenantService: query = ( db.session.query(Account, TenantAccountJoin.role) .select_from(Account) - .join( - TenantAccountJoin, Account.id == TenantAccountJoin.account_id - ) + .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) .filter(TenantAccountJoin.tenant_id == tenant.id) ) @@ -375,11 +375,9 @@ class TenantService: query = ( db.session.query(Account, TenantAccountJoin.role) .select_from(Account) - .join( - TenantAccountJoin, Account.id == TenantAccountJoin.account_id - ) + .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) .filter(TenantAccountJoin.tenant_id == tenant.id) - .filter(TenantAccountJoin.role == 'dataset_operator') + .filter(TenantAccountJoin.role == "dataset_operator") ) # Initialize an empty list to store the updated accounts @@ -395,20 +393,25 @@ class TenantService: def has_roles(tenant: Tenant, roles: list[TenantAccountJoinRole]) -> bool: """Check if user has any of the given roles for a tenant""" if not all(isinstance(role, TenantAccountJoinRole) for role in roles): - raise ValueError('all roles must be TenantAccountJoinRole') + raise ValueError("all roles must be TenantAccountJoinRole") - return db.session.query(TenantAccountJoin).filter( - TenantAccountJoin.tenant_id == tenant.id, - TenantAccountJoin.role.in_([role.value for role in roles]) - ).first() is not None + return ( + db.session.query(TenantAccountJoin) + .filter( + TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role.in_([role.value for role in roles]) + ) + .first() + is not None + ) @staticmethod def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountJoinRole]: """Get the role of the current account for a given tenant""" - join = db.session.query(TenantAccountJoin).filter( - TenantAccountJoin.tenant_id == tenant.id, - TenantAccountJoin.account_id == account.id - ).first() + join = ( + db.session.query(TenantAccountJoin) + .filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) + .first() + ) return join.role if join else None @staticmethod @@ -420,29 +423,26 @@ class TenantService: def check_member_permission(tenant: Tenant, operator: Account, member: Account, action: str) -> None: """Check member permission""" perms = { - 'add': [TenantAccountRole.OWNER, TenantAccountRole.ADMIN], - 'remove': [TenantAccountRole.OWNER], - 'update': [TenantAccountRole.OWNER] + "add": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN], + "remove": [TenantAccountRole.OWNER], + "update": [TenantAccountRole.OWNER], } - if action not in ['add', 'remove', 'update']: + if action not in ["add", "remove", "update"]: raise InvalidActionError("Invalid action.") if member: if operator.id == member.id: raise CannotOperateSelfError("Cannot operate self.") - ta_operator = TenantAccountJoin.query.filter_by( - tenant_id=tenant.id, - account_id=operator.id - ).first() + ta_operator = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=operator.id).first() if not ta_operator or ta_operator.role not in perms[action]: - raise NoPermissionError(f'No permission to {action} member.') + raise NoPermissionError(f"No permission to {action} member.") @staticmethod def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None: """Remove member from tenant""" - if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, 'remove'): + if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, "remove"): raise CannotOperateSelfError("Cannot operate self.") ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() @@ -455,23 +455,17 @@ class TenantService: @staticmethod def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account) -> None: """Update member role""" - TenantService.check_member_permission(tenant, operator, member, 'update') + TenantService.check_member_permission(tenant, operator, member, "update") - target_member_join = TenantAccountJoin.query.filter_by( - tenant_id=tenant.id, - account_id=member.id - ).first() + target_member_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=member.id).first() if target_member_join.role == new_role: raise RoleAlreadyAssignedError("The provided role is already assigned to the member.") - if new_role == 'owner': + if new_role == "owner": # Find the current owner and change their role to 'admin' - current_owner_join = TenantAccountJoin.query.filter_by( - tenant_id=tenant.id, - role='owner' - ).first() - current_owner_join.role = 'admin' + current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first() + current_owner_join.role = "admin" # Update the role of the target member target_member_join.role = new_role @@ -480,8 +474,8 @@ class TenantService: @staticmethod def dissolve_tenant(tenant: Tenant, operator: Account) -> None: """Dissolve tenant""" - if not TenantService.check_member_permission(tenant, operator, operator, 'remove'): - raise NoPermissionError('No permission to dissolve tenant.') + if not TenantService.check_member_permission(tenant, operator, operator, "remove"): + raise NoPermissionError("No permission to dissolve tenant.") db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete() db.session.delete(tenant) db.session.commit() @@ -494,10 +488,9 @@ class TenantService: class RegisterService: - @classmethod def _get_invitation_token_key(cls, token: str) -> str: - return f'member_invite:token:{token}' + return f"member_invite:token:{token}" @classmethod def setup(cls, email: str, name: str, password: str, ip_address: str) -> None: @@ -523,9 +516,7 @@ class RegisterService: TenantService.create_owner_tenant_if_not_exist(account) - dify_setup = DifySetup( - version=dify_config.CURRENT_VERSION - ) + dify_setup = DifySetup(version=dify_config.CURRENT_VERSION) db.session.add(dify_setup) db.session.commit() except Exception as e: @@ -535,34 +526,35 @@ class RegisterService: db.session.query(Tenant).delete() db.session.commit() - logging.exception(f'Setup failed: {e}') - raise ValueError(f'Setup failed: {e}') + logging.exception(f"Setup failed: {e}") + raise ValueError(f"Setup failed: {e}") @classmethod - def register(cls, email, name, - password: Optional[str] = None, - open_id: Optional[str] = None, - provider: Optional[str] = None, - language: Optional[str] = None, - status: Optional[AccountStatus] = None) -> Account: + def register( + cls, + email, + name, + password: Optional[str] = None, + open_id: Optional[str] = None, + provider: Optional[str] = None, + language: Optional[str] = None, + status: Optional[AccountStatus] = None, + ) -> Account: db.session.begin_nested() """Register account""" try: account = AccountService.create_account( - email=email, - name=name, - interface_language=language if language else languages[0], - password=password + email=email, name=name, interface_language=language if language else languages[0], password=password ) account.status = AccountStatus.ACTIVE.value if not status else status.value account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) if open_id is not None or provider is not None: AccountService.link_account_integrate(provider, open_id, account) - if dify_config.EDITION != 'SELF_HOSTED': + if dify_config.EDITION != "SELF_HOSTED": tenant = TenantService.create_tenant(f"{account.name}'s Workspace") - TenantService.create_tenant_member(tenant, account, role='owner') + TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant tenant_was_created.send(tenant) @@ -570,30 +562,29 @@ class RegisterService: db.session.commit() except Exception as e: db.session.rollback() - logging.error(f'Register failed: {e}') - raise AccountRegisterError(f'Registration failed: {e}') from e + logging.error(f"Register failed: {e}") + raise AccountRegisterError(f"Registration failed: {e}") from e return account @classmethod - def invite_new_member(cls, tenant: Tenant, email: str, language: str, role: str = 'normal', inviter: Account = None) -> str: + def invite_new_member( + cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account = None + ) -> str: """Invite new member""" account = Account.query.filter_by(email=email).first() if not account: - TenantService.check_member_permission(tenant, inviter, None, 'add') - name = email.split('@')[0] + TenantService.check_member_permission(tenant, inviter, None, "add") + name = email.split("@")[0] account = cls.register(email=email, name=name, language=language, status=AccountStatus.PENDING) # Create new tenant member for invited tenant TenantService.create_tenant_member(tenant, account, role) TenantService.switch_tenant(account, tenant.id) else: - TenantService.check_member_permission(tenant, inviter, account, 'add') - ta = TenantAccountJoin.query.filter_by( - tenant_id=tenant.id, - account_id=account.id - ).first() + TenantService.check_member_permission(tenant, inviter, account, "add") + ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() if not ta: TenantService.create_tenant_member(tenant, account, role) @@ -609,7 +600,7 @@ class RegisterService: language=account.interface_language, to=email, token=token, - inviter_name=inviter.name if inviter else 'Dify', + inviter_name=inviter.name if inviter else "Dify", workspace_name=tenant.name, ) @@ -619,23 +610,19 @@ class RegisterService: def generate_invite_token(cls, tenant: Tenant, account: Account) -> str: token = str(uuid.uuid4()) invitation_data = { - 'account_id': account.id, - 'email': account.email, - 'workspace_id': tenant.id, + "account_id": account.id, + "email": account.email, + "workspace_id": tenant.id, } expiryHours = dify_config.INVITE_EXPIRY_HOURS - redis_client.setex( - cls._get_invitation_token_key(token), - expiryHours * 60 * 60, - json.dumps(invitation_data) - ) + redis_client.setex(cls._get_invitation_token_key(token), expiryHours * 60 * 60, json.dumps(invitation_data)) return token @classmethod def revoke_token(cls, workspace_id: str, email: str, token: str): if workspace_id and email: email_hash = sha256(email.encode()).hexdigest() - cache_key = 'member_invite_token:{}, {}:{}'.format(workspace_id, email_hash, token) + cache_key = "member_invite_token:{}, {}:{}".format(workspace_id, email_hash, token) redis_client.delete(cache_key) else: redis_client.delete(cls._get_invitation_token_key(token)) @@ -646,17 +633,21 @@ class RegisterService: if not invitation_data: return None - tenant = db.session.query(Tenant).filter( - Tenant.id == invitation_data['workspace_id'], - Tenant.status == 'normal' - ).first() + tenant = ( + db.session.query(Tenant) + .filter(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal") + .first() + ) if not tenant: return None - tenant_account = db.session.query(Account, TenantAccountJoin.role).join( - TenantAccountJoin, Account.id == TenantAccountJoin.account_id - ).filter(Account.email == invitation_data['email'], TenantAccountJoin.tenant_id == tenant.id).first() + tenant_account = ( + db.session.query(Account, TenantAccountJoin.role) + .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) + .filter(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id) + .first() + ) if not tenant_account: return None @@ -665,29 +656,29 @@ class RegisterService: if not account: return None - if invitation_data['account_id'] != str(account.id): + if invitation_data["account_id"] != str(account.id): return None return { - 'account': account, - 'data': invitation_data, - 'tenant': tenant, + "account": account, + "data": invitation_data, + "tenant": tenant, } @classmethod def _get_invitation_by_token(cls, token: str, workspace_id: str, email: str) -> Optional[dict[str, str]]: if workspace_id is not None and email is not None: email_hash = sha256(email.encode()).hexdigest() - cache_key = f'member_invite_token:{workspace_id}, {email_hash}:{token}' + cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}" account_id = redis_client.get(cache_key) if not account_id: return None return { - 'account_id': account_id.decode('utf-8'), - 'email': email, - 'workspace_id': workspace_id, + "account_id": account_id.decode("utf-8"), + "email": email, + "workspace_id": workspace_id, } else: data = redis_client.get(cls._get_invitation_token_key(token)) diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index 213df26222..d2cd7bea67 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -1,4 +1,3 @@ - import copy from core.prompt.prompt_templates.advanced_prompt_templates import ( @@ -17,59 +16,78 @@ from models.model import AppMode class AdvancedPromptTemplateService: - @classmethod def get_prompt(cls, args: dict) -> dict: - app_mode = args['app_mode'] - model_mode = args['model_mode'] - model_name = args['model_name'] - has_context = args['has_context'] + app_mode = args["app_mode"] + model_mode = args["model_mode"] + model_name = args["model_name"] + has_context = args["has_context"] - if 'baichuan' in model_name.lower(): + if "baichuan" in model_name.lower(): return cls.get_baichuan_prompt(app_mode, model_mode, has_context) else: return cls.get_common_prompt(app_mode, model_mode, has_context) @classmethod - def get_common_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict: + def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict: context_prompt = copy.deepcopy(CONTEXT) if app_mode == AppMode.CHAT.value: if model_mode == "completion": - return cls.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt) + return cls.get_completion_prompt( + copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt + ) elif model_mode == "chat": return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt) elif app_mode == AppMode.COMPLETION.value: if model_mode == "completion": - return cls.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt) + return cls.get_completion_prompt( + copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt + ) elif model_mode == "chat": - return cls.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt) - + return cls.get_chat_prompt( + copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt + ) + @classmethod def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict: - if has_context == 'true': - prompt_template['completion_prompt_config']['prompt']['text'] = context + prompt_template['completion_prompt_config']['prompt']['text'] - + if has_context == "true": + prompt_template["completion_prompt_config"]["prompt"]["text"] = ( + context + prompt_template["completion_prompt_config"]["prompt"]["text"] + ) + return prompt_template @classmethod def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict: - if has_context == 'true': - prompt_template['chat_prompt_config']['prompt'][0]['text'] = context + prompt_template['chat_prompt_config']['prompt'][0]['text'] - + if has_context == "true": + prompt_template["chat_prompt_config"]["prompt"][0]["text"] = ( + context + prompt_template["chat_prompt_config"]["prompt"][0]["text"] + ) + return prompt_template @classmethod - def get_baichuan_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict: + def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict: baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT) if app_mode == AppMode.CHAT.value: if model_mode == "completion": - return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt) + return cls.get_completion_prompt( + copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt + ) elif model_mode == "chat": - return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt) + return cls.get_chat_prompt( + copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt + ) elif app_mode == AppMode.COMPLETION.value: if model_mode == "completion": - return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt) + return cls.get_completion_prompt( + copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), + has_context, + baichuan_context_prompt, + ) elif model_mode == "chat": - return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt) \ No newline at end of file + return cls.get_chat_prompt( + copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt + ) diff --git a/api/services/agent_service.py b/api/services/agent_service.py index ba5fd93326..887fb878b9 100644 --- a/api/services/agent_service.py +++ b/api/services/agent_service.py @@ -10,59 +10,65 @@ from models.model import App, Conversation, EndUser, Message, MessageAgentThough class AgentService: @classmethod - def get_agent_logs(cls, app_model: App, - conversation_id: str, - message_id: str) -> dict: + def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) -> dict: """ Service to get agent logs """ - conversation: Conversation = db.session.query(Conversation).filter( - Conversation.id == conversation_id, - Conversation.app_id == app_model.id, - ).first() + conversation: Conversation = ( + db.session.query(Conversation) + .filter( + Conversation.id == conversation_id, + Conversation.app_id == app_model.id, + ) + .first() + ) if not conversation: raise ValueError(f"Conversation not found: {conversation_id}") - message: Message = db.session.query(Message).filter( - Message.id == message_id, - Message.conversation_id == conversation_id, - ).first() + message: Message = ( + db.session.query(Message) + .filter( + Message.id == message_id, + Message.conversation_id == conversation_id, + ) + .first() + ) if not message: raise ValueError(f"Message not found: {message_id}") - + agent_thoughts: list[MessageAgentThought] = message.agent_thoughts if conversation.from_end_user_id: # only select name field - executor = db.session.query(EndUser, EndUser.name).filter( - EndUser.id == conversation.from_end_user_id - ).first() + executor = ( + db.session.query(EndUser, EndUser.name).filter(EndUser.id == conversation.from_end_user_id).first() + ) else: - executor = db.session.query(Account, Account.name).filter( - Account.id == conversation.from_account_id - ).first() - + executor = ( + db.session.query(Account, Account.name).filter(Account.id == conversation.from_account_id).first() + ) + if executor: executor = executor.name else: - executor = 'Unknown' + executor = "Unknown" timezone = pytz.timezone(current_user.timezone) result = { - 'meta': { - 'status': 'success', - 'executor': executor, - 'start_time': message.created_at.astimezone(timezone).isoformat(), - 'elapsed_time': message.provider_response_latency, - 'total_tokens': message.answer_tokens + message.message_tokens, - 'agent_mode': app_model.app_model_config.agent_mode_dict.get('strategy', 'react'), - 'iterations': len(agent_thoughts), + "meta": { + "status": "success", + "executor": executor, + "start_time": message.created_at.astimezone(timezone).isoformat(), + "elapsed_time": message.provider_response_latency, + "total_tokens": message.answer_tokens + message.message_tokens, + "agent_mode": app_model.app_model_config.agent_mode_dict.get("strategy", "react"), + "iterations": len(agent_thoughts), }, - 'iterations': [], - 'files': message.files, + "iterations": [], + "files": message.files, } agent_config = AgentConfigManager.convert(app_model.app_model_config.to_dict()) @@ -86,12 +92,12 @@ class AgentService: tool_input = tool_inputs.get(tool_name, {}) tool_output = tool_outputs.get(tool_name, {}) tool_meta_data = tool_meta.get(tool_name, {}) - tool_config = tool_meta_data.get('tool_config', {}) - if tool_config.get('tool_provider_type', '') != 'dataset-retrieval': + tool_config = tool_meta_data.get("tool_config", {}) + if tool_config.get("tool_provider_type", "") != "dataset-retrieval": tool_icon = ToolManager.get_tool_icon( tenant_id=app_model.tenant_id, - provider_type=tool_config.get('tool_provider_type', ''), - provider_id=tool_config.get('tool_provider', ''), + provider_type=tool_config.get("tool_provider_type", ""), + provider_id=tool_config.get("tool_provider", ""), ) if not tool_icon: tool_entity = find_agent_tool(tool_name) @@ -102,30 +108,34 @@ class AgentService: provider_id=tool_entity.provider_id, ) else: - tool_icon = '' + tool_icon = "" - tool_calls.append({ - 'status': 'success' if not tool_meta_data.get('error') else 'error', - 'error': tool_meta_data.get('error'), - 'time_cost': tool_meta_data.get('time_cost', 0), - 'tool_name': tool_name, - 'tool_label': tool_label, - 'tool_input': tool_input, - 'tool_output': tool_output, - 'tool_parameters': tool_meta_data.get('tool_parameters', {}), - 'tool_icon': tool_icon, - }) + tool_calls.append( + { + "status": "success" if not tool_meta_data.get("error") else "error", + "error": tool_meta_data.get("error"), + "time_cost": tool_meta_data.get("time_cost", 0), + "tool_name": tool_name, + "tool_label": tool_label, + "tool_input": tool_input, + "tool_output": tool_output, + "tool_parameters": tool_meta_data.get("tool_parameters", {}), + "tool_icon": tool_icon, + } + ) - result['iterations'].append({ - 'tokens': agent_thought.tokens, - 'tool_calls': tool_calls, - 'tool_raw': { - 'inputs': agent_thought.tool_input, - 'outputs': agent_thought.observation, - }, - 'thought': agent_thought.thought, - 'created_at': agent_thought.created_at.isoformat(), - 'files': agent_thought.files, - }) + result["iterations"].append( + { + "tokens": agent_thought.tokens, + "tool_calls": tool_calls, + "tool_raw": { + "inputs": agent_thought.tool_input, + "outputs": agent_thought.observation, + }, + "thought": agent_thought.thought, + "created_at": agent_thought.created_at.isoformat(), + "files": agent_thought.files, + } + ) - return result \ No newline at end of file + return result diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index addcde44ed..3cc6c51c2d 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -23,21 +23,18 @@ class AppAnnotationService: @classmethod def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation: # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") - if args.get('message_id'): - message_id = str(args['message_id']) + if args.get("message_id"): + message_id = str(args["message_id"]) # get message info - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app.id - ).first() + message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app.id).first() if not message: raise NotFound("Message Not Exists.") @@ -45,159 +42,166 @@ class AppAnnotationService: annotation = message.annotation # save the message annotation if annotation: - annotation.content = args['answer'] - annotation.question = args['question'] + annotation.content = args["answer"] + annotation.question = args["question"] else: annotation = MessageAnnotation( app_id=app.id, conversation_id=message.conversation_id, message_id=message.id, - content=args['answer'], - question=args['question'], - account_id=current_user.id + content=args["answer"], + question=args["question"], + account_id=current_user.id, ) else: annotation = MessageAnnotation( - app_id=app.id, - content=args['answer'], - question=args['question'], - account_id=current_user.id + app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id ) db.session.add(annotation) db.session.commit() # if annotation reply is enabled , add annotation to index - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id).first() + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if annotation_setting: - add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id, - app_id, annotation_setting.collection_binding_id) + add_annotation_to_index_task.delay( + annotation.id, + args["question"], + current_user.current_tenant_id, + app_id, + annotation_setting.collection_binding_id, + ) return annotation @classmethod def enable_app_annotation(cls, args: dict, app_id: str) -> dict: - enable_app_annotation_key = 'enable_app_annotation_{}'.format(str(app_id)) + enable_app_annotation_key = "enable_app_annotation_{}".format(str(app_id)) cache_result = redis_client.get(enable_app_annotation_key) if cache_result is not None: - return { - 'job_id': cache_result, - 'job_status': 'processing' - } + return {"job_id": cache_result, "job_status": "processing"} # async job job_id = str(uuid.uuid4()) - enable_app_annotation_job_key = 'enable_app_annotation_job_{}'.format(str(job_id)) + enable_app_annotation_job_key = "enable_app_annotation_job_{}".format(str(job_id)) # send batch add segments task - redis_client.setnx(enable_app_annotation_job_key, 'waiting') - enable_annotation_reply_task.delay(str(job_id), app_id, current_user.id, current_user.current_tenant_id, - args['score_threshold'], - args['embedding_provider_name'], args['embedding_model_name']) - return { - 'job_id': job_id, - 'job_status': 'waiting' - } + redis_client.setnx(enable_app_annotation_job_key, "waiting") + enable_annotation_reply_task.delay( + str(job_id), + app_id, + current_user.id, + current_user.current_tenant_id, + args["score_threshold"], + args["embedding_provider_name"], + args["embedding_model_name"], + ) + return {"job_id": job_id, "job_status": "waiting"} @classmethod def disable_app_annotation(cls, app_id: str) -> dict: - disable_app_annotation_key = 'disable_app_annotation_{}'.format(str(app_id)) + disable_app_annotation_key = "disable_app_annotation_{}".format(str(app_id)) cache_result = redis_client.get(disable_app_annotation_key) if cache_result is not None: - return { - 'job_id': cache_result, - 'job_status': 'processing' - } + return {"job_id": cache_result, "job_status": "processing"} # async job job_id = str(uuid.uuid4()) - disable_app_annotation_job_key = 'disable_app_annotation_job_{}'.format(str(job_id)) + disable_app_annotation_job_key = "disable_app_annotation_job_{}".format(str(job_id)) # send batch add segments task - redis_client.setnx(disable_app_annotation_job_key, 'waiting') + redis_client.setnx(disable_app_annotation_job_key, "waiting") disable_annotation_reply_task.delay(str(job_id), app_id, current_user.current_tenant_id) - return { - 'job_id': job_id, - 'job_status': 'waiting' - } + return {"job_id": job_id, "job_status": "waiting"} @classmethod def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") if keyword: - annotations = (db.session.query(MessageAnnotation) - .filter(MessageAnnotation.app_id == app_id) - .filter( - or_( - MessageAnnotation.question.ilike('%{}%'.format(keyword)), - MessageAnnotation.content.ilike('%{}%'.format(keyword)) + annotations = ( + db.session.query(MessageAnnotation) + .filter(MessageAnnotation.app_id == app_id) + .filter( + or_( + MessageAnnotation.question.ilike("%{}%".format(keyword)), + MessageAnnotation.content.ilike("%{}%".format(keyword)), + ) ) + .order_by(MessageAnnotation.created_at.desc()) + .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) ) - .order_by(MessageAnnotation.created_at.desc()) - .paginate(page=page, per_page=limit, max_per_page=100, error_out=False)) else: - annotations = (db.session.query(MessageAnnotation) - .filter(MessageAnnotation.app_id == app_id) - .order_by(MessageAnnotation.created_at.desc()) - .paginate(page=page, per_page=limit, max_per_page=100, error_out=False)) + annotations = ( + db.session.query(MessageAnnotation) + .filter(MessageAnnotation.app_id == app_id) + .order_by(MessageAnnotation.created_at.desc()) + .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + ) return annotations.items, annotations.total @classmethod def export_annotation_list_by_app_id(cls, app_id: str): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") - annotations = (db.session.query(MessageAnnotation) - .filter(MessageAnnotation.app_id == app_id) - .order_by(MessageAnnotation.created_at.desc()).all()) + annotations = ( + db.session.query(MessageAnnotation) + .filter(MessageAnnotation.app_id == app_id) + .order_by(MessageAnnotation.created_at.desc()) + .all() + ) return annotations @classmethod def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation: # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") annotation = MessageAnnotation( - app_id=app.id, - content=args['answer'], - question=args['question'], - account_id=current_user.id + app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id ) db.session.add(annotation) db.session.commit() # if annotation reply is enabled , add annotation to index - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id).first() + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if annotation_setting: - add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id, - app_id, annotation_setting.collection_binding_id) + add_annotation_to_index_task.delay( + annotation.id, + args["question"], + current_user.current_tenant_id, + app_id, + annotation_setting.collection_binding_id, + ) return annotation @classmethod def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") @@ -207,30 +211,34 @@ class AppAnnotationService: if not annotation: raise NotFound("Annotation not found") - annotation.content = args['answer'] - annotation.question = args['question'] + annotation.content = args["answer"] + annotation.question = args["question"] db.session.commit() # if annotation reply is enabled , add annotation to index - app_annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id - ).first() + app_annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if app_annotation_setting: - update_annotation_to_index_task.delay(annotation.id, annotation.question, - current_user.current_tenant_id, - app_id, app_annotation_setting.collection_binding_id) + update_annotation_to_index_task.delay( + annotation.id, + annotation.question, + current_user.current_tenant_id, + app_id, + app_annotation_setting.collection_binding_id, + ) return annotation @classmethod def delete_app_annotation(cls, app_id: str, annotation_id: str): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") @@ -242,33 +250,34 @@ class AppAnnotationService: db.session.delete(annotation) - annotation_hit_histories = (db.session.query(AppAnnotationHitHistory) - .filter(AppAnnotationHitHistory.annotation_id == annotation_id) - .all() - ) + annotation_hit_histories = ( + db.session.query(AppAnnotationHitHistory) + .filter(AppAnnotationHitHistory.annotation_id == annotation_id) + .all() + ) if annotation_hit_histories: for annotation_hit_history in annotation_hit_histories: db.session.delete(annotation_hit_history) db.session.commit() # if annotation reply is enabled , delete annotation index - app_annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id - ).first() + app_annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if app_annotation_setting: - delete_annotation_index_task.delay(annotation.id, app_id, - current_user.current_tenant_id, - app_annotation_setting.collection_binding_id) + delete_annotation_index_task.delay( + annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id + ) @classmethod def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict: # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") @@ -278,10 +287,7 @@ class AppAnnotationService: df = pd.read_csv(file) result = [] for index, row in df.iterrows(): - content = { - 'question': row[0], - 'answer': row[1] - } + content = {"question": row[0], "answer": row[1]} result.append(content) if len(result) == 0: raise ValueError("The CSV file is empty.") @@ -293,28 +299,24 @@ class AppAnnotationService: raise ValueError("The number of annotations exceeds the limit of your subscription.") # async job job_id = str(uuid.uuid4()) - indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id)) + indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id)) # send batch add segments task - redis_client.setnx(indexing_cache_key, 'waiting') - batch_import_annotations_task.delay(str(job_id), result, app_id, - current_user.current_tenant_id, current_user.id) + redis_client.setnx(indexing_cache_key, "waiting") + batch_import_annotations_task.delay( + str(job_id), result, app_id, current_user.current_tenant_id, current_user.id + ) except Exception as e: - return { - 'error_msg': str(e) - } - return { - 'job_id': job_id, - 'job_status': 'waiting' - } + return {"error_msg": str(e)} + return {"job_id": job_id, "job_status": "waiting"} @classmethod def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") @@ -324,12 +326,15 @@ class AppAnnotationService: if not annotation: raise NotFound("Annotation not found") - annotation_hit_histories = (db.session.query(AppAnnotationHitHistory) - .filter(AppAnnotationHitHistory.app_id == app_id, - AppAnnotationHitHistory.annotation_id == annotation_id, - ) - .order_by(AppAnnotationHitHistory.created_at.desc()) - .paginate(page=page, per_page=limit, max_per_page=100, error_out=False)) + annotation_hit_histories = ( + db.session.query(AppAnnotationHitHistory) + .filter( + AppAnnotationHitHistory.app_id == app_id, + AppAnnotationHitHistory.annotation_id == annotation_id, + ) + .order_by(AppAnnotationHitHistory.created_at.desc()) + .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + ) return annotation_hit_histories.items, annotation_hit_histories.total @classmethod @@ -341,15 +346,21 @@ class AppAnnotationService: return annotation @classmethod - def add_annotation_history(cls, annotation_id: str, app_id: str, annotation_question: str, - annotation_content: str, query: str, user_id: str, - message_id: str, from_source: str, score: float): + def add_annotation_history( + cls, + annotation_id: str, + app_id: str, + annotation_question: str, + annotation_content: str, + query: str, + user_id: str, + message_id: str, + from_source: str, + score: float, + ): # add hit count to annotation - db.session.query(MessageAnnotation).filter( - MessageAnnotation.id == annotation_id - ).update( - {MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1}, - synchronize_session=False + db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).update( + {MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1}, synchronize_session=False ) annotation_hit_history = AppAnnotationHitHistory( @@ -361,7 +372,7 @@ class AppAnnotationService: score=score, message_id=message_id, annotation_question=annotation_question, - annotation_content=annotation_content + annotation_content=annotation_content, ) db.session.add(annotation_hit_history) db.session.commit() @@ -369,17 +380,18 @@ class AppAnnotationService: @classmethod def get_app_annotation_setting_by_app_id(cls, app_id: str): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id).first() + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if annotation_setting: collection_binding_detail = annotation_setting.collection_binding_detail return { @@ -388,32 +400,34 @@ class AppAnnotationService: "score_threshold": annotation_setting.score_threshold, "embedding_model": { "embedding_provider_name": collection_binding_detail.provider_name, - "embedding_model_name": collection_binding_detail.model_name - } + "embedding_model_name": collection_binding_detail.model_name, + }, } - return { - "enabled": False - } + return {"enabled": False} @classmethod def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id, - AppAnnotationSetting.id == annotation_setting_id, - ).first() + annotation_setting = ( + db.session.query(AppAnnotationSetting) + .filter( + AppAnnotationSetting.app_id == app_id, + AppAnnotationSetting.id == annotation_setting_id, + ) + .first() + ) if not annotation_setting: raise NotFound("App annotation not found") - annotation_setting.score_threshold = args['score_threshold'] + annotation_setting.score_threshold = args["score_threshold"] annotation_setting.updated_user_id = current_user.id annotation_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(annotation_setting) @@ -427,6 +441,6 @@ class AppAnnotationService: "score_threshold": annotation_setting.score_threshold, "embedding_model": { "embedding_provider_name": collection_binding_detail.provider_name, - "embedding_model_name": collection_binding_detail.model_name - } + "embedding_model_name": collection_binding_detail.model_name, + }, } diff --git a/api/services/api_based_extension_service.py b/api/services/api_based_extension_service.py index 8441bbedb3..601d67d2fb 100644 --- a/api/services/api_based_extension_service.py +++ b/api/services/api_based_extension_service.py @@ -5,13 +5,14 @@ from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint class APIBasedExtensionService: - @staticmethod def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]: - extension_list = db.session.query(APIBasedExtension) \ - .filter_by(tenant_id=tenant_id) \ - .order_by(APIBasedExtension.created_at.desc()) \ - .all() + extension_list = ( + db.session.query(APIBasedExtension) + .filter_by(tenant_id=tenant_id) + .order_by(APIBasedExtension.created_at.desc()) + .all() + ) for extension in extension_list: extension.api_key = decrypt_token(extension.tenant_id, extension.api_key) @@ -35,10 +36,12 @@ class APIBasedExtensionService: @staticmethod def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: - extension = db.session.query(APIBasedExtension) \ - .filter_by(tenant_id=tenant_id) \ - .filter_by(id=api_based_extension_id) \ + extension = ( + db.session.query(APIBasedExtension) + .filter_by(tenant_id=tenant_id) + .filter_by(id=api_based_extension_id) .first() + ) if not extension: raise ValueError("API based extension is not found") @@ -55,20 +58,24 @@ class APIBasedExtensionService: if not extension_data.id: # case one: check new data, name must be unique - is_name_existed = db.session.query(APIBasedExtension) \ - .filter_by(tenant_id=extension_data.tenant_id) \ - .filter_by(name=extension_data.name) \ + is_name_existed = ( + db.session.query(APIBasedExtension) + .filter_by(tenant_id=extension_data.tenant_id) + .filter_by(name=extension_data.name) .first() + ) if is_name_existed: raise ValueError("name must be unique, it is already existed") else: # case two: check existing data, name must be unique - is_name_existed = db.session.query(APIBasedExtension) \ - .filter_by(tenant_id=extension_data.tenant_id) \ - .filter_by(name=extension_data.name) \ - .filter(APIBasedExtension.id != extension_data.id) \ + is_name_existed = ( + db.session.query(APIBasedExtension) + .filter_by(tenant_id=extension_data.tenant_id) + .filter_by(name=extension_data.name) + .filter(APIBasedExtension.id != extension_data.id) .first() + ) if is_name_existed: raise ValueError("name must be unique, it is already existed") @@ -92,7 +99,7 @@ class APIBasedExtensionService: try: client = APIBasedExtensionRequestor(extension_data.api_endpoint, extension_data.api_key) resp = client.request(point=APIBasedExtensionPoint.PING, params={}) - if resp.get('result') != 'pong': + if resp.get("result") != "pong": raise ValueError(resp) except Exception as e: raise ValueError("connection error: {}".format(e)) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 737def3366..895855a9c8 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -75,43 +75,47 @@ class AppDslService: # check or repair dsl version import_data = cls._check_or_fix_dsl(import_data) - app_data = import_data.get('app') + app_data = import_data.get("app") if not app_data: raise ValueError("Missing app in data argument") # get app basic info - name = args.get("name") if args.get("name") else app_data.get('name') - description = args.get("description") if args.get("description") else app_data.get('description', '') - icon_type = args.get("icon_type") if args.get("icon_type") else app_data.get('icon_type') - icon = args.get("icon") if args.get("icon") else app_data.get('icon') - icon_background = args.get("icon_background") if args.get("icon_background") \ - else app_data.get('icon_background') + name = args.get("name") if args.get("name") else app_data.get("name") + description = args.get("description") if args.get("description") else app_data.get("description", "") + icon_type = args.get("icon_type") if args.get("icon_type") else app_data.get("icon_type") + icon = args.get("icon") if args.get("icon") else app_data.get("icon") + icon_background = ( + args.get("icon_background") if args.get("icon_background") else app_data.get("icon_background") + ) + use_icon_as_answer_icon = app_data.get("use_icon_as_answer_icon", False) # import dsl and create app - app_mode = AppMode.value_of(app_data.get('mode')) + app_mode = AppMode.value_of(app_data.get("mode")) if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: app = cls._import_and_create_new_workflow_based_app( tenant_id=tenant_id, app_mode=app_mode, - workflow_data=import_data.get('workflow'), + workflow_data=import_data.get("workflow"), account=account, name=name, description=description, icon_type=icon_type, icon=icon, - icon_background=icon_background + icon_background=icon_background, + use_icon_as_answer_icon=use_icon_as_answer_icon, ) elif app_mode in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]: app = cls._import_and_create_new_model_config_based_app( tenant_id=tenant_id, app_mode=app_mode, - model_config_data=import_data.get('model_config'), + model_config_data=import_data.get("model_config"), account=account, name=name, description=description, icon_type=icon_type, icon=icon, - icon_background=icon_background + icon_background=icon_background, + use_icon_as_answer_icon=use_icon_as_answer_icon, ) else: raise ValueError("Invalid app mode") @@ -134,27 +138,26 @@ class AppDslService: # check or repair dsl version import_data = cls._check_or_fix_dsl(import_data) - app_data = import_data.get('app') + app_data = import_data.get("app") if not app_data: raise ValueError("Missing app in data argument") # import dsl and overwrite app - app_mode = AppMode.value_of(app_data.get('mode')) + app_mode = AppMode.value_of(app_data.get("mode")) if app_mode not in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: raise ValueError("Only support import workflow in advanced-chat or workflow app.") - if app_data.get('mode') != app_model.mode: - raise ValueError( - f"App mode {app_data.get('mode')} is not matched with current app mode {app_mode.value}") + if app_data.get("mode") != app_model.mode: + raise ValueError(f"App mode {app_data.get('mode')} is not matched with current app mode {app_mode.value}") return cls._import_and_overwrite_workflow_based_app( app_model=app_model, - workflow_data=import_data.get('workflow'), + workflow_data=import_data.get("workflow"), account=account, ) @classmethod - def export_dsl(cls, app_model: App, include_secret:bool = False) -> str: + def export_dsl(cls, app_model: App, include_secret: bool = False) -> str: """ Export app :param app_model: App instance @@ -168,14 +171,17 @@ class AppDslService: "app": { "name": app_model.name, "mode": app_model.mode, - "icon": '🤖' if app_model.icon_type == 'image' else app_model.icon, - "icon_background": '#FFEAD5' if app_model.icon_type == 'image' else app_model.icon_background, - "description": app_model.description - } + "icon": "🤖" if app_model.icon_type == "image" else app_model.icon, + "icon_background": "#FFEAD5" if app_model.icon_type == "image" else app_model.icon_background, + "description": app_model.description, + "use_icon_as_answer_icon": app_model.use_icon_as_answer_icon, + }, } if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: - cls._append_workflow_export_data(export_data=export_data, app_model=app_model, include_secret=include_secret) + cls._append_workflow_export_data( + export_data=export_data, app_model=app_model, include_secret=include_secret + ) else: cls._append_model_config_export_data(export_data, app_model) @@ -188,31 +194,36 @@ class AppDslService: :param import_data: import data """ - if not import_data.get('version'): - import_data['version'] = "0.1.0" + if not import_data.get("version"): + import_data["version"] = "0.1.0" - if not import_data.get('kind') or import_data.get('kind') != "app": - import_data['kind'] = "app" + if not import_data.get("kind") or import_data.get("kind") != "app": + import_data["kind"] = "app" - if import_data.get('version') != current_dsl_version: + if import_data.get("version") != current_dsl_version: # Currently only one DSL version, so no difference checks or compatibility fixes will be performed. - logger.warning(f"DSL version {import_data.get('version')} is not compatible " - f"with current version {current_dsl_version}, related to " - f"Dify version {dsl_to_dify_version_mapping.get(current_dsl_version)}.") + logger.warning( + f"DSL version {import_data.get('version')} is not compatible " + f"with current version {current_dsl_version}, related to " + f"Dify version {dsl_to_dify_version_mapping.get(current_dsl_version)}." + ) return import_data @classmethod - def _import_and_create_new_workflow_based_app(cls, - tenant_id: str, - app_mode: AppMode, - workflow_data: dict, - account: Account, - name: str, - description: str, - icon_type: str, - icon: str, - icon_background: str) -> App: + def _import_and_create_new_workflow_based_app( + cls, + tenant_id: str, + app_mode: AppMode, + workflow_data: dict, + account: Account, + name: str, + description: str, + icon_type: str, + icon: str, + icon_background: str, + use_icon_as_answer_icon: bool, + ) -> App: """ Import app dsl and create new workflow based app @@ -225,10 +236,10 @@ class AppDslService: :param icon_type: app icon type, "emoji" or "image" :param icon: app icon :param icon_background: app icon background + :param use_icon_as_answer_icon: use app icon as answer icon """ if not workflow_data: - raise ValueError("Missing workflow in data argument " - "when app mode is advanced-chat or workflow") + raise ValueError("Missing workflow in data argument " "when app mode is advanced-chat or workflow") app = cls._create_app( tenant_id=tenant_id, @@ -238,37 +249,33 @@ class AppDslService: description=description, icon_type=icon_type, icon=icon, - icon_background=icon_background + icon_background=icon_background, + use_icon_as_answer_icon=use_icon_as_answer_icon, ) # init draft workflow - environment_variables_list = workflow_data.get('environment_variables') or [] + environment_variables_list = workflow_data.get("environment_variables") or [] environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] - conversation_variables_list = workflow_data.get('conversation_variables') or [] + conversation_variables_list = workflow_data.get("conversation_variables") or [] conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] workflow_service = WorkflowService() draft_workflow = workflow_service.sync_draft_workflow( app_model=app, - graph=workflow_data.get('graph', {}), - features=workflow_data.get('../core/app/features', {}), + graph=workflow_data.get("graph", {}), + features=workflow_data.get("../core/app/features", {}), unique_hash=None, account=account, environment_variables=environment_variables, conversation_variables=conversation_variables, ) - workflow_service.publish_workflow( - app_model=app, - account=account, - draft_workflow=draft_workflow - ) + workflow_service.publish_workflow(app_model=app, account=account, draft_workflow=draft_workflow) return app @classmethod - def _import_and_overwrite_workflow_based_app(cls, - app_model: App, - workflow_data: dict, - account: Account) -> Workflow: + def _import_and_overwrite_workflow_based_app( + cls, app_model: App, workflow_data: dict, account: Account + ) -> Workflow: """ Import app dsl and overwrite workflow based app @@ -277,8 +284,7 @@ class AppDslService: :param account: Account instance """ if not workflow_data: - raise ValueError("Missing workflow in data argument " - "when app mode is advanced-chat or workflow") + raise ValueError("Missing workflow in data argument " "when app mode is advanced-chat or workflow") # fetch draft workflow by app_model workflow_service = WorkflowService() @@ -289,14 +295,14 @@ class AppDslService: unique_hash = None # sync draft workflow - environment_variables_list = workflow_data.get('environment_variables') or [] + environment_variables_list = workflow_data.get("environment_variables") or [] environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] - conversation_variables_list = workflow_data.get('conversation_variables') or [] + conversation_variables_list = workflow_data.get("conversation_variables") or [] conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] draft_workflow = workflow_service.sync_draft_workflow( app_model=app_model, - graph=workflow_data.get('graph', {}), - features=workflow_data.get('features', {}), + graph=workflow_data.get("graph", {}), + features=workflow_data.get("features", {}), unique_hash=unique_hash, account=account, environment_variables=environment_variables, @@ -306,16 +312,19 @@ class AppDslService: return draft_workflow @classmethod - def _import_and_create_new_model_config_based_app(cls, - tenant_id: str, - app_mode: AppMode, - model_config_data: dict, - account: Account, - name: str, - description: str, - icon_type: str, - icon: str, - icon_background: str) -> App: + def _import_and_create_new_model_config_based_app( + cls, + tenant_id: str, + app_mode: AppMode, + model_config_data: dict, + account: Account, + name: str, + description: str, + icon_type: str, + icon: str, + icon_background: str, + use_icon_as_answer_icon: bool, + ) -> App: """ Import app dsl and create new model config based app @@ -329,8 +338,7 @@ class AppDslService: :param icon_background: app icon background """ if not model_config_data: - raise ValueError("Missing model_config in data argument " - "when app mode is chat, agent-chat or completion") + raise ValueError("Missing model_config in data argument " "when app mode is chat, agent-chat or completion") app = cls._create_app( tenant_id=tenant_id, @@ -340,35 +348,38 @@ class AppDslService: description=description, icon_type=icon_type, icon=icon, - icon_background=icon_background + icon_background=icon_background, + use_icon_as_answer_icon=use_icon_as_answer_icon, ) app_model_config = AppModelConfig() app_model_config = app_model_config.from_model_config_dict(model_config_data) app_model_config.app_id = app.id + app_model_config.created_by = account.id + app_model_config.updated_by = account.id db.session.add(app_model_config) db.session.commit() app.app_model_config_id = app_model_config.id - app_model_config_was_updated.send( - app, - app_model_config=app_model_config - ) + app_model_config_was_updated.send(app, app_model_config=app_model_config) return app @classmethod - def _create_app(cls, - tenant_id: str, - app_mode: AppMode, - account: Account, - name: str, - description: str, - icon_type: str, - icon: str, - icon_background: str) -> App: + def _create_app( + cls, + tenant_id: str, + app_mode: AppMode, + account: Account, + name: str, + description: str, + icon_type: str, + icon: str, + icon_background: str, + use_icon_as_answer_icon: bool, + ) -> App: """ Create new app @@ -380,6 +391,7 @@ class AppDslService: :param icon_type: app icon type, "emoji" or "image" :param icon: app icon :param icon_background: app icon background + :param use_icon_as_answer_icon: use app icon as answer icon """ app = App( tenant_id=tenant_id, @@ -390,7 +402,10 @@ class AppDslService: icon=icon, icon_background=icon_background, enable_site=True, - enable_api=True + enable_api=True, + use_icon_as_answer_icon=use_icon_as_answer_icon, + created_by=account.id, + updated_by=account.id, ) db.session.add(app) @@ -412,7 +427,7 @@ class AppDslService: if not workflow: raise ValueError("Missing draft workflow configuration, please check.") - export_data['workflow'] = workflow.to_dict(include_secret=include_secret) + export_data["workflow"] = workflow.to_dict(include_secret=include_secret) @classmethod def _append_model_config_export_data(cls, export_data: dict, app_model: App) -> None: @@ -425,4 +440,4 @@ class AppDslService: if not app_model_config: raise ValueError("Missing app configuration, please check.") - export_data['model_config'] = app_model_config.to_dict() + export_data["model_config"] = app_model_config.to_dict() diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index cff4ba8af9..747505977f 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -1,6 +1,8 @@ from collections.abc import Generator from typing import Any, Union +from openai._exceptions import RateLimitError + from configs import dify_config from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator @@ -10,18 +12,20 @@ from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom from core.app.features.rate_limiting import RateLimit from models.model import Account, App, AppMode, EndUser +from services.errors.llm import InvokeRateLimitError from services.workflow_service import WorkflowService class AppGenerateService: - @classmethod - def generate(cls, app_model: App, - user: Union[Account, EndUser], - args: Any, - invoke_from: InvokeFrom, - streaming: bool = True, - ): + def generate( + cls, + app_model: App, + user: Union[Account, EndUser], + args: Any, + invoke_from: InvokeFrom, + streaming: bool = True, + ): """ App Content Generate :param app_model: app model @@ -37,51 +41,56 @@ class AppGenerateService: try: request_id = rate_limit.enter(request_id) if app_model.mode == AppMode.COMPLETION.value: - return rate_limit.generate(CompletionAppGenerator().generate( - app_model=app_model, - user=user, - args=args, - invoke_from=invoke_from, - stream=streaming - ), request_id) + return rate_limit.generate( + CompletionAppGenerator().generate( + app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming + ), + request_id, + ) elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: - return rate_limit.generate(AgentChatAppGenerator().generate( - app_model=app_model, - user=user, - args=args, - invoke_from=invoke_from, - stream=streaming - ), request_id) + return rate_limit.generate( + AgentChatAppGenerator().generate( + app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming + ), + request_id, + ) elif app_model.mode == AppMode.CHAT.value: - return rate_limit.generate(ChatAppGenerator().generate( - app_model=app_model, - user=user, - args=args, - invoke_from=invoke_from, - stream=streaming - ), request_id) + return rate_limit.generate( + ChatAppGenerator().generate( + app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming + ), + request_id, + ) elif app_model.mode == AppMode.ADVANCED_CHAT.value: workflow = cls._get_workflow(app_model, invoke_from) - return rate_limit.generate(AdvancedChatAppGenerator().generate( - app_model=app_model, - workflow=workflow, - user=user, - args=args, - invoke_from=invoke_from, - stream=streaming - ), request_id) + return rate_limit.generate( + AdvancedChatAppGenerator().generate( + app_model=app_model, + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + stream=streaming, + ), + request_id, + ) elif app_model.mode == AppMode.WORKFLOW.value: workflow = cls._get_workflow(app_model, invoke_from) - return rate_limit.generate(WorkflowAppGenerator().generate( - app_model=app_model, - workflow=workflow, - user=user, - args=args, - invoke_from=invoke_from, - stream=streaming - ), request_id) + return rate_limit.generate( + WorkflowAppGenerator().generate( + app_model=app_model, + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + stream=streaming, + ), + request_id, + ) else: - raise ValueError(f'Invalid app mode {app_model.mode}') + raise ValueError(f"Invalid app mode {app_model.mode}") + except RateLimitError as e: + raise InvokeRateLimitError(str(e)) finally: if not streaming: rate_limit.exit(request_id) @@ -94,38 +103,31 @@ class AppGenerateService: return max_active_requests @classmethod - def generate_single_iteration(cls, app_model: App, - user: Union[Account, EndUser], - node_id: str, - args: Any, - streaming: bool = True): + def generate_single_iteration( + cls, app_model: App, user: Union[Account, EndUser], node_id: str, args: Any, streaming: bool = True + ): if app_model.mode == AppMode.ADVANCED_CHAT.value: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return AdvancedChatAppGenerator().single_iteration_generate( - app_model=app_model, - workflow=workflow, - node_id=node_id, - user=user, - args=args, - stream=streaming + app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, stream=streaming ) elif app_model.mode == AppMode.WORKFLOW.value: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return WorkflowAppGenerator().single_iteration_generate( - app_model=app_model, - workflow=workflow, - node_id=node_id, - user=user, - args=args, - stream=streaming + app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, stream=streaming ) else: - raise ValueError(f'Invalid app mode {app_model.mode}') + raise ValueError(f"Invalid app mode {app_model.mode}") @classmethod - def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser], - message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \ - -> Union[dict, Generator]: + def generate_more_like_this( + cls, + app_model: App, + user: Union[Account, EndUser], + message_id: str, + invoke_from: InvokeFrom, + streaming: bool = True, + ) -> Union[dict, Generator]: """ Generate more like this :param app_model: app model @@ -136,11 +138,7 @@ class AppGenerateService: :return: """ return CompletionAppGenerator().generate_more_like_this( - app_model=app_model, - message_id=message_id, - user=user, - invoke_from=invoke_from, - stream=streaming + app_model=app_model, message_id=message_id, user=user, invoke_from=invoke_from, stream=streaming ) @classmethod @@ -157,12 +155,12 @@ class AppGenerateService: workflow = workflow_service.get_draft_workflow(app_model=app_model) if not workflow: - raise ValueError('Workflow not initialized') + raise ValueError("Workflow not initialized") else: # fetch published workflow by app_model workflow = workflow_service.get_published_workflow(app_model=app_model) if not workflow: - raise ValueError('Workflow not published') + raise ValueError("Workflow not published") return workflow diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index c84f6fbf45..a1ad271053 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -5,7 +5,6 @@ from models.model import AppMode class AppModelConfigService: - @classmethod def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> dict: if app_mode == AppMode.CHAT: diff --git a/api/services/app_service.py b/api/services/app_service.py index f33ef9b001..1dacfea246 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -33,27 +33,22 @@ class AppService: :param args: request args :return: """ - filters = [ - App.tenant_id == tenant_id, - App.is_universal == False - ] + filters = [App.tenant_id == tenant_id, App.is_universal == False] - if args['mode'] == 'workflow': + if args["mode"] == "workflow": filters.append(App.mode.in_([AppMode.WORKFLOW.value, AppMode.COMPLETION.value])) - elif args['mode'] == 'chat': + elif args["mode"] == "chat": filters.append(App.mode.in_([AppMode.CHAT.value, AppMode.ADVANCED_CHAT.value])) - elif args['mode'] == 'agent-chat': + elif args["mode"] == "agent-chat": filters.append(App.mode == AppMode.AGENT_CHAT.value) - elif args['mode'] == 'channel': + elif args["mode"] == "channel": filters.append(App.mode == AppMode.CHANNEL.value) - if args.get('name'): - name = args['name'][:30] - filters.append(App.name.ilike(f'%{name}%')) - if args.get('tag_ids'): - target_ids = TagService.get_target_ids_by_tag_ids('app', - tenant_id, - args['tag_ids']) + if args.get("name"): + name = args["name"][:30] + filters.append(App.name.ilike(f"%{name}%")) + if args.get("tag_ids"): + target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"]) if target_ids: filters.append(App.id.in_(target_ids)) else: @@ -61,9 +56,9 @@ class AppService: app_models = db.paginate( db.select(App).where(*filters).order_by(App.created_at.desc()), - page=args['page'], - per_page=args['limit'], - error_out=False + page=args["page"], + per_page=args["limit"], + error_out=False, ) return app_models @@ -75,21 +70,20 @@ class AppService: :param args: request args :param account: Account instance """ - app_mode = AppMode.value_of(args['mode']) + app_mode = AppMode.value_of(args["mode"]) app_template = default_app_templates[app_mode] # get model config - default_model_config = app_template.get('model_config') + default_model_config = app_template.get("model_config") default_model_config = default_model_config.copy() if default_model_config else None - if default_model_config and 'model' in default_model_config: + if default_model_config and "model" in default_model_config: # get model provider model_manager = ModelManager() # get default model instance try: model_instance = model_manager.get_default_model_instance( - tenant_id=account.current_tenant_id, - model_type=ModelType.LLM + tenant_id=account.current_tenant_id, model_type=ModelType.LLM ) except (ProviderTokenNotInitError, LLMBadRequestError): model_instance = None @@ -98,33 +92,43 @@ class AppService: model_instance = None if model_instance: - if model_instance.model == default_model_config['model']['name'] and model_instance.provider == default_model_config['model']['provider']: - default_model_dict = default_model_config['model'] + if ( + model_instance.model == default_model_config["model"]["name"] + and model_instance.provider == default_model_config["model"]["provider"] + ): + default_model_dict = default_model_config["model"] else: llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) default_model_dict = { - 'provider': model_instance.provider, - 'name': model_instance.model, - 'mode': model_schema.model_properties.get(ModelPropertyKey.MODE), - 'completion_params': {} + "provider": model_instance.provider, + "name": model_instance.model, + "mode": model_schema.model_properties.get(ModelPropertyKey.MODE), + "completion_params": {}, } else: - default_model_dict = default_model_config['model'] + provider, model = model_manager.get_default_provider_model_name( + tenant_id=account.current_tenant_id, model_type=ModelType.LLM + ) + default_model_config["model"]["provider"] = provider + default_model_config["model"]["name"] = model + default_model_dict = default_model_config["model"] - default_model_config['model'] = json.dumps(default_model_dict) + default_model_config["model"] = json.dumps(default_model_dict) - app = App(**app_template['app']) - app.name = args['name'] - app.description = args.get('description', '') - app.mode = args['mode'] - app.icon_type = args.get('icon_type', 'emoji') - app.icon = args['icon'] - app.icon_background = args['icon_background'] + app = App(**app_template["app"]) + app.name = args["name"] + app.description = args.get("description", "") + app.mode = args["mode"] + app.icon_type = args.get("icon_type", "emoji") + app.icon = args["icon"] + app.icon_background = args["icon_background"] app.tenant_id = tenant_id - app.api_rph = args.get('api_rph', 0) - app.api_rpm = args.get('api_rpm', 0) + app.api_rph = args.get("api_rph", 0) + app.api_rpm = args.get("api_rpm", 0) + app.created_by = account.id + app.updated_by = account.id db.session.add(app) db.session.flush() @@ -132,6 +136,8 @@ class AppService: if default_model_config: app_model_config = AppModelConfig(**default_model_config) app_model_config.app_id = app.id + app_model_config.created_by = account.id + app_model_config.updated_by = account.id db.session.add(app_model_config) db.session.flush() @@ -152,7 +158,7 @@ class AppService: model_config: AppModelConfig = app.app_model_config agent_mode = model_config.agent_mode_dict # decrypt agent tool parameters if it's secret-input - for tool in agent_mode.get('tools') or []: + for tool in agent_mode.get("tools") or []: if not isinstance(tool, dict) or len(tool.keys()) <= 3: continue agent_tool_entity = AgentToolEntity(**tool) @@ -168,7 +174,7 @@ class AppService: tool_runtime=tool_runtime, provider_name=agent_tool_entity.provider_id, provider_type=agent_tool_entity.provider_type, - identity_id=f'AGENT.{app.id}' + identity_id=f"AGENT.{app.id}", ) # get decrypted parameters @@ -179,7 +185,7 @@ class AppService: masked_parameter = {} # override tool parameters - tool['tool_parameters'] = masked_parameter + tool["tool_parameters"] = masked_parameter except Exception as e: pass @@ -190,13 +196,14 @@ class AppService: """ Modified App class """ + def __init__(self, app): self.__dict__.update(app.__dict__) @property def app_model_config(self): return model_config - + app = ModifiedApp(app) return app @@ -208,12 +215,14 @@ class AppService: :param args: request args :return: App instance """ - app.name = args.get('name') - app.description = args.get('description', '') - app.max_active_requests = args.get('max_active_requests') - app.icon_type = args.get('icon_type', 'emoji') - app.icon = args.get('icon') - app.icon_background = args.get('icon_background') + app.name = args.get("name") + app.description = args.get("description", "") + app.max_active_requests = args.get("max_active_requests") + app.icon_type = args.get("icon_type", "emoji") + app.icon = args.get("icon") + app.icon_background = args.get("icon_background") + app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False) + app.updated_by = current_user.id app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() @@ -230,6 +239,7 @@ class AppService: :return: App instance """ app.name = name + app.updated_by = current_user.id app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() @@ -245,6 +255,7 @@ class AppService: """ app.icon = icon app.icon_background = icon_background + app.updated_by = current_user.id app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() @@ -261,6 +272,7 @@ class AppService: return app app.enable_site = enable_site + app.updated_by = current_user.id app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() @@ -277,6 +289,7 @@ class AppService: return app app.enable_api = enable_api + app.updated_by = current_user.id app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() @@ -291,10 +304,7 @@ class AppService: db.session.commit() # Trigger asynchronous deletion of app and related data - remove_app_and_related_data_task.delay( - tenant_id=app.tenant_id, - app_id=app.id - ) + remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id) def get_app_meta(self, app_model: App) -> dict: """ @@ -304,9 +314,7 @@ class AppService: """ app_mode = AppMode.value_of(app_model.mode) - meta = { - 'tool_icons': {} - } + meta = {"tool_icons": {}} if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: workflow = app_model.workflow @@ -314,17 +322,19 @@ class AppService: return meta graph = workflow.graph_dict - nodes = graph.get('nodes', []) + nodes = graph.get("nodes", []) tools = [] for node in nodes: - if node.get('data', {}).get('type') == 'tool': - node_data = node.get('data', {}) - tools.append({ - 'provider_type': node_data.get('provider_type'), - 'provider_id': node_data.get('provider_id'), - 'tool_name': node_data.get('tool_name'), - 'tool_parameters': {} - }) + if node.get("data", {}).get("type") == "tool": + node_data = node.get("data", {}) + tools.append( + { + "provider_type": node_data.get("provider_type"), + "provider_id": node_data.get("provider_id"), + "tool_name": node_data.get("tool_name"), + "tool_parameters": {}, + } + ) else: app_model_config: AppModelConfig = app_model.app_model_config @@ -334,30 +344,26 @@ class AppService: agent_config = app_model_config.agent_mode_dict or {} # get all tools - tools = agent_config.get('tools', []) + tools = agent_config.get("tools", []) - url_prefix = (dify_config.CONSOLE_API_URL - + "/console/api/workspaces/current/tool-provider/builtin/") + url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/" for tool in tools: keys = list(tool.keys()) if len(keys) >= 4: # current tool standard - provider_type = tool.get('provider_type') - provider_id = tool.get('provider_id') - tool_name = tool.get('tool_name') - if provider_type == 'builtin': - meta['tool_icons'][tool_name] = url_prefix + provider_id + '/icon' - elif provider_type == 'api': + provider_type = tool.get("provider_type") + provider_id = tool.get("provider_id") + tool_name = tool.get("tool_name") + if provider_type == "builtin": + meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon" + elif provider_type == "api": try: - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.id == provider_id - ).first() - meta['tool_icons'][tool_name] = json.loads(provider.icon) + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider).filter(ApiToolProvider.id == provider_id).first() + ) + meta["tool_icons"][tool_name] = json.loads(provider.icon) except: - meta['tool_icons'][tool_name] = { - "background": "#252525", - "content": "\ud83d\ude01" - } + meta["tool_icons"][tool_name] = {"background": "#252525", "content": "\ud83d\ude01"} return meta diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 58c950816f..05cd1c96a1 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -17,7 +17,7 @@ from services.errors.audio import ( FILE_SIZE = 30 FILE_SIZE_LIMIT = FILE_SIZE * 1024 * 1024 -ALLOWED_EXTENSIONS = ['mp3', 'mp4', 'mpeg', 'mpga', 'm4a', 'wav', 'webm', 'amr'] +ALLOWED_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm", "amr"] logger = logging.getLogger(__name__) @@ -31,19 +31,19 @@ class AudioService: raise ValueError("Speech to text is not enabled") features_dict = workflow.features_dict - if 'speech_to_text' not in features_dict or not features_dict['speech_to_text'].get('enabled'): + if "speech_to_text" not in features_dict or not features_dict["speech_to_text"].get("enabled"): raise ValueError("Speech to text is not enabled") else: app_model_config: AppModelConfig = app_model.app_model_config - if not app_model_config.speech_to_text_dict['enabled']: + if not app_model_config.speech_to_text_dict["enabled"]: raise ValueError("Speech to text is not enabled") if file is None: raise NoAudioUploadedServiceError() extension = file.mimetype - if extension not in [f'audio/{ext}' for ext in ALLOWED_EXTENSIONS]: + if extension not in [f"audio/{ext}" for ext in ALLOWED_EXTENSIONS]: raise UnsupportedAudioTypeServiceError() file_content = file.read() @@ -55,20 +55,25 @@ class AudioService: model_manager = ModelManager() model_instance = model_manager.get_default_model_instance( - tenant_id=app_model.tenant_id, - model_type=ModelType.SPEECH2TEXT + tenant_id=app_model.tenant_id, model_type=ModelType.SPEECH2TEXT ) if model_instance is None: raise ProviderNotSupportSpeechToTextServiceError() buffer = io.BytesIO(file_content) - buffer.name = 'temp.mp3' + buffer.name = "temp.mp3" return {"text": model_instance.invoke_speech2text(file=buffer, user=end_user)} @classmethod - def transcript_tts(cls, app_model: App, text: Optional[str] = None, - voice: Optional[str] = None, end_user: Optional[str] = None, message_id: Optional[str] = None): + def transcript_tts( + cls, + app_model: App, + text: Optional[str] = None, + voice: Optional[str] = None, + end_user: Optional[str] = None, + message_id: Optional[str] = None, + ): from collections.abc import Generator from flask import Response, stream_with_context @@ -84,65 +89,56 @@ class AudioService: raise ValueError("TTS is not enabled") features_dict = workflow.features_dict - if 'text_to_speech' not in features_dict or not features_dict['text_to_speech'].get('enabled'): + if "text_to_speech" not in features_dict or not features_dict["text_to_speech"].get("enabled"): raise ValueError("TTS is not enabled") - voice = features_dict['text_to_speech'].get('voice') if voice is None else voice + voice = features_dict["text_to_speech"].get("voice") if voice is None else voice else: text_to_speech_dict = app_model.app_model_config.text_to_speech_dict - if not text_to_speech_dict.get('enabled'): + if not text_to_speech_dict.get("enabled"): raise ValueError("TTS is not enabled") - voice = text_to_speech_dict.get('voice') if voice is None else voice + voice = text_to_speech_dict.get("voice") if voice is None else voice model_manager = ModelManager() model_instance = model_manager.get_default_model_instance( - tenant_id=app_model.tenant_id, - model_type=ModelType.TTS + tenant_id=app_model.tenant_id, model_type=ModelType.TTS ) try: if not voice: voices = model_instance.get_tts_voices() if voices: - voice = voices[0].get('value') + voice = voices[0].get("value") else: raise ValueError("Sorry, no voice available.") return model_instance.invoke_tts( - content_text=text_content.strip(), - user=end_user, - tenant_id=app_model.tenant_id, - voice=voice + content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice ) except Exception as e: raise e if message_id: - message = db.session.query(Message).filter( - Message.id == message_id - ).first() - if message.answer == '' and message.status == 'normal': + message = db.session.query(Message).filter(Message.id == message_id).first() + if message.answer == "" and message.status == "normal": return None else: response = invoke_tts(message.answer, app_model=app_model, voice=voice) if isinstance(response, Generator): - return Response(stream_with_context(response), content_type='audio/mpeg') + return Response(stream_with_context(response), content_type="audio/mpeg") return response else: response = invoke_tts(text, app_model, voice) if isinstance(response, Generator): - return Response(stream_with_context(response), content_type='audio/mpeg') + return Response(stream_with_context(response), content_type="audio/mpeg") return response @classmethod def transcript_tts_voices(cls, tenant_id: str, language: str): model_manager = ModelManager() - model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, - model_type=ModelType.TTS - ) + model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.TTS) if model_instance is None: raise ProviderNotSupportTextToSpeechServiceError() diff --git a/api/services/auth/api_key_auth_factory.py b/api/services/auth/api_key_auth_factory.py index ccd0023c44..ae5b953b47 100644 --- a/api/services/auth/api_key_auth_factory.py +++ b/api/services/auth/api_key_auth_factory.py @@ -1,14 +1,12 @@ - from services.auth.firecrawl import FirecrawlAuth class ApiKeyAuthFactory: - def __init__(self, provider: str, credentials: dict): - if provider == 'firecrawl': + if provider == "firecrawl": self.auth = FirecrawlAuth(credentials) else: - raise ValueError('Invalid provider') + raise ValueError("Invalid provider") def validate_credentials(self): return self.auth.validate_credentials() diff --git a/api/services/auth/api_key_auth_service.py b/api/services/auth/api_key_auth_service.py index 43d0fbf98f..e5f4a3ef6e 100644 --- a/api/services/auth/api_key_auth_service.py +++ b/api/services/auth/api_key_auth_service.py @@ -7,39 +7,43 @@ from services.auth.api_key_auth_factory import ApiKeyAuthFactory class ApiKeyAuthService: - @staticmethod def get_provider_auth_list(tenant_id: str) -> list: - data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter( - DataSourceApiKeyAuthBinding.tenant_id == tenant_id, - DataSourceApiKeyAuthBinding.disabled.is_(False) - ).all() + data_source_api_key_bindings = ( + db.session.query(DataSourceApiKeyAuthBinding) + .filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)) + .all() + ) return data_source_api_key_bindings @staticmethod def create_provider_auth(tenant_id: str, args: dict): - auth_result = ApiKeyAuthFactory(args['provider'], args['credentials']).validate_credentials() + auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials() if auth_result: # Encrypt the api key - api_key = encrypter.encrypt_token(tenant_id, args['credentials']['config']['api_key']) - args['credentials']['config']['api_key'] = api_key + api_key = encrypter.encrypt_token(tenant_id, args["credentials"]["config"]["api_key"]) + args["credentials"]["config"]["api_key"] = api_key data_source_api_key_binding = DataSourceApiKeyAuthBinding() data_source_api_key_binding.tenant_id = tenant_id - data_source_api_key_binding.category = args['category'] - data_source_api_key_binding.provider = args['provider'] - data_source_api_key_binding.credentials = json.dumps(args['credentials'], ensure_ascii=False) + data_source_api_key_binding.category = args["category"] + data_source_api_key_binding.provider = args["provider"] + data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False) db.session.add(data_source_api_key_binding) db.session.commit() @staticmethod def get_auth_credentials(tenant_id: str, category: str, provider: str): - data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter( - DataSourceApiKeyAuthBinding.tenant_id == tenant_id, - DataSourceApiKeyAuthBinding.category == category, - DataSourceApiKeyAuthBinding.provider == provider, - DataSourceApiKeyAuthBinding.disabled.is_(False) - ).first() + data_source_api_key_bindings = ( + db.session.query(DataSourceApiKeyAuthBinding) + .filter( + DataSourceApiKeyAuthBinding.tenant_id == tenant_id, + DataSourceApiKeyAuthBinding.category == category, + DataSourceApiKeyAuthBinding.provider == provider, + DataSourceApiKeyAuthBinding.disabled.is_(False), + ) + .first() + ) if not data_source_api_key_bindings: return None credentials = json.loads(data_source_api_key_bindings.credentials) @@ -47,24 +51,24 @@ class ApiKeyAuthService: @staticmethod def delete_provider_auth(tenant_id: str, binding_id: str): - data_source_api_key_binding = db.session.query(DataSourceApiKeyAuthBinding).filter( - DataSourceApiKeyAuthBinding.tenant_id == tenant_id, - DataSourceApiKeyAuthBinding.id == binding_id - ).first() + data_source_api_key_binding = ( + db.session.query(DataSourceApiKeyAuthBinding) + .filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id) + .first() + ) if data_source_api_key_binding: db.session.delete(data_source_api_key_binding) db.session.commit() @classmethod def validate_api_key_auth_args(cls, args): - if 'category' not in args or not args['category']: - raise ValueError('category is required') - if 'provider' not in args or not args['provider']: - raise ValueError('provider is required') - if 'credentials' not in args or not args['credentials']: - raise ValueError('credentials is required') - if not isinstance(args['credentials'], dict): - raise ValueError('credentials must be a dictionary') - if 'auth_type' not in args['credentials'] or not args['credentials']['auth_type']: - raise ValueError('auth_type is required') - + if "category" not in args or not args["category"]: + raise ValueError("category is required") + if "provider" not in args or not args["provider"]: + raise ValueError("provider is required") + if "credentials" not in args or not args["credentials"]: + raise ValueError("credentials is required") + if not isinstance(args["credentials"], dict): + raise ValueError("credentials must be a dictionary") + if "auth_type" not in args["credentials"] or not args["credentials"]["auth_type"]: + raise ValueError("auth_type is required") diff --git a/api/services/auth/firecrawl.py b/api/services/auth/firecrawl.py index 69e3fb43c7..30e4ee57c0 100644 --- a/api/services/auth/firecrawl.py +++ b/api/services/auth/firecrawl.py @@ -8,49 +8,40 @@ from services.auth.api_key_auth_base import ApiKeyAuthBase class FirecrawlAuth(ApiKeyAuthBase): def __init__(self, credentials: dict): super().__init__(credentials) - auth_type = credentials.get('auth_type') - if auth_type != 'bearer': - raise ValueError('Invalid auth type, Firecrawl auth type must be Bearer') - self.api_key = credentials.get('config').get('api_key', None) - self.base_url = credentials.get('config').get('base_url', 'https://api.firecrawl.dev') + auth_type = credentials.get("auth_type") + if auth_type != "bearer": + raise ValueError("Invalid auth type, Firecrawl auth type must be Bearer") + self.api_key = credentials.get("config").get("api_key", None) + self.base_url = credentials.get("config").get("base_url", "https://api.firecrawl.dev") if not self.api_key: - raise ValueError('No API key provided') + raise ValueError("No API key provided") def validate_credentials(self): headers = self._prepare_headers() options = { - 'url': 'https://example.com', - 'crawlerOptions': { - 'excludes': [], - 'includes': [], - 'limit': 1 - }, - 'pageOptions': { - 'onlyMainContent': True - } + "url": "https://example.com", + "crawlerOptions": {"excludes": [], "includes": [], "limit": 1}, + "pageOptions": {"onlyMainContent": True}, } - response = self._post_request(f'{self.base_url}/v0/crawl', options, headers) + response = self._post_request(f"{self.base_url}/v0/crawl", options, headers) if response.status_code == 200: return True else: self._handle_error(response) def _prepare_headers(self): - return { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}' - } + return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} def _post_request(self, url, data, headers): return requests.post(url, headers=headers, json=data) def _handle_error(self, response): if response.status_code in [402, 409, 500]: - error_message = response.json().get('error', 'Unknown error occurred') - raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}') + error_message = response.json().get("error", "Unknown error occurred") + raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") else: if response.text: - error_message = json.loads(response.text).get('error', 'Unknown error occurred') - raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}') - raise Exception(f'Unexpected error occurred while trying to authorize. Status code: {response.status_code}') + error_message = json.loads(response.text).get("error", "Unknown error occurred") + raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") + raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}") diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 539f2712bb..911d234641 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -7,58 +7,40 @@ from models.account import TenantAccountJoin, TenantAccountRole class BillingService: - base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL') - secret_key = os.environ.get('BILLING_API_SECRET_KEY', 'BILLING_API_SECRET_KEY') + base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL") + secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY") @classmethod def get_info(cls, tenant_id: str): - params = {'tenant_id': tenant_id} + params = {"tenant_id": tenant_id} - billing_info = cls._send_request('GET', '/subscription/info', params=params) + billing_info = cls._send_request("GET", "/subscription/info", params=params) return billing_info @classmethod - def get_subscription(cls, plan: str, - interval: str, - prefilled_email: str = '', - tenant_id: str = ''): - params = { - 'plan': plan, - 'interval': interval, - 'prefilled_email': prefilled_email, - 'tenant_id': tenant_id - } - return cls._send_request('GET', '/subscription/payment-link', params=params) + def get_subscription(cls, plan: str, interval: str, prefilled_email: str = "", tenant_id: str = ""): + params = {"plan": plan, "interval": interval, "prefilled_email": prefilled_email, "tenant_id": tenant_id} + return cls._send_request("GET", "/subscription/payment-link", params=params) @classmethod - def get_model_provider_payment_link(cls, - provider_name: str, - tenant_id: str, - account_id: str, - prefilled_email: str): + def get_model_provider_payment_link(cls, provider_name: str, tenant_id: str, account_id: str, prefilled_email: str): params = { - 'provider_name': provider_name, - 'tenant_id': tenant_id, - 'account_id': account_id, - 'prefilled_email': prefilled_email + "provider_name": provider_name, + "tenant_id": tenant_id, + "account_id": account_id, + "prefilled_email": prefilled_email, } - return cls._send_request('GET', '/model-provider/payment-link', params=params) + return cls._send_request("GET", "/model-provider/payment-link", params=params) @classmethod - def get_invoices(cls, prefilled_email: str = '', tenant_id: str = ''): - params = { - 'prefilled_email': prefilled_email, - 'tenant_id': tenant_id - } - return cls._send_request('GET', '/invoices', params=params) + def get_invoices(cls, prefilled_email: str = "", tenant_id: str = ""): + params = {"prefilled_email": prefilled_email, "tenant_id": tenant_id} + return cls._send_request("GET", "/invoices", params=params) @classmethod def _send_request(cls, method, endpoint, json=None, params=None): - headers = { - "Content-Type": "application/json", - "Billing-Api-Secret-Key": cls.secret_key - } + headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} url = f"{cls.base_url}{endpoint}" response = requests.request(method, url, json=json, params=params, headers=headers) @@ -69,10 +51,11 @@ class BillingService: def is_tenant_owner_or_admin(current_user): tenant_id = current_user.current_tenant_id - join = db.session.query(TenantAccountJoin).filter( - TenantAccountJoin.tenant_id == tenant_id, - TenantAccountJoin.account_id == current_user.id - ).first() + join = ( + db.session.query(TenantAccountJoin) + .filter(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id) + .first() + ) if not TenantAccountRole.is_privileged_role(join.role): - raise ValueError('Only team owner or team admin can perform this action') + raise ValueError("Only team owner or team admin can perform this action") diff --git a/api/services/code_based_extension_service.py b/api/services/code_based_extension_service.py index 7b0d50a835..f7597b7f1f 100644 --- a/api/services/code_based_extension_service.py +++ b/api/services/code_based_extension_service.py @@ -2,12 +2,15 @@ from extensions.ext_code_based_extension import code_based_extension class CodeBasedExtensionService: - @staticmethod def get_code_based_extension(module: str) -> list[dict]: module_extensions = code_based_extension.module_extensions(module) - return [{ - 'name': module_extension.name, - 'label': module_extension.label, - 'form_schema': module_extension.form_schema - } for module_extension in module_extensions if not module_extension.builtin] + return [ + { + "name": module_extension.name, + "label": module_extension.label, + "form_schema": module_extension.form_schema, + } + for module_extension in module_extensions + if not module_extension.builtin + ] diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 82ee10ee78..7bfe59afa0 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,6 +1,7 @@ +from datetime import datetime, timezone from typing import Optional, Union -from sqlalchemy import or_ +from sqlalchemy import asc, desc, or_ from core.app.entities.app_invoke_entities import InvokeFrom from core.llm_generator.llm_generator import LLMGenerator @@ -14,21 +15,27 @@ from services.errors.message import MessageNotExistsError class ConversationService: @classmethod - def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], - last_id: Optional[str], limit: int, - invoke_from: InvokeFrom, - include_ids: Optional[list] = None, - exclude_ids: Optional[list] = None) -> InfiniteScrollPagination: + def pagination_by_last_id( + cls, + app_model: App, + user: Optional[Union[Account, EndUser]], + last_id: Optional[str], + limit: int, + invoke_from: InvokeFrom, + include_ids: Optional[list] = None, + exclude_ids: Optional[list] = None, + sort_by: str = "-updated_at", + ) -> InfiniteScrollPagination: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) base_query = db.session.query(Conversation).filter( Conversation.is_deleted == False, Conversation.app_id == app_model.id, - Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'), + Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"), Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None), Conversation.from_account_id == (user.id if isinstance(user, Account) else None), - or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value) + or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value), ) if include_ids is not None: @@ -37,47 +44,67 @@ class ConversationService: if exclude_ids is not None: base_query = base_query.filter(~Conversation.id.in_(exclude_ids)) - if last_id: - last_conversation = base_query.filter( - Conversation.id == last_id, - ).first() + # define sort fields and directions + sort_field, sort_direction = cls._get_sort_params(sort_by) + if last_id: + last_conversation = base_query.filter(Conversation.id == last_id).first() if not last_conversation: raise LastConversationNotExistsError() - conversations = base_query.filter( - Conversation.created_at < last_conversation.created_at, - Conversation.id != last_conversation.id - ).order_by(Conversation.created_at.desc()).limit(limit).all() - else: - conversations = base_query.order_by(Conversation.created_at.desc()).limit(limit).all() + # build filters based on sorting + filter_condition = cls._build_filter_condition(sort_field, sort_direction, last_conversation) + base_query = base_query.filter(filter_condition) + + base_query = base_query.order_by(sort_direction(getattr(Conversation, sort_field))) + + conversations = base_query.limit(limit).all() has_more = False if len(conversations) == limit: - current_page_first_conversation = conversations[-1] - rest_count = base_query.filter( - Conversation.created_at < current_page_first_conversation.created_at, - Conversation.id != current_page_first_conversation.id - ).count() + current_page_last_conversation = conversations[-1] + rest_filter_condition = cls._build_filter_condition( + sort_field, sort_direction, current_page_last_conversation, is_next_page=True + ) + rest_count = base_query.filter(rest_filter_condition).count() if rest_count > 0: has_more = True - return InfiniteScrollPagination( - data=conversations, - limit=limit, - has_more=has_more - ) + return InfiniteScrollPagination(data=conversations, limit=limit, has_more=has_more) @classmethod - def rename(cls, app_model: App, conversation_id: str, - user: Optional[Union[Account, EndUser]], name: str, auto_generate: bool): + def _get_sort_params(cls, sort_by: str) -> tuple[str, callable]: + if sort_by.startswith("-"): + return sort_by[1:], desc + return sort_by, asc + + @classmethod + def _build_filter_condition( + cls, sort_field: str, sort_direction: callable, reference_conversation: Conversation, is_next_page: bool = False + ): + field_value = getattr(reference_conversation, sort_field) + if (sort_direction == desc and not is_next_page) or (sort_direction == asc and is_next_page): + return getattr(Conversation, sort_field) < field_value + else: + return getattr(Conversation, sort_field) > field_value + + @classmethod + def rename( + cls, + app_model: App, + conversation_id: str, + user: Optional[Union[Account, EndUser]], + name: str, + auto_generate: bool, + ): conversation = cls.get_conversation(app_model, conversation_id, user) if auto_generate: return cls.auto_generate_name(app_model, conversation) else: conversation.name = name + conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() return conversation @@ -85,11 +112,12 @@ class ConversationService: @classmethod def auto_generate_name(cls, app_model: App, conversation: Conversation): # get conversation first message - message = db.session.query(Message) \ - .filter( - Message.app_id == app_model.id, - Message.conversation_id == conversation.id - ).order_by(Message.created_at.asc()).first() + message = ( + db.session.query(Message) + .filter(Message.app_id == app_model.id, Message.conversation_id == conversation.id) + .order_by(Message.created_at.asc()) + .first() + ) if not message: raise MessageNotExistsError() @@ -109,15 +137,18 @@ class ConversationService: @classmethod def get_conversation(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): - conversation = db.session.query(Conversation) \ + conversation = ( + db.session.query(Conversation) .filter( - Conversation.id == conversation_id, - Conversation.app_id == app_model.id, - Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'), - Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None), - Conversation.from_account_id == (user.id if isinstance(user, Account) else None), - Conversation.is_deleted == False - ).first() + Conversation.id == conversation_id, + Conversation.app_id == app_model.id, + Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"), + Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None), + Conversation.from_account_id == (user.id if isinstance(user, Account) else None), + Conversation.is_deleted == False, + ) + .first() + ) if not conversation: raise ConversationNotExistsError() diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 12ae0e39a8..ad552e1bab 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -27,6 +27,7 @@ from models.dataset import ( Dataset, DatasetCollectionBinding, DatasetPermission, + DatasetPermissionEnum, DatasetProcessRule, DatasetQuery, Document, @@ -54,7 +55,6 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde class DatasetService: - @staticmethod def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None, search=None, tag_ids=None): query = Dataset.query.filter(Dataset.provider == provider, Dataset.tenant_id == tenant_id).order_by( @@ -63,10 +63,7 @@ class DatasetService: if user: # get permitted dataset ids - dataset_permission = DatasetPermission.query.filter_by( - account_id=user.id, - tenant_id=tenant_id - ).all() + dataset_permission = DatasetPermission.query.filter_by(account_id=user.id, tenant_id=tenant_id).all() permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None if user.current_role == TenantAccountRole.DATASET_OPERATOR: @@ -80,83 +77,76 @@ class DatasetService: if permitted_dataset_ids: query = query.filter( db.or_( - Dataset.permission == 'all_team_members', - db.and_(Dataset.permission == 'only_me', Dataset.created_by == user.id), - db.and_(Dataset.permission == 'partial_members', Dataset.id.in_(permitted_dataset_ids)) + Dataset.permission == DatasetPermissionEnum.ALL_TEAM, + db.and_(Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id), + db.and_( + Dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM, + Dataset.id.in_(permitted_dataset_ids), + ), ) ) else: query = query.filter( db.or_( - Dataset.permission == 'all_team_members', - db.and_(Dataset.permission == 'only_me', Dataset.created_by == user.id) + Dataset.permission == DatasetPermissionEnum.ALL_TEAM, + db.and_(Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id), ) ) else: # if no user, only show datasets that are shared with all team members - query = query.filter(Dataset.permission == 'all_team_members') + query = query.filter(Dataset.permission == DatasetPermissionEnum.ALL_TEAM) if search: - query = query.filter(Dataset.name.ilike(f'%{search}%')) + query = query.filter(Dataset.name.ilike(f"%{search}%")) if tag_ids: - target_ids = TagService.get_target_ids_by_tag_ids('knowledge', tenant_id, tag_ids) + target_ids = TagService.get_target_ids_by_tag_ids("knowledge", tenant_id, tag_ids) if target_ids: query = query.filter(Dataset.id.in_(target_ids)) else: return [], 0 - datasets = query.paginate( - page=page, - per_page=per_page, - max_per_page=100, - error_out=False - ) + datasets = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False) return datasets.items, datasets.total @staticmethod def get_process_rules(dataset_id): # get the latest process rule - dataset_process_rule = db.session.query(DatasetProcessRule). \ - filter(DatasetProcessRule.dataset_id == dataset_id). \ - order_by(DatasetProcessRule.created_at.desc()). \ - limit(1). \ - one_or_none() + dataset_process_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.dataset_id == dataset_id) + .order_by(DatasetProcessRule.created_at.desc()) + .limit(1) + .one_or_none() + ) if dataset_process_rule: mode = dataset_process_rule.mode rules = dataset_process_rule.rules_dict else: - mode = DocumentService.DEFAULT_RULES['mode'] - rules = DocumentService.DEFAULT_RULES['rules'] - return { - 'mode': mode, - 'rules': rules - } + mode = DocumentService.DEFAULT_RULES["mode"] + rules = DocumentService.DEFAULT_RULES["rules"] + return {"mode": mode, "rules": rules} @staticmethod def get_datasets_by_ids(ids, tenant_id): - datasets = Dataset.query.filter( - Dataset.id.in_(ids), - Dataset.tenant_id == tenant_id - ).paginate( + datasets = Dataset.query.filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id).paginate( page=1, per_page=len(ids), max_per_page=len(ids), error_out=False ) return datasets.items, datasets.total @staticmethod - def create_empty_dataset(tenant_id: str, name: str, indexing_technique: Optional[str], account: Account): + def create_empty_dataset( + tenant_id: str, name: str, indexing_technique: Optional[str], account: Account, permission: Optional[str] = None + ): # check if dataset name already exists if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first(): - raise DatasetNameDuplicateError( - f'Dataset with name {name} already exists.' - ) + raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.") embedding_model = None - if indexing_technique == 'high_quality': + if indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_default_model_instance( - tenant_id=tenant_id, - model_type=ModelType.TEXT_EMBEDDING + tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING ) dataset = Dataset(name=name, indexing_technique=indexing_technique) # dataset = Dataset(name=name, provider=provider, config=config) @@ -165,26 +155,25 @@ class DatasetService: dataset.tenant_id = tenant_id dataset.embedding_model_provider = embedding_model.provider if embedding_model else None dataset.embedding_model = embedding_model.model if embedding_model else None + dataset.permission = permission if permission else DatasetPermissionEnum.ONLY_ME db.session.add(dataset) db.session.commit() return dataset @staticmethod def get_dataset(dataset_id): - return Dataset.query.filter_by( - id=dataset_id - ).first() + return Dataset.query.filter_by(id=dataset_id).first() @staticmethod def check_dataset_model_setting(dataset): - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": try: model_manager = ModelManager() model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ValueError( @@ -192,65 +181,56 @@ class DatasetService: "in the Settings -> Model Provider." ) except ProviderTokenNotInitError as ex: - raise ValueError( - f"The dataset in unavailable, due to: " - f"{ex.description}" - ) + raise ValueError(f"The dataset in unavailable, due to: " f"{ex.description}") @staticmethod - def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model:str): + def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str): try: model_manager = ModelManager() model_manager.get_model_instance( tenant_id=tenant_id, provider=embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=embedding_model + model=embedding_model, ) except LLMBadRequestError: raise ValueError( - "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider." + "No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider." ) except ProviderTokenNotInitError as ex: - raise ValueError( - f"The dataset in unavailable, due to: " - f"{ex.description}" - ) - + raise ValueError(f"The dataset in unavailable, due to: " f"{ex.description}") @staticmethod def update_dataset(dataset_id, data, user): - data.pop('partial_member_list', None) - filtered_data = {k: v for k, v in data.items() if v is not None or k == 'description'} + data.pop("partial_member_list", None) + filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"} dataset = DatasetService.get_dataset(dataset_id) DatasetService.check_dataset_permission(dataset, user) action = None - if dataset.indexing_technique != data['indexing_technique']: + if dataset.indexing_technique != data["indexing_technique"]: # if update indexing_technique - if data['indexing_technique'] == 'economy': - action = 'remove' - filtered_data['embedding_model'] = None - filtered_data['embedding_model_provider'] = None - filtered_data['collection_binding_id'] = None - elif data['indexing_technique'] == 'high_quality': - action = 'add' + if data["indexing_technique"] == "economy": + action = "remove" + filtered_data["embedding_model"] = None + filtered_data["embedding_model_provider"] = None + filtered_data["collection_binding_id"] = None + elif data["indexing_technique"] == "high_quality": + action = "add" # get embedding model setting try: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, - provider=data['embedding_model_provider'], + provider=data["embedding_model_provider"], model_type=ModelType.TEXT_EMBEDDING, - model=data['embedding_model'] + model=data["embedding_model"], ) - filtered_data['embedding_model'] = embedding_model.model - filtered_data['embedding_model_provider'] = embedding_model.provider + filtered_data["embedding_model"] = embedding_model.model + filtered_data["embedding_model_provider"] = embedding_model.provider dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, - embedding_model.model + embedding_model.provider, embedding_model.model ) - filtered_data['collection_binding_id'] = dataset_collection_binding.id + filtered_data["collection_binding_id"] = dataset_collection_binding.id except LLMBadRequestError: raise ValueError( "No Embedding Model available. Please configure a valid provider " @@ -259,24 +239,25 @@ class DatasetService: except ProviderTokenNotInitError as ex: raise ValueError(ex.description) else: - if data['embedding_model_provider'] != dataset.embedding_model_provider or \ - data['embedding_model'] != dataset.embedding_model: - action = 'update' + if ( + data["embedding_model_provider"] != dataset.embedding_model_provider + or data["embedding_model"] != dataset.embedding_model + ): + action = "update" try: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, - provider=data['embedding_model_provider'], + provider=data["embedding_model_provider"], model_type=ModelType.TEXT_EMBEDDING, - model=data['embedding_model'] + model=data["embedding_model"], ) - filtered_data['embedding_model'] = embedding_model.model - filtered_data['embedding_model_provider'] = embedding_model.provider + filtered_data["embedding_model"] = embedding_model.model + filtered_data["embedding_model_provider"] = embedding_model.provider dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, - embedding_model.model + embedding_model.provider, embedding_model.model ) - filtered_data['collection_binding_id'] = dataset_collection_binding.id + filtered_data["collection_binding_id"] = dataset_collection_binding.id except LLMBadRequestError: raise ValueError( "No Embedding Model available. Please configure a valid provider " @@ -285,11 +266,11 @@ class DatasetService: except ProviderTokenNotInitError as ex: raise ValueError(ex.description) - filtered_data['updated_by'] = user.id - filtered_data['updated_at'] = datetime.datetime.now() + filtered_data["updated_by"] = user.id + filtered_data["updated_at"] = datetime.datetime.now() # update Retrieval model - filtered_data['retrieval_model'] = data['retrieval_model'] + filtered_data["retrieval_model"] = data["retrieval_model"] dataset.query.filter_by(id=dataset_id).update(filtered_data) @@ -300,7 +281,6 @@ class DatasetService: @staticmethod def delete_dataset(dataset_id, user): - dataset = DatasetService.get_dataset(dataset_id) if dataset is None: @@ -324,72 +304,57 @@ class DatasetService: @staticmethod def check_dataset_permission(dataset, user): if dataset.tenant_id != user.current_tenant_id: - logging.debug( - f'User {user.id} does not have permission to access dataset {dataset.id}' - ) - raise NoPermissionError( - 'You do not have permission to access this dataset.' - ) - if dataset.permission == 'only_me' and dataset.created_by != user.id: - logging.debug( - f'User {user.id} does not have permission to access dataset {dataset.id}' - ) - raise NoPermissionError( - 'You do not have permission to access this dataset.' - ) - if dataset.permission == 'partial_members': - user_permission = DatasetPermission.query.filter_by( - dataset_id=dataset.id, account_id=user.id - ).first() + logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") + raise NoPermissionError("You do not have permission to access this dataset.") + if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id: + logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") + raise NoPermissionError("You do not have permission to access this dataset.") + if dataset.permission == "partial_members": + user_permission = DatasetPermission.query.filter_by(dataset_id=dataset.id, account_id=user.id).first() if not user_permission and dataset.tenant_id != user.current_tenant_id and dataset.created_by != user.id: - logging.debug( - f'User {user.id} does not have permission to access dataset {dataset.id}' - ) - raise NoPermissionError( - 'You do not have permission to access this dataset.' - ) + logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") + raise NoPermissionError("You do not have permission to access this dataset.") @staticmethod def check_dataset_operator_permission(user: Account = None, dataset: Dataset = None): - if dataset.permission == 'only_me': + if dataset.permission == DatasetPermissionEnum.ONLY_ME: if dataset.created_by != user.id: - raise NoPermissionError('You do not have permission to access this dataset.') + raise NoPermissionError("You do not have permission to access this dataset.") - elif dataset.permission == 'partial_members': + elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM: if not any( dp.dataset_id == dataset.id for dp in DatasetPermission.query.filter_by(account_id=user.id).all() ): - raise NoPermissionError('You do not have permission to access this dataset.') + raise NoPermissionError("You do not have permission to access this dataset.") @staticmethod def get_dataset_queries(dataset_id: str, page: int, per_page: int): - dataset_queries = DatasetQuery.query.filter_by(dataset_id=dataset_id) \ - .order_by(db.desc(DatasetQuery.created_at)) \ - .paginate( - page=page, per_page=per_page, max_per_page=100, error_out=False + dataset_queries = ( + DatasetQuery.query.filter_by(dataset_id=dataset_id) + .order_by(db.desc(DatasetQuery.created_at)) + .paginate(page=page, per_page=per_page, max_per_page=100, error_out=False) ) return dataset_queries.items, dataset_queries.total @staticmethod def get_related_apps(dataset_id: str): - return AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) \ - .order_by(db.desc(AppDatasetJoin.created_at)).all() + return ( + AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) + .order_by(db.desc(AppDatasetJoin.created_at)) + .all() + ) class DocumentService: DEFAULT_RULES = { - 'mode': 'custom', - 'rules': { - 'pre_processing_rules': [ - {'id': 'remove_extra_spaces', 'enabled': True}, - {'id': 'remove_urls_emails', 'enabled': False} + "mode": "custom", + "rules": { + "pre_processing_rules": [ + {"id": "remove_extra_spaces", "enabled": True}, + {"id": "remove_urls_emails", "enabled": False}, ], - 'segmentation': { - 'delimiter': '\n', - 'max_tokens': 500, - 'chunk_overlap': 50 - } - } + "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50}, + }, } DOCUMENT_METADATA_SCHEMA = { @@ -482,58 +447,55 @@ class DocumentService: "commit_date": str, "commit_author": str, }, - "others": dict + "others": dict, } @staticmethod def get_document(dataset_id: str, document_id: str) -> Optional[Document]: - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) return document @staticmethod def get_document_by_id(document_id: str) -> Optional[Document]: - document = db.session.query(Document).filter( - Document.id == document_id - ).first() + document = db.session.query(Document).filter(Document.id == document_id).first() return document @staticmethod def get_document_by_dataset_id(dataset_id: str) -> list[Document]: - documents = db.session.query(Document).filter( - Document.dataset_id == dataset_id, - Document.enabled == True - ).all() + documents = db.session.query(Document).filter(Document.dataset_id == dataset_id, Document.enabled == True).all() return documents @staticmethod def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]: - documents = db.session.query(Document).filter( - Document.dataset_id == dataset_id, - Document.indexing_status.in_(['error', 'paused']) - ).all() + documents = ( + db.session.query(Document) + .filter(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"])) + .all() + ) return documents @staticmethod def get_batch_documents(dataset_id: str, batch: str) -> list[Document]: - documents = db.session.query(Document).filter( - Document.batch == batch, - Document.dataset_id == dataset_id, - Document.tenant_id == current_user.current_tenant_id - ).all() + documents = ( + db.session.query(Document) + .filter( + Document.batch == batch, + Document.dataset_id == dataset_id, + Document.tenant_id == current_user.current_tenant_id, + ) + .all() + ) return documents @staticmethod def get_document_file_detail(file_id: str): - file_detail = db.session.query(UploadFile). \ - filter(UploadFile.id == file_id). \ - one_or_none() + file_detail = db.session.query(UploadFile).filter(UploadFile.id == file_id).one_or_none() return file_detail @staticmethod @@ -547,13 +509,14 @@ class DocumentService: def delete_document(document): # trigger document_was_deleted signal file_id = None - if document.data_source_type == 'upload_file': + if document.data_source_type == "upload_file": if document.data_source_info: data_source_info = document.data_source_info_dict - if data_source_info and 'upload_file_id' in data_source_info: - file_id = data_source_info['upload_file_id'] - document_was_deleted.send(document.id, dataset_id=document.dataset_id, - doc_form=document.doc_form, file_id=file_id) + if data_source_info and "upload_file_id" in data_source_info: + file_id = data_source_info["upload_file_id"] + document_was_deleted.send( + document.id, dataset_id=document.dataset_id, doc_form=document.doc_form, file_id=file_id + ) db.session.delete(document) db.session.commit() @@ -562,15 +525,15 @@ class DocumentService: def rename_document(dataset_id: str, document_id: str, name: str) -> Document: dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise ValueError('Dataset not found.') + raise ValueError("Dataset not found.") document = DocumentService.get_document(dataset_id, document_id) if not document: - raise ValueError('Document not found.') + raise ValueError("Document not found.") if document.tenant_id != current_user.current_tenant_id: - raise ValueError('No permission.') + raise ValueError("No permission.") document.name = name @@ -591,7 +554,7 @@ class DocumentService: db.session.add(document) db.session.commit() # set document paused flag - indexing_cache_key = 'document_{}_is_paused'.format(document.id) + indexing_cache_key = "document_{}_is_paused".format(document.id) redis_client.setnx(indexing_cache_key, "True") @staticmethod @@ -606,7 +569,7 @@ class DocumentService: db.session.add(document) db.session.commit() # delete paused flag - indexing_cache_key = 'document_{}_is_paused'.format(document.id) + indexing_cache_key = "document_{}_is_paused".format(document.id) redis_client.delete(indexing_cache_key) # trigger async task recover_document_indexing_task.delay(document.dataset_id, document.id) @@ -615,12 +578,12 @@ class DocumentService: def retry_document(dataset_id: str, documents: list[Document]): for document in documents: # add retry flag - retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id) + retry_indexing_cache_key = "document_{}_is_retried".format(document.id) cache_result = redis_client.get(retry_indexing_cache_key) if cache_result is not None: raise ValueError("Document is being retried, please try again later") # retry document indexing - document.indexing_status = 'waiting' + document.indexing_status = "waiting" db.session.add(document) db.session.commit() @@ -632,14 +595,14 @@ class DocumentService: @staticmethod def sync_website_document(dataset_id: str, document: Document): # add sync flag - sync_indexing_cache_key = 'document_{}_is_sync'.format(document.id) + sync_indexing_cache_key = "document_{}_is_sync".format(document.id) cache_result = redis_client.get(sync_indexing_cache_key) if cache_result is not None: raise ValueError("Document is being synced, please try again later") # sync document indexing - document.indexing_status = 'waiting' + document.indexing_status = "waiting" data_source_info = document.data_source_info_dict - data_source_info['mode'] = 'scrape' + data_source_info["mode"] = "scrape" document.data_source_info = json.dumps(data_source_info, ensure_ascii=False) db.session.add(document) db.session.commit() @@ -658,27 +621,28 @@ class DocumentService: @staticmethod def save_document_with_dataset_id( - dataset: Dataset, document_data: dict, - account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None, - created_from: str = 'web' + dataset: Dataset, + document_data: dict, + account: Account, + dataset_process_rule: Optional[DatasetProcessRule] = None, + created_from: str = "web", ): - # check document limit features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: - if 'original_document_id' not in document_data or not document_data['original_document_id']: + if "original_document_id" not in document_data or not document_data["original_document_id"]: count = 0 if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids'] + upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] count = len(upload_file_list) elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] + notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] for notion_info in notion_info_list: - count = count + len(notion_info['pages']) + count = count + len(notion_info["pages"]) elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]['info_list']['website_info_list'] - count = len(website_info['urls']) + website_info = document_data["data_source"]["info_list"]["website_info_list"] + count = len(website_info["urls"]) batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") @@ -690,42 +654,41 @@ class DocumentService: dataset.data_source_type = document_data["data_source"]["type"] if not dataset.indexing_technique: - if 'indexing_technique' not in document_data \ - or document_data['indexing_technique'] not in Dataset.INDEXING_TECHNIQUE_LIST: + if ( + "indexing_technique" not in document_data + or document_data["indexing_technique"] not in Dataset.INDEXING_TECHNIQUE_LIST + ): raise ValueError("Indexing technique is required") dataset.indexing_technique = document_data["indexing_technique"] - if document_data["indexing_technique"] == 'high_quality': + if document_data["indexing_technique"] == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, - model_type=ModelType.TEXT_EMBEDDING + tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING ) dataset.embedding_model = embedding_model.model dataset.embedding_model_provider = embedding_model.provider dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, - embedding_model.model + embedding_model.provider, embedding_model.model ) dataset.collection_binding_id = dataset_collection_binding.id if not dataset.retrieval_model: default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } - dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get( - 'retrieval_model' - ) else default_retrieval_model + dataset.retrieval_model = ( + document_data.get("retrieval_model") + if document_data.get("retrieval_model") + else default_retrieval_model + ) documents = [] - batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) + batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) if document_data.get("original_document_id"): document = DocumentService.update_document_with_dataset_id(dataset, document_data, account) documents.append(document) @@ -738,14 +701,14 @@ class DocumentService: dataset_id=dataset.id, mode=process_rule["mode"], rules=json.dumps(process_rule["rules"]), - created_by=account.id + created_by=account.id, ) elif process_rule["mode"] == "automatic": dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, mode=process_rule["mode"], rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), - created_by=account.id + created_by=account.id, ) db.session.add(dataset_process_rule) db.session.commit() @@ -753,12 +716,13 @@ class DocumentService: document_ids = [] duplicate_document_ids = [] if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids'] + upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] for file_id in upload_file_list: - file = db.session.query(UploadFile).filter( - UploadFile.tenant_id == dataset.tenant_id, - UploadFile.id == file_id - ).first() + file = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) + .first() + ) # raise error if file not found if not file: @@ -769,34 +733,39 @@ class DocumentService: "upload_file_id": file_id, } # check duplicate - if document_data.get('duplicate', False): + if document_data.get("duplicate", False): document = Document.query.filter_by( dataset_id=dataset.id, tenant_id=current_user.current_tenant_id, - data_source_type='upload_file', + data_source_type="upload_file", enabled=True, - name=file_name + name=file_name, ).first() if document: document.dataset_process_rule_id = dataset_process_rule.id document.updated_at = datetime.datetime.utcnow() document.created_from = created_from - document.doc_form = document_data['doc_form'] - document.doc_language = document_data['doc_language'] + document.doc_form = document_data["doc_form"] + document.doc_language = document_data["doc_language"] document.data_source_info = json.dumps(data_source_info) document.batch = batch - document.indexing_status = 'waiting' + document.indexing_status = "waiting" db.session.add(document) documents.append(document) duplicate_document_ids.append(document.id) continue document = DocumentService.build_document( - dataset, dataset_process_rule.id, + dataset, + dataset_process_rule.id, document_data["data_source"]["type"], document_data["doc_form"], document_data["doc_language"], - data_source_info, created_from, position, - account, file_name, batch + data_source_info, + created_from, + position, + account, + file_name, + batch, ) db.session.add(document) db.session.flush() @@ -804,47 +773,52 @@ class DocumentService: documents.append(document) position += 1 elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] + notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] exist_page_ids = [] exist_document = {} documents = Document.query.filter_by( dataset_id=dataset.id, tenant_id=current_user.current_tenant_id, - data_source_type='notion_import', - enabled=True + data_source_type="notion_import", + enabled=True, ).all() if documents: for document in documents: data_source_info = json.loads(document.data_source_info) - exist_page_ids.append(data_source_info['notion_page_id']) - exist_document[data_source_info['notion_page_id']] = document.id + exist_page_ids.append(data_source_info["notion_page_id"]) + exist_document[data_source_info["notion_page_id"]] = document.id for notion_info in notion_info_list: - workspace_id = notion_info['workspace_id'] + workspace_id = notion_info["workspace_id"] data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"' + DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', ) ).first() if not data_source_binding: - raise ValueError('Data source binding not found.') - for page in notion_info['pages']: - if page['page_id'] not in exist_page_ids: + raise ValueError("Data source binding not found.") + for page in notion_info["pages"]: + if page["page_id"] not in exist_page_ids: data_source_info = { "notion_workspace_id": workspace_id, - "notion_page_id": page['page_id'], - "notion_page_icon": page['page_icon'], - "type": page['type'] + "notion_page_id": page["page_id"], + "notion_page_icon": page["page_icon"], + "type": page["type"], } document = DocumentService.build_document( - dataset, dataset_process_rule.id, + dataset, + dataset_process_rule.id, document_data["data_source"]["type"], document_data["doc_form"], document_data["doc_language"], - data_source_info, created_from, position, - account, page['page_name'], batch + data_source_info, + created_from, + position, + account, + page["page_name"], + batch, ) db.session.add(document) db.session.flush() @@ -852,32 +826,37 @@ class DocumentService: documents.append(document) position += 1 else: - exist_document.pop(page['page_id']) + exist_document.pop(page["page_id"]) # delete not selected documents if len(exist_document) > 0: clean_notion_document_task.delay(list(exist_document.values()), dataset.id) elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]['info_list']['website_info_list'] - urls = website_info['urls'] + website_info = document_data["data_source"]["info_list"]["website_info_list"] + urls = website_info["urls"] for url in urls: data_source_info = { - 'url': url, - 'provider': website_info['provider'], - 'job_id': website_info['job_id'], - 'only_main_content': website_info.get('only_main_content', False), - 'mode': 'crawl', + "url": url, + "provider": website_info["provider"], + "job_id": website_info["job_id"], + "only_main_content": website_info.get("only_main_content", False), + "mode": "crawl", } if len(url) > 255: - document_name = url[:200] + '...' + document_name = url[:200] + "..." else: document_name = url document = DocumentService.build_document( - dataset, dataset_process_rule.id, + dataset, + dataset_process_rule.id, document_data["data_source"]["type"], document_data["doc_form"], document_data["doc_language"], - data_source_info, created_from, position, - account, document_name, batch + data_source_info, + created_from, + position, + account, + document_name, + batch, ) db.session.add(document) db.session.flush() @@ -899,15 +878,22 @@ class DocumentService: can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size if count > can_upload_size: raise ValueError( - f'You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded.' + f"You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded." ) @staticmethod def build_document( - dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str, - document_language: str, data_source_info: dict, created_from: str, position: int, + dataset: Dataset, + process_rule_id: str, + data_source_type: str, + document_form: str, + document_language: str, + data_source_info: dict, + created_from: str, + position: int, account: Account, - name: str, batch: str + name: str, + batch: str, ): document = Document( tenant_id=dataset.tenant_id, @@ -921,7 +907,7 @@ class DocumentService: created_from=created_from, created_by=account.id, doc_form=document_form, - doc_language=document_language + doc_language=document_language, ) return document @@ -931,54 +917,57 @@ class DocumentService: Document.completed_at.isnot(None), Document.enabled == True, Document.archived == False, - Document.tenant_id == current_user.current_tenant_id + Document.tenant_id == current_user.current_tenant_id, ).count() return documents_count @staticmethod def update_document_with_dataset_id( - dataset: Dataset, document_data: dict, - account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None, - created_from: str = 'web' + dataset: Dataset, + document_data: dict, + account: Account, + dataset_process_rule: Optional[DatasetProcessRule] = None, + created_from: str = "web", ): DatasetService.check_dataset_model_setting(dataset) document = DocumentService.get_document(dataset.id, document_data["original_document_id"]) - if document.display_status != 'available': + if document.display_status != "available": raise ValueError("Document is not available") # update document name - if document_data.get('name'): - document.name = document_data['name'] + if document_data.get("name"): + document.name = document_data["name"] # save process rule - if document_data.get('process_rule'): + if document_data.get("process_rule"): process_rule = document_data["process_rule"] if process_rule["mode"] == "custom": dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, mode=process_rule["mode"], rules=json.dumps(process_rule["rules"]), - created_by=account.id + created_by=account.id, ) elif process_rule["mode"] == "automatic": dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, mode=process_rule["mode"], rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), - created_by=account.id + created_by=account.id, ) db.session.add(dataset_process_rule) db.session.commit() document.dataset_process_rule_id = dataset_process_rule.id # update document data source - if document_data.get('data_source'): - file_name = '' + if document_data.get("data_source"): + file_name = "" data_source_info = {} if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids'] + upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] for file_id in upload_file_list: - file = db.session.query(UploadFile).filter( - UploadFile.tenant_id == dataset.tenant_id, - UploadFile.id == file_id - ).first() + file = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) + .first() + ) # raise error if file not found if not file: @@ -989,42 +978,42 @@ class DocumentService: "upload_file_id": file_id, } elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] + notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] for notion_info in notion_info_list: - workspace_id = notion_info['workspace_id'] + workspace_id = notion_info["workspace_id"] data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"' + DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', ) ).first() if not data_source_binding: - raise ValueError('Data source binding not found.') - for page in notion_info['pages']: + raise ValueError("Data source binding not found.") + for page in notion_info["pages"]: data_source_info = { "notion_workspace_id": workspace_id, - "notion_page_id": page['page_id'], - "notion_page_icon": page['page_icon'], - "type": page['type'] + "notion_page_id": page["page_id"], + "notion_page_icon": page["page_icon"], + "type": page["type"], } elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]['info_list']['website_info_list'] - urls = website_info['urls'] + website_info = document_data["data_source"]["info_list"]["website_info_list"] + urls = website_info["urls"] for url in urls: data_source_info = { - 'url': url, - 'provider': website_info['provider'], - 'job_id': website_info['job_id'], - 'only_main_content': website_info.get('only_main_content', False), - 'mode': 'crawl', + "url": url, + "provider": website_info["provider"], + "job_id": website_info["job_id"], + "only_main_content": website_info.get("only_main_content", False), + "mode": "crawl", } document.data_source_type = document_data["data_source"]["type"] document.data_source_info = json.dumps(data_source_info) document.name = file_name # update document to be waiting - document.indexing_status = 'waiting' + document.indexing_status = "waiting" document.completed_at = None document.processing_started_at = None document.parsing_completed_at = None @@ -1032,13 +1021,11 @@ class DocumentService: document.splitting_completed_at = None document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) document.created_from = created_from - document.doc_form = document_data['doc_form'] + document.doc_form = document_data["doc_form"] db.session.add(document) db.session.commit() # update document segment - update_params = { - DocumentSegment.status: 're_segment' - } + update_params = {DocumentSegment.status: "re_segment"} DocumentSegment.query.filter_by(document_id=document.id).update(update_params) db.session.commit() # trigger async task @@ -1052,60 +1039,54 @@ class DocumentService: if features.billing.enabled: count = 0 if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids'] + upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] count = len(upload_file_list) elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] + notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] for notion_info in notion_info_list: - count = count + len(notion_info['pages']) + count = count + len(notion_info["pages"]) elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]['info_list']['website_info_list'] - count = len(website_info['urls']) + website_info = document_data["data_source"]["info_list"]["website_info_list"] + count = len(website_info["urls"]) batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") DocumentService.check_documents_upload_quota(count, features) - embedding_model = None dataset_collection_binding_id = None retrieval_model = None - if document_data['indexing_technique'] == 'high_quality': + if document_data["indexing_technique"] == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, - model_type=ModelType.TEXT_EMBEDDING + tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING ) dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, - embedding_model.model + embedding_model.provider, embedding_model.model ) dataset_collection_binding_id = dataset_collection_binding.id - if document_data.get('retrieval_model'): - retrieval_model = document_data['retrieval_model'] + if document_data.get("retrieval_model"): + retrieval_model = document_data["retrieval_model"] else: default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } retrieval_model = default_retrieval_model # save dataset dataset = Dataset( tenant_id=tenant_id, - name='', + name="", data_source_type=document_data["data_source"]["type"], - indexing_technique=document_data["indexing_technique"], + indexing_technique=document_data.get("indexing_technique", "high_quality"), created_by=account.id, - embedding_model=embedding_model.model if embedding_model else None, - embedding_model_provider=embedding_model.provider if embedding_model else None, + embedding_model=document_data.get("embedding_model"), + embedding_model_provider=document_data.get("embedding_model_provider"), collection_binding_id=dataset_collection_binding_id, - retrieval_model=retrieval_model + retrieval_model=retrieval_model, ) db.session.add(dataset) @@ -1115,236 +1096,259 @@ class DocumentService: cut_length = 18 cut_name = documents[0].name[:cut_length] - dataset.name = cut_name + '...' - dataset.description = 'useful for when you want to answer queries about the ' + documents[0].name + dataset.name = cut_name + "..." + dataset.description = "useful for when you want to answer queries about the " + documents[0].name db.session.commit() return dataset, documents, batch @classmethod def document_create_args_validate(cls, args: dict): - if 'original_document_id' not in args or not args['original_document_id']: + if "original_document_id" not in args or not args["original_document_id"]: DocumentService.data_source_args_validate(args) DocumentService.process_rule_args_validate(args) else: - if ('data_source' not in args and not args['data_source']) \ - and ('process_rule' not in args and not args['process_rule']): + if ("data_source" not in args and not args["data_source"]) and ( + "process_rule" not in args and not args["process_rule"] + ): raise ValueError("Data source or Process rule is required") else: - if args.get('data_source'): + if args.get("data_source"): DocumentService.data_source_args_validate(args) - if args.get('process_rule'): + if args.get("process_rule"): DocumentService.process_rule_args_validate(args) @classmethod def data_source_args_validate(cls, args: dict): - if 'data_source' not in args or not args['data_source']: + if "data_source" not in args or not args["data_source"]: raise ValueError("Data source is required") - if not isinstance(args['data_source'], dict): + if not isinstance(args["data_source"], dict): raise ValueError("Data source is invalid") - if 'type' not in args['data_source'] or not args['data_source']['type']: + if "type" not in args["data_source"] or not args["data_source"]["type"]: raise ValueError("Data source type is required") - if args['data_source']['type'] not in Document.DATA_SOURCES: + if args["data_source"]["type"] not in Document.DATA_SOURCES: raise ValueError("Data source type is invalid") - if 'info_list' not in args['data_source'] or not args['data_source']['info_list']: + if "info_list" not in args["data_source"] or not args["data_source"]["info_list"]: raise ValueError("Data source info is required") - if args['data_source']['type'] == 'upload_file': - if 'file_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][ - 'file_info_list']: + if args["data_source"]["type"] == "upload_file": + if ( + "file_info_list" not in args["data_source"]["info_list"] + or not args["data_source"]["info_list"]["file_info_list"] + ): raise ValueError("File source info is required") - if args['data_source']['type'] == 'notion_import': - if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][ - 'notion_info_list']: + if args["data_source"]["type"] == "notion_import": + if ( + "notion_info_list" not in args["data_source"]["info_list"] + or not args["data_source"]["info_list"]["notion_info_list"] + ): raise ValueError("Notion source info is required") - if args['data_source']['type'] == 'website_crawl': - if 'website_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][ - 'website_info_list']: + if args["data_source"]["type"] == "website_crawl": + if ( + "website_info_list" not in args["data_source"]["info_list"] + or not args["data_source"]["info_list"]["website_info_list"] + ): raise ValueError("Website source info is required") @classmethod def process_rule_args_validate(cls, args: dict): - if 'process_rule' not in args or not args['process_rule']: + if "process_rule" not in args or not args["process_rule"]: raise ValueError("Process rule is required") - if not isinstance(args['process_rule'], dict): + if not isinstance(args["process_rule"], dict): raise ValueError("Process rule is invalid") - if 'mode' not in args['process_rule'] or not args['process_rule']['mode']: + if "mode" not in args["process_rule"] or not args["process_rule"]["mode"]: raise ValueError("Process rule mode is required") - if args['process_rule']['mode'] not in DatasetProcessRule.MODES: + if args["process_rule"]["mode"] not in DatasetProcessRule.MODES: raise ValueError("Process rule mode is invalid") - if args['process_rule']['mode'] == 'automatic': - args['process_rule']['rules'] = {} + if args["process_rule"]["mode"] == "automatic": + args["process_rule"]["rules"] = {} else: - if 'rules' not in args['process_rule'] or not args['process_rule']['rules']: + if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]: raise ValueError("Process rule rules is required") - if not isinstance(args['process_rule']['rules'], dict): + if not isinstance(args["process_rule"]["rules"], dict): raise ValueError("Process rule rules is invalid") - if 'pre_processing_rules' not in args['process_rule']['rules'] \ - or args['process_rule']['rules']['pre_processing_rules'] is None: + if ( + "pre_processing_rules" not in args["process_rule"]["rules"] + or args["process_rule"]["rules"]["pre_processing_rules"] is None + ): raise ValueError("Process rule pre_processing_rules is required") - if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list): + if not isinstance(args["process_rule"]["rules"]["pre_processing_rules"], list): raise ValueError("Process rule pre_processing_rules is invalid") unique_pre_processing_rule_dicts = {} - for pre_processing_rule in args['process_rule']['rules']['pre_processing_rules']: - if 'id' not in pre_processing_rule or not pre_processing_rule['id']: + for pre_processing_rule in args["process_rule"]["rules"]["pre_processing_rules"]: + if "id" not in pre_processing_rule or not pre_processing_rule["id"]: raise ValueError("Process rule pre_processing_rules id is required") - if pre_processing_rule['id'] not in DatasetProcessRule.PRE_PROCESSING_RULES: + if pre_processing_rule["id"] not in DatasetProcessRule.PRE_PROCESSING_RULES: raise ValueError("Process rule pre_processing_rules id is invalid") - if 'enabled' not in pre_processing_rule or pre_processing_rule['enabled'] is None: + if "enabled" not in pre_processing_rule or pre_processing_rule["enabled"] is None: raise ValueError("Process rule pre_processing_rules enabled is required") - if not isinstance(pre_processing_rule['enabled'], bool): + if not isinstance(pre_processing_rule["enabled"], bool): raise ValueError("Process rule pre_processing_rules enabled is invalid") - unique_pre_processing_rule_dicts[pre_processing_rule['id']] = pre_processing_rule + unique_pre_processing_rule_dicts[pre_processing_rule["id"]] = pre_processing_rule - args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values()) + args["process_rule"]["rules"]["pre_processing_rules"] = list(unique_pre_processing_rule_dicts.values()) - if 'segmentation' not in args['process_rule']['rules'] \ - or args['process_rule']['rules']['segmentation'] is None: + if ( + "segmentation" not in args["process_rule"]["rules"] + or args["process_rule"]["rules"]["segmentation"] is None + ): raise ValueError("Process rule segmentation is required") - if not isinstance(args['process_rule']['rules']['segmentation'], dict): + if not isinstance(args["process_rule"]["rules"]["segmentation"], dict): raise ValueError("Process rule segmentation is invalid") - if 'separator' not in args['process_rule']['rules']['segmentation'] \ - or not args['process_rule']['rules']['segmentation']['separator']: + if ( + "separator" not in args["process_rule"]["rules"]["segmentation"] + or not args["process_rule"]["rules"]["segmentation"]["separator"] + ): raise ValueError("Process rule segmentation separator is required") - if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str): + if not isinstance(args["process_rule"]["rules"]["segmentation"]["separator"], str): raise ValueError("Process rule segmentation separator is invalid") - if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \ - or not args['process_rule']['rules']['segmentation']['max_tokens']: + if ( + "max_tokens" not in args["process_rule"]["rules"]["segmentation"] + or not args["process_rule"]["rules"]["segmentation"]["max_tokens"] + ): raise ValueError("Process rule segmentation max_tokens is required") - if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int): + if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int): raise ValueError("Process rule segmentation max_tokens is invalid") @classmethod def estimate_args_validate(cls, args: dict): - if 'info_list' not in args or not args['info_list']: + if "info_list" not in args or not args["info_list"]: raise ValueError("Data source info is required") - if not isinstance(args['info_list'], dict): + if not isinstance(args["info_list"], dict): raise ValueError("Data info is invalid") - if 'process_rule' not in args or not args['process_rule']: + if "process_rule" not in args or not args["process_rule"]: raise ValueError("Process rule is required") - if not isinstance(args['process_rule'], dict): + if not isinstance(args["process_rule"], dict): raise ValueError("Process rule is invalid") - if 'mode' not in args['process_rule'] or not args['process_rule']['mode']: + if "mode" not in args["process_rule"] or not args["process_rule"]["mode"]: raise ValueError("Process rule mode is required") - if args['process_rule']['mode'] not in DatasetProcessRule.MODES: + if args["process_rule"]["mode"] not in DatasetProcessRule.MODES: raise ValueError("Process rule mode is invalid") - if args['process_rule']['mode'] == 'automatic': - args['process_rule']['rules'] = {} + if args["process_rule"]["mode"] == "automatic": + args["process_rule"]["rules"] = {} else: - if 'rules' not in args['process_rule'] or not args['process_rule']['rules']: + if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]: raise ValueError("Process rule rules is required") - if not isinstance(args['process_rule']['rules'], dict): + if not isinstance(args["process_rule"]["rules"], dict): raise ValueError("Process rule rules is invalid") - if 'pre_processing_rules' not in args['process_rule']['rules'] \ - or args['process_rule']['rules']['pre_processing_rules'] is None: + if ( + "pre_processing_rules" not in args["process_rule"]["rules"] + or args["process_rule"]["rules"]["pre_processing_rules"] is None + ): raise ValueError("Process rule pre_processing_rules is required") - if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list): + if not isinstance(args["process_rule"]["rules"]["pre_processing_rules"], list): raise ValueError("Process rule pre_processing_rules is invalid") unique_pre_processing_rule_dicts = {} - for pre_processing_rule in args['process_rule']['rules']['pre_processing_rules']: - if 'id' not in pre_processing_rule or not pre_processing_rule['id']: + for pre_processing_rule in args["process_rule"]["rules"]["pre_processing_rules"]: + if "id" not in pre_processing_rule or not pre_processing_rule["id"]: raise ValueError("Process rule pre_processing_rules id is required") - if pre_processing_rule['id'] not in DatasetProcessRule.PRE_PROCESSING_RULES: + if pre_processing_rule["id"] not in DatasetProcessRule.PRE_PROCESSING_RULES: raise ValueError("Process rule pre_processing_rules id is invalid") - if 'enabled' not in pre_processing_rule or pre_processing_rule['enabled'] is None: + if "enabled" not in pre_processing_rule or pre_processing_rule["enabled"] is None: raise ValueError("Process rule pre_processing_rules enabled is required") - if not isinstance(pre_processing_rule['enabled'], bool): + if not isinstance(pre_processing_rule["enabled"], bool): raise ValueError("Process rule pre_processing_rules enabled is invalid") - unique_pre_processing_rule_dicts[pre_processing_rule['id']] = pre_processing_rule + unique_pre_processing_rule_dicts[pre_processing_rule["id"]] = pre_processing_rule - args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values()) + args["process_rule"]["rules"]["pre_processing_rules"] = list(unique_pre_processing_rule_dicts.values()) - if 'segmentation' not in args['process_rule']['rules'] \ - or args['process_rule']['rules']['segmentation'] is None: + if ( + "segmentation" not in args["process_rule"]["rules"] + or args["process_rule"]["rules"]["segmentation"] is None + ): raise ValueError("Process rule segmentation is required") - if not isinstance(args['process_rule']['rules']['segmentation'], dict): + if not isinstance(args["process_rule"]["rules"]["segmentation"], dict): raise ValueError("Process rule segmentation is invalid") - if 'separator' not in args['process_rule']['rules']['segmentation'] \ - or not args['process_rule']['rules']['segmentation']['separator']: + if ( + "separator" not in args["process_rule"]["rules"]["segmentation"] + or not args["process_rule"]["rules"]["segmentation"]["separator"] + ): raise ValueError("Process rule segmentation separator is required") - if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str): + if not isinstance(args["process_rule"]["rules"]["segmentation"]["separator"], str): raise ValueError("Process rule segmentation separator is invalid") - if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \ - or not args['process_rule']['rules']['segmentation']['max_tokens']: + if ( + "max_tokens" not in args["process_rule"]["rules"]["segmentation"] + or not args["process_rule"]["rules"]["segmentation"]["max_tokens"] + ): raise ValueError("Process rule segmentation max_tokens is required") - if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int): + if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int): raise ValueError("Process rule segmentation max_tokens is invalid") class SegmentService: @classmethod def segment_create_args_validate(cls, args: dict, document: Document): - if document.doc_form == 'qa_model': - if 'answer' not in args or not args['answer']: + if document.doc_form == "qa_model": + if "answer" not in args or not args["answer"]: raise ValueError("Answer is required") - if not args['answer'].strip(): + if not args["answer"].strip(): raise ValueError("Answer is empty") - if 'content' not in args or not args['content'] or not args['content'].strip(): + if "content" not in args or not args["content"] or not args["content"].strip(): raise ValueError("Content is empty") @classmethod def create_segment(cls, args: dict, document: Document, dataset: Dataset): - content = args['content'] + content = args["content"] doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) # calc embedding use tokens - tokens = embedding_model.get_text_embedding_num_tokens( - texts=[content] - ) - lock_name = 'add_segment_lock_document_id_{}'.format(document.id) + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) + lock_name = "add_segment_lock_document_id_{}".format(document.id) with redis_client.lock(lock_name, timeout=600): - max_position = db.session.query(func.max(DocumentSegment.position)).filter( - DocumentSegment.document_id == document.id - ).scalar() + max_position = ( + db.session.query(func.max(DocumentSegment.position)) + .filter(DocumentSegment.document_id == document.id) + .scalar() + ) segment_document = DocumentSegment( tenant_id=current_user.current_tenant_id, dataset_id=document.dataset_id, @@ -1355,25 +1359,25 @@ class SegmentService: content=content, word_count=len(content), tokens=tokens, - status='completed', + status="completed", indexing_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), - created_by=current_user.id + created_by=current_user.id, ) - if document.doc_form == 'qa_model': - segment_document.answer = args['answer'] + if document.doc_form == "qa_model": + segment_document.answer = args["answer"] db.session.add(segment_document) db.session.commit() # save vector index try: - VectorService.create_segments_vector([args['keywords']], [segment_document], dataset) + VectorService.create_segments_vector([args["keywords"]], [segment_document], dataset) except Exception as e: logging.exception("create segment index failed") segment_document.enabled = False segment_document.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - segment_document.status = 'error' + segment_document.status = "error" segment_document.error = str(e) db.session.commit() segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first() @@ -1381,33 +1385,33 @@ class SegmentService: @classmethod def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset): - lock_name = 'multi_add_segment_lock_document_id_{}'.format(document.id) + lock_name = "multi_add_segment_lock_document_id_{}".format(document.id) with redis_client.lock(lock_name, timeout=600): embedding_model = None - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) - max_position = db.session.query(func.max(DocumentSegment.position)).filter( - DocumentSegment.document_id == document.id - ).scalar() + max_position = ( + db.session.query(func.max(DocumentSegment.position)) + .filter(DocumentSegment.document_id == document.id) + .scalar() + ) pre_segment_data_list = [] segment_data_list = [] keywords_list = [] for segment_item in segments: - content = segment_item['content'] + content = segment_item["content"] doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == 'high_quality' and embedding_model: + if dataset.indexing_technique == "high_quality" and embedding_model: # calc embedding use tokens - tokens = embedding_model.get_text_embedding_num_tokens( - texts=[content] - ) + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) segment_document = DocumentSegment( tenant_id=current_user.current_tenant_id, dataset_id=document.dataset_id, @@ -1418,19 +1422,19 @@ class SegmentService: content=content, word_count=len(content), tokens=tokens, - status='completed', + status="completed", indexing_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), - created_by=current_user.id + created_by=current_user.id, ) - if document.doc_form == 'qa_model': - segment_document.answer = segment_item['answer'] + if document.doc_form == "qa_model": + segment_document.answer = segment_item["answer"] db.session.add(segment_document) segment_data_list.append(segment_document) pre_segment_data_list.append(segment_document) - if 'keywords' in segment_item: - keywords_list.append(segment_item['keywords']) + if "keywords" in segment_item: + keywords_list.append(segment_item["keywords"]) else: keywords_list.append(None) @@ -1442,19 +1446,19 @@ class SegmentService: for segment_document in segment_data_list: segment_document.enabled = False segment_document.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - segment_document.status = 'error' + segment_document.status = "error" segment_document.error = str(e) db.session.commit() return segment_data_list @classmethod def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset): - indexing_cache_key = 'segment_{}_indexing'.format(segment.id) + indexing_cache_key = "segment_{}_indexing".format(segment.id) cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise ValueError("Segment is indexing, please try again later") - if 'enabled' in args and args['enabled'] is not None: - action = args['enabled'] + if "enabled" in args and args["enabled"] is not None: + action = args["enabled"] if segment.enabled != action: if not action: segment.enabled = action @@ -1467,25 +1471,25 @@ class SegmentService: disable_segment_from_index_task.delay(segment.id) return segment if not segment.enabled: - if 'enabled' in args and args['enabled'] is not None: - if not args['enabled']: + if "enabled" in args and args["enabled"] is not None: + if not args["enabled"]: raise ValueError("Can't update disabled segment") else: raise ValueError("Can't update disabled segment") try: - content = args['content'] + content = args["content"] if segment.content == content: - if document.doc_form == 'qa_model': - segment.answer = args['answer'] - if args.get('keywords'): - segment.keywords = args['keywords'] + if document.doc_form == "qa_model": + segment.answer = args["answer"] + if args.get("keywords"): + segment.keywords = args["keywords"] segment.enabled = True segment.disabled_at = None segment.disabled_by = None db.session.add(segment) db.session.commit() # update segment index task - if 'keywords' in args: + if "keywords" in args: keyword = Keyword(dataset) keyword.delete_by_ids([segment.index_node_id]) document = RAGDocument( @@ -1495,30 +1499,28 @@ class SegmentService: "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) - keyword.add_texts([document], keywords_list=[args['keywords']]) + keyword.add_texts([document], keywords_list=[args["keywords"]]) else: segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) # calc embedding use tokens - tokens = embedding_model.get_text_embedding_num_tokens( - texts=[content] - ) + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) segment.content = content segment.index_node_hash = segment_hash segment.word_count = len(content) segment.tokens = tokens - segment.status = 'completed' + segment.status = "completed" segment.indexing_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) segment.completed_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) segment.updated_by = current_user.id @@ -1526,18 +1528,18 @@ class SegmentService: segment.enabled = True segment.disabled_at = None segment.disabled_by = None - if document.doc_form == 'qa_model': - segment.answer = args['answer'] + if document.doc_form == "qa_model": + segment.answer = args["answer"] db.session.add(segment) db.session.commit() # update segment vector index - VectorService.update_segment_vector(args['keywords'], segment, dataset) + VectorService.update_segment_vector(args["keywords"], segment, dataset) except Exception as e: logging.exception("update segment index failed") segment.enabled = False segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - segment.status = 'error' + segment.status = "error" segment.error = str(e) db.session.commit() segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first() @@ -1545,7 +1547,7 @@ class SegmentService: @classmethod def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: Dataset): - indexing_cache_key = 'segment_{}_delete_indexing'.format(segment.id) + indexing_cache_key = "segment_{}_delete_indexing".format(segment.id) cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise ValueError("Segment is deleting.") @@ -1562,24 +1564,25 @@ class SegmentService: class DatasetCollectionBindingService: @classmethod def get_dataset_collection_binding( - cls, provider_name: str, model_name: str, - collection_type: str = 'dataset' + cls, provider_name: str, model_name: str, collection_type: str = "dataset" ) -> DatasetCollectionBinding: - dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ - filter( - DatasetCollectionBinding.provider_name == provider_name, - DatasetCollectionBinding.model_name == model_name, - DatasetCollectionBinding.type == collection_type - ). \ - order_by(DatasetCollectionBinding.created_at). \ - first() + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter( + DatasetCollectionBinding.provider_name == provider_name, + DatasetCollectionBinding.model_name == model_name, + DatasetCollectionBinding.type == collection_type, + ) + .order_by(DatasetCollectionBinding.created_at) + .first() + ) if not dataset_collection_binding: dataset_collection_binding = DatasetCollectionBinding( provider_name=provider_name, model_name=model_name, collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())), - type=collection_type + type=collection_type, ) db.session.add(dataset_collection_binding) db.session.commit() @@ -1587,16 +1590,16 @@ class DatasetCollectionBindingService: @classmethod def get_dataset_collection_binding_by_id_and_type( - cls, collection_binding_id: str, - collection_type: str = 'dataset' + cls, collection_binding_id: str, collection_type: str = "dataset" ) -> DatasetCollectionBinding: - dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ - filter( - DatasetCollectionBinding.id == collection_binding_id, - DatasetCollectionBinding.type == collection_type - ). \ - order_by(DatasetCollectionBinding.created_at). \ - first() + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter( + DatasetCollectionBinding.id == collection_binding_id, DatasetCollectionBinding.type == collection_type + ) + .order_by(DatasetCollectionBinding.created_at) + .first() + ) return dataset_collection_binding @@ -1604,11 +1607,13 @@ class DatasetCollectionBindingService: class DatasetPermissionService: @classmethod def get_dataset_partial_member_list(cls, dataset_id): - user_list_query = db.session.query( - DatasetPermission.account_id, - ).filter( - DatasetPermission.dataset_id == dataset_id - ).all() + user_list_query = ( + db.session.query( + DatasetPermission.account_id, + ) + .filter(DatasetPermission.dataset_id == dataset_id) + .all() + ) user_list = [] for user in user_list_query: @@ -1625,7 +1630,7 @@ class DatasetPermissionService: permission = DatasetPermission( tenant_id=tenant_id, dataset_id=dataset_id, - account_id=user['user_id'], + account_id=user["user_id"], ) permissions.append(permission) @@ -1638,19 +1643,19 @@ class DatasetPermissionService: @classmethod def check_permission(cls, user, dataset, requested_permission, requested_partial_member_list): if not user.is_dataset_editor: - raise NoPermissionError('User does not have permission to edit this dataset.') + raise NoPermissionError("User does not have permission to edit this dataset.") if user.is_dataset_operator and dataset.permission != requested_permission: - raise NoPermissionError('Dataset operators cannot change the dataset permissions.') + raise NoPermissionError("Dataset operators cannot change the dataset permissions.") - if user.is_dataset_operator and requested_permission == 'partial_members': + if user.is_dataset_operator and requested_permission == "partial_members": if not requested_partial_member_list: - raise ValueError('Partial member list is required when setting to partial members.') + raise ValueError("Partial member list is required when setting to partial members.") local_member_list = cls.get_dataset_partial_member_list(dataset.id) - request_member_list = [user['user_id'] for user in requested_partial_member_list] + request_member_list = [user["user_id"] for user in requested_partial_member_list] if set(local_member_list) != set(request_member_list): - raise ValueError('Dataset operators cannot change the dataset permissions.') + raise ValueError("Dataset operators cannot change the dataset permissions.") @classmethod def clear_partial_member_list(cls, dataset_id): diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py index c483d28152..ddee52164b 100644 --- a/api/services/enterprise/base.py +++ b/api/services/enterprise/base.py @@ -4,15 +4,12 @@ import requests class EnterpriseRequest: - base_url = os.environ.get('ENTERPRISE_API_URL', 'ENTERPRISE_API_URL') - secret_key = os.environ.get('ENTERPRISE_API_SECRET_KEY', 'ENTERPRISE_API_SECRET_KEY') + base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL") + secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY") @classmethod def send_request(cls, method, endpoint, json=None, params=None): - headers = { - "Content-Type": "application/json", - "Enterprise-Api-Secret-Key": cls.secret_key - } + headers = {"Content-Type": "application/json", "Enterprise-Api-Secret-Key": cls.secret_key} url = f"{cls.base_url}{endpoint}" response = requests.request(method, url, json=json, params=params, headers=headers) diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 115d0d5523..abc01ddf8f 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -2,7 +2,10 @@ from services.enterprise.base import EnterpriseRequest class EnterpriseService: - @classmethod def get_info(cls): - return EnterpriseRequest.send_request('GET', '/info') + return EnterpriseRequest.send_request("GET", "/info") + + @classmethod + def get_app_web_sso_enabled(cls, app_code): + return EnterpriseRequest.send_request("GET", f"/app-sso-setting?appCode={app_code}") diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index e5e4d7e235..c519f0b0e5 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -22,14 +22,16 @@ class CustomConfigurationStatus(Enum): """ Enum class for custom configuration status. """ - ACTIVE = 'active' - NO_CONFIGURE = 'no-configure' + + ACTIVE = "active" + NO_CONFIGURE = "no-configure" class CustomConfigurationResponse(BaseModel): """ Model class for provider custom configuration response. """ + status: CustomConfigurationStatus @@ -37,6 +39,7 @@ class SystemConfigurationResponse(BaseModel): """ Model class for provider system configuration response. """ + enabled: bool current_quota_type: Optional[ProviderQuotaType] = None quota_configurations: list[QuotaConfiguration] = [] @@ -46,6 +49,7 @@ class ProviderResponse(BaseModel): """ Model class for provider response. """ + provider: str label: I18nObject description: Optional[I18nObject] = None @@ -67,18 +71,15 @@ class ProviderResponse(BaseModel): def __init__(self, **data) -> None: super().__init__(**data) - url_prefix = (dify_config.CONSOLE_API_URL - + f"/console/api/workspaces/current/model-providers/{self.provider}") + url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}" if self.icon_small is not None: self.icon_small = I18nObject( - en_US=f"{url_prefix}/icon_small/en_US", - zh_Hans=f"{url_prefix}/icon_small/zh_Hans" + en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans" ) if self.icon_large is not None: self.icon_large = I18nObject( - en_US=f"{url_prefix}/icon_large/en_US", - zh_Hans=f"{url_prefix}/icon_large/zh_Hans" + en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" ) @@ -86,6 +87,7 @@ class ProviderWithModelsResponse(BaseModel): """ Model class for provider with models response. """ + provider: str label: I18nObject icon_small: Optional[I18nObject] = None @@ -96,18 +98,15 @@ class ProviderWithModelsResponse(BaseModel): def __init__(self, **data) -> None: super().__init__(**data) - url_prefix = (dify_config.CONSOLE_API_URL - + f"/console/api/workspaces/current/model-providers/{self.provider}") + url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}" if self.icon_small is not None: self.icon_small = I18nObject( - en_US=f"{url_prefix}/icon_small/en_US", - zh_Hans=f"{url_prefix}/icon_small/zh_Hans" + en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans" ) if self.icon_large is not None: self.icon_large = I18nObject( - en_US=f"{url_prefix}/icon_large/en_US", - zh_Hans=f"{url_prefix}/icon_large/zh_Hans" + en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" ) @@ -119,18 +118,15 @@ class SimpleProviderEntityResponse(SimpleProviderEntity): def __init__(self, **data) -> None: super().__init__(**data) - url_prefix = (dify_config.CONSOLE_API_URL - + f"/console/api/workspaces/current/model-providers/{self.provider}") + url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}" if self.icon_small is not None: self.icon_small = I18nObject( - en_US=f"{url_prefix}/icon_small/en_US", - zh_Hans=f"{url_prefix}/icon_small/zh_Hans" + en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans" ) if self.icon_large is not None: self.icon_large = I18nObject( - en_US=f"{url_prefix}/icon_large/en_US", - zh_Hans=f"{url_prefix}/icon_large/zh_Hans" + en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" ) @@ -138,6 +134,7 @@ class DefaultModelResponse(BaseModel): """ Default model entity. """ + model: str model_type: ModelType provider: SimpleProviderEntityResponse @@ -150,6 +147,7 @@ class ModelWithProviderEntityResponse(ModelWithProviderEntity): """ Model with provider entity. """ + provider: SimpleProviderEntityResponse def __init__(self, model: ModelWithProviderEntity) -> None: diff --git a/api/services/errors/account.py b/api/services/errors/account.py index ddc2dbdea8..cae31c5066 100644 --- a/api/services/errors/account.py +++ b/api/services/errors/account.py @@ -55,4 +55,3 @@ class RoleAlreadyAssignedError(BaseServiceError): class RateLimitExceededError(BaseServiceError): pass - diff --git a/api/services/errors/base.py b/api/services/errors/base.py index f5d41e17f1..1fed71cf9e 100644 --- a/api/services/errors/base.py +++ b/api/services/errors/base.py @@ -1,3 +1,3 @@ class BaseServiceError(Exception): def __init__(self, description: str = None): - self.description = description \ No newline at end of file + self.description = description diff --git a/api/services/errors/llm.py b/api/services/errors/llm.py new file mode 100644 index 0000000000..e4fac6f745 --- /dev/null +++ b/api/services/errors/llm.py @@ -0,0 +1,19 @@ +from typing import Optional + + +class InvokeError(Exception): + """Base class for all LLM exceptions.""" + + description: Optional[str] = None + + def __init__(self, description: Optional[str] = None) -> None: + self.description = description + + def __str__(self): + return self.description or self.__class__.__name__ + + +class InvokeRateLimitError(InvokeError): + """Raised when the Invoke returns rate limit error.""" + + description = "Rate Limit Error" diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 83e675a9d2..4d5812c6c6 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -6,8 +6,8 @@ from services.enterprise.enterprise_service import EnterpriseService class SubscriptionModel(BaseModel): - plan: str = 'sandbox' - interval: str = '' + plan: str = "sandbox" + interval: str = "" class BillingModel(BaseModel): @@ -27,7 +27,7 @@ class FeatureModel(BaseModel): vector_space: LimitationModel = LimitationModel(size=0, limit=5) annotation_quota_limit: LimitationModel = LimitationModel(size=0, limit=10) documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50) - docs_processing: str = 'standard' + docs_processing: str = "standard" can_replace_logo: bool = False model_load_balancing_enabled: bool = False dataset_operator_enabled: bool = False @@ -38,13 +38,13 @@ class FeatureModel(BaseModel): class SystemFeatureModel(BaseModel): sso_enforced_for_signin: bool = False - sso_enforced_for_signin_protocol: str = '' + sso_enforced_for_signin_protocol: str = "" sso_enforced_for_web: bool = False - sso_enforced_for_web_protocol: str = '' + sso_enforced_for_web_protocol: str = "" + enable_web_sso_switch_component: bool = False class FeatureService: - @classmethod def get_features(cls, tenant_id: str) -> FeatureModel: features = FeatureModel() @@ -61,6 +61,7 @@ class FeatureService: system_features = SystemFeatureModel() if dify_config.ENTERPRISE_ENABLED: + system_features.enable_web_sso_switch_component = True cls._fulfill_params_from_enterprise(system_features) return system_features @@ -75,44 +76,44 @@ class FeatureService: def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str): billing_info = BillingService.get_info(tenant_id) - features.billing.enabled = billing_info['enabled'] - features.billing.subscription.plan = billing_info['subscription']['plan'] - features.billing.subscription.interval = billing_info['subscription']['interval'] + features.billing.enabled = billing_info["enabled"] + features.billing.subscription.plan = billing_info["subscription"]["plan"] + features.billing.subscription.interval = billing_info["subscription"]["interval"] - if 'members' in billing_info: - features.members.size = billing_info['members']['size'] - features.members.limit = billing_info['members']['limit'] + if "members" in billing_info: + features.members.size = billing_info["members"]["size"] + features.members.limit = billing_info["members"]["limit"] - if 'apps' in billing_info: - features.apps.size = billing_info['apps']['size'] - features.apps.limit = billing_info['apps']['limit'] + if "apps" in billing_info: + features.apps.size = billing_info["apps"]["size"] + features.apps.limit = billing_info["apps"]["limit"] - if 'vector_space' in billing_info: - features.vector_space.size = billing_info['vector_space']['size'] - features.vector_space.limit = billing_info['vector_space']['limit'] + if "vector_space" in billing_info: + features.vector_space.size = billing_info["vector_space"]["size"] + features.vector_space.limit = billing_info["vector_space"]["limit"] - if 'documents_upload_quota' in billing_info: - features.documents_upload_quota.size = billing_info['documents_upload_quota']['size'] - features.documents_upload_quota.limit = billing_info['documents_upload_quota']['limit'] + if "documents_upload_quota" in billing_info: + features.documents_upload_quota.size = billing_info["documents_upload_quota"]["size"] + features.documents_upload_quota.limit = billing_info["documents_upload_quota"]["limit"] - if 'annotation_quota_limit' in billing_info: - features.annotation_quota_limit.size = billing_info['annotation_quota_limit']['size'] - features.annotation_quota_limit.limit = billing_info['annotation_quota_limit']['limit'] + if "annotation_quota_limit" in billing_info: + features.annotation_quota_limit.size = billing_info["annotation_quota_limit"]["size"] + features.annotation_quota_limit.limit = billing_info["annotation_quota_limit"]["limit"] - if 'docs_processing' in billing_info: - features.docs_processing = billing_info['docs_processing'] + if "docs_processing" in billing_info: + features.docs_processing = billing_info["docs_processing"] - if 'can_replace_logo' in billing_info: - features.can_replace_logo = billing_info['can_replace_logo'] + if "can_replace_logo" in billing_info: + features.can_replace_logo = billing_info["can_replace_logo"] - if 'model_load_balancing_enabled' in billing_info: - features.model_load_balancing_enabled = billing_info['model_load_balancing_enabled'] + if "model_load_balancing_enabled" in billing_info: + features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"] @classmethod def _fulfill_params_from_enterprise(cls, features): enterprise_info = EnterpriseService.get_info() - features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin'] - features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol'] - features.sso_enforced_for_web = enterprise_info['sso_enforced_for_web'] - features.sso_enforced_for_web_protocol = enterprise_info['sso_enforced_for_web_protocol'] + features.sso_enforced_for_signin = enterprise_info["sso_enforced_for_signin"] + features.sso_enforced_for_signin_protocol = enterprise_info["sso_enforced_for_signin_protocol"] + features.sso_enforced_for_web = enterprise_info["sso_enforced_for_web"] + features.sso_enforced_for_web_protocol = enterprise_info["sso_enforced_for_web_protocol"] diff --git a/api/services/file_service.py b/api/services/file_service.py index 9139962240..5780abb2be 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -17,27 +17,45 @@ from models.account import Account from models.model import EndUser, UploadFile from services.errors.file import FileTooLargeError, UnsupportedFileTypeError -IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg'] +IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) -ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'xls', 'docx', 'csv'] -UNSTRUCTURED_ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'xls', - 'docx', 'csv', 'eml', 'msg', 'pptx', 'ppt', 'xml', 'epub'] +ALLOWED_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"] +UNSTRUCTURED_ALLOWED_EXTENSIONS = [ + "txt", + "markdown", + "md", + "pdf", + "html", + "htm", + "xlsx", + "xls", + "docx", + "csv", + "eml", + "msg", + "pptx", + "ppt", + "xml", + "epub", +] PREVIEW_WORDS_LIMIT = 3000 class FileService: - @staticmethod def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile: filename = file.filename - extension = file.filename.split('.')[-1] + extension = file.filename.split(".")[-1] if len(filename) > 200: - filename = filename.split('.')[0][:200] + '.' + extension + filename = filename.split(".")[0][:200] + "." + extension etl_type = dify_config.ETL_TYPE - allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS if etl_type == 'Unstructured' \ + allowed_extensions = ( + UNSTRUCTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS + if etl_type == "Unstructured" else ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS + ) if extension.lower() not in allowed_extensions: raise UnsupportedFileTypeError() elif only_image and extension.lower() not in IMAGE_EXTENSIONS: @@ -55,7 +73,7 @@ class FileService: file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024 if file_size > file_size_limit: - message = f'File size exceeded. {file_size} > {file_size_limit}' + message = f"File size exceeded. {file_size} > {file_size_limit}" raise FileTooLargeError(message) # user uuid as file name @@ -67,7 +85,7 @@ class FileService: # end_user current_tenant_id = user.tenant_id - file_key = 'upload_files/' + current_tenant_id + '/' + file_uuid + '.' + extension + file_key = "upload_files/" + current_tenant_id + "/" + file_uuid + "." + extension # save file to storage storage.save(file_key, file_content) @@ -81,11 +99,11 @@ class FileService: size=file_size, extension=extension, mime_type=file.mimetype, - created_by_role=('account' if isinstance(user, Account) else 'end_user'), + created_by_role=("account" if isinstance(user, Account) else "end_user"), created_by=user.id, created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), used=False, - hash=hashlib.sha3_256(file_content).hexdigest() + hash=hashlib.sha3_256(file_content).hexdigest(), ) db.session.add(upload_file) @@ -99,10 +117,10 @@ class FileService: text_name = text_name[:200] # user uuid as file name file_uuid = str(uuid.uuid4()) - file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.txt' + file_key = "upload_files/" + current_user.current_tenant_id + "/" + file_uuid + ".txt" # save file to storage - storage.save(file_key, text.encode('utf-8')) + storage.save(file_key, text.encode("utf-8")) # save file to db upload_file = UploadFile( @@ -111,13 +129,13 @@ class FileService: key=file_key, name=text_name, size=len(text), - extension='txt', - mime_type='text/plain', + extension="txt", + mime_type="text/plain", created_by=current_user.id, created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), used=True, used_by=current_user.id, - used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), ) db.session.add(upload_file) @@ -127,9 +145,7 @@ class FileService: @staticmethod def get_file_preview(file_id: str) -> str: - upload_file = db.session.query(UploadFile) \ - .filter(UploadFile.id == file_id) \ - .first() + upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() if not upload_file: raise NotFound("File not found") @@ -137,12 +153,12 @@ class FileService: # extract text from file extension = upload_file.extension etl_type = dify_config.ETL_TYPE - allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS + allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS if extension.lower() not in allowed_extensions: raise UnsupportedFileTypeError() text = ExtractProcessor.load_from_upload_file(upload_file, return_text=True) - text = text[0:PREVIEW_WORDS_LIMIT] if text else '' + text = text[0:PREVIEW_WORDS_LIMIT] if text else "" return text @@ -152,9 +168,7 @@ class FileService: if not result: raise NotFound("File not found or signature is invalid") - upload_file = db.session.query(UploadFile) \ - .filter(UploadFile.id == file_id) \ - .first() + upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() if not upload_file: raise NotFound("File not found or signature is invalid") @@ -170,9 +184,7 @@ class FileService: @staticmethod def get_public_image_preview(file_id: str) -> tuple[Generator, str]: - upload_file = db.session.query(UploadFile) \ - .filter(UploadFile.id == file_id) \ - .first() + upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() if not upload_file: raise NotFound("File not found or signature is invalid") diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index de5f6994b0..db99064814 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -9,14 +9,11 @@ from models.account import Account from models.dataset import Dataset, DatasetQuery, DocumentSegment default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } @@ -27,9 +24,9 @@ class HitTestingService: return { "query": { "content": query, - "tsne_position": {'x': 0, 'y': 0}, + "tsne_position": {"x": 0, "y": 0}, }, - "records": [] + "records": [], } start = time.perf_counter() @@ -38,28 +35,28 @@ class HitTestingService: if not retrieval_model: retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model - all_documents = RetrievalService.retrieve(retrival_method=retrieval_model.get('search_method', 'semantic_search'), - dataset_id=dataset.id, - query=cls.escape_query_for_search(query), - top_k=retrieval_model.get('top_k', 2), - score_threshold=retrieval_model.get('score_threshold', .0) - if retrieval_model['score_threshold_enabled'] else None, - reranking_model=retrieval_model.get('reranking_model', None) - if retrieval_model['reranking_enable'] else None, - reranking_mode=retrieval_model.get('reranking_mode') - if retrieval_model.get('reranking_mode') else 'reranking_model', - weights=retrieval_model.get('weights', None), - ) + all_documents = RetrievalService.retrieve( + retrival_method=retrieval_model.get("search_method", "semantic_search"), + dataset_id=dataset.id, + query=cls.escape_query_for_search(query), + top_k=retrieval_model.get("top_k", 2), + score_threshold=retrieval_model.get("score_threshold", 0.0) + if retrieval_model["score_threshold_enabled"] + else None, + reranking_model=retrieval_model.get("reranking_model", None) + if retrieval_model["reranking_enable"] + else None, + reranking_mode=retrieval_model.get("reranking_mode") + if retrieval_model.get("reranking_mode") + else "reranking_model", + weights=retrieval_model.get("weights", None), + ) end = time.perf_counter() logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds") dataset_query = DatasetQuery( - dataset_id=dataset.id, - content=query, - source='hit_testing', - created_by_role='account', - created_by=account.id + dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id ) db.session.add(dataset_query) @@ -72,14 +69,18 @@ class HitTestingService: i = 0 records = [] for document in documents: - index_node_id = document.metadata['doc_id'] + index_node_id = document.metadata["doc_id"] - segment = db.session.query(DocumentSegment).filter( - DocumentSegment.dataset_id == dataset.id, - DocumentSegment.enabled == True, - DocumentSegment.status == 'completed', - DocumentSegment.index_node_id == index_node_id - ).first() + segment = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.index_node_id == index_node_id, + ) + .first() + ) if not segment: i += 1 @@ -87,7 +88,7 @@ class HitTestingService: record = { "segment": segment, - "score": document.metadata.get('score', None), + "score": document.metadata.get("score", None), } records.append(record) @@ -98,15 +99,15 @@ class HitTestingService: "query": { "content": query, }, - "records": records + "records": records, } @classmethod def hit_testing_args_check(cls, args): - query = args['query'] + query = args["query"] if not query or len(query) > 250: - raise ValueError('Query is required and cannot exceed 250 characters') + raise ValueError("Query is required and cannot exceed 250 characters") @staticmethod def escape_query_for_search(query: str) -> str: diff --git a/api/services/message_service.py b/api/services/message_service.py index 491a914c77..ecb121c36e 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -27,8 +27,14 @@ from services.workflow_service import WorkflowService class MessageService: @classmethod - def pagination_by_first_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], - conversation_id: str, first_id: Optional[str], limit: int) -> InfiniteScrollPagination: + def pagination_by_first_id( + cls, + app_model: App, + user: Optional[Union[Account, EndUser]], + conversation_id: str, + first_id: Optional[str], + limit: int, + ) -> InfiniteScrollPagination: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) @@ -36,52 +42,69 @@ class MessageService: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) conversation = ConversationService.get_conversation( - app_model=app_model, - user=user, - conversation_id=conversation_id + app_model=app_model, user=user, conversation_id=conversation_id ) if first_id: - first_message = db.session.query(Message) \ - .filter(Message.conversation_id == conversation.id, Message.id == first_id).first() + first_message = ( + db.session.query(Message) + .filter(Message.conversation_id == conversation.id, Message.id == first_id) + .first() + ) if not first_message: raise FirstMessageNotExistsError() - history_messages = db.session.query(Message).filter( - Message.conversation_id == conversation.id, - Message.created_at < first_message.created_at, - Message.id != first_message.id - ) \ - .order_by(Message.created_at.desc()).limit(limit).all() + history_messages = ( + db.session.query(Message) + .filter( + Message.conversation_id == conversation.id, + Message.created_at < first_message.created_at, + Message.id != first_message.id, + ) + .order_by(Message.created_at.desc()) + .limit(limit) + .all() + ) else: - history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \ - .order_by(Message.created_at.desc()).limit(limit).all() + history_messages = ( + db.session.query(Message) + .filter(Message.conversation_id == conversation.id) + .order_by(Message.created_at.desc()) + .limit(limit) + .all() + ) has_more = False if len(history_messages) == limit: current_page_first_message = history_messages[-1] - rest_count = db.session.query(Message).filter( - Message.conversation_id == conversation.id, - Message.created_at < current_page_first_message.created_at, - Message.id != current_page_first_message.id - ).count() + rest_count = ( + db.session.query(Message) + .filter( + Message.conversation_id == conversation.id, + Message.created_at < current_page_first_message.created_at, + Message.id != current_page_first_message.id, + ) + .count() + ) if rest_count > 0: has_more = True history_messages = list(reversed(history_messages)) - return InfiniteScrollPagination( - data=history_messages, - limit=limit, - has_more=has_more - ) + return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more) @classmethod - def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], - last_id: Optional[str], limit: int, conversation_id: Optional[str] = None, - include_ids: Optional[list] = None) -> InfiniteScrollPagination: + def pagination_by_last_id( + cls, + app_model: App, + user: Optional[Union[Account, EndUser]], + last_id: Optional[str], + limit: int, + conversation_id: Optional[str] = None, + include_ids: Optional[list] = None, + ) -> InfiniteScrollPagination: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) @@ -89,9 +112,7 @@ class MessageService: if conversation_id is not None: conversation = ConversationService.get_conversation( - app_model=app_model, - user=user, - conversation_id=conversation_id + app_model=app_model, user=user, conversation_id=conversation_id ) base_query = base_query.filter(Message.conversation_id == conversation.id) @@ -105,10 +126,12 @@ class MessageService: if not last_message: raise LastMessageNotExistsError() - history_messages = base_query.filter( - Message.created_at < last_message.created_at, - Message.id != last_message.id - ).order_by(Message.created_at.desc()).limit(limit).all() + history_messages = ( + base_query.filter(Message.created_at < last_message.created_at, Message.id != last_message.id) + .order_by(Message.created_at.desc()) + .limit(limit) + .all() + ) else: history_messages = base_query.order_by(Message.created_at.desc()).limit(limit).all() @@ -116,30 +139,22 @@ class MessageService: if len(history_messages) == limit: current_page_first_message = history_messages[-1] rest_count = base_query.filter( - Message.created_at < current_page_first_message.created_at, - Message.id != current_page_first_message.id + Message.created_at < current_page_first_message.created_at, Message.id != current_page_first_message.id ).count() if rest_count > 0: has_more = True - return InfiniteScrollPagination( - data=history_messages, - limit=limit, - has_more=has_more - ) + return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more) @classmethod - def create_feedback(cls, app_model: App, message_id: str, user: Optional[Union[Account, EndUser]], - rating: Optional[str]) -> MessageFeedback: + def create_feedback( + cls, app_model: App, message_id: str, user: Optional[Union[Account, EndUser]], rating: Optional[str] + ) -> MessageFeedback: if not user: - raise ValueError('user cannot be None') + raise ValueError("user cannot be None") - message = cls.get_message( - app_model=app_model, - user=user, - message_id=message_id - ) + message = cls.get_message(app_model=app_model, user=user, message_id=message_id) feedback = message.user_feedback if isinstance(user, EndUser) else message.admin_feedback @@ -148,14 +163,14 @@ class MessageService: elif rating and feedback: feedback.rating = rating elif not rating and not feedback: - raise ValueError('rating cannot be None when feedback not exists') + raise ValueError("rating cannot be None when feedback not exists") else: feedback = MessageFeedback( app_id=app_model.id, conversation_id=message.conversation_id, message_id=message.id, rating=rating, - from_source=('user' if isinstance(user, EndUser) else 'admin'), + from_source=("user" if isinstance(user, EndUser) else "admin"), from_end_user_id=(user.id if isinstance(user, EndUser) else None), from_account_id=(user.id if isinstance(user, Account) else None), ) @@ -167,13 +182,17 @@ class MessageService: @classmethod def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app_model.id, - Message.from_source == ('api' if isinstance(user, EndUser) else 'console'), - Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), - Message.from_account_id == (user.id if isinstance(user, Account) else None), - ).first() + message = ( + db.session.query(Message) + .filter( + Message.id == message_id, + Message.app_id == app_model.id, + Message.from_source == ("api" if isinstance(user, EndUser) else "console"), + Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), + Message.from_account_id == (user.id if isinstance(user, Account) else None), + ) + .first() + ) if not message: raise MessageNotExistsError() @@ -181,27 +200,22 @@ class MessageService: return message @classmethod - def get_suggested_questions_after_answer(cls, app_model: App, user: Optional[Union[Account, EndUser]], - message_id: str, invoke_from: InvokeFrom) -> list[Message]: + def get_suggested_questions_after_answer( + cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str, invoke_from: InvokeFrom + ) -> list[Message]: if not user: - raise ValueError('user cannot be None') + raise ValueError("user cannot be None") - message = cls.get_message( - app_model=app_model, - user=user, - message_id=message_id - ) + message = cls.get_message(app_model=app_model, user=user, message_id=message_id) conversation = ConversationService.get_conversation( - app_model=app_model, - conversation_id=message.conversation_id, - user=user + app_model=app_model, conversation_id=message.conversation_id, user=user ) if not conversation: raise ConversationNotExistsError() - if conversation.status != 'normal': + if conversation.status != "normal": raise ConversationCompletedError() model_manager = ModelManager() @@ -216,24 +230,23 @@ class MessageService: if workflow is None: return [] - app_config = AdvancedChatAppConfigManager.get_app_config( - app_model=app_model, - workflow=workflow - ) + app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) if not app_config.additional_features.suggested_questions_after_answer: raise SuggestedQuestionsAfterAnswerDisabledError() model_instance = model_manager.get_default_model_instance( - tenant_id=app_model.tenant_id, - model_type=ModelType.LLM + tenant_id=app_model.tenant_id, model_type=ModelType.LLM ) else: if not conversation.override_model_configs: - app_model_config = db.session.query(AppModelConfig).filter( - AppModelConfig.id == conversation.app_model_config_id, - AppModelConfig.app_id == app_model.id - ).first() + app_model_config = ( + db.session.query(AppModelConfig) + .filter( + AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id + ) + .first() + ) else: conversation_override_model_configs = json.loads(conversation.override_model_configs) app_model_config = AppModelConfig( @@ -249,16 +262,13 @@ class MessageService: model_instance = model_manager.get_model_instance( tenant_id=app_model.tenant_id, - provider=app_model_config.model_dict['provider'], + provider=app_model_config.model_dict["provider"], model_type=ModelType.LLM, - model=app_model_config.model_dict['name'] + model=app_model_config.model_dict["name"], ) # get memory of conversation (read-only) - memory = TokenBufferMemory( - conversation=conversation, - model_instance=model_instance - ) + memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) histories = memory.get_history_prompt_text( max_token_limit=3000, @@ -267,18 +277,14 @@ class MessageService: with measure_time() as timer: questions = LLMGenerator.generate_suggested_questions_after_answer( - tenant_id=app_model.tenant_id, - histories=histories + tenant_id=app_model.tenant_id, histories=histories ) # get tracing instance trace_manager = TraceQueueManager(app_id=app_model.id) trace_manager.add_trace_task( TraceTask( - TraceTaskName.SUGGESTED_QUESTION_TRACE, - message_id=message_id, - suggested_question=questions, - timer=timer + TraceTaskName.SUGGESTED_QUESTION_TRACE, message_id=message_id, suggested_question=questions, timer=timer ) ) diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index 80eb72140d..e7b9422cfe 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -23,7 +23,6 @@ logger = logging.getLogger(__name__) class ModelLoadBalancingService: - def __init__(self) -> None: self.provider_manager = ProviderManager() @@ -46,10 +45,7 @@ class ModelLoadBalancingService: raise ValueError(f"Provider {provider} does not exist.") # Enable model load balancing - provider_configuration.enable_model_load_balancing( - model=model, - model_type=ModelType.value_of(model_type) - ) + provider_configuration.enable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type)) def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: """ @@ -70,13 +66,11 @@ class ModelLoadBalancingService: raise ValueError(f"Provider {provider} does not exist.") # disable model load balancing - provider_configuration.disable_model_load_balancing( - model=model, - model_type=ModelType.value_of(model_type) - ) + provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type)) - def get_load_balancing_configs(self, tenant_id: str, provider: str, model: str, model_type: str) \ - -> tuple[bool, list[dict]]: + def get_load_balancing_configs( + self, tenant_id: str, provider: str, model: str, model_type: str + ) -> tuple[bool, list[dict]]: """ Get load balancing configurations. :param tenant_id: workspace id @@ -107,20 +101,24 @@ class ModelLoadBalancingService: is_load_balancing_enabled = True # Get load balancing configurations - load_balancing_configs = db.session.query(LoadBalancingModelConfig) \ + load_balancing_configs = ( + db.session.query(LoadBalancingModelConfig) .filter( - LoadBalancingModelConfig.tenant_id == tenant_id, - LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), - LoadBalancingModelConfig.model_name == model - ).order_by(LoadBalancingModelConfig.created_at).all() + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + ) + .order_by(LoadBalancingModelConfig.created_at) + .all() + ) if provider_configuration.custom_configuration.provider: # check if the inherit configuration exists, # inherit is represented for the provider or model custom credentials inherit_config_exists = False for load_balancing_config in load_balancing_configs: - if load_balancing_config.name == '__inherit__': + if load_balancing_config.name == "__inherit__": inherit_config_exists = True break @@ -133,7 +131,7 @@ class ModelLoadBalancingService: else: # move the inherit configuration to the first for i, load_balancing_config in enumerate(load_balancing_configs[:]): - if load_balancing_config.name == '__inherit__': + if load_balancing_config.name == "__inherit__": inherit_config = load_balancing_configs.pop(i) load_balancing_configs.insert(0, inherit_config) @@ -151,7 +149,7 @@ class ModelLoadBalancingService: provider=provider, model=model, model_type=model_type, - config_id=load_balancing_config.id + config_id=load_balancing_config.id, ) try: @@ -172,32 +170,32 @@ class ModelLoadBalancingService: if variable in credentials: try: credentials[variable] = encrypter.decrypt_token_with_decoding( - credentials.get(variable), - decoding_rsa_key, - decoding_cipher_rsa + credentials.get(variable), decoding_rsa_key, decoding_cipher_rsa ) except ValueError: pass # Obfuscate credentials credentials = provider_configuration.obfuscated_credentials( - credentials=credentials, - credential_form_schemas=credential_schemas.credential_form_schemas + credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas ) - datas.append({ - 'id': load_balancing_config.id, - 'name': load_balancing_config.name, - 'credentials': credentials, - 'enabled': load_balancing_config.enabled, - 'in_cooldown': in_cooldown, - 'ttl': ttl - }) + datas.append( + { + "id": load_balancing_config.id, + "name": load_balancing_config.name, + "credentials": credentials, + "enabled": load_balancing_config.enabled, + "in_cooldown": in_cooldown, + "ttl": ttl, + } + ) return is_load_balancing_enabled, datas - def get_load_balancing_config(self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str) \ - -> Optional[dict]: + def get_load_balancing_config( + self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str + ) -> Optional[dict]: """ Get load balancing configuration. :param tenant_id: workspace id @@ -219,14 +217,17 @@ class ModelLoadBalancingService: model_type = ModelType.value_of(model_type) # Get load balancing configurations - load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \ + load_balancing_model_config = ( + db.session.query(LoadBalancingModelConfig) .filter( - LoadBalancingModelConfig.tenant_id == tenant_id, - LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), - LoadBalancingModelConfig.model_name == model, - LoadBalancingModelConfig.id == config_id - ).first() + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + LoadBalancingModelConfig.id == config_id, + ) + .first() + ) if not load_balancing_model_config: return None @@ -244,19 +245,19 @@ class ModelLoadBalancingService: # Obfuscate credentials credentials = provider_configuration.obfuscated_credentials( - credentials=credentials, - credential_form_schemas=credential_schemas.credential_form_schemas + credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas ) return { - 'id': load_balancing_model_config.id, - 'name': load_balancing_model_config.name, - 'credentials': credentials, - 'enabled': load_balancing_model_config.enabled + "id": load_balancing_model_config.id, + "name": load_balancing_model_config.name, + "credentials": credentials, + "enabled": load_balancing_model_config.enabled, } - def _init_inherit_config(self, tenant_id: str, provider: str, model: str, model_type: ModelType) \ - -> LoadBalancingModelConfig: + def _init_inherit_config( + self, tenant_id: str, provider: str, model: str, model_type: ModelType + ) -> LoadBalancingModelConfig: """ Initialize the inherit configuration. :param tenant_id: workspace id @@ -271,18 +272,16 @@ class ModelLoadBalancingService: provider_name=provider, model_type=model_type.to_origin_model_type(), model_name=model, - name='__inherit__' + name="__inherit__", ) db.session.add(inherit_config) db.session.commit() return inherit_config - def update_load_balancing_configs(self, tenant_id: str, - provider: str, - model: str, - model_type: str, - configs: list[dict]) -> None: + def update_load_balancing_configs( + self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict] + ) -> None: """ Update load balancing configurations. :param tenant_id: workspace id @@ -304,15 +303,18 @@ class ModelLoadBalancingService: model_type = ModelType.value_of(model_type) if not isinstance(configs, list): - raise ValueError('Invalid load balancing configs') + raise ValueError("Invalid load balancing configs") - current_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \ + current_load_balancing_configs = ( + db.session.query(LoadBalancingModelConfig) .filter( - LoadBalancingModelConfig.tenant_id == tenant_id, - LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), - LoadBalancingModelConfig.model_name == model - ).all() + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + ) + .all() + ) # id as key, config as value current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs} @@ -320,25 +322,25 @@ class ModelLoadBalancingService: for config in configs: if not isinstance(config, dict): - raise ValueError('Invalid load balancing config') + raise ValueError("Invalid load balancing config") - config_id = config.get('id') - name = config.get('name') - credentials = config.get('credentials') - enabled = config.get('enabled') + config_id = config.get("id") + name = config.get("name") + credentials = config.get("credentials") + enabled = config.get("enabled") if not name: - raise ValueError('Invalid load balancing config name') + raise ValueError("Invalid load balancing config name") if enabled is None: - raise ValueError('Invalid load balancing config enabled') + raise ValueError("Invalid load balancing config enabled") # is config exists if config_id: config_id = str(config_id) if config_id not in current_load_balancing_configs_dict: - raise ValueError('Invalid load balancing config id: {}'.format(config_id)) + raise ValueError("Invalid load balancing config id: {}".format(config_id)) updated_config_ids.add(config_id) @@ -347,11 +349,11 @@ class ModelLoadBalancingService: # check duplicate name for current_load_balancing_config in current_load_balancing_configs: if current_load_balancing_config.id != config_id and current_load_balancing_config.name == name: - raise ValueError('Load balancing config name {} already exists'.format(name)) + raise ValueError("Load balancing config name {} already exists".format(name)) if credentials: if not isinstance(credentials, dict): - raise ValueError('Invalid load balancing config credentials') + raise ValueError("Invalid load balancing config credentials") # validate custom provider config credentials = self._custom_credentials_validate( @@ -361,7 +363,7 @@ class ModelLoadBalancingService: model=model, credentials=credentials, load_balancing_model_config=load_balancing_config, - validate=False + validate=False, ) # update load balancing config @@ -375,19 +377,19 @@ class ModelLoadBalancingService: self._clear_credentials_cache(tenant_id, config_id) else: # create load balancing config - if name == '__inherit__': - raise ValueError('Invalid load balancing config name') + if name == "__inherit__": + raise ValueError("Invalid load balancing config name") # check duplicate name for current_load_balancing_config in current_load_balancing_configs: if current_load_balancing_config.name == name: - raise ValueError('Load balancing config name {} already exists'.format(name)) + raise ValueError("Load balancing config name {} already exists".format(name)) if not credentials: - raise ValueError('Invalid load balancing config credentials') + raise ValueError("Invalid load balancing config credentials") if not isinstance(credentials, dict): - raise ValueError('Invalid load balancing config credentials') + raise ValueError("Invalid load balancing config credentials") # validate custom provider config credentials = self._custom_credentials_validate( @@ -396,7 +398,7 @@ class ModelLoadBalancingService: model_type=model_type, model=model, credentials=credentials, - validate=False + validate=False, ) # create load balancing config @@ -406,7 +408,7 @@ class ModelLoadBalancingService: model_type=model_type.to_origin_model_type(), model_name=model, name=name, - encrypted_config=json.dumps(credentials) + encrypted_config=json.dumps(credentials), ) db.session.add(load_balancing_model_config) @@ -420,12 +422,15 @@ class ModelLoadBalancingService: self._clear_credentials_cache(tenant_id, config_id) - def validate_load_balancing_credentials(self, tenant_id: str, - provider: str, - model: str, - model_type: str, - credentials: dict, - config_id: Optional[str] = None) -> None: + def validate_load_balancing_credentials( + self, + tenant_id: str, + provider: str, + model: str, + model_type: str, + credentials: dict, + config_id: Optional[str] = None, + ) -> None: """ Validate load balancing credentials. :param tenant_id: workspace id @@ -450,14 +455,17 @@ class ModelLoadBalancingService: load_balancing_model_config = None if config_id: # Get load balancing config - load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \ + load_balancing_model_config = ( + db.session.query(LoadBalancingModelConfig) .filter( - LoadBalancingModelConfig.tenant_id == tenant_id, - LoadBalancingModelConfig.provider_name == provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), - LoadBalancingModelConfig.model_name == model, - LoadBalancingModelConfig.id == config_id - ).first() + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + LoadBalancingModelConfig.id == config_id, + ) + .first() + ) if not load_balancing_model_config: raise ValueError(f"Load balancing config {config_id} does not exist.") @@ -469,16 +477,19 @@ class ModelLoadBalancingService: model_type=model_type, model=model, credentials=credentials, - load_balancing_model_config=load_balancing_model_config + load_balancing_model_config=load_balancing_model_config, ) - def _custom_credentials_validate(self, tenant_id: str, - provider_configuration: ProviderConfiguration, - model_type: ModelType, - model: str, - credentials: dict, - load_balancing_model_config: Optional[LoadBalancingModelConfig] = None, - validate: bool = True) -> dict: + def _custom_credentials_validate( + self, + tenant_id: str, + provider_configuration: ProviderConfiguration, + model_type: ModelType, + model: str, + credentials: dict, + load_balancing_model_config: Optional[LoadBalancingModelConfig] = None, + validate: bool = True, + ) -> dict: """ Validate custom credentials. :param tenant_id: workspace id @@ -521,12 +532,11 @@ class ModelLoadBalancingService: provider=provider_configuration.provider.provider, model_type=model_type, model=model, - credentials=credentials + credentials=credentials, ) else: credentials = model_provider_factory.provider_credentials_validate( - provider=provider_configuration.provider.provider, - credentials=credentials + provider=provider_configuration.provider.provider, credentials=credentials ) for key, value in credentials.items(): @@ -535,8 +545,9 @@ class ModelLoadBalancingService: return credentials - def _get_credential_schema(self, provider_configuration: ProviderConfiguration) \ - -> ModelCredentialSchema | ProviderCredentialSchema: + def _get_credential_schema( + self, provider_configuration: ProviderConfiguration + ) -> ModelCredentialSchema | ProviderCredentialSchema: """ Get form schemas. :param provider_configuration: provider configuration @@ -558,9 +569,7 @@ class ModelLoadBalancingService: :return: """ provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=tenant_id, - identity_id=config_id, - cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL + tenant_id=tenant_id, identity_id=config_id, cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL ) provider_model_credentials_cache.delete() diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 385af685f9..c0f3c40762 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -30,6 +30,7 @@ class ModelProviderService: """ Model Provider Service """ + def __init__(self) -> None: self.provider_manager = ProviderManager() @@ -72,8 +73,8 @@ class ModelProviderService: system_configuration=SystemConfigurationResponse( enabled=provider_configuration.system_configuration.enabled, current_quota_type=provider_configuration.system_configuration.current_quota_type, - quota_configurations=provider_configuration.system_configuration.quota_configurations - ) + quota_configurations=provider_configuration.system_configuration.quota_configurations, + ), ) provider_responses.append(provider_response) @@ -94,9 +95,9 @@ class ModelProviderService: provider_configurations = self.provider_manager.get_configurations(tenant_id) # Get provider available models - return [ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models( - provider=provider - )] + return [ + ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(provider=provider) + ] def get_provider_credentials(self, tenant_id: str, provider: str) -> dict: """ @@ -194,13 +195,12 @@ class ModelProviderService: # Get model custom credentials from ProviderModel if exists return provider_configuration.get_custom_model_credentials( - model_type=ModelType.value_of(model_type), - model=model, - obfuscated=True + model_type=ModelType.value_of(model_type), model=model, obfuscated=True ) - def model_credentials_validate(self, tenant_id: str, provider: str, model_type: str, model: str, - credentials: dict) -> None: + def model_credentials_validate( + self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict + ) -> None: """ validate model credentials. @@ -221,13 +221,12 @@ class ModelProviderService: # Validate model credentials provider_configuration.custom_model_credentials_validate( - model_type=ModelType.value_of(model_type), - model=model, - credentials=credentials + model_type=ModelType.value_of(model_type), model=model, credentials=credentials ) - def save_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str, - credentials: dict) -> None: + def save_model_credentials( + self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict + ) -> None: """ save model credentials. @@ -248,9 +247,7 @@ class ModelProviderService: # Add or update custom model credentials provider_configuration.add_or_update_custom_model_credentials( - model_type=ModelType.value_of(model_type), - model=model, - credentials=credentials + model_type=ModelType.value_of(model_type), model=model, credentials=credentials ) def remove_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> None: @@ -272,10 +269,7 @@ class ModelProviderService: raise ValueError(f"Provider {provider} does not exist.") # Remove custom model credentials - provider_configuration.delete_custom_model_credentials( - model_type=ModelType.value_of(model_type), - model=model - ) + provider_configuration.delete_custom_model_credentials(model_type=ModelType.value_of(model_type), model=model) def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[ProviderWithModelsResponse]: """ @@ -289,9 +283,7 @@ class ModelProviderService: provider_configurations = self.provider_manager.get_configurations(tenant_id) # Get provider available models - models = provider_configurations.get_models( - model_type=ModelType.value_of(model_type) - ) + models = provider_configurations.get_models(model_type=ModelType.value_of(model_type)) # Group models by provider provider_models = {} @@ -322,16 +314,19 @@ class ModelProviderService: icon_small=first_model.provider.icon_small, icon_large=first_model.provider.icon_large, status=CustomConfigurationStatus.ACTIVE, - models=[ProviderModelWithStatusEntity( - model=model.model, - label=model.label, - model_type=model.model_type, - features=model.features, - fetch_from=model.fetch_from, - model_properties=model.model_properties, - status=model.status, - load_balancing_enabled=model.load_balancing_enabled - ) for model in models] + models=[ + ProviderModelWithStatusEntity( + model=model.model, + label=model.label, + model_type=model.model_type, + features=model.features, + fetch_from=model.fetch_from, + model_properties=model.model_properties, + status=model.status, + load_balancing_enabled=model.load_balancing_enabled, + ) + for model in models + ], ) ) @@ -360,19 +355,13 @@ class ModelProviderService: model_type_instance = cast(LargeLanguageModel, model_type_instance) # fetch credentials - credentials = provider_configuration.get_current_credentials( - model_type=ModelType.LLM, - model=model - ) + credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model) if not credentials: return [] # Call get_parameter_rules method of model instance to get model parameter rules - return model_type_instance.get_parameter_rules( - model=model, - credentials=credentials - ) + return model_type_instance.get_parameter_rules(model=model, credentials=credentials) def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]: """ @@ -383,22 +372,26 @@ class ModelProviderService: :return: """ model_type_enum = ModelType.value_of(model_type) - result = self.provider_manager.get_default_model( - tenant_id=tenant_id, - model_type=model_type_enum - ) - - return DefaultModelResponse( - model=result.model, - model_type=result.model_type, - provider=SimpleProviderEntityResponse( - provider=result.provider.provider, - label=result.provider.label, - icon_small=result.provider.icon_small, - icon_large=result.provider.icon_large, - supported_model_types=result.provider.supported_model_types + result = self.provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type_enum) + try: + return ( + DefaultModelResponse( + model=result.model, + model_type=result.model_type, + provider=SimpleProviderEntityResponse( + provider=result.provider.provider, + label=result.provider.label, + icon_small=result.provider.icon_small, + icon_large=result.provider.icon_large, + supported_model_types=result.provider.supported_model_types, + ), + ) + if result + else None ) - ) if result else None + except Exception as e: + logger.info(f"get_default_model_of_model_type error: {e}") + return None def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None: """ @@ -412,13 +405,12 @@ class ModelProviderService: """ model_type_enum = ModelType.value_of(model_type) self.provider_manager.update_default_model_record( - tenant_id=tenant_id, - model_type=model_type_enum, - provider=provider, - model=model + tenant_id=tenant_id, model_type=model_type_enum, provider=provider, model=model ) - def get_model_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[Optional[bytes], Optional[str]]: + def get_model_provider_icon( + self, provider: str, icon_type: str, lang: str + ) -> tuple[Optional[bytes], Optional[str]]: """ get model provider icon. @@ -430,11 +422,11 @@ class ModelProviderService: provider_instance = model_provider_factory.get_provider_instance(provider) provider_schema = provider_instance.get_provider_schema() - if icon_type.lower() == 'icon_small': + if icon_type.lower() == "icon_small": if not provider_schema.icon_small: raise ValueError(f"Provider {provider} does not have small icon.") - if lang.lower() == 'zh_hans': + if lang.lower() == "zh_hans": file_name = provider_schema.icon_small.zh_Hans else: file_name = provider_schema.icon_small.en_US @@ -442,13 +434,15 @@ class ModelProviderService: if not provider_schema.icon_large: raise ValueError(f"Provider {provider} does not have large icon.") - if lang.lower() == 'zh_hans': + if lang.lower() == "zh_hans": file_name = provider_schema.icon_large.zh_Hans else: file_name = provider_schema.icon_large.en_US root_path = current_app.root_path - provider_instance_path = os.path.dirname(os.path.join(root_path, provider_instance.__class__.__module__.replace('.', '/'))) + provider_instance_path = os.path.dirname( + os.path.join(root_path, provider_instance.__class__.__module__.replace(".", "/")) + ) file_path = os.path.join(provider_instance_path, "_assets") file_path = os.path.join(file_path, file_name) @@ -456,10 +450,10 @@ class ModelProviderService: return None, None mimetype, _ = mimetypes.guess_type(file_path) - mimetype = mimetype or 'application/octet-stream' + mimetype = mimetype or "application/octet-stream" # read binary from file - with open(file_path, 'rb') as f: + with open(file_path, "rb") as f: byte_data = f.read() return byte_data, mimetype @@ -505,10 +499,7 @@ class ModelProviderService: raise ValueError(f"Provider {provider} does not exist.") # Enable model - provider_configuration.enable_model( - model=model, - model_type=ModelType.value_of(model_type) - ) + provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type)) def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: """ @@ -529,78 +520,49 @@ class ModelProviderService: raise ValueError(f"Provider {provider} does not exist.") # Enable model - provider_configuration.disable_model( - model=model, - model_type=ModelType.value_of(model_type) - ) + provider_configuration.disable_model(model=model, model_type=ModelType.value_of(model_type)) def free_quota_submit(self, tenant_id: str, provider: str): api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL") - api_url = api_base_url + '/api/v1/providers/apply' + api_url = api_base_url + "/api/v1/providers/apply" - headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {api_key}" - } - response = requests.post(api_url, headers=headers, json={'workspace_id': tenant_id, 'provider_name': provider}) + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} + response = requests.post(api_url, headers=headers, json={"workspace_id": tenant_id, "provider_name": provider}) if not response.ok: logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ") raise ValueError(f"Error: {response.status_code} ") - if response.json()["code"] != 'success': - raise ValueError( - f"error: {response.json()['message']}" - ) + if response.json()["code"] != "success": + raise ValueError(f"error: {response.json()['message']}") rst = response.json() - if rst['type'] == 'redirect': - return { - 'type': rst['type'], - 'redirect_url': rst['redirect_url'] - } + if rst["type"] == "redirect": + return {"type": rst["type"], "redirect_url": rst["redirect_url"]} else: - return { - 'type': rst['type'], - 'result': 'success' - } + return {"type": rst["type"], "result": "success"} def free_quota_qualification_verify(self, tenant_id: str, provider: str, token: Optional[str]): api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL") - api_url = api_base_url + '/api/v1/providers/qualification-verify' + api_url = api_base_url + "/api/v1/providers/qualification-verify" - headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {api_key}" - } - json_data = {'workspace_id': tenant_id, 'provider_name': provider} + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} + json_data = {"workspace_id": tenant_id, "provider_name": provider} if token: - json_data['token'] = token - response = requests.post(api_url, headers=headers, - json=json_data) + json_data["token"] = token + response = requests.post(api_url, headers=headers, json=json_data) if not response.ok: logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ") raise ValueError(f"Error: {response.status_code} ") rst = response.json() - if rst["code"] != 'success': - raise ValueError( - f"error: {rst['message']}" - ) + if rst["code"] != "success": + raise ValueError(f"error: {rst['message']}") - data = rst['data'] - if data['qualified'] is True: - return { - 'result': 'success', - 'provider_name': provider, - 'flag': True - } + data = rst["data"] + if data["qualified"] is True: + return {"result": "success", "provider_name": provider, "flag": True} else: - return { - 'result': 'success', - 'provider_name': provider, - 'flag': False, - 'reason': data['reason'] - } + return {"result": "success", "provider_name": provider, "flag": False, "reason": data["reason"]} diff --git a/api/services/moderation_service.py b/api/services/moderation_service.py index d472f8cfbc..dfb21e767f 100644 --- a/api/services/moderation_service.py +++ b/api/services/moderation_service.py @@ -4,17 +4,18 @@ from models.model import App, AppModelConfig class ModerationService: - def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult: app_model_config: AppModelConfig = None - app_model_config = db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() + app_model_config = ( + db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() + ) if not app_model_config: raise ValueError("app model config not found") - name = app_model_config.sensitive_word_avoidance_dict['type'] - config = app_model_config.sensitive_word_avoidance_dict['config'] + name = app_model_config.sensitive_word_avoidance_dict["type"] + config = app_model_config.sensitive_word_avoidance_dict["config"] moderation = ModerationFactory(name, app_id, app_model.tenant_id, config) return moderation.moderation_for_outputs(text) diff --git a/api/services/operation_service.py b/api/services/operation_service.py index 39f249dc24..8c8b64bcd5 100644 --- a/api/services/operation_service.py +++ b/api/services/operation_service.py @@ -4,15 +4,12 @@ import requests class OperationService: - base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL') - secret_key = os.environ.get('BILLING_API_SECRET_KEY', 'BILLING_API_SECRET_KEY') + base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL") + secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY") @classmethod def _send_request(cls, method, endpoint, json=None, params=None): - headers = { - "Content-Type": "application/json", - "Billing-Api-Secret-Key": cls.secret_key - } + headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} url = f"{cls.base_url}{endpoint}" response = requests.request(method, url, json=json, params=params, headers=headers) @@ -22,11 +19,11 @@ class OperationService: @classmethod def record_utm(cls, tenant_id: str, utm_info: dict): params = { - 'tenant_id': tenant_id, - 'utm_source': utm_info.get('utm_source', ''), - 'utm_medium': utm_info.get('utm_medium', ''), - 'utm_campaign': utm_info.get('utm_campaign', ''), - 'utm_content': utm_info.get('utm_content', ''), - 'utm_term': utm_info.get('utm_term', '') + "tenant_id": tenant_id, + "utm_source": utm_info.get("utm_source", ""), + "utm_medium": utm_info.get("utm_medium", ""), + "utm_campaign": utm_info.get("utm_campaign", ""), + "utm_content": utm_info.get("utm_content", ""), + "utm_term": utm_info.get("utm_term", ""), } - return cls._send_request('POST', '/tenant_utms', params=params) + return cls._send_request("POST", "/tenant_utms", params=params) diff --git a/api/services/ops_service.py b/api/services/ops_service.py index ffc12a9acd..35aa6817e1 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -12,20 +12,29 @@ class OpsService: :param tracing_provider: tracing provider :return: """ - trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter( - TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider - ).first() + trace_config_data: TraceAppConfig = ( + db.session.query(TraceAppConfig) + .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .first() + ) if not trace_config_data: return None # decrypt_token and obfuscated_token tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id - decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config(tenant_id, tracing_provider, trace_config_data.tracing_config) - decrypt_tracing_config = OpsTraceManager.obfuscated_decrypt_token(tracing_provider, decrypt_tracing_config) + decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config( + tenant_id, tracing_provider, trace_config_data.tracing_config + ) + new_decrypt_tracing_config = OpsTraceManager.obfuscated_decrypt_token(tracing_provider, decrypt_tracing_config) - trace_config_data.tracing_config = decrypt_tracing_config + if tracing_provider == "langfuse" and ( + "project_key" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_key") + ): + project_key = OpsTraceManager.get_trace_config_project_key(decrypt_tracing_config, tracing_provider) + new_decrypt_tracing_config.update({"project_key": project_key}) + trace_config_data.tracing_config = new_decrypt_tracing_config return trace_config_data.to_dict() @classmethod @@ -37,11 +46,13 @@ class OpsService: :param tracing_config: tracing config :return: """ - if tracing_provider not in provider_config_map.keys() and tracing_provider != None: + if tracing_provider not in provider_config_map.keys() and tracing_provider: return {"error": f"Invalid tracing provider: {tracing_provider}"} - config_class, other_keys = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['other_keys'] + config_class, other_keys = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["other_keys"], + ) default_config_instance = config_class(**tracing_config) for key in other_keys: if key in tracing_config and tracing_config[key] == "": @@ -51,10 +62,15 @@ class OpsService: if not OpsTraceManager.check_trace_config_is_effective(tracing_config, tracing_provider): return {"error": "Invalid Credentials"} + # get project key + project_key = OpsTraceManager.get_trace_config_project_key(tracing_config, tracing_provider) + # check if trace config already exists - trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter( - TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider - ).first() + trace_config_data: TraceAppConfig = ( + db.session.query(TraceAppConfig) + .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .first() + ) if trace_config_data: return None @@ -62,6 +78,8 @@ class OpsService: # get tenant id tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id tracing_config = OpsTraceManager.encrypt_tracing_config(tenant_id, tracing_provider, tracing_config) + if tracing_provider == "langfuse" and project_key: + tracing_config["project_key"] = project_key trace_config_data = TraceAppConfig( app_id=app_id, tracing_provider=tracing_provider, @@ -85,9 +103,11 @@ class OpsService: raise ValueError(f"Invalid tracing provider: {tracing_provider}") # check if trace config already exists - current_trace_config = db.session.query(TraceAppConfig).filter( - TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider - ).first() + current_trace_config = ( + db.session.query(TraceAppConfig) + .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .first() + ) if not current_trace_config: return None @@ -117,9 +137,11 @@ class OpsService: :param tracing_provider: tracing provider :return: """ - trace_config = db.session.query(TraceAppConfig).filter( - TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider - ).first() + trace_config = ( + db.session.query(TraceAppConfig) + .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .first() + ) if not trace_config: return None diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py index 1c1c5be17c..10abf0a764 100644 --- a/api/services/recommended_app_service.py +++ b/api/services/recommended_app_service.py @@ -16,7 +16,6 @@ logger = logging.getLogger(__name__) class RecommendedAppService: - builtin_data: Optional[dict] = None @classmethod @@ -27,21 +26,21 @@ class RecommendedAppService: :return: """ mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE - if mode == 'remote': + if mode == "remote": try: result = cls._fetch_recommended_apps_from_dify_official(language) except Exception as e: - logger.warning(f'fetch recommended apps from dify official failed: {e}, switch to built-in.') + logger.warning(f"fetch recommended apps from dify official failed: {e}, switch to built-in.") result = cls._fetch_recommended_apps_from_builtin(language) - elif mode == 'db': + elif mode == "db": result = cls._fetch_recommended_apps_from_db(language) - elif mode == 'builtin': + elif mode == "builtin": result = cls._fetch_recommended_apps_from_builtin(language) else: - raise ValueError(f'invalid fetch recommended apps mode: {mode}') + raise ValueError(f"invalid fetch recommended apps mode: {mode}") - if not result.get('recommended_apps') and language != 'en-US': - result = cls._fetch_recommended_apps_from_builtin('en-US') + if not result.get("recommended_apps") and language != "en-US": + result = cls._fetch_recommended_apps_from_builtin("en-US") return result @@ -52,16 +51,18 @@ class RecommendedAppService: :param language: language :return: """ - recommended_apps = db.session.query(RecommendedApp).filter( - RecommendedApp.is_listed == True, - RecommendedApp.language == language - ).all() + recommended_apps = ( + db.session.query(RecommendedApp) + .filter(RecommendedApp.is_listed == True, RecommendedApp.language == language) + .all() + ) if len(recommended_apps) == 0: - recommended_apps = db.session.query(RecommendedApp).filter( - RecommendedApp.is_listed == True, - RecommendedApp.language == languages[0] - ).all() + recommended_apps = ( + db.session.query(RecommendedApp) + .filter(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0]) + .all() + ) categories = set() recommended_apps_result = [] @@ -75,28 +76,28 @@ class RecommendedAppService: continue recommended_app_result = { - 'id': recommended_app.id, - 'app': { - 'id': app.id, - 'name': app.name, - 'mode': app.mode, - 'icon': app.icon, - 'icon_background': app.icon_background + "id": recommended_app.id, + "app": { + "id": app.id, + "name": app.name, + "mode": app.mode, + "icon": app.icon, + "icon_background": app.icon_background, }, - 'app_id': recommended_app.app_id, - 'description': site.description, - 'copyright': site.copyright, - 'privacy_policy': site.privacy_policy, - 'custom_disclaimer': site.custom_disclaimer, - 'category': recommended_app.category, - 'position': recommended_app.position, - 'is_listed': recommended_app.is_listed + "app_id": recommended_app.app_id, + "description": site.description, + "copyright": site.copyright, + "privacy_policy": site.privacy_policy, + "custom_disclaimer": site.custom_disclaimer, + "category": recommended_app.category, + "position": recommended_app.position, + "is_listed": recommended_app.is_listed, } recommended_apps_result.append(recommended_app_result) categories.add(recommended_app.category) # add category to categories - return {'recommended_apps': recommended_apps_result, 'categories': sorted(categories)} + return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)} @classmethod def _fetch_recommended_apps_from_dify_official(cls, language: str) -> dict: @@ -106,16 +107,16 @@ class RecommendedAppService: :return: """ domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN - url = f'{domain}/apps?language={language}' + url = f"{domain}/apps?language={language}" response = requests.get(url, timeout=(3, 10)) if response.status_code != 200: - raise ValueError(f'fetch recommended apps failed, status code: {response.status_code}') + raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}") result = response.json() if "categories" in result: result["categories"] = sorted(result["categories"]) - + return result @classmethod @@ -126,7 +127,7 @@ class RecommendedAppService: :return: """ builtin_data = cls._get_builtin_data() - return builtin_data.get('recommended_apps', {}).get(language) + return builtin_data.get("recommended_apps", {}).get(language) @classmethod def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]: @@ -136,18 +137,18 @@ class RecommendedAppService: :return: """ mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE - if mode == 'remote': + if mode == "remote": try: result = cls._fetch_recommended_app_detail_from_dify_official(app_id) except Exception as e: - logger.warning(f'fetch recommended app detail from dify official failed: {e}, switch to built-in.') + logger.warning(f"fetch recommended app detail from dify official failed: {e}, switch to built-in.") result = cls._fetch_recommended_app_detail_from_builtin(app_id) - elif mode == 'db': + elif mode == "db": result = cls._fetch_recommended_app_detail_from_db(app_id) - elif mode == 'builtin': + elif mode == "builtin": result = cls._fetch_recommended_app_detail_from_builtin(app_id) else: - raise ValueError(f'invalid fetch recommended app detail mode: {mode}') + raise ValueError(f"invalid fetch recommended app detail mode: {mode}") return result @@ -159,7 +160,7 @@ class RecommendedAppService: :return: """ domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN - url = f'{domain}/apps/{app_id}' + url = f"{domain}/apps/{app_id}" response = requests.get(url, timeout=(3, 10)) if response.status_code != 200: return None @@ -174,10 +175,11 @@ class RecommendedAppService: :return: """ # is in public recommended list - recommended_app = db.session.query(RecommendedApp).filter( - RecommendedApp.is_listed == True, - RecommendedApp.app_id == app_id - ).first() + recommended_app = ( + db.session.query(RecommendedApp) + .filter(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id) + .first() + ) if not recommended_app: return None @@ -188,12 +190,12 @@ class RecommendedAppService: return None return { - 'id': app_model.id, - 'name': app_model.name, - 'icon': app_model.icon, - 'icon_background': app_model.icon_background, - 'mode': app_model.mode, - 'export_data': AppDslService.export_dsl(app_model=app_model) + "id": app_model.id, + "name": app_model.name, + "icon": app_model.icon, + "icon_background": app_model.icon_background, + "mode": app_model.mode, + "export_data": AppDslService.export_dsl(app_model=app_model), } @classmethod @@ -204,7 +206,7 @@ class RecommendedAppService: :return: """ builtin_data = cls._get_builtin_data() - return builtin_data.get('app_details', {}).get(app_id) + return builtin_data.get("app_details", {}).get(app_id) @classmethod def _get_builtin_data(cls) -> dict: @@ -216,7 +218,7 @@ class RecommendedAppService: return cls.builtin_data root_path = current_app.root_path - with open(path.join(root_path, 'constants', 'recommended_apps.json'), encoding='utf-8') as f: + with open(path.join(root_path, "constants", "recommended_apps.json"), encoding="utf-8") as f: json_data = f.read() data = json.loads(json_data) cls.builtin_data = data @@ -229,27 +231,24 @@ class RecommendedAppService: Fetch all recommended apps and export datas :return: """ - templates = { - "recommended_apps": {}, - "app_details": {} - } + templates = {"recommended_apps": {}, "app_details": {}} for language in languages: try: result = cls._fetch_recommended_apps_from_dify_official(language) except Exception as e: - logger.warning(f'fetch recommended apps from dify official failed: {e}, skip.') + logger.warning(f"fetch recommended apps from dify official failed: {e}, skip.") continue - templates['recommended_apps'][language] = result + templates["recommended_apps"][language] = result - for recommended_app in result.get('recommended_apps'): - app_id = recommended_app.get('app_id') + for recommended_app in result.get("recommended_apps"): + app_id = recommended_app.get("app_id") # get app detail app_detail = cls._fetch_recommended_app_detail_from_dify_official(app_id) if not app_detail: continue - templates['app_details'][app_id] = app_detail + templates["app_details"][app_id] = app_detail return templates diff --git a/api/services/saved_message_service.py b/api/services/saved_message_service.py index f1113c1505..9fe3cecce7 100644 --- a/api/services/saved_message_service.py +++ b/api/services/saved_message_service.py @@ -10,46 +10,48 @@ from services.message_service import MessageService class SavedMessageService: @classmethod - def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], - last_id: Optional[str], limit: int) -> InfiniteScrollPagination: - saved_messages = db.session.query(SavedMessage).filter( - SavedMessage.app_id == app_model.id, - SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - SavedMessage.created_by == user.id - ).order_by(SavedMessage.created_at.desc()).all() + def pagination_by_last_id( + cls, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int + ) -> InfiniteScrollPagination: + saved_messages = ( + db.session.query(SavedMessage) + .filter( + SavedMessage.app_id == app_model.id, + SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + SavedMessage.created_by == user.id, + ) + .order_by(SavedMessage.created_at.desc()) + .all() + ) message_ids = [sm.message_id for sm in saved_messages] return MessageService.pagination_by_last_id( - app_model=app_model, - user=user, - last_id=last_id, - limit=limit, - include_ids=message_ids + app_model=app_model, user=user, last_id=last_id, limit=limit, include_ids=message_ids ) @classmethod def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): - saved_message = db.session.query(SavedMessage).filter( - SavedMessage.app_id == app_model.id, - SavedMessage.message_id == message_id, - SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - SavedMessage.created_by == user.id - ).first() + saved_message = ( + db.session.query(SavedMessage) + .filter( + SavedMessage.app_id == app_model.id, + SavedMessage.message_id == message_id, + SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + SavedMessage.created_by == user.id, + ) + .first() + ) if saved_message: return - message = MessageService.get_message( - app_model=app_model, - user=user, - message_id=message_id - ) + message = MessageService.get_message(app_model=app_model, user=user, message_id=message_id) saved_message = SavedMessage( app_id=app_model.id, message_id=message.id, - created_by_role='account' if isinstance(user, Account) else 'end_user', - created_by=user.id + created_by_role="account" if isinstance(user, Account) else "end_user", + created_by=user.id, ) db.session.add(saved_message) @@ -57,12 +59,16 @@ class SavedMessageService: @classmethod def delete(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): - saved_message = db.session.query(SavedMessage).filter( - SavedMessage.app_id == app_model.id, - SavedMessage.message_id == message_id, - SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - SavedMessage.created_by == user.id - ).first() + saved_message = ( + db.session.query(SavedMessage) + .filter( + SavedMessage.app_id == app_model.id, + SavedMessage.message_id == message_id, + SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + SavedMessage.created_by == user.id, + ) + .first() + ) if not saved_message: return diff --git a/api/services/tag_service.py b/api/services/tag_service.py index d6eba38fbd..0c17485a9f 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -12,38 +12,32 @@ from models.model import App, Tag, TagBinding class TagService: @staticmethod def get_tags(tag_type: str, current_tenant_id: str, keyword: str = None) -> list: - query = db.session.query( - Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label('binding_count') - ).outerjoin( - TagBinding, Tag.id == TagBinding.tag_id - ).filter( - Tag.type == tag_type, - Tag.tenant_id == current_tenant_id + query = ( + db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count")) + .outerjoin(TagBinding, Tag.id == TagBinding.tag_id) + .filter(Tag.type == tag_type, Tag.tenant_id == current_tenant_id) ) if keyword: - query = query.filter(db.and_(Tag.name.ilike(f'%{keyword}%'))) - query = query.group_by( - Tag.id - ) + query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%"))) + query = query.group_by(Tag.id) results = query.order_by(Tag.created_at.desc()).all() return results @staticmethod def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list: - tags = db.session.query(Tag).filter( - Tag.id.in_(tag_ids), - Tag.tenant_id == current_tenant_id, - Tag.type == tag_type - ).all() + tags = ( + db.session.query(Tag) + .filter(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type) + .all() + ) if not tags: return [] tag_ids = [tag.id for tag in tags] - tag_bindings = db.session.query( - TagBinding.target_id - ).filter( - TagBinding.tag_id.in_(tag_ids), - TagBinding.tenant_id == current_tenant_id - ).all() + tag_bindings = ( + db.session.query(TagBinding.target_id) + .filter(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id) + .all() + ) if not tag_bindings: return [] results = [tag_binding.target_id for tag_binding in tag_bindings] @@ -51,27 +45,28 @@ class TagService: @staticmethod def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list: - tags = db.session.query(Tag).join( - TagBinding, - Tag.id == TagBinding.tag_id - ).filter( - TagBinding.target_id == target_id, - TagBinding.tenant_id == current_tenant_id, - Tag.tenant_id == current_tenant_id, - Tag.type == tag_type - ).all() + tags = ( + db.session.query(Tag) + .join(TagBinding, Tag.id == TagBinding.tag_id) + .filter( + TagBinding.target_id == target_id, + TagBinding.tenant_id == current_tenant_id, + Tag.tenant_id == current_tenant_id, + Tag.type == tag_type, + ) + .all() + ) return tags if tags else [] - @staticmethod def save_tags(args: dict) -> Tag: tag = Tag( id=str(uuid.uuid4()), - name=args['name'], - type=args['type'], + name=args["name"], + type=args["type"], created_by=current_user.id, - tenant_id=current_user.current_tenant_id + tenant_id=current_user.current_tenant_id, ) db.session.add(tag) db.session.commit() @@ -82,7 +77,7 @@ class TagService: tag = db.session.query(Tag).filter(Tag.id == tag_id).first() if not tag: raise NotFound("Tag not found") - tag.name = args['name'] + tag.name = args["name"] db.session.commit() return tag @@ -107,20 +102,21 @@ class TagService: @staticmethod def save_tag_binding(args): # check if target exists - TagService.check_target_exists(args['type'], args['target_id']) + TagService.check_target_exists(args["type"], args["target_id"]) # save tag binding - for tag_id in args['tag_ids']: - tag_binding = db.session.query(TagBinding).filter( - TagBinding.tag_id == tag_id, - TagBinding.target_id == args['target_id'] - ).first() + for tag_id in args["tag_ids"]: + tag_binding = ( + db.session.query(TagBinding) + .filter(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"]) + .first() + ) if tag_binding: continue new_tag_binding = TagBinding( tag_id=tag_id, - target_id=args['target_id'], + target_id=args["target_id"], tenant_id=current_user.current_tenant_id, - created_by=current_user.id + created_by=current_user.id, ) db.session.add(new_tag_binding) db.session.commit() @@ -128,34 +124,34 @@ class TagService: @staticmethod def delete_tag_binding(args): # check if target exists - TagService.check_target_exists(args['type'], args['target_id']) + TagService.check_target_exists(args["type"], args["target_id"]) # delete tag binding - tag_bindings = db.session.query(TagBinding).filter( - TagBinding.target_id == args['target_id'], - TagBinding.tag_id == (args['tag_id']) - ).first() + tag_bindings = ( + db.session.query(TagBinding) + .filter(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"])) + .first() + ) if tag_bindings: db.session.delete(tag_bindings) db.session.commit() - - @staticmethod def check_target_exists(type: str, target_id: str): - if type == 'knowledge': - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == current_user.current_tenant_id, - Dataset.id == target_id - ).first() + if type == "knowledge": + dataset = ( + db.session.query(Dataset) + .filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id) + .first() + ) if not dataset: raise NotFound("Dataset not found") - elif type == 'app': - app = db.session.query(App).filter( - App.tenant_id == current_user.current_tenant_id, - App.id == target_id - ).first() + elif type == "app": + app = ( + db.session.query(App) + .filter(App.tenant_id == current_user.current_tenant_id, App.id == target_id) + .first() + ) if not app: raise NotFound("App not found") else: raise NotFound("Invalid binding type") - diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index ecc065d521..3ded9c0989 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -29,111 +29,107 @@ class ApiToolManageService: @staticmethod def parser_api_schema(schema: str) -> list[ApiToolBundle]: """ - parse api schema to tool bundle + parse api schema to tool bundle """ try: warnings = {} try: tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings) except Exception as e: - raise ValueError(f'invalid schema: {str(e)}') - + raise ValueError(f"invalid schema: {str(e)}") + credentials_schema = [ ToolProviderCredentials( - name='auth_type', + name="auth_type", type=ToolProviderCredentials.CredentialsType.SELECT, required=True, - default='none', + default="none", options=[ - ToolCredentialsOption(value='none', label=I18nObject( - en_US='None', - zh_Hans='无' - )), - ToolCredentialsOption(value='api_key', label=I18nObject( - en_US='Api Key', - zh_Hans='Api Key' - )), + ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="无")), + ToolCredentialsOption(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")), ], - placeholder=I18nObject( - en_US='Select auth type', - zh_Hans='选择认证方式' - ) + placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"), ), ToolProviderCredentials( - name='api_key_header', + name="api_key_header", type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, required=False, - placeholder=I18nObject( - en_US='Enter api key header', - zh_Hans='输入 api key header,如:X-API-KEY' - ), - default='api_key', - help=I18nObject( - en_US='HTTP header name for api key', - zh_Hans='HTTP 头部字段名,用于传递 api key' - ) + placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key header,如:X-API-KEY"), + default="api_key", + help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"), ), ToolProviderCredentials( - name='api_key_value', + name="api_key_value", type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, required=False, - placeholder=I18nObject( - en_US='Enter api key', - zh_Hans='输入 api key' - ), - default='' + placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"), + default="", ), ] - return jsonable_encoder({ - 'schema_type': schema_type, - 'parameters_schema': tool_bundles, - 'credentials_schema': credentials_schema, - 'warning': warnings - }) + return jsonable_encoder( + { + "schema_type": schema_type, + "parameters_schema": tool_bundles, + "credentials_schema": credentials_schema, + "warning": warnings, + } + ) except Exception as e: - raise ValueError(f'invalid schema: {str(e)}') + raise ValueError(f"invalid schema: {str(e)}") @staticmethod def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiToolBundle]: """ - convert schema to tool bundles + convert schema to tool bundles - :return: the list of tool bundles, description + :return: the list of tool bundles, description """ try: tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info) return tool_bundles except Exception as e: - raise ValueError(f'invalid schema: {str(e)}') + raise ValueError(f"invalid schema: {str(e)}") @staticmethod def create_api_tool_provider( - user_id: str, tenant_id: str, provider_name: str, icon: dict, credentials: dict, - schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str] + user_id: str, + tenant_id: str, + provider_name: str, + icon: dict, + credentials: dict, + schema_type: str, + schema: str, + privacy_policy: str, + custom_disclaimer: str, + labels: list[str], ): """ - create api tool provider + create api tool provider """ if schema_type not in [member.value for member in ApiProviderSchemaType]: - raise ValueError(f'invalid schema type {schema}') - + raise ValueError(f"invalid schema type {schema}") + # check if the provider exists - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .first() + ) if provider is not None: - raise ValueError(f'provider {provider_name} already exists') + raise ValueError(f"provider {provider_name} already exists") # parse openapi to tool bundle extra_info = {} # extra info like description will be set here tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) - + if len(tool_bundles) > 100: - raise ValueError('the number of apis should be less than 100') + raise ValueError("the number of apis should be less than 100") # create db provider db_provider = ApiToolProvider( @@ -142,19 +138,19 @@ class ApiToolManageService: name=provider_name, icon=json.dumps(icon), schema=schema, - description=extra_info.get('description', ''), + description=extra_info.get("description", ""), schema_type_str=schema_type, tools_str=json.dumps(jsonable_encoder(tool_bundles)), credentials_str={}, privacy_policy=privacy_policy, - custom_disclaimer=custom_disclaimer + custom_disclaimer=custom_disclaimer, ) - if 'auth_type' not in credentials: - raise ValueError('auth_type is required') + if "auth_type" not in credentials: + raise ValueError("auth_type is required") # get auth type, none or api key - auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) + auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) # create provider entity provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) @@ -172,14 +168,12 @@ class ApiToolManageService: # update labels ToolLabelManager.update_tool_labels(provider_controller, labels) - return { 'result': 'success' } - + return {"result": "success"} + @staticmethod - def get_api_tool_provider_remote_schema( - user_id: str, tenant_id: str, url: str - ): + def get_api_tool_provider_remote_schema(user_id: str, tenant_id: str, url: str): """ - get api tool provider remote schema + get api tool provider remote schema """ headers = { "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0", @@ -189,84 +183,98 @@ class ApiToolManageService: try: response = get(url, headers=headers, timeout=10) if response.status_code != 200: - raise ValueError(f'Got status code {response.status_code}') + raise ValueError(f"Got status code {response.status_code}") schema = response.text # try to parse schema, avoid SSRF attack ApiToolManageService.parser_api_schema(schema) except Exception as e: logger.error(f"parse api schema error: {str(e)}") - raise ValueError('invalid schema, please check the url you provided') - - return { - 'schema': schema - } + raise ValueError("invalid schema, please check the url you provided") + + return {"schema": schema} @staticmethod - def list_api_tool_provider_tools( - user_id: str, tenant_id: str, provider: str - ) -> list[UserTool]: + def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]: """ - list api tool provider tools + list api tool provider tools """ - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider, - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider, + ) + .first() + ) if provider is None: - raise ValueError(f'you have not added provider {provider}') - + raise ValueError(f"you have not added provider {provider}") + controller = ToolTransformService.api_provider_to_controller(db_provider=provider) labels = ToolLabelManager.get_tool_labels(controller) - + return [ ToolTransformService.tool_to_user_tool( tool_bundle, labels=labels, - ) for tool_bundle in provider.tools + ) + for tool_bundle in provider.tools ] @staticmethod def update_api_tool_provider( - user_id: str, tenant_id: str, provider_name: str, original_provider: str, icon: dict, credentials: dict, - schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str] + user_id: str, + tenant_id: str, + provider_name: str, + original_provider: str, + icon: dict, + credentials: dict, + schema_type: str, + schema: str, + privacy_policy: str, + custom_disclaimer: str, + labels: list[str], ): """ - update api tool provider + update api tool provider """ if schema_type not in [member.value for member in ApiProviderSchemaType]: - raise ValueError(f'invalid schema type {schema}') - + raise ValueError(f"invalid schema type {schema}") + # check if the provider exists - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == original_provider, - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == original_provider, + ) + .first() + ) if provider is None: - raise ValueError(f'api provider {provider_name} does not exists') + raise ValueError(f"api provider {provider_name} does not exists") # parse openapi to tool bundle extra_info = {} # extra info like description will be set here tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) - + # update db provider provider.name = provider_name provider.icon = json.dumps(icon) provider.schema = schema - provider.description = extra_info.get('description', '') + provider.description = extra_info.get("description", "") provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value provider.tools_str = json.dumps(jsonable_encoder(tool_bundles)) provider.privacy_policy = privacy_policy provider.custom_disclaimer = custom_disclaimer - if 'auth_type' not in credentials: - raise ValueError('auth_type is required') + if "auth_type" not in credentials: + raise ValueError("auth_type is required") # get auth type, none or api key - auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) + auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) # create provider entity provider_controller = ApiToolProviderController.from_db(provider, auth_type) @@ -295,84 +303,91 @@ class ApiToolManageService: # update labels ToolLabelManager.update_tool_labels(provider_controller, labels) - return { 'result': 'success' } - + return {"result": "success"} + @staticmethod - def delete_api_tool_provider( - user_id: str, tenant_id: str, provider_name: str - ): + def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str): """ - delete tool provider + delete tool provider """ - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .first() + ) if provider is None: - raise ValueError(f'you have not added provider {provider_name}') - + raise ValueError(f"you have not added provider {provider_name}") + db.session.delete(provider) db.session.commit() - return { 'result': 'success' } - + return {"result": "success"} + @staticmethod - def get_api_tool_provider( - user_id: str, tenant_id: str, provider: str - ): + def get_api_tool_provider(user_id: str, tenant_id: str, provider: str): """ - get api tool provider + get api tool provider """ return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id) - + @staticmethod def test_api_tool_preview( - tenant_id: str, + tenant_id: str, provider_name: str, - tool_name: str, - credentials: dict, - parameters: dict, - schema_type: str, - schema: str + tool_name: str, + credentials: dict, + parameters: dict, + schema_type: str, + schema: str, ): """ - test api tool before adding api tool provider + test api tool before adding api tool provider """ if schema_type not in [member.value for member in ApiProviderSchemaType]: - raise ValueError(f'invalid schema type {schema_type}') - + raise ValueError(f"invalid schema type {schema_type}") + try: tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema) except Exception as e: - raise ValueError('invalid schema') - + raise ValueError("invalid schema") + # get tool bundle tool_bundle = next(filter(lambda tb: tb.operation_id == tool_name, tool_bundles), None) if tool_bundle is None: - raise ValueError(f'invalid tool name {tool_name}') - - db_provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, - ).first() + raise ValueError(f"invalid tool name {tool_name}") + + db_provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .first() + ) if not db_provider: # create a fake db provider db_provider = ApiToolProvider( - tenant_id='', user_id='', name='', icon='', + tenant_id="", + user_id="", + name="", + icon="", schema=schema, - description='', + description="", schema_type_str=ApiProviderSchemaType.OPENAPI.value, tools_str=json.dumps(jsonable_encoder(tool_bundles)), credentials_str=json.dumps(credentials), ) - if 'auth_type' not in credentials: - raise ValueError('auth_type is required') + if "auth_type" not in credentials: + raise ValueError("auth_type is required") # get auth type, none or api key - auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) + auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) # create provider entity provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) @@ -381,10 +396,7 @@ class ApiToolManageService: # decrypt credentials if db_provider.id: - tool_configuration = ToolConfigurationManager( - tenant_id=tenant_id, - provider_controller=provider_controller - ) + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) # check if the credential has changed, save the original credential masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) @@ -396,27 +408,27 @@ class ApiToolManageService: provider_controller.validate_credentials_format(credentials) # get tool tool = provider_controller.get_tool(tool_name) - tool = tool.fork_tool_runtime(runtime={ - 'credentials': credentials, - 'tenant_id': tenant_id, - }) + tool = tool.fork_tool_runtime( + runtime={ + "credentials": credentials, + "tenant_id": tenant_id, + } + ) result = tool.validate_credentials(credentials, parameters) except Exception as e: - return { 'error': str(e) } - - return { 'result': result or 'empty response' } - + return {"error": str(e)} + + return {"result": result or "empty response"} + @staticmethod - def list_api_tools( - user_id: str, tenant_id: str - ) -> list[UserToolProvider]: + def list_api_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]: """ - list api tools + list api tools """ # get all api providers - db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id - ).all() or [] + db_providers: list[ApiToolProvider] = ( + db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() or [] + ) result: list[UserToolProvider] = [] @@ -425,26 +437,21 @@ class ApiToolManageService: provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider) labels = ToolLabelManager.get_tool_labels(provider_controller) user_provider = ToolTransformService.api_provider_to_user_provider( - provider_controller, - db_provider=provider, - decrypt_credentials=True + provider_controller, db_provider=provider, decrypt_credentials=True ) user_provider.labels = labels # add icon ToolTransformService.repack_provider(user_provider) - tools = provider_controller.get_tools( - user_id=user_id, tenant_id=tenant_id - ) + tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id) for tool in tools: - user_provider.tools.append(ToolTransformService.tool_to_user_tool( - tenant_id=tenant_id, - tool=tool, - credentials=user_provider.original_credentials, - labels=labels - )) + user_provider.tools.append( + ToolTransformService.tool_to_user_tool( + tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels + ) + ) result.append(user_provider) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index ea6ecf0c69..dc8cebb587 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -1,6 +1,8 @@ import json import logging +from configs import dify_config +from core.helper.position_helper import is_filtered from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.entities.api_entities import UserTool, UserToolProvider from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError @@ -18,21 +20,25 @@ logger = logging.getLogger(__name__) class BuiltinToolManageService: @staticmethod - def list_builtin_tool_provider_tools( - user_id: str, tenant_id: str, provider: str - ) -> list[UserTool]: + def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]: """ - list builtin tool provider tools + list builtin tool provider tools """ provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider) tools = provider_controller.get_tools() - tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) + tool_provider_configurations = ToolConfigurationManager( + tenant_id=tenant_id, provider_controller=provider_controller + ) # check if user has added the provider - builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider, - ).first() + builtin_provider: BuiltinToolProvider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + ) + .first() + ) credentials = {} if builtin_provider is not None: @@ -42,47 +48,47 @@ class BuiltinToolManageService: result = [] for tool in tools: - result.append(ToolTransformService.tool_to_user_tool( - tool=tool, - credentials=credentials, - tenant_id=tenant_id, - labels=ToolLabelManager.get_tool_labels(provider_controller) - )) + result.append( + ToolTransformService.tool_to_user_tool( + tool=tool, + credentials=credentials, + tenant_id=tenant_id, + labels=ToolLabelManager.get_tool_labels(provider_controller), + ) + ) return result - - @staticmethod - def list_builtin_provider_credentials_schema( - provider_name - ): - """ - list builtin provider credentials schema - :return: the list of tool providers + @staticmethod + def list_builtin_provider_credentials_schema(provider_name): + """ + list builtin provider credentials schema + + :return: the list of tool providers """ provider = ToolManager.get_builtin_provider(provider_name) - return jsonable_encoder([ - v for _, v in (provider.credentials_schema or {}).items() - ]) + return jsonable_encoder([v for _, v in (provider.credentials_schema or {}).items()]) @staticmethod - def update_builtin_tool_provider( - user_id: str, tenant_id: str, provider_name: str, credentials: dict - ): + def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict): """ - update builtin tool provider + update builtin tool provider """ # get if the provider exists - provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider_name, - ).first() + provider: BuiltinToolProvider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider_name, + ) + .first() + ) - try: + try: # get provider provider_controller = ToolManager.get_builtin_provider(provider_name) if not provider_controller.need_credentials: - raise ValueError(f'provider {provider_name} does not need credentials') + raise ValueError(f"provider {provider_name} does not need credentials") tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) # get original credentials if exists if provider is not None: @@ -119,23 +125,25 @@ class BuiltinToolManageService: # delete cache tool_configuration.delete_tool_credentials_cache() - return { 'result': 'success' } - + return {"result": "success"} + @staticmethod - def get_builtin_tool_provider_credentials( - user_id: str, tenant_id: str, provider: str - ): + def get_builtin_tool_provider_credentials(user_id: str, tenant_id: str, provider: str): """ - get builtin tool provider credentials + get builtin tool provider credentials """ - provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider, - ).first() + provider: BuiltinToolProvider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + ) + .first() + ) if provider is None: return {} - + provider_controller = ToolManager.get_builtin_provider(provider.provider) tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) @@ -143,20 +151,22 @@ class BuiltinToolManageService: return credentials @staticmethod - def delete_builtin_tool_provider( - user_id: str, tenant_id: str, provider_name: str - ): + def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str): """ - delete tool provider + delete tool provider """ - provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider_name, - ).first() + provider: BuiltinToolProvider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider_name, + ) + .first() + ) if provider is None: - raise ValueError(f'you have not added provider {provider_name}') - + raise ValueError(f"you have not added provider {provider_name}") + db.session.delete(provider) db.session.commit() @@ -165,48 +175,55 @@ class BuiltinToolManageService: tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) tool_configuration.delete_tool_credentials_cache() - return { 'result': 'success' } - + return {"result": "success"} + @staticmethod - def get_builtin_tool_provider_icon( - provider: str - ): + def get_builtin_tool_provider_icon(provider: str): """ - get tool provider icon and it's mimetype + get tool provider icon and it's mimetype """ icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider) - with open(icon_path, 'rb') as f: + with open(icon_path, "rb") as f: icon_bytes = f.read() return icon_bytes, mime_type - + @staticmethod - def list_builtin_tools( - user_id: str, tenant_id: str - ) -> list[UserToolProvider]: + def list_builtin_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]: """ - list builtin tools + list builtin tools """ # get all builtin providers provider_controllers = ToolManager.list_builtin_providers() # get all user added providers - db_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id - ).all() or [] + db_providers: list[BuiltinToolProvider] = ( + db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or [] + ) # find provider - find_provider = lambda provider: next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None) + find_provider = lambda provider: next( + filter(lambda db_provider: db_provider.provider == provider, db_providers), None + ) result: list[UserToolProvider] = [] for provider_controller in provider_controllers: try: + # handle include, exclude + if is_filtered( + include_set=dify_config.POSITION_TOOL_INCLUDES_SET, + exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, + data=provider_controller, + name_func=lambda x: x.identity.name, + ): + continue + # convert provider controller to user provider user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider_controller, db_provider=find_provider(provider_controller.identity.name), - decrypt_credentials=True + decrypt_credentials=True, ) # add icon @@ -214,16 +231,17 @@ class BuiltinToolManageService: tools = provider_controller.get_tools() for tool in tools: - user_builtin_provider.tools.append(ToolTransformService.tool_to_user_tool( - tenant_id=tenant_id, - tool=tool, - credentials=user_builtin_provider.original_credentials, - labels=ToolLabelManager.get_tool_labels(provider_controller) - )) + user_builtin_provider.tools.append( + ToolTransformService.tool_to_user_tool( + tenant_id=tenant_id, + tool=tool, + credentials=user_builtin_provider.original_credentials, + labels=ToolLabelManager.get_tool_labels(provider_controller), + ) + ) result.append(user_builtin_provider) except Exception as e: raise e return BuiltinToolProviderSort.sort(result) - \ No newline at end of file diff --git a/api/services/tools/tool_labels_service.py b/api/services/tools/tool_labels_service.py index 8a6aa025f2..35e58b5ade 100644 --- a/api/services/tools/tool_labels_service.py +++ b/api/services/tools/tool_labels_service.py @@ -5,4 +5,4 @@ from core.tools.entities.values import default_tool_labels class ToolLabelsService: @classmethod def list_tool_labels(cls) -> list[ToolLabel]: - return default_tool_labels \ No newline at end of file + return default_tool_labels diff --git a/api/services/tools/tools_manage_service.py b/api/services/tools/tools_manage_service.py index 76d2f53ae8..1c67f7648c 100644 --- a/api/services/tools/tools_manage_service.py +++ b/api/services/tools/tools_manage_service.py @@ -11,13 +11,11 @@ class ToolCommonService: @staticmethod def list_tool_providers(user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral = None): """ - list tool providers + list tool providers - :return: the list of tool providers + :return: the list of tool providers """ - providers = ToolManager.user_list_providers( - user_id, tenant_id, typ - ) + providers = ToolManager.user_list_providers(user_id, tenant_id, typ) # add icon for provider in providers: @@ -26,4 +24,3 @@ class ToolCommonService: result = [provider.to_dict() for provider in providers] return result - \ No newline at end of file diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index cfce3fbd01..6fb0f2f517 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -22,46 +22,39 @@ from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvi logger = logging.getLogger(__name__) + class ToolTransformService: @staticmethod def get_tool_provider_icon_url(provider_type: str, provider_name: str, icon: str) -> Union[str, dict]: """ - get tool provider icon url + get tool provider icon url """ - url_prefix = (dify_config.CONSOLE_API_URL - + "/console/api/workspaces/current/tool-provider/") - + url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/" + if provider_type == ToolProviderType.BUILT_IN.value: - return url_prefix + 'builtin/' + provider_name + '/icon' + return url_prefix + "builtin/" + provider_name + "/icon" elif provider_type in [ToolProviderType.API.value, ToolProviderType.WORKFLOW.value]: try: return json.loads(icon) except: - return { - "background": "#252525", - "content": "\ud83d\ude01" - } - - return '' - + return {"background": "#252525", "content": "\ud83d\ude01"} + + return "" + @staticmethod def repack_provider(provider: Union[dict, UserToolProvider]): """ - repack provider + repack provider - :param provider: the provider dict + :param provider: the provider dict """ - if isinstance(provider, dict) and 'icon' in provider: - provider['icon'] = ToolTransformService.get_tool_provider_icon_url( - provider_type=provider['type'], - provider_name=provider['name'], - icon=provider['icon'] + if isinstance(provider, dict) and "icon" in provider: + provider["icon"] = ToolTransformService.get_tool_provider_icon_url( + provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"] ) elif isinstance(provider, UserToolProvider): provider.icon = ToolTransformService.get_tool_provider_icon_url( - provider_type=provider.type.value, - provider_name=provider.name, - icon=provider.icon + provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon ) @staticmethod @@ -92,14 +85,13 @@ class ToolTransformService: masked_credentials={}, is_team_authorization=False, tools=[], - labels=provider_controller.tool_labels + labels=provider_controller.tool_labels, ) # get credentials schema schema = provider_controller.get_credentials_schema() for name, value in schema.items(): - result.masked_credentials[name] = \ - ToolProviderCredentials.CredentialsType.default(value.type) + result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(value.type) # check if the provider need credentials if not provider_controller.need_credentials: @@ -113,8 +105,7 @@ class ToolTransformService: # init tool configuration tool_configuration = ToolConfigurationManager( - tenant_id=db_provider.tenant_id, - provider_controller=provider_controller + tenant_id=db_provider.tenant_id, provider_controller=provider_controller ) # decrypt the credentials and mask the credentials decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials) @@ -124,7 +115,7 @@ class ToolTransformService: result.original_credentials = decrypted_credentials return result - + @staticmethod def api_provider_to_controller( db_provider: ApiToolProvider, @@ -135,25 +126,23 @@ class ToolTransformService: # package tool provider controller controller = ApiToolProviderController.from_db( db_provider=db_provider, - auth_type=ApiProviderAuthType.API_KEY if db_provider.credentials['auth_type'] == 'api_key' else - ApiProviderAuthType.NONE + auth_type=ApiProviderAuthType.API_KEY + if db_provider.credentials["auth_type"] == "api_key" + else ApiProviderAuthType.NONE, ) return controller - + @staticmethod - def workflow_provider_to_controller( - db_provider: WorkflowToolProvider - ) -> WorkflowToolProviderController: + def workflow_provider_to_controller(db_provider: WorkflowToolProvider) -> WorkflowToolProviderController: """ convert provider controller to provider """ return WorkflowToolProviderController.from_db(db_provider) - + @staticmethod def workflow_provider_to_user_provider( - provider_controller: WorkflowToolProviderController, - labels: list[str] = None + provider_controller: WorkflowToolProviderController, labels: list[str] = None ): """ convert provider controller to user provider @@ -175,7 +164,7 @@ class ToolTransformService: masked_credentials={}, is_team_authorization=True, tools=[], - labels=labels or [] + labels=labels or [], ) @staticmethod @@ -183,16 +172,16 @@ class ToolTransformService: provider_controller: ApiToolProviderController, db_provider: ApiToolProvider, decrypt_credentials: bool = True, - labels: list[str] = None + labels: list[str] = None, ) -> UserToolProvider: """ convert provider controller to user provider """ - username = 'Anonymous' + username = "Anonymous" try: username = db_provider.user.name except Exception as e: - logger.error(f'failed to get user name for api provider {db_provider.id}: {str(e)}') + logger.error(f"failed to get user name for api provider {db_provider.id}: {str(e)}") # add provider into providers credentials = db_provider.credentials result = UserToolProvider( @@ -212,14 +201,13 @@ class ToolTransformService: masked_credentials={}, is_team_authorization=True, tools=[], - labels=labels or [] + labels=labels or [], ) if decrypt_credentials: # init tool configuration tool_configuration = ToolConfigurationManager( - tenant_id=db_provider.tenant_id, - provider_controller=provider_controller + tenant_id=db_provider.tenant_id, provider_controller=provider_controller ) # decrypt the credentials and mask the credentials @@ -229,23 +217,25 @@ class ToolTransformService: result.masked_credentials = masked_credentials return result - + @staticmethod def tool_to_user_tool( - tool: Union[ApiToolBundle, WorkflowTool, Tool], - credentials: dict = None, + tool: Union[ApiToolBundle, WorkflowTool, Tool], + credentials: dict = None, tenant_id: str = None, - labels: list[str] = None + labels: list[str] = None, ) -> UserTool: """ convert tool to user tool """ if isinstance(tool, Tool): # fork tool runtime - tool = tool.fork_tool_runtime(runtime={ - 'credentials': credentials, - 'tenant_id': tenant_id, - }) + tool = tool.fork_tool_runtime( + runtime={ + "credentials": credentials, + "tenant_id": tenant_id, + } + ) # get tool parameters parameters = tool.parameters or [] @@ -270,20 +260,14 @@ class ToolTransformService: label=tool.identity.label, description=tool.description.human, parameters=current_parameters, - labels=labels + labels=labels, ) if isinstance(tool, ApiToolBundle): return UserTool( author=tool.author, name=tool.operation_id, - label=I18nObject( - en_US=tool.operation_id, - zh_Hans=tool.operation_id - ), - description=I18nObject( - en_US=tool.summary or '', - zh_Hans=tool.summary or '' - ), + label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id), + description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""), parameters=tool.parameters, - labels=labels - ) \ No newline at end of file + labels=labels, + ) diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 185483a71c..3830e75339 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -19,10 +19,21 @@ class WorkflowToolManageService: """ Service class for managing workflow tools. """ + @classmethod - def create_workflow_tool(cls, user_id: str, tenant_id: str, workflow_app_id: str, name: str, - label: str, icon: dict, description: str, - parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict: + def create_workflow_tool( + cls, + user_id: str, + tenant_id: str, + workflow_app_id: str, + name: str, + label: str, + icon: dict, + description: str, + parameters: list[dict], + privacy_policy: str = "", + labels: list[str] = None, + ) -> dict: """ Create a workflow tool. :param user_id: the user id @@ -38,27 +49,28 @@ class WorkflowToolManageService: WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) # check if the name is unique - existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - # name or app_id - or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id) - ).first() + existing_workflow_tool_provider = ( + db.session.query(WorkflowToolProvider) + .filter( + WorkflowToolProvider.tenant_id == tenant_id, + # name or app_id + or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id), + ) + .first() + ) if existing_workflow_tool_provider is not None: - raise ValueError(f'Tool with name {name} or app_id {workflow_app_id} already exists') - - app: App = db.session.query(App).filter( - App.id == workflow_app_id, - App.tenant_id == tenant_id - ).first() + raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists") + + app: App = db.session.query(App).filter(App.id == workflow_app_id, App.tenant_id == tenant_id).first() if app is None: - raise ValueError(f'App {workflow_app_id} not found') - + raise ValueError(f"App {workflow_app_id} not found") + workflow: Workflow = app.workflow if workflow is None: - raise ValueError(f'Workflow not found for app {workflow_app_id}') - + raise ValueError(f"Workflow not found for app {workflow_app_id}") + workflow_tool_provider = WorkflowToolProvider( tenant_id=tenant_id, user_id=user_id, @@ -76,19 +88,26 @@ class WorkflowToolManageService: WorkflowToolProviderController.from_db(workflow_tool_provider) except Exception as e: raise ValueError(str(e)) - + db.session.add(workflow_tool_provider) db.session.commit() - return { - 'result': 'success' - } - + return {"result": "success"} @classmethod - def update_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str, - name: str, label: str, icon: dict, description: str, - parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict: + def update_workflow_tool( + cls, + user_id: str, + tenant_id: str, + workflow_tool_id: str, + name: str, + label: str, + icon: dict, + description: str, + parameters: list[dict], + privacy_policy: str = "", + labels: list[str] = None, + ) -> dict: """ Update a workflow tool. :param user_id: the user id @@ -106,35 +125,39 @@ class WorkflowToolManageService: WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) # check if the name is unique - existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.name == name, - WorkflowToolProvider.id != workflow_tool_id - ).first() + existing_workflow_tool_provider = ( + db.session.query(WorkflowToolProvider) + .filter( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.name == name, + WorkflowToolProvider.id != workflow_tool_id, + ) + .first() + ) if existing_workflow_tool_provider is not None: - raise ValueError(f'Tool with name {name} already exists') - - workflow_tool_provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.id == workflow_tool_id - ).first() + raise ValueError(f"Tool with name {name} already exists") + + workflow_tool_provider: WorkflowToolProvider = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .first() + ) if workflow_tool_provider is None: - raise ValueError(f'Tool {workflow_tool_id} not found') - - app: App = db.session.query(App).filter( - App.id == workflow_tool_provider.app_id, - App.tenant_id == tenant_id - ).first() + raise ValueError(f"Tool {workflow_tool_id} not found") + + app: App = ( + db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first() + ) if app is None: - raise ValueError(f'App {workflow_tool_provider.app_id} not found') - + raise ValueError(f"App {workflow_tool_provider.app_id} not found") + workflow: Workflow = app.workflow if workflow is None: - raise ValueError(f'Workflow not found for app {workflow_tool_provider.app_id}') - + raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}") + workflow_tool_provider.name = name workflow_tool_provider.label = label workflow_tool_provider.icon = json.dumps(icon) @@ -154,13 +177,10 @@ class WorkflowToolManageService: if labels is not None: ToolLabelManager.update_tool_labels( - ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), - labels + ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels ) - return { - 'result': 'success' - } + return {"result": "success"} @classmethod def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]: @@ -170,9 +190,7 @@ class WorkflowToolManageService: :param tenant_id: the tenant id :return: the list of tools """ - db_tools = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id - ).all() + db_tools = db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() tools = [] for provider in db_tools: @@ -188,14 +206,12 @@ class WorkflowToolManageService: for tool in tools: user_tool_provider = ToolTransformService.workflow_provider_to_user_provider( - provider_controller=tool, - labels=labels.get(tool.provider_id, []) + provider_controller=tool, labels=labels.get(tool.provider_id, []) ) ToolTransformService.repack_provider(user_tool_provider) user_tool_provider.tools = [ ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], - labels=labels.get(tool.provider_id, []) + tool.get_tools(user_id, tenant_id)[0], labels=labels.get(tool.provider_id, []) ) ] result.append(user_tool_provider) @@ -211,15 +227,12 @@ class WorkflowToolManageService: :param workflow_app_id: the workflow app id """ db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.id == workflow_tool_id + WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id ).delete() db.session.commit() - return { - 'result': 'success' - } + return {"result": "success"} @classmethod def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict: @@ -230,40 +243,37 @@ class WorkflowToolManageService: :param workflow_app_id: the workflow app id :return: the tool """ - db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.id == workflow_tool_id - ).first() + db_tool: WorkflowToolProvider = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .first() + ) if db_tool is None: - raise ValueError(f'Tool {workflow_tool_id} not found') - - workflow_app: App = db.session.query(App).filter( - App.id == db_tool.app_id, - App.tenant_id == tenant_id - ).first() + raise ValueError(f"Tool {workflow_tool_id} not found") + + workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first() if workflow_app is None: - raise ValueError(f'App {db_tool.app_id} not found') + raise ValueError(f"App {db_tool.app_id} not found") tool = ToolTransformService.workflow_provider_to_controller(db_tool) return { - 'name': db_tool.name, - 'label': db_tool.label, - 'workflow_tool_id': db_tool.id, - 'workflow_app_id': db_tool.app_id, - 'icon': json.loads(db_tool.icon), - 'description': db_tool.description, - 'parameters': jsonable_encoder(db_tool.parameter_configurations), - 'tool': ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], - labels=ToolLabelManager.get_tool_labels(tool) + "name": db_tool.name, + "label": db_tool.label, + "workflow_tool_id": db_tool.id, + "workflow_app_id": db_tool.app_id, + "icon": json.loads(db_tool.icon), + "description": db_tool.description, + "parameters": jsonable_encoder(db_tool.parameter_configurations), + "tool": ToolTransformService.tool_to_user_tool( + tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) ), - 'synced': workflow_app.workflow.version == db_tool.version, - 'privacy_policy': db_tool.privacy_policy, + "synced": workflow_app.workflow.version == db_tool.version, + "privacy_policy": db_tool.privacy_policy, } - + @classmethod def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict: """ @@ -273,40 +283,37 @@ class WorkflowToolManageService: :param workflow_app_id: the workflow app id :return: the tool """ - db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.app_id == workflow_app_id - ).first() + db_tool: WorkflowToolProvider = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id) + .first() + ) if db_tool is None: - raise ValueError(f'Tool {workflow_app_id} not found') - - workflow_app: App = db.session.query(App).filter( - App.id == db_tool.app_id, - App.tenant_id == tenant_id - ).first() + raise ValueError(f"Tool {workflow_app_id} not found") + + workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first() if workflow_app is None: - raise ValueError(f'App {db_tool.app_id} not found') + raise ValueError(f"App {db_tool.app_id} not found") tool = ToolTransformService.workflow_provider_to_controller(db_tool) return { - 'name': db_tool.name, - 'label': db_tool.label, - 'workflow_tool_id': db_tool.id, - 'workflow_app_id': db_tool.app_id, - 'icon': json.loads(db_tool.icon), - 'description': db_tool.description, - 'parameters': jsonable_encoder(db_tool.parameter_configurations), - 'tool': ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], - labels=ToolLabelManager.get_tool_labels(tool) + "name": db_tool.name, + "label": db_tool.label, + "workflow_tool_id": db_tool.id, + "workflow_app_id": db_tool.app_id, + "icon": json.loads(db_tool.icon), + "description": db_tool.description, + "parameters": jsonable_encoder(db_tool.parameter_configurations), + "tool": ToolTransformService.tool_to_user_tool( + tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) ), - 'synced': workflow_app.workflow.version == db_tool.version, - 'privacy_policy': db_tool.privacy_policy + "synced": workflow_app.workflow.version == db_tool.version, + "privacy_policy": db_tool.privacy_policy, } - + @classmethod def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[dict]: """ @@ -316,19 +323,19 @@ class WorkflowToolManageService: :param workflow_app_id: the workflow app id :return: the list of tools """ - db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.id == workflow_tool_id - ).first() + db_tool: WorkflowToolProvider = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .first() + ) if db_tool is None: - raise ValueError(f'Tool {workflow_tool_id} not found') + raise ValueError(f"Tool {workflow_tool_id} not found") tool = ToolTransformService.workflow_provider_to_controller(db_tool) return [ ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], - labels=ToolLabelManager.get_tool_labels(tool) + tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) ) - ] \ No newline at end of file + ] diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 232d294325..3c67351335 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -7,10 +7,10 @@ from models.dataset import Dataset, DocumentSegment class VectorService: - @classmethod - def create_segments_vector(cls, keywords_list: Optional[list[list[str]]], - segments: list[DocumentSegment], dataset: Dataset): + def create_segments_vector( + cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset + ): documents = [] for segment in segments: document = Document( @@ -20,14 +20,12 @@ class VectorService: "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) documents.append(document) - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": # save vector index - vector = Vector( - dataset=dataset - ) + vector = Vector(dataset=dataset) vector.add_texts(documents, duplicate_check=True) # save keyword index @@ -50,13 +48,11 @@ class VectorService: "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": # update vector index - vector = Vector( - dataset=dataset - ) + vector = Vector(dataset=dataset) vector.delete_by_ids([segment.index_node_id]) vector.add_texts([document], duplicate_check=True) diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index cba048ccdb..d7ccc964cb 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -11,17 +11,29 @@ from services.conversation_service import ConversationService class WebConversationService: @classmethod - def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], - last_id: Optional[str], limit: int, invoke_from: InvokeFrom, - pinned: Optional[bool] = None) -> InfiniteScrollPagination: + def pagination_by_last_id( + cls, + app_model: App, + user: Optional[Union[Account, EndUser]], + last_id: Optional[str], + limit: int, + invoke_from: InvokeFrom, + pinned: Optional[bool] = None, + sort_by="-updated_at", + ) -> InfiniteScrollPagination: include_ids = None exclude_ids = None if pinned is not None: - pinned_conversations = db.session.query(PinnedConversation).filter( - PinnedConversation.app_id == app_model.id, - PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - PinnedConversation.created_by == user.id - ).order_by(PinnedConversation.created_at.desc()).all() + pinned_conversations = ( + db.session.query(PinnedConversation) + .filter( + PinnedConversation.app_id == app_model.id, + PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + PinnedConversation.created_by == user.id, + ) + .order_by(PinnedConversation.created_at.desc()) + .all() + ) pinned_conversation_ids = [pc.conversation_id for pc in pinned_conversations] if pinned: include_ids = pinned_conversation_ids @@ -36,31 +48,34 @@ class WebConversationService: invoke_from=invoke_from, include_ids=include_ids, exclude_ids=exclude_ids, + sort_by=sort_by, ) @classmethod def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): - pinned_conversation = db.session.query(PinnedConversation).filter( - PinnedConversation.app_id == app_model.id, - PinnedConversation.conversation_id == conversation_id, - PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - PinnedConversation.created_by == user.id - ).first() + pinned_conversation = ( + db.session.query(PinnedConversation) + .filter( + PinnedConversation.app_id == app_model.id, + PinnedConversation.conversation_id == conversation_id, + PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + PinnedConversation.created_by == user.id, + ) + .first() + ) if pinned_conversation: return conversation = ConversationService.get_conversation( - app_model=app_model, - conversation_id=conversation_id, - user=user + app_model=app_model, conversation_id=conversation_id, user=user ) pinned_conversation = PinnedConversation( app_id=app_model.id, conversation_id=conversation.id, - created_by_role='account' if isinstance(user, Account) else 'end_user', - created_by=user.id + created_by_role="account" if isinstance(user, Account) else "end_user", + created_by=user.id, ) db.session.add(pinned_conversation) @@ -68,12 +83,16 @@ class WebConversationService: @classmethod def unpin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): - pinned_conversation = db.session.query(PinnedConversation).filter( - PinnedConversation.app_id == app_model.id, - PinnedConversation.conversation_id == conversation_id, - PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - PinnedConversation.created_by == user.id - ).first() + pinned_conversation = ( + db.session.query(PinnedConversation) + .filter( + PinnedConversation.app_id == app_model.id, + PinnedConversation.conversation_id == conversation_id, + PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + PinnedConversation.created_by == user.id, + ) + .first() + ) if not pinned_conversation: return diff --git a/api/services/website_service.py b/api/services/website_service.py index c166b01237..6dff35d63f 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -11,161 +11,126 @@ from services.auth.api_key_auth_service import ApiKeyAuthService class WebsiteService: - @classmethod def document_create_args_validate(cls, args: dict): - if 'url' not in args or not args['url']: - raise ValueError('url is required') - if 'options' not in args or not args['options']: - raise ValueError('options is required') - if 'limit' not in args['options'] or not args['options']['limit']: - raise ValueError('limit is required') + if "url" not in args or not args["url"]: + raise ValueError("url is required") + if "options" not in args or not args["options"]: + raise ValueError("options is required") + if "limit" not in args["options"] or not args["options"]["limit"]: + raise ValueError("limit is required") @classmethod def crawl_url(cls, args: dict) -> dict: - provider = args.get('provider') - url = args.get('url') - options = args.get('options') - credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, - 'website', - provider) - if provider == 'firecrawl': + provider = args.get("provider") + url = args.get("url") + options = args.get("options") + credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider) + if provider == "firecrawl": # decrypt api_key api_key = encrypter.decrypt_token( - tenant_id=current_user.current_tenant_id, - token=credentials.get('config').get('api_key') + tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") ) - firecrawl_app = FirecrawlApp(api_key=api_key, - base_url=credentials.get('config').get('base_url', None)) - crawl_sub_pages = options.get('crawl_sub_pages', False) - only_main_content = options.get('only_main_content', False) + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) + crawl_sub_pages = options.get("crawl_sub_pages", False) + only_main_content = options.get("only_main_content", False) if not crawl_sub_pages: params = { - 'crawlerOptions': { + "crawlerOptions": { "includes": [], "excludes": [], "generateImgAltText": True, "limit": 1, - 'returnOnlyUrls': False, - 'pageOptions': { - 'onlyMainContent': only_main_content, - "includeHtml": False - } + "returnOnlyUrls": False, + "pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False}, } } else: - includes = options.get('includes').split(',') if options.get('includes') else [] - excludes = options.get('excludes').split(',') if options.get('excludes') else [] + includes = options.get("includes").split(",") if options.get("includes") else [] + excludes = options.get("excludes").split(",") if options.get("excludes") else [] params = { - 'crawlerOptions': { + "crawlerOptions": { "includes": includes if includes else [], "excludes": excludes if excludes else [], "generateImgAltText": True, - "limit": options.get('limit', 1), - 'returnOnlyUrls': False, - 'pageOptions': { - 'onlyMainContent': only_main_content, - "includeHtml": False - } + "limit": options.get("limit", 1), + "returnOnlyUrls": False, + "pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False}, } } - if options.get('max_depth'): - params['crawlerOptions']['maxDepth'] = options.get('max_depth') + if options.get("max_depth"): + params["crawlerOptions"]["maxDepth"] = options.get("max_depth") job_id = firecrawl_app.crawl_url(url, params) - website_crawl_time_cache_key = f'website_crawl_{job_id}' + website_crawl_time_cache_key = f"website_crawl_{job_id}" time = str(datetime.datetime.now().timestamp()) redis_client.setex(website_crawl_time_cache_key, 3600, time) - return { - 'status': 'active', - 'job_id': job_id - } + return {"status": "active", "job_id": job_id} else: - raise ValueError('Invalid provider') + raise ValueError("Invalid provider") @classmethod def get_crawl_status(cls, job_id: str, provider: str) -> dict: - credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, - 'website', - provider) - if provider == 'firecrawl': + credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider) + if provider == "firecrawl": # decrypt api_key api_key = encrypter.decrypt_token( - tenant_id=current_user.current_tenant_id, - token=credentials.get('config').get('api_key') + tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") ) - firecrawl_app = FirecrawlApp(api_key=api_key, - base_url=credentials.get('config').get('base_url', None)) + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) result = firecrawl_app.check_crawl_status(job_id) crawl_status_data = { - 'status': result.get('status', 'active'), - 'job_id': job_id, - 'total': result.get('total', 0), - 'current': result.get('current', 0), - 'data': result.get('data', []) + "status": result.get("status", "active"), + "job_id": job_id, + "total": result.get("total", 0), + "current": result.get("current", 0), + "data": result.get("data", []), } - if crawl_status_data['status'] == 'completed': - website_crawl_time_cache_key = f'website_crawl_{job_id}' + if crawl_status_data["status"] == "completed": + website_crawl_time_cache_key = f"website_crawl_{job_id}" start_time = redis_client.get(website_crawl_time_cache_key) if start_time: end_time = datetime.datetime.now().timestamp() time_consuming = abs(end_time - float(start_time)) - crawl_status_data['time_consuming'] = f"{time_consuming:.2f}" + crawl_status_data["time_consuming"] = f"{time_consuming:.2f}" redis_client.delete(website_crawl_time_cache_key) else: - raise ValueError('Invalid provider') + raise ValueError("Invalid provider") return crawl_status_data @classmethod def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict | None: - credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, - 'website', - provider) - if provider == 'firecrawl': - file_key = 'website_files/' + job_id + '.txt' + credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) + if provider == "firecrawl": + file_key = "website_files/" + job_id + ".txt" if storage.exists(file_key): data = storage.load_once(file_key) if data: - data = json.loads(data.decode('utf-8')) + data = json.loads(data.decode("utf-8")) else: # decrypt api_key - api_key = encrypter.decrypt_token( - tenant_id=tenant_id, - token=credentials.get('config').get('api_key') - ) - firecrawl_app = FirecrawlApp(api_key=api_key, - base_url=credentials.get('config').get('base_url', None)) + api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) result = firecrawl_app.check_crawl_status(job_id) - if result.get('status') != 'completed': - raise ValueError('Crawl job is not completed') - data = result.get('data') + if result.get("status") != "completed": + raise ValueError("Crawl job is not completed") + data = result.get("data") if data: for item in data: - if item.get('source_url') == url: + if item.get("source_url") == url: return item return None else: - raise ValueError('Invalid provider') + raise ValueError("Invalid provider") @classmethod def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict | None: - credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, - 'website', - provider) - if provider == 'firecrawl': + credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) + if provider == "firecrawl": # decrypt api_key - api_key = encrypter.decrypt_token( - tenant_id=tenant_id, - token=credentials.get('config').get('api_key') - ) - firecrawl_app = FirecrawlApp(api_key=api_key, - base_url=credentials.get('config').get('base_url', None)) - params = { - 'pageOptions': { - 'onlyMainContent': only_main_content, - "includeHtml": False - } - } + api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) + params = {"pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False}} result = firecrawl_app.scrape_url(url, params) return result else: - raise ValueError('Invalid provider') + raise ValueError("Invalid provider") diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 9578ffc49b..4b845be2f4 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -32,12 +32,9 @@ class WorkflowConverter: App Convert to Workflow Mode """ - def convert_to_workflow(self, app_model: App, - account: Account, - name: str, - icon_type: str, - icon: str, - icon_background: str) -> App: + def convert_to_workflow( + self, app_model: App, account: Account, name: str, icon_type: str, icon: str, icon_background: str + ): """ Convert app to workflow @@ -56,18 +53,18 @@ class WorkflowConverter: :return: new App instance """ # convert app model config + if not app_model.app_model_config: + raise ValueError("App model config is required") + workflow = self.convert_app_model_config_to_workflow( - app_model=app_model, - app_model_config=app_model.app_model_config, - account_id=account.id + app_model=app_model, app_model_config=app_model.app_model_config, account_id=account.id ) # create new app new_app = App() new_app.tenant_id = app_model.tenant_id - new_app.name = name if name else app_model.name + '(workflow)' - new_app.mode = AppMode.ADVANCED_CHAT.value \ - if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value + new_app.name = name if name else app_model.name + "(workflow)" + new_app.mode = AppMode.ADVANCED_CHAT.value if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value new_app.icon_type = icon_type if icon_type else app_model.icon_type new_app.icon = icon if icon else app_model.icon new_app.icon_background = icon_background if icon_background else app_model.icon_background @@ -77,6 +74,8 @@ class WorkflowConverter: new_app.api_rph = app_model.api_rph new_app.is_demo = False new_app.is_public = app_model.is_public + new_app.created_by = account.id + new_app.updated_by = account.id db.session.add(new_app) db.session.flush() db.session.commit() @@ -88,30 +87,21 @@ class WorkflowConverter: return new_app - def convert_app_model_config_to_workflow(self, app_model: App, - app_model_config: AppModelConfig, - account_id: str) -> Workflow: + def convert_app_model_config_to_workflow(self, app_model: App, app_model_config: AppModelConfig, account_id: str): """ Convert app model config to workflow mode :param app_model: App instance :param app_model_config: AppModelConfig instance :param account_id: Account ID - :return: """ # get new app mode new_app_mode = self._get_new_app_mode(app_model) # convert app model config - app_config = self._convert_to_app_config( - app_model=app_model, - app_model_config=app_model_config - ) + app_config = self._convert_to_app_config(app_model=app_model, app_model_config=app_model_config) # init workflow graph - graph = { - "nodes": [], - "edges": [] - } + graph = {"nodes": [], "edges": []} # Convert list: # - variables -> start @@ -123,11 +113,9 @@ class WorkflowConverter: # - show_retrieve_source -> knowledge-retrieval # convert to start node - start_node = self._convert_to_start_node( - variables=app_config.variables - ) + start_node = self._convert_to_start_node(variables=app_config.variables) - graph['nodes'].append(start_node) + graph["nodes"].append(start_node) # convert to http request node external_data_variable_node_mapping = {} @@ -135,7 +123,7 @@ class WorkflowConverter: http_request_nodes, external_data_variable_node_mapping = self._convert_to_http_request_node( app_model=app_model, variables=app_config.variables, - external_data_variables=app_config.external_data_variables + external_data_variables=app_config.external_data_variables, ) for http_request_node in http_request_nodes: @@ -144,9 +132,7 @@ class WorkflowConverter: # convert to knowledge retrieval node if app_config.dataset: knowledge_retrieval_node = self._convert_to_knowledge_retrieval_node( - new_app_mode=new_app_mode, - dataset_config=app_config.dataset, - model_config=app_config.model + new_app_mode=new_app_mode, dataset_config=app_config.dataset, model_config=app_config.model ) if knowledge_retrieval_node: @@ -160,7 +146,7 @@ class WorkflowConverter: model_config=app_config.model, prompt_template=app_config.prompt_template, file_upload=app_config.additional_features.file_upload, - external_data_variable_node_mapping=external_data_variable_node_mapping + external_data_variable_node_mapping=external_data_variable_node_mapping, ) graph = self._append_node(graph, llm_node) @@ -199,7 +185,7 @@ class WorkflowConverter: tenant_id=app_model.tenant_id, app_id=app_model.id, type=WorkflowType.from_app_mode(new_app_mode).value, - version='draft', + version="draft", graph=json.dumps(graph), features=json.dumps(features), created_by=account_id, @@ -212,24 +198,18 @@ class WorkflowConverter: return workflow - def _convert_to_app_config(self, app_model: App, - app_model_config: AppModelConfig) -> EasyUIBasedAppConfig: + def _convert_to_app_config(self, app_model: App, app_model_config: AppModelConfig) -> EasyUIBasedAppConfig: app_mode = AppMode.value_of(app_model.mode) if app_mode == AppMode.AGENT_CHAT or app_model.is_agent: app_model.mode = AppMode.AGENT_CHAT.value app_config = AgentChatAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config + app_model=app_model, app_model_config=app_model_config ) elif app_mode == AppMode.CHAT: - app_config = ChatAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config - ) + app_config = ChatAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config) elif app_mode == AppMode.COMPLETION: app_config = CompletionAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config + app_model=app_model, app_model_config=app_model_config ) else: raise ValueError("Invalid app mode") @@ -248,14 +228,13 @@ class WorkflowConverter: "data": { "title": "START", "type": NodeType.START.value, - "variables": [jsonable_encoder(v) for v in variables] - } + "variables": [jsonable_encoder(v) for v in variables], + }, } - def _convert_to_http_request_node(self, app_model: App, - variables: list[VariableEntity], - external_data_variables: list[ExternalDataVariableEntity]) \ - -> tuple[list[dict], dict[str, str]]: + def _convert_to_http_request_node( + self, app_model: App, variables: list[VariableEntity], external_data_variables: list[ExternalDataVariableEntity] + ) -> tuple[list[dict], dict[str, str]]: """ Convert API Based Extension to HTTP Request Node :param app_model: App instance @@ -277,40 +256,33 @@ class WorkflowConverter: # get params from config api_based_extension_id = tool_config.get("api_based_extension_id") + if not api_based_extension_id: + continue # get api_based_extension api_based_extension = self._get_api_based_extension( - tenant_id=tenant_id, - api_based_extension_id=api_based_extension_id + tenant_id=tenant_id, api_based_extension_id=api_based_extension_id ) - if not api_based_extension: - raise ValueError("[External data tool] API query failed, variable: {}, " - "error: api_based_extension_id is invalid" - .format(tool_variable)) - # decrypt api_key - api_key = encrypter.decrypt_token( - tenant_id=tenant_id, - token=api_based_extension.api_key - ) + api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=api_based_extension.api_key) inputs = {} for v in variables: - inputs[v.variable] = '{{#start.' + v.variable + '#}}' + inputs[v.variable] = "{{#start." + v.variable + "#}}" request_body = { - 'point': APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value, - 'params': { - 'app_id': app_model.id, - 'tool_variable': tool_variable, - 'inputs': inputs, - 'query': '{{#sys.query#}}' if app_model.mode == AppMode.CHAT.value else '' - } + "point": APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value, + "params": { + "app_id": app_model.id, + "tool_variable": tool_variable, + "inputs": inputs, + "query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT.value else "", + }, } request_body_json = json.dumps(request_body) - request_body_json = request_body_json.replace(r'\{\{', '{{').replace(r'\}\}', '}}') + request_body_json = request_body_json.replace(r"\{\{", "{{").replace(r"\}\}", "}}") http_request_node = { "id": f"http_request_{index}", @@ -320,20 +292,11 @@ class WorkflowConverter: "type": NodeType.HTTP_REQUEST.value, "method": "post", "url": api_based_extension.api_endpoint, - "authorization": { - "type": "api-key", - "config": { - "type": "bearer", - "api_key": api_key - } - }, + "authorization": {"type": "api-key", "config": {"type": "bearer", "api_key": api_key}}, "headers": "", "params": "", - "body": { - "type": "json", - "data": request_body_json - } - } + "body": {"type": "json", "data": request_body_json}, + }, } nodes.append(http_request_node) @@ -345,32 +308,24 @@ class WorkflowConverter: "data": { "title": f"Parse {api_based_extension.name} Response", "type": NodeType.CODE.value, - "variables": [{ - "variable": "response_json", - "value_selector": [http_request_node['id'], "body"] - }], + "variables": [{"variable": "response_json", "value_selector": [http_request_node["id"], "body"]}], "code_language": "python3", "code": "import json\n\ndef main(response_json: str) -> str:\n response_body = json.loads(" - "response_json)\n return {\n \"result\": response_body[\"result\"]\n }", - "outputs": { - "result": { - "type": "string" - } - } - } + 'response_json)\n return {\n "result": response_body["result"]\n }', + "outputs": {"result": {"type": "string"}}, + }, } nodes.append(code_node) - external_data_variable_node_mapping[external_data_variable.variable] = code_node['id'] + external_data_variable_node_mapping[external_data_variable.variable] = code_node["id"] index += 1 return nodes, external_data_variable_node_mapping - def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode, - dataset_config: DatasetEntity, - model_config: ModelConfigEntity) \ - -> Optional[dict]: + def _convert_to_knowledge_retrieval_node( + self, new_app_mode: AppMode, dataset_config: DatasetEntity, model_config: ModelConfigEntity + ) -> Optional[dict]: """ Convert datasets to Knowledge Retrieval Node :param new_app_mode: new app mode @@ -404,7 +359,7 @@ class WorkflowConverter: "completion_params": { **model_config.parameters, "stop": model_config.stop, - } + }, } } if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE @@ -412,20 +367,23 @@ class WorkflowConverter: "multiple_retrieval_config": { "top_k": retrieve_config.top_k, "score_threshold": retrieve_config.score_threshold, - "reranking_model": retrieve_config.reranking_model + "reranking_model": retrieve_config.reranking_model, } if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE else None, - } + }, } - def _convert_to_llm_node(self, original_app_mode: AppMode, - new_app_mode: AppMode, - graph: dict, - model_config: ModelConfigEntity, - prompt_template: PromptTemplateEntity, - file_upload: Optional[FileExtraConfig] = None, - external_data_variable_node_mapping: dict[str, str] = None) -> dict: + def _convert_to_llm_node( + self, + original_app_mode: AppMode, + new_app_mode: AppMode, + graph: dict, + model_config: ModelConfigEntity, + prompt_template: PromptTemplateEntity, + file_upload: Optional[FileExtraConfig] = None, + external_data_variable_node_mapping: dict[str, str] | None = None, + ) -> dict: """ Convert to LLM Node :param original_app_mode: original app mode @@ -437,17 +395,18 @@ class WorkflowConverter: :param external_data_variable_node_mapping: external data variable node mapping """ # fetch start and knowledge retrieval node - start_node = next(filter(lambda n: n['data']['type'] == NodeType.START.value, graph['nodes'])) - knowledge_retrieval_node = next(filter( - lambda n: n['data']['type'] == NodeType.KNOWLEDGE_RETRIEVAL.value, - graph['nodes'] - ), None) + start_node = next(filter(lambda n: n["data"]["type"] == NodeType.START.value, graph["nodes"])) + knowledge_retrieval_node = next( + filter(lambda n: n["data"]["type"] == NodeType.KNOWLEDGE_RETRIEVAL.value, graph["nodes"]), None + ) role_prefix = None # Chat Model if model_config.mode == LLMMode.CHAT.value: if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + if not prompt_template.simple_prompt_template: + raise ValueError("Simple prompt template is required") # get prompt template prompt_transform = SimplePromptTransform() prompt_template_config = prompt_transform.get_prompt_template( @@ -456,45 +415,35 @@ class WorkflowConverter: model=model_config.model, pre_prompt=prompt_template.simple_prompt_template, has_context=knowledge_retrieval_node is not None, - query_in_prompt=False + query_in_prompt=False, ) - template = prompt_template_config['prompt_template'].template + template = prompt_template_config["prompt_template"].template if not template: prompts = [] else: template = self._replace_template_variables( - template, - start_node['data']['variables'], - external_data_variable_node_mapping + template, start_node["data"]["variables"], external_data_variable_node_mapping ) - prompts = [ - { - "role": 'user', - "text": template - } - ] + prompts = [{"role": "user", "text": template}] else: advanced_chat_prompt_template = prompt_template.advanced_chat_prompt_template prompts = [] - for m in advanced_chat_prompt_template.messages: - if advanced_chat_prompt_template: + if advanced_chat_prompt_template: + for m in advanced_chat_prompt_template.messages: text = m.text text = self._replace_template_variables( - text, - start_node['data']['variables'], - external_data_variable_node_mapping + text, start_node["data"]["variables"], external_data_variable_node_mapping ) - prompts.append({ - "role": m.role.value, - "text": text - }) + prompts.append({"role": m.role.value, "text": text}) # Completion Model else: if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + if not prompt_template.simple_prompt_template: + raise ValueError("Simple prompt template is required") # get prompt template prompt_transform = SimplePromptTransform() prompt_template_config = prompt_transform.get_prompt_template( @@ -503,57 +452,50 @@ class WorkflowConverter: model=model_config.model, pre_prompt=prompt_template.simple_prompt_template, has_context=knowledge_retrieval_node is not None, - query_in_prompt=False + query_in_prompt=False, ) - template = prompt_template_config['prompt_template'].template + template = prompt_template_config["prompt_template"].template template = self._replace_template_variables( - template, - start_node['data']['variables'], - external_data_variable_node_mapping + template=template, + variables=start_node["data"]["variables"], + external_data_variable_node_mapping=external_data_variable_node_mapping, ) - prompts = { - "text": template - } + prompts = {"text": template} - prompt_rules = prompt_template_config['prompt_rules'] + prompt_rules = prompt_template_config["prompt_rules"] role_prefix = { - "user": prompt_rules.get('human_prefix', 'Human'), - "assistant": prompt_rules.get('assistant_prefix', 'Assistant') + "user": prompt_rules.get("human_prefix", "Human"), + "assistant": prompt_rules.get("assistant_prefix", "Assistant"), } else: advanced_completion_prompt_template = prompt_template.advanced_completion_prompt_template if advanced_completion_prompt_template: text = advanced_completion_prompt_template.prompt text = self._replace_template_variables( - text, - start_node['data']['variables'], - external_data_variable_node_mapping + template=text, + variables=start_node["data"]["variables"], + external_data_variable_node_mapping=external_data_variable_node_mapping, ) else: text = "" - text = text.replace('{{#query#}}', '{{#sys.query#}}') + text = text.replace("{{#query#}}", "{{#sys.query#}}") prompts = { "text": text, } - if advanced_completion_prompt_template.role_prefix: + if advanced_completion_prompt_template and advanced_completion_prompt_template.role_prefix: role_prefix = { "user": advanced_completion_prompt_template.role_prefix.user, - "assistant": advanced_completion_prompt_template.role_prefix.assistant + "assistant": advanced_completion_prompt_template.role_prefix.assistant, } memory = None if new_app_mode == AppMode.ADVANCED_CHAT: - memory = { - "role_prefix": role_prefix, - "window": { - "enabled": False - } - } + memory = {"role_prefix": role_prefix, "window": {"enabled": False}} completion_params = model_config.parameters completion_params.update({"stop": model_config.stop}) @@ -567,28 +509,29 @@ class WorkflowConverter: "provider": model_config.provider, "name": model_config.model, "mode": model_config.mode, - "completion_params": completion_params + "completion_params": completion_params, }, "prompt_template": prompts, "memory": memory, "context": { "enabled": knowledge_retrieval_node is not None, "variable_selector": ["knowledge_retrieval", "result"] - if knowledge_retrieval_node is not None else None + if knowledge_retrieval_node is not None + else None, }, "vision": { "enabled": file_upload is not None, "variable_selector": ["sys", "files"] if file_upload is not None else None, - "configs": { - "detail": file_upload.image_config['detail'] - } if file_upload is not None else None - } - } + "configs": {"detail": file_upload.image_config["detail"]} + if file_upload is not None and file_upload.image_config is not None + else None, + }, + }, } - def _replace_template_variables(self, template: str, - variables: list[dict], - external_data_variable_node_mapping: dict[str, str] = None) -> str: + def _replace_template_variables( + self, template: str, variables: list[dict], external_data_variable_node_mapping: dict[str, str] | None = None + ) -> str: """ Replace Template Variables :param template: template @@ -597,12 +540,11 @@ class WorkflowConverter: :return: """ for v in variables: - template = template.replace('{{' + v['variable'] + '}}', '{{#start.' + v['variable'] + '#}}') + template = template.replace("{{" + v["variable"] + "}}", "{{#start." + v["variable"] + "#}}") if external_data_variable_node_mapping: for variable, code_node_id in external_data_variable_node_mapping.items(): - template = template.replace('{{' + variable + '}}', - '{{#' + code_node_id + '.result#}}') + template = template.replace("{{" + variable + "}}", "{{#" + code_node_id + ".result#}}") return template @@ -618,11 +560,8 @@ class WorkflowConverter: "data": { "title": "END", "type": NodeType.END.value, - "outputs": [{ - "variable": "result", - "value_selector": ["llm", "text"] - }] - } + "outputs": [{"variable": "result", "value_selector": ["llm", "text"]}], + }, } def _convert_to_answer_node(self) -> dict: @@ -634,11 +573,7 @@ class WorkflowConverter: return { "id": "answer", "position": None, - "data": { - "title": "ANSWER", - "type": NodeType.ANSWER.value, - "answer": "{{#llm.text#}}" - } + "data": {"title": "ANSWER", "type": NodeType.ANSWER.value, "answer": "{{#llm.text#}}"}, } def _create_edge(self, source: str, target: str) -> dict: @@ -648,11 +583,7 @@ class WorkflowConverter: :param target: target node id :return: """ - return { - "id": f"{source}-{target}", - "source": source, - "target": target - } + return {"id": f"{source}-{target}", "source": source, "target": target} def _append_node(self, graph: dict, node: dict) -> dict: """ @@ -662,9 +593,9 @@ class WorkflowConverter: :param node: Node to append :return: """ - previous_node = graph['nodes'][-1] - graph['nodes'].append(node) - graph['edges'].append(self._create_edge(previous_node['id'], node['id'])) + previous_node = graph["nodes"][-1] + graph["nodes"].append(node) + graph["edges"].append(self._create_edge(previous_node["id"], node["id"])) return graph def _get_new_app_mode(self, app_model: App) -> AppMode: @@ -678,14 +609,20 @@ class WorkflowConverter: else: return AppMode.ADVANCED_CHAT - def _get_api_based_extension(self, tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: + def _get_api_based_extension(self, tenant_id: str, api_based_extension_id: str): """ Get API Based Extension :param tenant_id: tenant id :param api_based_extension_id: api based extension id :return: """ - return db.session.query(APIBasedExtension).filter( - APIBasedExtension.tenant_id == tenant_id, - APIBasedExtension.id == api_based_extension_id - ).first() + api_based_extension = ( + db.session.query(APIBasedExtension) + .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) + .first() + ) + + if not api_based_extension: + raise ValueError(f"API Based Extension not found, id: {api_based_extension_id}") + + return api_based_extension diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index c4d3d27631..b4f0882a3a 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -10,7 +10,6 @@ from models.workflow import WorkflowAppLog, WorkflowRun, WorkflowRunStatus class WorkflowAppService: - def get_paginate_workflow_app_logs(self, app_model: App, args: dict) -> Pagination: """ Get paginate workflow app logs @@ -18,20 +17,14 @@ class WorkflowAppService: :param args: request args :return: """ - query = ( - db.select(WorkflowAppLog) - .where( - WorkflowAppLog.tenant_id == app_model.tenant_id, - WorkflowAppLog.app_id == app_model.id - ) + query = db.select(WorkflowAppLog).where( + WorkflowAppLog.tenant_id == app_model.tenant_id, WorkflowAppLog.app_id == app_model.id ) - status = WorkflowRunStatus.value_of(args.get('status')) if args.get('status') else None - keyword = args['keyword'] + status = WorkflowRunStatus.value_of(args.get("status")) if args.get("status") else None + keyword = args["keyword"] if keyword or status: - query = query.join( - WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id - ) + query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id) if keyword: keyword_like_val = f"%{args['keyword'][:30]}%" @@ -39,7 +32,7 @@ class WorkflowAppService: WorkflowRun.inputs.ilike(keyword_like_val), WorkflowRun.outputs.ilike(keyword_like_val), # filter keyword by end user session id if created by end user role - and_(WorkflowRun.created_by_role == 'end_user', EndUser.session_id.ilike(keyword_like_val)) + and_(WorkflowRun.created_by_role == "end_user", EndUser.session_id.ilike(keyword_like_val)), ] # filter keyword by workflow run id @@ -49,23 +42,16 @@ class WorkflowAppService: query = query.outerjoin( EndUser, - and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER.value) + and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER.value), ).filter(or_(*keyword_conditions)) if status: # join with workflow_run and filter by status - query = query.filter( - WorkflowRun.status == status.value - ) + query = query.filter(WorkflowRun.status == status.value) query = query.order_by(WorkflowAppLog.created_at.desc()) - pagination = db.paginate( - query, - page=args['page'], - per_page=args['limit'], - error_out=False - ) + pagination = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False) return pagination diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index ccce38ada0..b7b3abeaa2 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -18,6 +18,7 @@ class WorkflowRunService: :param app_model: app model :param args: request args """ + class WorkflowWithMessage: message_id: str conversation_id: str @@ -33,9 +34,7 @@ class WorkflowRunService: with_message_workflow_runs = [] for workflow_run in pagination.data: message = workflow_run.message - with_message_workflow_run = WorkflowWithMessage( - workflow_run=workflow_run - ) + with_message_workflow_run = WorkflowWithMessage(workflow_run=workflow_run) if message: with_message_workflow_run.message_id = message.id with_message_workflow_run.conversation_id = message.conversation_id @@ -53,26 +52,30 @@ class WorkflowRunService: :param app_model: app model :param args: request args """ - limit = int(args.get('limit', 20)) + limit = int(args.get("limit", 20)) base_query = db.session.query(WorkflowRun).filter( WorkflowRun.tenant_id == app_model.tenant_id, WorkflowRun.app_id == app_model.id, - WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value + WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value, ) - if args.get('last_id'): + if args.get("last_id"): last_workflow_run = base_query.filter( - WorkflowRun.id == args.get('last_id'), + WorkflowRun.id == args.get("last_id"), ).first() if not last_workflow_run: - raise ValueError('Last workflow run not exists') + raise ValueError("Last workflow run not exists") - workflow_runs = base_query.filter( - WorkflowRun.created_at < last_workflow_run.created_at, - WorkflowRun.id != last_workflow_run.id - ).order_by(WorkflowRun.created_at.desc()).limit(limit).all() + workflow_runs = ( + base_query.filter( + WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id + ) + .order_by(WorkflowRun.created_at.desc()) + .limit(limit) + .all() + ) else: workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all() @@ -81,17 +84,13 @@ class WorkflowRunService: current_page_first_workflow_run = workflow_runs[-1] rest_count = base_query.filter( WorkflowRun.created_at < current_page_first_workflow_run.created_at, - WorkflowRun.id != current_page_first_workflow_run.id + WorkflowRun.id != current_page_first_workflow_run.id, ).count() if rest_count > 0: has_more = True - return InfiniteScrollPagination( - data=workflow_runs, - limit=limit, - has_more=has_more - ) + return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more) def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun: """ @@ -100,11 +99,15 @@ class WorkflowRunService: :param app_model: app model :param run_id: workflow run id """ - workflow_run = db.session.query(WorkflowRun).filter( - WorkflowRun.tenant_id == app_model.tenant_id, - WorkflowRun.app_id == app_model.id, - WorkflowRun.id == run_id, - ).first() + workflow_run = ( + db.session.query(WorkflowRun) + .filter( + WorkflowRun.tenant_id == app_model.tenant_id, + WorkflowRun.app_id == app_model.id, + WorkflowRun.id == run_id, + ) + .first() + ) return workflow_run @@ -117,12 +120,17 @@ class WorkflowRunService: if not workflow_run: return [] - node_executions = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.tenant_id == app_model.tenant_id, - WorkflowNodeExecution.app_id == app_model.id, - WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, - WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - WorkflowNodeExecution.workflow_run_id == run_id, - ).order_by(WorkflowNodeExecution.index.desc()).all() + node_executions = ( + db.session.query(WorkflowNodeExecution) + .filter( + WorkflowNodeExecution.tenant_id == app_model.tenant_id, + WorkflowNodeExecution.app_id == app_model.id, + WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + WorkflowNodeExecution.workflow_run_id == run_id, + ) + .order_by(WorkflowNodeExecution.index.desc()) + .all() + ) return node_executions diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index c593b66f36..4c3ded14ad 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -37,11 +37,13 @@ class WorkflowService: Get draft workflow """ # fetch draft workflow by app_model - workflow = db.session.query(Workflow).filter( - Workflow.tenant_id == app_model.tenant_id, - Workflow.app_id == app_model.id, - Workflow.version == 'draft' - ).first() + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.version == "draft" + ) + .first() + ) # return draft workflow return workflow @@ -55,11 +57,15 @@ class WorkflowService: return None # fetch published workflow by workflow_id - workflow = db.session.query(Workflow).filter( - Workflow.tenant_id == app_model.tenant_id, - Workflow.app_id == app_model.id, - Workflow.id == app_model.workflow_id - ).first() + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.id == app_model.workflow_id, + ) + .first() + ) return workflow @@ -85,10 +91,7 @@ class WorkflowService: raise WorkflowHashNotEqualError() # validate features structure - self.validate_features_structure( - app_model=app_model, - features=features - ) + self.validate_features_structure(app_model=app_model, features=features) # create draft workflow if not found if not workflow: @@ -96,7 +99,7 @@ class WorkflowService: tenant_id=app_model.tenant_id, app_id=app_model.id, type=WorkflowType.from_app_mode(app_model.mode).value, - version='draft', + version="draft", graph=json.dumps(graph), features=json.dumps(features), created_by=account.id, @@ -122,9 +125,7 @@ class WorkflowService: # return draft workflow return workflow - def publish_workflow(self, app_model: App, - account: Account, - draft_workflow: Optional[Workflow] = None) -> Workflow: + def publish_workflow(self, app_model: App, account: Account, draft_workflow: Optional[Workflow] = None) -> Workflow: """ Publish workflow from draft @@ -137,7 +138,7 @@ class WorkflowService: draft_workflow = self.get_draft_workflow(app_model=app_model) if not draft_workflow: - raise ValueError('No valid workflow found.') + raise ValueError("No valid workflow found.") # create new workflow workflow = Workflow( @@ -187,17 +188,16 @@ class WorkflowService: workflow_engine_manager = WorkflowEngineManager() return workflow_engine_manager.get_default_config(node_type, filters) - def run_draft_workflow_node(self, app_model: App, - node_id: str, - user_inputs: dict, - account: Account) -> WorkflowNodeExecution: + def run_draft_workflow_node( + self, app_model: App, node_id: str, user_inputs: dict, account: Account + ) -> WorkflowNodeExecution: """ Run draft workflow node """ # fetch draft workflow by app_model draft_workflow = self.get_draft_workflow(app_model=app_model) if not draft_workflow: - raise ValueError('Workflow not initialized') + raise ValueError("Workflow not initialized") # run draft workflow node workflow_engine_manager = WorkflowEngineManager() @@ -226,7 +226,7 @@ class WorkflowService: created_by_role=CreatedByRole.ACCOUNT.value, created_by=account.id, created_at=datetime.now(timezone.utc).replace(tzinfo=None), - finished_at=datetime.now(timezone.utc).replace(tzinfo=None) + finished_at=datetime.now(timezone.utc).replace(tzinfo=None), ) db.session.add(workflow_node_execution) db.session.commit() @@ -247,14 +247,15 @@ class WorkflowService: inputs=json.dumps(node_run_result.inputs) if node_run_result.inputs else None, process_data=json.dumps(node_run_result.process_data) if node_run_result.process_data else None, outputs=json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None, - execution_metadata=(json.dumps(jsonable_encoder(node_run_result.metadata)) - if node_run_result.metadata else None), + execution_metadata=( + json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None + ), status=WorkflowNodeExecutionStatus.SUCCEEDED.value, elapsed_time=time.perf_counter() - start_at, created_by_role=CreatedByRole.ACCOUNT.value, created_by=account.id, created_at=datetime.now(timezone.utc).replace(tzinfo=None), - finished_at=datetime.now(timezone.utc).replace(tzinfo=None) + finished_at=datetime.now(timezone.utc).replace(tzinfo=None), ) else: # create workflow node execution @@ -273,7 +274,7 @@ class WorkflowService: created_by_role=CreatedByRole.ACCOUNT.value, created_by=account.id, created_at=datetime.now(timezone.utc).replace(tzinfo=None), - finished_at=datetime.now(timezone.utc).replace(tzinfo=None) + finished_at=datetime.now(timezone.utc).replace(tzinfo=None), ) db.session.add(workflow_node_execution) @@ -295,16 +296,16 @@ class WorkflowService: workflow_converter = WorkflowConverter() if app_model.mode not in [AppMode.CHAT.value, AppMode.COMPLETION.value]: - raise ValueError(f'Current App mode: {app_model.mode} is not supported convert to workflow.') + raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.") # convert to workflow new_app = workflow_converter.convert_to_workflow( app_model=app_model, account=account, - name=args.get('name'), - icon_type=args.get('icon_type'), - icon=args.get('icon'), - icon_background=args.get('icon_background'), + name=args.get("name"), + icon_type=args.get("icon_type"), + icon=args.get("icon"), + icon_background=args.get("icon_background"), ) return new_app @@ -312,15 +313,11 @@ class WorkflowService: def validate_features_structure(self, app_model: App, features: dict) -> dict: if app_model.mode == AppMode.ADVANCED_CHAT.value: return AdvancedChatAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, - config=features, - only_structure_validate=True + tenant_id=app_model.tenant_id, config=features, only_structure_validate=True ) elif app_model.mode == AppMode.WORKFLOW.value: return WorkflowAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, - config=features, - only_structure_validate=True + tenant_id=app_model.tenant_id, config=features, only_structure_validate=True ) else: raise ValueError(f"Invalid app mode: {app_model.mode}") diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index 2bcbe5c6f6..8fcb12b1cb 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -1,4 +1,3 @@ - from flask_login import current_user from configs import dify_config @@ -14,34 +13,40 @@ class WorkspaceService: if not tenant: return None tenant_info = { - 'id': tenant.id, - 'name': tenant.name, - 'plan': tenant.plan, - 'status': tenant.status, - 'created_at': tenant.created_at, - 'in_trail': True, - 'trial_end_reason': None, - 'role': 'normal', + "id": tenant.id, + "name": tenant.name, + "plan": tenant.plan, + "status": tenant.status, + "created_at": tenant.created_at, + "in_trail": True, + "trial_end_reason": None, + "role": "normal", } # Get role of user - tenant_account_join = db.session.query(TenantAccountJoin).filter( - TenantAccountJoin.tenant_id == tenant.id, - TenantAccountJoin.account_id == current_user.id - ).first() - tenant_info['role'] = tenant_account_join.role + tenant_account_join = ( + db.session.query(TenantAccountJoin) + .filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id) + .first() + ) + tenant_info["role"] = tenant_account_join.role - can_replace_logo = FeatureService.get_features(tenant_info['id']).can_replace_logo + can_replace_logo = FeatureService.get_features(tenant_info["id"]).can_replace_logo - if can_replace_logo and TenantService.has_roles(tenant, - [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]): + if can_replace_logo and TenantService.has_roles( + tenant, [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN] + ): base_url = dify_config.FILES_URL - replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None - remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False) + replace_webapp_logo = ( + f"{base_url}/files/workspaces/{tenant.id}/webapp-logo" + if tenant.custom_config_dict.get("replace_webapp_logo") + else None + ) + remove_webapp_brand = tenant.custom_config_dict.get("remove_webapp_brand", False) - tenant_info['custom_config'] = { - 'remove_webapp_brand': remove_webapp_brand, - 'replace_webapp_logo': replace_webapp_logo, + tenant_info["custom_config"] = { + "remove_webapp_brand": remove_webapp_brand, + "replace_webapp_logo": replace_webapp_logo, } return tenant_info diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index e0a1b21909..b50876cc79 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -14,7 +14,7 @@ from models.dataset import Document as DatasetDocument from models.dataset import DocumentSegment -@shared_task(queue='dataset') +@shared_task(queue="dataset") def add_document_to_index_task(dataset_document_id: str): """ Async Add document to index @@ -22,24 +22,25 @@ def add_document_to_index_task(dataset_document_id: str): Usage: add_document_to_index.delay(document_id) """ - logging.info(click.style('Start add document to index: {}'.format(dataset_document_id), fg='green')) + logging.info(click.style("Start add document to index: {}".format(dataset_document_id), fg="green")) start_at = time.perf_counter() dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document_id).first() if not dataset_document: - raise NotFound('Document not found') + raise NotFound("Document not found") - if dataset_document.indexing_status != 'completed': + if dataset_document.indexing_status != "completed": return - indexing_cache_key = 'document_{}_indexing'.format(dataset_document.id) + indexing_cache_key = "document_{}_indexing".format(dataset_document.id) try: - segments = db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == dataset_document.id, - DocumentSegment.enabled == True - ) \ - .order_by(DocumentSegment.position.asc()).all() + segments = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) + .order_by(DocumentSegment.position.asc()) + .all() + ) documents = [] for segment in segments: @@ -50,7 +51,7 @@ def add_document_to_index_task(dataset_document_id: str): "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) documents.append(document) @@ -58,7 +59,7 @@ def add_document_to_index_task(dataset_document_id: str): dataset = dataset_document.dataset if not dataset: - raise Exception('Document has no dataset') + raise Exception("Document has no dataset") index_type = dataset.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() @@ -66,12 +67,15 @@ def add_document_to_index_task(dataset_document_id: str): end_at = time.perf_counter() logging.info( - click.style('Document added to index: {} latency: {}'.format(dataset_document.id, end_at - start_at), fg='green')) + click.style( + "Document added to index: {} latency: {}".format(dataset_document.id, end_at - start_at), fg="green" + ) + ) except Exception as e: logging.exception("add document to index failed") dataset_document.enabled = False dataset_document.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - dataset_document.status = 'error' + dataset_document.status = "error" dataset_document.error = str(e) db.session.commit() finally: diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py index b3aa8b596c..25c55bcfaf 100644 --- a/api/tasks/annotation/add_annotation_to_index_task.py +++ b/api/tasks/annotation/add_annotation_to_index_task.py @@ -10,9 +10,10 @@ from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService -@shared_task(queue='dataset') -def add_annotation_to_index_task(annotation_id: str, question: str, tenant_id: str, app_id: str, - collection_binding_id: str): +@shared_task(queue="dataset") +def add_annotation_to_index_task( + annotation_id: str, question: str, tenant_id: str, app_id: str, collection_binding_id: str +): """ Add annotation to index. :param annotation_id: annotation id @@ -23,38 +24,34 @@ def add_annotation_to_index_task(annotation_id: str, question: str, tenant_id: s Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) """ - logging.info(click.style('Start build index for annotation: {}'.format(annotation_id), fg='green')) + logging.info(click.style("Start build index for annotation: {}".format(annotation_id), fg="green")) start_at = time.perf_counter() try: dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( - collection_binding_id, - 'annotation' + collection_binding_id, "annotation" ) dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique='high_quality', + indexing_technique="high_quality", embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, - collection_binding_id=dataset_collection_binding.id + collection_binding_id=dataset_collection_binding.id, ) document = Document( - page_content=question, - metadata={ - "annotation_id": annotation_id, - "app_id": app_id, - "doc_id": annotation_id - } + page_content=question, metadata={"annotation_id": annotation_id, "app_id": app_id, "doc_id": annotation_id} ) - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) vector.create([document], duplicate_check=True) end_at = time.perf_counter() logging.info( click.style( - 'Build index successful for annotation: {} latency: {}'.format(annotation_id, end_at - start_at), - fg='green')) + "Build index successful for annotation: {} latency: {}".format(annotation_id, end_at - start_at), + fg="green", + ) + ) except Exception: logging.exception("Build index for annotation failed") diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index 6e6b16045d..fa7e5ac919 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -14,9 +14,8 @@ from models.model import App, AppAnnotationSetting, MessageAnnotation from services.dataset_service import DatasetCollectionBindingService -@shared_task(queue='dataset') -def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: str, tenant_id: str, - user_id: str): +@shared_task(queue="dataset") +def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: str, tenant_id: str, user_id: str): """ Add annotation to index. :param job_id: job_id @@ -26,72 +25,66 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: :param user_id: user_id """ - logging.info(click.style('Start batch import annotation: {}'.format(job_id), fg='green')) + logging.info(click.style("Start batch import annotation: {}".format(job_id), fg="green")) start_at = time.perf_counter() - indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id)) + indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id)) # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == tenant_id, - App.status == 'normal' - ).first() + app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() if app: try: documents = [] for content in content_list: annotation = MessageAnnotation( - app_id=app.id, - content=content['answer'], - question=content['question'], - account_id=user_id + app_id=app.id, content=content["answer"], question=content["question"], account_id=user_id ) db.session.add(annotation) db.session.flush() document = Document( - page_content=content['question'], - metadata={ - "annotation_id": annotation.id, - "app_id": app_id, - "doc_id": annotation.id - } + page_content=content["question"], + metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id}, ) documents.append(document) # if annotation reply is enabled , batch add annotations' index - app_annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id - ).first() + app_annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if app_annotation_setting: - dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( - app_annotation_setting.collection_binding_id, - 'annotation' + dataset_collection_binding = ( + DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( + app_annotation_setting.collection_binding_id, "annotation" + ) ) if not dataset_collection_binding: raise NotFound("App annotation setting not found") dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique='high_quality', + indexing_technique="high_quality", embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, - collection_binding_id=dataset_collection_binding.id + collection_binding_id=dataset_collection_binding.id, ) - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) vector.create(documents, duplicate_check=True) db.session.commit() - redis_client.setex(indexing_cache_key, 600, 'completed') + redis_client.setex(indexing_cache_key, 600, "completed") end_at = time.perf_counter() logging.info( click.style( - 'Build index successful for batch import annotation: {} latency: {}'.format(job_id, end_at - start_at), - fg='green')) + "Build index successful for batch import annotation: {} latency: {}".format( + job_id, end_at - start_at + ), + fg="green", + ) + ) except Exception as e: db.session.rollback() - redis_client.setex(indexing_cache_key, 600, 'error') - indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id)) + redis_client.setex(indexing_cache_key, 600, "error") + indexing_error_msg_key = "app_annotation_batch_import_error_msg_{}".format(str(job_id)) redis_client.setex(indexing_error_msg_key, 600, str(e)) logging.exception("Build index for batch import annotations failed") diff --git a/api/tasks/annotation/delete_annotation_index_task.py b/api/tasks/annotation/delete_annotation_index_task.py index 81155a35e4..5758db53de 100644 --- a/api/tasks/annotation/delete_annotation_index_task.py +++ b/api/tasks/annotation/delete_annotation_index_task.py @@ -9,36 +9,33 @@ from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService -@shared_task(queue='dataset') -def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str, - collection_binding_id: str): +@shared_task(queue="dataset") +def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str, collection_binding_id: str): """ Async delete annotation index task """ - logging.info(click.style('Start delete app annotation index: {}'.format(app_id), fg='green')) + logging.info(click.style("Start delete app annotation index: {}".format(app_id), fg="green")) start_at = time.perf_counter() try: dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( - collection_binding_id, - 'annotation' + collection_binding_id, "annotation" ) dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique='high_quality', - collection_binding_id=dataset_collection_binding.id + indexing_technique="high_quality", + collection_binding_id=dataset_collection_binding.id, ) try: - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) - vector.delete_by_metadata_field('annotation_id', annotation_id) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) + vector.delete_by_metadata_field("annotation_id", annotation_id) except Exception: logging.exception("Delete annotation index failed when annotation deleted.") end_at = time.perf_counter() logging.info( - click.style('App annotations index deleted : {} latency: {}'.format(app_id, end_at - start_at), - fg='green')) + click.style("App annotations index deleted : {} latency: {}".format(app_id, end_at - start_at), fg="green") + ) except Exception as e: logging.exception("Annotation deleted index failed:{}".format(str(e))) - diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py index e5e03c9b51..0f83dfdbd4 100644 --- a/api/tasks/annotation/disable_annotation_reply_task.py +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -12,49 +12,44 @@ from models.dataset import Dataset from models.model import App, AppAnnotationSetting, MessageAnnotation -@shared_task(queue='dataset') +@shared_task(queue="dataset") def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): """ Async enable annotation reply task """ - logging.info(click.style('Start delete app annotations index: {}'.format(app_id), fg='green')) + logging.info(click.style("Start delete app annotations index: {}".format(app_id), fg="green")) start_at = time.perf_counter() # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == tenant_id, - App.status == 'normal' - ).first() + app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() annotations_count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id).count() if not app: raise NotFound("App not found") - app_annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id - ).first() + app_annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if not app_annotation_setting: raise NotFound("App annotation setting not found") - disable_app_annotation_key = 'disable_app_annotation_{}'.format(str(app_id)) - disable_app_annotation_job_key = 'disable_app_annotation_job_{}'.format(str(job_id)) + disable_app_annotation_key = "disable_app_annotation_{}".format(str(app_id)) + disable_app_annotation_job_key = "disable_app_annotation_job_{}".format(str(job_id)) try: - dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique='high_quality', - collection_binding_id=app_annotation_setting.collection_binding_id + indexing_technique="high_quality", + collection_binding_id=app_annotation_setting.collection_binding_id, ) try: if annotations_count > 0: - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) - vector.delete_by_metadata_field('app_id', app_id) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) + vector.delete_by_metadata_field("app_id", app_id) except Exception: logging.exception("Delete annotation index failed when annotation deleted.") - redis_client.setex(disable_app_annotation_job_key, 600, 'completed') + redis_client.setex(disable_app_annotation_job_key, 600, "completed") # delete annotation setting db.session.delete(app_annotation_setting) @@ -62,12 +57,12 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): end_at = time.perf_counter() logging.info( - click.style('App annotations index deleted : {} latency: {}'.format(app_id, end_at - start_at), - fg='green')) + click.style("App annotations index deleted : {} latency: {}".format(app_id, end_at - start_at), fg="green") + ) except Exception as e: logging.exception("Annotation batch deleted index failed:{}".format(str(e))) - redis_client.setex(disable_app_annotation_job_key, 600, 'error') - disable_app_annotation_error_key = 'disable_app_annotation_error_{}'.format(str(job_id)) + redis_client.setex(disable_app_annotation_job_key, 600, "error") + disable_app_annotation_error_key = "disable_app_annotation_error_{}".format(str(job_id)) redis_client.setex(disable_app_annotation_error_key, 600, str(e)) finally: redis_client.delete(disable_app_annotation_key) diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index fda8b7a250..82b70f6b71 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -15,37 +15,39 @@ from models.model import App, AppAnnotationSetting, MessageAnnotation from services.dataset_service import DatasetCollectionBindingService -@shared_task(queue='dataset') -def enable_annotation_reply_task(job_id: str, app_id: str, user_id: str, tenant_id: str, score_threshold: float, - embedding_provider_name: str, embedding_model_name: str): +@shared_task(queue="dataset") +def enable_annotation_reply_task( + job_id: str, + app_id: str, + user_id: str, + tenant_id: str, + score_threshold: float, + embedding_provider_name: str, + embedding_model_name: str, +): """ Async enable annotation reply task """ - logging.info(click.style('Start add app annotation to index: {}'.format(app_id), fg='green')) + logging.info(click.style("Start add app annotation to index: {}".format(app_id), fg="green")) start_at = time.perf_counter() # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == tenant_id, - App.status == 'normal' - ).first() + app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() if not app: raise NotFound("App not found") annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id).all() - enable_app_annotation_key = 'enable_app_annotation_{}'.format(str(app_id)) - enable_app_annotation_job_key = 'enable_app_annotation_job_{}'.format(str(job_id)) + enable_app_annotation_key = "enable_app_annotation_{}".format(str(app_id)) + enable_app_annotation_job_key = "enable_app_annotation_job_{}".format(str(job_id)) try: documents = [] dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_provider_name, - embedding_model_name, - 'annotation' + embedding_provider_name, embedding_model_name, "annotation" + ) + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() ) - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id).first() if annotation_setting: annotation_setting.score_threshold = score_threshold annotation_setting.collection_binding_id = dataset_collection_binding.id @@ -58,48 +60,42 @@ def enable_annotation_reply_task(job_id: str, app_id: str, user_id: str, tenant_ score_threshold=score_threshold, collection_binding_id=dataset_collection_binding.id, created_user_id=user_id, - updated_user_id=user_id + updated_user_id=user_id, ) db.session.add(new_app_annotation_setting) dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique='high_quality', + indexing_technique="high_quality", embedding_model_provider=embedding_provider_name, embedding_model=embedding_model_name, - collection_binding_id=dataset_collection_binding.id + collection_binding_id=dataset_collection_binding.id, ) if annotations: for annotation in annotations: document = Document( page_content=annotation.question, - metadata={ - "annotation_id": annotation.id, - "app_id": app_id, - "doc_id": annotation.id - } + metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id}, ) documents.append(document) - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) try: - vector.delete_by_metadata_field('app_id', app_id) + vector.delete_by_metadata_field("app_id", app_id) except Exception as e: - logging.info( - click.style('Delete annotation index error: {}'.format(str(e)), - fg='red')) + logging.info(click.style("Delete annotation index error: {}".format(str(e)), fg="red")) vector.create(documents) db.session.commit() - redis_client.setex(enable_app_annotation_job_key, 600, 'completed') + redis_client.setex(enable_app_annotation_job_key, 600, "completed") end_at = time.perf_counter() logging.info( - click.style('App annotations added to index: {} latency: {}'.format(app_id, end_at - start_at), - fg='green')) + click.style("App annotations added to index: {} latency: {}".format(app_id, end_at - start_at), fg="green") + ) except Exception as e: logging.exception("Annotation batch created index failed:{}".format(str(e))) - redis_client.setex(enable_app_annotation_job_key, 600, 'error') - enable_app_annotation_error_key = 'enable_app_annotation_error_{}'.format(str(job_id)) + redis_client.setex(enable_app_annotation_job_key, 600, "error") + enable_app_annotation_error_key = "enable_app_annotation_error_{}".format(str(job_id)) redis_client.setex(enable_app_annotation_error_key, 600, str(e)) db.session.rollback() finally: diff --git a/api/tasks/annotation/update_annotation_to_index_task.py b/api/tasks/annotation/update_annotation_to_index_task.py index 7219abd3cd..b685d84d07 100644 --- a/api/tasks/annotation/update_annotation_to_index_task.py +++ b/api/tasks/annotation/update_annotation_to_index_task.py @@ -10,9 +10,10 @@ from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService -@shared_task(queue='dataset') -def update_annotation_to_index_task(annotation_id: str, question: str, tenant_id: str, app_id: str, - collection_binding_id: str): +@shared_task(queue="dataset") +def update_annotation_to_index_task( + annotation_id: str, question: str, tenant_id: str, app_id: str, collection_binding_id: str +): """ Update annotation to index. :param annotation_id: annotation id @@ -23,39 +24,35 @@ def update_annotation_to_index_task(annotation_id: str, question: str, tenant_id Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) """ - logging.info(click.style('Start update index for annotation: {}'.format(annotation_id), fg='green')) + logging.info(click.style("Start update index for annotation: {}".format(annotation_id), fg="green")) start_at = time.perf_counter() try: dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( - collection_binding_id, - 'annotation' + collection_binding_id, "annotation" ) dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique='high_quality', + indexing_technique="high_quality", embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, - collection_binding_id=dataset_collection_binding.id + collection_binding_id=dataset_collection_binding.id, ) document = Document( - page_content=question, - metadata={ - "annotation_id": annotation_id, - "app_id": app_id, - "doc_id": annotation_id - } + page_content=question, metadata={"annotation_id": annotation_id, "app_id": app_id, "doc_id": annotation_id} ) - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) - vector.delete_by_metadata_field('annotation_id', annotation_id) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) + vector.delete_by_metadata_field("annotation_id", annotation_id) vector.add_texts([document]) end_at = time.perf_counter() logging.info( click.style( - 'Build index successful for annotation: {} latency: {}'.format(annotation_id, end_at - start_at), - fg='green')) + "Build index successful for annotation: {} latency: {}".format(annotation_id, end_at - start_at), + fg="green", + ) + ) except Exception: logging.exception("Build index for annotation failed") diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 67cc03bdeb..de7f0ddec1 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -16,9 +16,10 @@ from libs import helper from models.dataset import Dataset, Document, DocumentSegment -@shared_task(queue='dataset') -def batch_create_segment_to_index_task(job_id: str, content: list, dataset_id: str, document_id: str, - tenant_id: str, user_id: str): +@shared_task(queue="dataset") +def batch_create_segment_to_index_task( + job_id: str, content: list, dataset_id: str, document_id: str, tenant_id: str, user_id: str +): """ Async batch create segment to index :param job_id: @@ -30,44 +31,44 @@ def batch_create_segment_to_index_task(job_id: str, content: list, dataset_id: s Usage: batch_create_segment_to_index_task.delay(segment_id) """ - logging.info(click.style('Start batch create segment jobId: {}'.format(job_id), fg='green')) + logging.info(click.style("Start batch create segment jobId: {}".format(job_id), fg="green")) start_at = time.perf_counter() - indexing_cache_key = 'segment_batch_import_{}'.format(job_id) + indexing_cache_key = "segment_batch_import_{}".format(job_id) try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: - raise ValueError('Dataset not exist.') + raise ValueError("Dataset not exist.") dataset_document = db.session.query(Document).filter(Document.id == document_id).first() if not dataset_document: - raise ValueError('Document not exist.') + raise ValueError("Document not exist.") - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': - raise ValueError('Document is not available.') + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + raise ValueError("Document is not available.") document_segments = [] embedding_model = None - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) for segment in content: - content = segment['content'] + content = segment["content"] doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) # calc embedding use tokens - tokens = embedding_model.get_text_embedding_num_tokens( - texts=[content] - ) if embedding_model else 0 - max_position = db.session.query(func.max(DocumentSegment.position)).filter( - DocumentSegment.document_id == dataset_document.id - ).scalar() + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) if embedding_model else 0 + max_position = ( + db.session.query(func.max(DocumentSegment.position)) + .filter(DocumentSegment.document_id == dataset_document.id) + .scalar() + ) segment_document = DocumentSegment( tenant_id=tenant_id, dataset_id=dataset_id, @@ -80,20 +81,22 @@ def batch_create_segment_to_index_task(job_id: str, content: list, dataset_id: s tokens=tokens, created_by=user_id, indexing_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), - status='completed', - completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + status="completed", + completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), ) - if dataset_document.doc_form == 'qa_model': - segment_document.answer = segment['answer'] + if dataset_document.doc_form == "qa_model": + segment_document.answer = segment["answer"] db.session.add(segment_document) document_segments.append(segment_document) # add index to db indexing_runner = IndexingRunner() indexing_runner.batch_add_segments(document_segments, dataset) db.session.commit() - redis_client.setex(indexing_cache_key, 600, 'completed') + redis_client.setex(indexing_cache_key, 600, "completed") end_at = time.perf_counter() - logging.info(click.style('Segment batch created job: {} latency: {}'.format(job_id, end_at - start_at), fg='green')) + logging.info( + click.style("Segment batch created job: {} latency: {}".format(job_id, end_at - start_at), fg="green") + ) except Exception as e: logging.exception("Segments batch created index failed:{}".format(str(e))) - redis_client.setex(indexing_cache_key, 600, 'error') + redis_client.setex(indexing_cache_key, 600, "error") diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index 1f26c966c4..3624903801 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -19,9 +19,15 @@ from models.model import UploadFile # Add import statement for ValueError -@shared_task(queue='dataset') -def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, - index_struct: str, collection_binding_id: str, doc_form: str): +@shared_task(queue="dataset") +def clean_dataset_task( + dataset_id: str, + tenant_id: str, + indexing_technique: str, + index_struct: str, + collection_binding_id: str, + doc_form: str, +): """ Clean dataset when dataset deleted. :param dataset_id: dataset id @@ -33,7 +39,7 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) """ - logging.info(click.style('Start clean dataset when dataset deleted: {}'.format(dataset_id), fg='green')) + logging.info(click.style("Start clean dataset when dataset deleted: {}".format(dataset_id), fg="green")) start_at = time.perf_counter() try: @@ -48,9 +54,9 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() if documents is None or len(documents) == 0: - logging.info(click.style('No documents found for dataset: {}'.format(dataset_id), fg='green')) + logging.info(click.style("No documents found for dataset: {}".format(dataset_id), fg="green")) else: - logging.info(click.style('Cleaning documents for dataset: {}'.format(dataset_id), fg='green')) + logging.info(click.style("Cleaning documents for dataset: {}".format(dataset_id), fg="green")) # Specify the index type before initializing the index processor if doc_form is None: raise ValueError("Index type must be specified.") @@ -71,15 +77,16 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, if documents: for document in documents: try: - if document.data_source_type == 'upload_file': + if document.data_source_type == "upload_file": if document.data_source_info: data_source_info = document.data_source_info_dict - if data_source_info and 'upload_file_id' in data_source_info: - file_id = data_source_info['upload_file_id'] - file = db.session.query(UploadFile).filter( - UploadFile.tenant_id == document.tenant_id, - UploadFile.id == file_id - ).first() + if data_source_info and "upload_file_id" in data_source_info: + file_id = data_source_info["upload_file_id"] + file = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) + .first() + ) if not file: continue storage.delete(file.key) @@ -90,6 +97,9 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, db.session.commit() end_at = time.perf_counter() logging.info( - click.style('Cleaned dataset when dataset deleted: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green')) + click.style( + "Cleaned dataset when dataset deleted: {} latency: {}".format(dataset_id, end_at - start_at), fg="green" + ) + ) except Exception: logging.exception("Cleaned dataset when dataset deleted failed") diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 0fd05615b6..ae2855aa2e 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -12,7 +12,7 @@ from models.dataset import Dataset, DocumentSegment from models.model import UploadFile -@shared_task(queue='dataset') +@shared_task(queue="dataset") def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_id: Optional[str]): """ Clean document when document deleted. @@ -23,14 +23,14 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i Usage: clean_document_task.delay(document_id, dataset_id) """ - logging.info(click.style('Start clean document when document deleted: {}'.format(document_id), fg='green')) + logging.info(click.style("Start clean document when document deleted: {}".format(document_id), fg="green")) start_at = time.perf_counter() try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: - raise Exception('Document has no dataset') + raise Exception("Document has no dataset") segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() # check segment is exist @@ -44,9 +44,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i db.session.commit() if file_id: - file = db.session.query(UploadFile).filter( - UploadFile.id == file_id - ).first() + file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() if file: try: storage.delete(file.key) @@ -57,6 +55,10 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i end_at = time.perf_counter() logging.info( - click.style('Cleaned document when document deleted: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) + click.style( + "Cleaned document when document deleted: {} latency: {}".format(document_id, end_at - start_at), + fg="green", + ) + ) except Exception: logging.exception("Cleaned document when document deleted failed") diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index 9b697b6351..75d9e03130 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -9,7 +9,7 @@ from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment -@shared_task(queue='dataset') +@shared_task(queue="dataset") def clean_notion_document_task(document_ids: list[str], dataset_id: str): """ Clean document when document deleted. @@ -18,20 +18,20 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): Usage: clean_notion_document_task.delay(document_ids, dataset_id) """ - logging.info(click.style('Start clean document when import form notion document deleted: {}'.format(dataset_id), fg='green')) + logging.info( + click.style("Start clean document when import form notion document deleted: {}".format(dataset_id), fg="green") + ) start_at = time.perf_counter() try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: - raise Exception('Document has no dataset') + raise Exception("Document has no dataset") index_type = dataset.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() for document_id in document_ids: - document = db.session.query(Document).filter( - Document.id == document_id - ).first() + document = db.session.query(Document).filter(Document.id == document_id).first() db.session.delete(document) segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() @@ -44,8 +44,12 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): db.session.commit() end_at = time.perf_counter() logging.info( - click.style('Clean document when import form notion document deleted end :: {} latency: {}'.format( - dataset_id, end_at - start_at), - fg='green')) + click.style( + "Clean document when import form notion document deleted end :: {} latency: {}".format( + dataset_id, end_at - start_at + ), + fg="green", + ) + ) except Exception: logging.exception("Cleaned document when import form notion document deleted failed") diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index d31286e4cc..26375743b6 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -14,7 +14,7 @@ from extensions.ext_redis import redis_client from models.dataset import DocumentSegment -@shared_task(queue='dataset') +@shared_task(queue="dataset") def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] = None): """ Async create segment to index @@ -22,23 +22,23 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] :param keywords: Usage: create_segment_to_index_task.delay(segment_id) """ - logging.info(click.style('Start create segment to index: {}'.format(segment_id), fg='green')) + logging.info(click.style("Start create segment to index: {}".format(segment_id), fg="green")) start_at = time.perf_counter() segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() if not segment: - raise NotFound('Segment not found') + raise NotFound("Segment not found") - if segment.status != 'waiting': + if segment.status != "waiting": return - indexing_cache_key = 'segment_{}_indexing'.format(segment.id) + indexing_cache_key = "segment_{}_indexing".format(segment.id) try: # update segment status to indexing update_params = { DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), } DocumentSegment.query.filter_by(id=segment.id).update(update_params) db.session.commit() @@ -49,23 +49,23 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) dataset = segment.dataset if not dataset: - logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan')) + logging.info(click.style("Segment {} has no dataset, pass.".format(segment.id), fg="cyan")) return dataset_document = segment.document if not dataset_document: - logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan')) + logging.info(click.style("Segment {} has no document, pass.".format(segment.id), fg="cyan")) return - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': - logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info(click.style("Segment {} document status is invalid, pass.".format(segment.id), fg="cyan")) return index_type = dataset.doc_form @@ -75,18 +75,20 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] # update segment to completed update_params = { DocumentSegment.status: "completed", - DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), } DocumentSegment.query.filter_by(id=segment.id).update(update_params) db.session.commit() end_at = time.perf_counter() - logging.info(click.style('Segment created to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) + logging.info( + click.style("Segment created to index: {} latency: {}".format(segment.id, end_at - start_at), fg="green") + ) except Exception as e: logging.exception("create segment to index failed") segment.enabled = False segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - segment.status = 'error' + segment.status = "error" segment.error = str(e) db.session.commit() finally: diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index ce93e111e5..cfc54920e2 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -11,7 +11,7 @@ from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument -@shared_task(queue='dataset') +@shared_task(queue="dataset") def deal_dataset_vector_index_task(dataset_id: str, action: str): """ Async deal dataset from index @@ -19,41 +19,46 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): :param action: action Usage: deal_dataset_vector_index_task.delay(dataset_id, action) """ - logging.info(click.style('Start deal dataset vector index: {}'.format(dataset_id), fg='green')) + logging.info(click.style("Start deal dataset vector index: {}".format(dataset_id), fg="green")) start_at = time.perf_counter() try: - dataset = Dataset.query.filter_by( - id=dataset_id - ).first() + dataset = Dataset.query.filter_by(id=dataset_id).first() if not dataset: - raise Exception('Dataset not found') + raise Exception("Dataset not found") index_type = dataset.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() if action == "remove": index_processor.clean(dataset, None, with_keywords=False) elif action == "add": - dataset_documents = db.session.query(DatasetDocument).filter( - DatasetDocument.dataset_id == dataset_id, - DatasetDocument.indexing_status == 'completed', - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).all() + dataset_documents = ( + db.session.query(DatasetDocument) + .filter( + DatasetDocument.dataset_id == dataset_id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + .all() + ) if dataset_documents: dataset_documents_ids = [doc.id for doc in dataset_documents] - db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)) \ - .update({"indexing_status": "indexing"}, synchronize_session=False) + db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update( + {"indexing_status": "indexing"}, synchronize_session=False + ) db.session.commit() for dataset_document in dataset_documents: try: # add from vector index - segments = db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == dataset_document.id, - DocumentSegment.enabled == True - ) .order_by(DocumentSegment.position.asc()).all() + segments = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) + .order_by(DocumentSegment.position.asc()) + .all() + ) if segments: documents = [] for segment in segments: @@ -64,32 +69,39 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) documents.append(document) # save vector index index_processor.load(dataset, documents, with_keywords=False) - db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id) \ - .update({"indexing_status": "completed"}, synchronize_session=False) + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "completed"}, synchronize_session=False + ) db.session.commit() except Exception as e: - db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id) \ - .update({"indexing_status": "error", "error": str(e)}, synchronize_session=False) + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "error", "error": str(e)}, synchronize_session=False + ) db.session.commit() - elif action == 'update': - dataset_documents = db.session.query(DatasetDocument).filter( - DatasetDocument.dataset_id == dataset_id, - DatasetDocument.indexing_status == 'completed', - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).all() + elif action == "update": + dataset_documents = ( + db.session.query(DatasetDocument) + .filter( + DatasetDocument.dataset_id == dataset_id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + .all() + ) # add new index if dataset_documents: # update document status dataset_documents_ids = [doc.id for doc in dataset_documents] - db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)) \ - .update({"indexing_status": "indexing"}, synchronize_session=False) + db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update( + {"indexing_status": "indexing"}, synchronize_session=False + ) db.session.commit() # clean index @@ -98,10 +110,12 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): for dataset_document in dataset_documents: # update from vector index try: - segments = db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == dataset_document.id, - DocumentSegment.enabled == True - ).order_by(DocumentSegment.position.asc()).all() + segments = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) + .order_by(DocumentSegment.position.asc()) + .all() + ) if segments: documents = [] for segment in segments: @@ -112,23 +126,25 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) documents.append(document) # save vector index index_processor.load(dataset, documents, with_keywords=False) - db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id) \ - .update({"indexing_status": "completed"}, synchronize_session=False) + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "completed"}, synchronize_session=False + ) db.session.commit() except Exception as e: - db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id) \ - .update({"indexing_status": "error", "error": str(e)}, synchronize_session=False) + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "error", "error": str(e)}, synchronize_session=False + ) db.session.commit() - end_at = time.perf_counter() logging.info( - click.style('Deal dataset vector index: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green')) + click.style("Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at), fg="green") + ) except Exception: logging.exception("Deal dataset vector index failed") diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index d79286cf3d..c3e0ea5d9f 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -10,7 +10,7 @@ from extensions.ext_redis import redis_client from models.dataset import Dataset, Document -@shared_task(queue='dataset') +@shared_task(queue="dataset") def delete_segment_from_index_task(segment_id: str, index_node_id: str, dataset_id: str, document_id: str): """ Async Remove segment from index @@ -21,22 +21,22 @@ def delete_segment_from_index_task(segment_id: str, index_node_id: str, dataset_ Usage: delete_segment_from_index_task.delay(segment_id) """ - logging.info(click.style('Start delete segment from index: {}'.format(segment_id), fg='green')) + logging.info(click.style("Start delete segment from index: {}".format(segment_id), fg="green")) start_at = time.perf_counter() - indexing_cache_key = 'segment_{}_delete_indexing'.format(segment_id) + indexing_cache_key = "segment_{}_delete_indexing".format(segment_id) try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: - logging.info(click.style('Segment {} has no dataset, pass.'.format(segment_id), fg='cyan')) + logging.info(click.style("Segment {} has no dataset, pass.".format(segment_id), fg="cyan")) return dataset_document = db.session.query(Document).filter(Document.id == document_id).first() if not dataset_document: - logging.info(click.style('Segment {} has no document, pass.'.format(segment_id), fg='cyan')) + logging.info(click.style("Segment {} has no document, pass.".format(segment_id), fg="cyan")) return - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': - logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment_id), fg='cyan')) + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info(click.style("Segment {} document status is invalid, pass.".format(segment_id), fg="cyan")) return index_type = dataset_document.doc_form @@ -44,7 +44,9 @@ def delete_segment_from_index_task(segment_id: str, index_node_id: str, dataset_ index_processor.clean(dataset, [index_node_id]) end_at = time.perf_counter() - logging.info(click.style('Segment deleted from index: {} latency: {}'.format(segment_id, end_at - start_at), fg='green')) + logging.info( + click.style("Segment deleted from index: {} latency: {}".format(segment_id, end_at - start_at), fg="green") + ) except Exception: logging.exception("delete segment from index failed") finally: diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py index 4788bf4e4b..15e1e50076 100644 --- a/api/tasks/disable_segment_from_index_task.py +++ b/api/tasks/disable_segment_from_index_task.py @@ -11,7 +11,7 @@ from extensions.ext_redis import redis_client from models.dataset import DocumentSegment -@shared_task(queue='dataset') +@shared_task(queue="dataset") def disable_segment_from_index_task(segment_id: str): """ Async disable segment from index @@ -19,33 +19,33 @@ def disable_segment_from_index_task(segment_id: str): Usage: disable_segment_from_index_task.delay(segment_id) """ - logging.info(click.style('Start disable segment from index: {}'.format(segment_id), fg='green')) + logging.info(click.style("Start disable segment from index: {}".format(segment_id), fg="green")) start_at = time.perf_counter() segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() if not segment: - raise NotFound('Segment not found') + raise NotFound("Segment not found") - if segment.status != 'completed': - raise NotFound('Segment is not completed , disable action is not allowed.') + if segment.status != "completed": + raise NotFound("Segment is not completed , disable action is not allowed.") - indexing_cache_key = 'segment_{}_indexing'.format(segment.id) + indexing_cache_key = "segment_{}_indexing".format(segment.id) try: dataset = segment.dataset if not dataset: - logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan')) + logging.info(click.style("Segment {} has no dataset, pass.".format(segment.id), fg="cyan")) return dataset_document = segment.document if not dataset_document: - logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan')) + logging.info(click.style("Segment {} has no document, pass.".format(segment.id), fg="cyan")) return - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': - logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info(click.style("Segment {} document status is invalid, pass.".format(segment.id), fg="cyan")) return index_type = dataset_document.doc_form @@ -53,7 +53,9 @@ def disable_segment_from_index_task(segment_id: str): index_processor.clean(dataset, [segment.index_node_id]) end_at = time.perf_counter() - logging.info(click.style('Segment removed from index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) + logging.info( + click.style("Segment removed from index: {} latency: {}".format(segment.id, end_at - start_at), fg="green") + ) except Exception: logging.exception("remove segment from index failed") segment.enabled = True diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 4cced36ecd..9ea4c99649 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -14,7 +14,7 @@ from models.dataset import Dataset, Document, DocumentSegment from models.source import DataSourceOauthBinding -@shared_task(queue='dataset') +@shared_task(queue="dataset") def document_indexing_sync_task(dataset_id: str, document_id: str): """ Async update document @@ -23,50 +23,50 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): Usage: document_indexing_sync_task.delay(dataset_id, document_id) """ - logging.info(click.style('Start sync document: {}'.format(document_id), fg='green')) + logging.info(click.style("Start sync document: {}".format(document_id), fg="green")) start_at = time.perf_counter() - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() if not document: - raise NotFound('Document not found') + raise NotFound("Document not found") data_source_info = document.data_source_info_dict - if document.data_source_type == 'notion_import': - if not data_source_info or 'notion_page_id' not in data_source_info \ - or 'notion_workspace_id' not in data_source_info: + if document.data_source_type == "notion_import": + if ( + not data_source_info + or "notion_page_id" not in data_source_info + or "notion_workspace_id" not in data_source_info + ): raise ValueError("no notion page found") - workspace_id = data_source_info['notion_workspace_id'] - page_id = data_source_info['notion_page_id'] - page_type = data_source_info['type'] - page_edited_time = data_source_info['last_edited_time'] + workspace_id = data_source_info["notion_workspace_id"] + page_id = data_source_info["notion_page_id"] + page_type = data_source_info["type"] + page_edited_time = data_source_info["last_edited_time"] data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == document.tenant_id, - DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"' + DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', ) ).first() if not data_source_binding: - raise ValueError('Data source binding not found.') + raise ValueError("Data source binding not found.") loader = NotionExtractor( notion_workspace_id=workspace_id, notion_obj_id=page_id, notion_page_type=page_type, notion_access_token=data_source_binding.access_token, - tenant_id=document.tenant_id + tenant_id=document.tenant_id, ) last_edited_time = loader.get_notion_last_edited_time() # check the page is updated if last_edited_time != page_edited_time: - document.indexing_status = 'parsing' + document.indexing_status = "parsing" document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() @@ -74,7 +74,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: - raise Exception('Dataset not found') + raise Exception("Dataset not found") index_type = document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() @@ -89,7 +89,13 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): end_at = time.perf_counter() logging.info( - click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) + click.style( + "Cleaned document when document update data source or process rule: {} latency: {}".format( + document_id, end_at - start_at + ), + fg="green", + ) + ) except Exception: logging.exception("Cleaned document when document update data source or process rule failed") @@ -97,8 +103,10 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): indexing_runner = IndexingRunner() indexing_runner.run([document]) end_at = time.perf_counter() - logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) + logging.info( + click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green") + ) except DocumentIsPausedException as ex: - logging.info(click.style(str(ex), fg='yellow')) + logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index cc93a1341e..e0da5f9ed0 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -12,7 +12,7 @@ from models.dataset import Dataset, Document from services.feature_service import FeatureService -@shared_task(queue='dataset') +@shared_task(queue="dataset") def document_indexing_task(dataset_id: str, document_ids: list): """ Async process document @@ -36,16 +36,17 @@ def document_indexing_task(dataset_id: str, document_ids: list): if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") if 0 < vector_space.limit <= vector_space.size: - raise ValueError("Your total number of documents plus the number of uploads have over the limit of " - "your subscription.") + raise ValueError( + "Your total number of documents plus the number of uploads have over the limit of " + "your subscription." + ) except Exception as e: for document_id in document_ids: - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) if document: - document.indexing_status = 'error' + document.indexing_status = "error" document.error = str(e) document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) @@ -53,15 +54,14 @@ def document_indexing_task(dataset_id: str, document_ids: list): return for document_id in document_ids: - logging.info(click.style('Start process document: {}'.format(document_id), fg='green')) + logging.info(click.style("Start process document: {}".format(document_id), fg="green")) - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) if document: - document.indexing_status = 'parsing' + document.indexing_status = "parsing" document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) documents.append(document) db.session.add(document) @@ -71,8 +71,8 @@ def document_indexing_task(dataset_id: str, document_ids: list): indexing_runner = IndexingRunner() indexing_runner.run(documents) end_at = time.perf_counter() - logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green')) + logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) except DocumentIsPausedException as ex: - logging.info(click.style(str(ex), fg='yellow')) + logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index f129d93de8..6e681bcf4f 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -12,7 +12,7 @@ from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment -@shared_task(queue='dataset') +@shared_task(queue="dataset") def document_indexing_update_task(dataset_id: str, document_id: str): """ Async update document @@ -21,18 +21,15 @@ def document_indexing_update_task(dataset_id: str, document_id: str): Usage: document_indexing_update_task.delay(dataset_id, document_id) """ - logging.info(click.style('Start update document: {}'.format(document_id), fg='green')) + logging.info(click.style("Start update document: {}".format(document_id), fg="green")) start_at = time.perf_counter() - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() if not document: - raise NotFound('Document not found') + raise NotFound("Document not found") - document.indexing_status = 'parsing' + document.indexing_status = "parsing" document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() @@ -40,7 +37,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: - raise Exception('Dataset not found') + raise Exception("Dataset not found") index_type = document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() @@ -57,7 +54,13 @@ def document_indexing_update_task(dataset_id: str, document_id: str): db.session.commit() end_at = time.perf_counter() logging.info( - click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) + click.style( + "Cleaned document when document update data source or process rule: {} latency: {}".format( + document_id, end_at - start_at + ), + fg="green", + ) + ) except Exception: logging.exception("Cleaned document when document update data source or process rule failed") @@ -65,8 +68,8 @@ def document_indexing_update_task(dataset_id: str, document_id: str): indexing_runner = IndexingRunner() indexing_runner.run([document]) end_at = time.perf_counter() - logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) + logging.info(click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green")) except DocumentIsPausedException as ex: - logging.info(click.style(str(ex), fg='yellow')) + logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index 884e222d1b..0a7568c385 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -13,7 +13,7 @@ from models.dataset import Dataset, Document, DocumentSegment from services.feature_service import FeatureService -@shared_task(queue='dataset') +@shared_task(queue="dataset") def duplicate_document_indexing_task(dataset_id: str, document_ids: list): """ Async process document @@ -37,16 +37,17 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") if 0 < vector_space.limit <= vector_space.size: - raise ValueError("Your total number of documents plus the number of uploads have over the limit of " - "your subscription.") + raise ValueError( + "Your total number of documents plus the number of uploads have over the limit of " + "your subscription." + ) except Exception as e: for document_id in document_ids: - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) if document: - document.indexing_status = 'error' + document.indexing_status = "error" document.error = str(e) document.stopped_at = datetime.datetime.utcnow() db.session.add(document) @@ -54,12 +55,11 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): return for document_id in document_ids: - logging.info(click.style('Start process document: {}'.format(document_id), fg='green')) + logging.info(click.style("Start process document: {}".format(document_id), fg="green")) - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) if document: # clean old data @@ -77,7 +77,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): db.session.delete(segment) db.session.commit() - document.indexing_status = 'parsing' + document.indexing_status = "parsing" document.processing_started_at = datetime.datetime.utcnow() documents.append(document) db.session.add(document) @@ -87,8 +87,8 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): indexing_runner = IndexingRunner() indexing_runner.run(documents) end_at = time.perf_counter() - logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green')) + logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) except DocumentIsPausedException as ex: - logging.info(click.style(str(ex), fg='yellow')) + logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index e37c06855d..1412ad9ec7 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -13,7 +13,7 @@ from extensions.ext_redis import redis_client from models.dataset import DocumentSegment -@shared_task(queue='dataset') +@shared_task(queue="dataset") def enable_segment_to_index_task(segment_id: str): """ Async enable segment to index @@ -21,17 +21,17 @@ def enable_segment_to_index_task(segment_id: str): Usage: enable_segment_to_index_task.delay(segment_id) """ - logging.info(click.style('Start enable segment to index: {}'.format(segment_id), fg='green')) + logging.info(click.style("Start enable segment to index: {}".format(segment_id), fg="green")) start_at = time.perf_counter() segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() if not segment: - raise NotFound('Segment not found') + raise NotFound("Segment not found") - if segment.status != 'completed': - raise NotFound('Segment is not completed, enable action is not allowed.') + if segment.status != "completed": + raise NotFound("Segment is not completed, enable action is not allowed.") - indexing_cache_key = 'segment_{}_indexing'.format(segment.id) + indexing_cache_key = "segment_{}_indexing".format(segment.id) try: document = Document( @@ -41,23 +41,23 @@ def enable_segment_to_index_task(segment_id: str): "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) dataset = segment.dataset if not dataset: - logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan')) + logging.info(click.style("Segment {} has no dataset, pass.".format(segment.id), fg="cyan")) return dataset_document = segment.document if not dataset_document: - logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan')) + logging.info(click.style("Segment {} has no document, pass.".format(segment.id), fg="cyan")) return - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': - logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info(click.style("Segment {} document status is invalid, pass.".format(segment.id), fg="cyan")) return index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() @@ -65,12 +65,14 @@ def enable_segment_to_index_task(segment_id: str): index_processor.load(dataset, [document]) end_at = time.perf_counter() - logging.info(click.style('Segment enabled to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) + logging.info( + click.style("Segment enabled to index: {} latency: {}".format(segment.id, end_at - start_at), fg="green") + ) except Exception as e: logging.exception("enable segment to index failed") segment.enabled = False segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - segment.status = 'error' + segment.status = "error" segment.error = str(e) db.session.commit() finally: diff --git a/api/tasks/mail_invite_member_task.py b/api/tasks/mail_invite_member_task.py index a46eafa797..c7dfb9bf60 100644 --- a/api/tasks/mail_invite_member_task.py +++ b/api/tasks/mail_invite_member_task.py @@ -9,7 +9,7 @@ from configs import dify_config from extensions.ext_mail import mail -@shared_task(queue='mail') +@shared_task(queue="mail") def send_invite_member_mail_task(language: str, to: str, token: str, inviter_name: str, workspace_name: str): """ Async Send invite member mail @@ -19,36 +19,43 @@ def send_invite_member_mail_task(language: str, to: str, token: str, inviter_nam :param inviter_name :param workspace_name - Usage: send_invite_member_mail_task.delay(langauge, to, token, inviter_name, workspace_name) + Usage: send_invite_member_mail_task.delay(language, to, token, inviter_name, workspace_name) """ if not mail.is_inited(): return - logging.info(click.style('Start send invite member mail to {} in workspace {}'.format(to, workspace_name), - fg='green')) + logging.info( + click.style("Start send invite member mail to {} in workspace {}".format(to, workspace_name), fg="green") + ) start_at = time.perf_counter() # send invite member mail using different languages try: - url = f'{dify_config.CONSOLE_WEB_URL}/activate?token={token}' - if language == 'zh-Hans': - html_content = render_template('invite_member_mail_template_zh-CN.html', - to=to, - inviter_name=inviter_name, - workspace_name=workspace_name, - url=url) + url = f"{dify_config.CONSOLE_WEB_URL}/activate?token={token}" + if language == "zh-Hans": + html_content = render_template( + "invite_member_mail_template_zh-CN.html", + to=to, + inviter_name=inviter_name, + workspace_name=workspace_name, + url=url, + ) mail.send(to=to, subject="立即加入 Dify 工作空间", html=html_content) else: - html_content = render_template('invite_member_mail_template_en-US.html', - to=to, - inviter_name=inviter_name, - workspace_name=workspace_name, - url=url) + html_content = render_template( + "invite_member_mail_template_en-US.html", + to=to, + inviter_name=inviter_name, + workspace_name=workspace_name, + url=url, + ) mail.send(to=to, subject="Join Dify Workspace Now", html=html_content) end_at = time.perf_counter() logging.info( - click.style('Send invite member mail to {} succeeded: latency: {}'.format(to, end_at - start_at), - fg='green')) + click.style( + "Send invite member mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green" + ) + ) except Exception: - logging.exception("Send invite member mail to {} failed".format(to)) \ No newline at end of file + logging.exception("Send invite member mail to {} failed".format(to)) diff --git a/api/tasks/mail_reset_password_task.py b/api/tasks/mail_reset_password_task.py index 4e1b8a8913..cbb78976ca 100644 --- a/api/tasks/mail_reset_password_task.py +++ b/api/tasks/mail_reset_password_task.py @@ -9,7 +9,7 @@ from configs import dify_config from extensions.ext_mail import mail -@shared_task(queue='mail') +@shared_task(queue="mail") def send_reset_password_mail_task(language: str, to: str, token: str): """ Async Send reset password mail @@ -20,26 +20,24 @@ def send_reset_password_mail_task(language: str, to: str, token: str): if not mail.is_inited(): return - logging.info(click.style('Start password reset mail to {}'.format(to), fg='green')) + logging.info(click.style("Start password reset mail to {}".format(to), fg="green")) start_at = time.perf_counter() # send reset password mail using different languages try: - url = f'{dify_config.CONSOLE_WEB_URL}/forgot-password?token={token}' - if language == 'zh-Hans': - html_content = render_template('reset_password_mail_template_zh-CN.html', - to=to, - url=url) + url = f"{dify_config.CONSOLE_WEB_URL}/forgot-password?token={token}" + if language == "zh-Hans": + html_content = render_template("reset_password_mail_template_zh-CN.html", to=to, url=url) mail.send(to=to, subject="重置您的 Dify 密码", html=html_content) else: - html_content = render_template('reset_password_mail_template_en-US.html', - to=to, - url=url) + html_content = render_template("reset_password_mail_template_en-US.html", to=to, url=url) mail.send(to=to, subject="Reset Your Dify Password", html=html_content) end_at = time.perf_counter() logging.info( - click.style('Send password reset mail to {} succeeded: latency: {}'.format(to, end_at - start_at), - fg='green')) + click.style( + "Send password reset mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green" + ) + ) except Exception: logging.exception("Send password reset mail to {} failed".format(to)) diff --git a/api/tasks/ops_trace_task.py b/api/tasks/ops_trace_task.py index 6b4cab55b3..260069c6e2 100644 --- a/api/tasks/ops_trace_task.py +++ b/api/tasks/ops_trace_task.py @@ -10,7 +10,7 @@ from models.model import Message from models.workflow import WorkflowRun -@shared_task(queue='ops_trace') +@shared_task(queue="ops_trace") def process_trace_tasks(tasks_data): """ Async process trace tasks @@ -20,17 +20,17 @@ def process_trace_tasks(tasks_data): """ from core.ops.ops_trace_manager import OpsTraceManager - trace_info = tasks_data.get('trace_info') - app_id = tasks_data.get('app_id') - trace_info_type = tasks_data.get('trace_info_type') + trace_info = tasks_data.get("trace_info") + app_id = tasks_data.get("app_id") + trace_info_type = tasks_data.get("trace_info_type") trace_instance = OpsTraceManager.get_ops_trace_instance(app_id) - if trace_info.get('message_data'): - trace_info['message_data'] = Message.from_dict(data=trace_info['message_data']) - if trace_info.get('workflow_data'): - trace_info['workflow_data'] = WorkflowRun.from_dict(data=trace_info['workflow_data']) - if trace_info.get('documents'): - trace_info['documents'] = [Document(**doc) for doc in trace_info['documents']] + if trace_info.get("message_data"): + trace_info["message_data"] = Message.from_dict(data=trace_info["message_data"]) + if trace_info.get("workflow_data"): + trace_info["workflow_data"] = WorkflowRun.from_dict(data=trace_info["workflow_data"]) + if trace_info.get("documents"): + trace_info["documents"] = [Document(**doc) for doc in trace_info["documents"]] try: if trace_instance: diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py index 02278f512b..18bae14ffa 100644 --- a/api/tasks/recover_document_indexing_task.py +++ b/api/tasks/recover_document_indexing_task.py @@ -10,7 +10,7 @@ from extensions.ext_database import db from models.dataset import Document -@shared_task(queue='dataset') +@shared_task(queue="dataset") def recover_document_indexing_task(dataset_id: str, document_id: str): """ Async recover document @@ -19,16 +19,13 @@ def recover_document_indexing_task(dataset_id: str, document_id: str): Usage: recover_document_indexing_task.delay(dataset_id, document_id) """ - logging.info(click.style('Recover document: {}'.format(document_id), fg='green')) + logging.info(click.style("Recover document: {}".format(document_id), fg="green")) start_at = time.perf_counter() - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() if not document: - raise NotFound('Document not found') + raise NotFound("Document not found") try: indexing_runner = IndexingRunner() @@ -39,8 +36,10 @@ def recover_document_indexing_task(dataset_id: str, document_id: str): elif document.indexing_status == "indexing": indexing_runner.run_in_indexing_status(document) end_at = time.perf_counter() - logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) + logging.info( + click.style("Processed document: {} latency: {}".format(document.id, end_at - start_at), fg="green") + ) except DocumentIsPausedException as ex: - logging.info(click.style(str(ex), fg='yellow')) + logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 4efe7ee38c..66f78636ec 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -33,9 +33,9 @@ from models.web import PinnedConversation, SavedMessage from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecution, WorkflowRun -@shared_task(queue='app_deletion', bind=True, max_retries=3) +@shared_task(queue="app_deletion", bind=True, max_retries=3) def remove_app_and_related_data_task(self, tenant_id: str, app_id: str): - logging.info(click.style(f'Start deleting app and related data: {tenant_id}:{app_id}', fg='green')) + logging.info(click.style(f"Start deleting app and related data: {tenant_id}:{app_id}", fg="green")) start_at = time.perf_counter() try: # Delete related data @@ -59,13 +59,14 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str): _delete_conversation_variables(app_id=app_id) end_at = time.perf_counter() - logging.info(click.style(f'App and related data deleted: {app_id} latency: {end_at - start_at}', fg='green')) + logging.info(click.style(f"App and related data deleted: {app_id} latency: {end_at - start_at}", fg="green")) except SQLAlchemyError as e: logging.exception( - click.style(f"Database error occurred while deleting app {app_id} and related data", fg='red')) + click.style(f"Database error occurred while deleting app {app_id} and related data", fg="red") + ) raise self.retry(exc=e, countdown=60) # Retry after 60 seconds except Exception as e: - logging.exception(click.style(f"Error occurred while deleting app {app_id} and related data", fg='red')) + logging.exception(click.style(f"Error occurred while deleting app {app_id} and related data", fg="red")) raise self.retry(exc=e, countdown=60) # Retry after 60 seconds @@ -77,7 +78,7 @@ def _delete_app_model_configs(tenant_id: str, app_id: str): """select id from app_model_configs where app_id=:app_id limit 1000""", {"app_id": app_id}, del_model_config, - "app model config" + "app model config", ) @@ -85,12 +86,7 @@ def _delete_app_site(tenant_id: str, app_id: str): def del_site(site_id: str): db.session.query(Site).filter(Site.id == site_id).delete(synchronize_session=False) - _delete_records( - """select id from sites where app_id=:app_id limit 1000""", - {"app_id": app_id}, - del_site, - "site" - ) + _delete_records("""select id from sites where app_id=:app_id limit 1000""", {"app_id": app_id}, del_site, "site") def _delete_app_api_tokens(tenant_id: str, app_id: str): @@ -98,10 +94,7 @@ def _delete_app_api_tokens(tenant_id: str, app_id: str): db.session.query(ApiToken).filter(ApiToken.id == api_token_id).delete(synchronize_session=False) _delete_records( - """select id from api_tokens where app_id=:app_id limit 1000""", - {"app_id": app_id}, - del_api_token, - "api token" + """select id from api_tokens where app_id=:app_id limit 1000""", {"app_id": app_id}, del_api_token, "api token" ) @@ -113,44 +106,47 @@ def _delete_installed_apps(tenant_id: str, app_id: str): """select id from installed_apps where tenant_id=:tenant_id and app_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_installed_app, - "installed app" + "installed app", ) def _delete_recommended_apps(tenant_id: str, app_id: str): def del_recommended_app(recommended_app_id: str): db.session.query(RecommendedApp).filter(RecommendedApp.id == recommended_app_id).delete( - synchronize_session=False) + synchronize_session=False + ) _delete_records( """select id from recommended_apps where app_id=:app_id limit 1000""", {"app_id": app_id}, del_recommended_app, - "recommended app" + "recommended app", ) def _delete_app_annotation_data(tenant_id: str, app_id: str): def del_annotation_hit_history(annotation_hit_history_id: str): db.session.query(AppAnnotationHitHistory).filter( - AppAnnotationHitHistory.id == annotation_hit_history_id).delete(synchronize_session=False) + AppAnnotationHitHistory.id == annotation_hit_history_id + ).delete(synchronize_session=False) _delete_records( """select id from app_annotation_hit_histories where app_id=:app_id limit 1000""", {"app_id": app_id}, del_annotation_hit_history, - "annotation hit history" + "annotation hit history", ) def del_annotation_setting(annotation_setting_id: str): db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.id == annotation_setting_id).delete( - synchronize_session=False) + synchronize_session=False + ) _delete_records( """select id from app_annotation_settings where app_id=:app_id limit 1000""", {"app_id": app_id}, del_annotation_setting, - "annotation setting" + "annotation setting", ) @@ -162,7 +158,7 @@ def _delete_app_dataset_joins(tenant_id: str, app_id: str): """select id from app_dataset_joins where app_id=:app_id limit 1000""", {"app_id": app_id}, del_dataset_join, - "dataset join" + "dataset join", ) @@ -174,7 +170,7 @@ def _delete_app_workflows(tenant_id: str, app_id: str): """select id from workflows where tenant_id=:tenant_id and app_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_workflow, - "workflow" + "workflow", ) @@ -186,89 +182,93 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str): """select id from workflow_runs where tenant_id=:tenant_id and app_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_workflow_run, - "workflow run" + "workflow run", ) def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): def del_workflow_node_execution(workflow_node_execution_id: str): - db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == workflow_node_execution_id).delete(synchronize_session=False) + db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).delete( + synchronize_session=False + ) _delete_records( """select id from workflow_node_executions where tenant_id=:tenant_id and app_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_workflow_node_execution, - "workflow node execution" + "workflow node execution", ) def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): def del_workflow_app_log(workflow_app_log_id: str): - db.session.query(WorkflowAppLog).filter(WorkflowAppLog.id == workflow_app_log_id).delete(synchronize_session=False) + db.session.query(WorkflowAppLog).filter(WorkflowAppLog.id == workflow_app_log_id).delete( + synchronize_session=False + ) _delete_records( """select id from workflow_app_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_workflow_app_log, - "workflow app log" + "workflow app log", ) def _delete_app_conversations(tenant_id: str, app_id: str): def del_conversation(conversation_id: str): db.session.query(PinnedConversation).filter(PinnedConversation.conversation_id == conversation_id).delete( - synchronize_session=False) + synchronize_session=False + ) db.session.query(Conversation).filter(Conversation.id == conversation_id).delete(synchronize_session=False) _delete_records( """select id from conversations where app_id=:app_id limit 1000""", {"app_id": app_id}, del_conversation, - "conversation" + "conversation", ) + def _delete_conversation_variables(*, app_id: str): stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id) with db.engine.connect() as conn: conn.execute(stmt) conn.commit() - logging.info(click.style(f"Deleted conversation variables for app {app_id}", fg='green')) + logging.info(click.style(f"Deleted conversation variables for app {app_id}", fg="green")) def _delete_app_messages(tenant_id: str, app_id: str): def del_message(message_id: str): db.session.query(MessageFeedback).filter(MessageFeedback.message_id == message_id).delete( - synchronize_session=False) + synchronize_session=False + ) db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id == message_id).delete( - synchronize_session=False) - db.session.query(MessageChain).filter(MessageChain.message_id == message_id).delete( - synchronize_session=False) + synchronize_session=False + ) + db.session.query(MessageChain).filter(MessageChain.message_id == message_id).delete(synchronize_session=False) db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message_id).delete( - synchronize_session=False) + synchronize_session=False + ) db.session.query(MessageFile).filter(MessageFile.message_id == message_id).delete(synchronize_session=False) - db.session.query(SavedMessage).filter(SavedMessage.message_id == message_id).delete( - synchronize_session=False) + db.session.query(SavedMessage).filter(SavedMessage.message_id == message_id).delete(synchronize_session=False) db.session.query(Message).filter(Message.id == message_id).delete() _delete_records( - """select id from messages where app_id=:app_id limit 1000""", - {"app_id": app_id}, - del_message, - "message" + """select id from messages where app_id=:app_id limit 1000""", {"app_id": app_id}, del_message, "message" ) def _delete_workflow_tool_providers(tenant_id: str, app_id: str): def del_tool_provider(tool_provider_id: str): db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.id == tool_provider_id).delete( - synchronize_session=False) + synchronize_session=False + ) _delete_records( """select id from tool_workflow_providers where tenant_id=:tenant_id and app_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_tool_provider, - "tool workflow provider" + "tool workflow provider", ) @@ -280,7 +280,7 @@ def _delete_app_tag_bindings(tenant_id: str, app_id: str): """select id from tag_bindings where tenant_id=:tenant_id and target_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_tag_binding, - "tag binding" + "tag binding", ) @@ -292,20 +292,21 @@ def _delete_end_users(tenant_id: str, app_id: str): """select id from end_users where tenant_id=:tenant_id and app_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_end_user, - "end user" + "end user", ) def _delete_trace_app_configs(tenant_id: str, app_id: str): def del_trace_app_config(trace_app_config_id: str): db.session.query(TraceAppConfig).filter(TraceAppConfig.id == trace_app_config_id).delete( - synchronize_session=False) + synchronize_session=False + ) _delete_records( """select id from trace_app_config where app_id=:app_id limit 1000""", {"app_id": app_id}, del_trace_app_config, - "trace app config" + "trace app config", ) @@ -321,7 +322,7 @@ def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: s try: delete_func(record_id) db.session.commit() - logging.info(click.style(f"Deleted {name} {record_id}", fg='green')) + logging.info(click.style(f"Deleted {name} {record_id}", fg="green")) except Exception: logging.exception(f"Error occurred while deleting {name} {record_id}") continue diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index cff8dddc53..1909eaf341 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -11,7 +11,7 @@ from extensions.ext_redis import redis_client from models.dataset import Document, DocumentSegment -@shared_task(queue='dataset') +@shared_task(queue="dataset") def remove_document_from_index_task(document_id: str): """ Async Remove document from index @@ -19,23 +19,23 @@ def remove_document_from_index_task(document_id: str): Usage: remove_document_from_index.delay(document_id) """ - logging.info(click.style('Start remove document segments from index: {}'.format(document_id), fg='green')) + logging.info(click.style("Start remove document segments from index: {}".format(document_id), fg="green")) start_at = time.perf_counter() document = db.session.query(Document).filter(Document.id == document_id).first() if not document: - raise NotFound('Document not found') + raise NotFound("Document not found") - if document.indexing_status != 'completed': + if document.indexing_status != "completed": return - indexing_cache_key = 'document_{}_indexing'.format(document.id) + indexing_cache_key = "document_{}_indexing".format(document.id) try: dataset = document.dataset if not dataset: - raise Exception('Document has no dataset') + raise Exception("Document has no dataset") index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() @@ -49,7 +49,10 @@ def remove_document_from_index_task(document_id: str): end_at = time.perf_counter() logging.info( - click.style('Document removed from index: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) + click.style( + "Document removed from index: {} latency: {}".format(document.id, end_at - start_at), fg="green" + ) + ) except Exception: logging.exception("remove document from index failed") if not document.archived: diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index 1114809b30..73471fd6e7 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -13,7 +13,7 @@ from models.dataset import Dataset, Document, DocumentSegment from services.feature_service import FeatureService -@shared_task(queue='dataset') +@shared_task(queue="dataset") def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): """ Async process document @@ -27,22 +27,23 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() for document_id in document_ids: - retry_indexing_cache_key = 'document_{}_is_retried'.format(document_id) + retry_indexing_cache_key = "document_{}_is_retried".format(document_id) # check document limit features = FeatureService.get_features(dataset.tenant_id) try: if features.billing.enabled: vector_space = features.vector_space if 0 < vector_space.limit <= vector_space.size: - raise ValueError("Your total number of documents plus the number of uploads have over the limit of " - "your subscription.") + raise ValueError( + "Your total number of documents plus the number of uploads have over the limit of " + "your subscription." + ) except Exception as e: - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) if document: - document.indexing_status = 'error' + document.indexing_status = "error" document.error = str(e) document.stopped_at = datetime.datetime.utcnow() db.session.add(document) @@ -50,11 +51,10 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): redis_client.delete(retry_indexing_cache_key) return - logging.info(click.style('Start retry document: {}'.format(document_id), fg='green')) - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + logging.info(click.style("Start retry document: {}".format(document_id), fg="green")) + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) try: if document: # clean old data @@ -70,7 +70,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): db.session.delete(segment) db.session.commit() - document.indexing_status = 'parsing' + document.indexing_status = "parsing" document.processing_started_at = datetime.datetime.utcnow() db.session.add(document) db.session.commit() @@ -79,13 +79,13 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): indexing_runner.run([document]) redis_client.delete(retry_indexing_cache_key) except Exception as ex: - document.indexing_status = 'error' + document.indexing_status = "error" document.error = str(ex) document.stopped_at = datetime.datetime.utcnow() db.session.add(document) db.session.commit() - logging.info(click.style(str(ex), fg='yellow')) + logging.info(click.style(str(ex), fg="yellow")) redis_client.delete(retry_indexing_cache_key) pass end_at = time.perf_counter() - logging.info(click.style('Retry dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green')) + logging.info(click.style("Retry dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py index 320da8718a..99fb66e1f3 100644 --- a/api/tasks/sync_website_document_indexing_task.py +++ b/api/tasks/sync_website_document_indexing_task.py @@ -13,7 +13,7 @@ from models.dataset import Dataset, Document, DocumentSegment from services.feature_service import FeatureService -@shared_task(queue='dataset') +@shared_task(queue="dataset") def sync_website_document_indexing_task(dataset_id: str, document_id: str): """ Async process document @@ -26,22 +26,23 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() - sync_indexing_cache_key = 'document_{}_is_sync'.format(document_id) + sync_indexing_cache_key = "document_{}_is_sync".format(document_id) # check document limit features = FeatureService.get_features(dataset.tenant_id) try: if features.billing.enabled: vector_space = features.vector_space if 0 < vector_space.limit <= vector_space.size: - raise ValueError("Your total number of documents plus the number of uploads have over the limit of " - "your subscription.") + raise ValueError( + "Your total number of documents plus the number of uploads have over the limit of " + "your subscription." + ) except Exception as e: - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) if document: - document.indexing_status = 'error' + document.indexing_status = "error" document.error = str(e) document.stopped_at = datetime.datetime.utcnow() db.session.add(document) @@ -49,11 +50,8 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): redis_client.delete(sync_indexing_cache_key) return - logging.info(click.style('Start sync website document: {}'.format(document_id), fg='green')) - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + logging.info(click.style("Start sync website document: {}".format(document_id), fg="green")) + document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() try: if document: # clean old data @@ -69,7 +67,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): db.session.delete(segment) db.session.commit() - document.indexing_status = 'parsing' + document.indexing_status = "parsing" document.processing_started_at = datetime.datetime.utcnow() db.session.add(document) db.session.commit() @@ -78,13 +76,13 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): indexing_runner.run([document]) redis_client.delete(sync_indexing_cache_key) except Exception as ex: - document.indexing_status = 'error' + document.indexing_status = "error" document.error = str(ex) document.stopped_at = datetime.datetime.utcnow() db.session.add(document) db.session.commit() - logging.info(click.style(str(ex), fg='yellow')) + logging.info(click.style(str(ex), fg="yellow")) redis_client.delete(sync_indexing_cache_key) pass end_at = time.perf_counter() - logging.info(click.style('Sync document: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) + logging.info(click.style("Sync document: {} latency: {}".format(document_id, end_at - start_at), fg="green")) diff --git a/api/tests/integration_tests/model_runtime/__mock/anthropic.py b/api/tests/integration_tests/model_runtime/__mock/anthropic.py index 3326f874b0..79a3dc0394 100644 --- a/api/tests/integration_tests/model_runtime/__mock/anthropic.py +++ b/api/tests/integration_tests/model_runtime/__mock/anthropic.py @@ -22,23 +22,20 @@ from anthropic.types import ( ) from anthropic.types.message_delta_event import Delta -MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' +MOCK = os.getenv("MOCK_SWITCH", "false") == "true" class MockAnthropicClass: @staticmethod def mocked_anthropic_chat_create_sync(model: str) -> Message: return Message( - id='msg-123', - type='message', - role='assistant', - content=[ContentBlock(text='hello, I\'m a chatbot from anthropic', type='text')], + id="msg-123", + type="message", + role="assistant", + content=[ContentBlock(text="hello, I'm a chatbot from anthropic", type="text")], model=model, - stop_reason='stop_sequence', - usage=Usage( - input_tokens=1, - output_tokens=1 - ) + stop_reason="stop_sequence", + usage=Usage(input_tokens=1, output_tokens=1), ) @staticmethod @@ -46,52 +43,43 @@ class MockAnthropicClass: full_response_text = "hello, I'm a chatbot from anthropic" yield MessageStartEvent( - type='message_start', + type="message_start", message=Message( - id='msg-123', + id="msg-123", content=[], - role='assistant', + role="assistant", model=model, stop_reason=None, - type='message', - usage=Usage( - input_tokens=1, - output_tokens=1 - ) - ) + type="message", + usage=Usage(input_tokens=1, output_tokens=1), + ), ) index = 0 for i in range(0, len(full_response_text)): yield ContentBlockDeltaEvent( - type='content_block_delta', - delta=TextDelta(text=full_response_text[i], type='text_delta'), - index=index + type="content_block_delta", delta=TextDelta(text=full_response_text[i], type="text_delta"), index=index ) index += 1 yield MessageDeltaEvent( - type='message_delta', - delta=Delta( - stop_reason='stop_sequence' - ), - usage=MessageDeltaUsage( - output_tokens=1 - ) + type="message_delta", delta=Delta(stop_reason="stop_sequence"), usage=MessageDeltaUsage(output_tokens=1) ) - yield MessageStopEvent(type='message_stop') + yield MessageStopEvent(type="message_stop") - def mocked_anthropic(self: Messages, *, - max_tokens: int, - messages: Iterable[MessageParam], - model: str, - stream: Literal[True], - **kwargs: Any - ) -> Union[Message, Stream[MessageStreamEvent]]: + def mocked_anthropic( + self: Messages, + *, + max_tokens: int, + messages: Iterable[MessageParam], + model: str, + stream: Literal[True], + **kwargs: Any, + ) -> Union[Message, Stream[MessageStreamEvent]]: if len(self._client.api_key) < 18: - raise anthropic.AuthenticationError('Invalid API key') + raise anthropic.AuthenticationError("Invalid API key") if stream: return MockAnthropicClass.mocked_anthropic_chat_create_stream(model=model) @@ -102,7 +90,7 @@ class MockAnthropicClass: @pytest.fixture def setup_anthropic_mock(request, monkeypatch: MonkeyPatch): if MOCK: - monkeypatch.setattr(Messages, 'create', MockAnthropicClass.mocked_anthropic) + monkeypatch.setattr(Messages, "create", MockAnthropicClass.mocked_anthropic) yield diff --git a/api/tests/integration_tests/model_runtime/__mock/google.py b/api/tests/integration_tests/model_runtime/__mock/google.py index d838e9890f..bc0684086f 100644 --- a/api/tests/integration_tests/model_runtime/__mock/google.py +++ b/api/tests/integration_tests/model_runtime/__mock/google.py @@ -12,63 +12,46 @@ from google.generativeai.client import _ClientManager, configure from google.generativeai.types import GenerateContentResponse from google.generativeai.types.generation_types import BaseGenerateContentResponse -current_api_key = '' +current_api_key = "" + class MockGoogleResponseClass: _done = False def __iter__(self): - full_response_text = 'it\'s google!' + full_response_text = "it's google!" for i in range(0, len(full_response_text) + 1, 1): if i == len(full_response_text): self._done = True yield GenerateContentResponse( - done=True, - iterator=None, - result=glm.GenerateContentResponse({ - - }), - chunks=[] + done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[] ) else: yield GenerateContentResponse( - done=False, - iterator=None, - result=glm.GenerateContentResponse({ - - }), - chunks=[] + done=False, iterator=None, result=glm.GenerateContentResponse({}), chunks=[] ) + class MockGoogleResponseCandidateClass: - finish_reason = 'stop' + finish_reason = "stop" @property def content(self) -> gag_content.Content: - return gag_content.Content( - parts=[ - gag_content.Part(text='it\'s google!') - ] - ) + return gag_content.Content(parts=[gag_content.Part(text="it's google!")]) + class MockGoogleClass: @staticmethod def generate_content_sync() -> GenerateContentResponse: - return GenerateContentResponse( - done=True, - iterator=None, - result=glm.GenerateContentResponse({ - - }), - chunks=[] - ) + return GenerateContentResponse(done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]) @staticmethod def generate_content_stream() -> Generator[GenerateContentResponse, None, None]: return MockGoogleResponseClass() - def generate_content(self: GenerativeModel, + def generate_content( + self: GenerativeModel, contents: content_types.ContentsType, *, generation_config: generation_config_types.GenerationConfigType | None = None, @@ -79,21 +62,21 @@ class MockGoogleClass: global current_api_key if len(current_api_key) < 16: - raise Exception('Invalid API key') + raise Exception("Invalid API key") if stream: return MockGoogleClass.generate_content_stream() - + return MockGoogleClass.generate_content_sync() - + @property def generative_response_text(self) -> str: - return 'it\'s google!' - + return "it's google!" + @property def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]: return [MockGoogleResponseCandidateClass()] - + def make_client(self: _ClientManager, name: str): global current_api_key @@ -121,7 +104,8 @@ class MockGoogleClass: if not self.default_metadata: return client - + + @pytest.fixture def setup_google_mock(request, monkeypatch: MonkeyPatch): monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text) @@ -131,4 +115,4 @@ def setup_google_mock(request, monkeypatch: MonkeyPatch): yield - monkeypatch.undo() \ No newline at end of file + monkeypatch.undo() diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface.py b/api/tests/integration_tests/model_runtime/__mock/huggingface.py index a75b058d92..97038ef596 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface.py @@ -6,14 +6,15 @@ from huggingface_hub import InferenceClient from tests.integration_tests.model_runtime.__mock.huggingface_chat import MockHuggingfaceChatClass -MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + @pytest.fixture def setup_huggingface_mock(request, monkeypatch: MonkeyPatch): if MOCK: monkeypatch.setattr(InferenceClient, "text_generation", MockHuggingfaceChatClass.text_generation) - + yield if MOCK: - monkeypatch.undo() \ No newline at end of file + monkeypatch.undo() diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py b/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py index 1607624c3c..9ee76c935c 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py @@ -22,10 +22,8 @@ class MockHuggingfaceChatClass: details=Details( finish_reason="length", generated_tokens=6, - tokens=[ - Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6) - ] - ) + tokens=[Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6)], + ), ) return response @@ -36,26 +34,23 @@ class MockHuggingfaceChatClass: for i in range(0, len(full_text)): response = TextGenerationStreamResponse( - token = Token(id=i, text=full_text[i], logprob=0.0, special=False), + token=Token(id=i, text=full_text[i], logprob=0.0, special=False), ) response.generated_text = full_text[i] - response.details = StreamDetails(finish_reason='stop_sequence', generated_tokens=1) + response.details = StreamDetails(finish_reason="stop_sequence", generated_tokens=1) yield response - def text_generation(self: InferenceClient, prompt: str, *, - stream: Literal[False] = ..., - model: Optional[str] = None, - **kwargs: Any + def text_generation( + self: InferenceClient, prompt: str, *, stream: Literal[False] = ..., model: Optional[str] = None, **kwargs: Any ) -> Union[TextGenerationResponse, Generator[TextGenerationStreamResponse, None, None]]: # check if key is valid - if not re.match(r'Bearer\shf\-[a-zA-Z0-9]{16,}', self.headers['authorization']): - raise BadRequestError('Invalid API key') - + if not re.match(r"Bearer\shf\-[a-zA-Z0-9]{16,}", self.headers["authorization"]): + raise BadRequestError("Invalid API key") + if model is None: - raise BadRequestError('Invalid model') - + raise BadRequestError("Invalid model") + if stream: return MockHuggingfaceChatClass.generate_create_stream(model) return MockHuggingfaceChatClass.generate_create_sync(model) - diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py b/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py index c2fe95974b..b37b109eba 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py @@ -5,10 +5,10 @@ class MockTEIClass: @staticmethod def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter: # During mock, we don't have a real server to query, so we just return a dummy value - if 'rerank' in model_name: - model_type = 'reranker' + if "rerank" in model_name: + model_type = "reranker" else: - model_type = 'embedding' + model_type = "embedding" return TeiModelExtraParameter(model_type=model_type, max_input_length=512, max_client_batch_size=1) @@ -17,16 +17,16 @@ class MockTEIClass: # Use space as token separator, and split the text into tokens tokenized_texts = [] for text in texts: - tokens = text.split(' ') + tokens = text.split(" ") current_index = 0 tokenized_text = [] for idx, token in enumerate(tokens): s_token = { - 'id': idx, - 'text': token, - 'special': False, - 'start': current_index, - 'stop': current_index + len(token), + "id": idx, + "text": token, + "special": False, + "start": current_index, + "stop": current_index + len(token), } current_index += len(token) + 1 tokenized_text.append(s_token) @@ -55,18 +55,18 @@ class MockTEIClass: embedding = [0.1] * 768 embeddings.append( { - 'object': 'embedding', - 'embedding': embedding, - 'index': idx, + "object": "embedding", + "embedding": embedding, + "index": idx, } ) return { - 'object': 'list', - 'data': embeddings, - 'model': 'MODEL_NAME', - 'usage': { - 'prompt_tokens': sum(len(text.split(' ')) for text in texts), - 'total_tokens': sum(len(text.split(' ')) for text in texts), + "object": "list", + "data": embeddings, + "model": "MODEL_NAME", + "usage": { + "prompt_tokens": sum(len(text.split(" ")) for text in texts), + "total_tokens": sum(len(text.split(" ")) for text in texts), }, } @@ -83,9 +83,9 @@ class MockTEIClass: for idx, text in enumerate(texts): reranked_docs.append( { - 'index': idx, - 'text': text, - 'score': 0.9, + "index": idx, + "text": text, + "score": 0.9, } ) # For mock, only return the first document diff --git a/api/tests/integration_tests/model_runtime/__mock/openai.py b/api/tests/integration_tests/model_runtime/__mock/openai.py index 0d3f0fbbea..6637f4f212 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai.py @@ -21,13 +21,17 @@ from tests.integration_tests.model_runtime.__mock.openai_remote import MockModel from tests.integration_tests.model_runtime.__mock.openai_speech2text import MockSpeech2TextClass -def mock_openai(monkeypatch: MonkeyPatch, methods: list[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]]) -> Callable[[], None]: +def mock_openai( + monkeypatch: MonkeyPatch, + methods: list[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]], +) -> Callable[[], None]: """ - mock openai module + mock openai module - :param monkeypatch: pytest monkeypatch fixture - :return: unpatch function + :param monkeypatch: pytest monkeypatch fixture + :return: unpatch function """ + def unpatch() -> None: monkeypatch.undo() @@ -52,15 +56,16 @@ def mock_openai(monkeypatch: MonkeyPatch, methods: list[Literal["completion", "c return unpatch -MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + @pytest.fixture def setup_openai_mock(request, monkeypatch): - methods = request.param if hasattr(request, 'param') else [] + methods = request.param if hasattr(request, "param") else [] if MOCK: unpatch = mock_openai(monkeypatch, methods=methods) - + yield if MOCK: - unpatch() \ No newline at end of file + unpatch() diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_chat.py b/api/tests/integration_tests/model_runtime/__mock/openai_chat.py index ba902e32ea..d9cd7b046e 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_chat.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_chat.py @@ -43,62 +43,64 @@ class MockChatClass: if not functions or len(functions) == 0: return None function: completion_create_params.Function = functions[0] - function_name = function['name'] - function_description = function['description'] - function_parameters = function['parameters'] - function_parameters_type = function_parameters['type'] - if function_parameters_type != 'object': + function_name = function["name"] + function_description = function["description"] + function_parameters = function["parameters"] + function_parameters_type = function_parameters["type"] + if function_parameters_type != "object": return None - function_parameters_properties = function_parameters['properties'] - function_parameters_required = function_parameters['required'] + function_parameters_properties = function_parameters["properties"] + function_parameters_required = function_parameters["required"] parameters = {} for parameter_name, parameter in function_parameters_properties.items(): if parameter_name not in function_parameters_required: continue - parameter_type = parameter['type'] - if parameter_type == 'string': - if 'enum' in parameter: - if len(parameter['enum']) == 0: + parameter_type = parameter["type"] + if parameter_type == "string": + if "enum" in parameter: + if len(parameter["enum"]) == 0: continue - parameters[parameter_name] = parameter['enum'][0] + parameters[parameter_name] = parameter["enum"][0] else: - parameters[parameter_name] = 'kawaii' - elif parameter_type == 'integer': + parameters[parameter_name] = "kawaii" + elif parameter_type == "integer": parameters[parameter_name] = 114514 - elif parameter_type == 'number': + elif parameter_type == "number": parameters[parameter_name] = 1919810.0 - elif parameter_type == 'boolean': + elif parameter_type == "boolean": parameters[parameter_name] = True return FunctionCall(name=function_name, arguments=dumps(parameters)) - + @staticmethod - def generate_tool_calls(tools = NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]: + def generate_tool_calls(tools=NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]: list_tool_calls = [] if not tools or len(tools) == 0: return None tool = tools[0] - if 'type' in tools and tools['type'] != 'function': + if "type" in tools and tools["type"] != "function": return None - function = tool['function'] + function = tool["function"] function_call = MockChatClass.generate_function_call(functions=[function]) if function_call is None: return None - - list_tool_calls.append(ChatCompletionMessageToolCall( - id='sakurajima-mai', - function=Function( - name=function_call.name, - arguments=function_call.arguments, - ), - type='function' - )) + + list_tool_calls.append( + ChatCompletionMessageToolCall( + id="sakurajima-mai", + function=Function( + name=function_call.name, + arguments=function_call.arguments, + ), + type="function", + ) + ) return list_tool_calls - + @staticmethod def mocked_openai_chat_create_sync( model: str, @@ -111,30 +113,27 @@ class MockChatClass: tool_calls = MockChatClass.generate_tool_calls(tools=tools) return _ChatCompletion( - id='cmpl-3QJQa5jXJ5Z5X', + id="cmpl-3QJQa5jXJ5Z5X", choices=[ _ChatCompletionChoice( - finish_reason='content_filter', + finish_reason="content_filter", index=0, message=ChatCompletionMessage( - content='elaina', - role='assistant', - function_call=function_call, - tool_calls=tool_calls - ) + content="elaina", role="assistant", function_call=function_call, tool_calls=tool_calls + ), ) ], created=int(time()), model=model, - object='chat.completion', - system_fingerprint='', + object="chat.completion", + system_fingerprint="", usage=CompletionUsage( prompt_tokens=2, completion_tokens=1, total_tokens=3, - ) + ), ) - + @staticmethod def mocked_openai_chat_create_stream( model: str, @@ -150,36 +149,40 @@ class MockChatClass: for i in range(0, len(full_text) + 1): if i == len(full_text): yield ChatCompletionChunk( - id='cmpl-3QJQa5jXJ5Z5X', + id="cmpl-3QJQa5jXJ5Z5X", choices=[ Choice( delta=ChoiceDelta( - content='', + content="", function_call=ChoiceDeltaFunctionCall( name=function_call.name, arguments=function_call.arguments, - ) if function_call else None, - role='assistant', + ) + if function_call + else None, + role="assistant", tool_calls=[ ChoiceDeltaToolCall( index=0, - id='misaka-mikoto', + id="misaka-mikoto", function=ChoiceDeltaToolCallFunction( name=tool_calls[0].function.name, arguments=tool_calls[0].function.arguments, ), - type='function' + type="function", ) - ] if tool_calls and len(tool_calls) > 0 else None + ] + if tool_calls and len(tool_calls) > 0 + else None, ), - finish_reason='function_call', + finish_reason="function_call", index=0, ) ], created=int(time()), model=model, - object='chat.completion.chunk', - system_fingerprint='', + object="chat.completion.chunk", + system_fingerprint="", usage=CompletionUsage( prompt_tokens=2, completion_tokens=17, @@ -188,30 +191,45 @@ class MockChatClass: ) else: yield ChatCompletionChunk( - id='cmpl-3QJQa5jXJ5Z5X', + id="cmpl-3QJQa5jXJ5Z5X", choices=[ Choice( delta=ChoiceDelta( content=full_text[i], - role='assistant', + role="assistant", ), - finish_reason='content_filter', + finish_reason="content_filter", index=0, ) ], created=int(time()), model=model, - object='chat.completion.chunk', - system_fingerprint='', + object="chat.completion.chunk", + system_fingerprint="", ) - def chat_create(self: Completions, *, + def chat_create( + self: Completions, + *, messages: list[ChatCompletionMessageParam], - model: Union[str,Literal[ - "gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613", - "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613", - "gpt-3.5-turbo-1106", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0301", - "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613"], + model: Union[ + str, + Literal[ + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0301", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", + ], ], functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN, response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, @@ -220,24 +238,32 @@ class MockChatClass: **kwargs: Any, ): openai_models = [ - "gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613", - "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613", - "gpt-3.5-turbo-1106", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0301", - "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0301", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", ] - azure_openai_models = [ - "gpt35", "gpt-4v", "gpt-35-turbo" - ] - if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()): - raise InvokeAuthorizationError('Invalid base url') + azure_openai_models = ["gpt35", "gpt-4v", "gpt-35-turbo"] + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + raise InvokeAuthorizationError("Invalid base url") if model in openai_models + azure_openai_models: - if not re.match(r'sk-[a-zA-Z0-9]{24,}$', self._client.api_key) and type(self._client) == OpenAI: + if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI: # sometime, provider use OpenAI compatible API will not have api key or have different api key format # so we only check if model is in openai_models - raise InvokeAuthorizationError('Invalid api key') + raise InvokeAuthorizationError("Invalid api key") if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI: - raise InvokeAuthorizationError('Invalid api key') + raise InvokeAuthorizationError("Invalid api key") if stream: return MockChatClass.mocked_openai_chat_create_stream(model=model, functions=functions, tools=tools) - - return MockChatClass.mocked_openai_chat_create_sync(model=model, functions=functions, tools=tools) \ No newline at end of file + + return MockChatClass.mocked_openai_chat_create_sync(model=model, functions=functions, tools=tools) diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_completion.py b/api/tests/integration_tests/model_runtime/__mock/openai_completion.py index b0d2675905..c27e89248f 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_completion.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_completion.py @@ -17,9 +17,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError class MockCompletionsClass: @staticmethod - def mocked_openai_completion_create_sync( - model: str - ) -> CompletionMessage: + def mocked_openai_completion_create_sync(model: str) -> CompletionMessage: return CompletionMessage( id="cmpl-3QJQa5jXJ5Z5X", object="text_completion", @@ -38,13 +36,11 @@ class MockCompletionsClass: prompt_tokens=2, completion_tokens=1, total_tokens=3, - ) + ), ) - + @staticmethod - def mocked_openai_completion_create_stream( - model: str - ) -> Generator[CompletionMessage, None, None]: + def mocked_openai_completion_create_stream(model: str) -> Generator[CompletionMessage, None, None]: full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```" for i in range(0, len(full_text) + 1): if i == len(full_text): @@ -76,46 +72,59 @@ class MockCompletionsClass: model=model, system_fingerprint="", choices=[ - CompletionChoice( - text=full_text[i], - index=0, - logprobs=None, - finish_reason="content_filter" - ) + CompletionChoice(text=full_text[i], index=0, logprobs=None, finish_reason="content_filter") ], ) - def completion_create(self: Completions, *, model: Union[ - str, Literal["babbage-002", "davinci-002", "gpt-3.5-turbo-instruct", - "text-davinci-003", "text-davinci-002", "text-davinci-001", - "code-davinci-002", "text-curie-001", "text-babbage-001", - "text-ada-001"], + def completion_create( + self: Completions, + *, + model: Union[ + str, + Literal[ + "babbage-002", + "davinci-002", + "gpt-3.5-turbo-instruct", + "text-davinci-003", + "text-davinci-002", + "text-davinci-001", + "code-davinci-002", + "text-curie-001", + "text-babbage-001", + "text-ada-001", + ], ], prompt: Union[str, list[str], list[int], list[list[int]], None], stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN, - **kwargs: Any + **kwargs: Any, ): openai_models = [ - "babbage-002", "davinci-002", "gpt-3.5-turbo-instruct", "text-davinci-003", "text-davinci-002", "text-davinci-001", - "code-davinci-002", "text-curie-001", "text-babbage-001", "text-ada-001", - ] - azure_openai_models = [ - "gpt-35-turbo-instruct" + "babbage-002", + "davinci-002", + "gpt-3.5-turbo-instruct", + "text-davinci-003", + "text-davinci-002", + "text-davinci-001", + "code-davinci-002", + "text-curie-001", + "text-babbage-001", + "text-ada-001", ] + azure_openai_models = ["gpt-35-turbo-instruct"] - if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()): - raise InvokeAuthorizationError('Invalid base url') + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + raise InvokeAuthorizationError("Invalid base url") if model in openai_models + azure_openai_models: - if not re.match(r'sk-[a-zA-Z0-9]{24,}$', self._client.api_key) and type(self._client) == OpenAI: + if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI: # sometime, provider use OpenAI compatible API will not have api key or have different api key format # so we only check if model is in openai_models - raise InvokeAuthorizationError('Invalid api key') + raise InvokeAuthorizationError("Invalid api key") if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI: - raise InvokeAuthorizationError('Invalid api key') - + raise InvokeAuthorizationError("Invalid api key") + if not prompt: - raise BadRequestError('Invalid prompt') + raise BadRequestError("Invalid prompt") if stream: return MockCompletionsClass.mocked_openai_completion_create_stream(model=model) - - return MockCompletionsClass.mocked_openai_completion_create_sync(model=model) \ No newline at end of file + + return MockCompletionsClass.mocked_openai_completion_create_sync(model=model) diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py b/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py index eccdbd3479..4138cdd40d 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py @@ -12,48 +12,39 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError class MockEmbeddingsClass: def create_embeddings( - self: Embeddings, *, + self: Embeddings, + *, input: Union[str, list[str], list[int], list[list[int]]], model: Union[str, Literal["text-embedding-ada-002"]], encoding_format: Literal["float", "base64"] | NotGiven = NOT_GIVEN, - **kwargs: Any + **kwargs: Any, ) -> CreateEmbeddingResponse: if isinstance(input, str): input = [input] - if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()): - raise InvokeAuthorizationError('Invalid base url') - + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + raise InvokeAuthorizationError("Invalid base url") + if len(self._client.api_key) < 18: - raise InvokeAuthorizationError('Invalid API key') - - if encoding_format == 'float': + raise InvokeAuthorizationError("Invalid API key") + + if encoding_format == "float": return CreateEmbeddingResponse( data=[ - Embedding( - embedding=[0.23333 for _ in range(233)], - index=i, - object='embedding' - ) for i in range(len(input)) + Embedding(embedding=[0.23333 for _ in range(233)], index=i, object="embedding") + for i in range(len(input)) ], model=model, - object='list', + object="list", # marked: usage of embeddings should equal the number of testcase - usage=Usage( - prompt_tokens=2, - total_tokens=2 - ) + usage=Usage(prompt_tokens=2, total_tokens=2), ) - - embeddings = 'VEfNvMLUnrwFleO8hcj9vEE/yrzyjOA84E1MvNfoCrxjrI+8sZUKvNgrBT17uY07gJ/IvNvhHLrUemc8KXXGumalIT3YKwU7ZsnbPMhATrwTt6u8JEwRPNMmCjxGREW7TRKvu6/MG7zAyDU8wXLkuuMDZDsXsL28zHzaOw0IArzOiMO8LtASvPKM4Dul5l+80V0bPGVDZ7wYNrI89ucsvJZdYztzRm+8P8ysOyGbc7zrdgK9sdiEPKQ8sbulKdq7KIgdvKIMDj25dNc8k0AXPBn/oLzrdgK8IXe5uz0Dvrt50V68tTjLO4ZOcjoG9x29oGfZufiwmzwMDXy8EL6ZPHvdx7nKjzE8+LCbPG22hTs3EZq7TM+0POrRzTxVZo084wPkO8Nak7z8cpw8pDwxvA2T8LvBC7C72fltvC8Atjp3fYE8JHDLvEYgC7xAdls8YiabPPkEeTzPUbK8gOLCPEBSIbyt5Oy8CpreusNakzywUhA824vLPHRlr7zAhTs7IZtzvHd9AT2xY/O6ok8IvOihqrql5l88K4EvuknWorvYKwW9iXkbvGMTRLw5qPG7onPCPLgNIzwAbK67ftbZPMxYILvAyDW9TLB0vIid1buzCKi7u+d0u8iDSLxNVam8PZyJPNxnETvVANw8Oi5mu9nVszzl65I7DIKNvLGVirxsMJE7tPXQu2PvCT1zRm87p1l9uyRMkbsdfqe8U52ePHRlr7wt9Mw8/C8ivTu02rwJFGq8tpoFPWnC7blWumq7sfy+vG1zCzy9Nlg8iv+PuvxT3DuLU228kVhoOkmTqDrv1kg8ocmTu1WpBzsKml48DzglvI8ECzxwTd27I+pWvIWkQ7xUR007GqlPPBFEDrzGECu865q8PI7BkDwNxYc8tgG6ullMSLsIajs84lk1PNLjD70mv648ZmInO2tnIjzvb5Q8o5KCPLo9xrwKMyq9QqGEvI8ECzxO2508ATUdPRAlTry5kxc8KVGMPJyBHjxIUC476KGqvIU9DzwX87c88PUIParrWrzdlzS/G3K+uzEw2TxB2BU86AhfPAMiRj2dK808a85WPPCft7xU4Bg95Q9NPDxZjzwrpek7yNkZvHa0EjyQ0nM6Nq9fuyjvUbsRq8I7CAMHO3VSWLyuauE7U1qkvPkEeTxs7ZY7B6FMO48Eizy75/S7ieBPvB07rTxmyVu8onPCO5rc6Tu7XIa7oEMfPYngT7u24vk7/+W5PE8eGDxJ1iI9t4cuvBGHiLyH1GY7jfghu+oUSDwa7Mk7iXmbuut2grrq8I2563v8uyofdTxRTrs44lm1vMeWnzukf6s7r4khvEKhhDyhyZO8G5Z4Oy56wTz4sBs81Zknuz3fg7wnJuO74n1vvASEADu98128gUl3vBtyvrtZCU47yep8u5FYaDx2G0e8a85WO5cmUjz3kds8qgqbPCUaerx50d67WKIZPI7BkDua3Om74vKAvL3zXbzXpRA9CI51vLo9xryKzXg7tXtFO9RWLTwnJuM854LqPEIs8zuO5cq8d8V1u9P0cjrQ++C8cGwdPDdUlLoOGeW8auEtu8Z337nlzFK8aRg/vFCkDD0nRSM879bIvKUFID1iStU8EL6ZvLufgLtKgNE7KVEMvJOnSzwahRU895HbvJiIjLvc8n88bmC0PPLP2rywM9C7jTscOoS3mjy/Znu7dhvHuu5Q1Dyq61o6CI71u09hkry0jhw8gb6IPI8EC7uoVAM8gs9rvGM3fjx2G8e81FYtu/ojubyYRRK72Riuu83elDtNNmk70/TyuzUFsbvgKZI7onNCvAehzLumr8679R6+urr6SztX2So8Bl5SOwSEgLv5NpA8LwC2PGPvibzJ6vw7H2tQvOtXwrzXpRC8j0z/uxwcbTy2vr+8VWYNu+t2ArwKmt68NKN2O3XrIzw9A747UU47vaavzjwU+qW8YBqyvE02aTyEt5o8cCmjOxtyPrxs7ZY775NOu+SJWLxMJQY8/bWWu6IMDrzSSsQ7GSPbPLlQnbpVzcE7Pka4PJ96sLycxJg8v/9GPO2HZTyeW3C8Vpawtx2iYTwWBg87/qI/OviwGzxyWcY7M9WNPIA4FD32C2e8tNGWPJ43trxCoYS8FGHavItTbbu7n4C80NemPLm30Ty1OMu7vG1pvG3aPztBP0o75Q/NPJhFEj2V9i683PL/O97+aLz6iu27cdPRum/mKLwvVgc89fqDu3LA+jvm2Ls8mVZ1PIuFBD3ZGK47Cpreut7+aLziWTU8XSEgPMvSKzzO73e5040+vBlmVTxS1K+8mQ4BPZZ8o7w8FpW6OR0DPSSPCz21Vwu99fqDOjMYiDy7XAY8oYaZO+aVwTyX49c84OaXOqdZfTunEQk7B8AMvMDs7zo/D6e8OP5CvN9gIzwNCII8FefOPE026TpzIjU8XsvOO+J9b7rkIiQ8is34O+e0AbxBpv67hcj9uiPq1jtCoQQ8JfY/u86nAz0Wkf28LnrBPJlW9Tt8P4K7BbSjO9grhbyAOJS8G3K+vJLe3LzXpZA7NQUxPJs+JDz6vAS8QHZbvYNVYDrj3yk88PWIPOJ97zuSIVc8ZUPnPMqPsbx2cZi7QfzPOxYGDz2hqtO6H2tQO543NjyFPY+7JRUAOt0wgDyJeZu8MpKTu6AApTtg1ze82JI5vKllZjvrV0I7HX6nu7vndDxg1ze8jwQLu1ZTNjuJvBU7BXGpvAP+C7xJk6g8j2u/vBABlLzlqBi8M9WNutRWLTx0zGM9sHbKPLoZDDtmyVu8tpqFOvPumjyuRqe87lBUvFU0drxs7Za8ejMZOzJPGbyC7qu863v8PDPVjTxJ1iI7Ca01PLuAQLuNHFy7At9LOwP+i7tYxlO80NemO9elkDx45LU8h9TmuzxZjzz/5bk8p84OurvndLwAkGi7XL9luCSzRTwMgg08vrxMPKIwyDwdomG8K6VpPGPvCTxkmTi7M/lHPGxUSzxwKSM8wQuwvOqtkzrLFSa8SbdivAMixjw2r9+7xWt2vAyCDT1NEi87B8CMvG1zi7xpwm27MrbNO9R6Z7xJt+K7jNnhu9ZiFrve/ug55CKkvCwHJLqsOr47+ortvPwvIr2v8NW8YmmVOE+FTLywUhA8MTBZvMiDyLtx8hG8OEE9vMDsbzroCF88DelBOobnPbx+b6U8sbnEOywr3ro93wO9dMzjup2xwbwnRaO7cRZMu8Z337vS44+7VpYwvFWphzxKgNE8L1aHPLPFLbunzo66zFggPN+jHbs7tFo8nW7HO9JKRLyoeD28Fm1DPGZip7u5dNe7KMsXvFnlkzxQpAw7MrZNPHpX0zwSyoK7ayQovPR0Dz3gClK8/juLPDjaCLvqrZO7a4vcO9HEzzvife88KKzXvDmocbwpMkw7t2huvaIMjjznguo7Gy/EOzxZjzoLuZ48qi5VvCjLFzuDmNo654LquyrXgDy7XAa8e7mNvJ7QAb0Rq8K7ojBIvBN0MTuOfha8GoUVveb89bxMsHS8jV9WPPKM4LyAOJS8me9AvZv7qbsbcr47tuL5uaXmXzweKNa7rkYnPINV4Lxcv+W8tVcLvI8oxbzvbxS7oYaZu9+jHT0cHO08c7uAPCSzRTywUhA85xu2u+wBcTuJvJU8PBYVusTghzsnAim8acJtPFQE0zzFIwI9C7meO1DIRry7XAY8MKpkPJZd47suN0e5JTm6u6BDn7zfx1e8AJDoOr9CQbwaQps7x/1TPLTRFryqLtU8JybjPIXI/Tz6I7k6mVb1PMWKNryd1fs8Ok0mPHt2kzy9Ep48TTZpvPS3ibwGOpi8Ns4fPBqFlbr3Kqc8+QR5vHLA+rt7uY289YXyPI6iULxL4gu8Tv/XuycCKbwCnFG8C7kevVG1b7zIXw68GoWVO4rNeDnrM4i8MxgIPUNLs7zSoJW86ScfO+rRzbs6Cqw8NxGautP0cjw0wjY8CGq7vAkU6rxKgNG5+uA+vJXXbrwKM6o86vCNOu+yjjoQAZS8xATCOQVxKbynzo68wxcZvMhATjzS4488ArsRvNEaobwRh4i7t4euvAvd2DwnAik8UtQvvBFEDrz4sJs79gtnvOknnzy+vEy8D3sfPLH8vjzmLo28KVGMvOtXwjvpapm8HBxtPH3K8Lu753Q8/l9FvLvn9DomoG48fET8u9zy/7wMpke8zmQJu3oU2TzlD828KteAPAwNfLu+mBI5ldduPNZDVjq+vEy8eEvqvDHJpLwUPaC6qi7VPABsLjwFcSm72sJcu+bYO7v41NW8RiALvYB7DjzL0is7qLs3us1FSbzaf2K8MnNTuxABFDzF8Wo838fXvOBNzDzre3w8afQEvQE1nbulBaC78zEVvG5B9LzH/VM82Riuuwu5nrwsByQ8Y6yPvHXro7yQ0nM8nStNPJkyOzwnJmM80m7+O1VmjTzqrZM8dhvHOyAQBbz3baG8KTJMPOlqmbxsVEs8Pq3suy56QbzUVq08X3CDvAE1nTwUHuA7hue9vF8tCbvwOAO6F7A9ugd9kryqLtW7auEtu9ONPryPa7+8o9r2O570OzyFpEO8ntCBPOqtk7sykhO7lC1AOw2TcLswhiq6vx4HvP5fRbwuesG7Mk8ZvA4Z5TlfcAM9DrIwPL//xrzMm5q8JEwRPHBsnbxL4gu8jyjFu99gozrkZZ483GeRPLuAwDuYiIw8iv8PvK5Gpzx+b6W87Yflu3NGbzyE+hQ8a4tcPItT7bsoy5e8L1YHvWQyBDwrga86kPEzvBQ9oDxtl0W8lwKYvGpIYrxQ5wY8AJDovOLyALyw3f489JjJvMdTpTkKMyo8V9mqvH3K8LpyNYy8JHDLOixu2LpQ54Y8Q0uzu8LUnrs0wrY84vIAveihqjwfihA8DIKNvLDd/jywM1C7FB7gOxsLirxAUqE7sulnvH3K8DkAkGg8jsGQvO+TzrynWf287CCxvK4Drbwg8UQ8JRr6vFEqAbskjwu76q2TPNP0cjopDhK8dVJYvFIXKrxLn5G8AK8oPAb3HbxbOXE8Bvedun5Q5ThHyjk8QdiVvBXDlLw0o/Y7aLGKupkOgTxKPdc81kNWPtUAXLxUR827X1FDPf47izxsEVE8akhiPIhaWzxYX5+7hT0PPSrXgLxQC0E8i4WEvKUp2jtCLHM8DcWHO768zLxnK5a89R6+vH9czrorpem73h0pvAnwr7yKzXi8gDgUPf47Czq9zyO8728UOf34EDy6PUY76OSkvKZIGr2ZDgE8gzEmPG3av7v77Ce7/oP/O3MiNTtas/w8x1OlO/D1CDvDfs27ll1jO2Ufrbv1hXK8WINZuxN0sbuxlYq8OYS3uia/rjyiTwi9O7TaO+/WyDyiDA49E7erO3fF9bj6I7k7qHi9O3SoKbyBSfc7drSSvGPvCT2pQay7t2huPGnC7byUCQY8CEaBu6rHoDhx8hE8/fgQvCjLl7zdeHS8x/3TO0Isc7tas3y8jwQLvUKhhDz+foU8fCDCPC+ZgTywD5Y7ZR8tOla66rtCCLm8gWg3vDoKrLxbWDE76SefPBkj2zrlqJi7pebfuv6Df7zWQ9a7lHA6PGDXtzzMv1Q8mtxpOwJ4lzxKGZ28mGnMPDw6z7yxY/O7m2Leu7juYjwvVge8zFigPGpIYjtWumo5xs2wOgyCjbxrZ6K8bbaFvKzTCbsks8W7C7mePIU9DzxQyEY8posUvAW0ozrHlh88CyBTPJRwursxySQ757SBuqcRCbwNCIK8EL6ZvIG+iLsIRgE8rF74vOJZtbuUcDq8r/DVPMpMt7sL3Vi8eWqquww/kzqj2vY5auGtu85kiTwMPxM66KGqvBIxNzuwUpA8v2b7u09C0rx7ms08NUirvFYQPLxKPdc68mimvP5fRTtoPPm7XuqOOgOJ+jxfLYm7u58AvXz8B72PR4W6ldfuuys+tbvYKwW7pkiaPLB2SjvKj7G875POvA6yML7qFEg9Eu68O6Up2rz77Kc84CmSPP6ivzz4sJu6/C+iOaUpWjwq14A84E3MOYB7Dr2d1Xu775NOvC6e+7spUYw8PzPhO5TGizt29ww9yNkZPY7lyrz020M7QRsQu3z8BzwkCZe79YXyO8jZmTzvGUM8HgQcO9kYrrzxBmy8hLeaPLYBOjz+oj88flBlO6GqUzuiMMi8fxlUvCr7ujz41NU8DA38PBeMAzx7uY28TTZpvFG1bzxtc4s89ucsPEereTwfipC82p4iPKtNFbzo5KQ7pcKlOW5gtDzO73c7B6FMOzRbgjxCXoo8v0JBOSl1RrwxDJ+7XWSaPD3Aw7sOsjA8tuJ5vKw6Pry5k5c8ZUNnvG/H6DyVTAA8Shkdvd7+aDvtpiW9qUGsPFTgmDwbcr68TTbpO1DnhryNX9a7mrivvIqpPjxsqhy81HrnOzv31Dvth+U6UtQvPBz4MrvtpqW84OYXvRz4sjxwkFe8zSGPuycCqbyFPY8818nKOw84JTy8bWk8USqBvBGHiLtosQo8BOs0u9skl7xQ54Y8uvrLPOknn7w705o8Jny0PAd9EjxhoKa8Iv2tu2M3/jtsVEs8DcUHPQSEADs3eE48GkKbupRR+rvdeHQ7Xy2JvO1jKz0xMFm8sWPzux07LbyrTZW7bdq/O6Pa9r0ahRW9CyDTOjSjdjyQ8bO8yaIIPfupLTz/CfQ7xndfvJs+JD0zPEK8KO/RvMpw8bwObzY7fm+lPJtiXrz5BHm8WmsIvKlBrLuDdKA7hWHJOgd9Ers0o/Y7nlvwu5NAl7u8BrW6utYRO2SZuDxyNYw8CppevAY6GDxVqQe9oGdZPFa6ary3RLS70NcmO2PQSb36ZrM86q2TPML42LwewaE8k2RRPDmocTsi/S29o/k2PHRlr7zjnC+8gHsOPUpcFzxtl8W6tuL5vHw/gry/2wy9yaIIvINV4Dx3fQG7ISFoPO7pnzwGXlK8HPiyPGAaMjzBC7A7MQyfu+eC6jyV1+67pDyxvBWkVLxrJKg754LqOScCKbwpUQy8KIgdOJDSc7zDfk08tLLWvNZDVjyh7c28ShmdvMnlgjs2NdS8ISHovP5+hbxGIIs8ayQouyKnXDzBcmS6zw44u86IQ7yl5l+7cngGvWvOVrsEhIC7yNkZPJODkbuAn0g8XN6lPOaVwbuTgxG8OR2DPAb3HTzlqJi8nUoNvCAVf73Mmxo9afSEu4FotzveHSk8c0ZvOMFOqjwP9Sq87iwavIEBg7xIUK68IbozuozZ4btg17c7vx4Hvarr2rtp9IQ8Rt0QO+1jqzyeNzY8kNLzO8sVpry98108OCL9uyisV7vhr4Y8FgaPvLFjczw42og8gWg3vPX6gzsNk/C83GeRPCUVgDy0jpw7yNkZu2VD5zvh93o81h+cuw3Fhzyl5t+86Y7TvHa0EjyzCCi7WmsIPIy1Jzy00Ra6NUiru50rTTx50d47/HKcO2wwETw0f7y8sFIQvNxnkbzS4w855pVBu9FdGzx9yvC6TM80vFQjkzy/Zvs7BhtYPLjKKLqPa787A/6LOyiInbzooSq8728UPIFJ97wq+7q8R6v5u1tYMbwdomG6iSPKPAb3HTx3oTu7fGO8POqtk7ze/ug84wNkPMnq/DsB8iK9ogwOu6lBrDznguo8NQUxvHKcwDo28tm7yNmZPN1UurxCoYS80m7+Oy+9OzzGzTC836MdvCDNCrtaawi7dVLYPEfKuTxzRm88cCmjOyXSBbwGOpi879ZIO8dTJbtqnrO8NMI2vR1+J7xwTV087umfPFG17zsC30s8oYaZPKllZrzZGK47zss9vP21FryZywa9bbYFPVNapDt2G0e7E3SxPMUjgry5dNc895Hbu0H8z7ueN7a7OccxPFhfH7vC1B48n3owvEhQLrzu6Z+8HTutvEBSITw6Taa5g1XgPCzEqbxfLYk9OYQ3vBlm1bvPUTI8wIU7PIy1pzyFyP07gzGmO3NGb7yS3ty7O5CguyEhaLyWoF28pmxUOaZImrz+g/87mnU1vFbsgTxvo668PFmPO2KNTzy09VC8LG5YPHhL6rsvJPC7kTQuvEGCxDlhB9s6u58AvfCAd7z0t4k7kVjoOCkOkrxMjDq8iPOmPL0SnrxsMJG7OEG9vCUa+rvx4rE7cpxAPDCGqjukf6u8TEnAvNn57TweBBw7JdKFvIy1p7vIg8i7' + + embeddings = "VEfNvMLUnrwFleO8hcj9vEE/yrzyjOA84E1MvNfoCrxjrI+8sZUKvNgrBT17uY07gJ/IvNvhHLrUemc8KXXGumalIT3YKwU7ZsnbPMhATrwTt6u8JEwRPNMmCjxGREW7TRKvu6/MG7zAyDU8wXLkuuMDZDsXsL28zHzaOw0IArzOiMO8LtASvPKM4Dul5l+80V0bPGVDZ7wYNrI89ucsvJZdYztzRm+8P8ysOyGbc7zrdgK9sdiEPKQ8sbulKdq7KIgdvKIMDj25dNc8k0AXPBn/oLzrdgK8IXe5uz0Dvrt50V68tTjLO4ZOcjoG9x29oGfZufiwmzwMDXy8EL6ZPHvdx7nKjzE8+LCbPG22hTs3EZq7TM+0POrRzTxVZo084wPkO8Nak7z8cpw8pDwxvA2T8LvBC7C72fltvC8Atjp3fYE8JHDLvEYgC7xAdls8YiabPPkEeTzPUbK8gOLCPEBSIbyt5Oy8CpreusNakzywUhA824vLPHRlr7zAhTs7IZtzvHd9AT2xY/O6ok8IvOihqrql5l88K4EvuknWorvYKwW9iXkbvGMTRLw5qPG7onPCPLgNIzwAbK67ftbZPMxYILvAyDW9TLB0vIid1buzCKi7u+d0u8iDSLxNVam8PZyJPNxnETvVANw8Oi5mu9nVszzl65I7DIKNvLGVirxsMJE7tPXQu2PvCT1zRm87p1l9uyRMkbsdfqe8U52ePHRlr7wt9Mw8/C8ivTu02rwJFGq8tpoFPWnC7blWumq7sfy+vG1zCzy9Nlg8iv+PuvxT3DuLU228kVhoOkmTqDrv1kg8ocmTu1WpBzsKml48DzglvI8ECzxwTd27I+pWvIWkQ7xUR007GqlPPBFEDrzGECu865q8PI7BkDwNxYc8tgG6ullMSLsIajs84lk1PNLjD70mv648ZmInO2tnIjzvb5Q8o5KCPLo9xrwKMyq9QqGEvI8ECzxO2508ATUdPRAlTry5kxc8KVGMPJyBHjxIUC476KGqvIU9DzwX87c88PUIParrWrzdlzS/G3K+uzEw2TxB2BU86AhfPAMiRj2dK808a85WPPCft7xU4Bg95Q9NPDxZjzwrpek7yNkZvHa0EjyQ0nM6Nq9fuyjvUbsRq8I7CAMHO3VSWLyuauE7U1qkvPkEeTxs7ZY7B6FMO48Eizy75/S7ieBPvB07rTxmyVu8onPCO5rc6Tu7XIa7oEMfPYngT7u24vk7/+W5PE8eGDxJ1iI9t4cuvBGHiLyH1GY7jfghu+oUSDwa7Mk7iXmbuut2grrq8I2563v8uyofdTxRTrs44lm1vMeWnzukf6s7r4khvEKhhDyhyZO8G5Z4Oy56wTz4sBs81Zknuz3fg7wnJuO74n1vvASEADu98128gUl3vBtyvrtZCU47yep8u5FYaDx2G0e8a85WO5cmUjz3kds8qgqbPCUaerx50d67WKIZPI7BkDua3Om74vKAvL3zXbzXpRA9CI51vLo9xryKzXg7tXtFO9RWLTwnJuM854LqPEIs8zuO5cq8d8V1u9P0cjrQ++C8cGwdPDdUlLoOGeW8auEtu8Z337nlzFK8aRg/vFCkDD0nRSM879bIvKUFID1iStU8EL6ZvLufgLtKgNE7KVEMvJOnSzwahRU895HbvJiIjLvc8n88bmC0PPLP2rywM9C7jTscOoS3mjy/Znu7dhvHuu5Q1Dyq61o6CI71u09hkry0jhw8gb6IPI8EC7uoVAM8gs9rvGM3fjx2G8e81FYtu/ojubyYRRK72Riuu83elDtNNmk70/TyuzUFsbvgKZI7onNCvAehzLumr8679R6+urr6SztX2So8Bl5SOwSEgLv5NpA8LwC2PGPvibzJ6vw7H2tQvOtXwrzXpRC8j0z/uxwcbTy2vr+8VWYNu+t2ArwKmt68NKN2O3XrIzw9A747UU47vaavzjwU+qW8YBqyvE02aTyEt5o8cCmjOxtyPrxs7ZY775NOu+SJWLxMJQY8/bWWu6IMDrzSSsQ7GSPbPLlQnbpVzcE7Pka4PJ96sLycxJg8v/9GPO2HZTyeW3C8Vpawtx2iYTwWBg87/qI/OviwGzxyWcY7M9WNPIA4FD32C2e8tNGWPJ43trxCoYS8FGHavItTbbu7n4C80NemPLm30Ty1OMu7vG1pvG3aPztBP0o75Q/NPJhFEj2V9i683PL/O97+aLz6iu27cdPRum/mKLwvVgc89fqDu3LA+jvm2Ls8mVZ1PIuFBD3ZGK47Cpreut7+aLziWTU8XSEgPMvSKzzO73e5040+vBlmVTxS1K+8mQ4BPZZ8o7w8FpW6OR0DPSSPCz21Vwu99fqDOjMYiDy7XAY8oYaZO+aVwTyX49c84OaXOqdZfTunEQk7B8AMvMDs7zo/D6e8OP5CvN9gIzwNCII8FefOPE026TpzIjU8XsvOO+J9b7rkIiQ8is34O+e0AbxBpv67hcj9uiPq1jtCoQQ8JfY/u86nAz0Wkf28LnrBPJlW9Tt8P4K7BbSjO9grhbyAOJS8G3K+vJLe3LzXpZA7NQUxPJs+JDz6vAS8QHZbvYNVYDrj3yk88PWIPOJ97zuSIVc8ZUPnPMqPsbx2cZi7QfzPOxYGDz2hqtO6H2tQO543NjyFPY+7JRUAOt0wgDyJeZu8MpKTu6AApTtg1ze82JI5vKllZjvrV0I7HX6nu7vndDxg1ze8jwQLu1ZTNjuJvBU7BXGpvAP+C7xJk6g8j2u/vBABlLzlqBi8M9WNutRWLTx0zGM9sHbKPLoZDDtmyVu8tpqFOvPumjyuRqe87lBUvFU0drxs7Za8ejMZOzJPGbyC7qu863v8PDPVjTxJ1iI7Ca01PLuAQLuNHFy7At9LOwP+i7tYxlO80NemO9elkDx45LU8h9TmuzxZjzz/5bk8p84OurvndLwAkGi7XL9luCSzRTwMgg08vrxMPKIwyDwdomG8K6VpPGPvCTxkmTi7M/lHPGxUSzxwKSM8wQuwvOqtkzrLFSa8SbdivAMixjw2r9+7xWt2vAyCDT1NEi87B8CMvG1zi7xpwm27MrbNO9R6Z7xJt+K7jNnhu9ZiFrve/ug55CKkvCwHJLqsOr47+ortvPwvIr2v8NW8YmmVOE+FTLywUhA8MTBZvMiDyLtx8hG8OEE9vMDsbzroCF88DelBOobnPbx+b6U8sbnEOywr3ro93wO9dMzjup2xwbwnRaO7cRZMu8Z337vS44+7VpYwvFWphzxKgNE8L1aHPLPFLbunzo66zFggPN+jHbs7tFo8nW7HO9JKRLyoeD28Fm1DPGZip7u5dNe7KMsXvFnlkzxQpAw7MrZNPHpX0zwSyoK7ayQovPR0Dz3gClK8/juLPDjaCLvqrZO7a4vcO9HEzzvife88KKzXvDmocbwpMkw7t2huvaIMjjznguo7Gy/EOzxZjzoLuZ48qi5VvCjLFzuDmNo654LquyrXgDy7XAa8e7mNvJ7QAb0Rq8K7ojBIvBN0MTuOfha8GoUVveb89bxMsHS8jV9WPPKM4LyAOJS8me9AvZv7qbsbcr47tuL5uaXmXzweKNa7rkYnPINV4Lxcv+W8tVcLvI8oxbzvbxS7oYaZu9+jHT0cHO08c7uAPCSzRTywUhA85xu2u+wBcTuJvJU8PBYVusTghzsnAim8acJtPFQE0zzFIwI9C7meO1DIRry7XAY8MKpkPJZd47suN0e5JTm6u6BDn7zfx1e8AJDoOr9CQbwaQps7x/1TPLTRFryqLtU8JybjPIXI/Tz6I7k6mVb1PMWKNryd1fs8Ok0mPHt2kzy9Ep48TTZpvPS3ibwGOpi8Ns4fPBqFlbr3Kqc8+QR5vHLA+rt7uY289YXyPI6iULxL4gu8Tv/XuycCKbwCnFG8C7kevVG1b7zIXw68GoWVO4rNeDnrM4i8MxgIPUNLs7zSoJW86ScfO+rRzbs6Cqw8NxGautP0cjw0wjY8CGq7vAkU6rxKgNG5+uA+vJXXbrwKM6o86vCNOu+yjjoQAZS8xATCOQVxKbynzo68wxcZvMhATjzS4488ArsRvNEaobwRh4i7t4euvAvd2DwnAik8UtQvvBFEDrz4sJs79gtnvOknnzy+vEy8D3sfPLH8vjzmLo28KVGMvOtXwjvpapm8HBxtPH3K8Lu753Q8/l9FvLvn9DomoG48fET8u9zy/7wMpke8zmQJu3oU2TzlD828KteAPAwNfLu+mBI5ldduPNZDVjq+vEy8eEvqvDHJpLwUPaC6qi7VPABsLjwFcSm72sJcu+bYO7v41NW8RiALvYB7DjzL0is7qLs3us1FSbzaf2K8MnNTuxABFDzF8Wo838fXvOBNzDzre3w8afQEvQE1nbulBaC78zEVvG5B9LzH/VM82Riuuwu5nrwsByQ8Y6yPvHXro7yQ0nM8nStNPJkyOzwnJmM80m7+O1VmjTzqrZM8dhvHOyAQBbz3baG8KTJMPOlqmbxsVEs8Pq3suy56QbzUVq08X3CDvAE1nTwUHuA7hue9vF8tCbvwOAO6F7A9ugd9kryqLtW7auEtu9ONPryPa7+8o9r2O570OzyFpEO8ntCBPOqtk7sykhO7lC1AOw2TcLswhiq6vx4HvP5fRbwuesG7Mk8ZvA4Z5TlfcAM9DrIwPL//xrzMm5q8JEwRPHBsnbxL4gu8jyjFu99gozrkZZ483GeRPLuAwDuYiIw8iv8PvK5Gpzx+b6W87Yflu3NGbzyE+hQ8a4tcPItT7bsoy5e8L1YHvWQyBDwrga86kPEzvBQ9oDxtl0W8lwKYvGpIYrxQ5wY8AJDovOLyALyw3f489JjJvMdTpTkKMyo8V9mqvH3K8LpyNYy8JHDLOixu2LpQ54Y8Q0uzu8LUnrs0wrY84vIAveihqjwfihA8DIKNvLDd/jywM1C7FB7gOxsLirxAUqE7sulnvH3K8DkAkGg8jsGQvO+TzrynWf287CCxvK4Drbwg8UQ8JRr6vFEqAbskjwu76q2TPNP0cjopDhK8dVJYvFIXKrxLn5G8AK8oPAb3HbxbOXE8Bvedun5Q5ThHyjk8QdiVvBXDlLw0o/Y7aLGKupkOgTxKPdc81kNWPtUAXLxUR827X1FDPf47izxsEVE8akhiPIhaWzxYX5+7hT0PPSrXgLxQC0E8i4WEvKUp2jtCLHM8DcWHO768zLxnK5a89R6+vH9czrorpem73h0pvAnwr7yKzXi8gDgUPf47Czq9zyO8728UOf34EDy6PUY76OSkvKZIGr2ZDgE8gzEmPG3av7v77Ce7/oP/O3MiNTtas/w8x1OlO/D1CDvDfs27ll1jO2Ufrbv1hXK8WINZuxN0sbuxlYq8OYS3uia/rjyiTwi9O7TaO+/WyDyiDA49E7erO3fF9bj6I7k7qHi9O3SoKbyBSfc7drSSvGPvCT2pQay7t2huPGnC7byUCQY8CEaBu6rHoDhx8hE8/fgQvCjLl7zdeHS8x/3TO0Isc7tas3y8jwQLvUKhhDz+foU8fCDCPC+ZgTywD5Y7ZR8tOla66rtCCLm8gWg3vDoKrLxbWDE76SefPBkj2zrlqJi7pebfuv6Df7zWQ9a7lHA6PGDXtzzMv1Q8mtxpOwJ4lzxKGZ28mGnMPDw6z7yxY/O7m2Leu7juYjwvVge8zFigPGpIYjtWumo5xs2wOgyCjbxrZ6K8bbaFvKzTCbsks8W7C7mePIU9DzxQyEY8posUvAW0ozrHlh88CyBTPJRwursxySQ757SBuqcRCbwNCIK8EL6ZvIG+iLsIRgE8rF74vOJZtbuUcDq8r/DVPMpMt7sL3Vi8eWqquww/kzqj2vY5auGtu85kiTwMPxM66KGqvBIxNzuwUpA8v2b7u09C0rx7ms08NUirvFYQPLxKPdc68mimvP5fRTtoPPm7XuqOOgOJ+jxfLYm7u58AvXz8B72PR4W6ldfuuys+tbvYKwW7pkiaPLB2SjvKj7G875POvA6yML7qFEg9Eu68O6Up2rz77Kc84CmSPP6ivzz4sJu6/C+iOaUpWjwq14A84E3MOYB7Dr2d1Xu775NOvC6e+7spUYw8PzPhO5TGizt29ww9yNkZPY7lyrz020M7QRsQu3z8BzwkCZe79YXyO8jZmTzvGUM8HgQcO9kYrrzxBmy8hLeaPLYBOjz+oj88flBlO6GqUzuiMMi8fxlUvCr7ujz41NU8DA38PBeMAzx7uY28TTZpvFG1bzxtc4s89ucsPEereTwfipC82p4iPKtNFbzo5KQ7pcKlOW5gtDzO73c7B6FMOzRbgjxCXoo8v0JBOSl1RrwxDJ+7XWSaPD3Aw7sOsjA8tuJ5vKw6Pry5k5c8ZUNnvG/H6DyVTAA8Shkdvd7+aDvtpiW9qUGsPFTgmDwbcr68TTbpO1DnhryNX9a7mrivvIqpPjxsqhy81HrnOzv31Dvth+U6UtQvPBz4MrvtpqW84OYXvRz4sjxwkFe8zSGPuycCqbyFPY8818nKOw84JTy8bWk8USqBvBGHiLtosQo8BOs0u9skl7xQ54Y8uvrLPOknn7w705o8Jny0PAd9EjxhoKa8Iv2tu2M3/jtsVEs8DcUHPQSEADs3eE48GkKbupRR+rvdeHQ7Xy2JvO1jKz0xMFm8sWPzux07LbyrTZW7bdq/O6Pa9r0ahRW9CyDTOjSjdjyQ8bO8yaIIPfupLTz/CfQ7xndfvJs+JD0zPEK8KO/RvMpw8bwObzY7fm+lPJtiXrz5BHm8WmsIvKlBrLuDdKA7hWHJOgd9Ers0o/Y7nlvwu5NAl7u8BrW6utYRO2SZuDxyNYw8CppevAY6GDxVqQe9oGdZPFa6ary3RLS70NcmO2PQSb36ZrM86q2TPML42LwewaE8k2RRPDmocTsi/S29o/k2PHRlr7zjnC+8gHsOPUpcFzxtl8W6tuL5vHw/gry/2wy9yaIIvINV4Dx3fQG7ISFoPO7pnzwGXlK8HPiyPGAaMjzBC7A7MQyfu+eC6jyV1+67pDyxvBWkVLxrJKg754LqOScCKbwpUQy8KIgdOJDSc7zDfk08tLLWvNZDVjyh7c28ShmdvMnlgjs2NdS8ISHovP5+hbxGIIs8ayQouyKnXDzBcmS6zw44u86IQ7yl5l+7cngGvWvOVrsEhIC7yNkZPJODkbuAn0g8XN6lPOaVwbuTgxG8OR2DPAb3HTzlqJi8nUoNvCAVf73Mmxo9afSEu4FotzveHSk8c0ZvOMFOqjwP9Sq87iwavIEBg7xIUK68IbozuozZ4btg17c7vx4Hvarr2rtp9IQ8Rt0QO+1jqzyeNzY8kNLzO8sVpry98108OCL9uyisV7vhr4Y8FgaPvLFjczw42og8gWg3vPX6gzsNk/C83GeRPCUVgDy0jpw7yNkZu2VD5zvh93o81h+cuw3Fhzyl5t+86Y7TvHa0EjyzCCi7WmsIPIy1Jzy00Ra6NUiru50rTTx50d47/HKcO2wwETw0f7y8sFIQvNxnkbzS4w855pVBu9FdGzx9yvC6TM80vFQjkzy/Zvs7BhtYPLjKKLqPa787A/6LOyiInbzooSq8728UPIFJ97wq+7q8R6v5u1tYMbwdomG6iSPKPAb3HTx3oTu7fGO8POqtk7ze/ug84wNkPMnq/DsB8iK9ogwOu6lBrDznguo8NQUxvHKcwDo28tm7yNmZPN1UurxCoYS80m7+Oy+9OzzGzTC836MdvCDNCrtaawi7dVLYPEfKuTxzRm88cCmjOyXSBbwGOpi879ZIO8dTJbtqnrO8NMI2vR1+J7xwTV087umfPFG17zsC30s8oYaZPKllZrzZGK47zss9vP21FryZywa9bbYFPVNapDt2G0e7E3SxPMUjgry5dNc895Hbu0H8z7ueN7a7OccxPFhfH7vC1B48n3owvEhQLrzu6Z+8HTutvEBSITw6Taa5g1XgPCzEqbxfLYk9OYQ3vBlm1bvPUTI8wIU7PIy1pzyFyP07gzGmO3NGb7yS3ty7O5CguyEhaLyWoF28pmxUOaZImrz+g/87mnU1vFbsgTxvo668PFmPO2KNTzy09VC8LG5YPHhL6rsvJPC7kTQuvEGCxDlhB9s6u58AvfCAd7z0t4k7kVjoOCkOkrxMjDq8iPOmPL0SnrxsMJG7OEG9vCUa+rvx4rE7cpxAPDCGqjukf6u8TEnAvNn57TweBBw7JdKFvIy1p7vIg8i7" data = [] for i, text in enumerate(input): - obj = Embedding( - embedding=[], - index=i, - object='embedding' - ) + obj = Embedding(embedding=[], index=i, object="embedding") obj.embedding = embeddings data.append(obj) @@ -61,10 +52,7 @@ class MockEmbeddingsClass: return CreateEmbeddingResponse( data=data, model=model, - object='list', + object="list", # marked: usage of embeddings should equal the number of testcase - usage=Usage( - prompt_tokens=2, - total_tokens=2 - ) - ) \ No newline at end of file + usage=Usage(prompt_tokens=2, total_tokens=2), + ) diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py b/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py index 9466f4bfb8..270a88e85f 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py @@ -10,58 +10,92 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError class MockModerationClass: - def moderation_create(self: Moderations,*, + def moderation_create( + self: Moderations, + *, input: Union[str, list[str]], model: Union[str, Literal["text-moderation-latest", "text-moderation-stable"]] | NotGiven = NOT_GIVEN, - **kwargs: Any + **kwargs: Any, ) -> ModerationCreateResponse: if isinstance(input, str): input = [input] - if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()): - raise InvokeAuthorizationError('Invalid base url') - + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + raise InvokeAuthorizationError("Invalid base url") + if len(self._client.api_key) < 18: - raise InvokeAuthorizationError('Invalid API key') + raise InvokeAuthorizationError("Invalid API key") for text in input: result = [] - if 'kill' in text: + if "kill" in text: moderation_categories = { - 'harassment': False, 'harassment/threatening': False, 'hate': False, 'hate/threatening': False, - 'self-harm': False, 'self-harm/instructions': False, 'self-harm/intent': False, 'sexual': False, - 'sexual/minors': False, 'violence': False, 'violence/graphic': False + "harassment": False, + "harassment/threatening": False, + "hate": False, + "hate/threatening": False, + "self-harm": False, + "self-harm/instructions": False, + "self-harm/intent": False, + "sexual": False, + "sexual/minors": False, + "violence": False, + "violence/graphic": False, } moderation_categories_scores = { - 'harassment': 1.0, 'harassment/threatening': 1.0, 'hate': 1.0, 'hate/threatening': 1.0, - 'self-harm': 1.0, 'self-harm/instructions': 1.0, 'self-harm/intent': 1.0, 'sexual': 1.0, - 'sexual/minors': 1.0, 'violence': 1.0, 'violence/graphic': 1.0 + "harassment": 1.0, + "harassment/threatening": 1.0, + "hate": 1.0, + "hate/threatening": 1.0, + "self-harm": 1.0, + "self-harm/instructions": 1.0, + "self-harm/intent": 1.0, + "sexual": 1.0, + "sexual/minors": 1.0, + "violence": 1.0, + "violence/graphic": 1.0, } - result.append(Moderation( - flagged=True, - categories=Categories(**moderation_categories), - category_scores=CategoryScores(**moderation_categories_scores) - )) + result.append( + Moderation( + flagged=True, + categories=Categories(**moderation_categories), + category_scores=CategoryScores(**moderation_categories_scores), + ) + ) else: moderation_categories = { - 'harassment': False, 'harassment/threatening': False, 'hate': False, 'hate/threatening': False, - 'self-harm': False, 'self-harm/instructions': False, 'self-harm/intent': False, 'sexual': False, - 'sexual/minors': False, 'violence': False, 'violence/graphic': False + "harassment": False, + "harassment/threatening": False, + "hate": False, + "hate/threatening": False, + "self-harm": False, + "self-harm/instructions": False, + "self-harm/intent": False, + "sexual": False, + "sexual/minors": False, + "violence": False, + "violence/graphic": False, } moderation_categories_scores = { - 'harassment': 0.0, 'harassment/threatening': 0.0, 'hate': 0.0, 'hate/threatening': 0.0, - 'self-harm': 0.0, 'self-harm/instructions': 0.0, 'self-harm/intent': 0.0, 'sexual': 0.0, - 'sexual/minors': 0.0, 'violence': 0.0, 'violence/graphic': 0.0 + "harassment": 0.0, + "harassment/threatening": 0.0, + "hate": 0.0, + "hate/threatening": 0.0, + "self-harm": 0.0, + "self-harm/instructions": 0.0, + "self-harm/intent": 0.0, + "sexual": 0.0, + "sexual/minors": 0.0, + "violence": 0.0, + "violence/graphic": 0.0, } - result.append(Moderation( - flagged=False, - categories=Categories(**moderation_categories), - category_scores=CategoryScores(**moderation_categories_scores) - )) + result.append( + Moderation( + flagged=False, + categories=Categories(**moderation_categories), + category_scores=CategoryScores(**moderation_categories_scores), + ) + ) - return ModerationCreateResponse( - id='shiroii kuloko', - model=model, - results=result - ) \ No newline at end of file + return ModerationCreateResponse(id="shiroii kuloko", model=model, results=result) diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_remote.py b/api/tests/integration_tests/model_runtime/__mock/openai_remote.py index 0124ac045b..cb8f249543 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_remote.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_remote.py @@ -6,17 +6,18 @@ from openai.types.model import Model class MockModelClass: """ - mock class for openai.models.Models + mock class for openai.models.Models """ + def list( self, **kwargs, ) -> list[Model]: return [ Model( - id='ft:gpt-3.5-turbo-0613:personal::8GYJLPDQ', + id="ft:gpt-3.5-turbo-0613:personal::8GYJLPDQ", created=int(time()), - object='model', - owned_by='organization:org-123', + object="model", + owned_by="organization:org-123", ) - ] \ No newline at end of file + ] diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py b/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py index 755fec4c1f..ef361e8613 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py @@ -9,7 +9,8 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError class MockSpeech2TextClass: - def speech2text_create(self: Transcriptions, + def speech2text_create( + self: Transcriptions, *, file: FileTypes, model: Union[str, Literal["whisper-1"]], @@ -17,14 +18,12 @@ class MockSpeech2TextClass: prompt: str | NotGiven = NOT_GIVEN, response_format: Literal["json", "text", "srt", "verbose_json", "vtt"] | NotGiven = NOT_GIVEN, temperature: float | NotGiven = NOT_GIVEN, - **kwargs: Any + **kwargs: Any, ) -> Transcription: - if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()): - raise InvokeAuthorizationError('Invalid base url') - + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + raise InvokeAuthorizationError("Invalid base url") + if len(self._client.api_key) < 18: - raise InvokeAuthorizationError('Invalid API key') - - return Transcription( - text='1, 2, 3, 4, 5, 6, 7, 8, 9, 10' - ) \ No newline at end of file + raise InvokeAuthorizationError("Invalid API key") + + return Transcription(text="1, 2, 3, 4, 5, 6, 7, 8, 9, 10") diff --git a/api/tests/integration_tests/model_runtime/__mock/xinference.py b/api/tests/integration_tests/model_runtime/__mock/xinference.py index 7cb0a1318e..777737187e 100644 --- a/api/tests/integration_tests/model_runtime/__mock/xinference.py +++ b/api/tests/integration_tests/model_runtime/__mock/xinference.py @@ -19,40 +19,43 @@ from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage class MockXinferenceClass: - def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]: - if not re.match(r'https?:\/\/[^\s\/$.?#].[^\s]*$', self.base_url): - raise RuntimeError('404 Not Found') - - if 'generate' == model_uid: + def get_chat_model( + self: Client, model_uid: str + ) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]: + if not re.match(r"https?:\/\/[^\s\/$.?#].[^\s]*$", self.base_url): + raise RuntimeError("404 Not Found") + + if "generate" == model_uid: return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={}) - if 'chat' == model_uid: + if "chat" == model_uid: return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={}) - if 'embedding' == model_uid: + if "embedding" == model_uid: return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={}) - if 'rerank' == model_uid: + if "rerank" == model_uid: return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={}) - raise RuntimeError('404 Not Found') - + raise RuntimeError("404 Not Found") + def get(self: Session, url: str, **kwargs): response = Response() - if 'v1/models/' in url: + if "v1/models/" in url: # get model uid - model_uid = url.split('/')[-1] or '' - if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \ - model_uid not in ['generate', 'chat', 'embedding', 'rerank']: + model_uid = url.split("/")[-1] or "" + if not re.match( + r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", model_uid + ) and model_uid not in ["generate", "chat", "embedding", "rerank"]: response.status_code = 404 - response._content = b'{}' + response._content = b"{}" return response # check if url is valid - if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url): + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", url): response.status_code = 404 - response._content = b'{}' + response._content = b"{}" return response - - if model_uid in ['generate', 'chat']: + + if model_uid in ["generate", "chat"]: response.status_code = 200 - response._content = b'''{ + response._content = b"""{ "model_type": "LLM", "address": "127.0.0.1:43877", "accelerators": [ @@ -75,12 +78,12 @@ class MockXinferenceClass: "revision": null, "context_length": 2048, "replica": 1 - }''' + }""" return response - - elif model_uid == 'embedding': + + elif model_uid == "embedding": response.status_code = 200 - response._content = b'''{ + response._content = b"""{ "model_type": "embedding", "address": "127.0.0.1:43877", "accelerators": [ @@ -93,51 +96,48 @@ class MockXinferenceClass: ], "revision": null, "max_tokens": 512 - }''' + }""" return response - - elif 'v1/cluster/auth' in url: + + elif "v1/cluster/auth" in url: response.status_code = 200 - response._content = b'''{ + response._content = b"""{ "auth": true - }''' + }""" return response - + def _check_cluster_authenticated(self): self._cluster_authed = True - - def rerank(self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool) -> dict: + + def rerank( + self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool + ) -> dict: # check if self._model_uid is a valid uuid - if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \ - self._model_uid != 'rerank': - raise RuntimeError('404 Not Found') - - if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._base_url): - raise RuntimeError('404 Not Found') + if ( + not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid) + and self._model_uid != "rerank" + ): + raise RuntimeError("404 Not Found") + + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._base_url): + raise RuntimeError("404 Not Found") if top_n is None: top_n = 1 return { - 'results': [ - { - 'index': i, - 'document': doc, - 'relevance_score': 0.9 - } - for i, doc in enumerate(documents[:top_n]) + "results": [ + {"index": i, "document": doc, "relevance_score": 0.9} for i, doc in enumerate(documents[:top_n]) ] } - - def create_embedding( - self: RESTfulGenerateModelHandle, - input: Union[str, list[str]], - **kwargs - ) -> dict: + + def create_embedding(self: RESTfulGenerateModelHandle, input: Union[str, list[str]], **kwargs) -> dict: # check if self._model_uid is a valid uuid - if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \ - self._model_uid != 'embedding': - raise RuntimeError('404 Not Found') + if ( + not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid) + and self._model_uid != "embedding" + ): + raise RuntimeError("404 Not Found") if isinstance(input, str): input = [input] @@ -147,32 +147,27 @@ class MockXinferenceClass: object="list", model=self._model_uid, data=[ - EmbeddingData( - index=i, - object="embedding", - embedding=[1919.810 for _ in range(768)] - ) + EmbeddingData(index=i, object="embedding", embedding=[1919.810 for _ in range(768)]) for i in range(ipt_len) ], - usage=EmbeddingUsage( - prompt_tokens=ipt_len, - total_tokens=ipt_len - ) + usage=EmbeddingUsage(prompt_tokens=ipt_len, total_tokens=ipt_len), ) return embedding -MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + @pytest.fixture def setup_xinference_mock(request, monkeypatch: MonkeyPatch): if MOCK: - monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model) - monkeypatch.setattr(Client, '_check_cluster_authenticated', MockXinferenceClass._check_cluster_authenticated) - monkeypatch.setattr(Session, 'get', MockXinferenceClass.get) - monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding) - monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank) + monkeypatch.setattr(Client, "get_model", MockXinferenceClass.get_chat_model) + monkeypatch.setattr(Client, "_check_cluster_authenticated", MockXinferenceClass._check_cluster_authenticated) + monkeypatch.setattr(Session, "get", MockXinferenceClass.get) + monkeypatch.setattr(RESTfulEmbeddingModelHandle, "create_embedding", MockXinferenceClass.create_embedding) + monkeypatch.setattr(RESTfulRerankModelHandle, "rerank", MockXinferenceClass.rerank) yield if MOCK: - monkeypatch.undo() \ No newline at end of file + monkeypatch.undo() diff --git a/api/tests/integration_tests/model_runtime/anthropic/test_llm.py b/api/tests/integration_tests/model_runtime/anthropic/test_llm.py index 0d54d97daa..8f7e9ec487 100644 --- a/api/tests/integration_tests/model_runtime/anthropic/test_llm.py +++ b/api/tests/integration_tests/model_runtime/anthropic/test_llm.py @@ -10,79 +10,60 @@ from core.model_runtime.model_providers.anthropic.llm.llm import AnthropicLargeL from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock -@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True) def test_validate_credentials(setup_anthropic_mock): model = AnthropicLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='claude-instant-1.2', - credentials={ - 'anthropic_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="claude-instant-1.2", credentials={"anthropic_api_key": "invalid_key"}) model.validate_credentials( - model='claude-instant-1.2', - credentials={ - 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') - } + model="claude-instant-1.2", credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")} ) -@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True) def test_invoke_model(setup_anthropic_mock): model = AnthropicLargeLanguageModel() response = model.invoke( - model='claude-instant-1.2', + model="claude-instant-1.2", credentials={ - 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY'), - 'anthropic_api_url': os.environ.get('ANTHROPIC_API_URL') + "anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY"), + "anthropic_api_url": os.environ.get("ANTHROPIC_API_URL"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'top_p': 1.0, - 'max_tokens': 10 - }, - stop=['How'], + model_parameters={"temperature": 0.0, "top_p": 1.0, "max_tokens": 10}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 -@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True) def test_invoke_stream_model(setup_anthropic_mock): model = AnthropicLargeLanguageModel() response = model.invoke( - model='claude-instant-1.2', - credentials={ - 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') - }, + model="claude-instant-1.2", + credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -98,18 +79,14 @@ def test_get_num_tokens(): model = AnthropicLargeLanguageModel() num_tokens = model.get_num_tokens( - model='claude-instant-1.2', - credentials={ - 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') - }, + model="claude-instant-1.2", + credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 18 diff --git a/api/tests/integration_tests/model_runtime/anthropic/test_provider.py b/api/tests/integration_tests/model_runtime/anthropic/test_provider.py index 7eaa40dfdd..6f1e50f431 100644 --- a/api/tests/integration_tests/model_runtime/anthropic/test_provider.py +++ b/api/tests/integration_tests/model_runtime/anthropic/test_provider.py @@ -7,17 +7,11 @@ from core.model_runtime.model_providers.anthropic.anthropic import AnthropicProv from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock -@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True) def test_validate_provider_credentials(setup_anthropic_mock): provider = AnthropicProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={ - 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/azure_ai_studio/__init__.py b/api/tests/integration_tests/model_runtime/azure_ai_studio/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/model_runtime/azure_ai_studio/test_llm.py b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_llm.py new file mode 100644 index 0000000000..8655b43d8f --- /dev/null +++ b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_llm.py @@ -0,0 +1,113 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.azure_ai_studio.llm.llm import AzureAIStudioLargeLanguageModel +from tests.integration_tests.model_runtime.__mock.azure_ai_studio import setup_azure_ai_studio_mock + + +@pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True) +def test_validate_credentials(setup_azure_ai_studio_mock): + model = AzureAIStudioLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="gpt-35-turbo", + credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")}, + ) + + model.validate_credentials( + model="gpt-35-turbo", + credentials={ + "api_key": os.getenv("AZURE_AI_STUDIO_API_KEY"), + "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE"), + }, + ) + + +@pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True) +def test_invoke_model(setup_azure_ai_studio_mock): + model = AzureAIStudioLargeLanguageModel() + + result = model.invoke( + model="gpt-35-turbo", + credentials={ + "api_key": os.getenv("AZURE_AI_STUDIO_API_KEY"), + "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE"), + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + stream=False, + user="abc-123", + ) + + assert isinstance(result, LLMResult) + assert len(result.message.content) > 0 + + +@pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True) +def test_invoke_stream_model(setup_azure_ai_studio_mock): + model = AzureAIStudioLargeLanguageModel() + + result = model.invoke( + model="gpt-35-turbo", + credentials={ + "api_key": os.getenv("AZURE_AI_STUDIO_API_KEY"), + "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE"), + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + stream=True, + user="abc-123", + ) + + assert isinstance(result, Generator) + + for chunk in result: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + if chunk.delta.finish_reason is not None: + assert chunk.delta.usage is not None + assert chunk.delta.usage.completion_tokens > 0 + + +def test_get_num_tokens(): + model = AzureAIStudioLargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="gpt-35-turbo", + credentials={ + "api_key": os.getenv("AZURE_AI_STUDIO_API_KEY"), + "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE"), + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + ) + + assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/azure_ai_studio/test_provider.py b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_provider.py new file mode 100644 index 0000000000..8afe38b09b --- /dev/null +++ b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_provider.py @@ -0,0 +1,17 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.azure_ai_studio.azure_ai_studio import AzureAIStudioProvider + + +def test_validate_provider_credentials(): + provider = AzureAIStudioProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={}) + + provider.validate_provider_credentials( + credentials={"api_key": os.getenv("AZURE_AI_STUDIO_API_KEY"), "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")} + ) diff --git a/api/tests/integration_tests/model_runtime/azure_ai_studio/test_rerank.py b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_rerank.py new file mode 100644 index 0000000000..466facc5ff --- /dev/null +++ b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_rerank.py @@ -0,0 +1,50 @@ +import os + +import pytest + +from core.model_runtime.entities.rerank_entities import RerankResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureAIStudioRerankModel + + +def test_validate_credentials(): + model = AzureAIStudioRerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="azure-ai-studio-rerank-v1", + credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")}, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8, + ) + + +def test_invoke_model(): + model = AzureAIStudioRerankModel() + + result = model.invoke( + model="azure-ai-studio-rerank-v1", + credentials={ + "api_key": os.getenv("AZURE_AI_STUDIO_JWT_TOKEN"), + "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE"), + }, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8, + ) + + assert isinstance(result, RerankResult) + assert len(result.docs) == 1 + assert result.docs[0].index == 1 + assert result.docs[0].score >= 0.8 diff --git a/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py b/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py index 6afec540ad..8f50ebf7a6 100644 --- a/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py @@ -17,101 +17,90 @@ from core.model_runtime.model_providers.azure_openai.llm.llm import AzureOpenAIL from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_validate_credentials_for_chat_model(setup_openai_mock): model = AzureOpenAILargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='gpt35', + model="gpt35", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': 'invalid_key', - 'base_model_name': 'gpt-35-turbo' - } + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": "invalid_key", + "base_model_name": "gpt-35-turbo", + }, ) model.validate_credentials( - model='gpt35', + model="gpt35", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-35-turbo' - } + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo", + }, ) -@pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) def test_validate_credentials_for_completion_model(setup_openai_mock): model = AzureOpenAILargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='gpt-35-turbo-instruct', + model="gpt-35-turbo-instruct", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': 'invalid_key', - 'base_model_name': 'gpt-35-turbo-instruct' - } + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": "invalid_key", + "base_model_name": "gpt-35-turbo-instruct", + }, ) model.validate_credentials( - model='gpt-35-turbo-instruct', + model="gpt-35-turbo-instruct", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-35-turbo-instruct' - } + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo-instruct", + }, ) -@pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) def test_invoke_completion_model(setup_openai_mock): model = AzureOpenAILargeLanguageModel() result = model.invoke( - model='gpt-35-turbo-instruct', + model="gpt-35-turbo-instruct", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-35-turbo-instruct' - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 1 + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo-instruct", }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.0, "max_tokens": 1}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) def test_invoke_stream_completion_model(setup_openai_mock): model = AzureOpenAILargeLanguageModel() result = model.invoke( - model='gpt-35-turbo-instruct', + model="gpt-35-turbo-instruct", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-35-turbo-instruct' - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo-instruct", }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(result, Generator) @@ -122,66 +111,60 @@ def test_invoke_stream_completion_model(setup_openai_mock): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model(setup_openai_mock): model = AzureOpenAILargeLanguageModel() result = model.invoke( - model='gpt35', + model="gpt35", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-35-turbo' + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.0, - 'top_p': 1.0, - 'presence_penalty': 0.0, - 'frequency_penalty': 0.0, - 'max_tokens': 10 + "temperature": 0.0, + "top_p": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "max_tokens": 10, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_stream_chat_model(setup_openai_mock): model = AzureOpenAILargeLanguageModel() result = model.invoke( - model='gpt35', + model="gpt35", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-35-turbo' + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(result, Generator) @@ -194,109 +177,87 @@ def test_invoke_stream_chat_model(setup_openai_mock): assert chunk.delta.usage is not None assert chunk.delta.usage.completion_tokens > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model_with_vision(setup_openai_mock): model = AzureOpenAILargeLanguageModel() result = model.invoke( - model='gpt-4v', + model="gpt-4v", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-4-vision-preview' + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-4-vision-preview", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( content=[ TextPromptMessageContent( - data='Hello World!', + data="Hello World!", ), ImagePromptMessageContent( - data='' - ) + data="" + ), ] - ) + ), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model_with_tools(setup_openai_mock): model = AzureOpenAILargeLanguageModel() result = model.invoke( - model='gpt-35-turbo', + model="gpt-35-turbo", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-35-turbo' + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( content="what's the weather today in London?", - ) + ), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, tools=[ PromptMessageTool( - name='get_weather', - description='Determine weather in my location', + name="get_weather", + description="Determine weather in my location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ), PromptMessageTool( - name='get_stock_price', - description='Get the current stock price', + name="get_stock_price", + description="Get the current stock price", parameters={ "type": "object", - "properties": { - "symbol": { - "type": "string", - "description": "The stock symbol" - } - }, - "required": [ - "symbol" - ] - } - ) + "properties": {"symbol": {"type": "string", "description": "The stock symbol"}}, + "required": ["symbol"], + }, + ), ], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) @@ -308,32 +269,22 @@ def test_get_num_tokens(): model = AzureOpenAILargeLanguageModel() num_tokens = model.get_num_tokens( - model='gpt-35-turbo-instruct', - credentials={ - 'base_model_name': 'gpt-35-turbo-instruct' - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ] + model="gpt-35-turbo-instruct", + credentials={"base_model_name": "gpt-35-turbo-instruct"}, + prompt_messages=[UserPromptMessage(content="Hello World!")], ) assert num_tokens == 3 num_tokens = model.get_num_tokens( - model='gpt35', - credentials={ - 'base_model_name': 'gpt-35-turbo' - }, + model="gpt35", + credentials={"base_model_name": "gpt-35-turbo"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py index 8b838eb8fc..a1ae2b2e5b 100644 --- a/api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py @@ -8,45 +8,43 @@ from core.model_runtime.model_providers.azure_openai.text_embedding.text_embeddi from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) def test_validate_credentials(setup_openai_mock): model = AzureOpenAITextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='embedding', + model="embedding", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': 'invalid_key', - 'base_model_name': 'text-embedding-ada-002' - } + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": "invalid_key", + "base_model_name": "text-embedding-ada-002", + }, ) model.validate_credentials( - model='embedding', + model="embedding", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'text-embedding-ada-002' - } + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "text-embedding-ada-002", + }, ) -@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) def test_invoke_model(setup_openai_mock): model = AzureOpenAITextEmbeddingModel() result = model.invoke( - model='embedding', + model="embedding", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'text-embedding-ada-002' + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "text-embedding-ada-002", }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -58,14 +56,7 @@ def test_get_num_tokens(): model = AzureOpenAITextEmbeddingModel() num_tokens = model.get_num_tokens( - model='embedding', - credentials={ - 'base_model_name': 'text-embedding-ada-002' - }, - texts=[ - "hello", - "world" - ] + model="embedding", credentials={"base_model_name": "text-embedding-ada-002"}, texts=["hello", "world"] ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/baichuan/test_llm.py b/api/tests/integration_tests/model_runtime/baichuan/test_llm.py index 1cae9a6dd0..ad58610287 100644 --- a/api/tests/integration_tests/model_runtime/baichuan/test_llm.py +++ b/api/tests/integration_tests/model_runtime/baichuan/test_llm.py @@ -17,111 +17,99 @@ def test_predefined_models(): assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) + def test_validate_credentials_for_chat_model(): sleep(3) model = BaichuanLarguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='baichuan2-turbo', - credentials={ - 'api_key': 'invalid_key', - 'secret_key': 'invalid_key' - } + model="baichuan2-turbo", credentials={"api_key": "invalid_key", "secret_key": "invalid_key"} ) model.validate_credentials( - model='baichuan2-turbo', + model="baichuan2-turbo", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), - 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY') - } + "api_key": os.environ.get("BAICHUAN_API_KEY"), + "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"), + }, ) + def test_invoke_model(): sleep(3) model = BaichuanLarguageModel() response = model.invoke( - model='baichuan2-turbo', + model="baichuan2-turbo", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), - 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY') + "api_key": os.environ.get("BAICHUAN_API_KEY"), + "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_model_with_system_message(): sleep(3) model = BaichuanLarguageModel() response = model.invoke( - model='baichuan2-turbo', + model="baichuan2-turbo", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), - 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY') + "api_key": os.environ.get("BAICHUAN_API_KEY"), + "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"), }, prompt_messages=[ - SystemPromptMessage( - content='请记住你是Kasumi。' - ), - UserPromptMessage( - content='现在告诉我你是谁?' - ) + SystemPromptMessage(content="请记住你是Kasumi。"), + UserPromptMessage(content="现在告诉我你是谁?"), ], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_stream_model(): sleep(3) model = BaichuanLarguageModel() response = model.invoke( - model='baichuan2-turbo', + model="baichuan2-turbo", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), - 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY') + "api_key": os.environ.get("BAICHUAN_API_KEY"), + "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -131,34 +119,31 @@ def test_invoke_stream_model(): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_invoke_with_search(): sleep(3) model = BaichuanLarguageModel() response = model.invoke( - model='baichuan2-turbo', + model="baichuan2-turbo", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), - 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY') + "api_key": os.environ.get("BAICHUAN_API_KEY"), + "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"), }, - prompt_messages=[ - UserPromptMessage( - content='北京今天的天气怎么样' - ) - ], + prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, - 'with_search_enhance': True, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, + "with_search_enhance": True, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) - total_message = '' + total_message = "" for chunk in response: assert isinstance(chunk, LLMResultChunk) assert isinstance(chunk.delta, LLMResultChunkDelta) @@ -166,25 +151,22 @@ def test_invoke_with_search(): assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True total_message += chunk.delta.message.content - assert '不' not in total_message + assert "不" not in total_message + def test_get_num_tokens(): sleep(3) model = BaichuanLarguageModel() response = model.get_num_tokens( - model='baichuan2-turbo', + model="baichuan2-turbo", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), - 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY') + "api_key": os.environ.get("BAICHUAN_API_KEY"), + "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - tools=[] + prompt_messages=[UserPromptMessage(content="Hello World!")], + tools=[], ) assert isinstance(response, int) - assert response == 9 \ No newline at end of file + assert response == 9 diff --git a/api/tests/integration_tests/model_runtime/baichuan/test_provider.py b/api/tests/integration_tests/model_runtime/baichuan/test_provider.py index 87b3d9a609..4036edfb7a 100644 --- a/api/tests/integration_tests/model_runtime/baichuan/test_provider.py +++ b/api/tests/integration_tests/model_runtime/baichuan/test_provider.py @@ -10,14 +10,6 @@ def test_validate_provider_credentials(): provider = BaichuanProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={ - 'api_key': 'hahahaha' - } - ) + provider.validate_provider_credentials(credentials={"api_key": "hahahaha"}) - provider.validate_provider_credentials( - credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("BAICHUAN_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py b/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py index 1210ebc53d..cbc63f3978 100644 --- a/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py @@ -11,18 +11,10 @@ def test_validate_credentials(): model = BaichuanTextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='baichuan-text-embedding', - credentials={ - 'api_key': 'invalid_key' - } - ) + model.validate_credentials(model="baichuan-text-embedding", credentials={"api_key": "invalid_key"}) model.validate_credentials( - model='baichuan-text-embedding', - credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY') - } + model="baichuan-text-embedding", credentials={"api_key": os.environ.get("BAICHUAN_API_KEY")} ) @@ -30,44 +22,40 @@ def test_invoke_model(): model = BaichuanTextEmbeddingModel() result = model.invoke( - model='baichuan-text-embedding', + model="baichuan-text-embedding", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), + "api_key": os.environ.get("BAICHUAN_API_KEY"), }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) assert len(result.embeddings) == 2 assert result.usage.total_tokens == 6 + def test_get_num_tokens(): model = BaichuanTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='baichuan-text-embedding', + model="baichuan-text-embedding", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), + "api_key": os.environ.get("BAICHUAN_API_KEY"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 + def test_max_chunks(): model = BaichuanTextEmbeddingModel() result = model.invoke( - model='baichuan-text-embedding', + model="baichuan-text-embedding", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), + "api_key": os.environ.get("BAICHUAN_API_KEY"), }, texts=[ "hello", @@ -92,8 +80,8 @@ def test_max_chunks(): "world", "hello", "world", - ] + ], ) assert isinstance(result, TextEmbeddingResult) - assert len(result.embeddings) == 22 \ No newline at end of file + assert len(result.embeddings) == 22 diff --git a/api/tests/integration_tests/model_runtime/bedrock/test_llm.py b/api/tests/integration_tests/model_runtime/bedrock/test_llm.py index 20dc11151a..c19ec35a6e 100644 --- a/api/tests/integration_tests/model_runtime/bedrock/test_llm.py +++ b/api/tests/integration_tests/model_runtime/bedrock/test_llm.py @@ -13,77 +13,63 @@ def test_validate_credentials(): model = BedrockLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='meta.llama2-13b-chat-v1', - credentials={ - 'anthropic_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="meta.llama2-13b-chat-v1", credentials={"anthropic_api_key": "invalid_key"}) model.validate_credentials( - model='meta.llama2-13b-chat-v1', + model="meta.llama2-13b-chat-v1", credentials={ "aws_region": os.getenv("AWS_REGION"), "aws_access_key": os.getenv("AWS_ACCESS_KEY"), - "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") - } + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), + }, ) + def test_invoke_model(): model = BedrockLargeLanguageModel() response = model.invoke( - model='meta.llama2-13b-chat-v1', + model="meta.llama2-13b-chat-v1", credentials={ "aws_region": os.getenv("AWS_REGION"), "aws_access_key": os.getenv("AWS_ACCESS_KEY"), - "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'top_p': 1.0, - 'max_tokens_to_sample': 10 - }, - stop=['How'], + model_parameters={"temperature": 0.0, "top_p": 1.0, "max_tokens_to_sample": 10}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 + def test_invoke_stream_model(): model = BedrockLargeLanguageModel() response = model.invoke( - model='meta.llama2-13b-chat-v1', + model="meta.llama2-13b-chat-v1", credentials={ "aws_region": os.getenv("AWS_REGION"), "aws_access_key": os.getenv("AWS_ACCESS_KEY"), - "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens_to_sample': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens_to_sample": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -100,20 +86,18 @@ def test_get_num_tokens(): model = BedrockLargeLanguageModel() num_tokens = model.get_num_tokens( - model='meta.llama2-13b-chat-v1', - credentials = { + model="meta.llama2-13b-chat-v1", + credentials={ "aws_region": os.getenv("AWS_REGION"), "aws_access_key": os.getenv("AWS_ACCESS_KEY"), - "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), }, messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 18 diff --git a/api/tests/integration_tests/model_runtime/bedrock/test_provider.py b/api/tests/integration_tests/model_runtime/bedrock/test_provider.py index e53d4c1db2..080727829e 100644 --- a/api/tests/integration_tests/model_runtime/bedrock/test_provider.py +++ b/api/tests/integration_tests/model_runtime/bedrock/test_provider.py @@ -10,14 +10,12 @@ def test_validate_provider_credentials(): provider = BedrockProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) provider.validate_provider_credentials( credentials={ "aws_region": os.getenv("AWS_REGION"), "aws_access_key": os.getenv("AWS_ACCESS_KEY"), - "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), } ) diff --git a/api/tests/integration_tests/model_runtime/chatglm/test_llm.py b/api/tests/integration_tests/model_runtime/chatglm/test_llm.py index e32f01a315..418e88874d 100644 --- a/api/tests/integration_tests/model_runtime/chatglm/test_llm.py +++ b/api/tests/integration_tests/model_runtime/chatglm/test_llm.py @@ -23,79 +23,64 @@ def test_predefined_models(): assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_validate_credentials_for_chat_model(setup_openai_mock): model = ChatGLMLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='chatglm2-6b', - credentials={ - 'api_base': 'invalid_key' - } - ) + model.validate_credentials(model="chatglm2-6b", credentials={"api_base": "invalid_key"}) - model.validate_credentials( - model='chatglm2-6b', - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - } - ) + model.validate_credentials(model="chatglm2-6b", credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}) -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_model(setup_openai_mock): model = ChatGLMLargeLanguageModel() response = model.invoke( - model='chatglm2-6b', - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - }, + model="chatglm2-6b", + credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_stream_model(setup_openai_mock): model = ChatGLMLargeLanguageModel() response = model.invoke( - model='chatglm2-6b', - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - }, + model="chatglm2-6b", + credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -105,56 +90,45 @@ def test_invoke_stream_model(setup_openai_mock): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_stream_model_with_functions(setup_openai_mock): model = ChatGLMLargeLanguageModel() response = model.invoke( - model='chatglm3-6b', - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - }, + model="chatglm3-6b", + credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}, prompt_messages=[ SystemPromptMessage( - content='你是一个天气机器人,你不知道今天的天气怎么样,你需要通过调用一个函数来获取天气信息。' + content="你是一个天气机器人,你不知道今天的天气怎么样,你需要通过调用一个函数来获取天气信息。" ), - UserPromptMessage( - content='波士顿天气如何?' - ) + UserPromptMessage(content="波士顿天气如何?"), ], model_parameters={ - 'temperature': 0, - 'top_p': 1.0, + "temperature": 0, + "top_p": 1.0, }, - stop=['you'], - user='abc-123', + stop=["you"], + user="abc-123", stream=True, tools=[ PromptMessageTool( - name='get_current_weather', - description='Get the current weather in a given location', + name="get_current_weather", + description="Get the current weather in a given location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ) - ] + ], ) assert isinstance(response, Generator) - + call: LLMResultChunk = None chunks = [] @@ -170,122 +144,87 @@ def test_invoke_stream_model_with_functions(setup_openai_mock): break assert call is not None - assert call.delta.message.tool_calls[0].function.name == 'get_current_weather' + assert call.delta.message.tool_calls[0].function.name == "get_current_weather" -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_model_with_functions(setup_openai_mock): model = ChatGLMLargeLanguageModel() response = model.invoke( - model='chatglm3-6b', - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - }, - prompt_messages=[ - UserPromptMessage( - content='What is the weather like in San Francisco?' - ) - ], + model="chatglm3-6b", + credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}, + prompt_messages=[UserPromptMessage(content="What is the weather like in San Francisco?")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], - user='abc-123', + stop=["you"], + user="abc-123", stream=False, tools=[ PromptMessageTool( - name='get_current_weather', - description='Get the current weather in a given location', + name="get_current_weather", + description="Get the current weather in a given location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ) - ] + ], ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 - assert response.message.tool_calls[0].function.name == 'get_current_weather' + assert response.message.tool_calls[0].function.name == "get_current_weather" def test_get_num_tokens(): model = ChatGLMLargeLanguageModel() num_tokens = model.get_num_tokens( - model='chatglm2-6b', - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - }, + model="chatglm2-6b", + credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], tools=[ PromptMessageTool( - name='get_current_weather', - description='Get the current weather in a given location', + name="get_current_weather", + description="Get the current weather in a given location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ) - ] + ], ) assert isinstance(num_tokens, int) assert num_tokens == 77 num_tokens = model.get_num_tokens( - model='chatglm2-6b', - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - }, + model="chatglm2-6b", + credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], ) assert isinstance(num_tokens, int) - assert num_tokens == 21 \ No newline at end of file + assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/chatglm/test_provider.py b/api/tests/integration_tests/model_runtime/chatglm/test_provider.py index e9c5c4da75..7907805d07 100644 --- a/api/tests/integration_tests/model_runtime/chatglm/test_provider.py +++ b/api/tests/integration_tests/model_runtime/chatglm/test_provider.py @@ -7,19 +7,11 @@ from core.model_runtime.model_providers.chatglm.chatglm import ChatGLMProvider from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_validate_provider_credentials(setup_openai_mock): provider = ChatGLMProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={ - 'api_base': 'hahahaha' - } - ) + provider.validate_provider_credentials(credentials={"api_base": "hahahaha"}) - provider.validate_provider_credentials( - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - } - ) + provider.validate_provider_credentials(credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}) diff --git a/api/tests/integration_tests/model_runtime/cohere/test_llm.py b/api/tests/integration_tests/model_runtime/cohere/test_llm.py index 5ce4f8ecfe..b7f707e935 100644 --- a/api/tests/integration_tests/model_runtime/cohere/test_llm.py +++ b/api/tests/integration_tests/model_runtime/cohere/test_llm.py @@ -13,87 +13,49 @@ def test_validate_credentials_for_chat_model(): model = CohereLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='command-light-chat', - credentials={ - 'api_key': 'invalid_key' - } - ) + model.validate_credentials(model="command-light-chat", credentials={"api_key": "invalid_key"}) - model.validate_credentials( - model='command-light-chat', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - } - ) + model.validate_credentials(model="command-light-chat", credentials={"api_key": os.environ.get("COHERE_API_KEY")}) def test_validate_credentials_for_completion_model(): model = CohereLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='command-light', - credentials={ - 'api_key': 'invalid_key' - } - ) + model.validate_credentials(model="command-light", credentials={"api_key": "invalid_key"}) - model.validate_credentials( - model='command-light', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - } - ) + model.validate_credentials(model="command-light", credentials={"api_key": os.environ.get("COHERE_API_KEY")}) def test_invoke_completion_model(): model = CohereLargeLanguageModel() - credentials = { - 'api_key': os.environ.get('COHERE_API_KEY') - } + credentials = {"api_key": os.environ.get("COHERE_API_KEY")} result = model.invoke( - model='command-light', + model="command-light", credentials=credentials, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 1 - }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.0, "max_tokens": 1}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 - assert model._num_tokens_from_string('command-light', credentials, result.message.content) == 1 + assert model._num_tokens_from_string("command-light", credentials, result.message.content) == 1 def test_invoke_stream_completion_model(): model = CohereLargeLanguageModel() result = model.invoke( - model='command-light', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model="command-light", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(result, Generator) @@ -109,28 +71,24 @@ def test_invoke_chat_model(): model = CohereLargeLanguageModel() result = model.invoke( - model='command-light-chat', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, + model="command-light-chat", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.0, - 'p': 0.99, - 'presence_penalty': 0.0, - 'frequency_penalty': 0.0, - 'max_tokens': 10 + "temperature": 0.0, + "p": 0.99, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "max_tokens": 10, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) @@ -141,24 +99,17 @@ def test_invoke_stream_chat_model(): model = CohereLargeLanguageModel() result = model.invoke( - model='command-light-chat', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, + model="command-light-chat", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(result, Generator) @@ -177,32 +128,22 @@ def test_get_num_tokens(): model = CohereLargeLanguageModel() num_tokens = model.get_num_tokens( - model='command-light', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ] + model="command-light", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], ) assert num_tokens == 3 num_tokens = model.get_num_tokens( - model='command-light-chat', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, + model="command-light-chat", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 15 @@ -213,25 +154,17 @@ def test_fine_tuned_model(): # test invoke result = model.invoke( - model='85ec47be-6139-4f75-a4be-0f0ec1ef115c-ft', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY'), - 'mode': 'completion' - }, + model="85ec47be-6139-4f75-a4be-0f0ec1ef115c-ft", + credentials={"api_key": os.environ.get("COHERE_API_KEY"), "mode": "completion"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) @@ -242,25 +175,17 @@ def test_fine_tuned_chat_model(): # test invoke result = model.invoke( - model='94f2d55a-4c79-4c00-bde4-23962e74b170-ft', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY'), - 'mode': 'chat' - }, + model="94f2d55a-4c79-4c00-bde4-23962e74b170-ft", + credentials={"api_key": os.environ.get("COHERE_API_KEY"), "mode": "chat"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) diff --git a/api/tests/integration_tests/model_runtime/cohere/test_provider.py b/api/tests/integration_tests/model_runtime/cohere/test_provider.py index a8f56b6194..fb7e6d3498 100644 --- a/api/tests/integration_tests/model_runtime/cohere/test_provider.py +++ b/api/tests/integration_tests/model_runtime/cohere/test_provider.py @@ -10,12 +10,6 @@ def test_validate_provider_credentials(): provider = CohereProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("COHERE_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/cohere/test_rerank.py b/api/tests/integration_tests/model_runtime/cohere/test_rerank.py index 415c5fbfda..a1b6922128 100644 --- a/api/tests/integration_tests/model_runtime/cohere/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/cohere/test_rerank.py @@ -11,29 +11,17 @@ def test_validate_credentials(): model = CohereRerankModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='rerank-english-v2.0', - credentials={ - 'api_key': 'invalid_key' - } - ) + model.validate_credentials(model="rerank-english-v2.0", credentials={"api_key": "invalid_key"}) - model.validate_credentials( - model='rerank-english-v2.0', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - } - ) + model.validate_credentials(model="rerank-english-v2.0", credentials={"api_key": os.environ.get("COHERE_API_KEY")}) def test_invoke_model(): model = CohereRerankModel() result = model.invoke( - model='rerank-english-v2.0', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, + model="rerank-english-v2.0", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, query="What is the capital of the United States?", docs=[ "Carson City is the capital city of the American state of Nevada. At the 2010 United States " @@ -41,9 +29,9 @@ def test_invoke_model(): "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) " "is the capital of the United States. It is a federal district. The President of the USA and many major " "national government offices are in the territory. This makes it the political center of the United " - "States of America." + "States of America.", ], - score_threshold=0.8 + score_threshold=0.8, ) assert isinstance(result, RerankResult) diff --git a/api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py b/api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py index 5017ba47e1..ae26d36635 100644 --- a/api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py @@ -11,18 +11,10 @@ def test_validate_credentials(): model = CohereTextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='embed-multilingual-v3.0', - credentials={ - 'api_key': 'invalid_key' - } - ) + model.validate_credentials(model="embed-multilingual-v3.0", credentials={"api_key": "invalid_key"}) model.validate_credentials( - model='embed-multilingual-v3.0', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - } + model="embed-multilingual-v3.0", credentials={"api_key": os.environ.get("COHERE_API_KEY")} ) @@ -30,17 +22,10 @@ def test_invoke_model(): model = CohereTextEmbeddingModel() result = model.invoke( - model='embed-multilingual-v3.0', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, - texts=[ - "hello", - "world", - " ".join(["long_text"] * 100), - " ".join(["another_long_text"] * 100) - ], - user="abc-123" + model="embed-multilingual-v3.0", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, + texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -52,14 +37,9 @@ def test_get_num_tokens(): model = CohereTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='embed-multilingual-v3.0', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, - texts=[ - "hello", - "world" - ] + model="embed-multilingual-v3.0", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, + texts=["hello", "world"], ) assert num_tokens == 3 diff --git a/api/tests/integration_tests/model_runtime/google/test_llm.py b/api/tests/integration_tests/model_runtime/google/test_llm.py index 00d907d19e..4d9d490a87 100644 --- a/api/tests/integration_tests/model_runtime/google/test_llm.py +++ b/api/tests/integration_tests/model_runtime/google/test_llm.py @@ -16,103 +16,73 @@ from core.model_runtime.model_providers.google.llm.llm import GoogleLargeLanguag from tests.integration_tests.model_runtime.__mock.google import setup_google_mock -@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) def test_validate_credentials(setup_google_mock): model = GoogleLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='gemini-pro', - credentials={ - 'google_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="gemini-pro", credentials={"google_api_key": "invalid_key"}) - model.validate_credentials( - model='gemini-pro', - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - } - ) + model.validate_credentials(model="gemini-pro", credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}) -@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) def test_invoke_model(setup_google_mock): model = GoogleLargeLanguageModel() response = model.invoke( - model='gemini-pro', - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - }, + model="gemini-pro", + credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', - ), - UserPromptMessage( - content='Give me your worst dad joke or i will unplug you' + content="You are a helpful AI assistant.", ), + UserPromptMessage(content="Give me your worst dad joke or i will unplug you"), AssistantPromptMessage( - content='Why did the scarecrow win an award? Because he was outstanding in his field!' + content="Why did the scarecrow win an award? Because he was outstanding in his field!" ), UserPromptMessage( content=[ - TextPromptMessageContent( - data="ok something snarkier pls" - ), - TextPromptMessageContent( - data="i may still unplug you" - )] - ) + TextPromptMessageContent(data="ok something snarkier pls"), + TextPromptMessageContent(data="i may still unplug you"), + ] + ), ], - model_parameters={ - 'temperature': 0.5, - 'top_p': 1.0, - 'max_tokens_to_sample': 2048 - }, - stop=['How'], + model_parameters={"temperature": 0.5, "top_p": 1.0, "max_tokens_to_sample": 2048}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 -@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) def test_invoke_stream_model(setup_google_mock): model = GoogleLargeLanguageModel() response = model.invoke( - model='gemini-pro', - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - }, + model="gemini-pro", + credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', - ), - UserPromptMessage( - content='Give me your worst dad joke or i will unplug you' + content="You are a helpful AI assistant.", ), + UserPromptMessage(content="Give me your worst dad joke or i will unplug you"), AssistantPromptMessage( - content='Why did the scarecrow win an award? Because he was outstanding in his field!' + content="Why did the scarecrow win an award? Because he was outstanding in his field!" ), UserPromptMessage( content=[ - TextPromptMessageContent( - data="ok something snarkier pls" - ), - TextPromptMessageContent( - data="i may still unplug you" - )] - ) + TextPromptMessageContent(data="ok something snarkier pls"), + TextPromptMessageContent(data="i may still unplug you"), + ] + ), ], - model_parameters={ - 'temperature': 0.2, - 'top_k': 5, - 'max_tokens_to_sample': 2048 - }, + model_parameters={"temperature": 0.2, "top_k": 5, "max_tokens_to_sample": 2048}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -123,88 +93,66 @@ def test_invoke_stream_model(setup_google_mock): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True -@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) def test_invoke_chat_model_with_vision(setup_google_mock): model = GoogleLargeLanguageModel() result = model.invoke( - model='gemini-pro-vision', - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - }, + model="gemini-pro-vision", + credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( - content=[ - TextPromptMessageContent( - data="what do you see?" - ), + content=[ + TextPromptMessageContent(data="what do you see?"), ImagePromptMessageContent( - data='' - ) + data="" + ), ] - ) + ), ], - model_parameters={ - 'temperature': 0.3, - 'top_p': 0.2, - 'top_k': 3, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.3, "top_p": 0.2, "top_k": 3, "max_tokens": 100}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 -@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock): model = GoogleLargeLanguageModel() result = model.invoke( - model='gemini-pro-vision', - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - }, + model="gemini-pro-vision", + credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}, prompt_messages=[ - SystemPromptMessage( - content='You are a helpful AI assistant.' - ), - UserPromptMessage( - content=[ - TextPromptMessageContent( - data="what do you see?" - ), - ImagePromptMessageContent( - data='' - ) - ] - ), - AssistantPromptMessage( - content="I see a blue letter 'D' with a gradient from light blue to dark blue." - ), + SystemPromptMessage(content="You are a helpful AI assistant."), UserPromptMessage( content=[ - TextPromptMessageContent( - data="what about now?" - ), + TextPromptMessageContent(data="what do you see?"), ImagePromptMessageContent( - data='' - ) + data="" + ), ] - ) + ), + AssistantPromptMessage(content="I see a blue letter 'D' with a gradient from light blue to dark blue."), + UserPromptMessage( + content=[ + TextPromptMessageContent(data="what about now?"), + ImagePromptMessageContent( + data="" + ), + ] + ), ], - model_parameters={ - 'temperature': 0.3, - 'top_p': 0.2, - 'top_k': 3, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.3, "top_p": 0.2, "top_k": 3, "max_tokens": 100}, stream=False, - user="abc-123" + user="abc-123", ) print(f"resultz: {result.message.content}") @@ -212,23 +160,18 @@ def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock): assert len(result.message.content) > 0 - def test_get_num_tokens(): model = GoogleLargeLanguageModel() num_tokens = model.get_num_tokens( - model='gemini-pro', - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - }, + model="gemini-pro", + credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens > 0 # The exact number of tokens may vary based on the model's tokenization diff --git a/api/tests/integration_tests/model_runtime/google/test_provider.py b/api/tests/integration_tests/model_runtime/google/test_provider.py index 103107ed5a..c217e4fe05 100644 --- a/api/tests/integration_tests/model_runtime/google/test_provider.py +++ b/api/tests/integration_tests/model_runtime/google/test_provider.py @@ -7,17 +7,11 @@ from core.model_runtime.model_providers.google.google import GoogleProvider from tests.integration_tests.model_runtime.__mock.google import setup_google_mock -@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) def test_validate_provider_credentials(setup_google_mock): provider = GoogleProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py b/api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py index 28cd0955b3..6a6cc874fa 100644 --- a/api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py +++ b/api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py @@ -10,87 +10,75 @@ from core.model_runtime.model_providers.huggingface_hub.llm.llm import Huggingfa from tests.integration_tests.model_runtime.__mock.huggingface import setup_huggingface_mock -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_hosted_inference_api_validate_credentials(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='HuggingFaceH4/zephyr-7b-beta', - credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': 'invalid_key' - } + model="HuggingFaceH4/zephyr-7b-beta", + credentials={"huggingfacehub_api_type": "hosted_inference_api", "huggingfacehub_api_token": "invalid_key"}, ) with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='fake-model', - credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': 'invalid_key' - } + model="fake-model", + credentials={"huggingfacehub_api_type": "hosted_inference_api", "huggingfacehub_api_token": "invalid_key"}, ) model.validate_credentials( - model='HuggingFaceH4/zephyr-7b-beta', + model="HuggingFaceH4/zephyr-7b-beta", credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY') - } + "huggingfacehub_api_type": "hosted_inference_api", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + }, ) -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_hosted_inference_api_invoke_model(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() response = model.invoke( - model='HuggingFaceH4/zephyr-7b-beta', + model="HuggingFaceH4/zephyr-7b-beta", credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY') + "huggingfacehub_api_type": "hosted_inference_api", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], + prompt_messages=[UserPromptMessage(content="Who are you?")], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_hosted_inference_api_invoke_stream_model(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() response = model.invoke( - model='HuggingFaceH4/zephyr-7b-beta', + model="HuggingFaceH4/zephyr-7b-beta", credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY') + "huggingfacehub_api_type": "hosted_inference_api", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], + prompt_messages=[UserPromptMessage(content="Who are you?")], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -101,86 +89,81 @@ def test_hosted_inference_api_invoke_stream_model(setup_huggingface_mock): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_inference_endpoints_text_generation_validate_credentials(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='openchat/openchat_3.5', + model="openchat/openchat_3.5", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': 'invalid_key', - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text-generation' - } + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": "invalid_key", + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"), + "task_type": "text-generation", + }, ) model.validate_credentials( - model='openchat/openchat_3.5', + model="openchat/openchat_3.5", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text-generation' - } + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"), + "task_type": "text-generation", + }, ) -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_inference_endpoints_text_generation_invoke_model(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() response = model.invoke( - model='openchat/openchat_3.5', + model="openchat/openchat_3.5", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text-generation' + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"), + "task_type": "text-generation", }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], + prompt_messages=[UserPromptMessage(content="Who are you?")], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_inference_endpoints_text_generation_invoke_stream_model(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() response = model.invoke( - model='openchat/openchat_3.5', + model="openchat/openchat_3.5", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text-generation' + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"), + "task_type": "text-generation", }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], + prompt_messages=[UserPromptMessage(content="Who are you?")], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -191,86 +174,81 @@ def test_inference_endpoints_text_generation_invoke_stream_model(setup_huggingfa assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_inference_endpoints_text2text_generation_validate_credentials(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='google/mt5-base', + model="google/mt5-base", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': 'invalid_key', - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text2text-generation' - } + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": "invalid_key", + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"), + "task_type": "text2text-generation", + }, ) model.validate_credentials( - model='google/mt5-base', + model="google/mt5-base", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text2text-generation' - } + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"), + "task_type": "text2text-generation", + }, ) -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_inference_endpoints_text2text_generation_invoke_model(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() response = model.invoke( - model='google/mt5-base', + model="google/mt5-base", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text2text-generation' + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"), + "task_type": "text2text-generation", }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], + prompt_messages=[UserPromptMessage(content="Who are you?")], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_inference_endpoints_text2text_generation_invoke_stream_model(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() response = model.invoke( - model='google/mt5-base', + model="google/mt5-base", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text2text-generation' + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"), + "task_type": "text2text-generation", }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], + prompt_messages=[UserPromptMessage(content="Who are you?")], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -286,18 +264,14 @@ def test_get_num_tokens(): model = HuggingfaceHubLargeLanguageModel() num_tokens = model.get_num_tokens( - model='google/mt5-base', + model="google/mt5-base", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text2text-generation' + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"), + "task_type": "text2text-generation", }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ] + prompt_messages=[UserPromptMessage(content="Hello World!")], ) assert num_tokens == 7 diff --git a/api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py b/api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py index d03b3186cb..0ee593f38a 100644 --- a/api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py @@ -14,19 +14,19 @@ def test_hosted_inference_api_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='facebook/bart-base', + model="facebook/bart-base", credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': 'invalid_key', - } + "huggingfacehub_api_type": "hosted_inference_api", + "huggingfacehub_api_token": "invalid_key", + }, ) model.validate_credentials( - model='facebook/bart-base', + model="facebook/bart-base", credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - } + "huggingfacehub_api_type": "hosted_inference_api", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + }, ) @@ -34,15 +34,12 @@ def test_hosted_inference_api_invoke_model(): model = HuggingfaceHubTextEmbeddingModel() result = model.invoke( - model='facebook/bart-base', + model="facebook/bart-base", credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), + "huggingfacehub_api_type": "hosted_inference_api", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert isinstance(result, TextEmbeddingResult) @@ -55,25 +52,25 @@ def test_inference_endpoints_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='all-MiniLM-L6-v2', + model="all-MiniLM-L6-v2", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': 'invalid_key', - 'huggingface_namespace': 'Dify-AI', - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'), - 'task_type': 'feature-extraction' - } + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": "invalid_key", + "huggingface_namespace": "Dify-AI", + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"), + "task_type": "feature-extraction", + }, ) model.validate_credentials( - model='all-MiniLM-L6-v2', + model="all-MiniLM-L6-v2", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingface_namespace': 'Dify-AI', - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'), - 'task_type': 'feature-extraction' - } + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingface_namespace": "Dify-AI", + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"), + "task_type": "feature-extraction", + }, ) @@ -81,18 +78,15 @@ def test_inference_endpoints_invoke_model(): model = HuggingfaceHubTextEmbeddingModel() result = model.invoke( - model='all-MiniLM-L6-v2', + model="all-MiniLM-L6-v2", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingface_namespace': 'Dify-AI', - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'), - 'task_type': 'feature-extraction' + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingface_namespace": "Dify-AI", + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"), + "task_type": "feature-extraction", }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert isinstance(result, TextEmbeddingResult) @@ -104,18 +98,15 @@ def test_get_num_tokens(): model = HuggingfaceHubTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='all-MiniLM-L6-v2', + model="all-MiniLM-L6-v2", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingface_namespace': 'Dify-AI', - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'), - 'task_type': 'feature-extraction' + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingface_namespace": "Dify-AI", + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"), + "task_type": "feature-extraction", }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py b/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py index ed371fbc07..b1fa9d5ca5 100644 --- a/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py +++ b/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py @@ -10,61 +10,59 @@ from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embe ) from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass -MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" @pytest.fixture def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch): if MOCK: - monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter) - monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize) - monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings) - monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank) + monkeypatch.setattr(TeiHelper, "get_tei_extra_parameter", MockTEIClass.get_tei_extra_parameter) + monkeypatch.setattr(TeiHelper, "invoke_tokenize", MockTEIClass.invoke_tokenize) + monkeypatch.setattr(TeiHelper, "invoke_embeddings", MockTEIClass.invoke_embeddings) + monkeypatch.setattr(TeiHelper, "invoke_rerank", MockTEIClass.invoke_rerank) yield if MOCK: monkeypatch.undo() -@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True) def test_validate_credentials(setup_tei_mock): model = HuggingfaceTeiTextEmbeddingModel() # model name is only used in mock - model_name = 'embedding' + model_name = "embedding" if MOCK: # TEI Provider will check model type by API endpoint, at real server, the model type is correct. # So we dont need to check model type here. Only check in mock with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='reranker', + model="reranker", credentials={ - 'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""), - } + "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), + }, ) model.validate_credentials( model=model_name, credentials={ - 'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""), - } + "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), + }, ) -@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True) def test_invoke_model(setup_tei_mock): model = HuggingfaceTeiTextEmbeddingModel() - model_name = 'embedding' + model_name = "embedding" result = model.invoke( model=model_name, credentials={ - 'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""), + "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) diff --git a/api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py b/api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py index 57e229e6be..45370d9fba 100644 --- a/api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py @@ -11,63 +11,65 @@ from core.model_runtime.model_providers.huggingface_tei.rerank.rerank import ( from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass -MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" @pytest.fixture def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch): if MOCK: - monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter) - monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize) - monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings) - monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank) + monkeypatch.setattr(TeiHelper, "get_tei_extra_parameter", MockTEIClass.get_tei_extra_parameter) + monkeypatch.setattr(TeiHelper, "invoke_tokenize", MockTEIClass.invoke_tokenize) + monkeypatch.setattr(TeiHelper, "invoke_embeddings", MockTEIClass.invoke_embeddings) + monkeypatch.setattr(TeiHelper, "invoke_rerank", MockTEIClass.invoke_rerank) yield if MOCK: monkeypatch.undo() -@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True) def test_validate_credentials(setup_tei_mock): model = HuggingfaceTeiRerankModel() # model name is only used in mock - model_name = 'reranker' + model_name = "reranker" if MOCK: # TEI Provider will check model type by API endpoint, at real server, the model type is correct. # So we dont need to check model type here. Only check in mock with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='embedding', + model="embedding", credentials={ - 'server_url': os.environ.get('TEI_RERANK_SERVER_URL'), - } + "server_url": os.environ.get("TEI_RERANK_SERVER_URL"), + }, ) model.validate_credentials( model=model_name, credentials={ - 'server_url': os.environ.get('TEI_RERANK_SERVER_URL'), - } + "server_url": os.environ.get("TEI_RERANK_SERVER_URL"), + }, ) -@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True) def test_invoke_model(setup_tei_mock): model = HuggingfaceTeiRerankModel() # model name is only used in mock - model_name = 'reranker' + model_name = "reranker" result = model.invoke( model=model_name, credentials={ - 'server_url': os.environ.get('TEI_RERANK_SERVER_URL'), + "server_url": os.environ.get("TEI_RERANK_SERVER_URL"), }, query="Who is Kasumi?", docs=[ - "Kasumi is a girl's name of Japanese origin meaning \"mist\".", + 'Kasumi is a girl\'s name of Japanese origin meaning "mist".', "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", - "and she leads a team named PopiParty." + "and she leads a team named PopiParty.", ], - score_threshold=0.8 + score_threshold=0.8, ) assert isinstance(result, RerankResult) diff --git a/api/tests/integration_tests/model_runtime/hunyuan/test_llm.py b/api/tests/integration_tests/model_runtime/hunyuan/test_llm.py index 305f967ef0..b3049a06d9 100644 --- a/api/tests/integration_tests/model_runtime/hunyuan/test_llm.py +++ b/api/tests/integration_tests/model_runtime/hunyuan/test_llm.py @@ -14,19 +14,15 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='hunyuan-standard', - credentials={ - 'secret_id': 'invalid_key', - 'secret_key': 'invalid_key' - } + model="hunyuan-standard", credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"} ) model.validate_credentials( - model='hunyuan-standard', + model="hunyuan-standard", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') - } + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), + }, ) @@ -34,23 +30,16 @@ def test_invoke_model(): model = HunyuanLargeLanguageModel() response = model.invoke( - model='hunyuan-standard', + model="hunyuan-standard", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), }, - prompt_messages=[ - UserPromptMessage( - content='Hi' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 10 - }, - stop=['How'], + prompt_messages=[UserPromptMessage(content="Hi")], + model_parameters={"temperature": 0.5, "max_tokens": 10}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -61,23 +50,15 @@ def test_invoke_stream_model(): model = HunyuanLargeLanguageModel() response = model.invoke( - model='hunyuan-standard', + model="hunyuan-standard", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hi' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 100, - 'seed': 1234 + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), }, + prompt_messages=[UserPromptMessage(content="Hi")], + model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -93,19 +74,17 @@ def test_get_num_tokens(): model = HunyuanLargeLanguageModel() num_tokens = model.get_num_tokens( - model='hunyuan-standard', + model="hunyuan-standard", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 14 diff --git a/api/tests/integration_tests/model_runtime/hunyuan/test_provider.py b/api/tests/integration_tests/model_runtime/hunyuan/test_provider.py index bdec3d0e22..e3748c2ce7 100644 --- a/api/tests/integration_tests/model_runtime/hunyuan/test_provider.py +++ b/api/tests/integration_tests/model_runtime/hunyuan/test_provider.py @@ -10,16 +10,11 @@ def test_validate_provider_credentials(): provider = HunyuanProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={ - 'secret_id': 'invalid_key', - 'secret_key': 'invalid_key' - } - ) + provider.validate_provider_credentials(credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"}) provider.validate_provider_credentials( credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), } ) diff --git a/api/tests/integration_tests/model_runtime/hunyuan/test_text_embedding.py b/api/tests/integration_tests/model_runtime/hunyuan/test_text_embedding.py index 7ae6c0e456..69d14dffee 100644 --- a/api/tests/integration_tests/model_runtime/hunyuan/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/hunyuan/test_text_embedding.py @@ -12,19 +12,15 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='hunyuan-embedding', - credentials={ - 'secret_id': 'invalid_key', - 'secret_key': 'invalid_key' - } + model="hunyuan-embedding", credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"} ) model.validate_credentials( - model='hunyuan-embedding', + model="hunyuan-embedding", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') - } + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), + }, ) @@ -32,47 +28,43 @@ def test_invoke_model(): model = HunyuanTextEmbeddingModel() result = model.invoke( - model='hunyuan-embedding', + model="hunyuan-embedding", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) assert len(result.embeddings) == 2 assert result.usage.total_tokens == 6 + def test_get_num_tokens(): model = HunyuanTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='hunyuan-embedding', + model="hunyuan-embedding", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 + def test_max_chunks(): model = HunyuanTextEmbeddingModel() result = model.invoke( - model='hunyuan-embedding', + model="hunyuan-embedding", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), }, texts=[ "hello", @@ -97,8 +89,8 @@ def test_max_chunks(): "world", "hello", "world", - ] + ], ) assert isinstance(result, TextEmbeddingResult) - assert len(result.embeddings) == 22 \ No newline at end of file + assert len(result.embeddings) == 22 diff --git a/api/tests/integration_tests/model_runtime/jina/test_provider.py b/api/tests/integration_tests/model_runtime/jina/test_provider.py index 2b43248388..e3b6128c59 100644 --- a/api/tests/integration_tests/model_runtime/jina/test_provider.py +++ b/api/tests/integration_tests/model_runtime/jina/test_provider.py @@ -10,14 +10,6 @@ def test_validate_provider_credentials(): provider = JinaProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={ - 'api_key': 'hahahaha' - } - ) + provider.validate_provider_credentials(credentials={"api_key": "hahahaha"}) - provider.validate_provider_credentials( - credentials={ - 'api_key': os.environ.get('JINA_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("JINA_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/jina/test_text_embedding.py b/api/tests/integration_tests/model_runtime/jina/test_text_embedding.py index ac17566174..290735ec49 100644 --- a/api/tests/integration_tests/model_runtime/jina/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/jina/test_text_embedding.py @@ -11,18 +11,10 @@ def test_validate_credentials(): model = JinaTextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='jina-embeddings-v2-base-en', - credentials={ - 'api_key': 'invalid_key' - } - ) + model.validate_credentials(model="jina-embeddings-v2-base-en", credentials={"api_key": "invalid_key"}) model.validate_credentials( - model='jina-embeddings-v2-base-en', - credentials={ - 'api_key': os.environ.get('JINA_API_KEY') - } + model="jina-embeddings-v2-base-en", credentials={"api_key": os.environ.get("JINA_API_KEY")} ) @@ -30,15 +22,12 @@ def test_invoke_model(): model = JinaTextEmbeddingModel() result = model.invoke( - model='jina-embeddings-v2-base-en', + model="jina-embeddings-v2-base-en", credentials={ - 'api_key': os.environ.get('JINA_API_KEY'), + "api_key": os.environ.get("JINA_API_KEY"), }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -50,14 +39,11 @@ def test_get_num_tokens(): model = JinaTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='jina-embeddings-v2-base-en', + model="jina-embeddings-v2-base-en", credentials={ - 'api_key': os.environ.get('JINA_API_KEY'), + "api_key": os.environ.get("JINA_API_KEY"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 6 diff --git a/api/tests/integration_tests/model_runtime/localai/test_embedding.py b/api/tests/integration_tests/model_runtime/localai/test_embedding.py index e05345ee56..7fd9f2b300 100644 --- a/api/tests/integration_tests/model_runtime/localai/test_embedding.py +++ b/api/tests/integration_tests/model_runtime/localai/test_embedding.py @@ -1,4 +1,4 @@ """ - LocalAI Embedding Interface is temporarily unavailable due to - we could not find a way to test it for now. -""" \ No newline at end of file +LocalAI Embedding Interface is temporarily unavailable due to +we could not find a way to test it for now. +""" diff --git a/api/tests/integration_tests/model_runtime/localai/test_llm.py b/api/tests/integration_tests/model_runtime/localai/test_llm.py index 6f421403d4..aa5436c34f 100644 --- a/api/tests/integration_tests/model_runtime/localai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/localai/test_llm.py @@ -21,99 +21,78 @@ def test_validate_credentials_for_chat_model(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='chinese-llama-2-7b', + model="chinese-llama-2-7b", credentials={ - 'server_url': 'hahahaha', - 'completion_type': 'completion', - } + "server_url": "hahahaha", + "completion_type": "completion", + }, ) model.validate_credentials( - model='chinese-llama-2-7b', + model="chinese-llama-2-7b", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'completion', - } + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "completion", + }, ) + def test_invoke_completion_model(): model = LocalAILanguageModel() response = model.invoke( - model='chinese-llama-2-7b', + model="chinese-llama-2-7b", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'completion', - }, - prompt_messages=[ - UserPromptMessage( - content='ping' - ) - ], - model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'max_tokens': 10 + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "completion", }, + prompt_messages=[UserPromptMessage(content="ping")], + model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, stop=[], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_chat_model(): model = LocalAILanguageModel() response = model.invoke( - model='chinese-llama-2-7b', + model="chinese-llama-2-7b", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'chat_completion', - }, - prompt_messages=[ - UserPromptMessage( - content='ping' - ) - ], - model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'max_tokens': 10 + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "chat_completion", }, + prompt_messages=[UserPromptMessage(content="ping")], + model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, stop=[], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_stream_completion_model(): model = LocalAILanguageModel() response = model.invoke( - model='chinese-llama-2-7b', + model="chinese-llama-2-7b", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'completion', + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "completion", }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'max_tokens': 10 - }, - stop=['you'], + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -123,28 +102,21 @@ def test_invoke_stream_completion_model(): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_invoke_stream_chat_model(): model = LocalAILanguageModel() response = model.invoke( - model='chinese-llama-2-7b', + model="chinese-llama-2-7b", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'chat_completion', + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "chat_completion", }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'max_tokens': 10 - }, - stop=['you'], + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -154,64 +126,48 @@ def test_invoke_stream_chat_model(): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_get_num_tokens(): model = LocalAILanguageModel() num_tokens = model.get_num_tokens( - model='????', + model="????", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'chat_completion', + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "chat_completion", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], tools=[ PromptMessageTool( - name='get_current_weather', - description='Get the current weather in a given location', + name="get_current_weather", + description="Get the current weather in a given location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ) - ] + ], ) assert isinstance(num_tokens, int) assert num_tokens == 77 num_tokens = model.get_num_tokens( - model='????', + model="????", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'chat_completion', + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "chat_completion", }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], ) assert isinstance(num_tokens, int) diff --git a/api/tests/integration_tests/model_runtime/localai/test_rerank.py b/api/tests/integration_tests/model_runtime/localai/test_rerank.py index 99847bc852..13c7df6d14 100644 --- a/api/tests/integration_tests/model_runtime/localai/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/localai/test_rerank.py @@ -12,30 +12,29 @@ def test_validate_credentials_for_chat_model(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='bge-reranker-v2-m3', + model="bge-reranker-v2-m3", credentials={ - 'server_url': 'hahahaha', - 'completion_type': 'completion', - } + "server_url": "hahahaha", + "completion_type": "completion", + }, ) model.validate_credentials( - model='bge-reranker-base', + model="bge-reranker-base", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'completion', - } + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "completion", + }, ) + def test_invoke_rerank_model(): model = LocalaiRerankModel() response = model.invoke( - model='bge-reranker-base', - credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL') - }, - query='Organic skincare products for sensitive skin', + model="bge-reranker-base", + credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")}, + query="Organic skincare products for sensitive skin", docs=[ "Eco-friendly kitchenware for modern homes", "Biodegradable cleaning supplies for eco-conscious consumers", @@ -45,43 +44,38 @@ def test_invoke_rerank_model(): "Sustainable gardening tools and compost solutions", "Sensitive skin-friendly facial cleansers and toners", "Organic food wraps and storage solutions", - "Yoga mats made from recycled materials" + "Yoga mats made from recycled materials", ], top_n=3, score_threshold=0.75, - user="abc-123" + user="abc-123", ) assert isinstance(response, RerankResult) assert len(response.docs) == 3 + def test__invoke(): model = LocalaiRerankModel() # Test case 1: Empty docs result = model._invoke( - model='bge-reranker-base', - credentials={ - 'server_url': 'https://example.com', - 'api_key': '1234567890' - }, - query='Organic skincare products for sensitive skin', + model="bge-reranker-base", + credentials={"server_url": "https://example.com", "api_key": "1234567890"}, + query="Organic skincare products for sensitive skin", docs=[], top_n=3, score_threshold=0.75, - user="abc-123" + user="abc-123", ) assert isinstance(result, RerankResult) assert len(result.docs) == 0 # Test case 2: Valid invocation result = model._invoke( - model='bge-reranker-base', - credentials={ - 'server_url': 'https://example.com', - 'api_key': '1234567890' - }, - query='Organic skincare products for sensitive skin', + model="bge-reranker-base", + credentials={"server_url": "https://example.com", "api_key": "1234567890"}, + query="Organic skincare products for sensitive skin", docs=[ "Eco-friendly kitchenware for modern homes", "Biodegradable cleaning supplies for eco-conscious consumers", @@ -91,12 +85,12 @@ def test__invoke(): "Sustainable gardening tools and compost solutions", "Sensitive skin-friendly facial cleansers and toners", "Organic food wraps and storage solutions", - "Yoga mats made from recycled materials" + "Yoga mats made from recycled materials", ], top_n=3, score_threshold=0.75, - user="abc-123" + user="abc-123", ) assert isinstance(result, RerankResult) assert len(result.docs) == 3 - assert all(isinstance(doc, RerankDocument) for doc in result.docs) \ No newline at end of file + assert all(isinstance(doc, RerankDocument) for doc in result.docs) diff --git a/api/tests/integration_tests/model_runtime/localai/test_speech2text.py b/api/tests/integration_tests/model_runtime/localai/test_speech2text.py index 3fd2ebed4f..91b7a5752c 100644 --- a/api/tests/integration_tests/model_runtime/localai/test_speech2text.py +++ b/api/tests/integration_tests/model_runtime/localai/test_speech2text.py @@ -10,19 +10,9 @@ def test_validate_credentials(): model = LocalAISpeech2text() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='whisper-1', - credentials={ - 'server_url': 'invalid_url' - } - ) + model.validate_credentials(model="whisper-1", credentials={"server_url": "invalid_url"}) - model.validate_credentials( - model='whisper-1', - credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL') - } - ) + model.validate_credentials(model="whisper-1", credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")}) def test_invoke_model(): @@ -32,23 +22,21 @@ def test_invoke_model(): current_dir = os.path.dirname(os.path.abspath(__file__)) # Get assets directory - assets_dir = os.path.join(os.path.dirname(current_dir), 'assets') + assets_dir = os.path.join(os.path.dirname(current_dir), "assets") # Construct the path to the audio file - audio_file_path = os.path.join(assets_dir, 'audio.mp3') + audio_file_path = os.path.join(assets_dir, "audio.mp3") # Open the file and get the file object - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: file = audio_file result = model.invoke( - model='whisper-1', - credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL') - }, + model="whisper-1", + credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")}, file=file, - user="abc-123" + user="abc-123", ) assert isinstance(result, str) - assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10' \ No newline at end of file + assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10" diff --git a/api/tests/integration_tests/model_runtime/minimax/test_embedding.py b/api/tests/integration_tests/model_runtime/minimax/test_embedding.py index 6f4b8a163f..cf2a28eb9e 100644 --- a/api/tests/integration_tests/model_runtime/minimax/test_embedding.py +++ b/api/tests/integration_tests/model_runtime/minimax/test_embedding.py @@ -12,54 +12,47 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='embo-01', - credentials={ - 'minimax_api_key': 'invalid_key', - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') - } + model="embo-01", + credentials={"minimax_api_key": "invalid_key", "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID")}, ) model.validate_credentials( - model='embo-01', + model="embo-01", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') - } + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), + }, ) + def test_invoke_model(): model = MinimaxTextEmbeddingModel() result = model.invoke( - model='embo-01', + model="embo-01", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) assert len(result.embeddings) == 2 assert result.usage.total_tokens == 16 + def test_get_num_tokens(): model = MinimaxTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='embo-01', + model="embo-01", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/minimax/test_llm.py b/api/tests/integration_tests/model_runtime/minimax/test_llm.py index 570e4901a9..aacde04d32 100644 --- a/api/tests/integration_tests/model_runtime/minimax/test_llm.py +++ b/api/tests/integration_tests/model_runtime/minimax/test_llm.py @@ -17,79 +17,70 @@ def test_predefined_models(): assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) + def test_validate_credentials_for_chat_model(): sleep(3) model = MinimaxLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='abab5.5-chat', - credentials={ - 'minimax_api_key': 'invalid_key', - 'minimax_group_id': 'invalid_key' - } + model="abab5.5-chat", credentials={"minimax_api_key": "invalid_key", "minimax_group_id": "invalid_key"} ) model.validate_credentials( - model='abab5.5-chat', + model="abab5.5-chat", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') - } + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), + }, ) + def test_invoke_model(): sleep(3) model = MinimaxLargeLanguageModel() response = model.invoke( - model='abab5-chat', + model="abab5-chat", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_stream_model(): sleep(3) model = MinimaxLargeLanguageModel() response = model.invoke( - model='abab5.5-chat', + model="abab5.5-chat", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -99,34 +90,31 @@ def test_invoke_stream_model(): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_invoke_with_search(): sleep(3) model = MinimaxLargeLanguageModel() response = model.invoke( - model='abab5.5-chat', + model="abab5.5-chat", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), }, - prompt_messages=[ - UserPromptMessage( - content='北京今天的天气怎么样' - ) - ], + prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, - 'plugin_web_search': True, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, + "plugin_web_search": True, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) - total_message = '' + total_message = "" for chunk in response: assert isinstance(chunk, LLMResultChunk) assert isinstance(chunk.delta, LLMResultChunkDelta) @@ -134,25 +122,22 @@ def test_invoke_with_search(): total_message += chunk.delta.message.content assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True - assert '参考资料' in total_message + assert "参考资料" in total_message + def test_get_num_tokens(): sleep(3) model = MinimaxLargeLanguageModel() response = model.get_num_tokens( - model='abab5.5-chat', + model="abab5.5-chat", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - tools=[] + prompt_messages=[UserPromptMessage(content="Hello World!")], + tools=[], ) assert isinstance(response, int) - assert response == 30 \ No newline at end of file + assert response == 30 diff --git a/api/tests/integration_tests/model_runtime/minimax/test_provider.py b/api/tests/integration_tests/model_runtime/minimax/test_provider.py index 4c5462c6df..575ed13eef 100644 --- a/api/tests/integration_tests/model_runtime/minimax/test_provider.py +++ b/api/tests/integration_tests/model_runtime/minimax/test_provider.py @@ -12,14 +12,14 @@ def test_validate_provider_credentials(): with pytest.raises(CredentialsValidateFailedError): provider.validate_provider_credentials( credentials={ - 'minimax_api_key': 'hahahaha', - 'minimax_group_id': '123', + "minimax_api_key": "hahahaha", + "minimax_group_id": "123", } ) provider.validate_provider_credentials( credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID'), + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), } ) diff --git a/api/tests/integration_tests/model_runtime/novita/test_llm.py b/api/tests/integration_tests/model_runtime/novita/test_llm.py index 4ebc68493f..35fa0dc190 100644 --- a/api/tests/integration_tests/model_runtime/novita/test_llm.py +++ b/api/tests/integration_tests/model_runtime/novita/test_llm.py @@ -19,19 +19,12 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='meta-llama/llama-3-8b-instruct', - credentials={ - 'api_key': 'invalid_key', - 'mode': 'chat' - } + model="meta-llama/llama-3-8b-instruct", credentials={"api_key": "invalid_key", "mode": "chat"} ) model.validate_credentials( - model='meta-llama/llama-3-8b-instruct', - credentials={ - 'api_key': os.environ.get('NOVITA_API_KEY'), - 'mode': 'chat' - } + model="meta-llama/llama-3-8b-instruct", + credentials={"api_key": os.environ.get("NOVITA_API_KEY"), "mode": "chat"}, ) @@ -39,27 +32,22 @@ def test_invoke_model(): model = NovitaLargeLanguageModel() response = model.invoke( - model='meta-llama/llama-3-8b-instruct', - credentials={ - 'api_key': os.environ.get('NOVITA_API_KEY'), - 'mode': 'completion' - }, + model="meta-llama/llama-3-8b-instruct", + credentials={"api_key": os.environ.get("NOVITA_API_KEY"), "mode": "completion"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_p': 0.5, - 'max_tokens': 10, + "temperature": 1.0, + "top_p": 0.5, + "max_tokens": 10, }, - stop=['How'], + stop=["How"], stream=False, - user="novita" + user="novita", ) assert isinstance(response, LLMResult) @@ -70,27 +58,17 @@ def test_invoke_stream_model(): model = NovitaLargeLanguageModel() response = model.invoke( - model='meta-llama/llama-3-8b-instruct', - credentials={ - 'api_key': os.environ.get('NOVITA_API_KEY'), - 'mode': 'chat' - }, + model="meta-llama/llama-3-8b-instruct", + credentials={"api_key": os.environ.get("NOVITA_API_KEY"), "mode": "chat"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], - model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, - 'max_tokens': 100 - }, + model_parameters={"temperature": 1.0, "top_k": 2, "top_p": 0.5, "max_tokens": 100}, stream=True, - user="novita" + user="novita", ) assert isinstance(response, Generator) @@ -105,18 +83,16 @@ def test_get_num_tokens(): model = NovitaLargeLanguageModel() num_tokens = model.get_num_tokens( - model='meta-llama/llama-3-8b-instruct', + model="meta-llama/llama-3-8b-instruct", credentials={ - 'api_key': os.environ.get('NOVITA_API_KEY'), + "api_key": os.environ.get("NOVITA_API_KEY"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert isinstance(num_tokens, int) diff --git a/api/tests/integration_tests/model_runtime/novita/test_provider.py b/api/tests/integration_tests/model_runtime/novita/test_provider.py index bb3f19dc85..191af99db2 100644 --- a/api/tests/integration_tests/model_runtime/novita/test_provider.py +++ b/api/tests/integration_tests/model_runtime/novita/test_provider.py @@ -10,12 +10,10 @@ def test_validate_provider_credentials(): provider = NovitaProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) provider.validate_provider_credentials( credentials={ - 'api_key': os.environ.get('NOVITA_API_KEY'), + "api_key": os.environ.get("NOVITA_API_KEY"), } ) diff --git a/api/tests/integration_tests/model_runtime/ollama/test_llm.py b/api/tests/integration_tests/model_runtime/ollama/test_llm.py index 272e639a8a..58a1339f50 100644 --- a/api/tests/integration_tests/model_runtime/ollama/test_llm.py +++ b/api/tests/integration_tests/model_runtime/ollama/test_llm.py @@ -20,23 +20,23 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': 'http://localhost:21434', - 'mode': 'chat', - 'context_size': 2048, - 'max_tokens': 2048, - } + "base_url": "http://localhost:21434", + "mode": "chat", + "context_size": 2048, + "max_tokens": 2048, + }, ) model.validate_credentials( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 2048, - 'max_tokens': 2048, - } + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 2048, + "max_tokens": 2048, + }, ) @@ -44,26 +44,17 @@ def test_invoke_model(): model = OllamaLargeLanguageModel() response = model.invoke( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 2048, - 'max_tokens': 2048, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 2048, + "max_tokens": 2048, }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], - model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, - 'num_predict': 10 - }, - stop=['How'], - stream=False + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 1.0, "top_k": 2, "top_p": 0.5, "num_predict": 10}, + stop=["How"], + stream=False, ) assert isinstance(response, LLMResult) @@ -74,29 +65,22 @@ def test_invoke_stream_model(): model = OllamaLargeLanguageModel() response = model.invoke( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 2048, - 'max_tokens': 2048, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 2048, + "max_tokens": 2048, }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], - model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, - 'num_predict': 10 - }, - stop=['How'], - stream=True + model_parameters={"temperature": 1.0, "top_k": 2, "top_p": 0.5, "num_predict": 10}, + stop=["How"], + stream=True, ) assert isinstance(response, Generator) @@ -111,26 +95,17 @@ def test_invoke_completion_model(): model = OllamaLargeLanguageModel() response = model.invoke( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'completion', - 'context_size': 2048, - 'max_tokens': 2048, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "completion", + "context_size": 2048, + "max_tokens": 2048, }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], - model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, - 'num_predict': 10 - }, - stop=['How'], - stream=False + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 1.0, "top_k": 2, "top_p": 0.5, "num_predict": 10}, + stop=["How"], + stream=False, ) assert isinstance(response, LLMResult) @@ -141,29 +116,22 @@ def test_invoke_stream_completion_model(): model = OllamaLargeLanguageModel() response = model.invoke( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'completion', - 'context_size': 2048, - 'max_tokens': 2048, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "completion", + "context_size": 2048, + "max_tokens": 2048, }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], - model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, - 'num_predict': 10 - }, - stop=['How'], - stream=True + model_parameters={"temperature": 1.0, "top_k": 2, "top_p": 0.5, "num_predict": 10}, + stop=["How"], + stream=True, ) assert isinstance(response, Generator) @@ -178,29 +146,26 @@ def test_invoke_completion_model_with_vision(): model = OllamaLargeLanguageModel() result = model.invoke( - model='llava', + model="llava", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'completion', - 'context_size': 2048, - 'max_tokens': 2048, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "completion", + "context_size": 2048, + "max_tokens": 2048, }, prompt_messages=[ UserPromptMessage( content=[ TextPromptMessageContent( - data='What is this in this picture?', + data="What is this in this picture?", ), ImagePromptMessageContent( - data='' - ) + data="" + ), ] ) ], - model_parameters={ - 'temperature': 0.1, - 'num_predict': 100 - }, + model_parameters={"temperature": 0.1, "num_predict": 100}, stream=False, ) @@ -212,29 +177,26 @@ def test_invoke_chat_model_with_vision(): model = OllamaLargeLanguageModel() result = model.invoke( - model='llava', + model="llava", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 2048, - 'max_tokens': 2048, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 2048, + "max_tokens": 2048, }, prompt_messages=[ UserPromptMessage( content=[ TextPromptMessageContent( - data='What is this in this picture?', + data="What is this in this picture?", ), ImagePromptMessageContent( - data='' - ) + data="" + ), ] ) ], - model_parameters={ - 'temperature': 0.1, - 'num_predict': 100 - }, + model_parameters={"temperature": 0.1, "num_predict": 100}, stream=False, ) @@ -246,18 +208,14 @@ def test_get_num_tokens(): model = OllamaLargeLanguageModel() num_tokens = model.get_num_tokens( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 2048, - 'max_tokens': 2048, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 2048, + "max_tokens": 2048, }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ] + prompt_messages=[UserPromptMessage(content="Hello World!")], ) assert isinstance(num_tokens, int) diff --git a/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py b/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py index c5f5918235..3c4f740a4f 100644 --- a/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py @@ -12,21 +12,21 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': 'http://localhost:21434', - 'mode': 'chat', - 'context_size': 4096, - } + "base_url": "http://localhost:21434", + "mode": "chat", + "context_size": 4096, + }, ) model.validate_credentials( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 4096, - } + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 4096, + }, ) @@ -34,17 +34,14 @@ def test_invoke_model(): model = OllamaEmbeddingModel() result = model.invoke( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 4096, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 4096, }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -56,16 +53,13 @@ def test_get_num_tokens(): model = OllamaEmbeddingModel() num_tokens = model.get_num_tokens( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 4096, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 4096, }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/openai/test_llm.py b/api/tests/integration_tests/model_runtime/openai/test_llm.py index bf4ac53579..3b3ea9ec80 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/openai/test_llm.py @@ -28,92 +28,61 @@ def test_predefined_models(): assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_validate_credentials_for_chat_model(setup_openai_mock): model = OpenAILargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='gpt-3.5-turbo', - credentials={ - 'openai_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="gpt-3.5-turbo", credentials={"openai_api_key": "invalid_key"}) - model.validate_credentials( - model='gpt-3.5-turbo', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } - ) + model.validate_credentials(model="gpt-3.5-turbo", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}) -@pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) def test_validate_credentials_for_completion_model(setup_openai_mock): model = OpenAILargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='text-davinci-003', - credentials={ - 'openai_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="text-davinci-003", credentials={"openai_api_key": "invalid_key"}) model.validate_credentials( - model='text-davinci-003', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + model="text-davinci-003", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")} ) -@pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) def test_invoke_completion_model(setup_openai_mock): model = OpenAILargeLanguageModel() result = model.invoke( - model='gpt-3.5-turbo-instruct', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY'), - 'openai_api_base': 'https://api.openai.com' - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 1 - }, + model="gpt-3.5-turbo-instruct", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY"), "openai_api_base": "https://api.openai.com"}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.0, "max_tokens": 1}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 - assert model._num_tokens_from_string('gpt-3.5-turbo-instruct', result.message.content) == 1 + assert model._num_tokens_from_string("gpt-3.5-turbo-instruct", result.message.content) == 1 -@pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) def test_invoke_stream_completion_model(setup_openai_mock): model = OpenAILargeLanguageModel() result = model.invoke( - model='gpt-3.5-turbo-instruct', + model="gpt-3.5-turbo-instruct", credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY'), - 'openai_organization': os.environ.get('OPENAI_ORGANIZATION'), - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 + "openai_api_key": os.environ.get("OPENAI_API_KEY"), + "openai_organization": os.environ.get("OPENAI_ORGANIZATION"), }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(result, Generator) @@ -124,166 +93,131 @@ def test_invoke_stream_completion_model(setup_openai_mock): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model(setup_openai_mock): model = OpenAILargeLanguageModel() result = model.invoke( - model='gpt-3.5-turbo', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="gpt-3.5-turbo", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.0, - 'top_p': 1.0, - 'presence_penalty': 0.0, - 'frequency_penalty': 0.0, - 'max_tokens': 10 + "temperature": 0.0, + "top_p": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "max_tokens": 10, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model_with_vision(setup_openai_mock): model = OpenAILargeLanguageModel() result = model.invoke( - model='gpt-4-vision-preview', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="gpt-4-vision-preview", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( content=[ TextPromptMessageContent( - data='Hello World!', + data="Hello World!", ), ImagePromptMessageContent( - data='' - ) + data="" + ), ] - ) + ), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model_with_tools(setup_openai_mock): model = OpenAILargeLanguageModel() result = model.invoke( - model='gpt-3.5-turbo', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="gpt-3.5-turbo", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( content="what's the weather today in London?", - ) + ), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, tools=[ PromptMessageTool( - name='get_weather', - description='Determine weather in my location', + name="get_weather", + description="Determine weather in my location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ), PromptMessageTool( - name='get_stock_price', - description='Get the current stock price', + name="get_stock_price", + description="Get the current stock price", parameters={ "type": "object", - "properties": { - "symbol": { - "type": "string", - "description": "The stock symbol" - } - }, - "required": [ - "symbol" - ] - } - ) + "properties": {"symbol": {"type": "string", "description": "The stock symbol"}}, + "required": ["symbol"], + }, + ), ], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert isinstance(result.message, AssistantPromptMessage) assert len(result.message.tool_calls) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_stream_chat_model(setup_openai_mock): model = OpenAILargeLanguageModel() result = model.invoke( - model='gpt-3.5-turbo', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="gpt-3.5-turbo", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(result, Generator) @@ -302,68 +236,46 @@ def test_get_num_tokens(): model = OpenAILargeLanguageModel() num_tokens = model.get_num_tokens( - model='gpt-3.5-turbo-instruct', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ] + model="gpt-3.5-turbo-instruct", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], ) assert num_tokens == 3 num_tokens = model.get_num_tokens( - model='gpt-3.5-turbo', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="gpt-3.5-turbo", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], tools=[ PromptMessageTool( - name='get_weather', - description='Determine weather in my location', + name="get_weather", + description="Determine weather in my location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ), - ] + ], ) assert num_tokens == 72 -@pytest.mark.parametrize('setup_openai_mock', [['chat', 'remote']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat", "remote"]], indirect=True) def test_fine_tuned_models(setup_openai_mock): model = OpenAILargeLanguageModel() - remote_models = model.remote_models(credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }) + remote_models = model.remote_models(credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}) if not remote_models: assert isinstance(remote_models, list) @@ -379,29 +291,23 @@ def test_fine_tuned_models(setup_openai_mock): # test invoke result = model.invoke( model=llm_model.model, - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) + def test__get_num_tokens_by_gpt2(): model = OpenAILargeLanguageModel() - num_tokens = model._get_num_tokens_by_gpt2('Hello World!') + num_tokens = model._get_num_tokens_by_gpt2("Hello World!") assert num_tokens == 3 diff --git a/api/tests/integration_tests/model_runtime/openai/test_moderation.py b/api/tests/integration_tests/model_runtime/openai/test_moderation.py index 04f9b9f33b..6de2624717 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_moderation.py +++ b/api/tests/integration_tests/model_runtime/openai/test_moderation.py @@ -7,48 +7,37 @@ from core.model_runtime.model_providers.openai.moderation.moderation import Open from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['moderation']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["moderation"]], indirect=True) def test_validate_credentials(setup_openai_mock): model = OpenAIModerationModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='text-moderation-stable', - credentials={ - 'openai_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="text-moderation-stable", credentials={"openai_api_key": "invalid_key"}) model.validate_credentials( - model='text-moderation-stable', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + model="text-moderation-stable", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")} ) -@pytest.mark.parametrize('setup_openai_mock', [['moderation']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["moderation"]], indirect=True) def test_invoke_model(setup_openai_mock): model = OpenAIModerationModel() result = model.invoke( - model='text-moderation-stable', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="text-moderation-stable", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, text="hello", - user="abc-123" + user="abc-123", ) assert isinstance(result, bool) assert result is False result = model.invoke( - model='text-moderation-stable', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="text-moderation-stable", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, text="i will kill you", - user="abc-123" + user="abc-123", ) assert isinstance(result, bool) diff --git a/api/tests/integration_tests/model_runtime/openai/test_provider.py b/api/tests/integration_tests/model_runtime/openai/test_provider.py index 5314bffbdf..4d56cfcf3c 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_provider.py +++ b/api/tests/integration_tests/model_runtime/openai/test_provider.py @@ -7,17 +7,11 @@ from core.model_runtime.model_providers.openai.openai import OpenAIProvider from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_validate_provider_credentials(setup_openai_mock): provider = OpenAIProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/openai/test_speech2text.py b/api/tests/integration_tests/model_runtime/openai/test_speech2text.py index f1a5c4fd23..aa92c8b61f 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_speech2text.py +++ b/api/tests/integration_tests/model_runtime/openai/test_speech2text.py @@ -7,26 +7,17 @@ from core.model_runtime.model_providers.openai.speech2text.speech2text import Op from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['speech2text']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["speech2text"]], indirect=True) def test_validate_credentials(setup_openai_mock): model = OpenAISpeech2TextModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='whisper-1', - credentials={ - 'openai_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="whisper-1", credentials={"openai_api_key": "invalid_key"}) - model.validate_credentials( - model='whisper-1', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } - ) + model.validate_credentials(model="whisper-1", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}) -@pytest.mark.parametrize('setup_openai_mock', [['speech2text']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["speech2text"]], indirect=True) def test_invoke_model(setup_openai_mock): model = OpenAISpeech2TextModel() @@ -34,23 +25,21 @@ def test_invoke_model(setup_openai_mock): current_dir = os.path.dirname(os.path.abspath(__file__)) # Get assets directory - assets_dir = os.path.join(os.path.dirname(current_dir), 'assets') + assets_dir = os.path.join(os.path.dirname(current_dir), "assets") # Construct the path to the audio file - audio_file_path = os.path.join(assets_dir, 'audio.mp3') + audio_file_path = os.path.join(assets_dir, "audio.mp3") # Open the file and get the file object - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: file = audio_file result = model.invoke( - model='whisper-1', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="whisper-1", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, file=file, - user="abc-123" + user="abc-123", ) assert isinstance(result, str) - assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10' + assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10" diff --git a/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py index e2c4c74ee7..f5dd73f2d4 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py @@ -8,42 +8,27 @@ from core.model_runtime.model_providers.openai.text_embedding.text_embedding imp from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) def test_validate_credentials(setup_openai_mock): model = OpenAITextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='text-embedding-ada-002', - credentials={ - 'openai_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="text-embedding-ada-002", credentials={"openai_api_key": "invalid_key"}) model.validate_credentials( - model='text-embedding-ada-002', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + model="text-embedding-ada-002", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")} ) -@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) def test_invoke_model(setup_openai_mock): model = OpenAITextEmbeddingModel() result = model.invoke( - model='text-embedding-ada-002', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY'), - 'openai_api_base': 'https://api.openai.com' - }, - texts=[ - "hello", - "world", - " ".join(["long_text"] * 100), - " ".join(["another_long_text"] * 100) - ], - user="abc-123" + model="text-embedding-ada-002", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY"), "openai_api_base": "https://api.openai.com"}, + texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -55,15 +40,9 @@ def test_get_num_tokens(): model = OpenAITextEmbeddingModel() num_tokens = model.get_num_tokens( - model='text-embedding-ada-002', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY'), - 'openai_api_base': 'https://api.openai.com' - }, - texts=[ - "hello", - "world" - ] + model="text-embedding-ada-002", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY"), "openai_api_base": "https://api.openai.com"}, + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py index c833508569..f2302ef05e 100644 --- a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py +++ b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py @@ -23,21 +23,17 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', - credentials={ - 'api_key': 'invalid_key', - 'endpoint_url': 'https://api.together.xyz/v1/', - 'mode': 'chat' - } + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + credentials={"api_key": "invalid_key", "endpoint_url": "https://api.together.xyz/v1/", "mode": "chat"}, ) model.validate_credentials( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', + model="mistralai/Mixtral-8x7B-Instruct-v0.1", credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'endpoint_url': 'https://api.together.xyz/v1/', - 'mode': 'chat' - } + "api_key": os.environ.get("TOGETHER_API_KEY"), + "endpoint_url": "https://api.together.xyz/v1/", + "mode": "chat", + }, ) @@ -45,28 +41,26 @@ def test_invoke_model(): model = OAIAPICompatLargeLanguageModel() response = model.invoke( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', + model="mistralai/Mixtral-8x7B-Instruct-v0.1", credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'endpoint_url': 'https://api.together.xyz/v1/', - 'mode': 'completion' + "api_key": os.environ.get("TOGETHER_API_KEY"), + "endpoint_url": "https://api.together.xyz/v1/", + "mode": "completion", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -77,29 +71,27 @@ def test_invoke_stream_model(): model = OAIAPICompatLargeLanguageModel() response = model.invoke( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', + model="mistralai/Mixtral-8x7B-Instruct-v0.1", credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'endpoint_url': 'https://api.together.xyz/v1/', - 'mode': 'chat', - 'stream_mode_delimiter': '\\n\\n' + "api_key": os.environ.get("TOGETHER_API_KEY"), + "endpoint_url": "https://api.together.xyz/v1/", + "mode": "chat", + "stream_mode_delimiter": "\\n\\n", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -114,28 +106,26 @@ def test_invoke_stream_model_without_delimiter(): model = OAIAPICompatLargeLanguageModel() response = model.invoke( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', + model="mistralai/Mixtral-8x7B-Instruct-v0.1", credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'endpoint_url': 'https://api.together.xyz/v1/', - 'mode': 'chat' + "api_key": os.environ.get("TOGETHER_API_KEY"), + "endpoint_url": "https://api.together.xyz/v1/", + "mode": "chat", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -151,51 +141,37 @@ def test_invoke_chat_model_with_tools(): model = OAIAPICompatLargeLanguageModel() result = model.invoke( - model='gpt-3.5-turbo', + model="gpt-3.5-turbo", credentials={ - 'api_key': os.environ.get('OPENAI_API_KEY'), - 'endpoint_url': 'https://api.openai.com/v1/', - 'mode': 'chat' + "api_key": os.environ.get("OPENAI_API_KEY"), + "endpoint_url": "https://api.openai.com/v1/", + "mode": "chat", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( content="what's the weather today in London?", - ) + ), ], tools=[ PromptMessageTool( - name='get_weather', - description='Determine weather in my location', + name="get_weather", + description="Determine weather in my location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "celsius", - "fahrenheit" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 1024 - }, + model_parameters={"temperature": 0.0, "max_tokens": 1024}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) @@ -207,19 +183,14 @@ def test_get_num_tokens(): model = OAIAPICompatLargeLanguageModel() num_tokens = model.get_num_tokens( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', - credentials={ - 'api_key': os.environ.get('OPENAI_API_KEY'), - 'endpoint_url': 'https://api.openai.com/v1/' - }, + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + credentials={"api_key": os.environ.get("OPENAI_API_KEY"), "endpoint_url": "https://api.openai.com/v1/"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert isinstance(num_tokens, int) diff --git a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_speech2text.py b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_speech2text.py index 61079104dc..cf805eafff 100644 --- a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_speech2text.py +++ b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_speech2text.py @@ -14,18 +14,12 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( model="whisper-1", - credentials={ - "api_key": "invalid_key", - "endpoint_url": "https://api.openai.com/v1/" - }, + credentials={"api_key": "invalid_key", "endpoint_url": "https://api.openai.com/v1/"}, ) model.validate_credentials( model="whisper-1", - credentials={ - "api_key": os.environ.get("OPENAI_API_KEY"), - "endpoint_url": "https://api.openai.com/v1/" - }, + credentials={"api_key": os.environ.get("OPENAI_API_KEY"), "endpoint_url": "https://api.openai.com/v1/"}, ) @@ -47,13 +41,10 @@ def test_invoke_model(): result = model.invoke( model="whisper-1", - credentials={ - "api_key": os.environ.get("OPENAI_API_KEY"), - "endpoint_url": "https://api.openai.com/v1/" - }, + credentials={"api_key": os.environ.get("OPENAI_API_KEY"), "endpoint_url": "https://api.openai.com/v1/"}, file=file, user="abc-123", ) assert isinstance(result, str) - assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10' + assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10" diff --git a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py index 77d27ec161..052b41605f 100644 --- a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py @@ -12,27 +12,23 @@ from core.model_runtime.model_providers.openai_api_compatible.text_embedding.tex Using OpenAI's API as testing endpoint """ + def test_validate_credentials(): model = OAICompatEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='text-embedding-ada-002', - credentials={ - 'api_key': 'invalid_key', - 'endpoint_url': 'https://api.openai.com/v1/', - 'context_size': 8184 - - } + model="text-embedding-ada-002", + credentials={"api_key": "invalid_key", "endpoint_url": "https://api.openai.com/v1/", "context_size": 8184}, ) model.validate_credentials( - model='text-embedding-ada-002', + model="text-embedding-ada-002", credentials={ - 'api_key': os.environ.get('OPENAI_API_KEY'), - 'endpoint_url': 'https://api.openai.com/v1/', - 'context_size': 8184 - } + "api_key": os.environ.get("OPENAI_API_KEY"), + "endpoint_url": "https://api.openai.com/v1/", + "context_size": 8184, + }, ) @@ -40,19 +36,14 @@ def test_invoke_model(): model = OAICompatEmbeddingModel() result = model.invoke( - model='text-embedding-ada-002', + model="text-embedding-ada-002", credentials={ - 'api_key': os.environ.get('OPENAI_API_KEY'), - 'endpoint_url': 'https://api.openai.com/v1/', - 'context_size': 8184 + "api_key": os.environ.get("OPENAI_API_KEY"), + "endpoint_url": "https://api.openai.com/v1/", + "context_size": 8184, }, - texts=[ - "hello", - "world", - " ".join(["long_text"] * 100), - " ".join(["another_long_text"] * 100) - ], - user="abc-123" + texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -64,16 +55,13 @@ def test_get_num_tokens(): model = OAICompatEmbeddingModel() num_tokens = model.get_num_tokens( - model='text-embedding-ada-002', + model="text-embedding-ada-002", credentials={ - 'api_key': os.environ.get('OPENAI_API_KEY'), - 'endpoint_url': 'https://api.openai.com/v1/embeddings', - 'context_size': 8184 + "api_key": os.environ.get("OPENAI_API_KEY"), + "endpoint_url": "https://api.openai.com/v1/embeddings", + "context_size": 8184, }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) - assert num_tokens == 2 \ No newline at end of file + assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/openllm/test_embedding.py b/api/tests/integration_tests/model_runtime/openllm/test_embedding.py index 9eb05a111d..14d47217af 100644 --- a/api/tests/integration_tests/model_runtime/openllm/test_embedding.py +++ b/api/tests/integration_tests/model_runtime/openllm/test_embedding.py @@ -12,17 +12,17 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': 'ww' + os.environ.get('OPENLLM_SERVER_URL'), - } + "server_url": "ww" + os.environ.get("OPENLLM_SERVER_URL"), + }, ) model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': os.environ.get('OPENLLM_SERVER_URL'), - } + "server_url": os.environ.get("OPENLLM_SERVER_URL"), + }, ) @@ -30,33 +30,28 @@ def test_invoke_model(): model = OpenLLMTextEmbeddingModel() result = model.invoke( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': os.environ.get('OPENLLM_SERVER_URL'), + "server_url": os.environ.get("OPENLLM_SERVER_URL"), }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) assert len(result.embeddings) == 2 assert result.usage.total_tokens > 0 + def test_get_num_tokens(): model = OpenLLMTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': os.environ.get('OPENLLM_SERVER_URL'), + "server_url": os.environ.get("OPENLLM_SERVER_URL"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/openllm/test_llm.py b/api/tests/integration_tests/model_runtime/openllm/test_llm.py index 853a0fbe3c..35939e3cfe 100644 --- a/api/tests/integration_tests/model_runtime/openllm/test_llm.py +++ b/api/tests/integration_tests/model_runtime/openllm/test_llm.py @@ -14,67 +14,61 @@ def test_validate_credentials_for_chat_model(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': 'invalid_key', - } + "server_url": "invalid_key", + }, ) model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': os.environ.get('OPENLLM_SERVER_URL'), - } + "server_url": os.environ.get("OPENLLM_SERVER_URL"), + }, ) + def test_invoke_model(): model = OpenLLMLargeLanguageModel() response = model.invoke( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': os.environ.get('OPENLLM_SERVER_URL'), + "server_url": os.environ.get("OPENLLM_SERVER_URL"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_stream_model(): model = OpenLLMLargeLanguageModel() response = model.invoke( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': os.environ.get('OPENLLM_SERVER_URL'), + "server_url": os.environ.get("OPENLLM_SERVER_URL"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -84,21 +78,18 @@ def test_invoke_stream_model(): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_get_num_tokens(): model = OpenLLMLargeLanguageModel() response = model.get_num_tokens( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': os.environ.get('OPENLLM_SERVER_URL'), + "server_url": os.environ.get("OPENLLM_SERVER_URL"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - tools=[] + prompt_messages=[UserPromptMessage(content="Hello World!")], + tools=[], ) assert isinstance(response, int) - assert response == 3 \ No newline at end of file + assert response == 3 diff --git a/api/tests/integration_tests/model_runtime/openrouter/test_llm.py b/api/tests/integration_tests/model_runtime/openrouter/test_llm.py index 8f1fb4c4ad..ce4876a73a 100644 --- a/api/tests/integration_tests/model_runtime/openrouter/test_llm.py +++ b/api/tests/integration_tests/model_runtime/openrouter/test_llm.py @@ -19,19 +19,12 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='mistralai/mixtral-8x7b-instruct', - credentials={ - 'api_key': 'invalid_key', - 'mode': 'chat' - } + model="mistralai/mixtral-8x7b-instruct", credentials={"api_key": "invalid_key", "mode": "chat"} ) model.validate_credentials( - model='mistralai/mixtral-8x7b-instruct', - credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'mode': 'chat' - } + model="mistralai/mixtral-8x7b-instruct", + credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"}, ) @@ -39,27 +32,22 @@ def test_invoke_model(): model = OpenRouterLargeLanguageModel() response = model.invoke( - model='mistralai/mixtral-8x7b-instruct', - credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'mode': 'completion' - }, + model="mistralai/mixtral-8x7b-instruct", + credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "completion"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -70,27 +58,22 @@ def test_invoke_stream_model(): model = OpenRouterLargeLanguageModel() response = model.invoke( - model='mistralai/mixtral-8x7b-instruct', - credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'mode': 'chat' - }, + model="mistralai/mixtral-8x7b-instruct", + credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -105,18 +88,16 @@ def test_get_num_tokens(): model = OpenRouterLargeLanguageModel() num_tokens = model.get_num_tokens( - model='mistralai/mixtral-8x7b-instruct', + model="mistralai/mixtral-8x7b-instruct", credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), + "api_key": os.environ.get("TOGETHER_API_KEY"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert isinstance(num_tokens, int) diff --git a/api/tests/integration_tests/model_runtime/replicate/test_llm.py b/api/tests/integration_tests/model_runtime/replicate/test_llm.py index e248f064c0..b940005b71 100644 --- a/api/tests/integration_tests/model_runtime/replicate/test_llm.py +++ b/api/tests/integration_tests/model_runtime/replicate/test_llm.py @@ -14,19 +14,19 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='meta/llama-2-13b-chat', + model="meta/llama-2-13b-chat", credentials={ - 'replicate_api_token': 'invalid_key', - 'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d' - } + "replicate_api_token": "invalid_key", + "model_version": "f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d", + }, ) model.validate_credentials( - model='meta/llama-2-13b-chat', + model="meta/llama-2-13b-chat", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d' - } + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d", + }, ) @@ -34,27 +34,25 @@ def test_invoke_model(): model = ReplicateLargeLanguageModel() response = model.invoke( - model='meta/llama-2-13b-chat', + model="meta/llama-2-13b-chat", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -65,27 +63,25 @@ def test_invoke_stream_model(): model = ReplicateLargeLanguageModel() response = model.invoke( - model='mistralai/mixtral-8x7b-instruct-v0.1', + model="mistralai/mixtral-8x7b-instruct-v0.1", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': '2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -100,19 +96,17 @@ def test_get_num_tokens(): model = ReplicateLargeLanguageModel() num_tokens = model.get_num_tokens( - model='', + model="", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': '2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 14 diff --git a/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py b/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py index 5708ec9e5a..397715f225 100644 --- a/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py @@ -12,19 +12,19 @@ def test_validate_credentials_one(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='replicate/all-mpnet-base-v2', + model="replicate/all-mpnet-base-v2", credentials={ - 'replicate_api_token': 'invalid_key', - 'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305' - } + "replicate_api_token": "invalid_key", + "model_version": "b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305", + }, ) model.validate_credentials( - model='replicate/all-mpnet-base-v2', + model="replicate/all-mpnet-base-v2", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305' - } + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305", + }, ) @@ -33,19 +33,19 @@ def test_validate_credentials_two(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='nateraw/bge-large-en-v1.5', + model="nateraw/bge-large-en-v1.5", credentials={ - 'replicate_api_token': 'invalid_key', - 'model_version': '9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1' - } + "replicate_api_token": "invalid_key", + "model_version": "9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1", + }, ) model.validate_credentials( - model='nateraw/bge-large-en-v1.5', + model="nateraw/bge-large-en-v1.5", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': '9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1' - } + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1", + }, ) @@ -53,16 +53,13 @@ def test_invoke_model_one(): model = ReplicateEmbeddingModel() result = model.invoke( - model='nateraw/bge-large-en-v1.5', + model="nateraw/bge-large-en-v1.5", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': '9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1", }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -74,16 +71,13 @@ def test_invoke_model_two(): model = ReplicateEmbeddingModel() result = model.invoke( - model='andreasjansson/clip-features', + model="andreasjansson/clip-features", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': '75b33f253f7714a281ad3e9b28f63e3232d583716ef6718f2e46641077ea040a' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "75b33f253f7714a281ad3e9b28f63e3232d583716ef6718f2e46641077ea040a", }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -95,16 +89,13 @@ def test_invoke_model_three(): model = ReplicateEmbeddingModel() result = model.invoke( - model='replicate/all-mpnet-base-v2', + model="replicate/all-mpnet-base-v2", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305", }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -116,16 +107,13 @@ def test_invoke_model_four(): model = ReplicateEmbeddingModel() result = model.invoke( - model='nateraw/jina-embeddings-v2-base-en', + model="nateraw/jina-embeddings-v2-base-en", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': 'f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e", }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -137,15 +125,12 @@ def test_get_num_tokens(): model = ReplicateEmbeddingModel() num_tokens = model.get_num_tokens( - model='nateraw/jina-embeddings-v2-base-en', + model="nateraw/jina-embeddings-v2-base-en", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': 'f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e", }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/sagemaker/test_provider.py b/api/tests/integration_tests/model_runtime/sagemaker/test_provider.py index 639227e745..9f0b439d6c 100644 --- a/api/tests/integration_tests/model_runtime/sagemaker/test_provider.py +++ b/api/tests/integration_tests/model_runtime/sagemaker/test_provider.py @@ -10,10 +10,6 @@ def test_validate_provider_credentials(): provider = SageMakerProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) diff --git a/api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py b/api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py index c67849dd79..d5a6798a1e 100644 --- a/api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py @@ -12,11 +12,11 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='bge-m3-rerank-v2', + model="bge-m3-rerank-v2", credentials={ "aws_region": os.getenv("AWS_REGION"), "aws_access_key": os.getenv("AWS_ACCESS_KEY"), - "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), }, query="What is the capital of the United States?", docs=[ @@ -25,7 +25,7 @@ def test_validate_credentials(): "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) @@ -33,11 +33,11 @@ def test_invoke_model(): model = SageMakerRerankModel() result = model.invoke( - model='bge-m3-rerank-v2', + model="bge-m3-rerank-v2", credentials={ "aws_region": os.getenv("AWS_REGION"), "aws_access_key": os.getenv("AWS_ACCESS_KEY"), - "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), }, query="What is the capital of the United States?", docs=[ @@ -46,7 +46,7 @@ def test_invoke_model(): "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) assert isinstance(result, RerankResult) diff --git a/api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py b/api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py index e817e8f04a..e4e404c7a8 100644 --- a/api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py @@ -11,45 +11,23 @@ def test_validate_credentials(): model = SageMakerEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='bge-m3', - credentials={ - } - ) + model.validate_credentials(model="bge-m3", credentials={}) - model.validate_credentials( - model='bge-m3-embedding', - credentials={ - } - ) + model.validate_credentials(model="bge-m3-embedding", credentials={}) def test_invoke_model(): model = SageMakerEmbeddingModel() - result = model.invoke( - model='bge-m3-embedding', - credentials={ - }, - texts=[ - "hello", - "world" - ], - user="abc-123" - ) + result = model.invoke(model="bge-m3-embedding", credentials={}, texts=["hello", "world"], user="abc-123") assert isinstance(result, TextEmbeddingResult) assert len(result.embeddings) == 2 + def test_get_num_tokens(): model = SageMakerEmbeddingModel() - num_tokens = model.get_num_tokens( - model='bge-m3-embedding', - credentials={ - }, - texts=[ - ] - ) + num_tokens = model.get_num_tokens(model="bge-m3-embedding", credentials={}, texts=[]) assert num_tokens == 0 diff --git a/api/tests/integration_tests/model_runtime/siliconflow/test_llm.py b/api/tests/integration_tests/model_runtime/siliconflow/test_llm.py index befdd82352..f47c9c5588 100644 --- a/api/tests/integration_tests/model_runtime/siliconflow/test_llm.py +++ b/api/tests/integration_tests/model_runtime/siliconflow/test_llm.py @@ -13,41 +13,22 @@ def test_validate_credentials(): model = SiliconflowLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='deepseek-ai/DeepSeek-V2-Chat', - credentials={ - 'api_key': 'invalid_key' - } - ) + model.validate_credentials(model="deepseek-ai/DeepSeek-V2-Chat", credentials={"api_key": "invalid_key"}) - model.validate_credentials( - model='deepseek-ai/DeepSeek-V2-Chat', - credentials={ - 'api_key': os.environ.get('API_KEY') - } - ) + model.validate_credentials(model="deepseek-ai/DeepSeek-V2-Chat", credentials={"api_key": os.environ.get("API_KEY")}) def test_invoke_model(): model = SiliconflowLargeLanguageModel() response = model.invoke( - model='deepseek-ai/DeepSeek-V2-Chat', - credentials={ - 'api_key': os.environ.get('API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 10 - }, - stop=['How'], + model="deepseek-ai/DeepSeek-V2-Chat", + credentials={"api_key": os.environ.get("API_KEY")}, + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 0.5, "max_tokens": 10}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -58,22 +39,12 @@ def test_invoke_stream_model(): model = SiliconflowLargeLanguageModel() response = model.invoke( - model='deepseek-ai/DeepSeek-V2-Chat', - credentials={ - 'api_key': os.environ.get('API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 100, - 'seed': 1234 - }, + model="deepseek-ai/DeepSeek-V2-Chat", + credentials={"api_key": os.environ.get("API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -89,18 +60,14 @@ def test_get_num_tokens(): model = SiliconflowLargeLanguageModel() num_tokens = model.get_num_tokens( - model='deepseek-ai/DeepSeek-V2-Chat', - credentials={ - 'api_key': os.environ.get('API_KEY') - }, + model="deepseek-ai/DeepSeek-V2-Chat", + credentials={"api_key": os.environ.get("API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 12 diff --git a/api/tests/integration_tests/model_runtime/siliconflow/test_provider.py b/api/tests/integration_tests/model_runtime/siliconflow/test_provider.py index 7b9211a5db..8f70210b7a 100644 --- a/api/tests/integration_tests/model_runtime/siliconflow/test_provider.py +++ b/api/tests/integration_tests/model_runtime/siliconflow/test_provider.py @@ -10,12 +10,6 @@ def test_validate_provider_credentials(): provider = SiliconflowProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={ - 'api_key': os.environ.get('API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/siliconflow/test_rerank.py b/api/tests/integration_tests/model_runtime/siliconflow/test_rerank.py index 7b3ff82727..ad794613f9 100644 --- a/api/tests/integration_tests/model_runtime/siliconflow/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/siliconflow/test_rerank.py @@ -13,9 +13,7 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( model="BAAI/bge-reranker-v2-m3", - credentials={ - "api_key": "invalid_key" - }, + credentials={"api_key": "invalid_key"}, ) model.validate_credentials( @@ -30,17 +28,17 @@ def test_invoke_model(): model = SiliconflowRerankModel() result = model.invoke( - model='BAAI/bge-reranker-v2-m3', + model="BAAI/bge-reranker-v2-m3", credentials={ "api_key": os.environ.get("API_KEY"), }, query="Who is Kasumi?", docs=[ - "Kasumi is a girl's name of Japanese origin meaning \"mist\".", + 'Kasumi is a girl\'s name of Japanese origin meaning "mist".', "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", - "and she leads a team named PopiParty." + "and she leads a team named PopiParty.", ], - score_threshold=0.8 + score_threshold=0.8, ) assert isinstance(result, RerankResult) diff --git a/api/tests/integration_tests/model_runtime/siliconflow/test_speech2text.py b/api/tests/integration_tests/model_runtime/siliconflow/test_speech2text.py index 82b7921c85..0502ba5ab4 100644 --- a/api/tests/integration_tests/model_runtime/siliconflow/test_speech2text.py +++ b/api/tests/integration_tests/model_runtime/siliconflow/test_speech2text.py @@ -12,16 +12,12 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( model="iic/SenseVoiceSmall", - credentials={ - "api_key": "invalid_key" - }, + credentials={"api_key": "invalid_key"}, ) model.validate_credentials( model="iic/SenseVoiceSmall", - credentials={ - "api_key": os.environ.get("API_KEY") - }, + credentials={"api_key": os.environ.get("API_KEY")}, ) @@ -42,12 +38,8 @@ def test_invoke_model(): file = audio_file result = model.invoke( - model="iic/SenseVoiceSmall", - credentials={ - "api_key": os.environ.get("API_KEY") - }, - file=file + model="iic/SenseVoiceSmall", credentials={"api_key": os.environ.get("API_KEY")}, file=file ) assert isinstance(result, str) - assert result == '1,2,3,4,5,6,7,8,9,10.' + assert result == "1,2,3,4,5,6,7,8,9,10." diff --git a/api/tests/integration_tests/model_runtime/siliconflow/test_text_embedding.py b/api/tests/integration_tests/model_runtime/siliconflow/test_text_embedding.py index 18bd2e893a..ab143c1061 100644 --- a/api/tests/integration_tests/model_runtime/siliconflow/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/siliconflow/test_text_embedding.py @@ -15,9 +15,7 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( model="BAAI/bge-large-zh-v1.5", - credentials={ - "api_key": "invalid_key" - }, + credentials={"api_key": "invalid_key"}, ) model.validate_credentials( diff --git a/api/tests/integration_tests/model_runtime/spark/test_llm.py b/api/tests/integration_tests/model_runtime/spark/test_llm.py index 706316449d..4fe2fd8c0a 100644 --- a/api/tests/integration_tests/model_runtime/spark/test_llm.py +++ b/api/tests/integration_tests/model_runtime/spark/test_llm.py @@ -13,20 +13,15 @@ def test_validate_credentials(): model = SparkLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='spark-1.5', - credentials={ - 'app_id': 'invalid_key' - } - ) + model.validate_credentials(model="spark-1.5", credentials={"app_id": "invalid_key"}) model.validate_credentials( - model='spark-1.5', + model="spark-1.5", credentials={ - 'app_id': os.environ.get('SPARK_APP_ID'), - 'api_secret': os.environ.get('SPARK_API_SECRET'), - 'api_key': os.environ.get('SPARK_API_KEY') - } + "app_id": os.environ.get("SPARK_APP_ID"), + "api_secret": os.environ.get("SPARK_API_SECRET"), + "api_key": os.environ.get("SPARK_API_KEY"), + }, ) @@ -34,24 +29,17 @@ def test_invoke_model(): model = SparkLargeLanguageModel() response = model.invoke( - model='spark-1.5', + model="spark-1.5", credentials={ - 'app_id': os.environ.get('SPARK_APP_ID'), - 'api_secret': os.environ.get('SPARK_API_SECRET'), - 'api_key': os.environ.get('SPARK_API_KEY') + "app_id": os.environ.get("SPARK_APP_ID"), + "api_secret": os.environ.get("SPARK_API_SECRET"), + "api_key": os.environ.get("SPARK_API_KEY"), }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 10 - }, - stop=['How'], + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 0.5, "max_tokens": 10}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -62,23 +50,16 @@ def test_invoke_stream_model(): model = SparkLargeLanguageModel() response = model.invoke( - model='spark-1.5', + model="spark-1.5", credentials={ - 'app_id': os.environ.get('SPARK_APP_ID'), - 'api_secret': os.environ.get('SPARK_API_SECRET'), - 'api_key': os.environ.get('SPARK_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 100 + "app_id": os.environ.get("SPARK_APP_ID"), + "api_secret": os.environ.get("SPARK_API_SECRET"), + "api_key": os.environ.get("SPARK_API_KEY"), }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.5, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -94,20 +75,18 @@ def test_get_num_tokens(): model = SparkLargeLanguageModel() num_tokens = model.get_num_tokens( - model='spark-1.5', + model="spark-1.5", credentials={ - 'app_id': os.environ.get('SPARK_APP_ID'), - 'api_secret': os.environ.get('SPARK_API_SECRET'), - 'api_key': os.environ.get('SPARK_API_KEY') + "app_id": os.environ.get("SPARK_APP_ID"), + "api_secret": os.environ.get("SPARK_API_SECRET"), + "api_key": os.environ.get("SPARK_API_KEY"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 14 diff --git a/api/tests/integration_tests/model_runtime/spark/test_provider.py b/api/tests/integration_tests/model_runtime/spark/test_provider.py index 8e22815a86..9da0df6bb3 100644 --- a/api/tests/integration_tests/model_runtime/spark/test_provider.py +++ b/api/tests/integration_tests/model_runtime/spark/test_provider.py @@ -10,14 +10,12 @@ def test_validate_provider_credentials(): provider = SparkProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) provider.validate_provider_credentials( credentials={ - 'app_id': os.environ.get('SPARK_APP_ID'), - 'api_secret': os.environ.get('SPARK_API_SECRET'), - 'api_key': os.environ.get('SPARK_API_KEY') + "app_id": os.environ.get("SPARK_APP_ID"), + "api_secret": os.environ.get("SPARK_API_SECRET"), + "api_key": os.environ.get("SPARK_API_KEY"), } ) diff --git a/api/tests/integration_tests/model_runtime/stepfun/test_llm.py b/api/tests/integration_tests/model_runtime/stepfun/test_llm.py index d703147d63..c03b1bae1f 100644 --- a/api/tests/integration_tests/model_runtime/stepfun/test_llm.py +++ b/api/tests/integration_tests/model_runtime/stepfun/test_llm.py @@ -21,40 +21,22 @@ def test_validate_credentials(): model = StepfunLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='step-1-8k', - credentials={ - 'api_key': 'invalid_key' - } - ) + model.validate_credentials(model="step-1-8k", credentials={"api_key": "invalid_key"}) + + model.validate_credentials(model="step-1-8k", credentials={"api_key": os.environ.get("STEPFUN_API_KEY")}) - model.validate_credentials( - model='step-1-8k', - credentials={ - 'api_key': os.environ.get('STEPFUN_API_KEY') - } - ) def test_invoke_model(): model = StepfunLargeLanguageModel() response = model.invoke( - model='step-1-8k', - credentials={ - 'api_key': os.environ.get('STEPFUN_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.9, - 'top_p': 0.7 - }, - stop=['Hi'], + model="step-1-8k", + credentials={"api_key": os.environ.get("STEPFUN_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.9, "top_p": 0.7}, + stop=["Hi"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -65,24 +47,17 @@ def test_invoke_stream_model(): model = StepfunLargeLanguageModel() response = model.invoke( - model='step-1-8k', - credentials={ - 'api_key': os.environ.get('STEPFUN_API_KEY') - }, + model="step-1-8k", + credentials={"api_key": os.environ.get("STEPFUN_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.9, - 'top_p': 0.7 - }, + model_parameters={"temperature": 0.9, "top_p": 0.7}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -98,10 +73,7 @@ def test_get_customizable_model_schema(): model = StepfunLargeLanguageModel() schema = model.get_customizable_model_schema( - model='step-1-8k', - credentials={ - 'api_key': os.environ.get('STEPFUN_API_KEY') - } + model="step-1-8k", credentials={"api_key": os.environ.get("STEPFUN_API_KEY")} ) assert isinstance(schema, AIModelEntity) @@ -110,67 +82,44 @@ def test_invoke_chat_model_with_tools(): model = StepfunLargeLanguageModel() result = model.invoke( - model='step-1-8k', - credentials={ - 'api_key': os.environ.get('STEPFUN_API_KEY') - }, + model="step-1-8k", + credentials={"api_key": os.environ.get("STEPFUN_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( content="what's the weather today in Shanghai?", - ) + ), ], - model_parameters={ - 'temperature': 0.9, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.9, "max_tokens": 100}, tools=[ PromptMessageTool( - name='get_weather', - description='Determine weather in my location', + name="get_weather", + description="Determine weather in my location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ), PromptMessageTool( - name='get_stock_price', - description='Get the current stock price', + name="get_stock_price", + description="Get the current stock price", parameters={ "type": "object", - "properties": { - "symbol": { - "type": "string", - "description": "The stock symbol" - } - }, - "required": [ - "symbol" - ] - } - ) + "properties": {"symbol": {"type": "string", "description": "The stock symbol"}}, + "required": ["symbol"], + }, + ), ], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert isinstance(result.message, AssistantPromptMessage) - assert len(result.message.tool_calls) > 0 \ No newline at end of file + assert len(result.message.tool_calls) > 0 diff --git a/api/tests/integration_tests/model_runtime/test_model_provider_factory.py b/api/tests/integration_tests/model_runtime/test_model_provider_factory.py index fd8aa3f610..0ec4b0b724 100644 --- a/api/tests/integration_tests/model_runtime/test_model_provider_factory.py +++ b/api/tests/integration_tests/model_runtime/test_model_provider_factory.py @@ -24,13 +24,8 @@ def test_get_models(): providers = factory.get_models( model_type=ModelType.LLM, provider_configs=[ - ProviderConfig( - provider='openai', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } - ) - ] + ProviderConfig(provider="openai", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}) + ], ) logger.debug(providers) @@ -44,29 +39,21 @@ def test_get_models(): assert provider_model.model_type == ModelType.LLM providers = factory.get_models( - provider='openai', + provider="openai", provider_configs=[ - ProviderConfig( - provider='openai', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } - ) - ] + ProviderConfig(provider="openai", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}) + ], ) assert len(providers) == 1 assert isinstance(providers[0], SimpleProviderEntity) - assert providers[0].provider == 'openai' + assert providers[0].provider == "openai" def test_provider_credentials_validate(): factory = ModelProviderFactory() factory.provider_credentials_validate( - provider='openai', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + provider="openai", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")} ) @@ -79,4 +66,4 @@ def test__get_model_provider_map(): logger.debug(model_provider.provider_instance) assert len(model_providers) >= 1 - assert isinstance(model_providers['openai'], ModelProviderExtension) + assert isinstance(model_providers["openai"], ModelProviderExtension) diff --git a/api/tests/integration_tests/model_runtime/togetherai/test_llm.py b/api/tests/integration_tests/model_runtime/togetherai/test_llm.py index 698f534517..06ebc2a82d 100644 --- a/api/tests/integration_tests/model_runtime/togetherai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/togetherai/test_llm.py @@ -19,76 +19,61 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', - credentials={ - 'api_key': 'invalid_key', - 'mode': 'chat' - } + model="mistralai/Mixtral-8x7B-Instruct-v0.1", credentials={"api_key": "invalid_key", "mode": "chat"} ) model.validate_credentials( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', - credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'mode': 'chat' - } + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"}, ) + def test_invoke_model(): model = TogetherAILargeLanguageModel() response = model.invoke( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', - credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'mode': 'completion' - }, + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "completion"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 + def test_invoke_stream_model(): model = TogetherAILargeLanguageModel() response = model.invoke( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', - credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'mode': 'chat' - }, + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -98,22 +83,21 @@ def test_invoke_stream_model(): assert isinstance(chunk.delta, LLMResultChunkDelta) assert isinstance(chunk.delta.message, AssistantPromptMessage) + def test_get_num_tokens(): model = TogetherAILargeLanguageModel() num_tokens = model.get_num_tokens( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', + model="mistralai/Mixtral-8x7B-Instruct-v0.1", credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), + "api_key": os.environ.get("TOGETHER_API_KEY"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert isinstance(num_tokens, int) diff --git a/api/tests/integration_tests/model_runtime/tongyi/test_llm.py b/api/tests/integration_tests/model_runtime/tongyi/test_llm.py index 81fb676018..61650735f2 100644 --- a/api/tests/integration_tests/model_runtime/tongyi/test_llm.py +++ b/api/tests/integration_tests/model_runtime/tongyi/test_llm.py @@ -13,18 +13,10 @@ def test_validate_credentials(): model = TongyiLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='qwen-turbo', - credentials={ - 'dashscope_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="qwen-turbo", credentials={"dashscope_api_key": "invalid_key"}) model.validate_credentials( - model='qwen-turbo', - credentials={ - 'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') - } + model="qwen-turbo", credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")} ) @@ -32,22 +24,13 @@ def test_invoke_model(): model = TongyiLargeLanguageModel() response = model.invoke( - model='qwen-turbo', - credentials={ - 'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 10 - }, - stop=['How'], + model="qwen-turbo", + credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 0.5, "max_tokens": 10}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -58,22 +41,12 @@ def test_invoke_stream_model(): model = TongyiLargeLanguageModel() response = model.invoke( - model='qwen-turbo', - credentials={ - 'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 100, - 'seed': 1234 - }, + model="qwen-turbo", + credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -89,18 +62,14 @@ def test_get_num_tokens(): model = TongyiLargeLanguageModel() num_tokens = model.get_num_tokens( - model='qwen-turbo', - credentials={ - 'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') - }, + model="qwen-turbo", + credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 12 diff --git a/api/tests/integration_tests/model_runtime/tongyi/test_provider.py b/api/tests/integration_tests/model_runtime/tongyi/test_provider.py index 6145c1dc37..0bc96c84e7 100644 --- a/api/tests/integration_tests/model_runtime/tongyi/test_provider.py +++ b/api/tests/integration_tests/model_runtime/tongyi/test_provider.py @@ -10,12 +10,8 @@ def test_validate_provider_credentials(): provider = TongyiProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) provider.validate_provider_credentials( - credentials={ - 'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') - } + credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")} ) diff --git a/api/tests/integration_tests/model_runtime/tongyi/test_response_format.py b/api/tests/integration_tests/model_runtime/tongyi/test_response_format.py index 1b0a38d5d1..905e7907fd 100644 --- a/api/tests/integration_tests/model_runtime/tongyi/test_response_format.py +++ b/api/tests/integration_tests/model_runtime/tongyi/test_response_format.py @@ -39,21 +39,17 @@ def invoke_model_with_json_response(model_name="qwen-max-0403"): response = model.invoke( model=model_name, - credentials={ - 'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') - }, + credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")}, prompt_messages=[ - UserPromptMessage( - content='output json data with format `{"data": "test", "code": 200, "msg": "success"}' - ) + UserPromptMessage(content='output json data with format `{"data": "test", "code": 200, "msg": "success"}') ], model_parameters={ - 'temperature': 0.5, - 'max_tokens': 50, - 'response_format': 'JSON', + "temperature": 0.5, + "max_tokens": 50, + "response_format": "JSON", }, stream=True, - user="abc-123" + user="abc-123", ) print("=====================================") print(response) @@ -81,4 +77,4 @@ def is_json(s): json.loads(s) except ValueError: return False - return True \ No newline at end of file + return True diff --git a/api/tests/integration_tests/model_runtime/upstage/test_llm.py b/api/tests/integration_tests/model_runtime/upstage/test_llm.py index c35580a8b1..bc7517acbe 100644 --- a/api/tests/integration_tests/model_runtime/upstage/test_llm.py +++ b/api/tests/integration_tests/model_runtime/upstage/test_llm.py @@ -26,151 +26,113 @@ def test_predefined_models(): assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_validate_credentials_for_chat_model(setup_openai_mock): model = UpstageLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): # model name to gpt-3.5-turbo because of mocking - model.validate_credentials( - model='gpt-3.5-turbo', - credentials={ - 'upstage_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="gpt-3.5-turbo", credentials={"upstage_api_key": "invalid_key"}) model.validate_credentials( - model='solar-1-mini-chat', - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - } + model="solar-1-mini-chat", credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")} ) -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model(setup_openai_mock): model = UpstageLargeLanguageModel() result = model.invoke( - model='solar-1-mini-chat', - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - }, + model="solar-1-mini-chat", + credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.0, - 'top_p': 1.0, - 'presence_penalty': 0.0, - 'frequency_penalty': 0.0, - 'max_tokens': 10 + "temperature": 0.0, + "top_p": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "max_tokens": 10, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model_with_tools(setup_openai_mock): model = UpstageLargeLanguageModel() result = model.invoke( - model='solar-1-mini-chat', - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - }, + model="solar-1-mini-chat", + credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( content="what's the weather today in London?", - ) + ), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, tools=[ PromptMessageTool( - name='get_weather', - description='Determine weather in my location', + name="get_weather", + description="Determine weather in my location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ), PromptMessageTool( - name='get_stock_price', - description='Get the current stock price', + name="get_stock_price", + description="Get the current stock price", parameters={ "type": "object", - "properties": { - "symbol": { - "type": "string", - "description": "The stock symbol" - } - }, - "required": [ - "symbol" - ] - } - ) + "properties": {"symbol": {"type": "string", "description": "The stock symbol"}}, + "required": ["symbol"], + }, + ), ], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert isinstance(result.message, AssistantPromptMessage) assert len(result.message.tool_calls) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_stream_chat_model(setup_openai_mock): model = UpstageLargeLanguageModel() result = model.invoke( - model='solar-1-mini-chat', - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - }, + model="solar-1-mini-chat", + credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(result, Generator) @@ -189,57 +151,36 @@ def test_get_num_tokens(): model = UpstageLargeLanguageModel() num_tokens = model.get_num_tokens( - model='solar-1-mini-chat', - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ] + model="solar-1-mini-chat", + credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], ) assert num_tokens == 13 num_tokens = model.get_num_tokens( - model='solar-1-mini-chat', - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - }, + model="solar-1-mini-chat", + credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], tools=[ PromptMessageTool( - name='get_weather', - description='Determine weather in my location', + name="get_weather", + description="Determine weather in my location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ), - ] + ], ) assert num_tokens == 106 diff --git a/api/tests/integration_tests/model_runtime/upstage/test_provider.py b/api/tests/integration_tests/model_runtime/upstage/test_provider.py index c33eef49b2..9d83779aa0 100644 --- a/api/tests/integration_tests/model_runtime/upstage/test_provider.py +++ b/api/tests/integration_tests/model_runtime/upstage/test_provider.py @@ -7,17 +7,11 @@ from core.model_runtime.model_providers.upstage.upstage import UpstageProvider from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_validate_provider_credentials(setup_openai_mock): provider = UpstageProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/upstage/test_text_embedding.py b/api/tests/integration_tests/model_runtime/upstage/test_text_embedding.py index 54135a0e74..8c83172fa3 100644 --- a/api/tests/integration_tests/model_runtime/upstage/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/upstage/test_text_embedding.py @@ -8,41 +8,31 @@ from core.model_runtime.model_providers.upstage.text_embedding.text_embedding im from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) def test_validate_credentials(setup_openai_mock): model = UpstageTextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='solar-embedding-1-large-passage', - credentials={ - 'upstage_api_key': 'invalid_key' - } + model="solar-embedding-1-large-passage", credentials={"upstage_api_key": "invalid_key"} ) model.validate_credentials( - model='solar-embedding-1-large-passage', - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - } + model="solar-embedding-1-large-passage", credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")} ) -@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) def test_invoke_model(setup_openai_mock): model = UpstageTextEmbeddingModel() result = model.invoke( - model='solar-embedding-1-large-passage', + model="solar-embedding-1-large-passage", credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY'), + "upstage_api_key": os.environ.get("UPSTAGE_API_KEY"), }, - texts=[ - "hello", - "world", - " ".join(["long_text"] * 100), - " ".join(["another_long_text"] * 100) - ], - user="abc-123" + texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -54,14 +44,11 @@ def test_get_num_tokens(): model = UpstageTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='solar-embedding-1-large-passage', + model="solar-embedding-1-large-passage", credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY'), + "upstage_api_key": os.environ.get("UPSTAGE_API_KEY"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 5 diff --git a/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py b/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py index 3b399d604e..f831c063a4 100644 --- a/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py +++ b/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py @@ -14,26 +14,26 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': 'INVALID', - 'volc_secret_access_key': 'INVALID', - 'endpoint_id': 'INVALID', - 'base_model_name': 'Doubao-embedding', - } + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": "INVALID", + "volc_secret_access_key": "INVALID", + "endpoint_id": "INVALID", + "base_model_name": "Doubao-embedding", + }, ) model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), - 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), - 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'), - 'base_model_name': 'Doubao-embedding', + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_EMBEDDING_ENDPOINT_ID"), + "base_model_name": "Doubao-embedding", }, ) @@ -42,20 +42,17 @@ def test_invoke_model(): model = VolcengineMaaSTextEmbeddingModel() result = model.invoke( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), - 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), - 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'), - 'base_model_name': 'Doubao-embedding', + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_EMBEDDING_ENDPOINT_ID"), + "base_model_name": "Doubao-embedding", }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -67,19 +64,16 @@ def test_get_num_tokens(): model = VolcengineMaaSTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), - 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), - 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'), - 'base_model_name': 'Doubao-embedding', + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_EMBEDDING_ENDPOINT_ID"), + "base_model_name": "Doubao-embedding", }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py b/api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py index 63835d0263..8ff9c41404 100644 --- a/api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py +++ b/api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py @@ -14,25 +14,25 @@ def test_validate_credentials_for_chat_model(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': 'INVALID', - 'volc_secret_access_key': 'INVALID', - 'endpoint_id': 'INVALID', - } + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": "INVALID", + "volc_secret_access_key": "INVALID", + "endpoint_id": "INVALID", + }, ) model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), - 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), - 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'), - } + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"), + }, ) @@ -40,28 +40,24 @@ def test_invoke_model(): model = VolcengineMaaSLargeLanguageModel() response = model.invoke( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), - 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), - 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'), - 'base_model_name': 'Skylark2-pro-4k', + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"), + "base_model_name": "Skylark2-pro-4k", }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) @@ -73,28 +69,24 @@ def test_invoke_stream_model(): model = VolcengineMaaSLargeLanguageModel() response = model.invoke( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), - 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), - 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'), - 'base_model_name': 'Skylark2-pro-4k', + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"), + "base_model_name": "Skylark2-pro-4k", }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -102,29 +94,24 @@ def test_invoke_stream_model(): assert isinstance(chunk, LLMResultChunk) assert isinstance(chunk.delta, LLMResultChunkDelta) assert isinstance(chunk.delta.message, AssistantPromptMessage) - assert len( - chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True def test_get_num_tokens(): model = VolcengineMaaSLargeLanguageModel() response = model.get_num_tokens( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), - 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), - 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'), - 'base_model_name': 'Skylark2-pro-4k', + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"), + "base_model_name": "Skylark2-pro-4k", }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - tools=[] + prompt_messages=[UserPromptMessage(content="Hello World!")], + tools=[], ) assert isinstance(response, int) diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_embedding.py b/api/tests/integration_tests/model_runtime/wenxin/test_embedding.py index d886226cf9..ac38340aec 100644 --- a/api/tests/integration_tests/model_runtime/wenxin/test_embedding.py +++ b/api/tests/integration_tests/model_runtime/wenxin/test_embedding.py @@ -10,13 +10,10 @@ def test_invoke_embedding_v1(): model = WenxinTextEmbeddingModel() response = model.invoke( - model='embedding-v1', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - texts=['hello', '你好', 'xxxxx'], - user="abc-123" + model="embedding-v1", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + texts=["hello", "你好", "xxxxx"], + user="abc-123", ) assert isinstance(response, TextEmbeddingResult) @@ -29,13 +26,10 @@ def test_invoke_embedding_bge_large_en(): model = WenxinTextEmbeddingModel() response = model.invoke( - model='bge-large-en', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - texts=['hello', '你好', 'xxxxx'], - user="abc-123" + model="bge-large-en", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + texts=["hello", "你好", "xxxxx"], + user="abc-123", ) assert isinstance(response, TextEmbeddingResult) @@ -48,13 +42,10 @@ def test_invoke_embedding_bge_large_zh(): model = WenxinTextEmbeddingModel() response = model.invoke( - model='bge-large-zh', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - texts=['hello', '你好', 'xxxxx'], - user="abc-123" + model="bge-large-zh", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + texts=["hello", "你好", "xxxxx"], + user="abc-123", ) assert isinstance(response, TextEmbeddingResult) @@ -67,13 +58,10 @@ def test_invoke_embedding_tao_8k(): model = WenxinTextEmbeddingModel() response = model.invoke( - model='tao-8k', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - texts=['hello', '你好', 'xxxxx'], - user="abc-123" + model="tao-8k", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + texts=["hello", "你好", "xxxxx"], + user="abc-123", ) assert isinstance(response, TextEmbeddingResult) diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_llm.py b/api/tests/integration_tests/model_runtime/wenxin/test_llm.py index 164e8253d9..e2e58f15e0 100644 --- a/api/tests/integration_tests/model_runtime/wenxin/test_llm.py +++ b/api/tests/integration_tests/model_runtime/wenxin/test_llm.py @@ -17,161 +17,125 @@ def test_predefined_models(): assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) + def test_validate_credentials_for_chat_model(): sleep(3) model = ErnieBotLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='ernie-bot', - credentials={ - 'api_key': 'invalid_key', - 'secret_key': 'invalid_key' - } + model="ernie-bot", credentials={"api_key": "invalid_key", "secret_key": "invalid_key"} ) model.validate_credentials( - model='ernie-bot', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - } + model="ernie-bot", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, ) + def test_invoke_model_ernie_bot(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.invoke( - model='ernie-bot', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + model="ernie-bot", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_model_ernie_bot_turbo(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.invoke( - model='ernie-bot-turbo', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + model="ernie-bot-turbo", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_model_ernie_8k(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.invoke( - model='ernie-bot-8k', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + model="ernie-bot-8k", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_model_ernie_bot_4(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.invoke( - model='ernie-bot-4', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + model="ernie-bot-4", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_stream_model(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.invoke( - model='ernie-3.5-8k', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + model="ernie-3.5-8k", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -181,63 +145,48 @@ def test_invoke_stream_model(): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_invoke_model_with_system(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.invoke( - model='ernie-bot', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - SystemPromptMessage( - content='你是Kasumi' - ), - UserPromptMessage( - content='你是谁?' - ) - ], + model="ernie-bot", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[SystemPromptMessage(content="你是Kasumi"), UserPromptMessage(content="你是谁?")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) - assert 'kasumi' in response.message.content.lower() + assert "kasumi" in response.message.content.lower() + def test_invoke_with_search(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.invoke( - model='ernie-bot', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='北京今天的天气怎么样' - ) - ], + model="ernie-bot", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'disable_search': True, + "temperature": 0.7, + "top_p": 1.0, + "disable_search": True, }, stop=[], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) - total_message = '' + total_message = "" for chunk in response: assert isinstance(chunk, LLMResultChunk) assert isinstance(chunk.delta, LLMResultChunkDelta) @@ -247,25 +196,19 @@ def test_invoke_with_search(): assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True # there should be 对不起、我不能、不支持…… - assert ('不' in total_message or '抱歉' in total_message or '无法' in total_message) + assert "不" in total_message or "抱歉" in total_message or "无法" in total_message + def test_get_num_tokens(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.get_num_tokens( - model='ernie-bot', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - tools=[] + model="ernie-bot", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + tools=[], ) assert isinstance(response, int) - assert response == 10 \ No newline at end of file + assert response == 10 diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_provider.py b/api/tests/integration_tests/model_runtime/wenxin/test_provider.py index 8922aa1868..337c3d2a80 100644 --- a/api/tests/integration_tests/model_runtime/wenxin/test_provider.py +++ b/api/tests/integration_tests/model_runtime/wenxin/test_provider.py @@ -10,16 +10,8 @@ def test_validate_provider_credentials(): provider = WenxinProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={ - 'api_key': 'hahahaha', - 'secret_key': 'hahahaha' - } - ) + provider.validate_provider_credentials(credentials={"api_key": "hahahaha", "secret_key": "hahahaha"}) provider.validate_provider_credentials( - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - } + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")} ) diff --git a/api/tests/integration_tests/model_runtime/xinference/test_embeddings.py b/api/tests/integration_tests/model_runtime/xinference/test_embeddings.py index f0a5151f3d..8e778d005a 100644 --- a/api/tests/integration_tests/model_runtime/xinference/test_embeddings.py +++ b/api/tests/integration_tests/model_runtime/xinference/test_embeddings.py @@ -8,61 +8,57 @@ from core.model_runtime.model_providers.xinference.text_embedding.text_embedding from tests.integration_tests.model_runtime.__mock.xinference import MOCK, setup_xinference_mock -@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True) def test_validate_credentials(setup_xinference_mock): model = XinferenceTextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='bge-base-en', + model="bge-base-en", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': 'www ' + os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID') - } + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": "www " + os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"), + }, ) model.validate_credentials( - model='bge-base-en', + model="bge-base-en", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID') - } + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"), + }, ) -@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True) def test_invoke_model(setup_xinference_mock): model = XinferenceTextEmbeddingModel() result = model.invoke( - model='bge-base-en', + model="bge-base-en", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"), }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) assert len(result.embeddings) == 2 assert result.usage.total_tokens > 0 + def test_get_num_tokens(): model = XinferenceTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='bge-base-en', + model="bge-base-en", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/xinference/test_llm.py b/api/tests/integration_tests/model_runtime/xinference/test_llm.py index 47730406de..48d1ae323d 100644 --- a/api/tests/integration_tests/model_runtime/xinference/test_llm.py +++ b/api/tests/integration_tests/model_runtime/xinference/test_llm.py @@ -20,92 +20,84 @@ from tests.integration_tests.model_runtime.__mock.openai import setup_openai_moc from tests.integration_tests.model_runtime.__mock.xinference import setup_xinference_mock -@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["chat", "none"]], indirect=True) def test_validate_credentials_for_chat_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='ChatGLM3', + model="ChatGLM3", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': 'www ' + os.environ.get('XINFERENCE_CHAT_MODEL_UID') - } + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": "www " + os.environ.get("XINFERENCE_CHAT_MODEL_UID"), + }, ) with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='aaaaa', - credentials={ - 'server_url': '', - 'model_uid': '' - } - ) + model.validate_credentials(model="aaaaa", credentials={"server_url": "", "model_uid": ""}) model.validate_credentials( - model='ChatGLM3', + model="ChatGLM3", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID') - } + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_CHAT_MODEL_UID"), + }, ) -@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["chat", "none"]], indirect=True) def test_invoke_chat_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() response = model.invoke( - model='ChatGLM3', + model="ChatGLM3", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_CHAT_MODEL_UID"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 -@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["chat", "none"]], indirect=True) def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() response = model.invoke( - model='ChatGLM3', + model="ChatGLM3", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_CHAT_MODEL_UID"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -114,6 +106,8 @@ def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock): assert isinstance(chunk.delta, LLMResultChunkDelta) assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + """ Funtion calling of xinference does not support stream mode currently """ @@ -168,7 +162,7 @@ def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock): # ) # assert isinstance(response, Generator) - + # call: LLMResultChunk = None # chunks = [] @@ -241,86 +235,75 @@ def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock): # assert response.usage.total_tokens > 0 # assert response.message.tool_calls[0].function.name == 'get_current_weather' -@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["completion", "none"]], indirect=True) def test_validate_credentials_for_generation_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='alapaca', + model="alapaca", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': 'www ' + os.environ.get('XINFERENCE_GENERATION_MODEL_UID') - } + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": "www " + os.environ.get("XINFERENCE_GENERATION_MODEL_UID"), + }, ) with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='alapaca', - credentials={ - 'server_url': '', - 'model_uid': '' - } - ) + model.validate_credentials(model="alapaca", credentials={"server_url": "", "model_uid": ""}) model.validate_credentials( - model='alapaca', + model="alapaca", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID') - } + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"), + }, ) -@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["completion", "none"]], indirect=True) def test_invoke_generation_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() response = model.invoke( - model='alapaca', + model="alapaca", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"), }, - prompt_messages=[ - UserPromptMessage( - content='the United States is' - ) - ], + prompt_messages=[UserPromptMessage(content="the United States is")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 -@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["completion", "none"]], indirect=True) def test_invoke_stream_generation_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() response = model.invoke( - model='alapaca', + model="alapaca", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"), }, - prompt_messages=[ - UserPromptMessage( - content='the United States is' - ) - ], + prompt_messages=[UserPromptMessage(content="the United States is")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -330,68 +313,54 @@ def test_invoke_stream_generation_model(setup_openai_mock, setup_xinference_mock assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_get_num_tokens(): model = XinferenceAILargeLanguageModel() num_tokens = model.get_num_tokens( - model='ChatGLM3', + model="ChatGLM3", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], tools=[ PromptMessageTool( - name='get_current_weather', - description='Get the current weather in a given location', + name="get_current_weather", + description="Get the current weather in a given location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ) - ] + ], ) assert isinstance(num_tokens, int) assert num_tokens == 77 num_tokens = model.get_num_tokens( - model='ChatGLM3', + model="ChatGLM3", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], ) assert isinstance(num_tokens, int) - assert num_tokens == 21 \ No newline at end of file + assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/xinference/test_rerank.py b/api/tests/integration_tests/model_runtime/xinference/test_rerank.py index 9012c16a7e..71ac4eef7c 100644 --- a/api/tests/integration_tests/model_runtime/xinference/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/xinference/test_rerank.py @@ -8,44 +8,42 @@ from core.model_runtime.model_providers.xinference.rerank.rerank import Xinferen from tests.integration_tests.model_runtime.__mock.xinference import MOCK, setup_xinference_mock -@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True) def test_validate_credentials(setup_xinference_mock): model = XinferenceRerankModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='bge-reranker-base', - credentials={ - 'server_url': 'awdawdaw', - 'model_uid': os.environ.get('XINFERENCE_RERANK_MODEL_UID') - } + model="bge-reranker-base", + credentials={"server_url": "awdawdaw", "model_uid": os.environ.get("XINFERENCE_RERANK_MODEL_UID")}, ) model.validate_credentials( - model='bge-reranker-base', + model="bge-reranker-base", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_RERANK_MODEL_UID') - } + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_RERANK_MODEL_UID"), + }, ) -@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True) def test_invoke_model(setup_xinference_mock): model = XinferenceRerankModel() result = model.invoke( - model='bge-reranker-base', + model="bge-reranker-base", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_RERANK_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_RERANK_MODEL_UID"), }, query="Who is Kasumi?", docs=[ - "Kasumi is a girl's name of Japanese origin meaning \"mist\".", + 'Kasumi is a girl\'s name of Japanese origin meaning "mist".', "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", - "and she leads a team named PopiParty." + "and she leads a team named PopiParty.", ], - score_threshold=0.8 + score_threshold=0.8, ) assert isinstance(result, RerankResult) diff --git a/api/tests/integration_tests/model_runtime/zhinao/test_llm.py b/api/tests/integration_tests/model_runtime/zhinao/test_llm.py index 47a5b6cae2..4ca1b86476 100644 --- a/api/tests/integration_tests/model_runtime/zhinao/test_llm.py +++ b/api/tests/integration_tests/model_runtime/zhinao/test_llm.py @@ -13,41 +13,22 @@ def test_validate_credentials(): model = ZhinaoLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='360gpt2-pro', - credentials={ - 'api_key': 'invalid_key' - } - ) + model.validate_credentials(model="360gpt2-pro", credentials={"api_key": "invalid_key"}) - model.validate_credentials( - model='360gpt2-pro', - credentials={ - 'api_key': os.environ.get('ZHINAO_API_KEY') - } - ) + model.validate_credentials(model="360gpt2-pro", credentials={"api_key": os.environ.get("ZHINAO_API_KEY")}) def test_invoke_model(): model = ZhinaoLargeLanguageModel() response = model.invoke( - model='360gpt2-pro', - credentials={ - 'api_key': os.environ.get('ZHINAO_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 10 - }, - stop=['How'], + model="360gpt2-pro", + credentials={"api_key": os.environ.get("ZHINAO_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 0.5, "max_tokens": 10}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -58,22 +39,12 @@ def test_invoke_stream_model(): model = ZhinaoLargeLanguageModel() response = model.invoke( - model='360gpt2-pro', - credentials={ - 'api_key': os.environ.get('ZHINAO_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 100, - 'seed': 1234 - }, + model="360gpt2-pro", + credentials={"api_key": os.environ.get("ZHINAO_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -89,18 +60,14 @@ def test_get_num_tokens(): model = ZhinaoLargeLanguageModel() num_tokens = model.get_num_tokens( - model='360gpt2-pro', - credentials={ - 'api_key': os.environ.get('ZHINAO_API_KEY') - }, + model="360gpt2-pro", + credentials={"api_key": os.environ.get("ZHINAO_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/zhinao/test_provider.py b/api/tests/integration_tests/model_runtime/zhinao/test_provider.py index 87b0e6c2d9..c22f797919 100644 --- a/api/tests/integration_tests/model_runtime/zhinao/test_provider.py +++ b/api/tests/integration_tests/model_runtime/zhinao/test_provider.py @@ -10,12 +10,6 @@ def test_validate_provider_credentials(): provider = ZhinaoProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={ - 'api_key': os.environ.get('ZHINAO_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("ZHINAO_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py b/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py index 0f92b50cb0..20380513ea 100644 --- a/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py @@ -18,41 +18,22 @@ def test_validate_credentials(): model = ZhipuAILargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='chatglm_turbo', - credentials={ - 'api_key': 'invalid_key' - } - ) + model.validate_credentials(model="chatglm_turbo", credentials={"api_key": "invalid_key"}) - model.validate_credentials( - model='chatglm_turbo', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - } - ) + model.validate_credentials(model="chatglm_turbo", credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}) def test_invoke_model(): model = ZhipuAILargeLanguageModel() response = model.invoke( - model='chatglm_turbo', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], - model_parameters={ - 'temperature': 0.9, - 'top_p': 0.7 - }, - stop=['How'], + model="chatglm_turbo", + credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 0.9, "top_p": 0.7}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -63,21 +44,12 @@ def test_invoke_stream_model(): model = ZhipuAILargeLanguageModel() response = model.invoke( - model='chatglm_turbo', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.9, - 'top_p': 0.7 - }, + model="chatglm_turbo", + credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.9, "top_p": 0.7}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -93,63 +65,45 @@ def test_get_num_tokens(): model = ZhipuAILargeLanguageModel() num_tokens = model.get_num_tokens( - model='chatglm_turbo', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - }, + model="chatglm_turbo", + credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 14 + def test_get_tools_num_tokens(): model = ZhipuAILargeLanguageModel() num_tokens = model.get_num_tokens( - model='tools', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - }, + model="tools", + credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, tools=[ PromptMessageTool( - name='get_current_weather', - description='Get the current weather in a given location', + name="get_current_weather", + description="Get the current weather in a given location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ) ], prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 88 diff --git a/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py b/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py index 51b9cccf2e..cb5bc0b20a 100644 --- a/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py +++ b/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py @@ -10,12 +10,6 @@ def test_validate_provider_credentials(): provider = ZhipuaiProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py index 7308c57296..9c97c91ecb 100644 --- a/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py @@ -11,34 +11,19 @@ def test_validate_credentials(): model = ZhipuAITextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='text_embedding', - credentials={ - 'api_key': 'invalid_key' - } - ) + model.validate_credentials(model="text_embedding", credentials={"api_key": "invalid_key"}) - model.validate_credentials( - model='text_embedding', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - } - ) + model.validate_credentials(model="text_embedding", credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}) def test_invoke_model(): model = ZhipuAITextEmbeddingModel() result = model.invoke( - model='text_embedding', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - }, - texts=[ - "hello", - "world" - ], - user="abc-123" + model="text_embedding", + credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -50,14 +35,7 @@ def test_get_num_tokens(): model = ZhipuAITextEmbeddingModel() num_tokens = model.get_num_tokens( - model='text_embedding', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - }, - texts=[ - "hello", - "world" - ] + model="text_embedding", credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, texts=["hello", "world"] ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/tools/__mock/http.py b/api/tests/integration_tests/tools/__mock/http.py index 41bb3daeb5..4dfc530010 100644 --- a/api/tests/integration_tests/tools/__mock/http.py +++ b/api/tests/integration_tests/tools/__mock/http.py @@ -7,20 +7,17 @@ from _pytest.monkeypatch import MonkeyPatch class MockedHttp: - def httpx_request(method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD'], - url: str, **kwargs) -> httpx.Response: + def httpx_request( + method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs + ) -> httpx.Response: """ Mocked httpx.request """ request = httpx.Request( - method, - url, - params=kwargs.get('params'), - headers=kwargs.get('headers'), - cookies=kwargs.get('cookies') + method, url, params=kwargs.get("params"), headers=kwargs.get("headers"), cookies=kwargs.get("cookies") ) - data = kwargs.get('data', None) - resp = json.dumps(data).encode('utf-8') if data else b'OK' + data = kwargs.get("data", None) + resp = json.dumps(data).encode("utf-8") if data else b"OK" response = httpx.Response( status_code=200, request=request, diff --git a/api/tests/integration_tests/tools/__mock_server/openapi_todo.py b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py index ba14d365c5..83f4d70ce9 100644 --- a/api/tests/integration_tests/tools/__mock_server/openapi_todo.py +++ b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py @@ -10,6 +10,7 @@ todos_data = { "user1": ["Go for a run", "Read a book"], } + class TodosResource(Resource): def get(self, username): todos = todos_data.get(username, []) @@ -32,7 +33,8 @@ class TodosResource(Resource): return {"error": "Invalid todo index"}, 400 -api.add_resource(TodosResource, '/todos/') -if __name__ == '__main__': +api.add_resource(TodosResource, "/todos/") + +if __name__ == "__main__": app.run(port=5003, debug=True) diff --git a/api/tests/integration_tests/tools/api_tool/test_api_tool.py b/api/tests/integration_tests/tools/api_tool/test_api_tool.py index f6e7b153dd..09729a961e 100644 --- a/api/tests/integration_tests/tools/api_tool/test_api_tool.py +++ b/api/tests/integration_tests/tools/api_tool/test_api_tool.py @@ -3,37 +3,40 @@ from core.tools.tool.tool import Tool from tests.integration_tests.tools.__mock.http import setup_http_mock tool_bundle = { - 'server_url': 'http://www.example.com/{path_param}', - 'method': 'post', - 'author': '', - 'openapi': {'parameters': [{'in': 'path', 'name': 'path_param'}, - {'in': 'query', 'name': 'query_param'}, - {'in': 'cookie', 'name': 'cookie_param'}, - {'in': 'header', 'name': 'header_param'}, - ], - 'requestBody': { - 'content': {'application/json': {'schema': {'properties': {'body_param': {'type': 'string'}}}}}} - }, - 'parameters': [] + "server_url": "http://www.example.com/{path_param}", + "method": "post", + "author": "", + "openapi": { + "parameters": [ + {"in": "path", "name": "path_param"}, + {"in": "query", "name": "query_param"}, + {"in": "cookie", "name": "cookie_param"}, + {"in": "header", "name": "header_param"}, + ], + "requestBody": { + "content": {"application/json": {"schema": {"properties": {"body_param": {"type": "string"}}}}} + }, + }, + "parameters": [], } parameters = { - 'path_param': 'p_param', - 'query_param': 'q_param', - 'cookie_param': 'c_param', - 'header_param': 'h_param', - 'body_param': 'b_param', + "path_param": "p_param", + "query_param": "q_param", + "cookie_param": "c_param", + "header_param": "h_param", + "body_param": "b_param", } def test_api_tool(setup_http_mock): - tool = ApiTool(api_bundle=tool_bundle, runtime=Tool.Runtime(credentials={'auth_type': 'none'})) + tool = ApiTool(api_bundle=tool_bundle, runtime=Tool.Runtime(credentials={"auth_type": "none"})) headers = tool.assembling_request(parameters) response = tool.do_http_request(tool.api_bundle.server_url, tool.api_bundle.method, headers, parameters) assert response.status_code == 200 - assert '/p_param' == response.request.url.path - assert b'query_param=q_param' == response.request.url.query - assert 'h_param' == response.request.headers.get('header_param') - assert 'application/json' == response.request.headers.get('content-type') - assert 'cookie_param=c_param' == response.request.headers.get('cookie') - assert 'b_param' in response.content.decode() + assert "/p_param" == response.request.url.path + assert b"query_param=q_param" == response.request.url.query + assert "h_param" == response.request.headers.get("header_param") + assert "application/json" == response.request.headers.get("content-type") + assert "cookie_param=c_param" == response.request.headers.get("cookie") + assert "b_param" in response.content.decode() diff --git a/api/tests/integration_tests/tools/test_all_provider.py b/api/tests/integration_tests/tools/test_all_provider.py index 2811bc816d..2dfce749b3 100644 --- a/api/tests/integration_tests/tools/test_all_provider.py +++ b/api/tests/integration_tests/tools/test_all_provider.py @@ -7,16 +7,17 @@ provider_names = [provider.identity.name for provider in provider_generator] ToolManager.clear_builtin_providers_cache() provider_generator = ToolManager.list_builtin_providers() -@pytest.mark.parametrize('name', provider_names) + +@pytest.mark.parametrize("name", provider_names) def test_tool_providers(benchmark, name): """ Test that all tool providers can be loaded """ - + def test(generator): try: return next(generator) except StopIteration: return None - - benchmark.pedantic(test, args=(provider_generator,), iterations=1, rounds=1) \ No newline at end of file + + benchmark.pedantic(test, args=(provider_generator,), iterations=1, rounds=1) diff --git a/api/tests/integration_tests/utils/parent_class.py b/api/tests/integration_tests/utils/parent_class.py index 39fc95256e..6a6de1cc41 100644 --- a/api/tests/integration_tests/utils/parent_class.py +++ b/api/tests/integration_tests/utils/parent_class.py @@ -3,4 +3,4 @@ class ParentClass: self.name = name def get_name(self): - return self.name \ No newline at end of file + return self.name diff --git a/api/tests/integration_tests/utils/test_module_import_helper.py b/api/tests/integration_tests/utils/test_module_import_helper.py index 256c9a911f..7d32f5ae66 100644 --- a/api/tests/integration_tests/utils/test_module_import_helper.py +++ b/api/tests/integration_tests/utils/test_module_import_helper.py @@ -7,26 +7,26 @@ from tests.integration_tests.utils.parent_class import ParentClass def test_loading_subclass_from_source(): current_path = os.getcwd() module = load_single_subclass_from_source( - module_name='ChildClass', - script_path=os.path.join(current_path, 'child_class.py'), - parent_type=ParentClass) - assert module and module.__name__ == 'ChildClass' + module_name="ChildClass", script_path=os.path.join(current_path, "child_class.py"), parent_type=ParentClass + ) + assert module and module.__name__ == "ChildClass" def test_load_import_module_from_source(): current_path = os.getcwd() module = import_module_from_source( - module_name='ChildClass', - py_file_path=os.path.join(current_path, 'child_class.py')) - assert module and module.__name__ == 'ChildClass' + module_name="ChildClass", py_file_path=os.path.join(current_path, "child_class.py") + ) + assert module and module.__name__ == "ChildClass" def test_lazy_loading_subclass_from_source(): current_path = os.getcwd() clz = load_single_subclass_from_source( - module_name='LazyLoadChildClass', - script_path=os.path.join(current_path, 'lazy_load_class.py'), + module_name="LazyLoadChildClass", + script_path=os.path.join(current_path, "lazy_load_class.py"), parent_type=ParentClass, - use_lazy_loader=True) - instance = clz('dify') - assert instance.get_name() == 'dify' + use_lazy_loader=True, + ) + instance = clz("dify") + assert instance.get_name() == "dify" diff --git a/api/tests/integration_tests/vdb/__mock/tcvectordb.py b/api/tests/integration_tests/vdb/__mock/tcvectordb.py index f8165cba94..571c1e3d44 100644 --- a/api/tests/integration_tests/vdb/__mock/tcvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/tcvectordb.py @@ -13,11 +13,15 @@ from xinference_client.types import Embedding class MockTcvectordbClass: - - def VectorDBClient(self, url=None, username='', key='', - read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY, - timeout=5, - adapter: HTTPAdapter = None): + def VectorDBClient( + self, + url=None, + username="", + key="", + read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY, + timeout=5, + adapter: HTTPAdapter = None, + ): self._conn = None self._read_consistency = read_consistency @@ -26,105 +30,96 @@ class MockTcvectordbClass: Database( conn=self._conn, read_consistency=self._read_consistency, - name='dify', - )] + name="dify", + ) + ] def list_collections(self, timeout: Optional[float] = None) -> list[Collection]: return [] def drop_collection(self, name: str, timeout: Optional[float] = None): - return { - "code": 0, - "msg": "operation success" - } + return {"code": 0, "msg": "operation success"} def create_collection( - self, - name: str, - shard: int, - replicas: int, - description: str, - index: Index, - embedding: Embedding = None, - timeout: float = None, + self, + name: str, + shard: int, + replicas: int, + description: str, + index: Index, + embedding: Embedding = None, + timeout: float = None, ) -> Collection: - return Collection(self, name, shard, replicas, description, index, embedding=embedding, - read_consistency=self._read_consistency, timeout=timeout) - - def describe_collection(self, name: str, timeout: Optional[float] = None) -> Collection: - collection = Collection( + return Collection( self, name, - shard=1, - replicas=2, - description=name, - timeout=timeout + shard, + replicas, + description, + index, + embedding=embedding, + read_consistency=self._read_consistency, + timeout=timeout, ) + + def describe_collection(self, name: str, timeout: Optional[float] = None) -> Collection: + collection = Collection(self, name, shard=1, replicas=2, description=name, timeout=timeout) return collection def collection_upsert( - self, - documents: list[Document], - timeout: Optional[float] = None, - build_index: bool = True, - **kwargs + self, documents: list[Document], timeout: Optional[float] = None, build_index: bool = True, **kwargs ): - return { - "code": 0, - "msg": "operation success" - } + return {"code": 0, "msg": "operation success"} def collection_search( - self, - vectors: list[list[float]], - filter: Filter = None, - params=None, - retrieve_vector: bool = False, - limit: int = 10, - output_fields: Optional[list[str]] = None, - timeout: Optional[float] = None, + self, + vectors: list[list[float]], + filter: Filter = None, + params=None, + retrieve_vector: bool = False, + limit: int = 10, + output_fields: Optional[list[str]] = None, + timeout: Optional[float] = None, ) -> list[list[dict]]: - return [[{'metadata': '{"doc_id":"foo1"}', 'text': 'text', 'doc_id': 'foo1', 'score': 0.1}]] + return [[{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}]] def collection_query( - self, - document_ids: Optional[list] = None, - retrieve_vector: bool = False, - limit: Optional[int] = None, - offset: Optional[int] = None, - filter: Optional[Filter] = None, - output_fields: Optional[list[str]] = None, - timeout: Optional[float] = None, + self, + document_ids: Optional[list] = None, + retrieve_vector: bool = False, + limit: Optional[int] = None, + offset: Optional[int] = None, + filter: Optional[Filter] = None, + output_fields: Optional[list[str]] = None, + timeout: Optional[float] = None, ) -> list[dict]: - return [{'metadata': '{"doc_id":"foo1"}', 'text': 'text', 'doc_id': 'foo1', 'score': 0.1}] + return [{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}] def collection_delete( - self, - document_ids: list[str] = None, - filter: Filter = None, - timeout: float = None, + self, + document_ids: list[str] = None, + filter: Filter = None, + timeout: float = None, ): - return { - "code": 0, - "msg": "operation success" - } + return {"code": 0, "msg": "operation success"} -MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + @pytest.fixture def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch): if MOCK: - monkeypatch.setattr(VectorDBClient, '__init__', MockTcvectordbClass.VectorDBClient) - monkeypatch.setattr(VectorDBClient, 'list_databases', MockTcvectordbClass.list_databases) - monkeypatch.setattr(Database, 'collection', MockTcvectordbClass.describe_collection) - monkeypatch.setattr(Database, 'list_collections', MockTcvectordbClass.list_collections) - monkeypatch.setattr(Database, 'drop_collection', MockTcvectordbClass.drop_collection) - monkeypatch.setattr(Database, 'create_collection', MockTcvectordbClass.create_collection) - monkeypatch.setattr(Collection, 'upsert', MockTcvectordbClass.collection_upsert) - monkeypatch.setattr(Collection, 'search', MockTcvectordbClass.collection_search) - monkeypatch.setattr(Collection, 'query', MockTcvectordbClass.collection_query) - monkeypatch.setattr(Collection, 'delete', MockTcvectordbClass.collection_delete) + monkeypatch.setattr(VectorDBClient, "__init__", MockTcvectordbClass.VectorDBClient) + monkeypatch.setattr(VectorDBClient, "list_databases", MockTcvectordbClass.list_databases) + monkeypatch.setattr(Database, "collection", MockTcvectordbClass.describe_collection) + monkeypatch.setattr(Database, "list_collections", MockTcvectordbClass.list_collections) + monkeypatch.setattr(Database, "drop_collection", MockTcvectordbClass.drop_collection) + monkeypatch.setattr(Database, "create_collection", MockTcvectordbClass.create_collection) + monkeypatch.setattr(Collection, "upsert", MockTcvectordbClass.collection_upsert) + monkeypatch.setattr(Collection, "search", MockTcvectordbClass.collection_search) + monkeypatch.setattr(Collection, "query", MockTcvectordbClass.collection_query) + monkeypatch.setattr(Collection, "delete", MockTcvectordbClass.collection_delete) yield diff --git a/api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py b/api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py index d6067af73b..970b98edc3 100644 --- a/api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py +++ b/api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py @@ -26,6 +26,7 @@ class AnalyticdbVectorTest(AbstractVectorTest): def run_all_tests(self): self.vector.delete() return super().run_all_tests() - + + def test_chroma_vector(setup_mock_redis): - AnalyticdbVectorTest().run_all_tests() \ No newline at end of file + AnalyticdbVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/vdb/chroma/test_chroma.py b/api/tests/integration_tests/vdb/chroma/test_chroma.py index 033f9a54da..ac7b5cbda4 100644 --- a/api/tests/integration_tests/vdb/chroma/test_chroma.py +++ b/api/tests/integration_tests/vdb/chroma/test_chroma.py @@ -14,13 +14,13 @@ class ChromaVectorTest(AbstractVectorTest): self.vector = ChromaVector( collection_name=self.collection_name, config=ChromaConfig( - host='localhost', + host="localhost", port=8000, tenant=chromadb.DEFAULT_TENANT, database=chromadb.DEFAULT_DATABASE, auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider", auth_credentials="difyai123456", - ) + ), ) def search_by_full_text(self): diff --git a/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py b/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py index b1c1cc10d9..2a0c1bb038 100644 --- a/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py +++ b/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py @@ -8,16 +8,11 @@ from tests.integration_tests.vdb.test_vector_store import ( class ElasticSearchVectorTest(AbstractVectorTest): def __init__(self): super().__init__() - self.attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash'] + self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"] self.vector = ElasticSearchVector( index_name=self.collection_name.lower(), - config=ElasticSearchConfig( - host='http://localhost', - port='9200', - username='elastic', - password='elastic' - ), - attributes=self.attributes + config=ElasticSearchConfig(host="http://localhost", port="9200", username="elastic", password="elastic"), + attributes=self.attributes, ) diff --git a/api/tests/integration_tests/vdb/milvus/test_milvus.py b/api/tests/integration_tests/vdb/milvus/test_milvus.py index 9c0917ef30..7b5f19ea62 100644 --- a/api/tests/integration_tests/vdb/milvus/test_milvus.py +++ b/api/tests/integration_tests/vdb/milvus/test_milvus.py @@ -12,11 +12,11 @@ class MilvusVectorTest(AbstractVectorTest): self.vector = MilvusVector( collection_name=self.collection_name, config=MilvusConfig( - host='localhost', + host="localhost", port=19530, - user='root', - password='Milvus', - ) + user="root", + password="Milvus", + ), ) def search_by_full_text(self): @@ -25,7 +25,7 @@ class MilvusVectorTest(AbstractVectorTest): assert len(hits_by_full_text) == 0 def get_ids_by_metadata_field(self): - ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) assert len(ids) == 1 diff --git a/api/tests/integration_tests/vdb/myscale/test_myscale.py b/api/tests/integration_tests/vdb/myscale/test_myscale.py index b6260d549a..55b2fde427 100644 --- a/api/tests/integration_tests/vdb/myscale/test_myscale.py +++ b/api/tests/integration_tests/vdb/myscale/test_myscale.py @@ -21,7 +21,7 @@ class MyScaleVectorTest(AbstractVectorTest): ) def get_ids_by_metadata_field(self): - ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) assert len(ids) == 1 diff --git a/api/tests/integration_tests/vdb/opensearch/test_opensearch.py b/api/tests/integration_tests/vdb/opensearch/test_opensearch.py index ea1e05da90..a99b81d41e 100644 --- a/api/tests/integration_tests/vdb/opensearch/test_opensearch.py +++ b/api/tests/integration_tests/vdb/opensearch/test_opensearch.py @@ -29,54 +29,55 @@ class TestOpenSearchVector: self.example_doc_id = "example_doc_id" self.vector = OpenSearchVector( collection_name=self.collection_name, - config=OpenSearchConfig( - host='localhost', - port=9200, - user='admin', - password='password', - secure=False - ) + config=OpenSearchConfig(host="localhost", port=9200, user="admin", password="password", secure=False), ) self.vector._client = MagicMock() - @pytest.mark.parametrize("search_response, expected_length, expected_doc_id", [ - ({ - 'hits': { - 'total': {'value': 1}, - 'hits': [ - {'_source': {'page_content': get_example_text(), 'metadata': {"document_id": "example_doc_id"}}} - ] - } - }, 1, "example_doc_id"), - ({ - 'hits': { - 'total': {'value': 0}, - 'hits': [] - } - }, 0, None) - ]) + @pytest.mark.parametrize( + "search_response, expected_length, expected_doc_id", + [ + ( + { + "hits": { + "total": {"value": 1}, + "hits": [ + { + "_source": { + "page_content": get_example_text(), + "metadata": {"document_id": "example_doc_id"}, + } + } + ], + } + }, + 1, + "example_doc_id", + ), + ({"hits": {"total": {"value": 0}, "hits": []}}, 0, None), + ], + ) def test_search_by_full_text(self, search_response, expected_length, expected_doc_id): self.vector._client.search.return_value = search_response hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) assert len(hits_by_full_text) == expected_length if expected_length > 0: - assert hits_by_full_text[0].metadata['document_id'] == expected_doc_id + assert hits_by_full_text[0].metadata["document_id"] == expected_doc_id def test_search_by_vector(self): vector = [0.1] * 128 mock_response = { - 'hits': { - 'total': {'value': 1}, - 'hits': [ + "hits": { + "total": {"value": 1}, + "hits": [ { - '_source': { + "_source": { Field.CONTENT_KEY.value: get_example_text(), - Field.METADATA_KEY.value: {"document_id": self.example_doc_id} + Field.METADATA_KEY.value: {"document_id": self.example_doc_id}, }, - '_score': 1.0 + "_score": 1.0, } - ] + ], } } self.vector._client.search.return_value = mock_response @@ -85,53 +86,45 @@ class TestOpenSearchVector: print("Hits by vector:", hits_by_vector) print("Expected document ID:", self.example_doc_id) - print("Actual document ID:", hits_by_vector[0].metadata['document_id'] if hits_by_vector else "No hits") + print("Actual document ID:", hits_by_vector[0].metadata["document_id"] if hits_by_vector else "No hits") assert len(hits_by_vector) > 0, f"Expected at least one hit, got {len(hits_by_vector)}" - assert hits_by_vector[0].metadata['document_id'] == self.example_doc_id, \ - f"Expected document ID {self.example_doc_id}, got {hits_by_vector[0].metadata['document_id']}" + assert ( + hits_by_vector[0].metadata["document_id"] == self.example_doc_id + ), f"Expected document ID {self.example_doc_id}, got {hits_by_vector[0].metadata['document_id']}" def test_get_ids_by_metadata_field(self): - mock_response = { - 'hits': { - 'total': {'value': 1}, - 'hits': [{'_id': 'mock_id'}] - } - } + mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}} self.vector._client.search.return_value = mock_response doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id}) embedding = [0.1] * 128 - with patch('opensearchpy.helpers.bulk') as mock_bulk: + with patch("opensearchpy.helpers.bulk") as mock_bulk: mock_bulk.return_value = ([], []) self.vector.add_texts([doc], [embedding]) - ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) assert len(ids) == 1 - assert ids[0] == 'mock_id' + assert ids[0] == "mock_id" def test_add_texts(self): - self.vector._client.index.return_value = {'result': 'created'} + self.vector._client.index.return_value = {"result": "created"} doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id}) embedding = [0.1] * 128 - with patch('opensearchpy.helpers.bulk') as mock_bulk: + with patch("opensearchpy.helpers.bulk") as mock_bulk: mock_bulk.return_value = ([], []) self.vector.add_texts([doc], [embedding]) - mock_response = { - 'hits': { - 'total': {'value': 1}, - 'hits': [{'_id': 'mock_id'}] - } - } + mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}} self.vector._client.search.return_value = mock_response - ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) assert len(ids) == 1 - assert ids[0] == 'mock_id' + assert ids[0] == "mock_id" + @pytest.mark.usefixtures("setup_mock_redis") class TestOpenSearchVectorWithRedis: @@ -141,11 +134,11 @@ class TestOpenSearchVectorWithRedis: def test_search_by_full_text(self): self.tester.setup_method() search_response = { - 'hits': { - 'total': {'value': 1}, - 'hits': [ - {'_source': {'page_content': get_example_text(), 'metadata': {"document_id": "example_doc_id"}}} - ] + "hits": { + "total": {"value": 1}, + "hits": [ + {"_source": {"page_content": get_example_text(), "metadata": {"document_id": "example_doc_id"}}} + ], } } expected_length = 1 diff --git a/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py b/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py index e6ce8aab3d..6b33217d15 100644 --- a/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py +++ b/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py @@ -12,13 +12,13 @@ class PGVectoRSVectorTest(AbstractVectorTest): self.vector = PGVectoRS( collection_name=self.collection_name.lower(), config=PgvectoRSConfig( - host='localhost', + host="localhost", port=5431, - user='postgres', - password='difyai123456', - database='dify', + user="postgres", + password="difyai123456", + database="dify", ), - dim=128 + dim=128, ) def search_by_full_text(self): @@ -27,8 +27,9 @@ class PGVectoRSVectorTest(AbstractVectorTest): assert len(hits_by_full_text) == 0 def get_ids_by_metadata_field(self): - ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) assert len(ids) == 1 + def test_pgvecot_rs(setup_mock_redis): PGVectoRSVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/vdb/qdrant/test_qdrant.py b/api/tests/integration_tests/vdb/qdrant/test_qdrant.py index 34beb25d45..61d9a9e712 100644 --- a/api/tests/integration_tests/vdb/qdrant/test_qdrant.py +++ b/api/tests/integration_tests/vdb/qdrant/test_qdrant.py @@ -8,14 +8,14 @@ from tests.integration_tests.vdb.test_vector_store import ( class QdrantVectorTest(AbstractVectorTest): def __init__(self): super().__init__() - self.attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash'] + self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"] self.vector = QdrantVector( collection_name=self.collection_name, group_id=self.dataset_id, config=QdrantConfig( - endpoint='http://localhost:6333', - api_key='difyai123456', - ) + endpoint="http://localhost:6333", + api_key="difyai123456", + ), ) diff --git a/api/tests/integration_tests/vdb/tcvectordb/test_tencent.py b/api/tests/integration_tests/vdb/tcvectordb/test_tencent.py index 8937fe0ea1..1b9466e27f 100644 --- a/api/tests/integration_tests/vdb/tcvectordb/test_tencent.py +++ b/api/tests/integration_tests/vdb/tcvectordb/test_tencent.py @@ -7,18 +7,22 @@ from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, ge mock_client = MagicMock() mock_client.list_databases.return_value = [{"name": "test"}] + class TencentVectorTest(AbstractVectorTest): def __init__(self): super().__init__() - self.vector = TencentVector("dify", TencentConfig( - url="http://127.0.0.1", - api_key="dify", - timeout=30, - username="dify", - database="dify", - shard=1, - replicas=2, - )) + self.vector = TencentVector( + "dify", + TencentConfig( + url="http://127.0.0.1", + api_key="dify", + timeout=30, + username="dify", + database="dify", + shard=1, + replicas=2, + ), + ) def search_by_vector(self): hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding) @@ -28,8 +32,6 @@ class TencentVectorTest(AbstractVectorTest): hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) assert len(hits_by_full_text) == 0 -def test_tencent_vector(setup_mock_redis,setup_tcvectordb_mock): + +def test_tencent_vector(setup_mock_redis, setup_tcvectordb_mock): TencentVectorTest().run_all_tests() - - - diff --git a/api/tests/integration_tests/vdb/test_vector_store.py b/api/tests/integration_tests/vdb/test_vector_store.py index cb35822709..a11cd225b3 100644 --- a/api/tests/integration_tests/vdb/test_vector_store.py +++ b/api/tests/integration_tests/vdb/test_vector_store.py @@ -10,7 +10,7 @@ from models.dataset import Dataset def get_example_text() -> str: - return 'test_text' + return "test_text" def get_example_document(doc_id: str) -> Document: @@ -21,7 +21,7 @@ def get_example_document(doc_id: str) -> Document: "doc_hash": doc_id, "document_id": doc_id, "dataset_id": doc_id, - } + }, ) return doc @@ -45,7 +45,7 @@ class AbstractVectorTest: def __init__(self): self.vector = None self.dataset_id = str(uuid.uuid4()) - self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + '_test' + self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + "_test" self.example_doc_id = str(uuid.uuid4()) self.example_embedding = [1.001 * i for i in range(128)] @@ -58,12 +58,12 @@ class AbstractVectorTest: def search_by_vector(self): hits_by_vector: list[Document] = self.vector.search_by_vector(query_vector=self.example_embedding) assert len(hits_by_vector) == 1 - assert hits_by_vector[0].metadata['doc_id'] == self.example_doc_id + assert hits_by_vector[0].metadata["doc_id"] == self.example_doc_id def search_by_full_text(self): hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text()) assert len(hits_by_full_text) == 1 - assert hits_by_full_text[0].metadata['doc_id'] == self.example_doc_id + assert hits_by_full_text[0].metadata["doc_id"] == self.example_doc_id def delete_vector(self): self.vector.delete() @@ -76,14 +76,14 @@ class AbstractVectorTest: documents = [get_example_document(doc_id=str(uuid.uuid4())) for _ in range(batch_size)] embeddings = [self.example_embedding] * batch_size self.vector.add_texts(documents=documents, embeddings=embeddings) - return [doc.metadata['doc_id'] for doc in documents] + return [doc.metadata["doc_id"] for doc in documents] def text_exists(self): assert self.vector.text_exists(self.example_doc_id) def get_ids_by_metadata_field(self): with pytest.raises(NotImplementedError): - self.vector.get_ids_by_metadata_field(key='key', value='value') + self.vector.get_ids_by_metadata_field(key="key", value="value") def run_all_tests(self): self.create_vector() diff --git a/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py b/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py index 18e00dbedd..2a5320c7d5 100644 --- a/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py +++ b/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py @@ -10,15 +10,15 @@ from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, ge @pytest.fixture def tidb_vector(): return TiDBVector( - collection_name='test_collection', + collection_name="test_collection", config=TiDBVectorConfig( host="xxx.eu-central-1.xxx.aws.tidbcloud.com", port="4000", user="xxx.root", password="xxxxxx", database="dify", - program_name="langgenius/dify" - ) + program_name="langgenius/dify", + ), ) @@ -40,7 +40,7 @@ class TiDBVectorTest(AbstractVectorTest): assert len(hits_by_full_text) == 0 def get_ids_by_metadata_field(self): - ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) assert len(ids) == 0 @@ -50,12 +50,12 @@ def test_tidb_vector(setup_mock_redis, setup_tidbvector_mock, tidb_vector, mock_ @pytest.fixture def mock_session(): - with patch('core.rag.datasource.vdb.tidb_vector.tidb_vector.Session', new_callable=MagicMock) as mock_session: + with patch("core.rag.datasource.vdb.tidb_vector.tidb_vector.Session", new_callable=MagicMock) as mock_session: yield mock_session @pytest.fixture def setup_tidbvector_mock(tidb_vector, mock_session): - with patch('core.rag.datasource.vdb.tidb_vector.tidb_vector.create_engine'): - with patch.object(tidb_vector._engine, 'connect'): + with patch("core.rag.datasource.vdb.tidb_vector.tidb_vector.create_engine"): + with patch.object(tidb_vector._engine, "connect"): yield tidb_vector diff --git a/api/tests/integration_tests/vdb/weaviate/test_weaviate.py b/api/tests/integration_tests/vdb/weaviate/test_weaviate.py index 3d540cee32..a6f55420d3 100644 --- a/api/tests/integration_tests/vdb/weaviate/test_weaviate.py +++ b/api/tests/integration_tests/vdb/weaviate/test_weaviate.py @@ -8,14 +8,14 @@ from tests.integration_tests.vdb.test_vector_store import ( class WeaviateVectorTest(AbstractVectorTest): def __init__(self): super().__init__() - self.attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash'] + self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"] self.vector = WeaviateVector( collection_name=self.collection_name, config=WeaviateConfig( - endpoint='http://localhost:8080', - api_key='WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih', + endpoint="http://localhost:8080", + api_key="WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih", ), - attributes=self.attributes + attributes=self.attributes, ) diff --git a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py index 13f992136e..6fb8c86b82 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py @@ -6,24 +6,22 @@ from _pytest.monkeypatch import MonkeyPatch from jinja2 import Template from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage -from core.helper.code_executor.entities import CodeDependency -MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' +MOCK = os.getenv("MOCK_SWITCH", "false") == "true" + class MockedCodeExecutor: @classmethod - def invoke(cls, language: Literal['python3', 'javascript', 'jinja2'], - code: str, inputs: dict, dependencies: Optional[list[CodeDependency]] = None) -> dict: + def invoke(cls, language: Literal["python3", "javascript", "jinja2"], code: str, inputs: dict) -> dict: # invoke directly match language: case CodeLanguage.PYTHON3: - return { - "result": 3 - } + return {"result": 3} case CodeLanguage.JINJA2: - return { - "result": Template(code).render(inputs) - } + return {"result": Template(code).render(inputs)} + case _: + raise Exception("Language not supported") + @pytest.fixture def setup_code_executor_mock(request, monkeypatch: MonkeyPatch): diff --git a/api/tests/integration_tests/workflow/nodes/__mock/http.py b/api/tests/integration_tests/workflow/nodes/__mock/http.py index beb5c04009..cfc47bcad4 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/http.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/http.py @@ -6,38 +6,32 @@ import httpx import pytest from _pytest.monkeypatch import MonkeyPatch -MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' +MOCK = os.getenv("MOCK_SWITCH", "false") == "true" class MockedHttp: - def httpx_request(method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD'], - url: str, **kwargs) -> httpx.Response: + def httpx_request( + method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs + ) -> httpx.Response: """ Mocked httpx.request """ - if url == 'http://404.com': - response = httpx.Response( - status_code=404, - request=httpx.Request(method, url), - content=b'Not Found' - ) + if url == "http://404.com": + response = httpx.Response(status_code=404, request=httpx.Request(method, url), content=b"Not Found") return response # get data, files - data = kwargs.get('data', None) - files = kwargs.get('files', None) + data = kwargs.get("data", None) + files = kwargs.get("files", None) if data is not None: - resp = dumps(data).encode('utf-8') + resp = dumps(data).encode("utf-8") elif files is not None: - resp = dumps(files).encode('utf-8') + resp = dumps(files).encode("utf-8") else: - resp = b'OK' + resp = b"OK" response = httpx.Response( - status_code=200, - request=httpx.Request(method, url), - headers=kwargs.get('headers', {}), - content=resp + status_code=200, request=httpx.Request(method, url), headers=kwargs.get("headers", {}), content=resp ) return response diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py index ae6e7ceaa7..44dcf9a10f 100644 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py @@ -2,10 +2,10 @@ import pytest from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor -CODE_LANGUAGE = 'unsupported_language' +CODE_LANGUAGE = "unsupported_language" def test_unsupported_with_code_template(): with pytest.raises(CodeExecutionException) as e: - CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code='', inputs={}) - assert str(e.value) == f'Unsupported language {CODE_LANGUAGE}' + CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code="", inputs={}) + assert str(e.value) == f"Unsupported language {CODE_LANGUAGE}" diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py index 2d798eb9c2..09fcb68cf0 100644 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py @@ -9,8 +9,8 @@ CODE_LANGUAGE = CodeLanguage.JAVASCRIPT def test_javascript_plain(): code = 'console.log("Hello World")' - result_message = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload='', code=code) - assert result_message == 'Hello World\n' + result_message = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code) + assert result_message == "Hello World\n" def test_javascript_json(): @@ -18,22 +18,17 @@ def test_javascript_json(): obj = {'Hello': 'World'} console.log(JSON.stringify(obj)) """) - result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload='', code=code) + result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code) assert result == '{"Hello":"World"}\n' def test_javascript_with_code_template(): result = CodeExecutor.execute_workflow_code_template( - language=CODE_LANGUAGE, code=JavascriptCodeProvider.get_default_code(), - inputs={'arg1': 'Hello', 'arg2': 'World'}) - assert result == {'result': 'HelloWorld'} - - -def test_javascript_list_default_available_packages(): - packages = JavascriptCodeProvider.get_default_available_packages() - - # no default packages available for javascript - assert len(packages) == 0 + language=CODE_LANGUAGE, + code=JavascriptCodeProvider.get_default_code(), + inputs={"arg1": "Hello", "arg2": "World"}, + ) + assert result == {"result": "HelloWorld"} def test_javascript_get_runner_script(): diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py index 425f4cbdd4..94903cf796 100644 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py @@ -7,21 +7,24 @@ CODE_LANGUAGE = CodeLanguage.JINJA2 def test_jinja2(): - template = 'Hello {{template}}' - inputs = base64.b64encode(b'{"template": "World"}').decode('utf-8') - code = (Jinja2TemplateTransformer.get_runner_script() - .replace(Jinja2TemplateTransformer._code_placeholder, template) - .replace(Jinja2TemplateTransformer._inputs_placeholder, inputs)) - result = CodeExecutor.execute_code(language=CODE_LANGUAGE, - preload=Jinja2TemplateTransformer.get_preload_script(), - code=code) - assert result == '<>Hello World<>\n' + template = "Hello {{template}}" + inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8") + code = ( + Jinja2TemplateTransformer.get_runner_script() + .replace(Jinja2TemplateTransformer._code_placeholder, template) + .replace(Jinja2TemplateTransformer._inputs_placeholder, inputs) + ) + result = CodeExecutor.execute_code( + language=CODE_LANGUAGE, preload=Jinja2TemplateTransformer.get_preload_script(), code=code + ) + assert result == "<>Hello World<>\n" def test_jinja2_with_code_template(): result = CodeExecutor.execute_workflow_code_template( - language=CODE_LANGUAGE, code='Hello {{template}}', inputs={'template': 'World'}) - assert result == {'result': 'Hello World'} + language=CODE_LANGUAGE, code="Hello {{template}}", inputs={"template": "World"} + ) + assert result == {"result": "Hello World"} def test_jinja2_get_runner_script(): diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py index d265011d4c..cbe4a5d335 100644 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py @@ -10,8 +10,8 @@ CODE_LANGUAGE = CodeLanguage.PYTHON3 def test_python3_plain(): code = 'print("Hello World")' - result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload='', code=code) - assert result == 'Hello World\n' + result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code) + assert result == "Hello World\n" def test_python3_json(): @@ -19,23 +19,15 @@ def test_python3_json(): import json print(json.dumps({'Hello': 'World'})) """) - result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload='', code=code) + result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code) assert result == '{"Hello": "World"}\n' def test_python3_with_code_template(): result = CodeExecutor.execute_workflow_code_template( - language=CODE_LANGUAGE, code=Python3CodeProvider.get_default_code(), inputs={'arg1': 'Hello', 'arg2': 'World'}) - assert result == {'result': 'HelloWorld'} - - -def test_python3_list_default_available_packages(): - packages = Python3CodeProvider.get_default_available_packages() - assert len(packages) > 0 - assert {'requests', 'httpx'}.issubset(p['name'] for p in packages) - - # check JSON serializable - assert len(str(json.dumps(packages))) > 0 + language=CODE_LANGUAGE, code=Python3CodeProvider.get_default_code(), inputs={"arg1": "Hello", "arg2": "World"} + ) + assert result == {"result": "HelloWorld"} def test_python3_get_runner_script(): diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 5c95258520..6f5421e108 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -9,137 +9,134 @@ from core.workflow.nodes.code.code_node import CodeNode from models.workflow import WorkflowNodeExecutionStatus from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock -CODE_MAX_STRING_LENGTH = int(getenv('CODE_MAX_STRING_LENGTH', '10000')) +CODE_MAX_STRING_LENGTH = int(getenv("CODE_MAX_STRING_LENGTH", "10000")) -@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) def test_execute_code(setup_code_executor_mock): - code = ''' + code = """ def main(args1: int, args2: int) -> dict: return { "result": args1 + args2, } - ''' + """ # trim first 4 spaces at the beginning of each line - code = '\n'.join([line[4:] for line in code.split('\n')]) + code = "\n".join([line[4:] for line in code.split("\n")]) node = CodeNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.WEB_APP, config={ - 'id': '1', - 'data': { - 'outputs': { - 'result': { - 'type': 'number', + "id": "1", + "data": { + "outputs": { + "result": { + "type": "number", }, }, - 'title': '123', - 'variables': [ + "title": "123", + "variables": [ { - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], + "variable": "args1", + "value_selector": ["1", "123", "args1"], }, - { - 'variable': 'args2', - 'value_selector': ['1', '123', 'args2'] - } + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, ], - 'answer': '123', - 'code_language': 'python3', - 'code': code - } - } + "answer": "123", + "code_language": "python3", + "code": code, + }, + }, ) # construct variable pool pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(['1', '123', 'args1'], 1) - pool.add(['1', '123', 'args2'], 2) - + pool.add(["1", "123", "args1"], 1) + pool.add(["1", "123", "args2"], 2) + # execute node result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs['result'] == 3 + assert result.outputs["result"] == 3 assert result.error is None -@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) def test_execute_code_output_validator(setup_code_executor_mock): - code = ''' + code = """ def main(args1: int, args2: int) -> dict: return { "result": args1 + args2, } - ''' + """ # trim first 4 spaces at the beginning of each line - code = '\n'.join([line[4:] for line in code.split('\n')]) + code = "\n".join([line[4:] for line in code.split("\n")]) node = CodeNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.WEB_APP, config={ - 'id': '1', - 'data': { + "id": "1", + "data": { "outputs": { "result": { "type": "string", }, }, - 'title': '123', - 'variables': [ + "title": "123", + "variables": [ { - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], + "variable": "args1", + "value_selector": ["1", "123", "args1"], }, - { - 'variable': 'args2', - 'value_selector': ['1', '123', 'args2'] - } + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, ], - 'answer': '123', - 'code_language': 'python3', - 'code': code - } - } + "answer": "123", + "code_language": "python3", + "code": code, + }, + }, ) # construct variable pool pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(['1', '123', 'args1'], 1) - pool.add(['1', '123', 'args2'], 2) - + pool.add(["1", "123", "args1"], 1) + pool.add(["1", "123", "args2"], 2) + # execute node result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.FAILED - assert result.error == 'Output variable `result` must be a string' + assert result.error == "Output variable `result` must be a string" + def test_execute_code_output_validator_depth(): - code = ''' + code = """ def main(args1: int, args2: int) -> dict: return { "result": { "result": args1 + args2, } } - ''' + """ # trim first 4 spaces at the beginning of each line - code = '\n'.join([line[4:] for line in code.split('\n')]) + code = "\n".join([line[4:] for line in code.split("\n")]) node = CodeNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.WEB_APP, config={ - 'id': '1', - 'data': { + "id": "1", + "data": { "outputs": { "string_validator": { "type": "string", @@ -168,29 +165,26 @@ def test_execute_code_output_validator_depth(): "depth": { "type": "number", } - } + }, } - } - } - } + }, + }, + }, }, }, - 'title': '123', - 'variables': [ + "title": "123", + "variables": [ { - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], + "variable": "args1", + "value_selector": ["1", "123", "args1"], }, - { - 'variable': 'args2', - 'value_selector': ['1', '123', 'args2'] - } + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, ], - 'answer': '123', - 'code_language': 'python3', - 'code': code - } - } + "answer": "123", + "code_language": "python3", + "code": code, + }, + }, ) # construct result @@ -199,14 +193,7 @@ def test_execute_code_output_validator_depth(): "string_validator": "1", "number_array_validator": [1, 2, 3, 3.333], "string_array_validator": ["1", "2", "3"], - "object_validator": { - "result": 1, - "depth": { - "depth": { - "depth": 1 - } - } - } + "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, } # validate @@ -218,14 +205,7 @@ def test_execute_code_output_validator_depth(): "string_validator": 1, "number_array_validator": ["1", "2", "3", "3.333"], "string_array_validator": [1, 2, 3], - "object_validator": { - "result": "1", - "depth": { - "depth": { - "depth": "1" - } - } - } + "object_validator": {"result": "1", "depth": {"depth": {"depth": "1"}}}, } # validate @@ -238,34 +218,20 @@ def test_execute_code_output_validator_depth(): "string_validator": (CODE_MAX_STRING_LENGTH + 1) * "1", "number_array_validator": [1, 2, 3, 3.333], "string_array_validator": ["1", "2", "3"], - "object_validator": { - "result": 1, - "depth": { - "depth": { - "depth": 1 - } - } - } + "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, } # validate with pytest.raises(ValueError): node._transform_result(result, node.node_data.outputs) - + # construct result result = { "number_validator": 1, "string_validator": "1", "number_array_validator": [1, 2, 3, 3.333] * 2000, "string_array_validator": ["1", "2", "3"], - "object_validator": { - "result": 1, - "depth": { - "depth": { - "depth": 1 - } - } - } + "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, } # validate @@ -274,58 +240,59 @@ def test_execute_code_output_validator_depth(): def test_execute_code_output_object_list(): - code = ''' + code = """ def main(args1: int, args2: int) -> dict: return { "result": { "result": args1 + args2, } } - ''' + """ # trim first 4 spaces at the beginning of each line - code = '\n'.join([line[4:] for line in code.split('\n')]) + code = "\n".join([line[4:] for line in code.split("\n")]) node = CodeNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': '1', - 'data': { + "id": "1", + "data": { "outputs": { "object_list": { "type": "array[object]", }, }, - 'title': '123', - 'variables': [ + "title": "123", + "variables": [ { - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], + "variable": "args1", + "value_selector": ["1", "123", "args1"], }, - { - 'variable': 'args2', - 'value_selector': ['1', '123', 'args2'] - } + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, ], - 'answer': '123', - 'code_language': 'python3', - 'code': code - } - } + "answer": "123", + "code_language": "python3", + "code": code, + }, + }, ) # construct result result = { - "object_list": [{ - "result": 1, - }, { - "result": 2, - }, { - "result": [1, 2, 3], - }] + "object_list": [ + { + "result": 1, + }, + { + "result": 2, + }, + { + "result": [1, 2, 3], + }, + ] } # validate @@ -333,13 +300,18 @@ def test_execute_code_output_object_list(): # construct result result = { - "object_list": [{ - "result": 1, - }, { - "result": 2, - }, { - "result": [1, 2, 3], - }, 1] + "object_list": [ + { + "result": 1, + }, + { + "result": 2, + }, + { + "result": [1, 2, 3], + }, + 1, + ] } # validate diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index a1354bd6a5..acb616b325 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -9,322 +9,337 @@ from core.workflow.nodes.http_request.http_request_node import HttpRequestNode from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock BASIC_NODE_DATA = { - 'tenant_id': '1', - 'app_id': '1', - 'workflow_id': '1', - 'user_id': '1', - 'user_from': UserFrom.ACCOUNT, - 'invoke_from': InvokeFrom.WEB_APP, + "tenant_id": "1", + "app_id": "1", + "workflow_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.WEB_APP, } # construct variable pool pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) -pool.add(['a', 'b123', 'args1'], 1) -pool.add(['a', 'b123', 'args2'], 2) +pool.add(["a", "b123", "args1"], 1) +pool.add(["a", "b123", "args2"], 2) -@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_get(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'get', - 'url': 'http://example.com', - 'authorization': { - 'type': 'api-key', - 'config': { - 'type': 'basic', - 'api_key': 'ak-xxx', - 'header': 'api-key', - } - }, - 'headers': 'X-Header:123', - 'params': 'A:b', - 'body': None, - } - }, **BASIC_NODE_DATA) - - result = node.run(pool) - - data = result.process_data.get('request', '') - - assert '?A=b' in data - assert 'X-Header: 123' in data - - -@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) -def test_no_auth(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'get', - 'url': 'http://example.com', - 'authorization': { - 'type': 'no-auth', - 'config': None, - }, - 'headers': 'X-Header:123', - 'params': 'A:b', - 'body': None, - } - }, **BASIC_NODE_DATA) - - result = node.run(pool) - - data = result.process_data.get('request', '') - - assert '?A=b' in data - assert 'X-Header: 123' in data - - -@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) -def test_custom_authorization_header(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'get', - 'url': 'http://example.com', - 'authorization': { - 'type': 'api-key', - 'config': { - 'type': 'custom', - 'api_key': 'Auth', - 'header': 'X-Auth', + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, }, + "headers": "X-Header:123", + "params": "A:b", + "body": None, }, - 'headers': 'X-Header:123', - 'params': 'A:b', - 'body': None, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) - data = result.process_data.get('request', '') + data = result.process_data.get("request", "") - assert '?A=b' in data - assert 'X-Header: 123' in data + assert "?A=b" in data + assert "X-Header: 123" in data -@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) +def test_no_auth(setup_http_mock): + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com", + "authorization": { + "type": "no-auth", + "config": None, + }, + "headers": "X-Header:123", + "params": "A:b", + "body": None, + }, + }, + **BASIC_NODE_DATA, + ) + + result = node.run(pool) + + data = result.process_data.get("request", "") + + assert "?A=b" in data + assert "X-Header: 123" in data + + +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) +def test_custom_authorization_header(setup_http_mock): + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "custom", + "api_key": "Auth", + "header": "X-Auth", + }, + }, + "headers": "X-Header:123", + "params": "A:b", + "body": None, + }, + }, + **BASIC_NODE_DATA, + ) + + result = node.run(pool) + + data = result.process_data.get("request", "") + + assert "?A=b" in data + assert "X-Header: 123" in data + + +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_template(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'get', - 'url': 'http://example.com/{{#a.b123.args2#}}', - 'authorization': { - 'type': 'api-key', - 'config': { - 'type': 'basic', - 'api_key': 'ak-xxx', - 'header': 'api-key', - } + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com/{{#a.b123.args2#}}", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + }, + "headers": "X-Header:123\nX-Header2:{{#a.b123.args2#}}", + "params": "A:b\nTemplate:{{#a.b123.args2#}}", + "body": None, }, - 'headers': 'X-Header:123\nX-Header2:{{#a.b123.args2#}}', - 'params': 'A:b\nTemplate:{{#a.b123.args2#}}', - 'body': None, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) - data = result.process_data.get('request', '') + data = result.process_data.get("request", "") - assert '?A=b' in data - assert 'Template=2' in data - assert 'X-Header: 123' in data - assert 'X-Header2: 2' in data + assert "?A=b" in data + assert "Template=2" in data + assert "X-Header: 123" in data + assert "X-Header2: 2" in data -@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_json(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'post', - 'url': 'http://example.com', - 'authorization': { - 'type': 'api-key', - 'config': { - 'type': 'basic', - 'api_key': 'ak-xxx', - 'header': 'api-key', - } + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "post", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + }, + "headers": "X-Header:123", + "params": "A:b", + "body": {"type": "json", "data": '{"a": "{{#a.b123.args1#}}"}'}, }, - 'headers': 'X-Header:123', - 'params': 'A:b', - 'body': { - 'type': 'json', - 'data': '{"a": "{{#a.b123.args1#}}"}' - }, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) - data = result.process_data.get('request', '') + data = result.process_data.get("request", "") assert '{"a": "1"}' in data - assert 'X-Header: 123' in data + assert "X-Header: 123" in data def test_x_www_form_urlencoded(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'post', - 'url': 'http://example.com', - 'authorization': { - 'type': 'api-key', - 'config': { - 'type': 'basic', - 'api_key': 'ak-xxx', - 'header': 'api-key', - } + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "post", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + }, + "headers": "X-Header:123", + "params": "A:b", + "body": {"type": "x-www-form-urlencoded", "data": "a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}"}, }, - 'headers': 'X-Header:123', - 'params': 'A:b', - 'body': { - 'type': 'x-www-form-urlencoded', - 'data': 'a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}' - }, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) - data = result.process_data.get('request', '') + data = result.process_data.get("request", "") - assert 'a=1&b=2' in data - assert 'X-Header: 123' in data + assert "a=1&b=2" in data + assert "X-Header: 123" in data def test_form_data(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'post', - 'url': 'http://example.com', - 'authorization': { - 'type': 'api-key', - 'config': { - 'type': 'basic', - 'api_key': 'ak-xxx', - 'header': 'api-key', - } + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "post", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + }, + "headers": "X-Header:123", + "params": "A:b", + "body": {"type": "form-data", "data": "a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}"}, }, - 'headers': 'X-Header:123', - 'params': 'A:b', - 'body': { - 'type': 'form-data', - 'data': 'a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}' - }, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) - data = result.process_data.get('request', '') + data = result.process_data.get("request", "") assert 'form-data; name="a"' in data - assert '1' in data + assert "1" in data assert 'form-data; name="b"' in data - assert '2' in data - assert 'X-Header: 123' in data + assert "2" in data + assert "X-Header: 123" in data def test_none_data(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'post', - 'url': 'http://example.com', - 'authorization': { - 'type': 'api-key', - 'config': { - 'type': 'basic', - 'api_key': 'ak-xxx', - 'header': 'api-key', - } + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "post", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + }, + "headers": "X-Header:123", + "params": "A:b", + "body": {"type": "none", "data": "123123123"}, }, - 'headers': 'X-Header:123', - 'params': 'A:b', - 'body': { - 'type': 'none', - 'data': '123123123' - }, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) - data = result.process_data.get('request', '') + data = result.process_data.get("request", "") - assert 'X-Header: 123' in data - assert '123123123' not in data + assert "X-Header: 123" in data + assert "123123123" not in data def test_mock_404(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'get', - 'url': 'http://404.com', - 'authorization': { - 'type': 'no-auth', - 'config': None, + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://404.com", + "authorization": { + "type": "no-auth", + "config": None, + }, + "body": None, + "params": "", + "headers": "X-Header:123", }, - 'body': None, - 'params': '', - 'headers': 'X-Header:123', - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) resp = result.outputs - assert 404 == resp.get('status_code') - assert 'Not Found' in resp.get('body') + assert 404 == resp.get("status_code") + assert "Not Found" in resp.get("body") def test_multi_colons_parse(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'get', - 'url': 'http://example.com', - 'authorization': { - 'type': 'no-auth', - 'config': None, + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com", + "authorization": { + "type": "no-auth", + "config": None, + }, + "params": "Referer:http://example1.com\nRedirect:http://example2.com", + "headers": "Referer:http://example3.com\nRedirect:http://example4.com", + "body": {"type": "form-data", "data": "Referer:http://example5.com\nRedirect:http://example6.com"}, }, - 'params': 'Referer:http://example1.com\nRedirect:http://example2.com', - 'headers': 'Referer:http://example3.com\nRedirect:http://example4.com', - 'body': { - 'type': 'form-data', - 'data': 'Referer:http://example5.com\nRedirect:http://example6.com' - }, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) resp = result.outputs - assert urlencode({'Redirect': 'http://example2.com'}) in result.process_data.get('request') - assert 'form-data; name="Redirect"\n\nhttp://example6.com' in result.process_data.get('request') - assert 'http://example3.com' == resp.get('headers').get('referer') + assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request") + assert 'form-data; name="Redirect"\n\nhttp://example6.com' in result.process_data.get("request") + assert "http://example3.com" == resp.get("headers").get("referer") diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 4686ce0675..6bab83a019 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -11,7 +11,7 @@ from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers import ModelProviderFactory from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.llm.llm_node import LLMNode from extensions.ext_database import db @@ -23,90 +23,71 @@ from tests.integration_tests.model_runtime.__mock.openai import setup_openai_moc from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_execute_llm(setup_openai_mock): node = LLMNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'llm', - 'model': { - 'provider': 'openai', - 'name': 'gpt-3.5-turbo', - 'mode': 'chat', - 'completion_params': {} - }, - 'prompt_template': [ - { - 'role': 'system', - 'text': 'you are a helpful assistant.\ntoday\'s weather is {{#abc.output#}}.' - }, - { - 'role': 'user', - 'text': '{{#sys.query#}}' - } + "id": "llm", + "data": { + "title": "123", + "type": "llm", + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, + "prompt_template": [ + {"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."}, + {"role": "user", "text": "{{#sys.query#}}"}, ], - 'memory': None, - 'context': { - 'enabled': False - }, - 'vision': { - 'enabled': False - } - } - } + "memory": None, + "context": {"enabled": False}, + "vision": {"enabled": False}, + }, + }, ) # construct variable pool - pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'what\'s the weather today?', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) - pool.add(['abc', 'output'], 'sunny') + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather today?", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + ) + pool.add(["abc", "output"], "sunny") - credentials = { - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} - provider_instance = ModelProviderFactory().get_provider_instance('openai') + provider_instance = ModelProviderFactory().get_provider_instance("openai") model_type_instance = provider_instance.get_model_instance(ModelType.LLM) provider_model_bundle = ProviderModelBundle( configuration=ProviderConfiguration( - tenant_id='1', + tenant_id="1", provider=provider_instance.get_provider_schema(), preferred_provider_type=ProviderType.CUSTOM, using_provider_type=ProviderType.CUSTOM, - system_configuration=SystemConfiguration( - enabled=False - ), - custom_configuration=CustomConfiguration( - provider=CustomProviderConfiguration( - credentials=credentials - ) - ), - model_settings=[] + system_configuration=SystemConfiguration(enabled=False), + custom_configuration=CustomConfiguration(provider=CustomProviderConfiguration(credentials=credentials)), + model_settings=[], ), provider_instance=provider_instance, - model_type_instance=model_type_instance + model_type_instance=model_type_instance, ) - model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo') + model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model="gpt-3.5-turbo") model_config = ModelConfigWithCredentialsEntity( - model='gpt-3.5-turbo', - provider='openai', - mode='chat', + model="gpt-3.5-turbo", + provider="openai", + mode="chat", credentials=credentials, parameters={}, - model_schema=model_type_instance.get_model_schema('gpt-3.5-turbo'), - provider_model_bundle=provider_model_bundle + model_schema=model_type_instance.get_model_schema("gpt-3.5-turbo"), + provider_model_bundle=provider_model_bundle, ) # Mock db.session.close() @@ -118,112 +99,97 @@ def test_execute_llm(setup_openai_mock): result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs['text'] is not None - assert result.outputs['usage']['total_tokens'] > 0 + assert result.outputs["text"] is not None + assert result.outputs["usage"]["total_tokens"] > 0 -@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): """ Test execute LLM node with jinja2 """ node = LLMNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'llm', - 'model': { - 'provider': 'openai', - 'name': 'gpt-3.5-turbo', - 'mode': 'chat', - 'completion_params': {} + "id": "llm", + "data": { + "title": "123", + "type": "llm", + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, + "prompt_config": { + "jinja2_variables": [ + {"variable": "sys_query", "value_selector": ["sys", "query"]}, + {"variable": "output", "value_selector": ["abc", "output"]}, + ] }, - 'prompt_config': { - 'jinja2_variables': [{ - 'variable': 'sys_query', - 'value_selector': ['sys', 'query'] - }, { - 'variable': 'output', - 'value_selector': ['abc', 'output'] - }] - }, - 'prompt_template': [ + "prompt_template": [ { - 'role': 'system', - 'text': 'you are a helpful assistant.\ntoday\'s weather is {{#abc.output#}}', - 'jinja2_text': 'you are a helpful assistant.\ntoday\'s weather is {{output}}.', - 'edition_type': 'jinja2' + "role": "system", + "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}", + "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.", + "edition_type": "jinja2", }, { - 'role': 'user', - 'text': '{{#sys.query#}}', - 'jinja2_text': '{{sys_query}}', - 'edition_type': 'basic' - } + "role": "user", + "text": "{{#sys.query#}}", + "jinja2_text": "{{sys_query}}", + "edition_type": "basic", + }, ], - 'memory': None, - 'context': { - 'enabled': False - }, - 'vision': { - 'enabled': False - } - } - } + "memory": None, + "context": {"enabled": False}, + "vision": {"enabled": False}, + }, + }, ) # construct variable pool - pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'what\'s the weather today?', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) - pool.add(['abc', 'output'], 'sunny') + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather today?", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + ) + pool.add(["abc", "output"], "sunny") - credentials = { - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} - provider_instance = ModelProviderFactory().get_provider_instance('openai') + provider_instance = ModelProviderFactory().get_provider_instance("openai") model_type_instance = provider_instance.get_model_instance(ModelType.LLM) provider_model_bundle = ProviderModelBundle( configuration=ProviderConfiguration( - tenant_id='1', + tenant_id="1", provider=provider_instance.get_provider_schema(), preferred_provider_type=ProviderType.CUSTOM, using_provider_type=ProviderType.CUSTOM, - system_configuration=SystemConfiguration( - enabled=False - ), - custom_configuration=CustomConfiguration( - provider=CustomProviderConfiguration( - credentials=credentials - ) - ), - model_settings=[] + system_configuration=SystemConfiguration(enabled=False), + custom_configuration=CustomConfiguration(provider=CustomProviderConfiguration(credentials=credentials)), + model_settings=[], ), provider_instance=provider_instance, model_type_instance=model_type_instance, ) - model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo') + model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model="gpt-3.5-turbo") model_config = ModelConfigWithCredentialsEntity( - model='gpt-3.5-turbo', - provider='openai', - mode='chat', + model="gpt-3.5-turbo", + provider="openai", + mode="chat", credentials=credentials, parameters={}, - model_schema=model_type_instance.get_model_schema('gpt-3.5-turbo'), - provider_model_bundle=provider_model_bundle + model_schema=model_type_instance.get_model_schema("gpt-3.5-turbo"), + provider_model_bundle=provider_model_bundle, ) # Mock db.session.close() @@ -235,5 +201,5 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert 'sunny' in json.dumps(result.process_data) - assert 'what\'s the weather today?' in json.dumps(result.process_data) + assert "sunny" in json.dumps(result.process_data) + assert "what's the weather today?" in json.dumps(result.process_data) diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index adf5ffe3ca..ca2bae5c53 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -13,7 +13,7 @@ from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from extensions.ext_database import db @@ -26,29 +26,25 @@ from tests.integration_tests.model_runtime.__mock.openai import setup_openai_moc def get_mocked_fetch_model_config( - provider: str, model: str, mode: str, + provider: str, + model: str, + mode: str, credentials: dict, ): provider_instance = ModelProviderFactory().get_provider_instance(provider) model_type_instance = provider_instance.get_model_instance(ModelType.LLM) provider_model_bundle = ProviderModelBundle( configuration=ProviderConfiguration( - tenant_id='1', + tenant_id="1", provider=provider_instance.get_provider_schema(), preferred_provider_type=ProviderType.CUSTOM, using_provider_type=ProviderType.CUSTOM, - system_configuration=SystemConfiguration( - enabled=False - ), - custom_configuration=CustomConfiguration( - provider=CustomProviderConfiguration( - credentials=credentials - ) - ), - model_settings=[] + system_configuration=SystemConfiguration(enabled=False), + custom_configuration=CustomConfiguration(provider=CustomProviderConfiguration(credentials=credentials)), + model_settings=[], ), provider_instance=provider_instance, - model_type_instance=model_type_instance + model_type_instance=model_type_instance, ) model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model=model) model_config = ModelConfigWithCredentialsEntity( @@ -58,268 +54,268 @@ def get_mocked_fetch_model_config( credentials=credentials, parameters={}, model_schema=model_type_instance.get_model_schema(model), - provider_model_bundle=provider_model_bundle + provider_model_bundle=provider_model_bundle, ) return MagicMock(return_value=(model_instance, model_config)) + def get_mocked_fetch_memory(memory_text: str): class MemoryMock: - def get_history_prompt_text(self, human_prefix: str = "Human", - ai_prefix: str = "Assistant", - max_token_limit: int = 2000, - message_limit: Optional[int] = None): + def get_history_prompt_text( + self, + human_prefix: str = "Human", + ai_prefix: str = "Assistant", + max_token_limit: int = 2000, + message_limit: Optional[int] = None, + ): return memory_text return MagicMock(return_value=MemoryMock()) -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_function_calling_parameter_extractor(setup_openai_mock): """ Test function calling for parameter extractor. """ node = ParameterExtractorNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'parameter-extractor', - 'model': { - 'provider': 'openai', - 'name': 'gpt-3.5-turbo', - 'mode': 'chat', - 'completion_params': {} - }, - 'query': ['sys', 'query'], - 'parameters': [{ - 'name': 'location', - 'type': 'string', - 'description': 'location', - 'required': True - }], - 'instruction': '', - 'reasoning_mode': 'function_call', - 'memory': None, - } - } + "id": "llm", + "data": { + "title": "123", + "type": "parameter-extractor", + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, + "query": ["sys", "query"], + "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}], + "instruction": "", + "reasoning_mode": "function_call", + "memory": None, + }, + }, ) node._fetch_model_config = get_mocked_fetch_model_config( - provider='openai', model='gpt-3.5-turbo', mode='chat', credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + provider="openai", + model="gpt-3.5-turbo", + mode="chat", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, ) db.session.close = MagicMock() # construct variable pool - pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'what\'s the weather in SF', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + ) result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs.get('location') == 'kawaii' - assert result.outputs.get('__reason') == None + assert result.outputs.get("location") == "kawaii" + assert result.outputs.get("__reason") == None -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_instructions(setup_openai_mock): """ Test chat parameter extractor. """ node = ParameterExtractorNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'parameter-extractor', - 'model': { - 'provider': 'openai', - 'name': 'gpt-3.5-turbo', - 'mode': 'chat', - 'completion_params': {} - }, - 'query': ['sys', 'query'], - 'parameters': [{ - 'name': 'location', - 'type': 'string', - 'description': 'location', - 'required': True - }], - 'reasoning_mode': 'function_call', - 'instruction': '{{#sys.query#}}', - 'memory': None, - } - } + "id": "llm", + "data": { + "title": "123", + "type": "parameter-extractor", + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, + "query": ["sys", "query"], + "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}], + "reasoning_mode": "function_call", + "instruction": "{{#sys.query#}}", + "memory": None, + }, + }, ) node._fetch_model_config = get_mocked_fetch_model_config( - provider='openai', model='gpt-3.5-turbo', mode='chat', credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + provider="openai", + model="gpt-3.5-turbo", + mode="chat", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, ) db.session.close = MagicMock() # construct variable pool - pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'what\'s the weather in SF', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + ) result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs.get('location') == 'kawaii' - assert result.outputs.get('__reason') == None + assert result.outputs.get("location") == "kawaii" + assert result.outputs.get("__reason") == None process_data = result.process_data - process_data.get('prompts') + process_data.get("prompts") - for prompt in process_data.get('prompts'): - if prompt.get('role') == 'system': - assert 'what\'s the weather in SF' in prompt.get('text') + for prompt in process_data.get("prompts"): + if prompt.get("role") == "system": + assert "what's the weather in SF" in prompt.get("text") -@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True) def test_chat_parameter_extractor(setup_anthropic_mock): """ Test chat parameter extractor. """ node = ParameterExtractorNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'parameter-extractor', - 'model': { - 'provider': 'anthropic', - 'name': 'claude-2', - 'mode': 'chat', - 'completion_params': {} - }, - 'query': ['sys', 'query'], - 'parameters': [{ - 'name': 'location', - 'type': 'string', - 'description': 'location', - 'required': True - }], - 'reasoning_mode': 'prompt', - 'instruction': '', - 'memory': None, - } - } + "id": "llm", + "data": { + "title": "123", + "type": "parameter-extractor", + "model": {"provider": "anthropic", "name": "claude-2", "mode": "chat", "completion_params": {}}, + "query": ["sys", "query"], + "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}], + "reasoning_mode": "prompt", + "instruction": "", + "memory": None, + }, + }, ) node._fetch_model_config = get_mocked_fetch_model_config( - provider='anthropic', model='claude-2', mode='chat', credentials={ - 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') - } + provider="anthropic", + model="claude-2", + mode="chat", + credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}, ) db.session.close = MagicMock() # construct variable pool - pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'what\'s the weather in SF', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + ) result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs.get('location') == '' - assert result.outputs.get('__reason') == 'Failed to extract result from function call or text response, using empty result.' - prompts = result.process_data.get('prompts') + assert result.outputs.get("location") == "" + assert ( + result.outputs.get("__reason") + == "Failed to extract result from function call or text response, using empty result." + ) + prompts = result.process_data.get("prompts") for prompt in prompts: - if prompt.get('role') == 'user': - if '' in prompt.get('text'): - assert '\n{"type": "object"' in prompt.get('text') + if prompt.get("role") == "user": + if "" in prompt.get("text"): + assert '\n{"type": "object"' in prompt.get("text") -@pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) def test_completion_parameter_extractor(setup_openai_mock): """ Test completion parameter extractor. """ node = ParameterExtractorNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'parameter-extractor', - 'model': { - 'provider': 'openai', - 'name': 'gpt-3.5-turbo-instruct', - 'mode': 'completion', - 'completion_params': {} + "id": "llm", + "data": { + "title": "123", + "type": "parameter-extractor", + "model": { + "provider": "openai", + "name": "gpt-3.5-turbo-instruct", + "mode": "completion", + "completion_params": {}, }, - 'query': ['sys', 'query'], - 'parameters': [{ - 'name': 'location', - 'type': 'string', - 'description': 'location', - 'required': True - }], - 'reasoning_mode': 'prompt', - 'instruction': '{{#sys.query#}}', - 'memory': None, - } - } + "query": ["sys", "query"], + "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}], + "reasoning_mode": "prompt", + "instruction": "{{#sys.query#}}", + "memory": None, + }, + }, ) node._fetch_model_config = get_mocked_fetch_model_config( - provider='openai', model='gpt-3.5-turbo-instruct', mode='completion', credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + provider="openai", + model="gpt-3.5-turbo-instruct", + mode="completion", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, ) db.session.close = MagicMock() # construct variable pool - pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'what\'s the weather in SF', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + ) result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs.get('location') == '' - assert result.outputs.get('__reason') == 'Failed to extract result from function call or text response, using empty result.' - assert len(result.process_data.get('prompts')) == 1 - assert 'SF' in result.process_data.get('prompts')[0].get('text') + assert result.outputs.get("location") == "" + assert ( + result.outputs.get("__reason") + == "Failed to extract result from function call or text response, using empty result." + ) + assert len(result.process_data.get("prompts")) == 1 + assert "SF" in result.process_data.get("prompts")[0].get("text") + def test_extract_json_response(): """ @@ -327,35 +323,30 @@ def test_extract_json_response(): """ node = ParameterExtractorNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'parameter-extractor', - 'model': { - 'provider': 'openai', - 'name': 'gpt-3.5-turbo-instruct', - 'mode': 'completion', - 'completion_params': {} + "id": "llm", + "data": { + "title": "123", + "type": "parameter-extractor", + "model": { + "provider": "openai", + "name": "gpt-3.5-turbo-instruct", + "mode": "completion", + "completion_params": {}, }, - 'query': ['sys', 'query'], - 'parameters': [{ - 'name': 'location', - 'type': 'string', - 'description': 'location', - 'required': True - }], - 'reasoning_mode': 'prompt', - 'instruction': '{{#sys.query#}}', - 'memory': None, - } - } + "query": ["sys", "query"], + "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}], + "reasoning_mode": "prompt", + "instruction": "{{#sys.query#}}", + "memory": None, + }, + }, ) result = node._extract_complete_json_response(""" @@ -366,83 +357,77 @@ def test_extract_json_response(): hello world. """) - assert result['location'] == 'kawaii' + assert result["location"] == "kawaii" -@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True) def test_chat_parameter_extractor_with_memory(setup_anthropic_mock): """ Test chat parameter extractor with memory. """ node = ParameterExtractorNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'parameter-extractor', - 'model': { - 'provider': 'anthropic', - 'name': 'claude-2', - 'mode': 'chat', - 'completion_params': {} - }, - 'query': ['sys', 'query'], - 'parameters': [{ - 'name': 'location', - 'type': 'string', - 'description': 'location', - 'required': True - }], - 'reasoning_mode': 'prompt', - 'instruction': '', - 'memory': { - 'window': { - 'enabled': True, - 'size': 50 - } - }, - } - } + "id": "llm", + "data": { + "title": "123", + "type": "parameter-extractor", + "model": {"provider": "anthropic", "name": "claude-2", "mode": "chat", "completion_params": {}}, + "query": ["sys", "query"], + "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}], + "reasoning_mode": "prompt", + "instruction": "", + "memory": {"window": {"enabled": True, "size": 50}}, + }, + }, ) node._fetch_model_config = get_mocked_fetch_model_config( - provider='anthropic', model='claude-2', mode='chat', credentials={ - 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') - } + provider="anthropic", + model="claude-2", + mode="chat", + credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}, ) - node._fetch_memory = get_mocked_fetch_memory('customized memory') + node._fetch_memory = get_mocked_fetch_memory("customized memory") db.session.close = MagicMock() # construct variable pool - pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'what\'s the weather in SF', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + ) result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs.get('location') == '' - assert result.outputs.get('__reason') == 'Failed to extract result from function call or text response, using empty result.' - prompts = result.process_data.get('prompts') + assert result.outputs.get("location") == "" + assert ( + result.outputs.get("__reason") + == "Failed to extract result from function call or text response, using empty result." + ) + prompts = result.process_data.get("prompts") latest_role = None for prompt in prompts: - if prompt.get('role') == 'user': - if '' in prompt.get('text'): - assert '\n{"type": "object"' in prompt.get('text') - elif prompt.get('role') == 'system': - assert 'customized memory' in prompt.get('text') + if prompt.get("role") == "user": + if "" in prompt.get("text"): + assert '\n{"type": "object"' in prompt.get("text") + elif prompt.get("role") == "system": + assert "customized memory" in prompt.get("text") if latest_role is not None: - assert latest_role != prompt.get('role') + assert latest_role != prompt.get("role") - if prompt.get('role') in ['user', 'assistant']: - latest_role = prompt.get('role') + if prompt.get("role") in ["user", "assistant"]: + latest_role = prompt.get("role") diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 781dfbc50f..617b6370c9 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -8,42 +8,39 @@ from models.workflow import WorkflowNodeExecutionStatus from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock -@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) def test_execute_code(setup_code_executor_mock): - code = '''{{args2}}''' + code = """{{args2}}""" node = TemplateTransformNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.END_USER, config={ - 'id': '1', - 'data': { - 'title': '123', - 'variables': [ + "id": "1", + "data": { + "title": "123", + "variables": [ { - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], + "variable": "args1", + "value_selector": ["1", "123", "args1"], }, - { - 'variable': 'args2', - 'value_selector': ['1', '123', 'args2'] - } + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, ], - 'template': code, - } - } + "template": code, + }, + }, ) # construct variable pool pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(['1', '123', 'args1'], 1) - pool.add(['1', '123', 'args2'], 3) - + pool.add(["1", "123", "args1"], 1) + pool.add(["1", "123", "args2"], 3) + # execute node result = node.run(pool) - + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs['output'] == '3' + assert result.outputs["output"] == "3" diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 01d62280e8..29c1efa8e7 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -7,78 +7,79 @@ from models.workflow import WorkflowNodeExecutionStatus def test_tool_variable_invoke(): pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(['1', '123', 'args1'], '1+1') + pool.add(["1", "123", "args1"], "1+1") node = ToolNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': '1', - 'data': { - 'title': 'a', - 'desc': 'a', - 'provider_id': 'maths', - 'provider_type': 'builtin', - 'provider_name': 'maths', - 'tool_name': 'eval_expression', - 'tool_label': 'eval_expression', - 'tool_configurations': {}, - 'tool_parameters': { - 'expression': { - 'type': 'variable', - 'value': ['1', '123', 'args1'], + "id": "1", + "data": { + "title": "a", + "desc": "a", + "provider_id": "maths", + "provider_type": "builtin", + "provider_name": "maths", + "tool_name": "eval_expression", + "tool_label": "eval_expression", + "tool_configurations": {}, + "tool_parameters": { + "expression": { + "type": "variable", + "value": ["1", "123", "args1"], } - } - } - } + }, + }, + }, ) # execute node result = node.run(pool) - + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert '2' in result.outputs['text'] - assert result.outputs['files'] == [] + assert "2" in result.outputs["text"] + assert result.outputs["files"] == [] + def test_tool_mixed_invoke(): pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(['1', 'args1'], '1+1') + pool.add(["1", "args1"], "1+1") node = ToolNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': '1', - 'data': { - 'title': 'a', - 'desc': 'a', - 'provider_id': 'maths', - 'provider_type': 'builtin', - 'provider_name': 'maths', - 'tool_name': 'eval_expression', - 'tool_label': 'eval_expression', - 'tool_configurations': {}, - 'tool_parameters': { - 'expression': { - 'type': 'mixed', - 'value': '{{#1.args1#}}', + "id": "1", + "data": { + "title": "a", + "desc": "a", + "provider_id": "maths", + "provider_type": "builtin", + "provider_name": "maths", + "tool_name": "eval_expression", + "tool_label": "eval_expression", + "tool_configurations": {}, + "tool_parameters": { + "expression": { + "type": "mixed", + "value": "{{#1.args1#}}", } - } - } - } + }, + }, + }, ) # execute node result = node.run(pool) - + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert '2' in result.outputs['text'] - assert result.outputs['files'] == [] \ No newline at end of file + assert "2" in result.outputs["text"] + assert result.outputs["files"] == [] diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index 949a5a1769..3f639ccacc 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -3,21 +3,26 @@ from textwrap import dedent import pytest from flask import Flask +from yarl import URL from configs.app_config import DifyConfig -EXAMPLE_ENV_FILENAME = '.env' +EXAMPLE_ENV_FILENAME = ".env" @pytest.fixture def example_env_file(tmp_path, monkeypatch) -> str: monkeypatch.chdir(tmp_path) file_path = tmp_path.joinpath(EXAMPLE_ENV_FILENAME) - file_path.write_text(dedent( - """ + file_path.write_text( + dedent( + """ CONSOLE_API_URL=https://example.com CONSOLE_WEB_URL=https://example.com - """)) + HTTP_REQUEST_MAX_WRITE_TIMEOUT=30 + """ + ) + ) return str(file_path) @@ -29,7 +34,7 @@ def test_dify_config_undefined_entry(example_env_file): # entries not defined in app settings with pytest.raises(TypeError): # TypeError: 'AppSettings' object is not subscriptable - assert config['LOG_LEVEL'] == 'INFO' + assert config["LOG_LEVEL"] == "INFO" def test_dify_config(example_env_file): @@ -37,47 +42,56 @@ def test_dify_config(example_env_file): config = DifyConfig(_env_file=example_env_file) # constant values - assert config.COMMIT_SHA == '' + assert config.COMMIT_SHA == "" # default values - assert config.EDITION == 'SELF_HOSTED' + assert config.EDITION == "SELF_HOSTED" assert config.API_COMPRESSION_ENABLED is False assert config.SENTRY_TRACES_SAMPLE_RATE == 1.0 + # annotated field with default value + assert config.HTTP_REQUEST_MAX_READ_TIMEOUT == 60 + + # annotated field with configured value + assert config.HTTP_REQUEST_MAX_WRITE_TIMEOUT == 30 + # NOTE: If there is a `.env` file in your Workspace, this test might not succeed as expected. # This is due to `pymilvus` loading all the variables from the `.env` file into `os.environ`. def test_flask_configs(example_env_file): - flask_app = Flask('app') + flask_app = Flask("app") # clear system environment variables os.environ.clear() flask_app.config.from_mapping(DifyConfig(_env_file=example_env_file).model_dump()) # pyright: ignore config = flask_app.config # configs read from pydantic-settings - assert config['LOG_LEVEL'] == 'INFO' - assert config['COMMIT_SHA'] == '' - assert config['EDITION'] == 'SELF_HOSTED' - assert config['API_COMPRESSION_ENABLED'] is False - assert config['SENTRY_TRACES_SAMPLE_RATE'] == 1.0 - assert config['TESTING'] == False + assert config["LOG_LEVEL"] == "INFO" + assert config["COMMIT_SHA"] == "" + assert config["EDITION"] == "SELF_HOSTED" + assert config["API_COMPRESSION_ENABLED"] is False + assert config["SENTRY_TRACES_SAMPLE_RATE"] == 1.0 + assert config["TESTING"] == False # value from env file - assert config['CONSOLE_API_URL'] == 'https://example.com' + assert config["CONSOLE_API_URL"] == "https://example.com" # fallback to alias choices value as CONSOLE_API_URL - assert config['FILES_URL'] == 'https://example.com' + assert config["FILES_URL"] == "https://example.com" - assert config['SQLALCHEMY_DATABASE_URI'] == 'postgresql://postgres:@localhost:5432/dify' - assert config['SQLALCHEMY_ENGINE_OPTIONS'] == { - 'connect_args': { - 'options': '-c timezone=UTC', + assert config["SQLALCHEMY_DATABASE_URI"] == "postgresql://postgres:@localhost:5432/dify" + assert config["SQLALCHEMY_ENGINE_OPTIONS"] == { + "connect_args": { + "options": "-c timezone=UTC", }, - 'max_overflow': 10, - 'pool_pre_ping': False, - 'pool_recycle': 3600, - 'pool_size': 30, + "max_overflow": 10, + "pool_pre_ping": False, + "pool_recycle": 3600, + "pool_size": 30, } - assert config['CONSOLE_WEB_URL']=='https://example.com' - assert config['CONSOLE_CORS_ALLOW_ORIGINS']==['https://example.com'] - assert config['WEB_API_CORS_ALLOW_ORIGINS'] == ['*'] + assert config["CONSOLE_WEB_URL"] == "https://example.com" + assert config["CONSOLE_CORS_ALLOW_ORIGINS"] == ["https://example.com"] + assert config["WEB_API_CORS_ALLOW_ORIGINS"] == ["*"] + + assert str(config["CODE_EXECUTION_ENDPOINT"]) == "http://sandbox:8194/" + assert str(URL(str(config["CODE_EXECUTION_ENDPOINT"])) / "v1") == "http://sandbox:8194/v1" diff --git a/api/tests/unit_tests/core/app/segments/test_factory.py b/api/tests/unit_tests/core/app/segments/test_factory.py index afd0fa50b5..0824c8e9e9 100644 --- a/api/tests/unit_tests/core/app/segments/test_factory.py +++ b/api/tests/unit_tests/core/app/segments/test_factory.py @@ -17,31 +17,31 @@ from core.app.segments.exc import VariableError def test_string_variable(): - test_data = {'value_type': 'string', 'name': 'test_text', 'value': 'Hello, World!'} + test_data = {"value_type": "string", "name": "test_text", "value": "Hello, World!"} result = factory.build_variable_from_mapping(test_data) assert isinstance(result, StringVariable) def test_integer_variable(): - test_data = {'value_type': 'number', 'name': 'test_int', 'value': 42} + test_data = {"value_type": "number", "name": "test_int", "value": 42} result = factory.build_variable_from_mapping(test_data) assert isinstance(result, IntegerVariable) def test_float_variable(): - test_data = {'value_type': 'number', 'name': 'test_float', 'value': 3.14} + test_data = {"value_type": "number", "name": "test_float", "value": 3.14} result = factory.build_variable_from_mapping(test_data) assert isinstance(result, FloatVariable) def test_secret_variable(): - test_data = {'value_type': 'secret', 'name': 'test_secret', 'value': 'secret_value'} + test_data = {"value_type": "secret", "name": "test_secret", "value": "secret_value"} result = factory.build_variable_from_mapping(test_data) assert isinstance(result, SecretVariable) def test_invalid_value_type(): - test_data = {'value_type': 'unknown', 'name': 'test_invalid', 'value': 'value'} + test_data = {"value_type": "unknown", "name": "test_invalid", "value": "value"} with pytest.raises(VariableError): factory.build_variable_from_mapping(test_data) @@ -49,51 +49,51 @@ def test_invalid_value_type(): def test_build_a_blank_string(): result = factory.build_variable_from_mapping( { - 'value_type': 'string', - 'name': 'blank', - 'value': '', + "value_type": "string", + "name": "blank", + "value": "", } ) assert isinstance(result, StringVariable) - assert result.value == '' + assert result.value == "" def test_build_a_object_variable_with_none_value(): var = factory.build_segment( { - 'key1': None, + "key1": None, } ) assert isinstance(var, ObjectSegment) - assert var.value['key1'] is None + assert var.value["key1"] is None def test_object_variable(): mapping = { - 'id': str(uuid4()), - 'value_type': 'object', - 'name': 'test_object', - 'description': 'Description of the variable.', - 'value': { - 'key1': 'text', - 'key2': 2, + "id": str(uuid4()), + "value_type": "object", + "name": "test_object", + "description": "Description of the variable.", + "value": { + "key1": "text", + "key2": 2, }, } variable = factory.build_variable_from_mapping(mapping) assert isinstance(variable, ObjectSegment) - assert isinstance(variable.value['key1'], str) - assert isinstance(variable.value['key2'], int) + assert isinstance(variable.value["key1"], str) + assert isinstance(variable.value["key2"], int) def test_array_string_variable(): mapping = { - 'id': str(uuid4()), - 'value_type': 'array[string]', - 'name': 'test_array', - 'description': 'Description of the variable.', - 'value': [ - 'text', - 'text', + "id": str(uuid4()), + "value_type": "array[string]", + "name": "test_array", + "description": "Description of the variable.", + "value": [ + "text", + "text", ], } variable = factory.build_variable_from_mapping(mapping) @@ -104,11 +104,11 @@ def test_array_string_variable(): def test_array_number_variable(): mapping = { - 'id': str(uuid4()), - 'value_type': 'array[number]', - 'name': 'test_array', - 'description': 'Description of the variable.', - 'value': [ + "id": str(uuid4()), + "value_type": "array[number]", + "name": "test_array", + "description": "Description of the variable.", + "value": [ 1, 2.0, ], @@ -121,18 +121,18 @@ def test_array_number_variable(): def test_array_object_variable(): mapping = { - 'id': str(uuid4()), - 'value_type': 'array[object]', - 'name': 'test_array', - 'description': 'Description of the variable.', - 'value': [ + "id": str(uuid4()), + "value_type": "array[object]", + "name": "test_array", + "description": "Description of the variable.", + "value": [ { - 'key1': 'text', - 'key2': 1, + "key1": "text", + "key2": 1, }, { - 'key1': 'text', - 'key2': 1, + "key1": "text", + "key2": 1, }, ], } @@ -140,19 +140,19 @@ def test_array_object_variable(): assert isinstance(variable, ArrayObjectVariable) assert isinstance(variable.value[0], dict) assert isinstance(variable.value[1], dict) - assert isinstance(variable.value[0]['key1'], str) - assert isinstance(variable.value[0]['key2'], int) - assert isinstance(variable.value[1]['key1'], str) - assert isinstance(variable.value[1]['key2'], int) + assert isinstance(variable.value[0]["key1"], str) + assert isinstance(variable.value[0]["key2"], int) + assert isinstance(variable.value[1]["key1"], str) + assert isinstance(variable.value[1]["key2"], int) def test_variable_cannot_large_than_5_kb(): with pytest.raises(VariableError): factory.build_variable_from_mapping( { - 'id': str(uuid4()), - 'value_type': 'string', - 'name': 'test_text', - 'value': 'a' * 1024 * 6, + "id": str(uuid4()), + "value_type": "string", + "name": "test_text", + "value": "a" * 1024 * 6, } ) diff --git a/api/tests/unit_tests/core/app/segments/test_segment.py b/api/tests/unit_tests/core/app/segments/test_segment.py index 7e3e69ffbf..7cc339d212 100644 --- a/api/tests/unit_tests/core/app/segments/test_segment.py +++ b/api/tests/unit_tests/core/app/segments/test_segment.py @@ -1,26 +1,26 @@ from core.app.segments import SecretVariable, StringSegment, parser from core.helper import encrypter from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey def test_segment_group_to_text(): variable_pool = VariablePool( system_variables={ - SystemVariable('user_id'): 'fake-user-id', + SystemVariableKey("user_id"): "fake-user-id", }, user_inputs={}, environment_variables=[ - SecretVariable(name='secret_key', value='fake-secret-key'), + SecretVariable(name="secret_key", value="fake-secret-key"), ], ) - variable_pool.add(('node_id', 'custom_query'), 'fake-user-query') + variable_pool.add(("node_id", "custom_query"), "fake-user-query") template = ( - 'Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}.' + "Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}." ) segments_group = parser.convert_template(template=template, variable_pool=variable_pool) - assert segments_group.text == 'Hello, fake-user-id! Your query is fake-user-query. And your key is fake-secret-key.' + assert segments_group.text == "Hello, fake-user-id! Your query is fake-user-query. And your key is fake-secret-key." assert ( segments_group.log == f"Hello, fake-user-id! Your query is fake-user-query. And your key is {encrypter.obfuscated_token('fake-secret-key')}." @@ -33,22 +33,22 @@ def test_convert_constant_to_segment_group(): user_inputs={}, environment_variables=[], ) - template = 'Hello, world!' + template = "Hello, world!" segments_group = parser.convert_template(template=template, variable_pool=variable_pool) - assert segments_group.text == 'Hello, world!' - assert segments_group.log == 'Hello, world!' + assert segments_group.text == "Hello, world!" + assert segments_group.log == "Hello, world!" def test_convert_variable_to_segment_group(): variable_pool = VariablePool( system_variables={ - SystemVariable('user_id'): 'fake-user-id', + SystemVariableKey("user_id"): "fake-user-id", }, user_inputs={}, environment_variables=[], ) - template = '{{#sys.user_id#}}' + template = "{{#sys.user_id#}}" segments_group = parser.convert_template(template=template, variable_pool=variable_pool) - assert segments_group.text == 'fake-user-id' - assert segments_group.log == 'fake-user-id' - assert segments_group.value == [StringSegment(value='fake-user-id')] + assert segments_group.text == "fake-user-id" + assert segments_group.log == "fake-user-id" + assert segments_group.value == [StringSegment(value="fake-user-id")] diff --git a/api/tests/unit_tests/core/app/segments/test_variables.py b/api/tests/unit_tests/core/app/segments/test_variables.py index 1f45c15f87..b3f0ae626c 100644 --- a/api/tests/unit_tests/core/app/segments/test_variables.py +++ b/api/tests/unit_tests/core/app/segments/test_variables.py @@ -13,60 +13,60 @@ from core.app.segments import ( def test_frozen_variables(): - var = StringVariable(name='text', value='text') + var = StringVariable(name="text", value="text") with pytest.raises(ValidationError): - var.value = 'new value' + var.value = "new value" - int_var = IntegerVariable(name='integer', value=42) + int_var = IntegerVariable(name="integer", value=42) with pytest.raises(ValidationError): int_var.value = 100 - float_var = FloatVariable(name='float', value=3.14) + float_var = FloatVariable(name="float", value=3.14) with pytest.raises(ValidationError): float_var.value = 2.718 - secret_var = SecretVariable(name='secret', value='secret_value') + secret_var = SecretVariable(name="secret", value="secret_value") with pytest.raises(ValidationError): - secret_var.value = 'new_secret_value' + secret_var.value = "new_secret_value" def test_variable_value_type_immutable(): with pytest.raises(ValidationError): - StringVariable(value_type=SegmentType.ARRAY_ANY, name='text', value='text') + StringVariable(value_type=SegmentType.ARRAY_ANY, name="text", value="text") with pytest.raises(ValidationError): - StringVariable.model_validate({'value_type': 'not text', 'name': 'text', 'value': 'text'}) + StringVariable.model_validate({"value_type": "not text", "name": "text", "value": "text"}) - var = IntegerVariable(name='integer', value=42) + var = IntegerVariable(name="integer", value=42) with pytest.raises(ValidationError): IntegerVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value) - var = FloatVariable(name='float', value=3.14) + var = FloatVariable(name="float", value=3.14) with pytest.raises(ValidationError): FloatVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value) - var = SecretVariable(name='secret', value='secret_value') + var = SecretVariable(name="secret", value="secret_value") with pytest.raises(ValidationError): SecretVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value) def test_object_variable_to_object(): var = ObjectVariable( - name='object', + name="object", value={ - 'key1': { - 'key2': 'value2', + "key1": { + "key2": "value2", }, - 'key2': ['value5_1', 42, {}], + "key2": ["value5_1", 42, {}], }, ) assert var.to_object() == { - 'key1': { - 'key2': 'value2', + "key1": { + "key2": "value2", }, - 'key2': [ - 'value5_1', + "key2": [ + "value5_1", 42, {}, ], @@ -74,11 +74,11 @@ def test_object_variable_to_object(): def test_variable_to_object(): - var = StringVariable(name='text', value='text') - assert var.to_object() == 'text' - var = IntegerVariable(name='integer', value=42) + var = StringVariable(name="text", value="text") + assert var.to_object() == "text" + var = IntegerVariable(name="integer", value=42) assert var.to_object() == 42 - var = FloatVariable(name='float', value=3.14) + var = FloatVariable(name="float", value=3.14) assert var.to_object() == 3.14 - var = SecretVariable(name='secret', value='secret_value') - assert var.to_object() == 'secret_value' + var = SecretVariable(name="secret", value="secret_value") + assert var.to_object() == "secret_value" diff --git a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py index d917bb1003..7a0bc70c63 100644 --- a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py +++ b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py @@ -4,17 +4,17 @@ from unittest.mock import MagicMock, patch from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request -@patch('httpx.request') +@patch("httpx.request") def test_successful_request(mock_request): mock_response = MagicMock() mock_response.status_code = 200 mock_request.return_value = mock_response - response = make_request('GET', 'http://example.com') + response = make_request("GET", "http://example.com") assert response.status_code == 200 -@patch('httpx.request') +@patch("httpx.request") def test_retry_exceed_max_retries(mock_request): mock_response = MagicMock() mock_response.status_code = 500 @@ -23,13 +23,13 @@ def test_retry_exceed_max_retries(mock_request): mock_request.side_effect = side_effects try: - make_request('GET', 'http://example.com', max_retries=SSRF_DEFAULT_MAX_RETRIES - 1) + make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES - 1) raise AssertionError("Expected Exception not raised") except Exception as e: assert str(e) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com" -@patch('httpx.request') +@patch("httpx.request") def test_retry_logic_success(mock_request): side_effects = [] @@ -45,8 +45,8 @@ def test_retry_logic_success(mock_request): mock_request.side_effect = side_effects - response = make_request('GET', 'http://example.com', max_retries=SSRF_DEFAULT_MAX_RETRIES) + response = make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES) assert response.status_code == 200 assert mock_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1 - assert mock_request.call_args_list[0][1].get('method') == 'GET' + assert mock_request.call_args_list[0][1].get("method") == "GET" diff --git a/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/test_text_embedding.py b/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/test_text_embedding.py index 68334fde82..5b159b49b6 100644 --- a/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/test_text_embedding.py +++ b/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/test_text_embedding.py @@ -21,18 +21,18 @@ def test_max_chunks(): def _create_text_embedding(api_key: str, secret_key: str) -> TextEmbedding: return _MockTextEmbedding() - model = 'embedding-v1' + model = "embedding-v1" credentials = { - 'api_key': 'xxxx', - 'secret_key': 'yyyy', + "api_key": "xxxx", + "secret_key": "yyyy", } embedding_model = WenxinTextEmbeddingModel() context_size = embedding_model._get_context_size(model, credentials) max_chunks = embedding_model._get_max_chunks(model, credentials) embedding_model._create_text_embedding = _create_text_embedding - texts = ['0123456789' for i in range(0, max_chunks * 2)] - result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test') + texts = ["0123456789" for i in range(0, max_chunks * 2)] + result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, "test") assert len(result.embeddings) == max_chunks * 2 @@ -41,16 +41,16 @@ def test_context_size(): return GPT2Tokenizer.get_num_tokens(text) def mock_text(token_size: int) -> str: - _text = "".join(['0' for i in range(token_size)]) + _text = "".join(["0" for i in range(token_size)]) num_tokens = get_num_tokens_by_gpt2(_text) ratio = int(np.floor(len(_text) / num_tokens)) m_text = "".join([_text for i in range(ratio)]) return m_text - model = 'embedding-v1' + model = "embedding-v1" credentials = { - 'api_key': 'xxxx', - 'secret_key': 'yyyy', + "api_key": "xxxx", + "secret_key": "yyyy", } embedding_model = WenxinTextEmbeddingModel() context_size = embedding_model._get_context_size(model, credentials) @@ -71,5 +71,5 @@ def test_context_size(): assert get_num_tokens_by_gpt2(text) == context_size * 2 texts = [text] - result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test') + result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, "test") assert result.usage.tokens == context_size diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index d24cd4aae9..24bbde6d4e 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -14,39 +14,24 @@ from models.model import Conversation def test__get_completion_model_prompt_messages(): model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = 'openai' - model_config_mock.model = 'gpt-3.5-turbo-instruct' + model_config_mock.provider = "openai" + model_config_mock.model = "gpt-3.5-turbo-instruct" prompt_template = "Context:\n{{#context#}}\n\nHistories:\n{{#histories#}}\n\nyou are {{name}}." - prompt_template_config = CompletionModelPromptTemplate( - text=prompt_template - ) + prompt_template_config = CompletionModelPromptTemplate(text=prompt_template) memory_config = MemoryConfig( - role_prefix=MemoryConfig.RolePrefix( - user="Human", - assistant="Assistant" - ), - window=MemoryConfig.WindowConfig( - enabled=False - ) + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=False), ) - inputs = { - "name": "John" - } + inputs = {"name": "John"} files = [] context = "I am superman." - memory = TokenBufferMemory( - conversation=Conversation(), - model_instance=model_config_mock - ) + memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock) - history_prompt_messages = [ - UserPromptMessage(content="Hi"), - AssistantPromptMessage(content="Hello") - ] + history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")] memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages) prompt_transform = AdvancedPromptTransform() @@ -59,16 +44,22 @@ def test__get_completion_model_prompt_messages(): context=context, memory_config=memory_config, memory=memory, - model_config=model_config_mock + model_config=model_config_mock, ) assert len(prompt_messages) == 1 - assert prompt_messages[0].content == PromptTemplateParser(template=prompt_template).format({ - "#context#": context, - "#histories#": "\n".join([f"{'Human' if prompt.role.value == 'user' else 'Assistant'}: " - f"{prompt.content}" for prompt in history_prompt_messages]), - **inputs, - }) + assert prompt_messages[0].content == PromptTemplateParser(template=prompt_template).format( + { + "#context#": context, + "#histories#": "\n".join( + [ + f"{'Human' if prompt.role.value == 'user' else 'Assistant'}: " f"{prompt.content}" + for prompt in history_prompt_messages + ] + ), + **inputs, + } + ) def test__get_chat_model_prompt_messages(get_chat_model_args): @@ -77,15 +68,9 @@ def test__get_chat_model_prompt_messages(get_chat_model_args): files = [] query = "Hi2." - memory = TokenBufferMemory( - conversation=Conversation(), - model_instance=model_config_mock - ) + memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock) - history_prompt_messages = [ - UserPromptMessage(content="Hi1."), - AssistantPromptMessage(content="Hello1!") - ] + history_prompt_messages = [UserPromptMessage(content="Hi1."), AssistantPromptMessage(content="Hello1!")] memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages) prompt_transform = AdvancedPromptTransform() @@ -98,14 +83,14 @@ def test__get_chat_model_prompt_messages(get_chat_model_args): context=context, memory_config=memory_config, memory=memory, - model_config=model_config_mock + model_config=model_config_mock, ) assert len(prompt_messages) == 6 assert prompt_messages[0].role == PromptMessageRole.SYSTEM - assert prompt_messages[0].content == PromptTemplateParser( - template=messages[0].text - ).format({**inputs, "#context#": context}) + assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format( + {**inputs, "#context#": context} + ) assert prompt_messages[5].content == query @@ -124,14 +109,14 @@ def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args): context=context, memory_config=None, memory=None, - model_config=model_config_mock + model_config=model_config_mock, ) assert len(prompt_messages) == 3 assert prompt_messages[0].role == PromptMessageRole.SYSTEM - assert prompt_messages[0].content == PromptTemplateParser( - template=messages[0].text - ).format({**inputs, "#context#": context}) + assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format( + {**inputs, "#context#": context} + ) def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args): @@ -148,7 +133,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg image_config={ "detail": "high", } - ) + ), ) ] @@ -162,14 +147,14 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg context=context, memory_config=None, memory=None, - model_config=model_config_mock + model_config=model_config_mock, ) assert len(prompt_messages) == 4 assert prompt_messages[0].role == PromptMessageRole.SYSTEM - assert prompt_messages[0].content == PromptTemplateParser( - template=messages[0].text - ).format({**inputs, "#context#": context}) + assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format( + {**inputs, "#context#": context} + ) assert isinstance(prompt_messages[3].content, list) assert len(prompt_messages[3].content) == 2 assert prompt_messages[3].content[1].data == files[0].url @@ -178,33 +163,20 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg @pytest.fixture def get_chat_model_args(): model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = 'openai' - model_config_mock.model = 'gpt-4' + model_config_mock.provider = "openai" + model_config_mock.model = "gpt-4" - memory_config = MemoryConfig( - window=MemoryConfig.WindowConfig( - enabled=False - ) - ) + memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) prompt_messages = [ ChatModelMessage( - text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", - role=PromptMessageRole.SYSTEM + text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", role=PromptMessageRole.SYSTEM ), - ChatModelMessage( - text="Hi.", - role=PromptMessageRole.USER - ), - ChatModelMessage( - text="Hello!", - role=PromptMessageRole.ASSISTANT - ) + ChatModelMessage(text="Hi.", role=PromptMessageRole.USER), + ChatModelMessage(text="Hello!", role=PromptMessageRole.ASSISTANT), ] - inputs = { - "name": "John" - } + inputs = {"name": "John"} context = "I am superman." diff --git a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py index 9de268d762..0fd176e65d 100644 --- a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py @@ -18,27 +18,28 @@ from models.model import Conversation def test_get_prompt(): prompt_messages = [ - SystemPromptMessage(content='System Template'), - UserPromptMessage(content='User Query'), + SystemPromptMessage(content="System Template"), + UserPromptMessage(content="User Query"), ] history_messages = [ - SystemPromptMessage(content='System Prompt 1'), - UserPromptMessage(content='User Prompt 1'), - AssistantPromptMessage(content='Assistant Thought 1'), - ToolPromptMessage(content='Tool 1-1', name='Tool 1-1', tool_call_id='1'), - ToolPromptMessage(content='Tool 1-2', name='Tool 1-2', tool_call_id='2'), - SystemPromptMessage(content='System Prompt 2'), - UserPromptMessage(content='User Prompt 2'), - AssistantPromptMessage(content='Assistant Thought 2'), - ToolPromptMessage(content='Tool 2-1', name='Tool 2-1', tool_call_id='3'), - ToolPromptMessage(content='Tool 2-2', name='Tool 2-2', tool_call_id='4'), - UserPromptMessage(content='User Prompt 3'), - AssistantPromptMessage(content='Assistant Thought 3'), + SystemPromptMessage(content="System Prompt 1"), + UserPromptMessage(content="User Prompt 1"), + AssistantPromptMessage(content="Assistant Thought 1"), + ToolPromptMessage(content="Tool 1-1", name="Tool 1-1", tool_call_id="1"), + ToolPromptMessage(content="Tool 1-2", name="Tool 1-2", tool_call_id="2"), + SystemPromptMessage(content="System Prompt 2"), + UserPromptMessage(content="User Prompt 2"), + AssistantPromptMessage(content="Assistant Thought 2"), + ToolPromptMessage(content="Tool 2-1", name="Tool 2-1", tool_call_id="3"), + ToolPromptMessage(content="Tool 2-2", name="Tool 2-2", tool_call_id="4"), + UserPromptMessage(content="User Prompt 3"), + AssistantPromptMessage(content="Assistant Thought 3"), ] # use message number instead of token for testing def side_effect_get_num_tokens(*args): return len(args[2]) + large_language_model_mock = MagicMock(spec=LargeLanguageModel) large_language_model_mock.get_num_tokens = MagicMock(side_effect=side_effect_get_num_tokens) @@ -46,20 +47,17 @@ def test_get_prompt(): provider_model_bundle_mock.model_type_instance = large_language_model_mock model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity) - model_config_mock.model = 'openai' + model_config_mock.model = "openai" model_config_mock.credentials = {} model_config_mock.provider_model_bundle = provider_model_bundle_mock - memory = TokenBufferMemory( - conversation=Conversation(), - model_instance=model_config_mock - ) + memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock) transform = AgentHistoryPromptTransform( model_config=model_config_mock, prompt_messages=prompt_messages, history_messages=history_messages, - memory=memory + memory=memory, ) max_token_limit = 5 diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py index 2bcc6f4292..89c14463bb 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py @@ -12,19 +12,15 @@ from core.prompt.prompt_transform import PromptTransform def test__calculate_rest_token(): model_schema_mock = MagicMock(spec=AIModelEntity) parameter_rule_mock = MagicMock(spec=ParameterRule) - parameter_rule_mock.name = 'max_tokens' - model_schema_mock.parameter_rules = [ - parameter_rule_mock - ] - model_schema_mock.model_properties = { - ModelPropertyKey.CONTEXT_SIZE: 62 - } + parameter_rule_mock.name = "max_tokens" + model_schema_mock.parameter_rules = [parameter_rule_mock] + model_schema_mock.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 62} large_language_model_mock = MagicMock(spec=LargeLanguageModel) large_language_model_mock.get_num_tokens.return_value = 6 provider_mock = MagicMock(spec=ProviderEntity) - provider_mock.provider = 'openai' + provider_mock.provider = "openai" provider_configuration_mock = MagicMock(spec=ProviderConfiguration) provider_configuration_mock.provider = provider_mock @@ -35,11 +31,9 @@ def test__calculate_rest_token(): provider_model_bundle_mock.configuration = provider_configuration_mock model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.model = 'gpt-4' + model_config_mock.model = "gpt-4" model_config_mock.credentials = {} - model_config_mock.parameters = { - 'max_tokens': 50 - } + model_config_mock.parameters = {"max_tokens": 50} model_config_mock.model_schema = model_schema_mock model_config_mock.provider_model_bundle = provider_model_bundle_mock @@ -49,8 +43,10 @@ def test__calculate_rest_token(): rest_tokens = prompt_transform._calculate_rest_token(prompt_messages, model_config_mock) # Validate based on the mock configuration and expected logic - expected_rest_tokens = (model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE] - - model_config_mock.parameters['max_tokens'] - - large_language_model_mock.get_num_tokens.return_value) + expected_rest_tokens = ( + model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE] + - model_config_mock.parameters["max_tokens"] + - large_language_model_mock.get_num_tokens.return_value + ) assert rest_tokens == expected_rest_tokens assert rest_tokens == 6 diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index 6d6363610b..c32fc2bc34 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -19,12 +19,15 @@ def test_get_common_chat_app_prompt_template_with_pcqm(): query_in_prompt=True, with_memory_prompt=True, ) - prompt_rules = prompt_template['prompt_rules'] - assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] - + pre_prompt + '\n' - + prompt_rules['histories_prompt'] - + prompt_rules['query_prompt']) - assert prompt_template['special_variable_keys'] == ['#context#', '#histories#', '#query#'] + prompt_rules = prompt_template["prompt_rules"] + assert prompt_template["prompt_template"].template == ( + prompt_rules["context_prompt"] + + pre_prompt + + "\n" + + prompt_rules["histories_prompt"] + + prompt_rules["query_prompt"] + ) + assert prompt_template["special_variable_keys"] == ["#context#", "#histories#", "#query#"] def test_get_baichuan_chat_app_prompt_template_with_pcqm(): @@ -39,12 +42,15 @@ def test_get_baichuan_chat_app_prompt_template_with_pcqm(): query_in_prompt=True, with_memory_prompt=True, ) - prompt_rules = prompt_template['prompt_rules'] - assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] - + pre_prompt + '\n' - + prompt_rules['histories_prompt'] - + prompt_rules['query_prompt']) - assert prompt_template['special_variable_keys'] == ['#context#', '#histories#', '#query#'] + prompt_rules = prompt_template["prompt_rules"] + assert prompt_template["prompt_template"].template == ( + prompt_rules["context_prompt"] + + pre_prompt + + "\n" + + prompt_rules["histories_prompt"] + + prompt_rules["query_prompt"] + ) + assert prompt_template["special_variable_keys"] == ["#context#", "#histories#", "#query#"] def test_get_common_completion_app_prompt_template_with_pcq(): @@ -59,11 +65,11 @@ def test_get_common_completion_app_prompt_template_with_pcq(): query_in_prompt=True, with_memory_prompt=False, ) - prompt_rules = prompt_template['prompt_rules'] - assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] - + pre_prompt + '\n' - + prompt_rules['query_prompt']) - assert prompt_template['special_variable_keys'] == ['#context#', '#query#'] + prompt_rules = prompt_template["prompt_rules"] + assert prompt_template["prompt_template"].template == ( + prompt_rules["context_prompt"] + pre_prompt + "\n" + prompt_rules["query_prompt"] + ) + assert prompt_template["special_variable_keys"] == ["#context#", "#query#"] def test_get_baichuan_completion_app_prompt_template_with_pcq(): @@ -78,12 +84,12 @@ def test_get_baichuan_completion_app_prompt_template_with_pcq(): query_in_prompt=True, with_memory_prompt=False, ) - print(prompt_template['prompt_template'].template) - prompt_rules = prompt_template['prompt_rules'] - assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] - + pre_prompt + '\n' - + prompt_rules['query_prompt']) - assert prompt_template['special_variable_keys'] == ['#context#', '#query#'] + print(prompt_template["prompt_template"].template) + prompt_rules = prompt_template["prompt_rules"] + assert prompt_template["prompt_template"].template == ( + prompt_rules["context_prompt"] + pre_prompt + "\n" + prompt_rules["query_prompt"] + ) + assert prompt_template["special_variable_keys"] == ["#context#", "#query#"] def test_get_common_chat_app_prompt_template_with_q(): @@ -98,9 +104,9 @@ def test_get_common_chat_app_prompt_template_with_q(): query_in_prompt=True, with_memory_prompt=False, ) - prompt_rules = prompt_template['prompt_rules'] - assert prompt_template['prompt_template'].template == prompt_rules['query_prompt'] - assert prompt_template['special_variable_keys'] == ['#query#'] + prompt_rules = prompt_template["prompt_rules"] + assert prompt_template["prompt_template"].template == prompt_rules["query_prompt"] + assert prompt_template["special_variable_keys"] == ["#query#"] def test_get_common_chat_app_prompt_template_with_cq(): @@ -115,10 +121,11 @@ def test_get_common_chat_app_prompt_template_with_cq(): query_in_prompt=True, with_memory_prompt=False, ) - prompt_rules = prompt_template['prompt_rules'] - assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] - + prompt_rules['query_prompt']) - assert prompt_template['special_variable_keys'] == ['#context#', '#query#'] + prompt_rules = prompt_template["prompt_rules"] + assert prompt_template["prompt_template"].template == ( + prompt_rules["context_prompt"] + prompt_rules["query_prompt"] + ) + assert prompt_template["special_variable_keys"] == ["#context#", "#query#"] def test_get_common_chat_app_prompt_template_with_p(): @@ -133,30 +140,25 @@ def test_get_common_chat_app_prompt_template_with_p(): query_in_prompt=False, with_memory_prompt=False, ) - assert prompt_template['prompt_template'].template == pre_prompt + '\n' - assert prompt_template['custom_variable_keys'] == ['name'] - assert prompt_template['special_variable_keys'] == [] + assert prompt_template["prompt_template"].template == pre_prompt + "\n" + assert prompt_template["custom_variable_keys"] == ["name"] + assert prompt_template["special_variable_keys"] == [] def test__get_chat_model_prompt_messages(): model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity) - model_config_mock.provider = 'openai' - model_config_mock.model = 'gpt-4' + model_config_mock.provider = "openai" + model_config_mock.model = "gpt-4" memory_mock = MagicMock(spec=TokenBufferMemory) - history_prompt_messages = [ - UserPromptMessage(content="Hi"), - AssistantPromptMessage(content="Hello") - ] + history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")] memory_mock.get_history_prompt_messages.return_value = history_prompt_messages prompt_transform = SimplePromptTransform() prompt_transform._calculate_rest_token = MagicMock(return_value=2000) pre_prompt = "You are a helpful assistant {{name}}." - inputs = { - "name": "John" - } + inputs = {"name": "John"} context = "yes or no." query = "How are you?" prompt_messages, _ = prompt_transform._get_chat_model_prompt_messages( @@ -167,7 +169,7 @@ def test__get_chat_model_prompt_messages(): files=[], context=context, memory=memory_mock, - model_config=model_config_mock + model_config=model_config_mock, ) prompt_template = prompt_transform.get_prompt_template( @@ -180,8 +182,8 @@ def test__get_chat_model_prompt_messages(): with_memory_prompt=False, ) - full_inputs = {**inputs, '#context#': context} - real_system_prompt = prompt_template['prompt_template'].format(full_inputs) + full_inputs = {**inputs, "#context#": context} + real_system_prompt = prompt_template["prompt_template"].format(full_inputs) assert len(prompt_messages) == 4 assert prompt_messages[0].content == real_system_prompt @@ -192,26 +194,18 @@ def test__get_chat_model_prompt_messages(): def test__get_completion_model_prompt_messages(): model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity) - model_config_mock.provider = 'openai' - model_config_mock.model = 'gpt-3.5-turbo-instruct' + model_config_mock.provider = "openai" + model_config_mock.model = "gpt-3.5-turbo-instruct" - memory = TokenBufferMemory( - conversation=Conversation(), - model_instance=model_config_mock - ) + memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock) - history_prompt_messages = [ - UserPromptMessage(content="Hi"), - AssistantPromptMessage(content="Hello") - ] + history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")] memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages) prompt_transform = SimplePromptTransform() prompt_transform._calculate_rest_token = MagicMock(return_value=2000) pre_prompt = "You are a helpful assistant {{name}}." - inputs = { - "name": "John" - } + inputs = {"name": "John"} context = "yes or no." query = "How are you?" prompt_messages, stops = prompt_transform._get_completion_model_prompt_messages( @@ -222,7 +216,7 @@ def test__get_completion_model_prompt_messages(): files=[], context=context, memory=memory, - model_config=model_config_mock + model_config=model_config_mock, ) prompt_template = prompt_transform.get_prompt_template( @@ -235,14 +229,19 @@ def test__get_completion_model_prompt_messages(): with_memory_prompt=True, ) - prompt_rules = prompt_template['prompt_rules'] - full_inputs = {**inputs, '#context#': context, '#query#': query, '#histories#': memory.get_history_prompt_text( - max_token_limit=2000, - human_prefix=prompt_rules.get("human_prefix", "Human"), - ai_prefix=prompt_rules.get("assistant_prefix", "Assistant") - )} - real_prompt = prompt_template['prompt_template'].format(full_inputs) + prompt_rules = prompt_template["prompt_rules"] + full_inputs = { + **inputs, + "#context#": context, + "#query#": query, + "#histories#": memory.get_history_prompt_text( + max_token_limit=2000, + human_prefix=prompt_rules.get("human_prefix", "Human"), + ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"), + ), + } + real_prompt = prompt_template["prompt_template"].format(full_inputs) assert len(prompt_messages) == 1 - assert stops == prompt_rules.get('stops') + assert stops == prompt_rules.get("stops") assert prompt_messages[0].content == real_prompt diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py b/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py index 9e43b23658..8d735cae86 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py @@ -5,20 +5,15 @@ from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig def test_default_value(): - valid_config = { - 'host': 'localhost', - 'port': 19530, - 'user': 'root', - 'password': 'Milvus' - } + valid_config = {"host": "localhost", "port": 19530, "user": "root", "password": "Milvus"} for key in valid_config: config = valid_config.copy() del config[key] with pytest.raises(ValidationError) as e: MilvusConfig(**config) - assert e.value.errors()[0]['msg'] == f'Value error, config MILVUS_{key.upper()} is required' + assert e.value.errors()[0]["msg"] == f"Value error, config MILVUS_{key.upper()} is required" config = MilvusConfig(**valid_config) assert config.secure is False - assert config.database == 'default' + assert config.database == "default" diff --git a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py index a8bba11e16..d5a1d8f436 100644 --- a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py +++ b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py @@ -9,19 +9,17 @@ from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_resp def test_firecrawl_web_extractor_crawl_mode(mocker): url = "https://firecrawl.dev" - api_key = os.getenv('FIRECRAWL_API_KEY') or 'fc-' - base_url = 'https://api.firecrawl.dev' - firecrawl_app = FirecrawlApp(api_key=api_key, - base_url=base_url) + api_key = os.getenv("FIRECRAWL_API_KEY") or "fc-" + base_url = "https://api.firecrawl.dev" + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=base_url) params = { - 'crawlerOptions': { + "crawlerOptions": { "includes": [], "excludes": [], "generateImgAltText": True, "maxDepth": 1, "limit": 1, - 'returnOnlyUrls': False, - + "returnOnlyUrls": False, } } mocked_firecrawl = { diff --git a/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py index b231fe479b..eea584a2f8 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py @@ -8,11 +8,8 @@ page_id = "page1" extractor = notion_extractor.NotionExtractor( - notion_workspace_id='x', - notion_obj_id='x', - notion_page_type='page', - tenant_id='x', - notion_access_token='x') + notion_workspace_id="x", notion_obj_id="x", notion_page_type="page", tenant_id="x", notion_access_token="x" +) def _generate_page(page_title: str): @@ -21,16 +18,10 @@ def _generate_page(page_title: str): "id": page_id, "properties": { "Page": { - "type": "title", - "title": [ - { - "type": "text", - "text": {"content": page_title}, - "plain_text": page_title - } - ] + "type": "title", + "title": [{"type": "text", "text": {"content": page_title}, "plain_text": page_title}], } - } + }, } @@ -38,10 +29,7 @@ def _generate_block(block_id: str, block_type: str, block_text: str): return { "object": "block", "id": block_id, - "parent": { - "type": "page_id", - "page_id": page_id - }, + "parent": {"type": "page_id", "page_id": page_id}, "type": block_type, "has_children": False, block_type: { @@ -49,10 +37,11 @@ def _generate_block(block_id: str, block_type: str, block_text: str): { "type": "text", "text": {"content": block_text}, - "plain_text": block_text, - }] - } - } + "plain_text": block_text, + } + ] + }, + } def _mock_response(data): @@ -63,7 +52,7 @@ def _mock_response(data): def _remove_multiple_new_lines(text): - while '\n\n' in text: + while "\n\n" in text: text = text.replace("\n\n", "\n") return text.strip() @@ -71,21 +60,21 @@ def _remove_multiple_new_lines(text): def test_notion_page(mocker): texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"] mocked_notion_page = { - "object": "list", - "results": [ - _generate_block("b1", "heading_1", texts[0]), - _generate_block("b2", "heading_2", texts[1]), - _generate_block("b3", "paragraph", texts[2]), - _generate_block("b4", "heading_3", texts[3]) - ], - "next_cursor": None + "object": "list", + "results": [ + _generate_block("b1", "heading_1", texts[0]), + _generate_block("b2", "heading_2", texts[1]), + _generate_block("b3", "paragraph", texts[2]), + _generate_block("b4", "heading_3", texts[3]), + ], + "next_cursor": None, } mocker.patch("requests.request", return_value=_mock_response(mocked_notion_page)) page_docs = extractor._load_data_as_documents(page_id, "page") assert len(page_docs) == 1 content = _remove_multiple_new_lines(page_docs[0].page_content) - assert content == '# Head 1\n## 1.1\nparagraph 1\n### 1.1.1' + assert content == "# Head 1\n## 1.1\nparagraph 1\n### 1.1.1" def test_notion_database(mocker): @@ -93,10 +82,10 @@ def test_notion_database(mocker): mocked_notion_database = { "object": "list", "results": [_generate_page(i) for i in page_title_list], - "next_cursor": None + "next_cursor": None, } mocker.patch("requests.post", return_value=_mock_response(mocked_notion_database)) database_docs = extractor._load_data_as_documents(database_id, "database") assert len(database_docs) == 1 content = _remove_multiple_new_lines(database_docs[0].page_content) - assert content == '\n'.join([f'Page:{i}' for i in page_title_list]) + assert content == "\n".join([f"Page:{i}" for i in page_title_list]) diff --git a/api/tests/unit_tests/core/test_model_manager.py b/api/tests/unit_tests/core/test_model_manager.py index 3024a54a4d..2808b5b0fa 100644 --- a/api/tests/unit_tests/core/test_model_manager.py +++ b/api/tests/unit_tests/core/test_model_manager.py @@ -10,36 +10,24 @@ from core.model_runtime.entities.model_entities import ModelType @pytest.fixture def lb_model_manager(): load_balancing_configs = [ - ModelLoadBalancingConfiguration( - id='id1', - name='__inherit__', - credentials={} - ), - ModelLoadBalancingConfiguration( - id='id2', - name='first', - credentials={"openai_api_key": "fake_key"} - ), - ModelLoadBalancingConfiguration( - id='id3', - name='second', - credentials={"openai_api_key": "fake_key"} - ) + ModelLoadBalancingConfiguration(id="id1", name="__inherit__", credentials={}), + ModelLoadBalancingConfiguration(id="id2", name="first", credentials={"openai_api_key": "fake_key"}), + ModelLoadBalancingConfiguration(id="id3", name="second", credentials={"openai_api_key": "fake_key"}), ] lb_model_manager = LBModelManager( - tenant_id='tenant_id', - provider='openai', + tenant_id="tenant_id", + provider="openai", model_type=ModelType.LLM, - model='gpt-4', + model="gpt-4", load_balancing_configs=load_balancing_configs, - managed_credentials={"openai_api_key": "fake_key"} + managed_credentials={"openai_api_key": "fake_key"}, ) lb_model_manager.cooldown = MagicMock(return_value=None) def is_cooldown(config: ModelLoadBalancingConfiguration): - if config.id == 'id1': + if config.id == "id1": return True return False @@ -61,14 +49,15 @@ def test_lb_model_manager_fetch_next(mocker, lb_model_manager): assert lb_model_manager.in_cooldown(config3) is False start_index = 0 + def incr(key): nonlocal start_index start_index += 1 return start_index - mocker.patch('redis.Redis.incr', side_effect=incr) - mocker.patch('redis.Redis.set', return_value=None) - mocker.patch('redis.Redis.expire', return_value=None) + mocker.patch("redis.Redis.incr", side_effect=incr) + mocker.patch("redis.Redis.set", return_value=None) + mocker.patch("redis.Redis.expire", return_value=None) config = lb_model_manager.fetch_next() assert config == config2 diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index 072b6f100f..2f4214a580 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -11,62 +11,62 @@ def test__to_model_settings(mocker): provider_entity = None for provider in provider_entities: - if provider.provider == 'openai': + if provider.provider == "openai": provider_entity = provider # Mocking the inputs - provider_model_settings = [ProviderModelSetting( - id='id', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - enabled=True, - load_balancing_enabled=True - )] - load_balancing_model_configs = [ - LoadBalancingModelConfig( - id='id1', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - name='__inherit__', - encrypted_config=None, - enabled=True - ), - LoadBalancingModelConfig( - id='id2', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - name='first', - encrypted_config='{"openai_api_key": "fake_key"}', - enabled=True + provider_model_settings = [ + ProviderModelSetting( + id="id", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + enabled=True, + load_balancing_enabled=True, ) ] + load_balancing_model_configs = [ + LoadBalancingModelConfig( + id="id1", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="__inherit__", + encrypted_config=None, + enabled=True, + ), + LoadBalancingModelConfig( + id="id2", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="first", + encrypted_config='{"openai_api_key": "fake_key"}', + enabled=True, + ), + ] - mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"}) + mocker.patch( + "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} + ) provider_manager = ProviderManager() # Running the method - result = provider_manager._to_model_settings( - provider_entity, - provider_model_settings, - load_balancing_model_configs - ) + result = provider_manager._to_model_settings(provider_entity, provider_model_settings, load_balancing_model_configs) # Asserting that the result is as expected assert len(result) == 1 assert isinstance(result[0], ModelSettings) - assert result[0].model == 'gpt-4' + assert result[0].model == "gpt-4" assert result[0].model_type == ModelType.LLM assert result[0].enabled is True assert len(result[0].load_balancing_configs) == 2 - assert result[0].load_balancing_configs[0].name == '__inherit__' - assert result[0].load_balancing_configs[1].name == 'first' + assert result[0].load_balancing_configs[0].name == "__inherit__" + assert result[0].load_balancing_configs[1].name == "first" def test__to_model_settings_only_one_lb(mocker): @@ -75,47 +75,47 @@ def test__to_model_settings_only_one_lb(mocker): provider_entity = None for provider in provider_entities: - if provider.provider == 'openai': + if provider.provider == "openai": provider_entity = provider # Mocking the inputs - provider_model_settings = [ProviderModelSetting( - id='id', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - enabled=True, - load_balancing_enabled=True - )] + provider_model_settings = [ + ProviderModelSetting( + id="id", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + enabled=True, + load_balancing_enabled=True, + ) + ] load_balancing_model_configs = [ LoadBalancingModelConfig( - id='id1', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - name='__inherit__', + id="id1", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="__inherit__", encrypted_config=None, - enabled=True + enabled=True, ) ] - mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"}) + mocker.patch( + "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} + ) provider_manager = ProviderManager() # Running the method - result = provider_manager._to_model_settings( - provider_entity, - provider_model_settings, - load_balancing_model_configs - ) + result = provider_manager._to_model_settings(provider_entity, provider_model_settings, load_balancing_model_configs) # Asserting that the result is as expected assert len(result) == 1 assert isinstance(result[0], ModelSettings) - assert result[0].model == 'gpt-4' + assert result[0].model == "gpt-4" assert result[0].model_type == ModelType.LLM assert result[0].enabled is True assert len(result[0].load_balancing_configs) == 0 @@ -127,57 +127,57 @@ def test__to_model_settings_lb_disabled(mocker): provider_entity = None for provider in provider_entities: - if provider.provider == 'openai': + if provider.provider == "openai": provider_entity = provider # Mocking the inputs - provider_model_settings = [ProviderModelSetting( - id='id', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - enabled=True, - load_balancing_enabled=False - )] - load_balancing_model_configs = [ - LoadBalancingModelConfig( - id='id1', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - name='__inherit__', - encrypted_config=None, - enabled=True - ), - LoadBalancingModelConfig( - id='id2', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - name='first', - encrypted_config='{"openai_api_key": "fake_key"}', - enabled=True + provider_model_settings = [ + ProviderModelSetting( + id="id", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + enabled=True, + load_balancing_enabled=False, ) ] + load_balancing_model_configs = [ + LoadBalancingModelConfig( + id="id1", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="__inherit__", + encrypted_config=None, + enabled=True, + ), + LoadBalancingModelConfig( + id="id2", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="first", + encrypted_config='{"openai_api_key": "fake_key"}', + enabled=True, + ), + ] - mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"}) + mocker.patch( + "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} + ) provider_manager = ProviderManager() # Running the method - result = provider_manager._to_model_settings( - provider_entity, - provider_model_settings, - load_balancing_model_configs - ) + result = provider_manager._to_model_settings(provider_entity, provider_model_settings, load_balancing_model_configs) # Asserting that the result is as expected assert len(result) == 1 assert isinstance(result[0], ModelSettings) - assert result[0].model == 'gpt-4' + assert result[0].model == "gpt-4" assert result[0].model_type == ModelType.LLM assert result[0].enabled is True assert len(result[0].load_balancing_configs) == 0 diff --git a/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py b/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py index 9addeeadca..279a6cdbc3 100644 --- a/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py +++ b/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py @@ -5,52 +5,52 @@ from core.tools.utils.tool_parameter_converter import ToolParameterConverter def test_get_parameter_type(): - assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.STRING) == 'string' - assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.SELECT) == 'string' - assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.BOOLEAN) == 'boolean' - assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.NUMBER) == 'number' + assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.STRING) == "string" + assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.SELECT) == "string" + assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.BOOLEAN) == "boolean" + assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.NUMBER) == "number" with pytest.raises(ValueError): - ToolParameterConverter.get_parameter_type('unsupported_type') + ToolParameterConverter.get_parameter_type("unsupported_type") def test_cast_parameter_by_type(): # string - assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.STRING) == 'test' - assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.STRING) == '1' - assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.STRING) == '1.0' - assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.STRING) == '' + assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.STRING) == "test" + assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.STRING) == "1" + assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.STRING) == "1.0" + assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.STRING) == "" # secret input - assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.SECRET_INPUT) == 'test' - assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SECRET_INPUT) == '1' - assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SECRET_INPUT) == '1.0' - assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SECRET_INPUT) == '' + assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.SECRET_INPUT) == "test" + assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SECRET_INPUT) == "1" + assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SECRET_INPUT) == "1.0" + assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SECRET_INPUT) == "" # select - assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.SELECT) == 'test' - assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SELECT) == '1' - assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SELECT) == '1.0' - assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SELECT) == '' + assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.SELECT) == "test" + assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SELECT) == "1" + assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SELECT) == "1.0" + assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SELECT) == "" # boolean - true_values = [True, 'True', 'true', '1', 'YES', 'Yes', 'yes', 'y', 'something'] + true_values = [True, "True", "true", "1", "YES", "Yes", "yes", "y", "something"] for value in true_values: assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is True - false_values = [False, 'False', 'false', '0', 'NO', 'No', 'no', 'n', None, ''] + false_values = [False, "False", "false", "0", "NO", "No", "no", "n", None, ""] for value in false_values: assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is False # number - assert ToolParameterConverter.cast_parameter_by_type('1', ToolParameter.ToolParameterType.NUMBER) == 1 - assert ToolParameterConverter.cast_parameter_by_type('1.0', ToolParameter.ToolParameterType.NUMBER) == 1.0 - assert ToolParameterConverter.cast_parameter_by_type('-1.0', ToolParameter.ToolParameterType.NUMBER) == -1.0 + assert ToolParameterConverter.cast_parameter_by_type("1", ToolParameter.ToolParameterType.NUMBER) == 1 + assert ToolParameterConverter.cast_parameter_by_type("1.0", ToolParameter.ToolParameterType.NUMBER) == 1.0 + assert ToolParameterConverter.cast_parameter_by_type("-1.0", ToolParameter.ToolParameterType.NUMBER) == -1.0 assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.NUMBER) == 1 assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.NUMBER) == 1.0 assert ToolParameterConverter.cast_parameter_by_type(-1.0, ToolParameter.ToolParameterType.NUMBER) == -1.0 assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None # unknown - assert ToolParameterConverter.cast_parameter_by_type('1', 'unknown_type') == '1' - assert ToolParameterConverter.cast_parameter_by_type(1, 'unknown_type') == '1' + assert ToolParameterConverter.cast_parameter_by_type("1", "unknown_type") == "1" + assert ToolParameterConverter.cast_parameter_by_type(1, "unknown_type") == "1" assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/test_answer.py index 4617b6a42f..8020674ee6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_answer.py @@ -2,7 +2,7 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.base_node import UserFrom from extensions.ext_database import db @@ -11,29 +11,30 @@ from models.workflow import WorkflowNodeExecutionStatus def test_execute_answer(): node = AnswerNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, config={ - 'id': 'answer', - 'data': { - 'title': '123', - 'type': 'answer', - 'answer': 'Today\'s weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.' - } - } + "id": "answer", + "data": { + "title": "123", + "type": "answer", + "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", + }, + }, ) # construct variable pool - pool = VariablePool(system_variables={ - SystemVariable.FILES: [], - SystemVariable.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) - pool.add(['start', 'weather'], 'sunny') - pool.add(['llm', 'text'], 'You are a helpful AI.') + pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + ) + pool.add(["start", "weather"], "sunny") + pool.add(["llm", "text"], "You are a helpful AI.") # Mock db.session.close() db.session.close = MagicMock() @@ -42,4 +43,4 @@ def test_execute_answer(): result = node._run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs['answer'] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin." + assert result.outputs["answer"] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin." diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index d21b7785c4..9535bc2186 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -2,7 +2,7 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.if_else.if_else_node import IfElseNode from extensions.ext_database import db @@ -11,134 +11,81 @@ from models.workflow import WorkflowNodeExecutionStatus def test_execute_if_else_result_true(): node = IfElseNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, config={ - 'id': 'if-else', - 'data': { - 'title': '123', - 'type': 'if-else', - 'logical_operator': 'and', - 'conditions': [ + "id": "if-else", + "data": { + "title": "123", + "type": "if-else", + "logical_operator": "and", + "conditions": [ { - 'comparison_operator': 'contains', - 'variable_selector': ['start', 'array_contains'], - 'value': 'ab' + "comparison_operator": "contains", + "variable_selector": ["start", "array_contains"], + "value": "ab", }, { - 'comparison_operator': 'not contains', - 'variable_selector': ['start', 'array_not_contains'], - 'value': 'ab' + "comparison_operator": "not contains", + "variable_selector": ["start", "array_not_contains"], + "value": "ab", }, + {"comparison_operator": "contains", "variable_selector": ["start", "contains"], "value": "ab"}, { - 'comparison_operator': 'contains', - 'variable_selector': ['start', 'contains'], - 'value': 'ab' + "comparison_operator": "not contains", + "variable_selector": ["start", "not_contains"], + "value": "ab", }, + {"comparison_operator": "start with", "variable_selector": ["start", "start_with"], "value": "ab"}, + {"comparison_operator": "end with", "variable_selector": ["start", "end_with"], "value": "ab"}, + {"comparison_operator": "is", "variable_selector": ["start", "is"], "value": "ab"}, + {"comparison_operator": "is not", "variable_selector": ["start", "is_not"], "value": "ab"}, + {"comparison_operator": "empty", "variable_selector": ["start", "empty"], "value": "ab"}, + {"comparison_operator": "not empty", "variable_selector": ["start", "not_empty"], "value": "ab"}, + {"comparison_operator": "=", "variable_selector": ["start", "equals"], "value": "22"}, + {"comparison_operator": "≠", "variable_selector": ["start", "not_equals"], "value": "22"}, + {"comparison_operator": ">", "variable_selector": ["start", "greater_than"], "value": "22"}, + {"comparison_operator": "<", "variable_selector": ["start", "less_than"], "value": "22"}, { - 'comparison_operator': 'not contains', - 'variable_selector': ['start', 'not_contains'], - 'value': 'ab' + "comparison_operator": "≥", + "variable_selector": ["start", "greater_than_or_equal"], + "value": "22", }, - { - 'comparison_operator': 'start with', - 'variable_selector': ['start', 'start_with'], - 'value': 'ab' - }, - { - 'comparison_operator': 'end with', - 'variable_selector': ['start', 'end_with'], - 'value': 'ab' - }, - { - 'comparison_operator': 'is', - 'variable_selector': ['start', 'is'], - 'value': 'ab' - }, - { - 'comparison_operator': 'is not', - 'variable_selector': ['start', 'is_not'], - 'value': 'ab' - }, - { - 'comparison_operator': 'empty', - 'variable_selector': ['start', 'empty'], - 'value': 'ab' - }, - { - 'comparison_operator': 'not empty', - 'variable_selector': ['start', 'not_empty'], - 'value': 'ab' - }, - { - 'comparison_operator': '=', - 'variable_selector': ['start', 'equals'], - 'value': '22' - }, - { - 'comparison_operator': '≠', - 'variable_selector': ['start', 'not_equals'], - 'value': '22' - }, - { - 'comparison_operator': '>', - 'variable_selector': ['start', 'greater_than'], - 'value': '22' - }, - { - 'comparison_operator': '<', - 'variable_selector': ['start', 'less_than'], - 'value': '22' - }, - { - 'comparison_operator': '≥', - 'variable_selector': ['start', 'greater_than_or_equal'], - 'value': '22' - }, - { - 'comparison_operator': '≤', - 'variable_selector': ['start', 'less_than_or_equal'], - 'value': '22' - }, - { - 'comparison_operator': 'null', - 'variable_selector': ['start', 'null'] - }, - { - 'comparison_operator': 'not null', - 'variable_selector': ['start', 'not_null'] - }, - ] - } - } + {"comparison_operator": "≤", "variable_selector": ["start", "less_than_or_equal"], "value": "22"}, + {"comparison_operator": "null", "variable_selector": ["start", "null"]}, + {"comparison_operator": "not null", "variable_selector": ["start", "not_null"]}, + ], + }, + }, ) # construct variable pool - pool = VariablePool(system_variables={ - SystemVariable.FILES: [], - SystemVariable.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) - pool.add(['start', 'array_contains'], ['ab', 'def']) - pool.add(['start', 'array_not_contains'], ['ac', 'def']) - pool.add(['start', 'contains'], 'cabcde') - pool.add(['start', 'not_contains'], 'zacde') - pool.add(['start', 'start_with'], 'abc') - pool.add(['start', 'end_with'], 'zzab') - pool.add(['start', 'is'], 'ab') - pool.add(['start', 'is_not'], 'aab') - pool.add(['start', 'empty'], '') - pool.add(['start', 'not_empty'], 'aaa') - pool.add(['start', 'equals'], 22) - pool.add(['start', 'not_equals'], 23) - pool.add(['start', 'greater_than'], 23) - pool.add(['start', 'less_than'], 21) - pool.add(['start', 'greater_than_or_equal'], 22) - pool.add(['start', 'less_than_or_equal'], 21) - pool.add(['start', 'not_null'], '1212') + pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + ) + pool.add(["start", "array_contains"], ["ab", "def"]) + pool.add(["start", "array_not_contains"], ["ac", "def"]) + pool.add(["start", "contains"], "cabcde") + pool.add(["start", "not_contains"], "zacde") + pool.add(["start", "start_with"], "abc") + pool.add(["start", "end_with"], "zzab") + pool.add(["start", "is"], "ab") + pool.add(["start", "is_not"], "aab") + pool.add(["start", "empty"], "") + pool.add(["start", "not_empty"], "aaa") + pool.add(["start", "equals"], 22) + pool.add(["start", "not_equals"], 23) + pool.add(["start", "greater_than"], 23) + pool.add(["start", "less_than"], 21) + pool.add(["start", "greater_than_or_equal"], 22) + pool.add(["start", "less_than_or_equal"], 21) + pool.add(["start", "not_null"], "1212") # Mock db.session.close() db.session.close = MagicMock() @@ -147,46 +94,47 @@ def test_execute_if_else_result_true(): result = node._run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs['result'] is True + assert result.outputs["result"] is True def test_execute_if_else_result_false(): node = IfElseNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, config={ - 'id': 'if-else', - 'data': { - 'title': '123', - 'type': 'if-else', - 'logical_operator': 'or', - 'conditions': [ + "id": "if-else", + "data": { + "title": "123", + "type": "if-else", + "logical_operator": "or", + "conditions": [ { - 'comparison_operator': 'contains', - 'variable_selector': ['start', 'array_contains'], - 'value': 'ab' + "comparison_operator": "contains", + "variable_selector": ["start", "array_contains"], + "value": "ab", }, { - 'comparison_operator': 'not contains', - 'variable_selector': ['start', 'array_not_contains'], - 'value': 'ab' - } - ] - } - } + "comparison_operator": "not contains", + "variable_selector": ["start", "array_not_contains"], + "value": "ab", + }, + ], + }, + }, ) # construct variable pool - pool = VariablePool(system_variables={ - SystemVariable.FILES: [], - SystemVariable.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) - pool.add(['start', 'array_contains'], ['1ab', 'def']) - pool.add(['start', 'array_not_contains'], ['ab', 'def']) + pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + ) + pool.add(["start", "array_contains"], ["1ab", "def"]) + pool.add(["start", "array_not_contains"], ["ab", "def"]) # Mock db.session.close() db.session.close = MagicMock() @@ -195,4 +143,4 @@ def test_execute_if_else_result_false(): result = node._run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs['result'] is False + assert result.outputs["result"] is False diff --git a/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py index 0b37d06fc0..e26c7df642 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py @@ -4,45 +4,45 @@ from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom from core.app.segments import ArrayStringVariable, StringVariable from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode -DEFAULT_NODE_ID = 'node_id' +DEFAULT_NODE_ID = "node_id" def test_overwrite_string_variable(): conversation_variable = StringVariable( id=str(uuid4()), - name='test_conversation_variable', - value='the first value', + name="test_conversation_variable", + value="the first value", ) input_variable = StringVariable( id=str(uuid4()), - name='test_string_variable', - value='the second value', + name="test_string_variable", + value="the second value", ) node = VariableAssignerNode( - tenant_id='tenant_id', - app_id='app_id', - workflow_id='workflow_id', - user_id='user_id', + tenant_id="tenant_id", + app_id="app_id", + workflow_id="workflow_id", + user_id="user_id", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, config={ - 'id': 'node_id', - 'data': { - 'assigned_variable_selector': ['conversation', conversation_variable.name], - 'write_mode': WriteMode.OVER_WRITE.value, - 'input_variable_selector': [DEFAULT_NODE_ID, input_variable.name], + "id": "node_id", + "data": { + "assigned_variable_selector": ["conversation", conversation_variable.name], + "write_mode": WriteMode.OVER_WRITE.value, + "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], }, }, ) variable_pool = VariablePool( - system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'}, + system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], @@ -52,48 +52,48 @@ def test_overwrite_string_variable(): input_variable, ) - with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run: + with mock.patch("core.workflow.nodes.variable_assigner.node.update_conversation_variable") as mock_run: node.run(variable_pool) mock_run.assert_called_once() - got = variable_pool.get(['conversation', conversation_variable.name]) + got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None - assert got.value == 'the second value' - assert got.to_object() == 'the second value' + assert got.value == "the second value" + assert got.to_object() == "the second value" def test_append_variable_to_array(): conversation_variable = ArrayStringVariable( id=str(uuid4()), - name='test_conversation_variable', - value=['the first value'], + name="test_conversation_variable", + value=["the first value"], ) input_variable = StringVariable( id=str(uuid4()), - name='test_string_variable', - value='the second value', + name="test_string_variable", + value="the second value", ) node = VariableAssignerNode( - tenant_id='tenant_id', - app_id='app_id', - workflow_id='workflow_id', - user_id='user_id', + tenant_id="tenant_id", + app_id="app_id", + workflow_id="workflow_id", + user_id="user_id", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, config={ - 'id': 'node_id', - 'data': { - 'assigned_variable_selector': ['conversation', conversation_variable.name], - 'write_mode': WriteMode.APPEND.value, - 'input_variable_selector': [DEFAULT_NODE_ID, input_variable.name], + "id": "node_id", + "data": { + "assigned_variable_selector": ["conversation", conversation_variable.name], + "write_mode": WriteMode.APPEND.value, + "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], }, }, ) variable_pool = VariablePool( - system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'}, + system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], @@ -103,41 +103,41 @@ def test_append_variable_to_array(): input_variable, ) - with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run: + with mock.patch("core.workflow.nodes.variable_assigner.node.update_conversation_variable") as mock_run: node.run(variable_pool) mock_run.assert_called_once() - got = variable_pool.get(['conversation', conversation_variable.name]) + got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None - assert got.to_object() == ['the first value', 'the second value'] + assert got.to_object() == ["the first value", "the second value"] def test_clear_array(): conversation_variable = ArrayStringVariable( id=str(uuid4()), - name='test_conversation_variable', - value=['the first value'], + name="test_conversation_variable", + value=["the first value"], ) node = VariableAssignerNode( - tenant_id='tenant_id', - app_id='app_id', - workflow_id='workflow_id', - user_id='user_id', + tenant_id="tenant_id", + app_id="app_id", + workflow_id="workflow_id", + user_id="user_id", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, config={ - 'id': 'node_id', - 'data': { - 'assigned_variable_selector': ['conversation', conversation_variable.name], - 'write_mode': WriteMode.CLEAR.value, - 'input_variable_selector': [], + "id": "node_id", + "data": { + "assigned_variable_selector": ["conversation", conversation_variable.name], + "write_mode": WriteMode.CLEAR.value, + "input_variable_selector": [], }, }, ) variable_pool = VariablePool( - system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'}, + system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], @@ -145,6 +145,6 @@ def test_clear_array(): node.run(variable_pool) - got = variable_pool.get(['conversation', conversation_variable.name]) + got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None assert got.to_object() == [] diff --git a/api/tests/unit_tests/libs/test_pandas.py b/api/tests/unit_tests/libs/test_pandas.py index bbc372ed61..21c2f0781d 100644 --- a/api/tests/unit_tests/libs/test_pandas.py +++ b/api/tests/unit_tests/libs/test_pandas.py @@ -3,50 +3,46 @@ import pandas as pd def test_pandas_csv(tmp_path, monkeypatch): monkeypatch.chdir(tmp_path) - data = {'col1': [1, 2.2, -3.3, 4.0, 5], - 'col2': ['A', 'B', 'C', 'D', 'E']} + data = {"col1": [1, 2.2, -3.3, 4.0, 5], "col2": ["A", "B", "C", "D", "E"]} df1 = pd.DataFrame(data) # write to csv file - csv_file_path = tmp_path.joinpath('example.csv') + csv_file_path = tmp_path.joinpath("example.csv") df1.to_csv(csv_file_path, index=False) # read from csv file - df2 = pd.read_csv(csv_file_path, on_bad_lines='skip') - assert df2[df2.columns[0]].to_list() == data['col1'] - assert df2[df2.columns[1]].to_list() == data['col2'] + df2 = pd.read_csv(csv_file_path, on_bad_lines="skip") + assert df2[df2.columns[0]].to_list() == data["col1"] + assert df2[df2.columns[1]].to_list() == data["col2"] def test_pandas_xlsx(tmp_path, monkeypatch): monkeypatch.chdir(tmp_path) - data = {'col1': [1, 2.2, -3.3, 4.0, 5], - 'col2': ['A', 'B', 'C', 'D', 'E']} + data = {"col1": [1, 2.2, -3.3, 4.0, 5], "col2": ["A", "B", "C", "D", "E"]} df1 = pd.DataFrame(data) # write to xlsx file - xlsx_file_path = tmp_path.joinpath('example.xlsx') + xlsx_file_path = tmp_path.joinpath("example.xlsx") df1.to_excel(xlsx_file_path, index=False) # read from xlsx file df2 = pd.read_excel(xlsx_file_path) - assert df2[df2.columns[0]].to_list() == data['col1'] - assert df2[df2.columns[1]].to_list() == data['col2'] + assert df2[df2.columns[0]].to_list() == data["col1"] + assert df2[df2.columns[1]].to_list() == data["col2"] def test_pandas_xlsx_with_sheets(tmp_path, monkeypatch): monkeypatch.chdir(tmp_path) - data1 = {'col1': [1, 2, 3, 4, 5], - 'col2': ['A', 'B', 'C', 'D', 'E']} + data1 = {"col1": [1, 2, 3, 4, 5], "col2": ["A", "B", "C", "D", "E"]} df1 = pd.DataFrame(data1) - data2 = {'col1': [6, 7, 8, 9, 10], - 'col2': ['F', 'G', 'H', 'I', 'J']} + data2 = {"col1": [6, 7, 8, 9, 10], "col2": ["F", "G", "H", "I", "J"]} df2 = pd.DataFrame(data2) # write to xlsx file with sheets - xlsx_file_path = tmp_path.joinpath('example_with_sheets.xlsx') - sheet1 = 'Sheet1' - sheet2 = 'Sheet2' + xlsx_file_path = tmp_path.joinpath("example_with_sheets.xlsx") + sheet1 = "Sheet1" + sheet2 = "Sheet2" with pd.ExcelWriter(xlsx_file_path) as excel_writer: df1.to_excel(excel_writer, sheet_name=sheet1, index=False) df2.to_excel(excel_writer, sheet_name=sheet2, index=False) @@ -54,9 +50,9 @@ def test_pandas_xlsx_with_sheets(tmp_path, monkeypatch): # read from xlsx file with sheets with pd.ExcelFile(xlsx_file_path) as excel_file: df1 = pd.read_excel(excel_file, sheet_name=sheet1) - assert df1[df1.columns[0]].to_list() == data1['col1'] - assert df1[df1.columns[1]].to_list() == data1['col2'] + assert df1[df1.columns[0]].to_list() == data1["col1"] + assert df1[df1.columns[1]].to_list() == data1["col2"] df2 = pd.read_excel(excel_file, sheet_name=sheet2) - assert df2[df2.columns[0]].to_list() == data2['col1'] - assert df2[df2.columns[1]].to_list() == data2['col2'] + assert df2[df2.columns[0]].to_list() == data2["col1"] + assert df2[df2.columns[1]].to_list() == data2["col2"] diff --git a/api/tests/unit_tests/libs/test_rsa.py b/api/tests/unit_tests/libs/test_rsa.py index a979b77d70..2dc51252f0 100644 --- a/api/tests/unit_tests/libs/test_rsa.py +++ b/api/tests/unit_tests/libs/test_rsa.py @@ -15,7 +15,7 @@ def test_gmpy2_pkcs10aep_cipher() -> None: private_rsa_key = RSA.import_key(private_key) private_cipher_rsa = gmpy2_pkcs10aep_cipher.new(private_rsa_key) - raw_text = 'raw_text' + raw_text = "raw_text" raw_text_bytes = raw_text.encode() # RSA encryption by public key and decryption by private key diff --git a/api/tests/unit_tests/libs/test_yarl.py b/api/tests/unit_tests/libs/test_yarl.py index 75a5344126..b9aee4af5f 100644 --- a/api/tests/unit_tests/libs/test_yarl.py +++ b/api/tests/unit_tests/libs/test_yarl.py @@ -3,21 +3,21 @@ from yarl import URL def test_yarl_urls(): - expected_1 = 'https://dify.ai/api' - assert str(URL('https://dify.ai') / 'api') == expected_1 - assert str(URL('https://dify.ai/') / 'api') == expected_1 + expected_1 = "https://dify.ai/api" + assert str(URL("https://dify.ai") / "api") == expected_1 + assert str(URL("https://dify.ai/") / "api") == expected_1 - expected_2 = 'http://dify.ai:12345/api' - assert str(URL('http://dify.ai:12345') / 'api') == expected_2 - assert str(URL('http://dify.ai:12345/') / 'api') == expected_2 + expected_2 = "http://dify.ai:12345/api" + assert str(URL("http://dify.ai:12345") / "api") == expected_2 + assert str(URL("http://dify.ai:12345/") / "api") == expected_2 - expected_3 = 'https://dify.ai/api/v1' - assert str(URL('https://dify.ai') / 'api' / 'v1') == expected_3 - assert str(URL('https://dify.ai') / 'api/v1') == expected_3 - assert str(URL('https://dify.ai/') / 'api/v1') == expected_3 - assert str(URL('https://dify.ai/api') / 'v1') == expected_3 - assert str(URL('https://dify.ai/api/') / 'v1') == expected_3 + expected_3 = "https://dify.ai/api/v1" + assert str(URL("https://dify.ai") / "api" / "v1") == expected_3 + assert str(URL("https://dify.ai") / "api/v1") == expected_3 + assert str(URL("https://dify.ai/") / "api/v1") == expected_3 + assert str(URL("https://dify.ai/api") / "v1") == expected_3 + assert str(URL("https://dify.ai/api/") / "v1") == expected_3 with pytest.raises(ValueError) as e1: - str(URL('https://dify.ai') / '/api') + str(URL("https://dify.ai") / "/api") assert str(e1.value) == "Appending path '/api' starting from slash is forbidden" diff --git a/api/tests/unit_tests/models/test_account.py b/api/tests/unit_tests/models/test_account.py index 006b99fb7d..026912ffbe 100644 --- a/api/tests/unit_tests/models/test_account.py +++ b/api/tests/unit_tests/models/test_account.py @@ -2,13 +2,13 @@ from models.account import TenantAccountRole def test_account_is_privileged_role() -> None: - assert TenantAccountRole.ADMIN == 'admin' - assert TenantAccountRole.OWNER == 'owner' - assert TenantAccountRole.EDITOR == 'editor' - assert TenantAccountRole.NORMAL == 'normal' + assert TenantAccountRole.ADMIN == "admin" + assert TenantAccountRole.OWNER == "owner" + assert TenantAccountRole.EDITOR == "editor" + assert TenantAccountRole.NORMAL == "normal" assert TenantAccountRole.is_privileged_role(TenantAccountRole.ADMIN) assert TenantAccountRole.is_privileged_role(TenantAccountRole.OWNER) assert not TenantAccountRole.is_privileged_role(TenantAccountRole.NORMAL) assert not TenantAccountRole.is_privileged_role(TenantAccountRole.EDITOR) - assert not TenantAccountRole.is_privileged_role('') + assert not TenantAccountRole.is_privileged_role("") diff --git a/api/tests/unit_tests/models/test_conversation_variable.py b/api/tests/unit_tests/models/test_conversation_variable.py index 9e16010d7e..7968347dec 100644 --- a/api/tests/unit_tests/models/test_conversation_variable.py +++ b/api/tests/unit_tests/models/test_conversation_variable.py @@ -7,19 +7,19 @@ from models import ConversationVariable def test_from_variable_and_to_variable(): variable = factory.build_variable_from_mapping( { - 'id': str(uuid4()), - 'name': 'name', - 'value_type': SegmentType.OBJECT, - 'value': { - 'key': { - 'key': 'value', + "id": str(uuid4()), + "name": "name", + "value_type": SegmentType.OBJECT, + "value": { + "key": { + "key": "value", } }, } ) conversation_variable = ConversationVariable.from_variable( - app_id='app_id', conversation_id='conversation_id', variable=variable + app_id="app_id", conversation_id="conversation_id", variable=variable ) assert conversation_variable.to_variable() == variable diff --git a/api/tests/unit_tests/models/test_workflow.py b/api/tests/unit_tests/models/test_workflow.py index facea34b5b..40483d7e3a 100644 --- a/api/tests/unit_tests/models/test_workflow.py +++ b/api/tests/unit_tests/models/test_workflow.py @@ -8,20 +8,30 @@ from models.workflow import Workflow def test_environment_variables(): - contexts.tenant_id.set('tenant_id') + contexts.tenant_id.set("tenant_id") # Create a Workflow instance - workflow = Workflow() + workflow = Workflow( + tenant_id="tenant_id", + app_id="app_id", + type="workflow", + version="draft", + graph="{}", + features="{}", + created_by="account_id", + environment_variables=[], + conversation_variables=[], + ) # Create some EnvironmentVariable instances - variable1 = StringVariable.model_validate({'name': 'var1', 'value': 'value1', 'id': str(uuid4())}) - variable2 = IntegerVariable.model_validate({'name': 'var2', 'value': 123, 'id': str(uuid4())}) - variable3 = SecretVariable.model_validate({'name': 'var3', 'value': 'secret', 'id': str(uuid4())}) - variable4 = FloatVariable.model_validate({'name': 'var4', 'value': 3.14, 'id': str(uuid4())}) + variable1 = StringVariable.model_validate({"name": "var1", "value": "value1", "id": str(uuid4())}) + variable2 = IntegerVariable.model_validate({"name": "var2", "value": 123, "id": str(uuid4())}) + variable3 = SecretVariable.model_validate({"name": "var3", "value": "secret", "id": str(uuid4())}) + variable4 = FloatVariable.model_validate({"name": "var4", "value": 3.14, "id": str(uuid4())}) with ( - mock.patch('core.helper.encrypter.encrypt_token', return_value='encrypted_token'), - mock.patch('core.helper.encrypter.decrypt_token', return_value='secret'), + mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"), + mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"), ): # Set the environment_variables property of the Workflow instance variables = [variable1, variable2, variable3, variable4] @@ -32,20 +42,30 @@ def test_environment_variables(): def test_update_environment_variables(): - contexts.tenant_id.set('tenant_id') + contexts.tenant_id.set("tenant_id") # Create a Workflow instance - workflow = Workflow() + workflow = Workflow( + tenant_id="tenant_id", + app_id="app_id", + type="workflow", + version="draft", + graph="{}", + features="{}", + created_by="account_id", + environment_variables=[], + conversation_variables=[], + ) # Create some EnvironmentVariable instances - variable1 = StringVariable.model_validate({'name': 'var1', 'value': 'value1', 'id': str(uuid4())}) - variable2 = IntegerVariable.model_validate({'name': 'var2', 'value': 123, 'id': str(uuid4())}) - variable3 = SecretVariable.model_validate({'name': 'var3', 'value': 'secret', 'id': str(uuid4())}) - variable4 = FloatVariable.model_validate({'name': 'var4', 'value': 3.14, 'id': str(uuid4())}) + variable1 = StringVariable.model_validate({"name": "var1", "value": "value1", "id": str(uuid4())}) + variable2 = IntegerVariable.model_validate({"name": "var2", "value": 123, "id": str(uuid4())}) + variable3 = SecretVariable.model_validate({"name": "var3", "value": "secret", "id": str(uuid4())}) + variable4 = FloatVariable.model_validate({"name": "var4", "value": 3.14, "id": str(uuid4())}) with ( - mock.patch('core.helper.encrypter.encrypt_token', return_value='encrypted_token'), - mock.patch('core.helper.encrypter.decrypt_token', return_value='secret'), + mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"), + mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"), ): variables = [variable1, variable2, variable3, variable4] @@ -56,40 +76,48 @@ def test_update_environment_variables(): # Update the name of variable3 and keep the value as it is variables[2] = variable3.model_copy( update={ - 'name': 'new name', - 'value': HIDDEN_VALUE, + "name": "new name", + "value": HIDDEN_VALUE, } ) workflow.environment_variables = variables - assert workflow.environment_variables[2].name == 'new name' + assert workflow.environment_variables[2].name == "new name" assert workflow.environment_variables[2].value == variable3.value def test_to_dict(): - contexts.tenant_id.set('tenant_id') + contexts.tenant_id.set("tenant_id") # Create a Workflow instance - workflow = Workflow() - workflow.graph = '{}' - workflow.features = '{}' + workflow = Workflow( + tenant_id="tenant_id", + app_id="app_id", + type="workflow", + version="draft", + graph="{}", + features="{}", + created_by="account_id", + environment_variables=[], + conversation_variables=[], + ) # Create some EnvironmentVariable instances with ( - mock.patch('core.helper.encrypter.encrypt_token', return_value='encrypted_token'), - mock.patch('core.helper.encrypter.decrypt_token', return_value='secret'), + mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"), + mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"), ): # Set the environment_variables property of the Workflow instance workflow.environment_variables = [ - SecretVariable.model_validate({'name': 'secret', 'value': 'secret', 'id': str(uuid4())}), - StringVariable.model_validate({'name': 'text', 'value': 'text', 'id': str(uuid4())}), + SecretVariable.model_validate({"name": "secret", "value": "secret", "id": str(uuid4())}), + StringVariable.model_validate({"name": "text", "value": "text", "id": str(uuid4())}), ] workflow_dict = workflow.to_dict() - assert workflow_dict['environment_variables'][0]['value'] == '' - assert workflow_dict['environment_variables'][1]['value'] == 'text' + assert workflow_dict["environment_variables"][0]["value"] == "" + assert workflow_dict["environment_variables"][1]["value"] == "text" workflow_dict = workflow.to_dict(include_secret=True) - assert workflow_dict['environment_variables'][0]['value'] == 'secret' - assert workflow_dict['environment_variables'][1]['value'] == 'text' + assert workflow_dict["environment_variables"][0]["value"] == "secret" + assert workflow_dict["environment_variables"][1]["value"] == "text" diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py index f589cd2097..805d92dfc9 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -14,6 +14,7 @@ from core.app.app_config.entities import ( ModelConfigEntity, PromptTemplateEntity, VariableEntity, + VariableEntityType, ) from core.helper import encrypter from core.model_runtime.entities.llm_entities import LLMMode @@ -25,23 +26,24 @@ from services.workflow.workflow_converter import WorkflowConverter @pytest.fixture def default_variables(): - return [ + value = [ VariableEntity( variable="text_input", label="text-input", - type=VariableEntity.Type.TEXT_INPUT + type=VariableEntityType.TEXT_INPUT, ), VariableEntity( variable="paragraph", label="paragraph", - type=VariableEntity.Type.PARAGRAPH + type=VariableEntityType.PARAGRAPH, ), VariableEntity( variable="select", label="select", - type=VariableEntity.Type.SELECT - ) + type=VariableEntityType.SELECT, + ), ] + return value def test__convert_to_start_node(default_variables): @@ -81,18 +83,12 @@ def test__convert_to_http_request_node_for_chatbot(default_variables): external_data_variables = [ ExternalDataVariableEntity( - variable="external_variable", - type="api", - config={ - "api_based_extension_id": api_based_extension_id - } + variable="external_variable", type="api", config={"api_based_extension_id": api_based_extension_id} ) ] nodes, _ = workflow_converter._convert_to_http_request_node( - app_model=app_model, - variables=default_variables, - external_data_variables=external_data_variables + app_model=app_model, variables=default_variables, external_data_variables=external_data_variables ) assert len(nodes) == 2 @@ -103,10 +99,7 @@ def test__convert_to_http_request_node_for_chatbot(default_variables): assert http_request_node["data"]["method"] == "post" assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint assert http_request_node["data"]["authorization"]["type"] == "api-key" - assert http_request_node["data"]["authorization"]["config"] == { - "type": "bearer", - "api_key": "api_key" - } + assert http_request_node["data"]["authorization"]["config"] == {"type": "bearer", "api_key": "api_key"} assert http_request_node["data"]["body"]["type"] == "json" body_data = http_request_node["data"]["body"]["data"] @@ -151,18 +144,12 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables): external_data_variables = [ ExternalDataVariableEntity( - variable="external_variable", - type="api", - config={ - "api_based_extension_id": api_based_extension_id - } + variable="external_variable", type="api", config={"api_based_extension_id": api_based_extension_id} ) ] nodes, _ = workflow_converter._convert_to_http_request_node( - app_model=app_model, - variables=default_variables, - external_data_variables=external_data_variables + app_model=app_model, variables=default_variables, external_data_variables=external_data_variables ) assert len(nodes) == 2 @@ -173,10 +160,7 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables): assert http_request_node["data"]["method"] == "post" assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint assert http_request_node["data"]["authorization"]["type"] == "api-key" - assert http_request_node["data"]["authorization"]["config"] == { - "type": "bearer", - "api_key": "api_key" - } + assert http_request_node["data"]["authorization"]["config"] == {"type": "bearer", "api_key": "api_key"} assert http_request_node["data"]["body"]["type"] == "json" body_data = http_request_node["data"]["body"]["data"] @@ -205,37 +189,25 @@ def test__convert_to_knowledge_retrieval_node_for_chatbot(): retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, top_k=5, score_threshold=0.8, - reranking_model={ - 'reranking_provider_name': 'cohere', - 'reranking_model_name': 'rerank-english-v2.0' - }, - reranking_enabled=True - ) + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"}, + reranking_enabled=True, + ), ) - model_config = ModelConfigEntity( - provider='openai', - model='gpt-4', - mode='chat', - parameters={}, - stop=[] - ) + model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[]) node = WorkflowConverter()._convert_to_knowledge_retrieval_node( - new_app_mode=new_app_mode, - dataset_config=dataset_config, - model_config=model_config + new_app_mode=new_app_mode, dataset_config=dataset_config, model_config=model_config ) assert node["data"]["type"] == "knowledge-retrieval" assert node["data"]["query_variable_selector"] == ["sys", "query"] assert node["data"]["dataset_ids"] == dataset_config.dataset_ids - assert (node["data"]["retrieval_mode"] - == dataset_config.retrieve_config.retrieve_strategy.value) + assert node["data"]["retrieval_mode"] == dataset_config.retrieve_config.retrieve_strategy.value assert node["data"]["multiple_retrieval_config"] == { "top_k": dataset_config.retrieve_config.top_k, "score_threshold": dataset_config.retrieve_config.score_threshold, - "reranking_model": dataset_config.retrieve_config.reranking_model + "reranking_model": dataset_config.retrieve_config.reranking_model, } @@ -249,37 +221,25 @@ def test__convert_to_knowledge_retrieval_node_for_workflow_app(): retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, top_k=5, score_threshold=0.8, - reranking_model={ - 'reranking_provider_name': 'cohere', - 'reranking_model_name': 'rerank-english-v2.0' - }, - reranking_enabled=True - ) + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"}, + reranking_enabled=True, + ), ) - model_config = ModelConfigEntity( - provider='openai', - model='gpt-4', - mode='chat', - parameters={}, - stop=[] - ) + model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[]) node = WorkflowConverter()._convert_to_knowledge_retrieval_node( - new_app_mode=new_app_mode, - dataset_config=dataset_config, - model_config=model_config + new_app_mode=new_app_mode, dataset_config=dataset_config, model_config=model_config ) assert node["data"]["type"] == "knowledge-retrieval" assert node["data"]["query_variable_selector"] == ["start", dataset_config.retrieve_config.query_variable] assert node["data"]["dataset_ids"] == dataset_config.dataset_ids - assert (node["data"]["retrieval_mode"] - == dataset_config.retrieve_config.retrieve_strategy.value) + assert node["data"]["retrieval_mode"] == dataset_config.retrieve_config.retrieve_strategy.value assert node["data"]["multiple_retrieval_config"] == { "top_k": dataset_config.retrieve_config.top_k, "score_threshold": dataset_config.retrieve_config.score_threshold, - "reranking_model": dataset_config.retrieve_config.reranking_model + "reranking_model": dataset_config.retrieve_config.reranking_model, } @@ -291,14 +251,12 @@ def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables): workflow_converter = WorkflowConverter() start_node = workflow_converter._convert_to_start_node(default_variables) graph = { - "nodes": [ - start_node - ], - "edges": [] # no need + "nodes": [start_node], + "edges": [], # no need } model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = 'openai' + model_config_mock.provider = "openai" model_config_mock.model = model model_config_mock.mode = model_mode.value model_config_mock.parameters = {} @@ -306,7 +264,7 @@ def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables): prompt_template = PromptTemplateEntity( prompt_type=PromptTemplateEntity.PromptType.SIMPLE, - simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}." + simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}.", ) llm_node = workflow_converter._convert_to_llm_node( @@ -314,17 +272,17 @@ def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables): new_app_mode=new_app_mode, model_config=model_config_mock, graph=graph, - prompt_template=prompt_template + prompt_template=prompt_template, ) assert llm_node["data"]["type"] == "llm" - assert llm_node["data"]["model"]['name'] == model - assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["model"]["name"] == model + assert llm_node["data"]["model"]["mode"] == model_mode.value template = prompt_template.simple_prompt_template for v in default_variables: - template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}') - assert llm_node["data"]["prompt_template"][0]['text'] == template + '\n' - assert llm_node["data"]['context']['enabled'] is False + template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") + assert llm_node["data"]["prompt_template"][0]["text"] == template + "\n" + assert llm_node["data"]["context"]["enabled"] is False def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variables): @@ -335,14 +293,12 @@ def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variab workflow_converter = WorkflowConverter() start_node = workflow_converter._convert_to_start_node(default_variables) graph = { - "nodes": [ - start_node - ], - "edges": [] # no need + "nodes": [start_node], + "edges": [], # no need } model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = 'openai' + model_config_mock.provider = "openai" model_config_mock.model = model model_config_mock.mode = model_mode.value model_config_mock.parameters = {} @@ -350,7 +306,7 @@ def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variab prompt_template = PromptTemplateEntity( prompt_type=PromptTemplateEntity.PromptType.SIMPLE, - simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}." + simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}.", ) llm_node = workflow_converter._convert_to_llm_node( @@ -358,17 +314,17 @@ def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variab new_app_mode=new_app_mode, model_config=model_config_mock, graph=graph, - prompt_template=prompt_template + prompt_template=prompt_template, ) assert llm_node["data"]["type"] == "llm" - assert llm_node["data"]["model"]['name'] == model - assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["model"]["name"] == model + assert llm_node["data"]["model"]["mode"] == model_mode.value template = prompt_template.simple_prompt_template for v in default_variables: - template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}') - assert llm_node["data"]["prompt_template"]['text'] == template + '\n' - assert llm_node["data"]['context']['enabled'] is False + template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") + assert llm_node["data"]["prompt_template"]["text"] == template + "\n" + assert llm_node["data"]["context"]["enabled"] is False def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables): @@ -379,14 +335,12 @@ def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables) workflow_converter = WorkflowConverter() start_node = workflow_converter._convert_to_start_node(default_variables) graph = { - "nodes": [ - start_node - ], - "edges": [] # no need + "nodes": [start_node], + "edges": [], # no need } model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = 'openai' + model_config_mock.provider = "openai" model_config_mock.model = model model_config_mock.mode = model_mode.value model_config_mock.parameters = {} @@ -394,12 +348,16 @@ def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables) prompt_template = PromptTemplateEntity( prompt_type=PromptTemplateEntity.PromptType.ADVANCED, - advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity(messages=[ - AdvancedChatMessageEntity(text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", - role=PromptMessageRole.SYSTEM), - AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER), - AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT), - ]) + advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity( + messages=[ + AdvancedChatMessageEntity( + text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", + role=PromptMessageRole.SYSTEM, + ), + AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER), + AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT), + ] + ), ) llm_node = workflow_converter._convert_to_llm_node( @@ -407,18 +365,18 @@ def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables) new_app_mode=new_app_mode, model_config=model_config_mock, graph=graph, - prompt_template=prompt_template + prompt_template=prompt_template, ) assert llm_node["data"]["type"] == "llm" - assert llm_node["data"]["model"]['name'] == model - assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["model"]["name"] == model + assert llm_node["data"]["model"]["mode"] == model_mode.value assert isinstance(llm_node["data"]["prompt_template"], list) assert len(llm_node["data"]["prompt_template"]) == len(prompt_template.advanced_chat_prompt_template.messages) template = prompt_template.advanced_chat_prompt_template.messages[0].text for v in default_variables: - template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}') - assert llm_node["data"]["prompt_template"][0]['text'] == template + template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") + assert llm_node["data"]["prompt_template"][0]["text"] == template def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_variables): @@ -429,14 +387,12 @@ def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_var workflow_converter = WorkflowConverter() start_node = workflow_converter._convert_to_start_node(default_variables) graph = { - "nodes": [ - start_node - ], - "edges": [] # no need + "nodes": [start_node], + "edges": [], # no need } model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = 'openai' + model_config_mock.provider = "openai" model_config_mock.model = model model_config_mock.mode = model_mode.value model_config_mock.parameters = {} @@ -446,12 +402,9 @@ def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_var prompt_type=PromptTemplateEntity.PromptType.ADVANCED, advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity( prompt="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}\n\n" - "Human: hi\nAssistant: ", - role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity( - user="Human", - assistant="Assistant" - ) - ) + "Human: hi\nAssistant: ", + role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(user="Human", assistant="Assistant"), + ), ) llm_node = workflow_converter._convert_to_llm_node( @@ -459,14 +412,14 @@ def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_var new_app_mode=new_app_mode, model_config=model_config_mock, graph=graph, - prompt_template=prompt_template + prompt_template=prompt_template, ) assert llm_node["data"]["type"] == "llm" - assert llm_node["data"]["model"]['name'] == model - assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["model"]["name"] == model + assert llm_node["data"]["model"]["mode"] == model_mode.value assert isinstance(llm_node["data"]["prompt_template"], dict) template = prompt_template.advanced_completion_prompt_template.prompt for v in default_variables: - template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}') - assert llm_node["data"]["prompt_template"]['text'] == template + template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") + assert llm_node["data"]["prompt_template"]["text"] == template diff --git a/api/tests/unit_tests/utils/position_helper/test_position_helper.py b/api/tests/unit_tests/utils/position_helper/test_position_helper.py index 2237319904..29558a93c2 100644 --- a/api/tests/unit_tests/utils/position_helper/test_position_helper.py +++ b/api/tests/unit_tests/utils/position_helper/test_position_helper.py @@ -2,54 +2,122 @@ from textwrap import dedent import pytest -from core.helper.position_helper import get_position_map +from core.helper.position_helper import get_position_map, is_filtered, pin_position_map, sort_by_position_map @pytest.fixture def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str: monkeypatch.chdir(tmp_path) - tmp_path.joinpath("example_positions.yaml").write_text(dedent( - """\ + tmp_path.joinpath("example_positions.yaml").write_text( + dedent( + """\ - first - second # - commented - third - + - 9999999999999 - forth - """)) + """ + ) + ) return str(tmp_path) @pytest.fixture def prepare_empty_commented_positions_yaml(tmp_path, monkeypatch) -> str: monkeypatch.chdir(tmp_path) - tmp_path.joinpath("example_positions_all_commented.yaml").write_text(dedent( - """\ + tmp_path.joinpath("example_positions_all_commented.yaml").write_text( + dedent( + """\ # - commented1 # - commented2 - - - - - - """)) + - + - + + """ + ) + ) return str(tmp_path) def test_position_helper(prepare_example_positions_yaml): - position_map = get_position_map( - folder_path=prepare_example_positions_yaml, - file_name='example_positions.yaml') + position_map = get_position_map(folder_path=prepare_example_positions_yaml, file_name="example_positions.yaml") assert len(position_map) == 4 assert position_map == { - 'first': 0, - 'second': 1, - 'third': 2, - 'forth': 3, + "first": 0, + "second": 1, + "third": 2, + "forth": 3, } def test_position_helper_with_all_commented(prepare_empty_commented_positions_yaml): position_map = get_position_map( - folder_path=prepare_empty_commented_positions_yaml, - file_name='example_positions_all_commented.yaml') + folder_path=prepare_empty_commented_positions_yaml, file_name="example_positions_all_commented.yaml" + ) assert position_map == {} + + +def test_excluded_position_data(prepare_example_positions_yaml): + position_map = get_position_map(folder_path=prepare_example_positions_yaml, file_name="example_positions.yaml") + pin_list = ["forth", "first"] + include_set = set() + exclude_set = {"9999999999999"} + + position_map = pin_position_map(original_position_map=position_map, pin_list=pin_list) + + data = [ + "forth", + "first", + "second", + "third", + "9999999999999", + "extra1", + "extra2", + ] + + # filter out the data + data = [item for item in data if not is_filtered(include_set, exclude_set, item, lambda x: x)] + + # sort data by position map + sorted_data = sort_by_position_map( + position_map=position_map, + data=data, + name_func=lambda x: x, + ) + + # assert the result in the correct order + assert sorted_data == ["forth", "first", "second", "third", "extra1", "extra2"] + + +def test_included_position_data(prepare_example_positions_yaml): + position_map = get_position_map(folder_path=prepare_example_positions_yaml, file_name="example_positions.yaml") + pin_list = ["forth", "first"] + include_set = {"forth", "first"} + exclude_set = {} + + position_map = pin_position_map(original_position_map=position_map, pin_list=pin_list) + + data = [ + "forth", + "first", + "second", + "third", + "9999999999999", + "extra1", + "extra2", + ] + + # filter out the data + data = [item for item in data if not is_filtered(include_set, exclude_set, item, lambda x: x)] + + # sort data by position map + sorted_data = sort_by_position_map( + position_map=position_map, + data=data, + name_func=lambda x: x, + ) + + # assert the result in the correct order + assert sorted_data == ["forth", "first"] diff --git a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py index c0452b4e4d..95b93651d5 100644 --- a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py +++ b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py @@ -5,17 +5,18 @@ from yaml import YAMLError from core.tools.utils.yaml_utils import load_yaml_file -EXAMPLE_YAML_FILE = 'example_yaml.yaml' -INVALID_YAML_FILE = 'invalid_yaml.yaml' -NON_EXISTING_YAML_FILE = 'non_existing_file.yaml' +EXAMPLE_YAML_FILE = "example_yaml.yaml" +INVALID_YAML_FILE = "invalid_yaml.yaml" +NON_EXISTING_YAML_FILE = "non_existing_file.yaml" @pytest.fixture def prepare_example_yaml_file(tmp_path, monkeypatch) -> str: monkeypatch.chdir(tmp_path) file_path = tmp_path.joinpath(EXAMPLE_YAML_FILE) - file_path.write_text(dedent( - """\ + file_path.write_text( + dedent( + """\ address: city: Example City country: Example Country @@ -26,7 +27,9 @@ def prepare_example_yaml_file(tmp_path, monkeypatch) -> str: - Java - C++ empty_key: - """)) + """ + ) + ) return str(file_path) @@ -34,8 +37,9 @@ def prepare_example_yaml_file(tmp_path, monkeypatch) -> str: def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str: monkeypatch.chdir(tmp_path) file_path = tmp_path.joinpath(INVALID_YAML_FILE) - file_path.write_text(dedent( - """\ + file_path.write_text( + dedent( + """\ address: city: Example City country: Example Country @@ -45,13 +49,15 @@ def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str: - Python - Java - C++ - """)) + """ + ) + ) return str(file_path) def test_load_yaml_non_existing_file(): assert load_yaml_file(file_path=NON_EXISTING_YAML_FILE) == {} - assert load_yaml_file(file_path='') == {} + assert load_yaml_file(file_path="") == {} with pytest.raises(FileNotFoundError): load_yaml_file(file_path=NON_EXISTING_YAML_FILE, ignore_error=False) @@ -60,12 +66,12 @@ def test_load_yaml_non_existing_file(): def test_load_valid_yaml_file(prepare_example_yaml_file): yaml_data = load_yaml_file(file_path=prepare_example_yaml_file) assert len(yaml_data) > 0 - assert yaml_data['age'] == 30 - assert yaml_data['gender'] == 'male' - assert yaml_data['address']['city'] == 'Example City' - assert set(yaml_data['languages']) == {'Python', 'Java', 'C++'} - assert yaml_data.get('empty_key') is None - assert yaml_data.get('non_existed_key') is None + assert yaml_data["age"] == 30 + assert yaml_data["gender"] == "male" + assert yaml_data["address"]["city"] == "Example City" + assert set(yaml_data["languages"]) == {"Python", "Java", "C++"} + assert yaml_data.get("empty_key") is None + assert yaml_data.get("non_existed_key") is None def test_load_invalid_yaml_file(prepare_invalid_yaml_file): diff --git a/docker-legacy/docker-compose.yaml b/docker-legacy/docker-compose.yaml index edefd129d5..548a1c004e 100644 --- a/docker-legacy/docker-compose.yaml +++ b/docker-legacy/docker-compose.yaml @@ -2,7 +2,7 @@ version: '3' services: # API service api: - image: langgenius/dify-api:0.7.1 + image: langgenius/dify-api:0.7.3 restart: always environment: # Startup mode, 'api' starts the API server. @@ -229,7 +229,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.7.1 + image: langgenius/dify-api:0.7.3 restart: always environment: CONSOLE_WEB_URL: '' @@ -400,7 +400,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.7.1 + image: langgenius/dify-web:0.7.3 restart: always environment: # The base URL of console application api server, refers to the Console base URL of WEB service if console domain is diff --git a/docker/.env.example b/docker/.env.example index 03e1e4e50e..d8fa14f7c0 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -285,6 +285,8 @@ ALIYUN_OSS_SECRET_KEY=your-secret-key ALIYUN_OSS_ENDPOINT=https://oss-ap-southeast-1-internal.aliyuncs.com ALIYUN_OSS_REGION=ap-southeast-1 ALIYUN_OSS_AUTH_VERSION=v4 +# Don't start with '/'. OSS doesn't support leading slash in object names. +ALIYUN_OSS_PATH=your-path # Tencent COS Configuration # The name of the Tencent COS bucket to use for storing files. @@ -514,6 +516,8 @@ RESET_PASSWORD_TOKEN_EXPIRY_HOURS=24 CODE_EXECUTION_ENDPOINT=http://sandbox:8194 CODE_MAX_NUMBER=9223372036854775807 CODE_MIN_NUMBER=-9223372036854775808 +CODE_MAX_DEPTH=5 +CODE_MAX_PRECISION=20 CODE_MAX_STRING_LENGTH=80000 TEMPLATE_TRANSFORM_MAX_LENGTH=80000 CODE_MAX_STRING_ARRAY_LENGTH=30 @@ -701,3 +705,22 @@ COMPOSE_PROFILES=${VECTOR_STORE:-weaviate} # ------------------------------ EXPOSE_NGINX_PORT=80 EXPOSE_NGINX_SSL_PORT=443 + +# ---------------------------------------------------------------------------- +# ModelProvider & Tool Position Configuration +# Used to specify the model providers and tools that can be used in the app. +# ---------------------------------------------------------------------------- + +# Pin, include, and exclude tools +# Use comma-separated values with no spaces between items. +# Example: POSITION_TOOL_PINS=bing,google +POSITION_TOOL_PINS= +POSITION_TOOL_INCLUDES= +POSITION_TOOL_EXCLUDES= + +# Pin, include, and exclude model providers +# Use comma-separated values with no spaces between items. +# Example: POSITION_PROVIDER_PINS=openai,openllm +POSITION_PROVIDER_PINS= +POSITION_PROVIDER_INCLUDES= +POSITION_PROVIDER_EXCLUDES= \ No newline at end of file diff --git a/docker/certbot/README.md b/docker/certbot/README.md index 3fab2f4bb7..c6f73ae699 100644 --- a/docker/certbot/README.md +++ b/docker/certbot/README.md @@ -16,7 +16,7 @@ Use `docker-compose --profile certbot up` to use this features. CERTBOT_DOMAIN=your_domain.com CERTBOT_EMAIL=example@your_domain.com ``` - excecute command: + execute command: ```shell sudo docker network prune sudo docker-compose --profile certbot up --force-recreate -d @@ -30,7 +30,7 @@ Use `docker-compose --profile certbot up` to use this features. ```properties NGINX_HTTPS_ENABLED=true ``` - excecute command: + execute command: ```shell sudo docker-compose --profile certbot up -d --no-deps --force-recreate nginx ``` @@ -73,4 +73,4 @@ To use cert files dir `nginx/ssl` as before, simply launch containers WITHOUT `- ```shell sudo docker-compose up -d -``` \ No newline at end of file +``` diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 9a4c40448b..adaecc69ec 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -19,6 +19,11 @@ services: - ./volumes/db/data:/var/lib/postgresql/data ports: - "${EXPOSE_POSTGRES_PORT:-5432}:5432" + healthcheck: + test: [ "CMD", "pg_isready" ] + interval: 1s + timeout: 3s + retries: 30 # The redis cache. redis: @@ -31,6 +36,8 @@ services: command: redis-server --requirepass difyai123456 ports: - "${EXPOSE_REDIS_PORT:-6379}:6379" + healthcheck: + test: [ "CMD", "redis-cli", "ping" ] # The DifySandbox sandbox: diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index aea16e3817..9a532eed07 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -66,6 +66,7 @@ x-shared-env: &shared-api-worker-env ALIYUN_OSS_ENDPOINT: ${ALIYUN_OSS_ENDPOINT:-} ALIYUN_OSS_REGION: ${ALIYUN_OSS_REGION:-} ALIYUN_OSS_AUTH_VERSION: ${ALIYUN_OSS_AUTH_VERSION:-v4} + ALIYUN_OSS_PATHS: ${ALIYUN_OSS_PATH:-} TENCENT_COS_BUCKET_NAME: ${TENCENT_COS_BUCKET_NAME:-} TENCENT_COS_SECRET_KEY: ${TENCENT_COS_SECRET_KEY:-} TENCENT_COS_SECRET_ID: ${TENCENT_COS_SECRET_ID:-} @@ -177,6 +178,8 @@ x-shared-env: &shared-api-worker-env CODE_EXECUTION_API_KEY: ${SANDBOX_API_KEY:-dify-sandbox} CODE_MAX_NUMBER: ${CODE_MAX_NUMBER:-9223372036854775807} CODE_MIN_NUMBER: ${CODE_MIN_NUMBER:--9223372036854775808} + CODE_MAX_DEPTH: ${CODE_MAX_DEPTH:-5} + CODE_MAX_PRECISION: ${CODE_MAX_PRECISION:-20} CODE_MAX_STRING_LENGTH: ${CODE_MAX_STRING_LENGTH:-80000} TEMPLATE_TRANSFORM_MAX_LENGTH: ${TEMPLATE_TRANSFORM_MAX_LENGTH:-80000} CODE_MAX_STRING_ARRAY_LENGTH: ${CODE_MAX_STRING_ARRAY_LENGTH:-30} @@ -188,7 +191,7 @@ x-shared-env: &shared-api-worker-env services: # API service api: - image: langgenius/dify-api:0.7.1 + image: langgenius/dify-api:0.7.3 restart: always environment: # Use the shared environment variables. @@ -208,7 +211,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.7.1 + image: langgenius/dify-api:0.7.3 restart: always environment: # Use the shared environment variables. @@ -227,7 +230,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.7.1 + image: langgenius/dify-web:0.7.3 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -292,7 +295,7 @@ services: # ssrf_proxy server # for more information, please refer to - # https://docs.dify.ai/learn-more/faq/self-host-faq#id-18.-why-is-ssrf_proxy-needed + # https://docs.dify.ai/learn-more/faq/install-faq#id-18.-why-is-ssrf_proxy-needed ssrf_proxy: image: ubuntu/squid:latest restart: always diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx index 09569df8bf..8723420d84 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx @@ -15,13 +15,14 @@ import { } from '@remixicon/react' import { useTranslation } from 'react-i18next' import { useShallow } from 'zustand/react/shallow' +import { useContextSelector } from 'use-context-selector' import s from './style.module.css' import cn from '@/utils/classnames' import { useStore } from '@/app/components/app/store' import AppSideBar from '@/app/components/app-sidebar' import type { NavIcon } from '@/app/components/app-sidebar/navLink' -import { fetchAppDetail } from '@/service/apps' -import { useAppContext } from '@/context/app-context' +import { fetchAppDetail, fetchAppSSO } from '@/service/apps' +import AppContext, { useAppContext } from '@/context/app-context' import Loading from '@/app/components/base/loading' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' @@ -52,6 +53,7 @@ const AppDetailLayout: FC = (props) => { icon: NavIcon selectedIcon: NavIcon }>>([]) + const systemFeatures = useContextSelector(AppContext, state => state.systemFeatures) const getNavigations = useCallback((appId: string, isCurrentWorkspaceEditor: boolean, mode: string) => { const navs = [ @@ -114,14 +116,19 @@ const AppDetailLayout: FC = (props) => { router.replace(`/app/${appId}/configuration`) } else { - setAppDetail(res) + setAppDetail({ ...res, enable_sso: false }) setNavigation(getNavigations(appId, isCurrentWorkspaceEditor, res.mode)) + if (systemFeatures.enable_web_sso_switch_component) { + fetchAppSSO({ appId }).then((ssoRes) => { + setAppDetail({ ...res, enable_sso: ssoRes.enabled }) + }) + } } }).catch((e: any) => { if (e.status === 404) router.replace('/apps') }) - }, [appId, isCurrentWorkspaceEditor]) + }, [appId, isCurrentWorkspaceEditor, systemFeatures]) useUnmount(() => { setAppDetail() diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/cardView.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/cardView.tsx index 5fa9a2e406..8f3ee510b8 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/cardView.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/cardView.tsx @@ -2,22 +2,25 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' +import { useContext, useContextSelector } from 'use-context-selector' import AppCard from '@/app/components/app/overview/appCard' import Loading from '@/app/components/base/loading' import { ToastContext } from '@/app/components/base/toast' import { fetchAppDetail, + fetchAppSSO, + updateAppSSO, updateAppSiteAccessToken, updateAppSiteConfig, updateAppSiteStatus, } from '@/service/apps' -import type { App } from '@/types/app' +import type { App, AppSSO } from '@/types/app' import type { UpdateAppSiteCodeResponse } from '@/models/app' import { asyncRunSafe } from '@/utils' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import type { IAppCardProps } from '@/app/components/app/overview/appCard' import { useStore as useAppStore } from '@/app/components/app/store' +import AppContext from '@/context/app-context' export type ICardViewProps = { appId: string @@ -28,11 +31,20 @@ const CardView: FC = ({ appId }) => { const { notify } = useContext(ToastContext) const appDetail = useAppStore(state => state.appDetail) const setAppDetail = useAppStore(state => state.setAppDetail) + const systemFeatures = useContextSelector(AppContext, state => state.systemFeatures) const updateAppDetail = async () => { - fetchAppDetail({ url: '/apps', id: appId }).then((res) => { - setAppDetail(res) - }) + try { + const res = await fetchAppDetail({ url: '/apps', id: appId }) + if (systemFeatures.enable_web_sso_switch_component) { + const ssoRes = await fetchAppSSO({ appId }) + setAppDetail({ ...res, enable_sso: ssoRes.enabled }) + } + else { + setAppDetail({ ...res }) + } + } + catch (error) { console.error(error) } } const handleCallbackResult = (err: Error | null, message?: string) => { @@ -81,6 +93,16 @@ const CardView: FC = ({ appId }) => { if (!err) localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1') + if (systemFeatures.enable_web_sso_switch_component) { + const [sso_err] = await asyncRunSafe( + updateAppSSO({ id: appId, enabled: Boolean(params.enable_sso) }) as Promise, + ) + if (sso_err) { + handleCallbackResult(sso_err) + return + } + } + handleCallbackResult(err) } diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chartView.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chartView.tsx index ff32a157fc..b01bc1b856 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chartView.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chartView.tsx @@ -4,7 +4,7 @@ import dayjs from 'dayjs' import quarterOfYear from 'dayjs/plugin/quarterOfYear' import { useTranslation } from 'react-i18next' import type { PeriodParams } from '@/app/components/app/overview/appChart' -import { AvgResponseTime, AvgSessionInteractions, AvgUserInteractions, ConversationsChart, CostChart, EndUsersChart, TokenPerSecond, UserSatisfactionRate, WorkflowCostChart, WorkflowDailyTerminalsChart, WorkflowMessagesChart } from '@/app/components/app/overview/appChart' +import { AvgResponseTime, AvgSessionInteractions, AvgUserInteractions, ConversationsChart, CostChart, EndUsersChart, MessagesChart, TokenPerSecond, UserSatisfactionRate, WorkflowCostChart, WorkflowDailyTerminalsChart, WorkflowMessagesChart } from '@/app/components/app/overview/appChart' import type { Item } from '@/app/components/base/select' import { SimpleSelect } from '@/app/components/base/select' import { TIME_PERIOD_LIST } from '@/app/components/app/log/filter' @@ -79,6 +79,11 @@ export default function ChartView({ appId }: IChartViewProps) { )} + {!isWorkflow && isChatApp && ( +
+ +
+ )} {isWorkflow && (
diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx index 7aa1fca96d..8e3d8f9ec6 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx @@ -10,7 +10,7 @@ import { TracingProvider } from './type' import ProviderConfigModal from './provider-config-modal' import Indicator from '@/app/components/header/indicator' import Switch from '@/app/components/base/switch' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' const I18N_PREFIX = 'app.tracing' @@ -85,6 +85,7 @@ const ConfigPopup: FC = ({ = ({ = ({ <> {providerAllNotConfigured ? ( - {switchContent} - - + ) : switchContent} diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx index 120fe29dff..b908322a92 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx @@ -6,6 +6,7 @@ import { TracingProvider } from './type' import cn from '@/utils/classnames' import { LangfuseIconBig, LangsmithIconBig } from '@/app/components/base/icons/src/public/tracing' import { Settings04 } from '@/app/components/base/icons/src/vender/line/general' +import { Eye as View } from '@/app/components/base/icons/src/vender/solid/general' const I18N_PREFIX = 'app.tracing' @@ -13,6 +14,7 @@ type Props = { type: TracingProvider readOnly: boolean isChosen: boolean + config: any onChoose: () => void hasConfigured: boolean onConfig: () => void @@ -29,6 +31,7 @@ const ProviderPanel: FC = ({ type, readOnly, isChosen, + config, onChoose, hasConfigured, onConfig, @@ -41,6 +44,14 @@ const ProviderPanel: FC = ({ onConfig() }, [onConfig]) + const viewBtnClick = useCallback((e: React.MouseEvent) => { + e.preventDefault() + e.stopPropagation() + + const url = `${config?.host}/project/${config?.project_key}` + window.open(url, '_blank', 'noopener,noreferrer') + }, []) + const handleChosen = useCallback((e: React.MouseEvent) => { e.stopPropagation() if (isChosen || !hasConfigured || readOnly) @@ -58,12 +69,20 @@ const ProviderPanel: FC = ({ {isChosen &&
{t(`${I18N_PREFIX}.inUse`)}
}
{!readOnly && ( -
- -
{t(`${I18N_PREFIX}.config`)}
+
+ {hasConfigured && ( +
+ +
{t(`${I18N_PREFIX}.view`)}
+
+ )} +
+ +
{t(`${I18N_PREFIX}.config`)}
+
)} diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/toggle-fold-btn.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/toggle-fold-btn.tsx index 9119deede8..934eb681b9 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/toggle-fold-btn.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/toggle-fold-btn.tsx @@ -3,7 +3,7 @@ import { ChevronDoubleDownIcon } from '@heroicons/react/20/solid' import type { FC } from 'react' import { useTranslation } from 'react-i18next' import React, { useCallback } from 'react' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' const I18N_PREFIX = 'app.tracing' @@ -25,9 +25,8 @@ const ToggleFoldBtn: FC = ({ return ( // text-[0px] to hide spacing between tooltip elements
- {isFold && (
@@ -39,7 +38,7 @@ const ToggleFoldBtn: FC = ({
)} -
+
) } diff --git a/web/app/(commonLayout)/apps/AppCard.tsx b/web/app/(commonLayout)/apps/AppCard.tsx index bc7308a711..fb39dee5a3 100644 --- a/web/app/(commonLayout)/apps/AppCard.tsx +++ b/web/app/(commonLayout)/apps/AppCard.tsx @@ -79,6 +79,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { icon, icon_background, description, + use_icon_as_answer_icon, }) => { try { await updateAppInfo({ @@ -88,6 +89,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { icon, icon_background, description, + use_icon_as_answer_icon, }) setShowEditModal(false) notify({ @@ -255,7 +257,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { e.preventDefault() getRedirection(isCurrentWorkspaceEditor, app, push) }} - className='group flex col-span-1 bg-white border-2 border-solid border-transparent rounded-xl shadow-sm min-h-[160px] flex flex-col transition-all duration-200 ease-in-out cursor-pointer hover:shadow-lg' + className='relative group col-span-1 bg-white border-2 border-solid border-transparent rounded-xl shadow-sm flex flex-col transition-all duration-200 ease-in-out cursor-pointer hover:shadow-lg' >
@@ -297,17 +299,16 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
-
- {app.description} +
+
+ {app.description} +
{isCurrentWorkspaceEditor && ( @@ -371,6 +372,8 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { appIconBackground={app.icon_background} appIconUrl={app.icon_url} appDescription={app.description} + appMode={app.mode} + appUseIconAsAnswerIcon={app.use_icon_as_answer_icon} show={showEditModal} onConfirm={onEdit} onHide={() => setShowEditModal(false)} diff --git a/web/app/(commonLayout)/apps/Apps.tsx b/web/app/(commonLayout)/apps/Apps.tsx index c16512bd50..132096c6b4 100644 --- a/web/app/(commonLayout)/apps/Apps.tsx +++ b/web/app/(commonLayout)/apps/Apps.tsx @@ -139,7 +139,7 @@ const Apps = () => {
diff --git a/web/app/(commonLayout)/datasets/template/template.en.mdx b/web/app/(commonLayout)/datasets/template/template.en.mdx index 44c5964d77..33451b8a0b 100644 --- a/web/app/(commonLayout)/datasets/template/template.en.mdx +++ b/web/app/(commonLayout)/datasets/template/template.en.mdx @@ -236,6 +236,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from Knowledge name + + Permission + - only_me Only me + - all_team_members All team members + - partial_members Partial members + @@ -243,14 +249,15 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from title="Request" tag="POST" label="/datasets" - targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"name": "name"}'`} + targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"name": "name", "permission": "only_me"}'`} > ```bash {{ title: 'cURL' }} curl --location --request POST '${apiBaseUrl}/v1/datasets' \ --header 'Authorization: Bearer {api_key}' \ --header 'Content-Type: application/json' \ --data-raw '{ - "name": "name" + "name": "name", + "permission": "only_me" }' ``` diff --git a/web/app/(commonLayout)/datasets/template/template.zh.mdx b/web/app/(commonLayout)/datasets/template/template.zh.mdx index 9f79b0f900..dc48f92f18 100644 --- a/web/app/(commonLayout)/datasets/template/template.zh.mdx +++ b/web/app/(commonLayout)/datasets/template/template.zh.mdx @@ -236,6 +236,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from 知识库名称 + + 权限 + - only_me 仅自己 + - all_team_members 所有团队成员 + - partial_members 部分团队成员 + @@ -243,14 +249,15 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from title="Request" tag="POST" label="/datasets" - targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"name": "name"}'`} + targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"name": "name", "permission": "only_me"}'`} > ```bash {{ title: 'cURL' }} curl --location --request POST '${props.apiBaseUrl}/datasets' \ --header 'Authorization: Bearer {api_key}' \ --header 'Content-Type: application/json' \ --data-raw '{ - "name": "name" + "name": "name", + "permission": "only_me" }' ``` diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx index c5bc3bb210..b4a7a8a619 100644 --- a/web/app/components/app-sidebar/app-info.tsx +++ b/web/app/components/app-sidebar/app-info.tsx @@ -63,6 +63,7 @@ const AppInfo = ({ expand }: IAppInfoProps) => { icon, icon_background, description, + use_icon_as_answer_icon, }) => { if (!appDetail) return @@ -74,6 +75,7 @@ const AppInfo = ({ expand }: IAppInfoProps) => { icon, icon_background, description, + use_icon_as_answer_icon, }) setShowEditModal(false) notify({ @@ -423,6 +425,8 @@ const AppInfo = ({ expand }: IAppInfoProps) => { appIconBackground={appDetail.icon_background} appIconUrl={appDetail.icon_url} appDescription={appDetail.description} + appMode={appDetail.mode} + appUseIconAsAnswerIcon={appDetail.use_icon_as_answer_icon} show={showEditModal} onConfirm={onEdit} onHide={() => setShowEditModal(false)} diff --git a/web/app/components/app-sidebar/basic.tsx b/web/app/components/app-sidebar/basic.tsx index 09f978b04b..c939cb7bb3 100644 --- a/web/app/components/app-sidebar/basic.tsx +++ b/web/app/components/app-sidebar/basic.tsx @@ -1,15 +1,11 @@ import React from 'react' -import { - InformationCircleIcon, -} from '@heroicons/react/24/outline' -import Tooltip from '../base/tooltip' import AppIcon from '../base/app-icon' -import { randomString } from '@/utils' +import Tooltip from '@/app/components/base/tooltip' export type IAppBasicProps = { iconType?: 'app' | 'api' | 'dataset' | 'webapp' | 'notion' icon?: string - icon_background?: string + icon_background?: string | null name: string type: string | React.ReactNode hoverTip?: string @@ -74,9 +70,17 @@ export default function AppBasic({ icon, icon_background, name, type, hoverTip,
{name} {hoverTip - && - - } + && + {hoverTip} +
+ } + popupClassName='ml-1' + triggerClassName='w-4 h-4 ml-1' + position='top' + /> + }
{type}
} diff --git a/web/app/components/app/annotation/index.tsx b/web/app/components/app/annotation/index.tsx index 1e65d7a94f..0f54f5bfc3 100644 --- a/web/app/components/app/annotation/index.tsx +++ b/web/app/components/app/annotation/index.tsx @@ -280,7 +280,7 @@ const Annotation: FC = ({ onSave={async (embeddingModel, score) => { if ( embeddingModel.embedding_model_name !== annotationConfig?.embedding_model?.embedding_model_name - && embeddingModel.embedding_provider_name !== annotationConfig?.embedding_model?.embedding_provider_name + || embeddingModel.embedding_provider_name !== annotationConfig?.embedding_model?.embedding_provider_name ) { const { job_id: jobId }: any = await updateAnnotationStatus(appDetail.id, AnnotationEnableStatus.enable, embeddingModel, score) await ensureJobCompleted(jobId, AnnotationEnableStatus.enable) diff --git a/web/app/components/app/app-publisher/index.tsx b/web/app/components/app/app-publisher/index.tsx index e971274a71..2bcc74ec01 100644 --- a/web/app/components/app/app-publisher/index.tsx +++ b/web/app/components/app/app-publisher/index.tsx @@ -24,6 +24,7 @@ import { LeftIndent02 } from '@/app/components/base/icons/src/vender/line/editor import { FileText } from '@/app/components/base/icons/src/vender/line/files' import WorkflowToolConfigureButton from '@/app/components/tools/workflow-tool/configure-button' import type { InputVar } from '@/app/components/workflow/types' +import { appDefaultIconBackground } from '@/config' export type AppPublisherProps = { disabled?: boolean @@ -62,6 +63,7 @@ const AppPublisher = ({ const [published, setPublished] = useState(false) const [open, setOpen] = useState(false) const appDetail = useAppStore(state => state.appDetail) + const [publishedTime, setPublishedTime] = useState(publishedAt) const { app_base_url: appBaseURL = '', access_token: accessToken = '' } = appDetail?.site ?? {} const appMode = (appDetail?.mode !== 'completion' && appDetail?.mode !== 'workflow') ? 'chat' : appDetail.mode const appURL = `${appBaseURL}/${appMode}/${accessToken}` @@ -75,6 +77,7 @@ const AppPublisher = ({ try { await onPublish?.(modelAndParameter) setPublished(true) + setPublishedTime(Date.now()) } catch (e) { setPublished(false) @@ -130,13 +133,13 @@ const AppPublisher = ({
- {publishedAt ? t('workflow.common.latestPublished') : t('workflow.common.currentDraftUnpublished')} + {publishedTime ? t('workflow.common.latestPublished') : t('workflow.common.currentDraftUnpublished')}
- {publishedAt + {publishedTime ? (
- {t('workflow.common.publishedAt')} {formatTimeFromNow(publishedAt)} + {t('workflow.common.publishedAt')} {formatTimeFromNow(publishedTime)}
- }>{t('workflow.common.runApp')} + }>{t('workflow.common.runApp')} {appDetail?.mode === 'workflow' ? ( } > @@ -198,22 +201,22 @@ const AppPublisher = ({ setEmbeddingModalOpen(true) handleTrigger() }} - disabled={!publishedAt} + disabled={!publishedTime} icon={} > {t('workflow.common.embedIntoSite')} )} - }>{t('workflow.common.accessAPIReference')} + }>{t('workflow.common.accessAPIReference')} {appDetail?.mode === 'workflow' && ( = ({
{t('appDebug.pageTitle.line1')}
- {t('appDebug.promptTip')} -
} - selector='config-prompt-tooltip'> - - + popupContent={ +
+ {t('appDebug.promptTip')} +
+ } + />
)}
{canDelete && ( diff --git a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx index adcfcdd126..69e01a8e22 100644 --- a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx @@ -3,9 +3,6 @@ import type { FC } from 'react' import React, { useState } from 'react' import { useTranslation } from 'react-i18next' import { useBoolean } from 'ahooks' -import { - RiQuestionLine, -} from '@remixicon/react' import produce from 'immer' import { useContext } from 'use-context-selector' import ConfirmAddVar from './confirm-add-var' @@ -156,12 +153,12 @@ const Prompt: FC = ({
{mode !== AppType.completion ? t('appDebug.chatSubTitle') : t('appDebug.completionSubTitle')}
{!readonly && ( - {t('appDebug.promptTip')} -
} - selector='config-prompt-tooltip'> - - + popupContent={ +
+ {t('appDebug.promptTip')} +
+ } + /> )}
diff --git a/web/app/components/app/configuration/config-var/config-modal/index.tsx b/web/app/components/app/configuration/config-var/config-modal/index.tsx index 20fcf49de1..3296c77fb2 100644 --- a/web/app/components/app/configuration/config-var/config-modal/index.tsx +++ b/web/app/components/app/configuration/config-var/config-modal/index.tsx @@ -42,18 +42,19 @@ const ConfigModal: FC = ({ const { type, label, variable, options, max_length } = tempPayload const isStringInput = type === InputVarType.textInput || type === InputVarType.paragraph + const checkVariableName = useCallback((value: string) => { + const { isValid, errorMessageKey } = checkKeys([value], false) + if (!isValid) { + Toast.notify({ + type: 'error', + message: t(`appDebug.varKeyError.${errorMessageKey}`, { key: t('appDebug.variableConig.varName') }), + }) + return false + } + return true + }, [t]) const handlePayloadChange = useCallback((key: string) => { return (value: any) => { - if (key === 'variable') { - const { isValid, errorKey, errorMessageKey } = checkKeys([value], true) - if (!isValid) { - Toast.notify({ - type: 'error', - message: t(`appDebug.varKeyError.${errorMessageKey}`, { key: errorKey }), - }) - return - } - } setTempPayload((prev) => { const newPayload = { ...prev, @@ -63,19 +64,20 @@ const ConfigModal: FC = ({ return newPayload }) } - }, [t]) + }, []) const handleVarKeyBlur = useCallback((e: any) => { - if (tempPayload.label) + const varName = e.target.value + if (!checkVariableName(varName) || tempPayload.label) return setTempPayload((prev) => { return { ...prev, - label: e.target.value, + label: varName, } }) - }, [tempPayload]) + }, [checkVariableName, tempPayload.label]) const handleConfirm = () => { const moreInfo = tempPayload.variable === payload?.variable @@ -84,10 +86,11 @@ const ConfigModal: FC = ({ type: ChangeType.changeVarName, payload: { beforeKey: payload?.variable || '', afterKey: tempPayload.variable }, } - if (!tempPayload.variable) { - Toast.notify({ type: 'error', message: t('appDebug.variableConig.errorMsg.varNameRequired') }) + + const isVariableNameValid = checkVariableName(tempPayload.variable) + if (!isVariableNameValid) return - } + // TODO: check if key already exists. should the consider the edit case // if (varKeys.map(key => key?.trim()).includes(tempPayload.variable.trim())) { // Toast.notify({ diff --git a/web/app/components/app/configuration/config-var/index.tsx b/web/app/components/app/configuration/config-var/index.tsx index 82a220c6db..802528e0af 100644 --- a/web/app/components/app/configuration/config-var/index.tsx +++ b/web/app/components/app/configuration/config-var/index.tsx @@ -8,7 +8,6 @@ import { useContext } from 'use-context-selector' import produce from 'immer' import { RiDeleteBinLine, - RiQuestionLine, } from '@remixicon/react' import Panel from '../base/feature-panel' import EditModal from './config-modal' @@ -282,11 +281,13 @@ const ConfigVar: FC = ({ promptVariables, readonly, onPromptVar
{t('appDebug.variableTitle')}
{!readonly && ( - - {t('appDebug.variableTip')} -
} selector='config-var-tooltip'> - - + + {t('appDebug.variableTip')} +
+ } + /> )} } diff --git a/web/app/components/app/configuration/config-vision/index.tsx b/web/app/components/app/configuration/config-vision/index.tsx index 9b12e059b5..515709bff1 100644 --- a/web/app/components/app/configuration/config-vision/index.tsx +++ b/web/app/components/app/configuration/config-vision/index.tsx @@ -2,9 +2,6 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import { - RiQuestionLine, -} from '@remixicon/react' import { useContext } from 'use-context-selector' import Panel from '../base/feature-panel' import ParamConfig from './param-config' @@ -33,11 +30,13 @@ const ConfigVision: FC = () => { title={
{t('appDebug.vision.name')}
- - {t('appDebug.vision.description')} -
} selector='config-vision-tooltip'> - - + + {t('appDebug.vision.description')} + + } + /> } headerRight={ diff --git a/web/app/components/app/configuration/config-vision/param-config-content.tsx b/web/app/components/app/configuration/config-vision/param-config-content.tsx index 89fad411e7..cb01f57254 100644 --- a/web/app/components/app/configuration/config-vision/param-config-content.tsx +++ b/web/app/components/app/configuration/config-vision/param-config-content.tsx @@ -3,9 +3,6 @@ import type { FC } from 'react' import React from 'react' import { useContext } from 'use-context-selector' import { useTranslation } from 'react-i18next' -import { - RiQuestionLine, -} from '@remixicon/react' import RadioGroup from './radio-group' import ConfigContext from '@/context/debug-configuration' import { Resolution, TransferMethod } from '@/types/app' @@ -37,13 +34,15 @@ const ParamConfigContent: FC = () => {
{t('appDebug.vision.visionSettings.resolution')}
- - {t('appDebug.vision.visionSettings.resolutionTooltip').split('\n').map(item => ( -
{item}
- ))} -
} selector='config-resolution-tooltip'> - - + + {t('appDebug.vision.visionSettings.resolutionTooltip').split('\n').map(item => ( +
{item}
+ ))} +
+ } + /> {
{t('appDebug.voice.voiceSettings.language')}
- - {t('appDebug.voice.voiceSettings.resolutionTooltip').split('\n').map(item => ( -
{item}
- ))} -
} selector='config-resolution-tooltip'> - -
+ + {t('appDebug.voice.voiceSettings.resolutionTooltip').split('\n').map(item => ( +
{item}
+ ))} + + } + /> = ({ {icon}
{name}
{description} } - selector={`agent-setting-tooltip-${name}`} > -
diff --git a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx index 16f2257c38..1280fd8928 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx @@ -7,13 +7,11 @@ import produce from 'immer' import { RiDeleteBinLine, RiHammerFill, - RiQuestionLine, } from '@remixicon/react' import { useFormattingChangedDispatcher } from '../../../debug/hooks' import SettingBuiltInTool from './setting-built-in-tool' import cn from '@/utils/classnames' import Panel from '@/app/components/app/configuration/base/feature-panel' -import Tooltip from '@/app/components/base/tooltip' import { InfoCircle } from '@/app/components/base/icons/src/vender/line/general' import OperationBtn from '@/app/components/app/configuration/base/operation-btn' import AppIcon from '@/app/components/base/app-icon' @@ -23,7 +21,7 @@ import type { AgentTool } from '@/types/app' import { type Collection, CollectionType } from '@/app/components/tools/types' import { MAX_TOOLS_NUM } from '@/config' import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import { DefaultToolIcon } from '@/app/components/base/icons/src/public/other' import AddToolModal from '@/app/components/tools/add-tool-modal' @@ -68,11 +66,13 @@ const AgentTools: FC = () => { title={
{t('appDebug.agent.tools.name')}
- - {t('appDebug.agent.tools.description')} -
} selector='config-tools-tooltip'> - - + + {t('appDebug.agent.tools.description')} +
+ } + /> } headerRight={ @@ -119,19 +119,20 @@ const AgentTools: FC = () => { className={cn((item.isDeleted || item.notAuthor) ? 'line-through opacity-50' : '', 'grow w-0 ml-2 leading-[18px] text-[13px] font-medium text-gray-800 truncate')} > {item.provider_type === CollectionType.builtIn ? item.provider_name : item.tool_label} - {item.tool_name} - +
{(item.isDeleted || item.notAuthor) ? (
-
{ if (item.notAuthor) @@ -139,7 +140,7 @@ const AgentTools: FC = () => { }}>
-
+
{ const newModelConfig = produce(modelConfig, (draft) => { @@ -155,16 +156,17 @@ const AgentTools: FC = () => { ) : (
- -
{ +
{ setCurrentTool(item) setIsShowSettingTool(true) }}>
- +
{ const newModelConfig = produce(modelConfig, (draft) => { diff --git a/web/app/components/app/configuration/dataset-config/card-item/index.tsx b/web/app/components/app/configuration/dataset-config/card-item/index.tsx index 7b369d9d79..a528b2288c 100644 --- a/web/app/components/app/configuration/dataset-config/card-item/index.tsx +++ b/web/app/components/app/configuration/dataset-config/card-item/index.tsx @@ -39,10 +39,9 @@ const CardItem: FC = ({
{config.name}
{!config.embedding_available && ( - {t('dataset.unavailable')} + {t('dataset.unavailable')} )}
diff --git a/web/app/components/app/configuration/dataset-config/context-var/index.tsx b/web/app/components/app/configuration/dataset-config/context-var/index.tsx index be0ae47242..0de182b0a9 100644 --- a/web/app/components/app/configuration/dataset-config/context-var/index.tsx +++ b/web/app/components/app/configuration/dataset-config/context-var/index.tsx @@ -2,9 +2,6 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import { - RiQuestionLine, -} from '@remixicon/react' import type { Props } from './var-picker' import VarPicker from './var-picker' import cn from '@/utils/classnames' @@ -24,13 +21,12 @@ const ContextVar: FC = (props) => {
{t('appDebug.feature.dataSet.queryVariable.title')}
- {t('appDebug.feature.dataSet.queryVariable.tip')} -
} - selector='context-var-tooltip' - > - - + popupContent={ +
+ {t('appDebug.feature.dataSet.queryVariable.tip')} +
+ } + />
diff --git a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx index 3656bf6ea7..7f55649dab 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx @@ -5,7 +5,6 @@ import type { FC } from 'react' import { useTranslation } from 'react-i18next' import { RiAlertFill, - RiQuestionLine, } from '@remixicon/react' import WeightedScore from './weighted-score' import TopKItem from '@/app/components/base/param-item/top-k-item' @@ -23,7 +22,7 @@ import ModelSelector from '@/app/components/header/account-setting/model-provide import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import type { ModelConfig } from '@/app/components/workflow/types' import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import type { DataSet, @@ -173,7 +172,7 @@ const ConfigContent: FC = ({ title={(
{t('appDebug.datasetConfig.retrieveOneWay.title')} - {t('dataset.nTo1RetrievalLegacy')} @@ -181,7 +180,7 @@ const ConfigContent: FC = ({ )} >
legacy
-
+
)} description={t('appDebug.datasetConfig.retrieveOneWay.description')} @@ -250,12 +249,15 @@ const ConfigContent: FC = ({ onClick={() => handleRerankModeChange(option.value)} >
{option.label}
- {option.tips}
} - hideArrow - > - - + + {option.tips} +
+ } + popupClassName='ml-0.5' + triggerClassName='ml-0.5 w-3.5 h-3.5' + /> )) } @@ -281,9 +283,15 @@ const ConfigContent: FC = ({ ) }
{t('common.modelProvider.rerankModel.key')}
- {t('common.modelProvider.rerankModel.tip')}}> - - + + {t('common.modelProvider.rerankModel.tip')} + + } + popupClassName='ml-0.5' + triggerClassName='ml-0.5 w-3.5 h-3.5' + />
= ({
{t('common.modelProvider.systemReasoningModel.key')}
- - - + />
= ({ { currentModel && currentModel.status !== ModelStatusEnum.active && ( - + - + ) }
diff --git a/web/app/components/app/configuration/features/chat-group/suggested-questions-after-answer/index.tsx b/web/app/components/app/configuration/features/chat-group/suggested-questions-after-answer/index.tsx index e27eec46c8..199558f4aa 100644 --- a/web/app/components/app/configuration/features/chat-group/suggested-questions-after-answer/index.tsx +++ b/web/app/components/app/configuration/features/chat-group/suggested-questions-after-answer/index.tsx @@ -2,9 +2,6 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import { - RiQuestionLine, -} from '@remixicon/react' import Panel from '@/app/components/app/configuration/base/feature-panel' import SuggestedQuestionsAfterAnswerIcon from '@/app/components/app/configuration/base/icons/suggested-questions-after-answer-icon' import Tooltip from '@/app/components/base/tooltip' @@ -15,13 +12,15 @@ const SuggestedQuestionsAfterAnswer: FC = () => { return ( +
{t('appDebug.feature.suggestedQuestionsAfterAnswer.title')}
- - {t('appDebug.feature.suggestedQuestionsAfterAnswer.description')} -
} selector='suggestion-question-tooltip'> - - + + {t('appDebug.feature.suggestedQuestionsAfterAnswer.description')} +
+ } + /> } headerIcon={} diff --git a/web/app/components/app/configuration/prompt-value-panel/index.tsx b/web/app/components/app/configuration/prompt-value-panel/index.tsx index 0192024c83..dc9b0a4333 100644 --- a/web/app/components/app/configuration/prompt-value-panel/index.tsx +++ b/web/app/components/app/configuration/prompt-value-panel/index.tsx @@ -16,7 +16,7 @@ import { AppType, ModelModeType } from '@/types/app' import Select from '@/app/components/base/select' import { DEFAULT_VALUE_MAX_LEN } from '@/config' import Button from '@/app/components/base/button' -import Tooltip from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import TextGenerationImageUploader from '@/app/components/base/image-uploader/text-generation-image-uploader' import type { VisionFile, VisionSettings } from '@/types/app' @@ -207,6 +207,7 @@ const PromptValuePanel: FC = ({ {canNotRun ? ( {renderRunButton()} ) diff --git a/web/app/components/app/configuration/toolbox/annotation/annotation-ctrl-btn/index.tsx b/web/app/components/app/configuration/toolbox/annotation/annotation-ctrl-btn/index.tsx index b2c6792107..111c380afc 100644 --- a/web/app/components/app/configuration/toolbox/annotation/annotation-ctrl-btn/index.tsx +++ b/web/app/components/app/configuration/toolbox/annotation/annotation-ctrl-btn/index.tsx @@ -8,7 +8,7 @@ import { MessageCheckRemove, MessageFastPlus } from '@/app/components/base/icons import { MessageFast } from '@/app/components/base/icons/src/vender/solid/communication' import { Edit04 } from '@/app/components/base/icons/src/vender/line/general' import RemoveAnnotationConfirmModal from '@/app/components/app/annotation/remove-annotation-confirm-modal' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import { addAnnotation, delAnnotation } from '@/service/annotation' import Toast from '@/app/components/base/toast' import { useProviderContext } from '@/context/provider-context' @@ -99,8 +99,9 @@ const CacheCtrlBtn: FC = ({ ) : answer ? ( -
= ({ >
-
+
) : null } -
= ({ >
-
+
{title}
- {tooltip}
} - > - - + />
{children}
@@ -103,7 +98,7 @@ const AnnotationReplyConfig: FC = ({ let isEmbeddingModelChanged = false if ( embeddingModel.embedding_model_name !== annotationConfig.embedding_model.embedding_model_name - && embeddingModel.embedding_provider_name !== annotationConfig.embedding_model.embedding_provider_name + || embeddingModel.embedding_provider_name !== annotationConfig.embedding_model.embedding_provider_name ) { await onEmbeddingChange(embeddingModel) isEmbeddingModelChanged = true diff --git a/web/app/components/app/configuration/toolbox/index.tsx b/web/app/components/app/configuration/toolbox/index.tsx index 488ca86b90..00ea301a42 100644 --- a/web/app/components/app/configuration/toolbox/index.tsx +++ b/web/app/components/app/configuration/toolbox/index.tsx @@ -32,7 +32,7 @@ const Toolbox: FC = ({ ) } { - (showAnnotation || true) && ( + showAnnotation && ( {
{t('appDebug.feature.tools.title')}
- {t('appDebug.feature.tools.tips')}}> - - + + {t('appDebug.feature.tools.tips')} + + } + /> { !expanded && !!externalDataToolsConfig.length && ( @@ -143,7 +146,7 @@ const Tools = () => { background={item.icon_background} />
{item.label}
-
{ > {item.variable}
-
+
{
{t('app.newApp.captionAppType')}
- {t('app.newApp.chatbotDescription')}
} @@ -120,9 +119,8 @@ const CreateAppModal = ({ show, onSuccess, onClose }: CreateAppDialogProps) => {
{t('app.types.chatbot')}
- - +
{t('app.newApp.completionDescription')}
@@ -143,9 +141,8 @@ const CreateAppModal = ({ show, onSuccess, onClose }: CreateAppDialogProps) => {
{t('app.newApp.completeApp')}
- - + {t('app.newApp.agentDescription')} } @@ -164,9 +161,8 @@ const CreateAppModal = ({ show, onSuccess, onClose }: CreateAppDialogProps) => {
{t('app.types.agent')}
-
- +
{t('app.newApp.workflowDescription')}
@@ -188,7 +184,7 @@ const CreateAppModal = ({ show, onSuccess, onClose }: CreateAppDialogProps) => {
{t('app.types.workflow')}
BETA -
+ {showChatBotType && ( diff --git a/web/app/components/app/log-annotation/index.tsx b/web/app/components/app/log-annotation/index.tsx index 852e57035c..e0ace764f0 100644 --- a/web/app/components/app/log-annotation/index.tsx +++ b/web/app/components/app/log-annotation/index.tsx @@ -1,6 +1,6 @@ 'use client' import type { FC } from 'react' -import React from 'react' +import React, { useMemo } from 'react' import { useTranslation } from 'react-i18next' import { useRouter } from 'next/navigation' import cn from '@/utils/classnames' @@ -23,10 +23,14 @@ const LogAnnotation: FC = ({ const router = useRouter() const appDetail = useAppStore(state => state.appDetail) - const options = [ - { value: PageType.log, text: t('appLog.title') }, - { value: PageType.annotation, text: t('appAnnotation.title') }, - ] + const options = useMemo(() => { + if (appDetail?.mode === 'completion') + return [{ value: PageType.log, text: t('appLog.title') }] + return [ + { value: PageType.log, text: t('appLog.title') }, + { value: PageType.annotation, text: t('appAnnotation.title') }, + ] + }, [appDetail]) if (!appDetail) { return ( diff --git a/web/app/components/app/log/filter.tsx b/web/app/components/app/log/filter.tsx index 80a58bb5a2..0552b44d16 100644 --- a/web/app/components/app/log/filter.tsx +++ b/web/app/components/app/log/filter.tsx @@ -10,6 +10,7 @@ import dayjs from 'dayjs' import quarterOfYear from 'dayjs/plugin/quarterOfYear' import type { QueryParam } from './index' import { SimpleSelect } from '@/app/components/base/select' +import Sort from '@/app/components/base/sort' import { fetchAnnotationsCount } from '@/service/log' dayjs.extend(quarterOfYear) @@ -28,18 +29,19 @@ export const TIME_PERIOD_LIST = [ ] type IFilterProps = { + isChatMode?: boolean appId: string queryParams: QueryParam setQueryParams: (v: QueryParam) => void } -const Filter: FC = ({ appId, queryParams, setQueryParams }: IFilterProps) => { +const Filter: FC = ({ isChatMode, appId, queryParams, setQueryParams }: IFilterProps) => { const { data } = useSWR({ url: `/apps/${appId}/annotations/count` }, fetchAnnotationsCount) const { t } = useTranslation() if (!data) return null return ( -
+
({ value: item.value, name: t(`appLog.filter.period.${item.name}`) }))} className='mt-0 !w-40' @@ -68,7 +70,7 @@ const Filter: FC = ({ appId, queryParams, setQueryParams }: IFilte { @@ -76,6 +78,22 @@ const Filter: FC = ({ appId, queryParams, setQueryParams }: IFilte }} />
+ {isChatMode && ( + <> +
+ { + setQueryParams({ ...queryParams, sort_by: value as string }) + }} + /> + + )}
) } diff --git a/web/app/components/app/log/index.tsx b/web/app/components/app/log/index.tsx index e9ad2f43c6..dd6ebd08f0 100644 --- a/web/app/components/app/log/index.tsx +++ b/web/app/components/app/log/index.tsx @@ -24,6 +24,7 @@ export type QueryParam = { period?: number | string annotation_status?: string keyword?: string + sort_by?: string } const ThreeDotsIcon = ({ className }: SVGProps) => { @@ -52,9 +53,16 @@ const EmptyElement: FC<{ appUrl: string }> = ({ appUrl }) => { const Logs: FC = ({ appDetail }) => { const { t } = useTranslation() - const [queryParams, setQueryParams] = useState({ period: 7, annotation_status: 'all' }) + const [queryParams, setQueryParams] = useState({ + period: 7, + annotation_status: 'all', + sort_by: '-created_at', + }) const [currPage, setCurrPage] = React.useState(0) + // Get the app type first + const isChatMode = appDetail.mode !== 'completion' + const query = { page: currPage + 1, limit: APP_PAGE_LIMIT, @@ -64,6 +72,7 @@ const Logs: FC = ({ appDetail }) => { end: dayjs().endOf('day').format('YYYY-MM-DD HH:mm'), } : {}), + ...(isChatMode ? { sort_by: queryParams.sort_by } : {}), ...omit(queryParams, ['period']), } @@ -73,9 +82,6 @@ const Logs: FC = ({ appDetail }) => { return appType } - // Get the app type first - const isChatMode = appDetail.mode !== 'completion' - // When the details are obtained, proceed to the next request const { data: chatConversations, mutate: mutateChatList } = useSWR(() => isChatMode ? { @@ -97,7 +103,7 @@ const Logs: FC = ({ appDetail }) => {

{t('appLog.description')}

- + {total === undefined ? : total > 0 diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index 646ae80116..0bc118c46f 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -5,10 +5,9 @@ import useSWR from 'swr' import { HandThumbDownIcon, HandThumbUpIcon, - InformationCircleIcon, XMarkIcon, } from '@heroicons/react/24/outline' -import { RiEditFill } from '@remixicon/react' +import { RiEditFill, RiQuestionLine } from '@remixicon/react' import { get } from 'lodash-es' import InfiniteScroll from 'react-infinite-scroll-component' import dayjs from 'dayjs' @@ -20,7 +19,6 @@ import { useTranslation } from 'react-i18next' import s from './style.module.css' import VarPanel from './var-panel' import cn from '@/utils/classnames' -import { randomString } from '@/utils' import type { FeedbackFunc, Feedbacktype, IChatItem, SubmitAnnotationFunc } from '@/app/components/base/chat/chat/type' import type { Annotation, ChatConversationFullDetailResponse, ChatConversationGeneralDetail, ChatConversationsResponse, ChatMessage, ChatMessagesRequest, CompletionConversationFullDetailResponse, CompletionConversationGeneralDetail, CompletionConversationsResponse, LogAnnotation } from '@/models/log' import type { App } from '@/types/app' @@ -28,7 +26,6 @@ import Loading from '@/app/components/base/loading' import Drawer from '@/app/components/base/drawer' import Popover from '@/app/components/base/popover' import Chat from '@/app/components/base/chat/chat' -import Tooltip from '@/app/components/base/tooltip' import { ToastContext } from '@/app/components/base/toast' import { fetchChatConversationDetail, fetchChatMessages, fetchCompletionConversationDetail, updateLogMessageAnnotations, updateLogMessageFeedbacks } from '@/service/log' import { TONE_LIST } from '@/config' @@ -42,7 +39,7 @@ import MessageLogModal from '@/app/components/base/message-log-modal' import { useStore as useAppStore } from '@/app/components/app/store' import { useAppContext } from '@/context/app-context' import useTimestamp from '@/hooks/use-timestamp' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import { CopyIcon } from '@/app/components/base/copy-icon' dayjs.extend(utc) @@ -346,11 +343,11 @@ function DetailPanel{isChatMode ? t('appLog.detail.conversationId') : t('appLog.detail.time')}
{isChatMode && (
- +
{detail.id}
-
+
)} @@ -380,7 +377,7 @@ function DetailPanel {targetTone} - + } htmlContent={
@@ -641,13 +638,12 @@ const ConversationList: FC = ({ logs, appDetail, onRefresh }) const renderTdValue = (value: string | number | null, isEmptyStyle: boolean, isHighlight = false, annotation?: LogAnnotation) => { return ( {`${t('appLog.detail.annotationTip', { user: annotation?.account?.name })} ${formatTime(annotation?.created_at || dayjs().unix(), 'MM-DD hh:mm A')}`} } - className={(isHighlight && !isChatMode) ? '' : '!hidden'} - selector={`highlight-${randomString(16)}`} + popupClassName={(isHighlight && !isChatMode) ? '' : '!hidden'} >
{value || '-'} @@ -671,17 +667,18 @@ const ConversationList: FC = ({ logs, appDetail, onRefresh }) - {t('appLog.table.header.time')} - {t('appLog.table.header.endUser')} {isChatMode ? t('appLog.table.header.summary') : t('appLog.table.header.input')} + {t('appLog.table.header.endUser')} {isChatMode ? t('appLog.table.header.messageCount') : t('appLog.table.header.output')} {t('appLog.table.header.userRate')} {t('appLog.table.header.adminRate')} + {t('appLog.table.header.updatedTime')} + {t('appLog.table.header.time')} {logs.data.map((log: any) => { - const endUser = log.from_end_user_session_id + const endUser = log.from_end_user_session_id || log.from_account_name const leftValue = get(log, isChatMode ? 'name' : 'message.inputs.query') || (!isChatMode ? (get(log, 'message.query') || get(log, 'message.inputs.default_input')) : '') || '' const rightValue = get(log, isChatMode ? 'message_count' : 'message.answer') return = ({ logs, appDetail, onRefresh }) setCurrentConversation(log) }}> {!log.read_at && } - {formatTime(log.created_at, t('appLog.dateTimeFormat') as string)} - {renderTdValue(endUser || defaultValue, !endUser)} {renderTdValue(leftValue || t('appLog.table.empty.noChat'), !leftValue, isChatMode && log.annotated)} + {renderTdValue(endUser || defaultValue, !endUser)} {renderTdValue(rightValue === 0 ? 0 : (rightValue || t('appLog.table.empty.noOutput')), !rightValue, !isChatMode && !!log.annotation?.content, log.annotation)} @@ -718,6 +714,8 @@ const ConversationList: FC = ({ logs, appDetail, onRefresh }) } + {formatTime(log.updated_at, t('appLog.dateTimeFormat') as string)} + {formatTime(log.created_at, t('appLog.dateTimeFormat') as string)} })} diff --git a/web/app/components/app/overview/appCard.tsx b/web/app/components/app/overview/appCard.tsx index ea0b793857..f9f5c1fbff 100644 --- a/web/app/components/app/overview/appCard.tsx +++ b/web/app/components/app/overview/appCard.tsx @@ -27,10 +27,11 @@ import ShareQRCode from '@/app/components/base/qrcode' import SecretKeyButton from '@/app/components/develop/secret-key/secret-key-button' import type { AppDetailResponse } from '@/models/app' import { useAppContext } from '@/context/app-context' +import type { AppSSO } from '@/types/app' export type IAppCardProps = { className?: string - appInfo: AppDetailResponse + appInfo: AppDetailResponse & Partial cardType?: 'api' | 'webapp' customBgColor?: string onChangeStatus: (val: boolean) => Promise @@ -133,8 +134,8 @@ function AppCard({ return (
@@ -175,7 +176,6 @@ function AppCard({ {isApp && } {/* button copy link/ button regenerate */} @@ -194,16 +194,15 @@ function AppCard({ )} {isApp && isCurrentWorkspaceManager && (
setShowConfirmDelete(true)} >
@@ -226,11 +225,10 @@ function AppCard({ disabled={disabled} >
diff --git a/web/app/components/app/overview/appChart.tsx b/web/app/components/app/overview/appChart.tsx index 7fd316a34b..e0788bcda3 100644 --- a/web/app/components/app/overview/appChart.tsx +++ b/web/app/components/app/overview/appChart.tsx @@ -10,8 +10,8 @@ import { useTranslation } from 'react-i18next' import { formatNumber } from '@/utils/format' import Basic from '@/app/components/app-sidebar/basic' import Loading from '@/app/components/base/loading' -import type { AppDailyConversationsResponse, AppDailyEndUsersResponse, AppTokenCostsResponse } from '@/models/app' -import { getAppDailyConversations, getAppDailyEndUsers, getAppStatistics, getAppTokenCosts, getWorkflowDailyConversations } from '@/service/apps' +import type { AppDailyConversationsResponse, AppDailyEndUsersResponse, AppDailyMessagesResponse, AppTokenCostsResponse } from '@/models/app' +import { getAppDailyConversations, getAppDailyEndUsers, getAppDailyMessages, getAppStatistics, getAppTokenCosts, getWorkflowDailyConversations } from '@/service/apps' const valueFormatter = (v: string | number) => v const COLOR_TYPE_MAP = { @@ -36,12 +36,15 @@ const COMMON_COLOR_MAP = { } type IColorType = 'green' | 'orange' | 'blue' -type IChartType = 'conversations' | 'endUsers' | 'costs' | 'workflowCosts' +type IChartType = 'messages' | 'conversations' | 'endUsers' | 'costs' | 'workflowCosts' type IChartConfigType = { colorType: IColorType; showTokens?: boolean } const commonDateFormat = 'MMM D, YYYY' const CHART_TYPE_CONFIG: Record = { + messages: { + colorType: 'green', + }, conversations: { colorType: 'green', }, @@ -89,7 +92,7 @@ export type IChartProps = { unit?: string yMax?: number chartType: IChartType - chartData: AppDailyConversationsResponse | AppDailyEndUsersResponse | AppTokenCostsResponse | { data: Array<{ date: string; count: number }> } + chartData: AppDailyMessagesResponse | AppDailyConversationsResponse | AppDailyEndUsersResponse | AppTokenCostsResponse | { data: Array<{ date: string; count: number }> } } const Chart: React.FC = ({ @@ -258,6 +261,20 @@ const getDefaultChartData = ({ start, end, key = 'count' }: { start: string; end }) } +export const MessagesChart: FC = ({ id, period }) => { + const { t } = useTranslation() + const { data: response } = useSWR({ url: `/apps/${id}/statistics/daily-messages`, params: period.query }, getAppDailyMessages) + if (!response) + return + const noDataFlag = !response.data || response.data.length === 0 + return +} + export const ConversationsChart: FC = ({ id, period }) => { const { t } = useTranslation() const { data: response } = useSWR({ url: `/apps/${id}/statistics/daily-conversations`, params: period.query }, getAppDailyConversations) @@ -265,7 +282,7 @@ export const ConversationsChart: FC = ({ id, period }) => { return const noDataFlag = !response.data || response.data.length === 0 return
diff --git a/web/app/components/app/overview/settings/index.tsx b/web/app/components/app/overview/settings/index.tsx index 8da9b9864f..a501d06ce4 100644 --- a/web/app/components/app/overview/settings/index.tsx +++ b/web/app/components/app/overview/settings/index.tsx @@ -4,21 +4,25 @@ import React, { useEffect, useState } from 'react' import { ChevronRightIcon } from '@heroicons/react/20/solid' import Link from 'next/link' import { Trans, useTranslation } from 'react-i18next' +import { useContextSelector } from 'use-context-selector' import s from './style.module.css' import Modal from '@/app/components/base/modal' import Button from '@/app/components/base/button' import AppIcon from '@/app/components/base/app-icon' +import Switch from '@/app/components/base/switch' import { SimpleSelect } from '@/app/components/base/select' import type { AppDetailResponse } from '@/models/app' -import type { AppIconType, Language } from '@/types/app' +import type { AppIconType, AppSSO, Language } from '@/types/app' import { useToastContext } from '@/app/components/base/toast' import { languages } from '@/i18n/language' +import Tooltip from '@/app/components/base/tooltip' +import AppContext from '@/context/app-context' import type { AppIconSelection } from '@/app/components/base/app-icon-picker' import AppIconPicker from '@/app/components/base/app-icon-picker' export type ISettingsModalProps = { isChat: boolean - appInfo: AppDetailResponse + appInfo: AppDetailResponse & Partial isShow: boolean defaultValue?: string onClose: () => void @@ -39,6 +43,8 @@ export type ConfigParams = { icon: string icon_background?: string show_workflow_steps: boolean + use_icon_as_answer_icon: boolean + enable_sso?: boolean } const prefixSettings = 'appOverview.overview.appInfo.settings' @@ -50,6 +56,7 @@ const SettingsModal: FC = ({ onClose, onSave, }) => { + const systemFeatures = useContextSelector(AppContext, state => state.systemFeatures) const { notify } = useToastContext() const [isShowMore, setIsShowMore] = useState(false) const { @@ -66,6 +73,7 @@ const SettingsModal: FC = ({ custom_disclaimer, default_language, show_workflow_steps, + use_icon_as_answer_icon, } = appInfo.site const [inputInfo, setInputInfo] = useState({ title, @@ -76,6 +84,8 @@ const SettingsModal: FC = ({ privacyPolicy: privacy_policy, customDisclaimer: custom_disclaimer, show_workflow_steps, + use_icon_as_answer_icon, + enable_sso: appInfo.enable_sso, }) const [language, setLanguage] = useState(default_language) const [saveLoading, setSaveLoading] = useState(false) @@ -87,6 +97,7 @@ const SettingsModal: FC = ({ ? { type: 'image', url: icon_url!, fileId: icon } : { type: 'emoji', icon, background: icon_background! }, ) + const isChatBot = appInfo.mode === 'chat' || appInfo.mode === 'advanced-chat' || appInfo.mode === 'agent-chat' useEffect(() => { setInputInfo({ @@ -98,6 +109,8 @@ const SettingsModal: FC = ({ privacyPolicy: privacy_policy, customDisclaimer: custom_disclaimer, show_workflow_steps, + use_icon_as_answer_icon, + enable_sso: appInfo.enable_sso, }) setLanguage(default_language) setAppIcon(icon_type === 'image' @@ -149,6 +162,8 @@ const SettingsModal: FC = ({ icon: appIcon.type === 'emoji' ? appIcon.icon : appIcon.fileId, icon_background: appIcon.type === 'emoji' ? appIcon.background : undefined, show_workflow_steps: inputInfo.show_workflow_steps, + use_icon_as_answer_icon: inputInfo.use_icon_as_answer_icon, + enable_sso: inputInfo.enable_sso, } await onSave?.(params) setSaveLoading(false) @@ -200,28 +215,61 @@ const SettingsModal: FC = ({ onChange={onChange('desc')} placeholder={t(`${prefixSettings}.webDescPlaceholder`) as string} /> + {isChatBot && ( +
+
+
{t('app.answerIcon.title')}
+ setInputInfo({ ...inputInfo, use_icon_as_answer_icon: v })} + /> +
+

{t('app.answerIcon.description')}

+
+ )}
{t(`${prefixSettings}.language`)}
item.supported)} defaultValue={language} onSelect={item => setLanguage(item.value as Language)} /> - {(appInfo.mode === 'workflow' || appInfo.mode === 'advanced-chat') && <> -
{t(`${prefixSettings}.workflow.title`)}
- setInputInfo({ ...inputInfo, show_workflow_steps: item.value === 'true' })} - /> - } +
+

{t(`${prefixSettings}.workflow.title`)}

+
+
{t(`${prefixSettings}.workflow.subTitle`)}
+ setInputInfo({ ...inputInfo, show_workflow_steps: v })} + /> +
+

{t(`${prefixSettings}.workflow.showDesc`)}

+
+ {isChat && <>
{t(`${prefixSettings}.chatColorTheme`)}

{t(`${prefixSettings}.chatColorThemeDesc`)}

} + {systemFeatures.enable_web_sso_switch_component &&
+

{t(`${prefixSettings}.sso.label`)}

+
+
{t(`${prefixSettings}.sso.title`)}
+ {t(`${prefixSettings}.sso.tooltip`)}
+ } + asChild={false} + > + setInputInfo({ ...inputInfo, enable_sso: v })}> + +
+

{t(`${prefixSettings}.sso.description`)}

+
} {!isShowMore &&
setIsShowMore(true)}>
{t(`${prefixSettings}.more.entry`)}
diff --git a/web/app/components/app/store.ts b/web/app/components/app/store.ts index a89b96d65d..0209102372 100644 --- a/web/app/components/app/store.ts +++ b/web/app/components/app/store.ts @@ -1,9 +1,9 @@ import { create } from 'zustand' -import type { App } from '@/types/app' +import type { App, AppSSO } from '@/types/app' import type { IChatItem } from '@/app/components/base/chat/chat/type' type State = { - appDetail?: App + appDetail?: App & Partial appSidebarExpand: string currentLogItem?: IChatItem currentLogModalActiveTab: string @@ -13,7 +13,7 @@ type State = { } type Action = { - setAppDetail: (appDetail?: App) => void + setAppDetail: (appDetail?: App & Partial) => void setAppSiderbarExpand: (state: string) => void setCurrentLogItem: (item?: IChatItem) => void setCurrentLogModalActiveTab: (tab: string) => void diff --git a/web/app/components/app/workflow-log/list.tsx b/web/app/components/app/workflow-log/list.tsx index f4707dce59..5a81c8cd8b 100644 --- a/web/app/components/app/workflow-log/list.tsx +++ b/web/app/components/app/workflow-log/list.tsx @@ -91,7 +91,7 @@ const WorkflowAppLogList: FC = ({ logs, appDetail, onRefresh }) => { {logs.data.map((log: WorkflowAppLogDetail) => { - const endUser = log.created_by_end_user ? log.created_by_end_user.session_id : defaultValue + const endUser = log.created_by_end_user ? log.created_by_end_user.session_id : log.created_by_account ? log.created_by_account.name : defaultValue return = ({ + iconType, + icon, + background, + imageUrl, +}) => { + const wrapperClassName = classNames( + 'flex', + 'items-center', + 'justify-center', + 'w-full', + 'h-full', + 'rounded-full', + 'border-[0.5px]', + 'border-black/5', + 'text-xl', + ) + const isValidImageIcon = iconType === 'image' && imageUrl + return
+ {isValidImageIcon + ? answer icon + : (icon && icon !== '') ? : + } +
+} + +export default AnswerIcon diff --git a/web/app/components/base/audio-btn/index.tsx b/web/app/components/base/audio-btn/index.tsx index 675f58b530..d57c79b571 100644 --- a/web/app/components/base/audio-btn/index.tsx +++ b/web/app/components/base/audio-btn/index.tsx @@ -83,25 +83,25 @@ const AudioBtn = ({ }[audioState] return ( -
+
diff --git a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx index 8eda66c52a..1129b18a8d 100644 --- a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx +++ b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx @@ -13,6 +13,7 @@ import { getUrl, stopChatMessageResponding, } from '@/service/share' +import AnswerIcon from '@/app/components/base/answer-icon' const ChatWrapper = () => { const { @@ -128,6 +129,15 @@ const ChatWrapper = () => { isMobile, ]) + const answerIcon = (appData?.site && appData.site.use_icon_as_answer_icon) + ? + : null + return ( { allToolIcons={appMeta?.tool_icons || {}} onFeedback={handleFeedback} suggestedQuestions={suggestedQuestions} + answerIcon={answerIcon} hideProcessDetail themeBuilder={themeBuilder} /> diff --git a/web/app/components/base/chat/chat-with-history/config-panel/index.tsx b/web/app/components/base/chat/chat-with-history/config-panel/index.tsx index 6bee44b8ef..05f253290f 100644 --- a/web/app/components/base/chat/chat-with-history/config-panel/index.tsx +++ b/web/app/components/base/chat/chat-with-history/config-panel/index.tsx @@ -43,9 +43,12 @@ const ConfigPanel = () => { <>
{appData?.site.title}
diff --git a/web/app/components/base/chat/chat-with-history/hooks.tsx b/web/app/components/base/chat/chat-with-history/hooks.tsx index 624cc53a18..fe952efc27 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.tsx @@ -65,6 +65,7 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { prompt_public: false, copyright: '', show_workflow_steps: true, + use_icon_as_answer_icon: app.use_icon_as_answer_icon, }, plan: 'basic', } as AppData diff --git a/web/app/components/base/chat/chat/answer/index.tsx b/web/app/components/base/chat/chat/answer/index.tsx index 78a0842595..270cd553c2 100644 --- a/web/app/components/base/chat/chat/answer/index.tsx +++ b/web/app/components/base/chat/chat/answer/index.tsx @@ -22,6 +22,7 @@ import Citation from '@/app/components/base/chat/chat/citation' import { EditTitle } from '@/app/components/app/annotation/edit-annotation-modal/edit-item' import type { Emoji } from '@/app/components/tools/types' import type { AppData } from '@/models/share' +import AnswerIcon from '@/app/components/base/answer-icon' type AnswerProps = { item: ChatItem @@ -89,11 +90,7 @@ const Answer: FC = ({
{ - answerIcon || ( -
- 🤖 -
- ) + answerIcon || } { responding && ( diff --git a/web/app/components/base/chat/chat/answer/operation.tsx b/web/app/components/base/chat/chat/answer/operation.tsx index 8ec5c0f3b2..f54e736826 100644 --- a/web/app/components/base/chat/chat/answer/operation.tsx +++ b/web/app/components/base/chat/chat/answer/operation.tsx @@ -17,7 +17,7 @@ import { ThumbsDown, ThumbsUp, } from '@/app/components/base/icons/src/vender/line/alertsAndFeedback' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import Log from '@/app/components/base/chat/chat/log' type OperationProps = { @@ -162,28 +162,34 @@ const Operation: FC = ({ { config?.supportFeedback && !localFeedback?.rating && onFeedback && !isOpeningStatement && (
- +
handleFeedback('like')} >
-
- + +
handleFeedback('dislike')} >
-
+
) } { config?.supportFeedback && localFeedback?.rating && onFeedback && !isOpeningStatement && ( - +
= ({ ) }
-
+ ) }
diff --git a/web/app/components/base/chat/chat/chat-input.tsx b/web/app/components/base/chat/chat/chat-input.tsx index 0c083157a1..c4578fab62 100644 --- a/web/app/components/base/chat/chat/chat-input.tsx +++ b/web/app/components/base/chat/chat/chat-input.tsx @@ -17,7 +17,7 @@ import { TransferMethod } from '../types' import { useChatWithHistoryContext } from '../chat-with-history/context' import type { Theme } from '../embedded-chatbot/theme/theme-context' import { CssTransform } from '../embedded-chatbot/theme/utils' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import { ToastContext } from '@/app/components/base/toast' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import VoiceInput from '@/app/components/base/voice-input' @@ -220,7 +220,7 @@ const ChatInput: FC = ({ {isMobile ? sendBtn : ( -
{t('common.operation.send')} Enter
@@ -229,7 +229,7 @@ const ChatInput: FC = ({ } > {sendBtn} -
+ )}
{ diff --git a/web/app/components/base/chat/chat/hooks.ts b/web/app/components/base/chat/chat/hooks.ts index a70b1c2ed0..4caa8116df 100644 --- a/web/app/components/base/chat/chat/hooks.ts +++ b/web/app/components/base/chat/chat/hooks.ts @@ -372,11 +372,16 @@ export const useChat = ( handleUpdateChatList(newChatList) } if (config?.suggested_questions_after_answer?.enabled && !hasStopResponded.current && onGetSuggestedQuestions) { - const { data }: any = await onGetSuggestedQuestions( - responseItem.id, - newAbortController => suggestedQuestionsAbortControllerRef.current = newAbortController, - ) - setSuggestQuestions(data) + try { + const { data }: any = await onGetSuggestedQuestions( + responseItem.id, + newAbortController => suggestedQuestionsAbortControllerRef.current = newAbortController, + ) + setSuggestQuestions(data) + } + catch (e) { + setSuggestQuestions([]) + } } }, onFile(file) { diff --git a/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx b/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx index 68646a17db..4934339e87 100644 --- a/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx +++ b/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx @@ -15,6 +15,7 @@ import { stopChatMessageResponding, } from '@/service/share' import LogoAvatar from '@/app/components/base/logo/logo-embeded-chat-avatar' +import AnswerIcon from '@/app/components/base/answer-icon' const ChatWrapper = () => { const { @@ -114,6 +115,17 @@ const ChatWrapper = () => { return null }, [currentConversationId, inputsForms, isMobile]) + const answerIcon = isDify() + ? + : (appData?.site && appData.site.use_icon_as_answer_icon) + ? + : null + return ( { allToolIcons={appMeta?.tool_icons || {}} onFeedback={handleFeedback} suggestedQuestions={suggestedQuestions} - answerIcon={isDify() ? : null} + answerIcon={answerIcon} hideProcessDetail themeBuilder={themeBuilder} /> diff --git a/web/app/components/base/chat/embedded-chatbot/config-panel/index.tsx b/web/app/components/base/chat/embedded-chatbot/config-panel/index.tsx index 81f57a04ae..df5d12ef14 100644 --- a/web/app/components/base/chat/embedded-chatbot/config-panel/index.tsx +++ b/web/app/components/base/chat/embedded-chatbot/config-panel/index.tsx @@ -48,9 +48,12 @@ const ConfigPanel = () => { <>
{appData?.site.title}
diff --git a/web/app/components/base/chat/embedded-chatbot/header.tsx b/web/app/components/base/chat/embedded-chatbot/header.tsx index c35f98e3f2..a5c74434c6 100644 --- a/web/app/components/base/chat/embedded-chatbot/header.tsx +++ b/web/app/components/base/chat/embedded-chatbot/header.tsx @@ -41,9 +41,7 @@ const Header: FC = ({
{ onCreateNewChat?.() diff --git a/web/app/components/base/chat/embedded-chatbot/index.tsx b/web/app/components/base/chat/embedded-chatbot/index.tsx index d34fe164d1..480adaae2d 100644 --- a/web/app/components/base/chat/embedded-chatbot/index.tsx +++ b/web/app/components/base/chat/embedded-chatbot/index.tsx @@ -88,9 +88,7 @@ const Chatbot = () => { {!isMobile && (
diff --git a/web/app/components/base/copy-btn/index.tsx b/web/app/components/base/copy-btn/index.tsx index a03b991057..2acb5d8e76 100644 --- a/web/app/components/base/copy-btn/index.tsx +++ b/web/app/components/base/copy-btn/index.tsx @@ -1,10 +1,9 @@ 'use client' -import { useRef, useState } from 'react' +import { useState } from 'react' import { t } from 'i18next' import copy from 'copy-to-clipboard' import s from './style.module.css' import Tooltip from '@/app/components/base/tooltip' -import { randomString } from '@/utils' type ICopyBtnProps = { value: string @@ -18,14 +17,11 @@ const CopyBtn = ({ isPlain, }: ICopyBtnProps) => { const [isCopied, setIsCopied] = useState(false) - const selector = useRef(`copy-tooltip-${randomString(4)}`) return (
{ +const CopyFeedback = ({ content, className }: Props) => { const { t } = useTranslation() const [isCopied, setIsCopied] = useState(false) @@ -30,8 +28,7 @@ const CopyFeedback = ({ content, selectorId, className }: Props) => { return ( { className={`w-8 h-8 cursor-pointer hover:bg-gray-100 rounded-lg ${ className ?? '' }`} - onMouseLeave={onMouseLeave} >
- +
) } diff --git a/web/app/components/base/copy-icon/index.tsx b/web/app/components/base/copy-icon/index.tsx index 0eb0356ffe..425a9ad293 100644 --- a/web/app/components/base/copy-icon/index.tsx +++ b/web/app/components/base/copy-icon/index.tsx @@ -3,7 +3,7 @@ import React, { useState } from 'react' import { useTranslation } from 'react-i18next' import { debounce } from 'lodash-es' import copy from 'copy-to-clipboard' -import TooltipPlus from '../tooltip-plus' +import Tooltip from '../tooltip' import { Clipboard, ClipboardCheck, @@ -29,7 +29,7 @@ export const CopyIcon = ({ content }: Props) => { }, 100) return ( - { ) }
- +
) } diff --git a/web/app/components/base/features/feature-panel/suggested-questions-after-answer/index.tsx b/web/app/components/base/features/feature-panel/suggested-questions-after-answer/index.tsx index e424c4ead5..e6d0b6e7e0 100644 --- a/web/app/components/base/features/feature-panel/suggested-questions-after-answer/index.tsx +++ b/web/app/components/base/features/feature-panel/suggested-questions-after-answer/index.tsx @@ -2,11 +2,8 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import { - RiQuestionLine, -} from '@remixicon/react' import { MessageSmileSquare } from '@/app/components/base/icons/src/vender/solid/communication' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' const SuggestedQuestionsAfterAnswer: FC = () => { const { t } = useTranslation() @@ -18,9 +15,7 @@ const SuggestedQuestionsAfterAnswer: FC = () => {
{t('appDebug.feature.suggestedQuestionsAfterAnswer.title')}
- - - +
{t('appDebug.feature.suggestedQuestionsAfterAnswer.resDes')}
diff --git a/web/app/components/base/features/feature-panel/text-to-speech/param-config-content.tsx b/web/app/components/base/features/feature-panel/text-to-speech/param-config-content.tsx index a5a2eb7bb7..e923d9a333 100644 --- a/web/app/components/base/features/feature-panel/text-to-speech/param-config-content.tsx +++ b/web/app/components/base/features/feature-panel/text-to-speech/param-config-content.tsx @@ -2,9 +2,6 @@ import useSWR from 'swr' import produce from 'immer' import React, { Fragment } from 'react' -import { - RiQuestionLine, -} from '@remixicon/react' import { usePathname } from 'next/navigation' import { useTranslation } from 'react-i18next' import { Listbox, Transition } from '@headlessui/react' @@ -74,13 +71,16 @@ const VoiceParamConfig = ({
{t('appDebug.voice.voiceSettings.language')}
- - {t('appDebug.voice.voiceSettings.resolutionTooltip').split('\n').map(item => ( -
{item}
- ))} -
} selector='config-resolution-tooltip'> - - + + {t('appDebug.voice.voiceSettings.resolutionTooltip').split('\n').map(item => ( +
{item} +
+ ))} +
+ } + />
= ({ )} {item.progress === -1 && ( - - + )}
)} diff --git a/web/app/components/base/input/index.tsx b/web/app/components/base/input/index.tsx index 5ab8249446..cac5efd879 100644 --- a/web/app/components/base/input/index.tsx +++ b/web/app/components/base/input/index.tsx @@ -1,43 +1,83 @@ -'use client' -import type { SVGProps } from 'react' -import React, { useState } from 'react' +import type { CSSProperties } from 'react' +import React from 'react' import { useTranslation } from 'react-i18next' -import cn from 'classnames' +import { RiCloseCircleFill, RiErrorWarningLine, RiSearchLine } from '@remixicon/react' +import { type VariantProps, cva } from 'class-variance-authority' +import cn from '@/utils/classnames' -type InputProps = { - placeholder?: string - value?: string - defaultValue?: string - onChange?: (v: string) => void - className?: string - wrapperClassName?: string - type?: string - showPrefix?: React.ReactNode - prefixIcon?: React.ReactNode -} - -const GlassIcon = ({ className }: SVGProps) => ( - - - +export const inputVariants = cva( + '', + { + variants: { + size: { + regular: 'px-3 radius-md system-sm-regular', + large: 'px-4 radius-lg system-md-regular', + }, + }, + defaultVariants: { + size: 'regular', + }, + }, ) -const Input = ({ value, defaultValue, onChange, className = '', wrapperClassName = '', placeholder, type, showPrefix, prefixIcon }: InputProps) => { - const [localValue, setLocalValue] = useState(value ?? defaultValue) +export type InputProps = { + showLeftIcon?: boolean + showClearIcon?: boolean + onClear?: () => void + disabled?: boolean + destructive?: boolean + wrapperClassName?: string + styleCss?: CSSProperties +} & React.InputHTMLAttributes & VariantProps + +const Input = ({ + size, + disabled, + destructive, + showLeftIcon, + showClearIcon, + onClear, + wrapperClassName, + className, + styleCss, + value, + placeholder, + onChange, + ...props +}: InputProps) => { const { t } = useTranslation() return ( -
- {showPrefix && {prefixIcon ?? }} +
+ {showLeftIcon && } { - setLocalValue(e.target.value) - onChange && onChange(e.target.value) - }} + style={styleCss} + className={cn( + 'w-full py-[7px] bg-components-input-bg-normal border border-transparent text-components-input-text-filled hover:bg-components-input-bg-hover hover:border-components-input-border-hover focus:bg-components-input-bg-active focus:border-components-input-border-active focus:shadow-xs placeholder:text-components-input-text-placeholder appearance-none outline-none caret-primary-600', + inputVariants({ size }), + showLeftIcon && 'pl-[26px]', + showLeftIcon && size === 'large' && 'pl-7', + showClearIcon && value && 'pr-[26px]', + showClearIcon && value && size === 'large' && 'pr-7', + destructive && 'pr-[26px]', + destructive && size === 'large' && 'pr-7', + disabled && 'bg-components-input-bg-disabled border-transparent text-components-input-text-filled-disabled cursor-not-allowed hover:bg-components-input-bg-disabled hover:border-transparent', + destructive && 'bg-components-input-bg-destructive border-components-input-border-destructive text-components-input-text-filled hover:bg-components-input-bg-destructive hover:border-components-input-border-destructive focus:bg-components-input-bg-destructive focus:border-components-input-border-destructive', + className, + )} + placeholder={placeholder ?? (showLeftIcon ? t('common.operation.search') ?? '' : 'please input')} + value={value} + onChange={onChange} + disabled={disabled} + {...props} /> + {showClearIcon && value && !disabled && !destructive && ( +
+ +
+ )} + {destructive && ( + + )}
) } diff --git a/web/app/components/base/markdown.tsx b/web/app/components/base/markdown.tsx index af4b13ff70..6e8ae6c9e6 100644 --- a/web/app/components/base/markdown.tsx +++ b/web/app/components/base/markdown.tsx @@ -8,7 +8,7 @@ import RemarkGfm from 'remark-gfm' import SyntaxHighlighter from 'react-syntax-highlighter' import { atelierHeathLight } from 'react-syntax-highlighter/dist/esm/styles/hljs' import type { RefObject } from 'react' -import { memo, useEffect, useMemo, useRef, useState } from 'react' +import { Component, memo, useEffect, useMemo, useRef, useState } from 'react' import type { CodeComponent } from 'react-markdown/lib/ast-to-react' import cn from '@/utils/classnames' import CopyBtn from '@/app/components/base/copy-btn' @@ -104,7 +104,7 @@ const CodeBlock: CodeComponent = memo(({ inline, className, children, ...props } const match = /language-(\w+)/.exec(className || '') const language = match?.[1] const languageShowName = getCorrectCapitalizationLanguageName(language || '') - let chartData = JSON.parse(String('{"title":{"text":"Something went wrong."}}').replace(/\n$/, '')) + let chartData = JSON.parse(String('{"title":{"text":"ECharts error - Wrong JSON format."}}').replace(/\n$/, '')) if (language === 'echarts') { try { chartData = JSON.parse(String(children).replace(/\n$/, '')) @@ -143,10 +143,10 @@ const CodeBlock: CodeComponent = memo(({ inline, className, children, ...props } ? () : ( (language === 'echarts') - ? (
-
) +
) : ( ) } + +// **Add an ECharts runtime error handler +// Avoid error #7832 (Crash when ECharts accesses undefined objects) +// This can happen when a component attempts to access an undefined object that references an unregistered map, causing the program to crash. + +export default class ErrorBoundary extends Component { + constructor(props) { + super(props) + this.state = { hasError: false } + } + + componentDidCatch(error, errorInfo) { + this.setState({ hasError: true }) + console.error(error, errorInfo) + } + + render() { + if (this.state.hasError) + return
Oops! ECharts reported a runtime error.
(see the browser console for more information)
+ return this.props.children + } +} diff --git a/web/app/components/base/param-item/index.tsx b/web/app/components/base/param-item/index.tsx index 20beb7a9a1..318b0fb0e3 100644 --- a/web/app/components/base/param-item/index.tsx +++ b/web/app/components/base/param-item/index.tsx @@ -1,10 +1,6 @@ 'use client' import type { FC } from 'react' -import { - RiQuestionLine, -} from '@remixicon/react' - -import Tooltip from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import Slider from '@/app/components/base/slider' import Switch from '@/app/components/base/switch' @@ -40,9 +36,9 @@ const ParamItem: FC = ({ className, id, name, noTooltip, tip, step = 0.1, )} {name} {!noTooltip && ( - {tip}
}> - - + {tip}
} + /> )}
diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx index 39193fc31d..65f3dad3a2 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx @@ -25,7 +25,7 @@ import { BubbleX, Env } from '@/app/components/base/icons/src/vender/line/others import { VarBlockIcon } from '@/app/components/workflow/block-icon' import { Line3 } from '@/app/components/base/icons/src/public/common' import { isConversationVar, isENV, isSystemVar } from '@/app/components/workflow/nodes/_base/components/variable/utils' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' type WorkflowVariableBlockComponentProps = { nodeKey: string @@ -113,9 +113,9 @@ const WorkflowVariableBlockComponent = ({ if (!node && !isEnv && !isChatVar) { return ( - + {Item} - + ) } diff --git a/web/app/components/base/qrcode/index.tsx b/web/app/components/base/qrcode/index.tsx index 721f8c9029..c9323992e9 100644 --- a/web/app/components/base/qrcode/index.tsx +++ b/web/app/components/base/qrcode/index.tsx @@ -2,8 +2,8 @@ import React, { useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import QRCode from 'qrcode.react' -import Tooltip from '../tooltip' import QrcodeStyle from './style.module.css' +import Tooltip from '@/app/components/base/tooltip' type Props = { content: string @@ -51,8 +51,7 @@ const ShareQRCode = ({ content, selectorId, className }: Props) => { return (
+ +type Props = { + order?: string + value: number | string + items: Item[] + onSelect: (item: any) => void +} +const Sort: FC = ({ + order, + value, + items, + onSelect, +}) => { + const { t } = useTranslation() + const [open, setOpen] = useState(false) + + const triggerContent = useMemo(() => { + return items.find(item => item.value === value)?.name || '' + }, [items, value]) + + return ( +
+ +
+ setOpen(v => !v)} + className='block' + > +
+
+
{t('appLog.filter.sortBy')}
+
+ {triggerContent} +
+
+ +
+
+ +
+
+ {items.map(item => ( +
{ + onSelect(`${order}${item.value}`) + setOpen(false) + }} + > +
{item.name}
+ {value === item.value && } +
+ ))} +
+
+
+
+
+
onSelect(`${order ? '' : '-'}${value}`)}> + {!order && } + {order && } +
+
+ + ) +} + +export default Sort diff --git a/web/app/components/base/tag-management/selector.tsx b/web/app/components/base/tag-management/selector.tsx index 74e0357064..fd271a82e8 100644 --- a/web/app/components/base/tag-management/selector.tsx +++ b/web/app/components/base/tag-management/selector.tsx @@ -68,6 +68,7 @@ const Panel = (props: PanelProps) => { ...tagList, newTag, ]) + setKeywords('') setCreating(false) onCreate() } @@ -123,11 +124,8 @@ const Panel = (props: PanelProps) => { handleValueChange() }) - const onMouseLeave = async () => { - props.onClose?.() - } return ( -
+
diff --git a/web/app/components/base/tag-management/tag-item-editor.tsx b/web/app/components/base/tag-management/tag-item-editor.tsx index f20e61a43c..3735695302 100644 --- a/web/app/components/base/tag-management/tag-item-editor.tsx +++ b/web/app/components/base/tag-management/tag-item-editor.tsx @@ -8,7 +8,7 @@ import { useDebounceFn } from 'ahooks' import { useContext } from 'use-context-selector' import { useTranslation } from 'react-i18next' import { useStore as useTagStore } from './store' -import TagRemoveModal from './tag-remove-modal' +import Confirm from '@/app/components/base/confirm' import cn from '@/utils/classnames' import type { Tag } from '@/app/components/base/tag-management/constant' import { ToastContext } from '@/app/components/base/toast' @@ -134,14 +134,15 @@ const TagItemEditor: FC = ({ /> )}
- { handleRemove() setShowRemoveModal(false) }} - onClose={() => setShowRemoveModal(false)} + onCancel={() => setShowRemoveModal(false)} /> ) diff --git a/web/app/components/base/tooltip-plus/index.tsx b/web/app/components/base/tooltip-plus/index.tsx deleted file mode 100644 index 1f8a091fa5..0000000000 --- a/web/app/components/base/tooltip-plus/index.tsx +++ /dev/null @@ -1,109 +0,0 @@ -'use client' -import type { FC } from 'react' -import React, { useEffect, useRef, useState } from 'react' -import { useBoolean } from 'ahooks' -import type { OffsetOptions, Placement } from '@floating-ui/react' -import cn from '@/utils/classnames' -import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '@/app/components/base/portal-to-follow-elem' -export type TooltipProps = { - position?: Placement - triggerMethod?: 'hover' | 'click' - disabled?: boolean - popupContent: React.ReactNode - children: React.ReactNode - hideArrow?: boolean - popupClassName?: string - offset?: OffsetOptions - asChild?: boolean -} - -const arrow = ( - -) - -const Tooltip: FC = ({ - position = 'top', - triggerMethod = 'hover', - disabled = false, - popupContent, - children, - hideArrow, - popupClassName, - offset, - asChild, -}) => { - const [open, setOpen] = useState(false) - const [isHoverPopup, { - setTrue: setHoverPopup, - setFalse: setNotHoverPopup, - }] = useBoolean(false) - - const isHoverPopupRef = useRef(isHoverPopup) - useEffect(() => { - isHoverPopupRef.current = isHoverPopup - }, [isHoverPopup]) - - const [isHoverTrigger, { - setTrue: setHoverTrigger, - setFalse: setNotHoverTrigger, - }] = useBoolean(false) - - const isHoverTriggerRef = useRef(isHoverTrigger) - useEffect(() => { - isHoverTriggerRef.current = isHoverTrigger - }, [isHoverTrigger]) - - const handleLeave = (isTrigger: boolean) => { - if (isTrigger) - setNotHoverTrigger() - - else - setNotHoverPopup() - - // give time to move to the popup - setTimeout(() => { - if (!isHoverPopupRef.current && !isHoverTriggerRef.current) - setOpen(false) - }, 500) - } - - return ( - - triggerMethod === 'click' && setOpen(v => !v)} - onMouseEnter={() => { - if (triggerMethod === 'hover') { - setHoverTrigger() - setOpen(true) - } - }} - onMouseLeave={() => triggerMethod === 'hover' && handleLeave(true)} - asChild={asChild} - > - {children} - - -
triggerMethod === 'hover' && setHoverPopup()} - onMouseLeave={() => triggerMethod === 'hover' && handleLeave(false)} - > - {popupContent} - {!hideArrow && arrow} -
-
-
- ) -} - -export default React.memo(Tooltip) diff --git a/web/app/components/base/tooltip/index.tsx b/web/app/components/base/tooltip/index.tsx index e7795c6537..f3b4cff132 100644 --- a/web/app/components/base/tooltip/index.tsx +++ b/web/app/components/base/tooltip/index.tsx @@ -1,52 +1,112 @@ 'use client' import type { FC } from 'react' -import React from 'react' -import { Tooltip as ReactTooltip } from 'react-tooltip' // fixed version to 5.8.3 https://github.com/ReactTooltip/react-tooltip/issues/972 -import classNames from '@/utils/classnames' -import 'react-tooltip/dist/react-tooltip.css' - -type TooltipProps = { - selector: string - content?: string +import React, { useEffect, useRef, useState } from 'react' +import { useBoolean } from 'ahooks' +import type { OffsetOptions, Placement } from '@floating-ui/react' +import { RiQuestionLine } from '@remixicon/react' +import cn from '@/utils/classnames' +import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '@/app/components/base/portal-to-follow-elem' +export type TooltipProps = { + position?: Placement + triggerMethod?: 'hover' | 'click' + triggerClassName?: string disabled?: boolean - htmlContent?: React.ReactNode - className?: string // This should use !impornant to override the default styles eg: '!bg-white' - position?: 'top' | 'right' | 'bottom' | 'left' - clickable?: boolean - children: React.ReactNode - noArrow?: boolean + popupContent?: React.ReactNode + children?: React.ReactNode + popupClassName?: string + offset?: OffsetOptions + needsDelay?: boolean + asChild?: boolean } const Tooltip: FC = ({ - selector, - content, - disabled, position = 'top', + triggerMethod = 'hover', + triggerClassName, + disabled = false, + popupContent, children, - htmlContent, - className, - clickable, - noArrow, + popupClassName, + offset, + asChild = true, + needsDelay = false, }) => { + const [open, setOpen] = useState(false) + const [isHoverPopup, { + setTrue: setHoverPopup, + setFalse: setNotHoverPopup, + }] = useBoolean(false) + + const isHoverPopupRef = useRef(isHoverPopup) + useEffect(() => { + isHoverPopupRef.current = isHoverPopup + }, [isHoverPopup]) + + const [isHoverTrigger, { + setTrue: setHoverTrigger, + setFalse: setNotHoverTrigger, + }] = useBoolean(false) + + const isHoverTriggerRef = useRef(isHoverTrigger) + useEffect(() => { + isHoverTriggerRef.current = isHoverTrigger + }, [isHoverTrigger]) + + const handleLeave = (isTrigger: boolean) => { + if (isTrigger) + setNotHoverTrigger() + + else + setNotHoverPopup() + + // give time to move to the popup + if (needsDelay) { + setTimeout(() => { + if (!isHoverPopupRef.current && !isHoverTriggerRef.current) + setOpen(false) + }, 500) + } + else { + setOpen(false) + } + } + return ( -
- {React.cloneElement(children as React.ReactElement, { - 'data-tooltip-id': selector, - }) - } - + triggerMethod === 'click' && setOpen(v => !v)} + onMouseEnter={() => { + if (triggerMethod === 'hover') { + setHoverTrigger() + setOpen(true) + } + }} + onMouseLeave={() => triggerMethod === 'hover' && handleLeave(true)} + asChild={asChild} > - {htmlContent && htmlContent} - -
+ {children ||
} + + + {popupContent && (
triggerMethod === 'hover' && setHoverPopup()} + onMouseLeave={() => triggerMethod === 'hover' && handleLeave(false)} + > + {popupContent} +
)} +
+ ) } -export default Tooltip +export default React.memo(Tooltip) diff --git a/web/app/components/billing/pricing/plan-item.tsx b/web/app/components/billing/pricing/plan-item.tsx index 87a20437c3..b6ac17472e 100644 --- a/web/app/components/billing/pricing/plan-item.tsx +++ b/web/app/components/billing/pricing/plan-item.tsx @@ -2,14 +2,11 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import { - RiQuestionLine, -} from '@remixicon/react' import { useContext } from 'use-context-selector' import { Plan } from '../type' import { ALL_PLANS, NUM_INFINITE, contactSalesUrl, contractSales, unAvailable } from '../config' import Toast from '../../base/toast' -import TooltipPlus from '../../base/tooltip-plus' +import Tooltip from '../../base/tooltip' import { PlanRange } from './select-plan-range' import cn from '@/utils/classnames' import { useAppContext } from '@/context/app-context' @@ -30,13 +27,11 @@ const KeyValue = ({ label, value, tooltip }: { label: string; value: string | nu
{label}
{tooltip && ( - {tooltip}
} - > - - + /> )}
{value}
@@ -136,25 +131,21 @@ const PlanItem: FC = ({
+
{t('billing.plansCommon.supportItems.llmLoadingBalancing')}
- {t('billing.plansCommon.supportItems.llmLoadingBalancingTooltip')}
} - > - - + />
+
 {t('billing.plansCommon.supportItems.ragAPIRequest')}
- {t('billing.plansCommon.ragAPIRequestTooltip')}
} - > - - + />
{comingSoon}
diff --git a/web/app/components/billing/priority-label/index.tsx b/web/app/components/billing/priority-label/index.tsx index d8ad27b6e0..36338cf4a8 100644 --- a/web/app/components/billing/priority-label/index.tsx +++ b/web/app/components/billing/priority-label/index.tsx @@ -9,7 +9,7 @@ import { ZapFast, ZapNarrow, } from '@/app/components/base/icons/src/vender/solid/general' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' const PriorityLabel = () => { const { t } = useTranslation() @@ -27,7 +27,7 @@ const PriorityLabel = () => { }, [plan]) return ( -
{`${t('billing.plansCommon.documentProcessingPriority')}: ${t(`billing.plansCommon.priority.${priority}`)}`}
{ @@ -53,7 +53,7 @@ const PriorityLabel = () => { } {t(`billing.plansCommon.priority.${priority}`)} -
+
) } diff --git a/web/app/components/billing/usage-info/index.tsx b/web/app/components/billing/usage-info/index.tsx index e929249584..ee41508ea6 100644 --- a/web/app/components/billing/usage-info/index.tsx +++ b/web/app/components/billing/usage-info/index.tsx @@ -2,7 +2,6 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import { InfoCircle } from '../../base/icons/src/vender/line/general' import ProgressBar from '../progress-bar' import { NUM_INFINITE } from '../config' import Tooltip from '@/app/components/base/tooltip' @@ -48,11 +47,13 @@ const UsageInfo: FC = ({
{name}
{tooltip && ( - - {tooltip} -
} selector='config-var-tooltip'> - - + + {tooltip} +
+ } + /> )}
diff --git a/web/app/components/custom/custom-app-header-brand/index.tsx b/web/app/components/custom/custom-app-header-brand/index.tsx deleted file mode 100644 index 9564986c28..0000000000 --- a/web/app/components/custom/custom-app-header-brand/index.tsx +++ /dev/null @@ -1,62 +0,0 @@ -import { useTranslation } from 'react-i18next' -import s from './style.module.css' -import Button from '@/app/components/base/button' -import { Grid01 } from '@/app/components/base/icons/src/vender/solid/layout' -import { Container, Database01 } from '@/app/components/base/icons/src/vender/line/development' -import { ImagePlus } from '@/app/components/base/icons/src/vender/line/images' -import { useProviderContext } from '@/context/provider-context' -import { Plan } from '@/app/components/billing/type' - -const CustomAppHeaderBrand = () => { - const { t } = useTranslation() - const { plan } = useProviderContext() - - return ( -
-
{t('custom.app.title')}
-
-
-
-
-
-
YOUR LOGO
-
-
-
-
-
-
- -
-
-
- -
-
-
- -
-
-
-
-
-
- -
- -
-
{t('custom.app.changeLogoTip')}
-
- ) -} - -export default CustomAppHeaderBrand diff --git a/web/app/components/custom/custom-app-header-brand/style.module.css b/web/app/components/custom/custom-app-header-brand/style.module.css deleted file mode 100644 index 492733ff9f..0000000000 --- a/web/app/components/custom/custom-app-header-brand/style.module.css +++ /dev/null @@ -1,3 +0,0 @@ -.mask { - background: linear-gradient(95deg, rgba(255, 255, 255, 0.00) 43.9%, rgba(255, 255, 255, 0.80) 95.76%); ; -} \ No newline at end of file diff --git a/web/app/components/custom/custom-page/index.tsx b/web/app/components/custom/custom-page/index.tsx index c3b1e93da3..75d592389d 100644 --- a/web/app/components/custom/custom-page/index.tsx +++ b/web/app/components/custom/custom-page/index.tsx @@ -1,6 +1,5 @@ import { useTranslation } from 'react-i18next' import CustomWebAppBrand from '../custom-web-app-brand' -import CustomAppHeaderBrand from '../custom-app-header-brand' import s from '../style.module.css' import GridMask from '@/app/components/base/grid-mask' import UpgradeBtn from '@/app/components/billing/upgrade-btn' @@ -13,7 +12,6 @@ const CustomPage = () => { const { plan, enableBilling } = useProviderContext() const showBillingTip = enableBilling && plan.type === Plan.sandbox - const showCustomAppHeaderBrand = enableBilling && plan.type === Plan.sandbox const showContact = enableBilling && (plan.type === Plan.professional || plan.type === Plan.team) return ( @@ -32,14 +30,6 @@ const CustomPage = () => { ) } - { - showCustomAppHeaderBrand && ( - <> -
- - - ) - } { showContact && (
diff --git a/web/app/components/datasets/common/retrieval-method-config/index.tsx b/web/app/components/datasets/common/retrieval-method-config/index.tsx index 1e407b62e1..20d93568ad 100644 --- a/web/app/components/datasets/common/retrieval-method-config/index.tsx +++ b/web/app/components/datasets/common/retrieval-method-config/index.tsx @@ -11,6 +11,11 @@ import { FileSearch02 } from '@/app/components/base/icons/src/vender/solid/files import { useProviderContext } from '@/context/provider-context' import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { + DEFAULT_WEIGHTED_SCORE, + RerankingModeEnum, + WeightedScoreEnum, +} from '@/models/datasets' type Props = { value: RetrievalConfig @@ -32,6 +37,18 @@ const RetrievalMethodConfig: FC = ({ reranking_provider_name: rerankDefaultModel?.provider.provider || '', reranking_model_name: rerankDefaultModel?.model || '', }, + reranking_mode: passValue.reranking_mode || (rerankDefaultModel ? RerankingModeEnum.RerankingModel : RerankingModeEnum.WeightedScore), + weights: passValue.weights || { + weight_type: WeightedScoreEnum.Customized, + vector_setting: { + vector_weight: DEFAULT_WEIGHTED_SCORE.other.semantic, + embedding_provider_name: '', + embedding_model_name: '', + }, + keyword_setting: { + keyword_weight: DEFAULT_WEIGHTED_SCORE.other.keyword, + }, + }, } } return passValue diff --git a/web/app/components/datasets/common/retrieval-param-config/index.tsx b/web/app/components/datasets/common/retrieval-param-config/index.tsx index 98676f2e83..323e47f3b4 100644 --- a/web/app/components/datasets/common/retrieval-param-config/index.tsx +++ b/web/app/components/datasets/common/retrieval-param-config/index.tsx @@ -2,15 +2,13 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import { - RiQuestionLine, -} from '@remixicon/react' + import cn from '@/utils/classnames' import TopKItem from '@/app/components/base/param-item/top-k-item' import ScoreThresholdItem from '@/app/components/base/param-item/score-threshold-item' import { RETRIEVE_METHOD } from '@/types/app' import Switch from '@/app/components/base/switch' -import Tooltip from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import type { RetrievalConfig } from '@/types/app' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' import { useModelListAndDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' @@ -114,9 +112,11 @@ const RetrievalParamConfig: FC = ({ )}
{t('common.modelProvider.rerankModel.key')} - {t('common.modelProvider.rerankModel.tip')}
}> - - + {t('common.modelProvider.rerankModel.tip')}
+ } + />
= ({
{option.label}
{option.tips}
} - hideArrow - > - - + triggerClassName='ml-0.5 w-3.5 h-3.5' + />
)) } diff --git a/web/app/components/datasets/create/embedding-process/index.tsx b/web/app/components/datasets/create/embedding-process/index.tsx index 1e340d692f..7786582085 100644 --- a/web/app/components/datasets/create/embedding-process/index.tsx +++ b/web/app/components/datasets/create/embedding-process/index.tsx @@ -13,8 +13,7 @@ import cn from '@/utils/classnames' import { FieldInfo } from '@/app/components/datasets/documents/detail/metadata' import Button from '@/app/components/base/button' import type { FullDocumentDetail, IndexingStatusResponse, ProcessRuleResponse } from '@/models/datasets' -import { formatNumber } from '@/utils/format' -import { fetchIndexingStatusBatch as doFetchIndexingStatus, fetchIndexingEstimateBatch, fetchProcessRule } from '@/service/datasets' +import { fetchIndexingStatusBatch as doFetchIndexingStatus, fetchProcessRule } from '@/service/datasets' import { DataSourceType } from '@/models/datasets' import NotionIcon from '@/app/components/base/notion-icon' import PriorityLabel from '@/app/components/billing/priority-label' @@ -22,7 +21,7 @@ import { Plan } from '@/app/components/billing/type' import { ZapFast } from '@/app/components/base/icons/src/vender/solid/general' import UpgradeBtn from '@/app/components/billing/upgrade-btn' import { useProviderContext } from '@/context/provider-context' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import { sleep } from '@/utils' type Props = { @@ -142,14 +141,6 @@ const EmbeddingProcess: FC = ({ datasetId, batchId, documents = [], index }, apiParams => fetchProcessRule(omit(apiParams, 'action')), { revalidateOnFocus: false, }) - // get cost - const { data: indexingEstimateDetail } = useSWR({ - action: 'fetchIndexingEstimateBatch', - datasetId, - batchId, - }, apiParams => fetchIndexingEstimateBatch(omit(apiParams, 'action')), { - revalidateOnFocus: false, - }) const router = useRouter() const navToDocumentList = () => { @@ -190,28 +181,11 @@ const EmbeddingProcess: FC = ({ datasetId, batchId, documents = [], index return ( <> -
+
{isEmbedding && t('datasetDocuments.embedding.processing')} {isEmbeddingCompleted && t('datasetDocuments.embedding.completed')}
-
- {indexingType === 'high_quality' && ( -
-
- {t('datasetDocuments.embedding.highQuality')} · {t('datasetDocuments.embedding.estimate')} - {formatNumber(indexingEstimateDetail?.tokens || 0)}tokens - (${formatNumber(indexingEstimateDetail?.total_price || 0)}) -
- )} - {indexingType === 'economy' && ( -
-
- {t('datasetDocuments.embedding.economy')} · {t('datasetDocuments.embedding.estimate')} - 0tokens -
- )} -
{ enableBilling && plan.type !== Plan.team && ( @@ -259,16 +233,18 @@ const EmbeddingProcess: FC = ({ datasetId, batchId, documents = [], index
{`${getSourcePercent(indexingStatusDetail)}%`}
)} {indexingStatusDetail.indexing_status === 'error' && indexingStatusDetail.error && ( - - {indexingStatusDetail.error} -
- )}> + + {indexingStatusDetail.error} +
+ )} + >
Error
- + )} {indexingStatusDetail.indexing_status === 'error' && !indexingStatusDetail.error && (
diff --git a/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx b/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx index e9247c49df..c490d50b88 100644 --- a/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx +++ b/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx @@ -58,7 +58,7 @@ const EmptyDatasetCreationModal = ({
{t('datasetCreation.stepOne.modal.tip')}
{t('datasetCreation.stepOne.modal.input')}
- + setInputValue(e.target.value)} />
diff --git a/web/app/components/datasets/create/step-two/index.module.css b/web/app/components/datasets/create/step-two/index.module.css index 24a62c8e3c..f89d6d67ea 100644 --- a/web/app/components/datasets/create/step-two/index.module.css +++ b/web/app/components/datasets/create/step-two/index.module.css @@ -30,7 +30,7 @@ } .indexItem { - min-height: 146px; + min-height: 126px; } .indexItem .disableMask { @@ -121,10 +121,6 @@ @apply pb-1; } -.radioItem.indexItem .typeHeader .tip { - @apply pb-3; -} - .radioItem .typeIcon { position: absolute; top: 18px; @@ -264,7 +260,7 @@ } .input { - @apply inline-flex h-9 w-full py-1 px-2 rounded-lg text-xs leading-normal; + @apply inline-flex h-9 w-full py-1 px-2 pr-14 rounded-lg text-xs leading-normal; @apply bg-gray-100 caret-primary-600 hover:bg-gray-100 focus:ring-1 focus:ring-inset focus:ring-gray-200 focus-visible:outline-none focus:bg-white placeholder:text-gray-400; } diff --git a/web/app/components/datasets/create/step-two/index.tsx b/web/app/components/datasets/create/step-two/index.tsx index 3849f817d6..b25ca71311 100644 --- a/web/app/components/datasets/create/step-two/index.tsx +++ b/web/app/components/datasets/create/step-two/index.tsx @@ -7,16 +7,14 @@ import { XMarkIcon } from '@heroicons/react/20/solid' import { RocketLaunchIcon } from '@heroicons/react/24/outline' import { RiCloseLine, - RiQuestionLine, } from '@remixicon/react' import Link from 'next/link' import { groupBy } from 'lodash-es' -import RetrievalMethodInfo from '../../common/retrieval-method-info' import PreviewItem, { PreviewType } from './preview-item' import LanguageSelect from './language-select' import s from './index.module.css' import cn from '@/utils/classnames' -import type { CrawlOptions, CrawlResultItem, CreateDocumentReq, CustomFile, FileIndexingEstimateResponse, FullDocumentDetail, IndexingEstimateParams, IndexingEstimateResponse, NotionInfo, PreProcessingRule, ProcessRule, Rules, createDocumentResponse } from '@/models/datasets' +import type { CrawlOptions, CrawlResultItem, CreateDocumentReq, CustomFile, FileIndexingEstimateResponse, FullDocumentDetail, IndexingEstimateParams, NotionInfo, PreProcessingRule, ProcessRule, Rules, createDocumentResponse } from '@/models/datasets' import { createDocument, createFirstDocument, @@ -43,9 +41,10 @@ import { IS_CE_EDITION } from '@/config' import { RETRIEVE_METHOD } from '@/types/app' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import Tooltip from '@/app/components/base/tooltip' -import TooltipPlus from '@/app/components/base/tooltip-plus' -import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { useDefaultModel, useModelList, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { LanguagesSupported } from '@/i18n/language' +import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' +import type { DefaultModel } from '@/app/components/header/account-setting/model-provider-page/declarations' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { Globe01 } from '@/app/components/base/icons/src/vender/line/mapsAndTravel' @@ -112,7 +111,7 @@ const StepTwo = ({ const [previewScrolled, setPreviewScrolled] = useState(false) const [segmentationType, setSegmentationType] = useState(SegmentType.AUTO) const [segmentIdentifier, setSegmentIdentifier] = useState('\\n') - const [max, setMax] = useState(500) + const [max, setMax] = useState(5000) // default chunk length const [overlap, setOverlap] = useState(50) const [rules, setRules] = useState([]) const [defaultConfig, setDefaultConfig] = useState() @@ -126,13 +125,14 @@ const StepTwo = ({ const [docForm, setDocForm] = useState( (datasetId && documentDetail) ? documentDetail.doc_form : DocForm.TEXT, ) - const [docLanguage, setDocLanguage] = useState(locale !== LanguagesSupported[1] ? 'English' : 'Chinese') + const [docLanguage, setDocLanguage] = useState( + (datasetId && documentDetail) ? documentDetail.doc_language : (locale !== LanguagesSupported[1] ? 'English' : 'Chinese'), + ) const [QATipHide, setQATipHide] = useState(false) const [previewSwitched, setPreviewSwitched] = useState(false) const [showPreview, { setTrue: setShowPreview, setFalse: hidePreview }] = useBoolean() const [customFileIndexingEstimate, setCustomFileIndexingEstimate] = useState(null) const [automaticFileIndexingEstimate, setAutomaticFileIndexingEstimate] = useState(null) - const [estimateTokes, setEstimateTokes] = useState | null>(null) const fileIndexingEstimate = (() => { return segmentationType === SegmentType.AUTO ? automaticFileIndexingEstimate : customFileIndexingEstimate @@ -193,13 +193,10 @@ const StepTwo = ({ const fetchFileIndexingEstimate = async (docForm = DocForm.TEXT) => { // eslint-disable-next-line @typescript-eslint/no-use-before-define const res = await didFetchFileIndexingEstimate(getFileIndexingEstimateParams(docForm)!) - if (segmentationType === SegmentType.CUSTOM) { + if (segmentationType === SegmentType.CUSTOM) setCustomFileIndexingEstimate(res) - } - else { + else setAutomaticFileIndexingEstimate(res) - indexType === IndexingType.QUALIFIED && setEstimateTokes({ tokens: res.tokens, total_price: res.total_price }) - } } const confirmChangeCustomConfig = () => { @@ -311,6 +308,19 @@ const StepTwo = ({ defaultModel: rerankDefaultModel, currentModel: isRerankDefaultModelVaild, } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) + const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding) + const { data: defaultEmbeddingModel } = useDefaultModel(ModelTypeEnum.textEmbedding) + const [embeddingModel, setEmbeddingModel] = useState( + currentDataset?.embedding_model + ? { + provider: currentDataset.embedding_model_provider, + model: currentDataset.embedding_model, + } + : { + provider: defaultEmbeddingModel?.provider.provider || '', + model: defaultEmbeddingModel?.model || '', + }, + ) const getCreationParams = () => { let params if (segmentationType === SegmentType.CUSTOM && overlap > max) { @@ -325,6 +335,8 @@ const StepTwo = ({ process_rule: getProcessRule(), // eslint-disable-next-line @typescript-eslint/no-use-before-define retrieval_model: retrievalConfig, // Readonly. If want to changed, just go to settings page. + embedding_model: embeddingModel.model, // Readonly + embedding_model_provider: embeddingModel.provider, // Readonly } as CreateDocumentReq } else { // create @@ -361,6 +373,8 @@ const StepTwo = ({ doc_language: docLanguage, retrieval_model: postRetrievalConfig, + embedding_model: embeddingModel.model, + embedding_model_provider: embeddingModel.provider, } as CreateDocumentReq if (dataSourceType === DataSourceType.FILE) { params.data_source.info_list.file_info_list = { @@ -556,7 +570,7 @@ const StepTwo = ({ className='border-[0.5px] !h-8 hover:outline hover:outline-[0.5px] hover:outline-gray-300 text-gray-700 font-medium bg-white shadow-[0px_1px_2px_0px_rgba(16,24,40,0.05)]' onClick={setShowPreview} > - +
{t('datasetCreation.stepTwo.previewTitleButton')} @@ -614,36 +628,42 @@ const StepTwo = ({
{t('datasetCreation.stepTwo.maxLength')}
- setMax(parseInt(e.target.value.replace(/^0+/, ''), 10))} - /> +
+ setMax(parseInt(e.target.value.replace(/^0+/, ''), 10))} + /> +
Tokens
+
{t('datasetCreation.stepTwo.overlap')} - - {t('datasetCreation.stepTwo.overlapTip')} -
- }> - - + + {t('datasetCreation.stepTwo.overlapTip')} +
+ } + /> +
+
+ setOverlap(parseInt(e.target.value.replace(/^0+/, ''), 10))} + /> +
Tokens
- setOverlap(parseInt(e.target.value.replace(/^0+/, ''), 10))} - />
@@ -676,7 +696,7 @@ const StepTwo = ({ !isAPIKeySet && s.disabled, !hasSetIndexType && indexType === IndexingType.QUALIFIED && s.active, hasSetIndexType && s.disabled, - hasSetIndexType && '!w-full', + hasSetIndexType && '!w-full !min-h-[96px]', )} onClick={() => { if (isAPIKeySet) @@ -691,16 +711,6 @@ const StepTwo = ({ {!hasSetIndexType && {t('datasetCreation.stepTwo.recommend')}}
{t('datasetCreation.stepTwo.qualifiedTip')}
-
{t('datasetCreation.stepTwo.emstimateCost')}
- { - estimateTokes - ? ( -
{formatNumber(estimateTokes.tokens)} tokens(${formatNumber(estimateTokes.total_price)})
- ) - : ( -
{t('datasetCreation.stepTwo.calculating')}
- ) - }
{!isAPIKeySet && (
@@ -718,7 +728,7 @@ const StepTwo = ({ s.indexItem, !hasSetIndexType && indexType === IndexingType.ECONOMICAL && s.active, hasSetIndexType && s.disabled, - hasSetIndexType && '!w-full', + hasSetIndexType && '!w-full !min-h-[96px]', )} onClick={changeToEconomicalType} > @@ -727,13 +737,11 @@ const StepTwo = ({
{t('datasetCreation.stepTwo.economical')}
{t('datasetCreation.stepTwo.economicalTip')}
-
{t('datasetCreation.stepTwo.emstimateCost')}
-
0 tokens
)}
- {hasSetIndexType && ( + {hasSetIndexType && indexType === IndexingType.ECONOMICAL && (
{t('datasetCreation.stepTwo.indexSettedTip')} {t('datasetCreation.stepTwo.datasetSettingLink')} @@ -768,12 +776,32 @@ const StepTwo = ({ )}
)} + {/* Embedding model */} + {indexType === IndexingType.QUALIFIED && ( +
+
{t('datasetSettings.form.embeddingModel')}
+ { + setEmbeddingModel(model) + }} + /> + {!!datasetId && ( +
+ {t('datasetCreation.stepTwo.indexSettedTip')} + {t('datasetCreation.stepTwo.datasetSettingLink')} +
+ )} +
+ )} {/* Retrieval Method Config */}
{!datasetId ? (
- {t('datasetSettings.form.retrievalSetting.title')} +
{t('datasetSettings.form.retrievalSetting.title')}
{t('datasetSettings.form.retrievalSetting.learnMore')} {t('datasetSettings.form.retrievalSetting.longDescription')} @@ -787,34 +815,21 @@ const StepTwo = ({ )}
- {!datasetId - ? (<> - {getIndexing_technique() === IndexingType.QUALIFIED - ? ( - - ) - : ( - - )} - ) - : ( -
- -
- {t('datasetCreation.stepTwo.retrivalSettedTip')} - {t('datasetCreation.stepTwo.datasetSettingLink')} -
-
- )} - + ) + : ( + + ) + }
diff --git a/web/app/components/datasets/create/steps-nav-bar/index.tsx b/web/app/components/datasets/create/steps-nav-bar/index.tsx index 70724a308c..b676f3ace4 100644 --- a/web/app/components/datasets/create/steps-nav-bar/index.tsx +++ b/web/app/components/datasets/create/steps-nav-bar/index.tsx @@ -49,7 +49,7 @@ const StepsNavBar = ({ key={item} className={cn(s.stepItem, s[`step${item}`], step === item && s.active, step > item && s.done, isMobile && 'px-0')} > -
{item}
+
{step > item ? '' : item}
{isMobile ? '' : t(STEP_T_MAP[item])}
))} diff --git a/web/app/components/datasets/create/website/firecrawl/base/field.tsx b/web/app/components/datasets/create/website/firecrawl/base/field.tsx index b1b7858d78..5b5ca90c5d 100644 --- a/web/app/components/datasets/create/website/firecrawl/base/field.tsx +++ b/web/app/components/datasets/create/website/firecrawl/base/field.tsx @@ -1,12 +1,9 @@ 'use client' import type { FC } from 'react' import React from 'react' -import { - RiQuestionLine, -} from '@remixicon/react' import Input from './input' import cn from '@/utils/classnames' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' type Props = { className?: string @@ -37,11 +34,12 @@ const Field: FC = ({
{label}
{isRequired && *} {tooltip && ( - {tooltip}
- }> - - + {tooltip}
+ } + triggerClassName='ml-0.5 w-4 h-4' + /> )} = ({ value, diff --git a/web/app/components/datasets/documents/detail/completed/index.tsx b/web/app/components/datasets/documents/detail/completed/index.tsx index f2addac2e2..195b75b0a2 100644 --- a/web/app/components/datasets/documents/detail/completed/index.tsx +++ b/web/app/components/datasets/documents/detail/completed/index.tsx @@ -391,7 +391,7 @@ const Completed: FC = ({ defaultValue={'all'} className={s.select} wrapperClassName='h-fit w-[120px] mr-2' /> - + setSearchValue(e.target.value), 500)} /> = ( } -const EmbeddingDetail: FC = ({ detail, stopPosition = 'top', datasetId: dstId, documentId: docId, indexingType, detailUpdate }) => { +const EmbeddingDetail: FC = ({ detail, stopPosition = 'top', datasetId: dstId, documentId: docId, detailUpdate }) => { const onTop = stopPosition === 'top' const { t } = useTranslation() const { notify } = useContext(ToastContext) const { datasetId = '', documentId = '' } = useContext(DocumentContext) - const { indexingTechnique } = useContext(DatasetDetailContext) const localDatasetId = dstId ?? datasetId const localDocumentId = docId ?? documentId - const localIndexingTechnique = indexingType ?? indexingTechnique const [indexingStatusDetail, setIndexingStatusDetail] = useState(null) const fetchIndexingStatus = async () => { @@ -160,14 +156,6 @@ const EmbeddingDetail: FC = ({ detail, stopPosition = 'top', datasetId: d } }, [startQueryStatus, stopQueryStatus]) - const { data: indexingEstimateDetail, error: indexingEstimateErr } = useSWR({ - action: 'fetchIndexingEstimate', - datasetId: localDatasetId, - documentId: localDocumentId, - }, apiParams => fetchIndexingEstimate(omit(apiParams, 'action')), { - revalidateOnFocus: false, - }) - const { data: ruleDetail, error: ruleError } = useSWR({ action: 'fetchProcessRule', params: { documentId: localDocumentId }, @@ -250,21 +238,6 @@ const EmbeddingDetail: FC = ({ detail, stopPosition = 'top', datasetId: d
{t('datasetDocuments.embedding.segments')} {indexingStatusDetail?.completed_segments}/{indexingStatusDetail?.total_segments} · {percent}%
- {localIndexingTechnique === 'high_quaility' && ( -
-
- {t('datasetDocuments.embedding.highQuality')} · {t('datasetDocuments.embedding.estimate')} - {formatNumber(indexingEstimateDetail?.tokens || 0)}tokens - (${formatNumber(indexingEstimateDetail?.total_price || 0)}) -
- )} - {localIndexingTechnique === 'economy' && ( -
-
- {t('datasetDocuments.embedding.economy')} · {t('datasetDocuments.embedding.estimate')} - 0tokens -
- )}
{!onTop && ( diff --git a/web/app/components/datasets/documents/detail/embedding/style.module.css b/web/app/components/datasets/documents/detail/embedding/style.module.css index 6dc1a5e80b..c24444ac12 100644 --- a/web/app/components/datasets/documents/detail/embedding/style.module.css +++ b/web/app/components/datasets/documents/detail/embedding/style.module.css @@ -31,7 +31,7 @@ @apply rounded-r-md; } .progressData { - @apply w-full flex justify-between items-center text-xs text-gray-700; + @apply w-full flex items-center text-xs text-gray-700; } .previewTip { @apply pb-1 pt-12 text-gray-900 text-sm font-medium; diff --git a/web/app/components/datasets/documents/detail/metadata/index.tsx b/web/app/components/datasets/documents/detail/metadata/index.tsx index 9210926022..2e0f8af961 100644 --- a/web/app/components/datasets/documents/detail/metadata/index.tsx +++ b/web/app/components/datasets/documents/detail/metadata/index.tsx @@ -79,7 +79,7 @@ export const FieldInfo: FC = ({ /> : onUpdate?.(e.target.value)} value={value} defaultValue={defaultValue} placeholder={`${t('datasetDocuments.metadata.placeholder.add')}${label}`} @@ -102,7 +102,9 @@ const IconButton: FC<{ const metadataMap = useMetadataMap() return ( - +
)} {!isExplore && ( -
+
+ {/* answer icon */} + {isEditModal && (appMode === 'chat' || appMode === 'advanced-chat' || appMode === 'agent-chat') && ( +
+
+
{t('app.answerIcon.title')}
+ setUseIconAsAnswerIcon(v)} + /> +
+

{t('app.answerIcon.descriptionInExplore')}

+
+ )} {!isEditModal && isAppsFull && }
diff --git a/web/app/components/header/account-dropdown/index.tsx b/web/app/components/header/account-dropdown/index.tsx index 2298bff82d..1dd03e5290 100644 --- a/web/app/components/header/account-dropdown/index.tsx +++ b/web/app/components/header/account-dropdown/index.tsx @@ -125,7 +125,7 @@ export default function AppSelector({ isMobile }: IAppSelecotr) { className={classNames(itemClassName, 'group justify-between')} href='https://github.com/langgenius/dify/discussions/categories/feedbacks' target='_blank' rel='noopener noreferrer'> -
{t('common.userProfile.roadmapAndFeedback')}
+
{t('common.userProfile.communityFeedback')}
@@ -149,6 +149,15 @@ export default function AppSelector({ isMobile }: IAppSelecotr) { + + +
{t('common.userProfile.roadmap')}
+ + +
{ document?.body?.getAttribute('data-public-site-about') !== 'hide' && ( diff --git a/web/app/components/header/account-setting/members-page/invited-modal/index.tsx b/web/app/components/header/account-setting/members-page/invited-modal/index.tsx index d3bcf9870e..7af19b06c3 100644 --- a/web/app/components/header/account-setting/members-page/invited-modal/index.tsx +++ b/web/app/components/header/account-setting/members-page/invited-modal/index.tsx @@ -1,5 +1,6 @@ import { CheckCircleIcon } from '@heroicons/react/24/solid' -import { QuestionMarkCircleIcon, XMarkIcon } from '@heroicons/react/24/outline' +import { XMarkIcon } from '@heroicons/react/24/outline' +import { RiQuestionLine } from '@remixicon/react' import { useTranslation } from 'react-i18next' import { useMemo } from 'react' import InvitationLink from './invitation-link' @@ -64,12 +65,11 @@ const InvitedModal = ({ failedInvationResults.map(item =>
{item.email} - +
, diff --git a/web/app/components/header/account-setting/members-page/invited-modal/invitation-link.tsx b/web/app/components/header/account-setting/members-page/invited-modal/invitation-link.tsx index 876c7217d5..e912847786 100644 --- a/web/app/components/header/account-setting/members-page/invited-modal/invitation-link.tsx +++ b/web/app/components/header/account-setting/members-page/invited-modal/invitation-link.tsx @@ -39,18 +39,14 @@ const InvitationLink = ({
{value.url}
diff --git a/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx b/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx index 3cf3fc513f..afcb4ad50e 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx @@ -1,8 +1,6 @@ import { Fragment, useState } from 'react' import type { FC } from 'react' -import { - RiQuestionLine, -} from '@remixicon/react' +import { RiQuestionLine } from '@remixicon/react' import { ValidatingTip } from '../../key-validator/ValidateStatus' import type { CredentialFormSchema, @@ -18,7 +16,7 @@ import { useLanguage } from '../hooks' import Input from './Input' import cn from '@/utils/classnames' import { SimpleSelect } from '@/app/components/base/select' -import Tooltip from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import Radio from '@/app/components/base/radio' type FormProps = { className?: string diff --git a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx index eced2a8082..e60ef418ed 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx @@ -1,8 +1,5 @@ import type { FC } from 'react' import { useEffect, useRef, useState } from 'react' -import { - RiQuestionLine, -} from '@remixicon/react' import type { ModelParameterRule } from '../declarations' import { useLanguage } from '../hooks' import { isNullOrUndefined } from '../utils' @@ -41,7 +38,7 @@ const ParameterItem: FC = ({ if (parameterRule.type === 'int' || parameterRule.type === 'float') defaultValue = isNullOrUndefined(parameterRule.default) ? (parameterRule.min || 0) : parameterRule.default - else if (parameterRule.type === 'string') + else if (parameterRule.type === 'string' || parameterRule.type === 'text') defaultValue = parameterRule.options?.length ? (parameterRule.default || '') : (parameterRule.default || '') else if (parameterRule.type === 'boolean') defaultValue = !isNullOrUndefined(parameterRule.default) ? parameterRule.default : false @@ -241,18 +238,18 @@ const ParameterItem: FC = ({ { parameterRule.help && ( {parameterRule.help[language] || parameterRule.help.en_US}
)} - > - - + popupClassName='mr-1' + triggerClassName='mr-1 w-4 h-4 shrink-0' + /> ) } { !parameterRule.required && parameterRule.name !== 'stop' && ( = ({ { disabled ? ( - = ({ } > - + ) : ( diff --git a/web/app/components/header/account-setting/model-provider-page/model-selector/deprecated-model-trigger.tsx b/web/app/components/header/account-setting/model-provider-page/model-selector/deprecated-model-trigger.tsx index 4eb7c3ba04..f40423d869 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-selector/deprecated-model-trigger.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-selector/deprecated-model-trigger.tsx @@ -3,7 +3,7 @@ import { useTranslation } from 'react-i18next' import ModelIcon from '../model-icon' import { AlertTriangle } from '@/app/components/base/icons/src/vender/line/alertsAndFeedback' import { useProviderContext } from '@/context/provider-context' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' type ModelTriggerProps = { modelName: string @@ -35,9 +35,9 @@ const ModelTrigger: FC = ({ {modelName}
- + - +
) diff --git a/web/app/components/header/account-setting/model-provider-page/model-selector/feature-icon.tsx b/web/app/components/header/account-setting/model-provider-page/model-selector/feature-icon.tsx index bf4d15ee3a..32bd58d318 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-selector/feature-icon.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-selector/feature-icon.tsx @@ -11,7 +11,7 @@ import { // MagicWand, // Robot, } from '@/app/components/base/icons/src/vender/solid/mediaAndDevices' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' type FeatureIconProps = { feature: ModelFeatureEnum @@ -25,49 +25,51 @@ const FeatureIcon: FC = ({ // if (feature === ModelFeatureEnum.agentThought) { // return ( - // // // // - // + // // ) // } // if (feature === ModelFeatureEnum.toolCall) { // return ( - // // // // - // + // // ) // } // if (feature === ModelFeatureEnum.multiToolCall) { // return ( - // // // // - // + // // ) // } if (feature === ModelFeatureEnum.vision) { return ( - - - - - +
+ + + +
+ ) } diff --git a/web/app/components/header/account-setting/model-provider-page/model-selector/model-trigger.tsx b/web/app/components/header/account-setting/model-provider-page/model-selector/model-trigger.tsx index 1ecf1e0e9d..023c6a5cd2 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-selector/model-trigger.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-selector/model-trigger.tsx @@ -12,7 +12,7 @@ import { useLanguage } from '../hooks' import ModelIcon from '../model-icon' import ModelName from '../model-name' import { AlertTriangle } from '@/app/components/base/icons/src/vender/line/alertsAndFeedback' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' type ModelTriggerProps = { open: boolean @@ -56,9 +56,9 @@ const ModelTrigger: FC = ({ { model.status !== ModelStatusEnum.active ? ( - + - + ) : ( = ({ { model.models.map(modelItem => (
{ return displayTime ? ( - + - + ) : null } diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-list-item.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-list-item.tsx index 1272627061..3fc73a62b2 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-list-item.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-list-item.tsx @@ -11,7 +11,7 @@ import Button from '@/app/components/base/button' import { Balance } from '@/app/components/base/icons/src/vender/line/financeAndECommerce' import { Settings01 } from '@/app/components/base/icons/src/vender/line/general' import Switch from '@/app/components/base/switch' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import { useProviderContext, useProviderContextSelector } from '@/context/provider-context' import { disableModel, enableModel } from '@/service/common' import { Plan } from '@/app/components/billing/type' @@ -99,9 +99,14 @@ const ModelListItem = ({ model, provider, isConfigurable, onConfig, onModifyLoad { model.deprecated ? ( - {t('common.modelProvider.modelHasBeenDeprecated')}} offset={{ mainAxis: 4 }}> + {t('common.modelProvider.modelHasBeenDeprecated')}} offset={{ mainAxis: 4 } + } + needsDelay + > - + ) : (isCurrentWorkspaceManager && (
{t('common.modelProvider.loadBalancing')} - - - +
{t('common.modelProvider.loadBalancingDescription')}
@@ -191,9 +192,9 @@ const ModelLoadBalancingConfigs = ({ clearCountdown(index)} /> ) : ( - + - +
)}
diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/priority-use-tip.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/priority-use-tip.tsx index 294a13ecfc..24e91d2214 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/priority-use-tip.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/priority-use-tip.tsx @@ -7,8 +7,7 @@ const PriorityUseTip = () => { return (
diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/quota-panel.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/quota-panel.tsx index c00933468f..0f5c265d52 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/quota-panel.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/quota-panel.tsx @@ -10,8 +10,7 @@ import { MODEL_PROVIDER_QUOTA_GET_PAID, } from '../utils' import PriorityUseTip from './priority-use-tip' -import { InfoCircle } from '@/app/components/base/icons/src/vender/line/general' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import { formatNumber } from '@/utils/format' type QuotaPanelProps = { @@ -32,13 +31,12 @@ const QuotaPanel: FC = ({
{t('common.modelProvider.quota')} - - - + } + />
{ currentQuota && ( diff --git a/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx b/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx index b2dfe4bfe4..1574785898 100644 --- a/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx @@ -1,9 +1,6 @@ import type { FC } from 'react' import { useState } from 'react' import { useTranslation } from 'react-i18next' -import { - RiQuestionLine, -} from '@remixicon/react' import ModelSelector from '../model-selector' import { useModelList, @@ -146,13 +143,13 @@ const SystemModel: FC = ({
{t('common.modelProvider.systemReasoningModel.key')} {t('common.modelProvider.systemReasoningModel.tip')}
+ popupContent={ +
+ {t('common.modelProvider.systemReasoningModel.tip')} +
} - > - - + triggerClassName='ml-0.5 w-4 h-4 shrink-0' + />
= ({
{t('common.modelProvider.embeddingModel.key')} {t('common.modelProvider.embeddingModel.tip')}
+ popupContent={ +
+ {t('common.modelProvider.embeddingModel.tip')} +
} - > - - + triggerClassName='ml-0.5 w-4 h-4 shrink-0' + />
= ({
{t('common.modelProvider.rerankModel.key')} {t('common.modelProvider.rerankModel.tip')}
+ popupContent={ +
+ {t('common.modelProvider.rerankModel.tip')} +
} - > - - + triggerClassName='ml-0.5 w-4 h-4 shrink-0' + />
= ({
{t('common.modelProvider.speechToTextModel.key')} {t('common.modelProvider.speechToTextModel.tip')}
+ popupContent={ +
+ {t('common.modelProvider.speechToTextModel.tip')} +
} - > - - + triggerClassName='ml-0.5 w-4 h-4 shrink-0' + />
= ({
{t('common.modelProvider.ttsModel.key')} {t('common.modelProvider.ttsModel.tip')}
+ popupContent={ +
+ {t('common.modelProvider.ttsModel.tip')} +
} - > - - + triggerClassName='ml-0.5 w-4 h-4 shrink-0' + />
{ })(isCurrentWorkspaceEditor, app) return { id: app.id, + icon_type: app.icon_type, icon: app.icon, icon_background: app.icon_background, + icon_url: app.icon_url, name: app.name, mode: app.mode, link, diff --git a/web/app/components/header/nav/nav-selector/index.tsx b/web/app/components/header/nav/nav-selector/index.tsx index 26f538d72d..ab5189bf92 100644 --- a/web/app/components/header/nav/nav-selector/index.tsx +++ b/web/app/components/header/nav/nav-selector/index.tsx @@ -16,13 +16,16 @@ import { Route } from '@/app/components/base/icons/src/vender/solid/mapsAndTrave import { useAppContext } from '@/context/app-context' import { useStore as useAppStore } from '@/app/components/app/store' import { FileArrow01, FilePlus01, FilePlus02 } from '@/app/components/base/icons/src/vender/line/files' +import type { AppIconType } from '@/types/app' export type NavItem = { id: string name: string link: string + icon_type: AppIconType | null icon: string icon_background: string + icon_url: string | null mode?: string } export type INavSelectorProps = { @@ -82,7 +85,7 @@ const NavSelector = ({ curNav, navs, createText, isApp, onCreate, onLoadmore }: router.push(nav.link) }} title={nav.name}>
- + {!!nav.mode && ( = ({
- +
{siteInfo.title}
{!isPC && ( diff --git a/web/app/components/share/text-generation/result/header.tsx b/web/app/components/share/text-generation/result/header.tsx index bb3c5695f3..bd5c317153 100644 --- a/web/app/components/share/text-generation/result/header.tsx +++ b/web/app/components/share/text-generation/result/header.tsx @@ -42,8 +42,7 @@ const Header: FC = ({ {showFeedback && feedback.rating && feedback.rating === 'like' && (
{ @@ -59,8 +58,7 @@ const Header: FC = ({ {showFeedback && feedback.rating && feedback.rating === 'dislike' && (
{ @@ -77,8 +75,8 @@ const Header: FC = ({ {showFeedback && !feedback.rating && (
{ @@ -91,8 +89,8 @@ const Header: FC = ({
{ diff --git a/web/app/components/tools/add-tool-modal/tools.tsx b/web/app/components/tools/add-tool-modal/tools.tsx index 8810294d98..af26dd3e25 100644 --- a/web/app/components/tools/add-tool-modal/tools.tsx +++ b/web/app/components/tools/add-tool-modal/tools.tsx @@ -68,10 +68,9 @@ const Blocks = ({ return ( )} - noArrow + needsDelay >
= ({
{t('tools.createTool.authMethod.key')} {t('tools.createTool.authMethod.keyTooltip')}
} - > - - + triggerClassName='ml-0.5 w-4 h-4' + />
= ({
{t('tools.createTool.name')} *
- { setShowEmojiPicker(true) }} className='cursor-pointer' icon={emoji.content} background={emoji.background} /> + { setShowEmojiPicker(true) }} className='cursor-pointer' iconType='emoji' icon={emoji.content} background={emoji.background} /> = ({
{t('tools.createTool.nameForToolCall')} * {t('tools.createTool.nameForToolCallPlaceHolder')}
} - selector='workflow-tool-modal-tooltip' - > - - + />
( {nodesExtraData[block.type].about}
)} - noArrow >
( {tool.description[language]}
)} - noArrow >
-
{tool.label[language]}
+
{tool.label[language]}
)) diff --git a/web/app/components/workflow/header/view-history.tsx b/web/app/components/workflow/header/view-history.tsx index a6318dbfeb..06eebfd329 100644 --- a/web/app/components/workflow/header/view-history.tsx +++ b/web/app/components/workflow/header/view-history.tsx @@ -24,7 +24,7 @@ import { PortalToFollowElemContent, PortalToFollowElemTrigger, } from '@/app/components/base/portal-to-follow-elem' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import { useStore as useAppStore } from '@/app/components/app/store' import { ClockPlay, @@ -100,7 +100,7 @@ const ViewHistory = ({ } { !withText && ( -
-
+ ) } diff --git a/web/app/components/workflow/hooks/use-nodes-interactions.ts b/web/app/components/workflow/hooks/use-nodes-interactions.ts index 87d1b4de8c..3645e18449 100644 --- a/web/app/components/workflow/hooks/use-nodes-interactions.ts +++ b/web/app/components/workflow/hooks/use-nodes-interactions.ts @@ -1027,7 +1027,7 @@ export const useNodesInteractions = () => { handleNodeSelect(node.id) }, [workflowStore, handleNodeSelect]) - const handleNodesCopy = useCallback(() => { + const handleNodesCopy = useCallback((nodeId?: string) => { if (getNodesReadOnly()) return @@ -1038,17 +1038,27 @@ export const useNodesInteractions = () => { } = store.getState() const nodes = getNodes() - const bundledNodes = nodes.filter(node => node.data._isBundled && node.data.type !== BlockEnum.Start && !node.data.isInIteration) - if (bundledNodes.length) { - setClipboardElements(bundledNodes) - return + if (nodeId) { + // If nodeId is provided, copy that specific node + const nodeToCopy = nodes.find(node => node.id === nodeId && node.data.type !== BlockEnum.Start) + if (nodeToCopy) + setClipboardElements([nodeToCopy]) } + else { + // If no nodeId is provided, fall back to the current behavior + const bundledNodes = nodes.filter(node => node.data._isBundled && node.data.type !== BlockEnum.Start && !node.data.isInIteration) - const selectedNode = nodes.find(node => node.data.selected && node.data.type !== BlockEnum.Start) + if (bundledNodes.length) { + setClipboardElements(bundledNodes) + return + } - if (selectedNode) - setClipboardElements([selectedNode]) + const selectedNode = nodes.find(node => node.data.selected && node.data.type !== BlockEnum.Start) + + if (selectedNode) + setClipboardElements([selectedNode]) + } }, [getNodesReadOnly, store, workflowStore]) const handleNodesPaste = useCallback(() => { @@ -1128,11 +1138,11 @@ export const useNodesInteractions = () => { } }, [getNodesReadOnly, workflowStore, store, reactflow, saveStateToHistory, handleSyncWorkflowDraft, handleNodeIterationChildrenCopy]) - const handleNodesDuplicate = useCallback(() => { + const handleNodesDuplicate = useCallback((nodeId?: string) => { if (getNodesReadOnly()) return - handleNodesCopy() + handleNodesCopy(nodeId) handleNodesPaste() }, [getNodesReadOnly, handleNodesCopy, handleNodesPaste]) diff --git a/web/app/components/workflow/hooks/use-shortcuts.ts b/web/app/components/workflow/hooks/use-shortcuts.ts index 9484f9c16e..666c3a45ba 100644 --- a/web/app/components/workflow/hooks/use-shortcuts.ts +++ b/web/app/components/workflow/hooks/use-shortcuts.ts @@ -37,12 +37,25 @@ export const useShortcuts = (): void => { const { handleLayout } = useWorkflowOrganize() const { - zoomIn, - zoomOut, zoomTo, + getZoom, fitView, } = useReactFlow() + // Zoom out to a minimum of 0.5 for shortcut + const constrainedZoomOut = () => { + const currentZoom = getZoom() + const newZoom = Math.max(currentZoom - 0.1, 0.5) + zoomTo(newZoom) + } + + // Zoom in to a maximum of 1 for shortcut + const constrainedZoomIn = () => { + const currentZoom = getZoom() + const newZoom = Math.min(currentZoom + 0.1, 1) + zoomTo(newZoom) + } + const shouldHandleShortcut = useCallback((e: KeyboardEvent) => { const { showFeaturesPanel } = workflowStore.getState() return !showFeaturesPanel && !isEventTargetInputArea(e.target as HTMLElement) @@ -165,7 +178,7 @@ export const useShortcuts = (): void => { useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.dash`, (e) => { if (shouldHandleShortcut(e)) { e.preventDefault() - zoomOut() + constrainedZoomOut() handleSyncWorkflowDraft() } }, { @@ -176,7 +189,7 @@ export const useShortcuts = (): void => { useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.equalsign`, (e) => { if (shouldHandleShortcut(e)) { e.preventDefault() - zoomIn() + constrainedZoomIn() handleSyncWorkflowDraft() } }, { diff --git a/web/app/components/workflow/node-contextmenu.tsx b/web/app/components/workflow/node-contextmenu.tsx index adfed37b26..311bf1fddf 100644 --- a/web/app/components/workflow/node-contextmenu.tsx +++ b/web/app/components/workflow/node-contextmenu.tsx @@ -1,5 +1,6 @@ import { memo, + useEffect, useRef, } from 'react' import { useClickAway } from 'ahooks' @@ -9,13 +10,18 @@ import type { Node } from './types' import { useStore } from './store' import { usePanelInteractions } from './hooks' -const PanelContextmenu = () => { +const NodeContextmenu = () => { const ref = useRef(null) const nodes = useNodes() - const { handleNodeContextmenuCancel } = usePanelInteractions() + const { handleNodeContextmenuCancel, handlePaneContextmenuCancel } = usePanelInteractions() const nodeMenu = useStore(s => s.nodeMenu) const currentNode = nodes.find(node => node.id === nodeMenu?.nodeId) as Node + useEffect(() => { + if (nodeMenu) + handlePaneContextmenuCancel() + }, [nodeMenu, handlePaneContextmenuCancel]) + useClickAway(() => { handleNodeContextmenuCancel() }, ref) @@ -42,4 +48,4 @@ const PanelContextmenu = () => { ) } -export default memo(PanelContextmenu) +export default memo(NodeContextmenu) diff --git a/web/app/components/workflow/nodes/_base/components/field.tsx b/web/app/components/workflow/nodes/_base/components/field.tsx index 344fc3d708..334bce2fb8 100644 --- a/web/app/components/workflow/nodes/_base/components/field.tsx +++ b/web/app/components/workflow/nodes/_base/components/field.tsx @@ -3,12 +3,11 @@ import type { FC } from 'react' import React from 'react' import { RiArrowDownSLine, - RiQuestionLine, } from '@remixicon/react' import { useBoolean } from 'ahooks' import type { DefaultTFuncReturn } from 'i18next' import cn from '@/utils/classnames' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' type Props = { className?: string @@ -40,12 +39,11 @@ const Filed: FC = ({
{title}
{tooltip && ( - - {tooltip} -
}> - - + )}
diff --git a/web/app/components/workflow/nodes/_base/components/help-link.tsx b/web/app/components/workflow/nodes/_base/components/help-link.tsx index 248eb5a546..a2b0837fbd 100644 --- a/web/app/components/workflow/nodes/_base/components/help-link.tsx +++ b/web/app/components/workflow/nodes/_base/components/help-link.tsx @@ -2,7 +2,7 @@ import { memo } from 'react' import { useTranslation } from 'react-i18next' import { RiBookOpenLine } from '@remixicon/react' import { useNodeHelpLink } from '../hooks/use-node-help-link' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import TooltipPlus from '@/app/components/base/tooltip' import type { BlockEnum } from '@/app/components/workflow/types' type HelpLinkProps = { @@ -15,7 +15,9 @@ const HelpLink = ({ const link = useNodeHelpLink(nodeType) return ( - + = ({ {readOnly &&
} {isFocus && (
-
-
+
)} diff --git a/web/app/components/workflow/nodes/_base/components/node-control.tsx b/web/app/components/workflow/nodes/_base/components/node-control.tsx index 25b941c216..1ce78220a1 100644 --- a/web/app/components/workflow/nodes/_base/components/node-control.tsx +++ b/web/app/components/workflow/nodes/_base/components/node-control.tsx @@ -19,7 +19,7 @@ import PanelOperator from './panel-operator' import { Stop, } from '@/app/components/base/icons/src/vender/line/mediaAndDevices' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' type NodeControlProps = Pick const NodeControl: FC = ({ @@ -68,11 +68,12 @@ const NodeControl: FC = ({ data._isSingleRun ? : ( - - +
) }
diff --git a/web/app/components/workflow/nodes/_base/components/option-card.tsx b/web/app/components/workflow/nodes/_base/components/option-card.tsx index 71c2c2958d..f19338d1b7 100644 --- a/web/app/components/workflow/nodes/_base/components/option-card.tsx +++ b/web/app/components/workflow/nodes/_base/components/option-card.tsx @@ -4,6 +4,7 @@ import React, { useCallback } from 'react' import type { VariantProps } from 'class-variance-authority' import { cva } from 'class-variance-authority' import cn from '@/utils/classnames' +import Tooltip from '@/app/components/base/tooltip' const variants = cva([], { variants: { @@ -26,6 +27,7 @@ type Props = { selected: boolean disabled?: boolean align?: 'left' | 'center' | 'right' + tooltip?: string } & VariantProps const OptionCard: FC = ({ @@ -35,6 +37,7 @@ const OptionCard: FC = ({ selected, disabled, align = 'center', + tooltip, }) => { const handleSelect = useCallback(() => { if (selected || disabled) @@ -54,7 +57,16 @@ const OptionCard: FC = ({ )} onClick={handleSelect} > - {title} + {title} + {tooltip + && + {tooltip} +
+ } + /> + }
) } diff --git a/web/app/components/workflow/nodes/_base/components/panel-operator/panel-operator-popup.tsx b/web/app/components/workflow/nodes/_base/components/panel-operator/panel-operator-popup.tsx index aade4d8be8..bd642fcd66 100644 --- a/web/app/components/workflow/nodes/_base/components/panel-operator/panel-operator-popup.tsx +++ b/web/app/components/workflow/nodes/_base/components/panel-operator/panel-operator-popup.tsx @@ -138,7 +138,7 @@ const PanelOperatorPopup = ({ className='flex items-center justify-between px-3 h-8 text-sm text-gray-700 rounded-lg cursor-pointer hover:bg-gray-50' onClick={() => { onClosePopup() - handleNodesDuplicate() + handleNodesDuplicate(id) }} > {t('workflow.common.duplicate')} diff --git a/web/app/components/workflow/nodes/_base/components/prompt/editor.tsx b/web/app/components/workflow/nodes/_base/components/prompt/editor.tsx index daf4bddbc4..080346b18d 100644 --- a/web/app/components/workflow/nodes/_base/components/prompt/editor.tsx +++ b/web/app/components/workflow/nodes/_base/components/prompt/editor.tsx @@ -31,7 +31,7 @@ import { useEventEmitterContextContext } from '@/context/event-emitter' import { PROMPT_EDITOR_INSERT_QUICKLY } from '@/app/components/base/prompt-editor/plugins/update-block' import { Variable02 } from '@/app/components/base/icons/src/vender/solid/development' import ActionButton from '@/app/components/base/action-button' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor/editor-support-vars' import Switch from '@/app/components/base/switch' import { Jinja } from '@/app/components/base/icons/src/vender/workflow' @@ -141,14 +141,14 @@ const Editor: FC = ({ {/* Operations */} } - hideArrow + needsDelay >
@@ -160,18 +160,17 @@ const Editor: FC = ({ }} />
- + )} {!readOnly && ( - - + )} {showRemove && ( diff --git a/web/app/components/workflow/nodes/_base/components/variable/utils.ts b/web/app/components/workflow/nodes/_base/components/variable/utils.ts index e42088fd1b..3deec09dc2 100644 --- a/web/app/components/workflow/nodes/_base/components/variable/utils.ts +++ b/web/app/components/workflow/nodes/_base/components/variable/utils.ts @@ -99,6 +99,10 @@ const formatItem = ( variable: 'sys.query', type: VarType.string, }) + res.vars.push({ + variable: 'sys.dialogue_count', + type: VarType.number, + }) res.vars.push({ variable: 'sys.conversation_id', type: VarType.string, diff --git a/web/app/components/workflow/nodes/_base/panel.tsx b/web/app/components/workflow/nodes/_base/panel.tsx index 269d8110dc..83387621fc 100644 --- a/web/app/components/workflow/nodes/_base/panel.tsx +++ b/web/app/components/workflow/nodes/_base/panel.tsx @@ -35,7 +35,7 @@ import { useWorkflowHistory, } from '@/app/components/workflow/hooks' import { canRunBySingle } from '@/app/components/workflow/utils' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import type { Node } from '@/app/components/workflow/types' import { useStore as useAppStore } from '@/app/components/app/store' import { useStore } from '@/app/components/workflow/store' @@ -127,8 +127,9 @@ const BasePanel: FC = ({
{ canRunBySingle(data.type) && !nodesReadOnly && ( -
= ({ >
-
+ ) } diff --git a/web/app/components/workflow/nodes/assigner/panel.tsx b/web/app/components/workflow/nodes/assigner/panel.tsx index 6bba0d2a31..ff5a6420f3 100644 --- a/web/app/components/workflow/nodes/assigner/panel.tsx +++ b/web/app/components/workflow/nodes/assigner/panel.tsx @@ -49,7 +49,6 @@ const Panel: FC> = ({
{writeModeTypes.map(type => ( @@ -59,6 +58,7 @@ const Panel: FC> = ({ onSelect={handleWriteModeChange(type)} selected={inputs.write_mode === type} disabled={!isSupportAppend && type === WriteMode.Append} + tooltip={type === WriteMode.Append ? t(`${i18nPrefix}.writeModeTip`)! : undefined} /> ))}
diff --git a/web/app/components/workflow/nodes/code/dependency-picker.tsx b/web/app/components/workflow/nodes/code/dependency-picker.tsx deleted file mode 100644 index 3aa6c45f20..0000000000 --- a/web/app/components/workflow/nodes/code/dependency-picker.tsx +++ /dev/null @@ -1,97 +0,0 @@ -import type { FC } from 'react' -import React, { useCallback, useState } from 'react' -import { t } from 'i18next' -import { - RiArrowDownSLine, - RiSearchLine, -} from '@remixicon/react' -import type { CodeDependency } from './types' -import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '@/app/components/base/portal-to-follow-elem' -import { Check } from '@/app/components/base/icons/src/vender/line/general' -import { XCircle } from '@/app/components/base/icons/src/vender/solid/general' - -type Props = { - value: CodeDependency - available_dependencies: CodeDependency[] - onChange: (dependency: CodeDependency) => void -} - -const DependencyPicker: FC = ({ - available_dependencies, - value, - onChange, -}) => { - const [open, setOpen] = useState(false) - const [searchText, setSearchText] = useState('') - - const handleChange = useCallback((dependency: CodeDependency) => { - return () => { - setOpen(false) - onChange(dependency) - } - }, [onChange]) - - return ( - - setOpen(!open)} className='flex-grow cursor-pointer'> -
-
{value.name}
- -
-
- -
-
- - setSearchText(e.target.value)} - autoFocus - /> - { - searchText && ( -
setSearchText('')} - > - -
- ) - } -
-
- {available_dependencies.filter((v) => { - if (!searchText) - return true - return v.name.toLowerCase().includes(searchText.toLowerCase()) - }).map(dependency => ( -
-
{dependency.name}
- {dependency.name === value.name && } -
- ))} -
-
-
-
- ) -} - -export default React.memo(DependencyPicker) diff --git a/web/app/components/workflow/nodes/code/dependency.tsx b/web/app/components/workflow/nodes/code/dependency.tsx deleted file mode 100644 index 5e868efe31..0000000000 --- a/web/app/components/workflow/nodes/code/dependency.tsx +++ /dev/null @@ -1,36 +0,0 @@ -import type { FC } from 'react' -import React from 'react' -import RemoveButton from '../_base/components/remove-button' -import type { CodeDependency } from './types' -import DependencyPicker from './dependency-picker' - -type Props = { - available_dependencies: CodeDependency[] - dependencies: CodeDependency[] - handleRemove: (index: number) => void - handleChange: (index: number, dependency: CodeDependency) => void -} - -const Dependencies: FC = ({ - available_dependencies, dependencies, handleRemove, handleChange, -}) => { - return ( -
- {dependencies.map((dependency, index) => ( -
- handleChange(index, dependency)} - /> - handleRemove(index)} - /> -
- ))} -
- ) -} - -export default React.memo(Dependencies) diff --git a/web/app/components/workflow/nodes/code/panel.tsx b/web/app/components/workflow/nodes/code/panel.tsx index 8ab9b3d0e5..838e7190d3 100644 --- a/web/app/components/workflow/nodes/code/panel.tsx +++ b/web/app/components/workflow/nodes/code/panel.tsx @@ -5,7 +5,6 @@ import RemoveEffectVarConfirm from '../_base/components/remove-effect-var-confir import useConfig from './use-config' import type { CodeNodeType } from './types' import { CodeLanguage } from './types' -import Dependencies from './dependency' import VarList from '@/app/components/workflow/nodes/_base/components/variable/var-list' import OutputVarList from '@/app/components/workflow/nodes/_base/components/variable/output-var-list' import AddButton from '@/app/components/base/button/add-button' @@ -60,11 +59,6 @@ const Panel: FC> = ({ varInputs, inputVarValues, setInputVarValues, - allowDependencies, - availableDependencies, - handleAddDependency, - handleRemoveDependency, - handleChangeDependency, } = useConfig(id, data) return ( @@ -84,31 +78,6 @@ const Panel: FC> = ({ filterVar={filterVar} />
- { - allowDependencies - ? ( -
- -
- handleAddDependency({ name: '', version: '' })} /> - } - tooltip={t(`${i18nPrefix}.advancedDependenciesTip`)!} - > - handleRemoveDependency(index)} - handleChange={(index, dependency) => handleChangeDependency(index, dependency)} - /> - -
-
- ) - : null - } { const appId = useAppStore.getState().appDetail?.id const [allLanguageDefault, setAllLanguageDefault] = useState | null>(null) - const [allLanguageDependencies, setAllLanguageDependencies] = useState | null>(null) useEffect(() => { if (appId) { (async () => { const { config: javaScriptConfig } = await fetchNodeDefault(appId, BlockEnum.Code, { code_language: CodeLanguage.javascript }) as any - const { config: pythonConfig, available_dependencies: pythonDependencies } = await fetchNodeDefault(appId, BlockEnum.Code, { code_language: CodeLanguage.python3 }) as any + const { config: pythonConfig } = await fetchNodeDefault(appId, BlockEnum.Code, { code_language: CodeLanguage.python3 }) as any setAllLanguageDefault({ [CodeLanguage.javascript]: javaScriptConfig as CodeNodeType, [CodeLanguage.python3]: pythonConfig as CodeNodeType, } as any) - setAllLanguageDependencies({ - [CodeLanguage.python3]: pythonDependencies as CodeDependency[], - } as any) })() } }, [appId]) @@ -45,62 +41,6 @@ const useConfig = (id: string, payload: CodeNodeType) => { setInputs, }) - const handleAddDependency = useCallback((dependency: CodeDependency) => { - const newInputs = produce(inputs, (draft) => { - if (!draft.dependencies) - draft.dependencies = [] - draft.dependencies.push(dependency) - }) - setInputs(newInputs) - }, [inputs, setInputs]) - - const handleRemoveDependency = useCallback((index: number) => { - const newInputs = produce(inputs, (draft) => { - if (!draft.dependencies) - draft.dependencies = [] - draft.dependencies.splice(index, 1) - }) - setInputs(newInputs) - }, [inputs, setInputs]) - - const handleChangeDependency = useCallback((index: number, dependency: CodeDependency) => { - const newInputs = produce(inputs, (draft) => { - if (!draft.dependencies) - draft.dependencies = [] - draft.dependencies[index] = dependency - }) - setInputs(newInputs) - }, [inputs, setInputs]) - - const [allowDependencies, setAllowDependencies] = useState(false) - useEffect(() => { - if (!inputs.code_language) - return - if (!allLanguageDependencies) - return - - const newAllowDependencies = !!allLanguageDependencies[inputs.code_language] - setAllowDependencies(newAllowDependencies) - }, [allLanguageDependencies, inputs.code_language]) - - const [availableDependencies, setAvailableDependencies] = useState([]) - useEffect(() => { - if (!inputs.code_language) - return - if (!allLanguageDependencies) - return - - const newAvailableDependencies = produce(allLanguageDependencies[inputs.code_language], (draft) => { - const currentLanguage = inputs.code_language - if (!currentLanguage || !draft || !inputs.dependencies) - return [] - return draft.filter((dependency) => { - return !inputs.dependencies?.find(d => d.name === dependency.name) - }) - }) - setAvailableDependencies(newAvailableDependencies || []) - }, [allLanguageDependencies, inputs.code_language, inputs.dependencies]) - const [outputKeyOrders, setOutputKeyOrders] = useState([]) const syncOutputKeyOrders = useCallback((outputs: OutputVar) => { setOutputKeyOrders(Object.keys(outputs)) @@ -223,11 +163,6 @@ const useConfig = (id: string, payload: CodeNodeType) => { inputVarValues, setInputVarValues, runResult, - availableDependencies, - allowDependencies, - handleAddDependency, - handleRemoveDependency, - handleChangeDependency, } } diff --git a/web/app/components/workflow/nodes/http/components/edit-body/index.tsx b/web/app/components/workflow/nodes/http/components/edit-body/index.tsx index 645bcdcdf9..6e8f4eac3b 100644 --- a/web/app/components/workflow/nodes/http/components/edit-body/index.tsx +++ b/web/app/components/workflow/nodes/http/components/edit-body/index.tsx @@ -44,7 +44,7 @@ const EditBody: FC = ({ const { availableVars, availableNodes } = useAvailableVarList(nodeId, { onlyLeafNodeVar: false, filterVar: (varPayload: Var) => { - return [VarType.string, VarType.number, VarType.secret].includes(varPayload.type) + return [VarType.string, VarType.number, VarType.secret, VarType.arrayNumber, VarType.arrayString].includes(varPayload.type) }, }) diff --git a/web/app/components/workflow/nodes/http/components/key-value/key-value-edit/item.tsx b/web/app/components/workflow/nodes/http/components/key-value/key-value-edit/item.tsx index 93c2696b98..0f28db96c3 100644 --- a/web/app/components/workflow/nodes/http/components/key-value/key-value-edit/item.tsx +++ b/web/app/components/workflow/nodes/http/components/key-value/key-value-edit/item.tsx @@ -73,7 +73,7 @@ const KeyValueItem: FC = ({ handleChange('key')(e.target.value)} /> )}
diff --git a/web/app/components/workflow/nodes/iteration/add-block.tsx b/web/app/components/workflow/nodes/iteration/add-block.tsx index fb61dede28..fd8480b7df 100644 --- a/web/app/components/workflow/nodes/iteration/add-block.tsx +++ b/web/app/components/workflow/nodes/iteration/add-block.tsx @@ -29,7 +29,7 @@ import type { import { BlockEnum, } from '@/app/components/workflow/types' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' type AddBlockProps = { iterationNodeId: string @@ -99,11 +99,11 @@ const AddBlock = ({ return (
- +
-
+
{ iterationNodeData.startNodeType && ( diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/components/retrieval-config.tsx b/web/app/components/workflow/nodes/knowledge-retrieval/components/retrieval-config.tsx index ebe9d2151e..b335b62e33 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/components/retrieval-config.tsx +++ b/web/app/components/workflow/nodes/knowledge-retrieval/components/retrieval-config.tsx @@ -70,7 +70,7 @@ const RetrievalConfig: FC = ({ } onMultipleRetrievalConfigChange({ top_k: configs.top_k, - score_threshold: configs.score_threshold_enabled ? (configs.score_threshold || DATASET_DEFAULT.score_threshold) : null, + score_threshold: configs.score_threshold_enabled ? (configs.score_threshold ?? DATASET_DEFAULT.score_threshold) : null, reranking_model: payload.retrieval_mode === RETRIEVE_TYPE.oneWay ? undefined : (!configs.reranking_model?.reranking_provider_name diff --git a/web/app/components/workflow/nodes/llm/components/config-prompt-item.tsx b/web/app/components/workflow/nodes/llm/components/config-prompt-item.tsx index 39715a7c71..c8d4d92fda 100644 --- a/web/app/components/workflow/nodes/llm/components/config-prompt-item.tsx +++ b/web/app/components/workflow/nodes/llm/components/config-prompt-item.tsx @@ -3,13 +3,12 @@ import type { FC } from 'react' import React, { useCallback, useEffect, useState } from 'react' import { uniqueId } from 'lodash-es' import { useTranslation } from 'react-i18next' -import { RiQuestionLine } from '@remixicon/react' import type { ModelConfig, PromptItem, Variable } from '../../../types' import { EditionType } from '../../../types' import { useWorkflowStore } from '../../../store' import Editor from '@/app/components/workflow/nodes/_base/components/prompt/editor' import TypeSelector from '@/app/components/workflow/nodes/_base/components/selector' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import { PromptRole } from '@/models/debug' const i18nPrefix = 'workflow.nodes.llm' @@ -118,13 +117,12 @@ const ConfigPromptItem: FC = ({ /> )} - {t(`${i18nPrefix}.roleDescription.${payload.role}`)}
} - > - - + triggerClassName='w-4 h-4' + />
} value={payload.edition_type === EditionType.jinja2 ? (payload.jinja2_text || '') : payload.text} diff --git a/web/app/components/workflow/nodes/llm/panel.tsx b/web/app/components/workflow/nodes/llm/panel.tsx index 1c2ec3c985..569e7f3feb 100644 --- a/web/app/components/workflow/nodes/llm/panel.tsx +++ b/web/app/components/workflow/nodes/llm/panel.tsx @@ -1,7 +1,6 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import { RiQuestionLine } from '@remixicon/react' import MemoryConfig from '../_base/components/memory-config' import VarReferencePicker from '../_base/components/variable/var-reference-picker' import useConfig from './use-config' @@ -19,7 +18,7 @@ import { InputVarType, type NodePanelProps } from '@/app/components/workflow/typ import BeforeRunForm from '@/app/components/workflow/nodes/_base/components/before-run-form' import type { Props as FormProps } from '@/app/components/workflow/nodes/_base/components/before-run-form/form' import ResultPanel from '@/app/components/workflow/run/result-panel' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import Editor from '@/app/components/workflow/nodes/_base/components/prompt/editor' import Switch from '@/app/components/base/switch' const i18nPrefix = 'workflow.nodes.llm' @@ -206,11 +205,10 @@ const Panel: FC> = ({
{t('workflow.nodes.common.memories.title')}
- - - + triggerClassName='w-4 h-4' + />
{t('workflow.nodes.common.memories.builtIn')}
@@ -219,13 +217,12 @@ const Panel: FC> = ({
user
- {t('workflow.nodes.llm.roleDescription.user')}
} - > - - + triggerClassName='w-4 h-4' + />
} value={inputs.memory.query_prompt_template || '{{#sys.query#}}'} onChange={handleSyeQueryChange} diff --git a/web/app/components/workflow/nodes/parameter-extractor/panel.tsx b/web/app/components/workflow/nodes/parameter-extractor/panel.tsx index 7c5686fadc..0fc7c3ff98 100644 --- a/web/app/components/workflow/nodes/parameter-extractor/panel.tsx +++ b/web/app/components/workflow/nodes/parameter-extractor/panel.tsx @@ -1,9 +1,6 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import { - RiQuestionLine, -} from '@remixicon/react' import MemoryConfig from '../_base/components/memory-config' import VarReferencePicker from '../_base/components/variable/var-reference-picker' import Editor from '../_base/components/prompt/editor' @@ -19,7 +16,7 @@ import Split from '@/app/components/workflow/nodes/_base/components/split' import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal' import OutputVars, { VarItem } from '@/app/components/workflow/nodes/_base/components/output-vars' import { InputVarType, type NodePanelProps } from '@/app/components/workflow/types' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import BeforeRunForm from '@/app/components/workflow/nodes/_base/components/before-run-form' import { VarType } from '@/app/components/workflow/types' @@ -126,12 +123,14 @@ const Panel: FC> = ({ title={
{t(`${i18nPrefix}.instruction`)} - - {t(`${i18nPrefix}.instructionTip`)} -
}> - - + + {t(`${i18nPrefix}.instructionTip`)} +
+ } + triggerClassName='w-3.5 h-3.5 ml-0.5' + />
} value={inputs.instruction} diff --git a/web/app/components/workflow/nodes/question-classifier/components/advanced-setting.tsx b/web/app/components/workflow/nodes/question-classifier/components/advanced-setting.tsx index c89e0f2668..dc654607f7 100644 --- a/web/app/components/workflow/nodes/question-classifier/components/advanced-setting.tsx +++ b/web/app/components/workflow/nodes/question-classifier/components/advanced-setting.tsx @@ -2,13 +2,10 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import { - RiQuestionLine, -} from '@remixicon/react' import MemoryConfig from '../../_base/components/memory-config' import Editor from '@/app/components/workflow/nodes/_base/components/prompt/editor' import type { Memory, Node, NodeOutPutVar } from '@/app/components/workflow/types' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' const i18nPrefix = 'workflow.nodes.questionClassifiers' type Props = { @@ -50,12 +47,14 @@ const AdvancedSetting: FC = ({ title={
{t(`${i18nPrefix}.instruction`)} - - {t(`${i18nPrefix}.instructionTip`)} -
}> - - + + {t(`${i18nPrefix}.instructionTip`)} +
+ } + triggerClassName='w-3.5 h-3.5 ml-0.5' + />
} value={instruction} diff --git a/web/app/components/workflow/nodes/start/panel.tsx b/web/app/components/workflow/nodes/start/panel.tsx index 48b5d6b7c2..ce86a34265 100644 --- a/web/app/components/workflow/nodes/start/panel.tsx +++ b/web/app/components/workflow/nodes/start/panel.tsx @@ -84,17 +84,30 @@ const Panel: FC> = ({ /> { isChatMode && ( - - String -
- } - /> + <> + + Number + + } + /> + + String + + } + /> + ) } +
{icon}
- +
) } diff --git a/web/app/components/workflow/operator/tip-popup.tsx b/web/app/components/workflow/operator/tip-popup.tsx index ecd108dffc..a389d9e4c6 100644 --- a/web/app/components/workflow/operator/tip-popup.tsx +++ b/web/app/components/workflow/operator/tip-popup.tsx @@ -1,6 +1,6 @@ import { memo } from 'react' import ShortcutsName from '../shortcuts-name' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' type TipPopupProps = { title: string @@ -13,9 +13,8 @@ const TipPopup = ({ shortcuts, }: TipPopupProps) => { return ( - @@ -27,7 +26,7 @@ const TipPopup = ({ } > {children} - + ) } diff --git a/web/app/components/workflow/panel-contextmenu.tsx b/web/app/components/workflow/panel-contextmenu.tsx index 502967ce2c..f01e3037a2 100644 --- a/web/app/components/workflow/panel-contextmenu.tsx +++ b/web/app/components/workflow/panel-contextmenu.tsx @@ -1,5 +1,6 @@ import { memo, + useEffect, useRef, } from 'react' import { useTranslation } from 'react-i18next' @@ -23,11 +24,16 @@ const PanelContextmenu = () => { const clipboardElements = useStore(s => s.clipboardElements) const setShowImportDSLModal = useStore(s => s.setShowImportDSLModal) const { handleNodesPaste } = useNodesInteractions() - const { handlePaneContextmenuCancel } = usePanelInteractions() + const { handlePaneContextmenuCancel, handleNodeContextmenuCancel } = usePanelInteractions() const { handleStartWorkflowRun } = useWorkflowStartRun() const { handleAddNote } = useOperator() const { exportCheck } = useDSL() + useEffect(() => { + if (panelMenu) + handleNodeContextmenuCancel() + }, [panelMenu, handleNodeContextmenuCancel]) + useClickAway(() => { handlePaneContextmenuCancel() }, ref) diff --git a/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx b/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx index 289e29d592..e6c1ebb5cc 100644 --- a/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx +++ b/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx @@ -15,6 +15,7 @@ import type { ConversationVariable } from '@/app/components/workflow/types' import { CodeLanguage } from '@/app/components/workflow/nodes/code/types' import { ChatVarType } from '@/app/components/workflow/panel/chat-variable-panel/type' import cn from '@/utils/classnames' +import { checkKeys } from '@/utils/var' export type ModalPropsType = { chatVar?: ConversationVariable @@ -128,14 +129,16 @@ const ChatVariableModal = ({ } } - const handleNameChange = (v: string) => { - if (!v) - return setName('') - if (!/^[a-zA-Z0-9_]+$/.test(v)) - return notify({ type: 'error', message: 'name is can only contain letters, numbers and underscores' }) - if (/^[0-9]/.test(v)) - return notify({ type: 'error', message: 'name can not start with a number' }) - setName(v) + const checkVariableName = (value: string) => { + const { isValid, errorMessageKey } = checkKeys([value], false) + if (!isValid) { + notify({ + type: 'error', + message: t(`appDebug.varKeyError.${errorMessageKey}`, { key: t('workflow.env.modal.name') }), + }) + return false + } + return true } const handleTypeChange = (v: ChatVarType) => { @@ -211,8 +214,8 @@ const ChatVariableModal = ({ } const handleSave = () => { - if (!name) - return notify({ type: 'error', message: 'name can not be empty' }) + if (!checkVariableName(name)) + return if (!chatVar && varList.some(chatVar => chatVar.name === name)) return notify({ type: 'error', message: 'name is existed' }) // if (type !== ChatVarType.Object && !value) @@ -272,7 +275,8 @@ const ChatVariableModal = ({ className='block px-3 w-full h-8 bg-components-input-bg-normal system-sm-regular radius-md border border-transparent appearance-none outline-none caret-primary-600 hover:border-components-input-border-hover hover:bg-components-input-bg-hover focus:bg-components-input-bg-active focus:border-components-input-border-active focus:shadow-xs placeholder:system-sm-regular placeholder:text-components-input-text-placeholder' placeholder={t('workflow.chatVariable.modal.namePlaceholder') || ''} value={name} - onChange={e => handleNameChange(e.target.value)} + onChange={e => setName(e.target.value || '')} + onBlur={e => checkVariableName(e.target.value)} type='text' /> diff --git a/web/app/components/workflow/panel/debug-and-preview/hooks.ts b/web/app/components/workflow/panel/debug-and-preview/hooks.ts index 155d2a84ac..54d3915a13 100644 --- a/web/app/components/workflow/panel/debug-and-preview/hooks.ts +++ b/web/app/components/workflow/panel/debug-and-preview/hooks.ts @@ -248,11 +248,16 @@ export const useChat = ( } if (config?.suggested_questions_after_answer?.enabled && !hasStopResponded.current && onGetSuggestedQuestions) { - const { data }: any = await onGetSuggestedQuestions( - responseItem.id, - newAbortController => suggestedQuestionsAbortControllerRef.current = newAbortController, - ) - setSuggestQuestions(data) + try { + const { data }: any = await onGetSuggestedQuestions( + responseItem.id, + newAbortController => suggestedQuestionsAbortControllerRef.current = newAbortController, + ) + setSuggestQuestions(data) + } + catch (error) { + setSuggestQuestions([]) + } } }, onMessageEnd: (messageEnd) => { diff --git a/web/app/components/workflow/panel/debug-and-preview/index.tsx b/web/app/components/workflow/panel/debug-and-preview/index.tsx index 1f94b4fbc3..29fc48f896 100644 --- a/web/app/components/workflow/panel/debug-and-preview/index.tsx +++ b/web/app/components/workflow/panel/debug-and-preview/index.tsx @@ -18,7 +18,7 @@ import ChatWrapper from './chat-wrapper' import cn from '@/utils/classnames' import { RefreshCcw01 } from '@/app/components/base/icons/src/vender/line/arrows' import { BubbleX } from '@/app/components/base/icons/src/vender/line/others' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import ActionButton, { ActionButtonState } from '@/app/components/base/action-button' import { useStore } from '@/app/components/workflow/store' @@ -63,31 +63,31 @@ const DebugAndPreview = () => {
{t('workflow.common.debugAndPreview').toLocaleUpperCase()}
- handleRestartChat()}> - + {varList.length > 0 && ( - setShowConversationVariableModal(true)}> - + )} {variables.length > 0 && (
- setExpanded(!expanded)}> - + {expanded &&
}
)} diff --git a/web/app/components/workflow/panel/env-panel/variable-modal.tsx b/web/app/components/workflow/panel/env-panel/variable-modal.tsx index 46f92bd8ed..2180782ffe 100644 --- a/web/app/components/workflow/panel/env-panel/variable-modal.tsx +++ b/web/app/components/workflow/panel/env-panel/variable-modal.tsx @@ -1,14 +1,15 @@ import React, { useEffect } from 'react' import { useTranslation } from 'react-i18next' import { v4 as uuid4 } from 'uuid' -import { RiCloseLine, RiQuestionLine } from '@remixicon/react' +import { RiCloseLine } from '@remixicon/react' import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import { ToastContext } from '@/app/components/base/toast' import { useStore } from '@/app/components/workflow/store' import type { EnvironmentVariable } from '@/app/components/workflow/types' import cn from '@/utils/classnames' +import { checkKeys } from '@/utils/var' export type ModalPropsType = { env?: EnvironmentVariable @@ -28,19 +29,21 @@ const VariableModal = ({ const [name, setName] = React.useState('') const [value, setValue] = React.useState() - const handleNameChange = (v: string) => { - if (!v) - return setName('') - if (!/^[a-zA-Z0-9_]+$/.test(v)) - return notify({ type: 'error', message: 'name is can only contain letters, numbers and underscores' }) - if (/^[0-9]/.test(v)) - return notify({ type: 'error', message: 'name can not start with a number' }) - setName(v) + const checkVariableName = (value: string) => { + const { isValid, errorMessageKey } = checkKeys([value], false) + if (!isValid) { + notify({ + type: 'error', + message: t(`appDebug.varKeyError.${errorMessageKey}`, { key: t('workflow.env.modal.name') }), + }) + return false + } + return true } const handleSave = () => { - if (!name) - return notify({ type: 'error', message: 'name can not be empty' }) + if (!checkVariableName(name)) + return if (!value) return notify({ type: 'error', message: 'value can not be empty' }) if (!env && envList.some(env => env.name === name)) @@ -99,13 +102,14 @@ const VariableModal = ({ type === 'secret' && 'text-text-primary font-medium border-[1.5px] shadow-xs bg-components-option-card-option-selected-bg border-components-option-card-option-selected-border hover:border-components-option-card-option-selected-border', )} onClick={() => setType('secret')}> Secret - - {t('workflow.env.modal.secretTip')} -
- }> - - + + {t('workflow.env.modal.secretTip')} +
+ } + triggerClassName='ml-0.5 w-3.5 h-3.5' + />
@@ -118,7 +122,8 @@ const VariableModal = ({ className='block px-3 w-full h-8 bg-components-input-bg-normal system-sm-regular radius-md border border-transparent appearance-none outline-none caret-primary-600 hover:border-components-input-border-hover hover:bg-components-input-bg-hover focus:bg-components-input-bg-active focus:border-components-input-border-active focus:shadow-xs placeholder:system-sm-regular placeholder:text-components-input-text-placeholder' placeholder={t('workflow.env.modal.namePlaceholder') || ''} value={name} - onChange={e => handleNameChange(e.target.value)} + onChange={e => setName(e.target.value || '')} + onBlur={e => checkVariableName(e.target.value)} type='text' /> diff --git a/web/app/signin/oneMoreStep.tsx b/web/app/signin/oneMoreStep.tsx index d1e94ce2fc..42a797b4d6 100644 --- a/web/app/signin/oneMoreStep.tsx +++ b/web/app/signin/oneMoreStep.tsx @@ -6,8 +6,7 @@ import useSWR from 'swr' import { useRouter } from 'next/navigation' // import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' -import Tooltip from '@/app/components/base/tooltip/index' - +import Tooltip from '@/app/components/base/tooltip' import { SimpleSelect } from '@/app/components/base/select' import { timezones } from '@/utils/timezone' import { LanguagesSupported, languages } from '@/i18n/language' @@ -88,9 +87,7 @@ const OneMoreStep = () => {