merge feat/plugins

This commit is contained in:
Joel 2024-11-01 12:37:42 +08:00
commit a273ae35f9
653 changed files with 46039 additions and 3369 deletions

View File

@ -1,11 +1,12 @@
#!/bin/bash
cd web && npm install
npm add -g pnpm@9.12.2
cd web && pnpm install
pipx install poetry
echo 'alias start-api="cd /workspaces/dify/api && poetry run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
echo 'alias start-worker="cd /workspaces/dify/api && poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc
echo 'alias start-web="cd /workspaces/dify/web && npm run dev"' >> ~/.bashrc
echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc
source /home/vscode/.bashrc
source /home/vscode/.bashrc

View File

@ -78,7 +78,7 @@ jobs:
- name: Run Workflow
run: poetry run -C api bash dev/pytest/pytest_workflow.sh
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch)
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase)
uses: hoverkraft-tech/compose-action@v2.0.0
with:
compose-file: |
@ -86,6 +86,7 @@ jobs:
services: |
weaviate
qdrant
couchbase-server
etcd
minio
milvus-standalone

View File

@ -7,5 +7,7 @@ yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/dock
yq eval '.services.pgvector.ports += ["5433:5432"]' -i docker/docker-compose.yaml
yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compose.yaml
yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-compose.yaml
yq eval '.services.couchbase-server.ports += ["8091-8096:8091-8096"]' -i docker/docker-compose.yaml
yq eval '.services.couchbase-server.ports += ["11210:11210"]' -i docker/docker-compose.yaml
echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch"
echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase"

View File

@ -76,16 +76,16 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
with:
node-version: 20
cache: yarn
cache: pnpm
cache-dependency-path: ./web/package.json
- name: Web dependencies
if: steps.changed-files.outputs.any_changed == 'true'
run: yarn install --frozen-lockfile
run: pnpm install --frozen-lockfile
- name: Web style check
if: steps.changed-files.outputs.any_changed == 'true'
run: yarn run lint
run: pnpm run lint
superlinter:

View File

@ -32,10 +32,10 @@ jobs:
with:
node-version: ${{ matrix.node-version }}
cache: ''
cache-dependency-path: 'yarn.lock'
cache-dependency-path: 'pnpm-lock.yaml'
- name: Install Dependencies
run: yarn install
run: pnpm install
- name: Test
run: yarn test
run: pnpm test

View File

@ -38,11 +38,11 @@ jobs:
- name: Install dependencies
if: env.FILES_CHANGED == 'true'
run: yarn install --frozen-lockfile
run: pnpm install --frozen-lockfile
- name: Run npm script
if: env.FILES_CHANGED == 'true'
run: npm run auto-gen-i18n
run: pnpm run auto-gen-i18n
- name: Create Pull Request
if: env.FILES_CHANGED == 'true'

View File

@ -34,13 +34,13 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
with:
node-version: 20
cache: yarn
cache: pnpm
cache-dependency-path: ./web/package.json
- name: Install dependencies
if: steps.changed-files.outputs.any_changed == 'true'
run: yarn install --frozen-lockfile
run: pnpm install --frozen-lockfile
- name: Run tests
if: steps.changed-files.outputs.any_changed == 'true'
run: yarn test
run: pnpm test

7
.gitignore vendored
View File

@ -173,6 +173,8 @@ docker/volumes/myscale/log/*
docker/volumes/unstructured/*
docker/volumes/pgvector/data/*
docker/volumes/pgvecto_rs/data/*
docker/volumes/couchbase/*
docker/volumes/oceanbase/*
docker/nginx/conf.d/default.conf
docker/nginx/ssl/*
@ -189,4 +191,7 @@ pyrightconfig.json
api/.vscode
.idea/
.vscode
.vscode
# pnpm
/.pnpm-store

View File

@ -1,5 +1,9 @@
![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab)
<p align="center">
📌 <a href="https://dify.ai/blog/introducing-dify-workflow-file-upload-a-demo-on-ai-podcast">Introducing Dify Workflow File Upload: Recreate Google NotebookLM Podcast</a>
</p>
<p align="center">
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
<a href="https://docs.dify.ai/getting-started/install-self-hosted">Self-hosting</a> ·

View File

@ -154,7 +154,7 @@ Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI
我们提供[ Dify 云服务](https://dify.ai),任何人都可以零设置尝试。它提供了自部署版本的所有功能,并在沙盒计划中包含 200 次免费的 GPT-4 调用。
- **自托管 Dify 社区版</br>**
使用这个[入门指南](#quick-start)快速在您的环境中运行 Dify。
使用这个[入门指南](#快速启动)快速在您的环境中运行 Dify。
使用我们的[文档](https://docs.dify.ai)进行进一步的参考和更深入的说明。
- **面向企业/组织的 Dify</br>**

241
README_PT.md Normal file
View File

@ -0,0 +1,241 @@
![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab)
<p align="center">
📌 <a href="https://dify.ai/blog/introducing-dify-workflow-file-upload-a-demo-on-ai-podcast">Introduzindo o Dify Workflow com Upload de Arquivo: Recrie o Podcast Google NotebookLM</a>
</p>
<p align="center">
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
<a href="https://docs.dify.ai/getting-started/install-self-hosted">Auto-hospedagem</a> ·
<a href="https://docs.dify.ai">Documentação</a> ·
<a href="https://udify.app/chat/22L1zSxg6yW1cWQg">Consultas empresariais</a>
</p>
<p align="center">
<a href="https://dify.ai" target="_blank">
<img alt="Static Badge" src="https://img.shields.io/badge/Product-F04438"></a>
<a href="https://dify.ai/pricing" target="_blank">
<img alt="Static Badge" src="https://img.shields.io/badge/free-pricing?logo=free&color=%20%23155EEF&label=pricing&labelColor=%20%23528bff"></a>
<a href="https://discord.gg/FngNHpbcY7" target="_blank">
<img src="https://img.shields.io/discord/1082486657678311454?logo=discord&labelColor=%20%235462eb&logoColor=%20%23f5f5f5&color=%20%235462eb"
alt="chat on Discord"></a>
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
alt="follow on X(Twitter)"></a>
<a href="https://hub.docker.com/u/langgenius" target="_blank">
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
<img alt="Commits last month" src="https://img.shields.io/github/commit-activity/m/langgenius/dify?labelColor=%20%2332b583&color=%20%2312b76a"></a>
<a href="https://github.com/langgenius/dify/" target="_blank">
<img alt="Issues closed" src="https://img.shields.io/github/issues-search?query=repo%3Alanggenius%2Fdify%20is%3Aclosed&label=issues%20closed&labelColor=%20%237d89b0&color=%20%235d6b98"></a>
<a href="https://github.com/langgenius/dify/discussions/" target="_blank">
<img alt="Discussion posts" src="https://img.shields.io/github/discussions/langgenius/dify?labelColor=%20%239b8afb&color=%20%237a5af8"></a>
</p>
<p align="center">
<a href="./README.md"><img alt="README em Inglês" src="https://img.shields.io/badge/English-d9d9d9"></a>
<a href="./README_CN.md"><img alt="简体中文版自述文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
<a href="./README_JA.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
<a href="./README_ES.md"><img alt="README em Espanhol" src="https://img.shields.io/badge/Español-d9d9d9"></a>
<a href="./README_FR.md"><img alt="README em Francês" src="https://img.shields.io/badge/Français-d9d9d9"></a>
<a href="./README_KL.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
<a href="./README_KR.md"><img alt="README em Coreano" src="https://img.shields.io/badge/한국어-d9d9d9"></a>
<a href="./README_AR.md"><img alt="README em Árabe" src="https://img.shields.io/badge/العربية-d9d9d9"></a>
<a href="./README_TR.md"><img alt="README em Turco" src="https://img.shields.io/badge/Türkçe-d9d9d9"></a>
<a href="./README_VI.md"><img alt="README em Vietnamita" src="https://img.shields.io/badge/Ti%E1%BA%BFng%20Vi%E1%BB%87t-d9d9d9"></a>
<a href="./README_PT.md"><img alt="README em Português - BR" src="https://img.shields.io/badge/Portugu%C3%AAs-BR?style=flat&label=BR&color=d9d9d9"></a>
</p>
Dify é uma plataforma de desenvolvimento de aplicativos LLM de código aberto. Sua interface intuitiva combina workflow de IA, pipeline RAG, capacidades de agente, gerenciamento de modelos, recursos de observabilidade e muito mais, permitindo que você vá rapidamente do protótipo à produção. Aqui está uma lista das principais funcionalidades:
</br> </br>
**1. Workflow**:
Construa e teste workflows poderosos de IA em uma interface visual, aproveitando todos os recursos a seguir e muito mais.
https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa
**2. Suporte abrangente a modelos**:
Integração perfeita com centenas de LLMs proprietários e de código aberto de diversas provedoras e soluções auto-hospedadas, abrangendo GPT, Mistral, Llama3 e qualquer modelo compatível com a API da OpenAI. A lista completa de provedores suportados pode ser encontrada [aqui](https://docs.dify.ai/getting-started/readme/model-providers).
![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3)
**3. IDE de Prompt**:
Interface intuitiva para criação de prompts, comparação de desempenho de modelos e adição de recursos como conversão de texto para fala em um aplicativo baseado em chat.
**4. Pipeline RAG**:
Extensas capacidades de RAG que cobrem desde a ingestão de documentos até a recuperação, com suporte nativo para extração de texto de PDFs, PPTs e outros formatos de documentos comuns.
**5. Capacidades de agente**:
Você pode definir agentes com base em LLM Function Calling ou ReAct e adicionar ferramentas pré-construídas ou personalizadas para o agente. O Dify oferece mais de 50 ferramentas integradas para agentes de IA, como Google Search, DALL·E, Stable Diffusion e WolframAlpha.
**6. LLMOps**:
Monitore e analise os registros e o desempenho do aplicativo ao longo do tempo. É possível melhorar continuamente prompts, conjuntos de dados e modelos com base nos dados de produção e anotações.
**7. Backend como Serviço**:
Todas os recursos do Dify vêm com APIs correspondentes, permitindo que você integre o Dify sem esforço na lógica de negócios da sua empresa.
## Comparação de recursos
<table style="width: 100%;">
<tr>
<th align="center">Recurso</th>
<th align="center">Dify.AI</th>
<th align="center">LangChain</th>
<th align="center">Flowise</th>
<th align="center">OpenAI Assistants API</th>
</tr>
<tr>
<td align="center">Abordagem de Programação</td>
<td align="center">Orientada a API + Aplicativo</td>
<td align="center">Código Python</td>
<td align="center">Orientada a Aplicativo</td>
<td align="center">Orientada a API</td>
</tr>
<tr>
<td align="center">LLMs Suportados</td>
<td align="center">Variedade Rica</td>
<td align="center">Variedade Rica</td>
<td align="center">Variedade Rica</td>
<td align="center">Apenas OpenAI</td>
</tr>
<tr>
<td align="center">RAG Engine</td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
</tr>
<tr>
<td align="center">Agente</td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
</tr>
<tr>
<td align="center">Workflow</td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
</tr>
<tr>
<td align="center">Observabilidade</td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
</tr>
<tr>
<td align="center">Recursos Empresariais (SSO/Controle de Acesso)</td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
</tr>
<tr>
<td align="center">Implantação Local</td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
</tr>
</table>
## Usando o Dify
- **Nuvem </br>**
Oferecemos o serviço [Dify Cloud](https://dify.ai) para qualquer pessoa experimentar sem nenhuma configuração. Ele fornece todas as funcionalidades da versão auto-hospedada, incluindo 200 chamadas GPT-4 gratuitas no plano sandbox.
- **Auto-hospedagem do Dify Community Edition</br>**
Configure rapidamente o Dify no seu ambiente com este [guia inicial](#quick-start).
Use nossa [documentação](https://docs.dify.ai) para referências adicionais e instruções mais detalhadas.
- **Dify para empresas/organizações</br>**
Oferecemos recursos adicionais voltados para empresas. [Envie suas perguntas através deste chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) ou [envie-nos um e-mail](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) para discutir necessidades empresariais. </br>
> Para startups e pequenas empresas que utilizam AWS, confira o [Dify Premium no AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) e implemente no seu próprio AWS VPC com um clique. É uma oferta AMI acessível com a opção de criar aplicativos com logotipo e marca personalizados.
## Mantendo-se atualizado
Dê uma estrela no Dify no GitHub e seja notificado imediatamente sobre novos lançamentos.
![star-us](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4)
## Início rápido
> Antes de instalar o Dify, certifique-se de que sua máquina atenda aos seguintes requisitos mínimos de sistema:
>
>- CPU >= 2 Núcleos
>- RAM >= 4 GiB
</br>
A maneira mais fácil de iniciar o servidor Dify é executar nosso arquivo [docker-compose.yml](docker/docker-compose.yaml). Antes de rodar o comando de instalação, certifique-se de que o [Docker](https://docs.docker.com/get-docker/) e o [Docker Compose](https://docs.docker.com/compose/install/) estão instalados na sua máquina:
```bash
cd docker
cp .env.example .env
docker compose up -d
```
Após a execução, você pode acessar o painel do Dify no navegador em [http://localhost/install](http://localhost/install) e iniciar o processo de inicialização.
> Se você deseja contribuir com o Dify ou fazer desenvolvimento adicional, consulte nosso [guia para implantar a partir do código fonte](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code).
## Próximos passos
Se precisar personalizar a configuração, consulte os comentários no nosso arquivo [.env.example](docker/.env.example) e atualize os valores correspondentes no seu arquivo `.env`. Além disso, talvez seja necessário fazer ajustes no próprio arquivo `docker-compose.yaml`, como alterar versões de imagem, mapeamentos de portas ou montagens de volumes, com base no seu ambiente de implantação específico e nas suas necessidades. Após fazer quaisquer alterações, execute novamente `docker-compose up -d`. Você pode encontrar a lista completa de variáveis de ambiente disponíveis [aqui](https://docs.dify.ai/getting-started/install-self-hosted/environments).
Se deseja configurar uma instalação de alta disponibilidade, há [Helm Charts](https://helm.sh/) e arquivos YAML contribuídos pela comunidade que permitem a implantação do Dify no Kubernetes.
- [Helm Chart de @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
- [Helm Chart de @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
- [Arquivo YAML de @Winson-030](https://github.com/Winson-030/dify-kubernetes)
#### Usando o Terraform para Implantação
Implante o Dify na Plataforma Cloud com um único clique usando [terraform](https://www.terraform.io/)
##### Azure Global
- [Azure Terraform por @nikawang](https://github.com/nikawang/dify-azure-terraform)
##### Google Cloud
- [Google Cloud Terraform por @sotazum](https://github.com/DeNA/dify-google-cloud-terraform)
## Contribuindo
Para aqueles que desejam contribuir com código, veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
Ao mesmo tempo, considere apoiar o Dify compartilhando-o nas redes sociais e em eventos e conferências.
> Estamos buscando contribuidores para ajudar na tradução do Dify para idiomas além de Mandarim e Inglês. Se você tiver interesse em ajudar, consulte o [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) para mais informações e deixe-nos um comentário no canal `global-users` em nosso [Servidor da Comunidade no Discord](https://discord.gg/8Tpq4AcN9c).
**Contribuidores**
<a href="https://github.com/langgenius/dify/graphs/contributors">
<img src="https://contrib.rocks/image?repo=langgenius/dify" />
</a>
## Comunidade e contato
* [Discussões no GitHub](https://github.com/langgenius/dify/discussions). Melhor para: compartilhar feedback e fazer perguntas.
* [Problemas no GitHub](https://github.com/langgenius/dify/issues). Melhor para: relatar bugs encontrados no Dify.AI e propor novos recursos. Veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
* [Discord](https://discord.gg/FngNHpbcY7). Melhor para: compartilhar suas aplicações e interagir com a comunidade.
* [X(Twitter)](https://twitter.com/dify_ai). Melhor para: compartilhar suas aplicações e interagir com a comunidade.
## Histórico de estrelas
[![Gráfico de Histórico de Estrelas](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date)
## Divulgação de segurança
Para proteger sua privacidade, evite postar problemas de segurança no GitHub. Em vez disso, envie suas perguntas para security@dify.ai e forneceremos uma resposta mais detalhada.
## Licença
Este repositório está disponível sob a [Licença de Código Aberto Dify](LICENSE), que é essencialmente Apache 2.0 com algumas restrições adicionais.

View File

@ -31,8 +31,17 @@ REDIS_HOST=localhost
REDIS_PORT=6379
REDIS_USERNAME=
REDIS_PASSWORD=difyai123456
REDIS_USE_SSL=false
REDIS_DB=0
# redis Sentinel configuration.
REDIS_USE_SENTINEL=false
REDIS_SENTINELS=
REDIS_SENTINEL_SERVICE_NAME=
REDIS_SENTINEL_USERNAME=
REDIS_SENTINEL_PASSWORD=
REDIS_SENTINEL_SOCKET_TIMEOUT=0.1
# PostgreSQL database configuration
DB_USERNAME=postgres
DB_PASSWORD=difyai123456
@ -111,7 +120,7 @@ SUPABASE_URL=your-server-url
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, vikingdb
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash
VECTOR_STORE=weaviate
# Weaviate configuration
@ -127,6 +136,13 @@ QDRANT_CLIENT_TIMEOUT=20
QDRANT_GRPC_ENABLED=false
QDRANT_GRPC_PORT=6334
#Couchbase configuration
COUCHBASE_CONNECTION_STRING=127.0.0.1
COUCHBASE_USER=Administrator
COUCHBASE_PASSWORD=password
COUCHBASE_BUCKET_NAME=Embeddings
COUCHBASE_SCOPE_NAME=_default
# Milvus configuration
MILVUS_URI=http://127.0.0.1:19530
MILVUS_TOKEN=
@ -186,6 +202,20 @@ TIDB_VECTOR_USER=xxx.root
TIDB_VECTOR_PASSWORD=xxxxxx
TIDB_VECTOR_DATABASE=dify
# Tidb on qdrant configuration
TIDB_ON_QDRANT_URL=http://127.0.0.1
TIDB_ON_QDRANT_API_KEY=dify
TIDB_ON_QDRANT_CLIENT_TIMEOUT=20
TIDB_ON_QDRANT_GRPC_ENABLED=false
TIDB_ON_QDRANT_GRPC_PORT=6334
TIDB_PUBLIC_KEY=dify
TIDB_PRIVATE_KEY=dify
TIDB_API_URL=http://127.0.0.1
TIDB_IAM_API_URL=http://127.0.0.1
TIDB_REGION=regions/aws-us-east-1
TIDB_PROJECT_ID=dify
TIDB_SPEND_LIMIT=100
# Chroma configuration
CHROMA_HOST=127.0.0.1
CHROMA_PORT=8000
@ -220,6 +250,10 @@ BAIDU_VECTOR_DB_DATABASE=dify
BAIDU_VECTOR_DB_SHARD=1
BAIDU_VECTOR_DB_REPLICAS=3
# Upstash configuration
UPSTASH_VECTOR_URL=your-server-url
UPSTASH_VECTOR_TOKEN=your-access-token
# ViKingDB configuration
VIKINGDB_ACCESS_KEY=your-ak
VIKINGDB_SECRET_KEY=your-sk
@ -229,6 +263,14 @@ VIKINGDB_SCHEMA=http
VIKINGDB_CONNECTION_TIMEOUT=30
VIKINGDB_SOCKET_TIMEOUT=30
# OceanBase Vector configuration
OCEANBASE_VECTOR_HOST=127.0.0.1
OCEANBASE_VECTOR_PORT=2881
OCEANBASE_VECTOR_USER=root@test
OCEANBASE_VECTOR_PASSWORD=
OCEANBASE_VECTOR_DATABASE=test
OCEANBASE_MEMORY_LIMIT=6G
# Upload configuration
UPLOAD_FILE_SIZE_LIMIT=15
UPLOAD_FILE_BATCH_LIMIT=5
@ -239,6 +281,7 @@ UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
# Model Configuration
MULTIMODAL_SEND_IMAGE_FORMAT=base64
PROMPT_GENERATION_MAX_TOKENS=512
CODE_GENERATION_MAX_TOKENS=1024
# Mail configuration, support: resend, smtp
MAIL_TYPE=
@ -304,6 +347,10 @@ RESPECT_XFORWARD_HEADERS_ENABLED=false
# Log file path
LOG_FILE=
# Log file max size, the unit is MB
LOG_FILE_MAX_SIZE=20
# Log file max backup count
LOG_FILE_BACKUP_COUNT=5
# Indexing configuration
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=1000

View File

@ -55,7 +55,9 @@ RUN apt-get update \
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
&& apt-get update \
# For Security
&& apt-get install -y --no-install-recommends libsqlite3-0=3.46.1-1 \
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.3-1 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-6 libsqlite3-0=3.46.1-1 \
# install a chinese font to support the use of tools like matplotlib
&& apt-get install -y fonts-noto-cjk \
&& apt-get autoremove -y \
&& rm -rf /var/lib/apt/lists/*

View File

@ -277,6 +277,9 @@ def migrate_knowledge_vector_database():
VectorType.TENCENT,
VectorType.BAIDU,
VectorType.VIKINGDB,
VectorType.UPSTASH,
VectorType.COUCHBASE,
VectorType.OCEANBASE,
}
page = 1
while True:

View File

@ -10,6 +10,7 @@ from pydantic import (
PositiveInt,
computed_field,
)
from pydantic_extra_types.timezone_name import TimeZoneName
from pydantic_settings import BaseSettings
from configs.feature.hosted_service import HostedServiceConfig
@ -372,6 +373,16 @@ class LoggingConfig(BaseSettings):
default=None,
)
LOG_FILE_MAX_SIZE: PositiveInt = Field(
description="Maximum file size for file rotation retention, the unit is megabytes (MB)",
default=20,
)
LOG_FILE_BACKUP_COUNT: PositiveInt = Field(
description="Maximum file backup count file rotation retention",
default=5,
)
LOG_FORMAT: str = Field(
description="Format string for log messages",
default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s",
@ -382,8 +393,9 @@ class LoggingConfig(BaseSettings):
default=None,
)
LOG_TZ: Optional[str] = Field(
description="Timezone for log timestamps (e.g., 'America/New_York')",
LOG_TZ: Optional[TimeZoneName] = Field(
description="Timezone for log timestamps. Allowed timezone values can be referred to IANA Time Zone Database,"
" e.g., 'America/New_York')",
default=None,
)
@ -614,6 +626,11 @@ class DataSetConfig(BaseSettings):
default=False,
)
TIDB_SERVERLESS_NUMBER: PositiveInt = Field(
description="number of tidb serverless cluster",
default=500,
)
class WorkspaceConfig(BaseSettings):
"""

View File

@ -17,9 +17,11 @@ from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCO
from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
from configs.middleware.vdb.chroma_config import ChromaConfig
from configs.middleware.vdb.couchbase_config import CouchbaseConfig
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
from configs.middleware.vdb.milvus_config import MilvusConfig
from configs.middleware.vdb.myscale_config import MyScaleConfig
from configs.middleware.vdb.oceanbase_config import OceanBaseVectorConfig
from configs.middleware.vdb.opensearch_config import OpenSearchConfig
from configs.middleware.vdb.oracle_config import OracleConfig
from configs.middleware.vdb.pgvector_config import PGVectorConfig
@ -27,7 +29,9 @@ from configs.middleware.vdb.pgvectors_config import PGVectoRSConfig
from configs.middleware.vdb.qdrant_config import QdrantConfig
from configs.middleware.vdb.relyt_config import RelytConfig
from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig
from configs.middleware.vdb.tidb_on_qdrant_config import TidbOnQdrantConfig
from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig
from configs.middleware.vdb.upstash_config import UpstashConfig
from configs.middleware.vdb.vikingdb_config import VikingDBConfig
from configs.middleware.vdb.weaviate_config import WeaviateConfig
@ -53,6 +57,11 @@ class VectorStoreConfig(BaseSettings):
default=None,
)
VECTOR_STORE_WHITELIST_ENABLE: Optional[bool] = Field(
description="Enable whitelist for vector store.",
default=False,
)
class KeywordStoreConfig(BaseSettings):
KEYWORD_STORE: str = Field(
@ -244,7 +253,11 @@ class MiddlewareConfig(
TiDBVectorConfig,
WeaviateConfig,
ElasticsearchConfig,
CouchbaseConfig,
InternalTestConfig,
VikingDBConfig,
UpstashConfig,
TidbOnQdrantConfig,
OceanBaseVectorConfig,
):
pass

View File

@ -0,0 +1,34 @@
from typing import Optional
from pydantic import BaseModel, Field
class CouchbaseConfig(BaseModel):
"""
Couchbase configs
"""
COUCHBASE_CONNECTION_STRING: Optional[str] = Field(
description="COUCHBASE connection string",
default=None,
)
COUCHBASE_USER: Optional[str] = Field(
description="COUCHBASE user",
default=None,
)
COUCHBASE_PASSWORD: Optional[str] = Field(
description="COUCHBASE password",
default=None,
)
COUCHBASE_BUCKET_NAME: Optional[str] = Field(
description="COUCHBASE bucket name",
default=None,
)
COUCHBASE_SCOPE_NAME: Optional[str] = Field(
description="COUCHBASE scope name",
default=None,
)

View File

@ -0,0 +1,35 @@
from typing import Optional
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class OceanBaseVectorConfig(BaseSettings):
"""
Configuration settings for OceanBase Vector database
"""
OCEANBASE_VECTOR_HOST: Optional[str] = Field(
description="Hostname or IP address of the OceanBase Vector server (e.g. 'localhost')",
default=None,
)
OCEANBASE_VECTOR_PORT: Optional[PositiveInt] = Field(
description="Port number on which the OceanBase Vector server is listening (default is 2881)",
default=2881,
)
OCEANBASE_VECTOR_USER: Optional[str] = Field(
description="Username for authenticating with the OceanBase Vector database",
default=None,
)
OCEANBASE_VECTOR_PASSWORD: Optional[str] = Field(
description="Password for authenticating with the OceanBase Vector database",
default=None,
)
OCEANBASE_VECTOR_DATABASE: Optional[str] = Field(
description="Name of the OceanBase Vector database to connect to",
default=None,
)

View File

@ -0,0 +1,70 @@
from typing import Optional
from pydantic import Field, NonNegativeInt, PositiveInt
from pydantic_settings import BaseSettings
class TidbOnQdrantConfig(BaseSettings):
"""
Tidb on Qdrant configs
"""
TIDB_ON_QDRANT_URL: Optional[str] = Field(
description="Tidb on Qdrant url",
default=None,
)
TIDB_ON_QDRANT_API_KEY: Optional[str] = Field(
description="Tidb on Qdrant api key",
default=None,
)
TIDB_ON_QDRANT_CLIENT_TIMEOUT: NonNegativeInt = Field(
description="Tidb on Qdrant client timeout in seconds",
default=20,
)
TIDB_ON_QDRANT_GRPC_ENABLED: bool = Field(
description="whether enable grpc support for Tidb on Qdrant connection",
default=False,
)
TIDB_ON_QDRANT_GRPC_PORT: PositiveInt = Field(
description="Tidb on Qdrant grpc port",
default=6334,
)
TIDB_PUBLIC_KEY: Optional[str] = Field(
description="Tidb account public key",
default=None,
)
TIDB_PRIVATE_KEY: Optional[str] = Field(
description="Tidb account private key",
default=None,
)
TIDB_API_URL: Optional[str] = Field(
description="Tidb API url",
default=None,
)
TIDB_IAM_API_URL: Optional[str] = Field(
description="Tidb IAM API url",
default=None,
)
TIDB_REGION: Optional[str] = Field(
description="Tidb serverless region",
default="regions/aws-us-east-1",
)
TIDB_PROJECT_ID: Optional[str] = Field(
description="Tidb project id",
default=None,
)
TIDB_SPEND_LIMIT: Optional[int] = Field(
description="Tidb spend limit",
default=100,
)

View File

@ -0,0 +1,20 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
class UpstashConfig(BaseSettings):
"""
Configuration settings for Upstash vector database
"""
UPSTASH_VECTOR_URL: Optional[str] = Field(
description="URL of the upstash server (e.g., 'https://vector.upstash.io')",
default=None,
)
UPSTASH_VECTOR_TOKEN: Optional[str] = Field(
description="Token for authenticating with the upstash server",
default=None,
)

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
default="0.10.0",
default="0.10.2",
)
COMMIT_SHA: str = Field(

View File

@ -15,7 +15,9 @@ AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
if dify_config.ETL_TYPE == "Unstructured":
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls"]
DOCUMENT_EXTENSIONS.extend(("docx", "csv", "eml", "msg", "pptx", "ppt", "xml", "epub"))
DOCUMENT_EXTENSIONS.extend(("docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
if dify_config.UNSTRUCTURED_API_URL:
DOCUMENT_EXTENSIONS.append("ppt")
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
else:
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"]

View File

@ -52,4 +52,39 @@ class RuleGenerateApi(Resource):
return rules
class RuleCodeGenerateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
parser.add_argument("code_language", type=str, required=False, default="javascript", location="json")
args = parser.parse_args()
account = current_user
CODE_GENERATION_MAX_TOKENS = int(os.getenv("CODE_GENERATION_MAX_TOKENS", "1024"))
try:
code_result = LLMGenerator.generate_code(
tenant_id=account.current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
code_language=args["code_language"],
max_tokens=CODE_GENERATION_MAX_TOKENS,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
return code_result
api.add_resource(RuleGenerateApi, "/rule-generate")
api.add_resource(RuleCodeGenerateApi, "/rule-code-generate")

View File

@ -105,6 +105,8 @@ class ChatMessageListApi(Resource):
if rest_count > 0:
has_more = True
history_messages = list(reversed(history_messages))
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)

View File

@ -1,11 +1,10 @@
from typing import cast
import flask_login
from flask import redirect, request
from flask import request
from flask_restful import Resource, reqparse
import services
from configs import dify_config
from constants.languages import languages
from controllers.console import api
from controllers.console.auth.error import (
@ -196,10 +195,7 @@ class EmailCodeLoginApi(Resource):
email=user_email, name=user_email, interface_language=languages[0]
)
except WorkSpaceNotAllowedCreateError:
return redirect(
f"{dify_config.CONSOLE_WEB_URL}/signin"
"?message=Workspace not found, please contact system admin to invite you to join in a workspace."
)
return NotAllowedCreateWorkspace()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "data": token_pair.model_dump()}

View File

@ -96,17 +96,15 @@ class OAuthCallback(Resource):
account = _generate_account(provider, user_info)
except AccountNotFoundError:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.")
except WorkSpaceNotFoundError:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found.")
except WorkSpaceNotAllowedCreateError:
except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError):
return redirect(
f"{dify_config.CONSOLE_WEB_URL}/signin"
"?message=Workspace not found, please contact system admin to invite you to join in a workspace."
)
# Check account status
if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}:
return {"error": "Account is banned or closed."}, 403
if account.status == AccountStatus.BANNED.value:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.")
if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value

View File

@ -102,6 +102,13 @@ class DatasetListApi(Resource):
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
type=str,
nullable=True,
required=False,
default="",
)
parser.add_argument(
"indexing_technique",
type=str,
@ -140,6 +147,7 @@ class DatasetListApi(Resource):
dataset = DatasetService.create_empty_dataset(
tenant_id=current_user.current_tenant_id,
name=args["name"],
description=args["description"],
indexing_technique=args["indexing_technique"],
account=current_user,
permission=DatasetPermissionEnum.ONLY_ME,
@ -619,6 +627,8 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.PGVECTO_RS
| VectorType.BAIDU
| VectorType.VIKINGDB
| VectorType.UPSTASH
| VectorType.OCEANBASE
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (
@ -630,6 +640,8 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
| VectorType.PGVECTOR
| VectorType.TIDB_ON_QDRANT
| VectorType.COUCHBASE
):
return {
"retrieval_method": [
@ -657,6 +669,8 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.PGVECTO_RS
| VectorType.BAIDU
| VectorType.VIKINGDB
| VectorType.UPSTASH
| VectorType.OCEANBASE
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (
@ -667,6 +681,7 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.MYSCALE
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
| VectorType.COUCHBASE
| VectorType.PGVECTOR
):
return {

View File

@ -41,7 +41,7 @@ class AlreadyActivateError(BaseHTTPException):
class NotAllowedCreateWorkspace(BaseHTTPException):
error_code = "unauthorized"
error_code = "not_allowed_create_workspace"
description = "Workspace not found, please contact system admin to invite you to join in a workspace."
code = 400

View File

@ -21,7 +21,12 @@ class AppParameterApi(InstalledAppResource):
"options": fields.List(fields.String),
}
system_parameters_fields = {"image_file_size_limit": fields.String}
system_parameters_fields = {
"image_file_size_limit": fields.Integer,
"video_file_size_limit": fields.Integer,
"audio_file_size_limit": fields.Integer,
"file_size_limit": fields.Integer,
}
parameters_fields = {
"opening_statement": fields.String,
@ -82,7 +87,12 @@ class AppParameterApi(InstalledAppResource):
}
},
),
"system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT},
"system_parameters": {
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
},
}

View File

@ -1,5 +1,5 @@
from flask import Response, request
from flask_restful import Resource
from flask_restful import Resource, reqparse
from werkzeug.exceptions import NotFound
import services
@ -41,24 +41,39 @@ class FilePreviewApi(Resource):
def get(self, file_id):
file_id = str(file_id)
timestamp = request.args.get("timestamp")
nonce = request.args.get("nonce")
sign = request.args.get("sign")
parser = reqparse.RequestParser()
parser.add_argument("timestamp", type=str, required=True, location="args")
parser.add_argument("nonce", type=str, required=True, location="args")
parser.add_argument("sign", type=str, required=True, location="args")
parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
if not timestamp or not nonce or not sign:
args = parser.parse_args()
if not args["timestamp"] or not args["nonce"] or not args["sign"]:
return {"content": "Invalid request."}, 400
try:
generator, mimetype = FileService.get_signed_file_preview(
generator, upload_file = FileService.get_file_generator_by_file_id(
file_id=file_id,
timestamp=timestamp,
nonce=nonce,
sign=sign,
timestamp=args["timestamp"],
nonce=args["nonce"],
sign=args["sign"],
)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return Response(generator, mimetype=mimetype)
response = Response(
generator,
mimetype=upload_file.mime_type,
direct_passthrough=True,
headers={},
)
if upload_file.size > 0:
response.headers["Content-Length"] = str(upload_file.size)
if args["as_attachment"]:
response.headers["Content-Disposition"] = f"attachment; filename={upload_file.name}"
return response
class WorkspaceWebappLogoApi(Resource):

View File

@ -21,7 +21,7 @@ class EnterpriseWorkspace(Resource):
if account is None:
return {"message": "owner account not found."}, 404
tenant = TenantService.create_tenant(args["name"])
tenant = TenantService.create_tenant(args["name"], is_from_dashboard=True)
TenantService.create_tenant_member(tenant, account, role="owner")
tenant_was_created.send(tenant)

View File

@ -21,7 +21,12 @@ class AppParameterApi(Resource):
"options": fields.List(fields.String),
}
system_parameters_fields = {"image_file_size_limit": fields.String}
system_parameters_fields = {
"image_file_size_limit": fields.Integer,
"video_file_size_limit": fields.Integer,
"audio_file_size_limit": fields.Integer,
"file_size_limit": fields.Integer,
}
parameters_fields = {
"opening_statement": fields.String,
@ -81,7 +86,12 @@ class AppParameterApi(Resource):
}
},
),
"system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT},
"system_parameters": {
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
},
}

View File

@ -66,6 +66,13 @@ class DatasetListApi(DatasetApiResource):
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
type=str,
nullable=True,
required=False,
default="",
)
parser.add_argument(
"indexing_technique",
type=str,
@ -108,6 +115,7 @@ class DatasetListApi(DatasetApiResource):
dataset = DatasetService.create_empty_dataset(
tenant_id=tenant_id,
name=args["name"],
description=args["description"],
indexing_technique=args["indexing_technique"],
account=current_user,
permission=args["permission"],

View File

@ -230,7 +230,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": document.batch}
return documents_and_batch_fields, 200

View File

@ -21,7 +21,12 @@ class AppParameterApi(WebApiResource):
"options": fields.List(fields.String),
}
system_parameters_fields = {"image_file_size_limit": fields.String}
system_parameters_fields = {
"image_file_size_limit": fields.Integer,
"video_file_size_limit": fields.Integer,
"audio_file_size_limit": fields.Integer,
"file_size_limit": fields.Integer,
}
parameters_fields = {
"opening_statement": fields.String,
@ -80,7 +85,12 @@ class AppParameterApi(WebApiResource):
}
},
),
"system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT},
"system_parameters": {
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
},
}

View File

@ -156,6 +156,12 @@ class BaseAgentRunner(AppRunner):
continue
parameter_type = parameter.type.as_normal_type()
if parameter.type in {
ToolParameter.ToolParameterType.SYSTEM_FILES,
ToolParameter.ToolParameterType.FILE,
ToolParameter.ToolParameterType.FILES,
}:
continue
enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options]
@ -243,6 +249,12 @@ class BaseAgentRunner(AppRunner):
continue
parameter_type = parameter.type.as_normal_type()
if parameter.type in {
ToolParameter.ToolParameterType.SYSTEM_FILES,
ToolParameter.ToolParameterType.FILE,
ToolParameter.ToolParameterType.FILES,
}:
continue
enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options]

View File

@ -27,6 +27,7 @@ from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBa
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from extensions.ext_database import db
from models import Account
from models.enums import CreatedByRole
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError
@ -240,7 +241,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
belongs_to="user",
url=file.remote_url,
upload_file_id=file.related_id,
created_by_role=("account" if account_id else "end_user"),
created_by_role=(CreatedByRole.ACCOUNT if account_id else CreatedByRole.END_USER),
created_by=account_id or end_user_id or "",
)
db.session.add(message_file)

View File

@ -76,8 +76,16 @@ def to_prompt_message_content(f: File, /):
def download(f: File, /):
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
return _download_file_content(upload_file.key)
if f.transfer_method == FileTransferMethod.TOOL_FILE:
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
return _download_file_content(tool_file.file_key)
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
return _download_file_content(upload_file.key)
# remote file
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
response.raise_for_status()
return response.content
def _download_file_content(path: str, /):

View File

@ -8,6 +8,8 @@ from core.llm_generator.output_parser.suggested_questions_after_answer import Su
from core.llm_generator.prompts import (
CONVERSATION_TITLE_PROMPT,
GENERATOR_QA_PROMPT,
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
)
from core.model_manager import ModelManager
@ -239,6 +241,54 @@ class LLMGenerator:
return rule_config
@classmethod
def generate_code(
cls,
tenant_id: str,
instruction: str,
model_config: dict,
code_language: str = "javascript",
max_tokens: int = 1000,
) -> dict:
if code_language == "python":
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
else:
prompt_template = PromptTemplateParser(JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE)
prompt = prompt_template.format(
inputs={
"INSTRUCTION": instruction,
"CODE_LANGUAGE": code_language,
},
remove_template_variables=False,
)
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.get("provider") if model_config else None,
model=model_config.get("name") if model_config else None,
)
prompt_messages = [UserPromptMessage(content=prompt)]
model_parameters = {"max_tokens": max_tokens, "temperature": 0.01}
try:
response = model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
)
generated_code = response.message.content
return {"code": generated_code, "language": code_language, "error": ""}
except InvokeError as e:
error = str(e)
return {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"}
except Exception as e:
logging.exception(e)
return {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"}
@classmethod
def generate_qa_document(cls, tenant_id: str, query, document_language: str):
prompt = GENERATOR_QA_PROMPT.format(language=document_language)

View File

@ -61,6 +61,73 @@ User Input: yo, 你今天咋样?
User Input:
""" # noqa: E501
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE = (
"You are an expert programmer. Generate code based on the following instructions:\n\n"
"Instructions: {{INSTRUCTION}}\n\n"
"Write the code in {{CODE_LANGUAGE}}.\n\n"
"Please ensure that you meet the following requirements:\n"
"1. Define a function named 'main'.\n"
"2. The 'main' function must return a dictionary (dict).\n"
"3. You may modify the arguments of the 'main' function, but include appropriate type hints.\n"
"4. The returned dictionary should contain at least one key-value pair.\n\n"
"5. You may ONLY use the following libraries in your code: \n"
"- json\n"
"- datetime\n"
"- math\n"
"- random\n"
"- re\n"
"- string\n"
"- sys\n"
"- time\n"
"- traceback\n"
"- uuid\n"
"- os\n"
"- base64\n"
"- hashlib\n"
"- hmac\n"
"- binascii\n"
"- collections\n"
"- functools\n"
"- operator\n"
"- itertools\n\n"
"Example:\n"
"def main(arg1: str, arg2: int) -> dict:\n"
" return {\n"
' "result": arg1 * arg2,\n'
" }\n\n"
"IMPORTANT:\n"
"- Provide ONLY the code without any additional explanations, comments, or markdown formatting.\n"
"- DO NOT use markdown code blocks (``` or ``` python). Return the raw code directly.\n"
"- The code should start immediately after this instruction, without any preceding newlines or spaces.\n"
"- The code should be complete, functional, and follow best practices for {{CODE_LANGUAGE}}.\n\n"
"- Always use the format return {'result': ...} for the output.\n\n"
"Generated Code:\n"
)
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE = (
"You are an expert programmer. Generate code based on the following instructions:\n\n"
"Instructions: {{INSTRUCTION}}\n\n"
"Write the code in {{CODE_LANGUAGE}}.\n\n"
"Please ensure that you meet the following requirements:\n"
"1. Define a function named 'main'.\n"
"2. The 'main' function must return an object.\n"
"3. You may modify the arguments of the 'main' function, but include appropriate JSDoc annotations.\n"
"4. The returned object should contain at least one key-value pair.\n\n"
"5. The returned object should always be in the format: {result: ...}\n\n"
"Example:\n"
"function main(arg1, arg2) {\n"
" return {\n"
" result: arg1 * arg2\n"
" };\n"
"}\n\n"
"IMPORTANT:\n"
"- Provide ONLY the code without any additional explanations, comments, or markdown formatting.\n"
"- DO NOT use markdown code blocks (``` or ``` javascript). Return the raw code directly.\n"
"- The code should start immediately after this instruction, without any preceding newlines or spaces.\n"
"- The code should be complete, functional, and follow best practices for {{CODE_LANGUAGE}}.\n\n"
"Generated Code:\n"
)
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
"Please help me predict the three most likely questions that human would ask, "
"and keeping each question under 20 characters.\n"

View File

@ -105,6 +105,7 @@ class LLMResult(BaseModel):
Model class for llm result.
"""
id: Optional[str] = None
model: str
prompt_messages: list[PromptMessage]
message: AssistantPromptMessage

View File

@ -0,0 +1,9 @@
- claude-3-5-sonnet-20241022
- claude-3-5-sonnet-20240620
- claude-3-haiku-20240307
- claude-3-opus-20240229
- claude-3-sonnet-20240229
- claude-2.1
- claude-instant-1.2
- claude-2
- claude-instant-1

View File

@ -0,0 +1,39 @@
model: claude-3-5-sonnet-20241022
label:
en_US: claude-3-5-sonnet-20241022
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
pricing:
input: '3.00'
output: '15.00'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,245 @@
provider: azure_openai
label:
en_US: Azure OpenAI Service Model
icon_small:
en_US: icon_s_en.svg
icon_large:
en_US: icon_l_en.png
background: "#E3F0FF"
help:
title:
en_US: Get your API key from Azure
zh_Hans: 从 Azure 获取 API Key
url:
en_US: https://azure.microsoft.com/en-us/products/ai-services/openai-service
supported_model_types:
- llm
- text-embedding
- speech2text
- tts
configurate_methods:
- customizable-model
model_credential_schema:
model:
label:
en_US: Deployment Name
zh_Hans: 部署名称
placeholder:
en_US: Enter your Deployment Name here, matching the Azure deployment name.
zh_Hans: 在此输入您的部署名称,与 Azure 部署名称匹配。
credential_form_schemas:
- variable: openai_api_base
label:
en_US: API Endpoint URL
zh_Hans: API 域名
type: text-input
required: true
placeholder:
zh_Hans: '在此输入您的 API 域名https://example.com/xxx'
en_US: 'Enter your API Endpoint, eg: https://example.com/xxx'
- variable: openai_api_key
label:
en_US: API Key
zh_Hans: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API key here
- variable: openai_api_version
label:
zh_Hans: API 版本
en_US: API Version
type: select
required: true
options:
- label:
en_US: 2024-10-01-preview
value: 2024-10-01-preview
- label:
en_US: 2024-09-01-preview
value: 2024-09-01-preview
- label:
en_US: 2024-08-01-preview
value: 2024-08-01-preview
- label:
en_US: 2024-07-01-preview
value: 2024-07-01-preview
- label:
en_US: 2024-05-01-preview
value: 2024-05-01-preview
- label:
en_US: 2024-04-01-preview
value: 2024-04-01-preview
- label:
en_US: 2024-03-01-preview
value: 2024-03-01-preview
- label:
en_US: 2024-02-15-preview
value: 2024-02-15-preview
- label:
en_US: 2023-12-01-preview
value: 2023-12-01-preview
- label:
en_US: '2024-02-01'
value: '2024-02-01'
- label:
en_US: '2024-06-01'
value: '2024-06-01'
placeholder:
zh_Hans: 在此选择您的 API 版本
en_US: Select your API Version here
- variable: base_model_name
label:
en_US: Base Model
zh_Hans: 基础模型
type: select
required: true
options:
- label:
en_US: gpt-35-turbo
value: gpt-35-turbo
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-35-turbo-0125
value: gpt-35-turbo-0125
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-35-turbo-16k
value: gpt-35-turbo-16k
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4
value: gpt-4
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4-32k
value: gpt-4-32k
show_on:
- variable: __model_type
value: llm
- label:
en_US: o1-mini
value: o1-mini
show_on:
- variable: __model_type
value: llm
- label:
en_US: o1-preview
value: o1-preview
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4o-mini
value: gpt-4o-mini
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4o-mini-2024-07-18
value: gpt-4o-mini-2024-07-18
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4o
value: gpt-4o
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4o-2024-05-13
value: gpt-4o-2024-05-13
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4o-2024-08-06
value: gpt-4o-2024-08-06
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4-turbo
value: gpt-4-turbo
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4-turbo-2024-04-09
value: gpt-4-turbo-2024-04-09
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4-0125-preview
value: gpt-4-0125-preview
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4-1106-preview
value: gpt-4-1106-preview
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4-vision-preview
value: gpt-4-vision-preview
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-35-turbo-instruct
value: gpt-35-turbo-instruct
show_on:
- variable: __model_type
value: llm
- label:
en_US: text-embedding-ada-002
value: text-embedding-ada-002
show_on:
- variable: __model_type
value: text-embedding
- label:
en_US: text-embedding-3-small
value: text-embedding-3-small
show_on:
- variable: __model_type
value: text-embedding
- label:
en_US: text-embedding-3-large
value: text-embedding-3-large
show_on:
- variable: __model_type
value: text-embedding
- label:
en_US: whisper-1
value: whisper-1
show_on:
- variable: __model_type
value: speech2text
- label:
en_US: tts-1
value: tts-1
show_on:
- variable: __model_type
value: tts
- label:
en_US: tts-1-hd
value: tts-1-hd
show_on:
- variable: __model_type
value: tts
placeholder:
zh_Hans: 在此输入您的模型版本
en_US: Enter your model version

View File

@ -0,0 +1,764 @@
import copy
import json
import logging
from collections.abc import Generator, Sequence
from typing import Optional, Union, cast
import tiktoken
from openai import AzureOpenAI, Stream
from openai.types import Completion
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
PromptMessageFunction,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS
from core.model_runtime.utils import helper
logger = logging.getLogger(__name__)
class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
base_model_name = self._get_base_model_name(credentials)
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
# chat model
return self._chat_generate(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
)
else:
# text completion model
return self._generate(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
stop=stop,
stream=stream,
user=user,
)
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
base_model_name = self._get_base_model_name(credentials)
model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
if not model_entity:
raise ValueError(f"Base Model Name {base_model_name} is invalid")
model_mode = model_entity.entity.model_properties.get(ModelPropertyKey.MODE)
if model_mode == LLMMode.CHAT.value:
# chat model
return self._num_tokens_from_messages(credentials, prompt_messages, tools)
else:
# text completion model, do not support tool calling
content = prompt_messages[0].content
assert isinstance(content, str)
return self._num_tokens_from_string(credentials, content)
def validate_credentials(self, model: str, credentials: dict) -> None:
if "openai_api_base" not in credentials:
raise CredentialsValidateFailedError("Azure OpenAI API Base Endpoint is required")
if "openai_api_key" not in credentials:
raise CredentialsValidateFailedError("Azure OpenAI API key is required")
if "base_model_name" not in credentials:
raise CredentialsValidateFailedError("Base Model Name is required")
base_model_name = self._get_base_model_name(credentials)
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
if not ai_model_entity:
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
try:
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
if model.startswith("o1"):
client.chat.completions.create(
messages=[{"role": "user", "content": "ping"}],
model=model,
temperature=1,
max_completion_tokens=20,
stream=False,
)
elif ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
# chat model
client.chat.completions.create(
messages=[{"role": "user", "content": "ping"}],
model=model,
temperature=0,
max_tokens=20,
stream=False,
)
else:
# text completion model
client.completions.create(
prompt="ping",
model=model,
temperature=0,
max_tokens=20,
stream=False,
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
base_model_name = self._get_base_model_name(credentials)
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
return ai_model_entity.entity if ai_model_entity else None
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
extra_model_kwargs = {}
if stop:
extra_model_kwargs["stop"] = stop
if user:
extra_model_kwargs["user"] = user
# text completion model
response = client.completions.create(
prompt=prompt_messages[0].content, model=model, stream=stream, **model_parameters, **extra_model_kwargs
)
if stream:
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_response(
self, model: str, credentials: dict, response: Completion, prompt_messages: list[PromptMessage]
):
assistant_text = response.choices[0].text
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=assistant_text)
# calculate num tokens
if response.usage:
# transform usage
prompt_tokens = response.usage.prompt_tokens
completion_tokens = response.usage.completion_tokens
else:
# calculate num tokens
content = prompt_messages[0].content
assert isinstance(content, str)
prompt_tokens = self._num_tokens_from_string(credentials, content)
completion_tokens = self._num_tokens_from_string(credentials, assistant_text)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
# transform response
result = LLMResult(
model=response.model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage,
system_fingerprint=response.system_fingerprint,
)
return result
def _handle_generate_stream_response(
self, model: str, credentials: dict, response: Stream[Completion], prompt_messages: list[PromptMessage]
) -> Generator:
full_text = ""
for chunk in response:
if len(chunk.choices) == 0:
continue
delta = chunk.choices[0]
if delta.finish_reason is None and (delta.text is None or delta.text == ""):
continue
# transform assistant message to prompt message
text = delta.text or ""
assistant_prompt_message = AssistantPromptMessage(content=text)
full_text += text
if delta.finish_reason is not None:
# calculate num tokens
if chunk.usage:
# transform usage
prompt_tokens = chunk.usage.prompt_tokens
completion_tokens = chunk.usage.completion_tokens
else:
# calculate num tokens
content = prompt_messages[0].content
assert isinstance(content, str)
prompt_tokens = self._num_tokens_from_string(credentials, content)
completion_tokens = self._num_tokens_from_string(credentials, full_text)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk(
model=chunk.model,
prompt_messages=prompt_messages,
system_fingerprint=chunk.system_fingerprint,
delta=LLMResultChunkDelta(
index=delta.index,
message=assistant_prompt_message,
finish_reason=delta.finish_reason,
usage=usage,
),
)
else:
yield LLMResultChunk(
model=chunk.model,
prompt_messages=prompt_messages,
system_fingerprint=chunk.system_fingerprint,
delta=LLMResultChunkDelta(
index=delta.index,
message=assistant_prompt_message,
),
)
def _chat_generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
response_format = model_parameters.get("response_format")
if response_format:
if response_format == "json_schema":
json_schema = model_parameters.get("json_schema")
if not json_schema:
raise ValueError("Must define JSON Schema when the response format is json_schema")
try:
schema = json.loads(json_schema)
except:
raise ValueError(f"not correct json_schema format: {json_schema}")
model_parameters.pop("json_schema")
model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema}
else:
model_parameters["response_format"] = {"type": response_format}
extra_model_kwargs = {}
if tools:
extra_model_kwargs["tools"] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
if stop:
extra_model_kwargs["stop"] = stop
if user:
extra_model_kwargs["user"] = user
# clear illegal prompt messages
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
block_as_stream = False
if model.startswith("o1"):
if stream:
block_as_stream = True
stream = False
if "stream_options" in extra_model_kwargs:
del extra_model_kwargs["stream_options"]
if "stop" in extra_model_kwargs:
del extra_model_kwargs["stop"]
# chat model
response = client.chat.completions.create(
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
model=model,
stream=stream,
**model_parameters,
**extra_model_kwargs,
)
if stream:
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
if block_as_stream:
return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop)
return block_result
def _handle_chat_block_as_stream_response(
self,
block_result: LLMResult,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None,
) -> Generator[LLMResultChunk, None, None]:
"""
Handle llm chat response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:param stop: stop words
:return: llm response chunk generator
"""
text = block_result.message.content
text = cast(str, text)
if stop:
text = self.enforce_stop_tokens(text, stop)
yield LLMResultChunk(
model=block_result.model,
prompt_messages=prompt_messages,
system_fingerprint=block_result.system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=text),
finish_reason="stop",
usage=block_result.usage,
),
)
def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Clear illegal prompt messages for OpenAI API
:param model: model name
:param prompt_messages: prompt messages
:return: cleaned prompt messages
"""
checklist = ["gpt-4-turbo", "gpt-4-turbo-2024-04-09"]
if model in checklist:
# count how many user messages are there
user_message_count = len([m for m in prompt_messages if isinstance(m, UserPromptMessage)])
if user_message_count > 1:
for prompt_message in prompt_messages:
if isinstance(prompt_message, UserPromptMessage):
if isinstance(prompt_message.content, list):
prompt_message.content = "\n".join(
[
item.data
if item.type == PromptMessageContentType.TEXT
else "[IMAGE]"
if item.type == PromptMessageContentType.IMAGE
else ""
for item in prompt_message.content
]
)
if model.startswith("o1"):
system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)])
if system_message_count > 0:
new_prompt_messages = []
for prompt_message in prompt_messages:
if isinstance(prompt_message, SystemPromptMessage):
prompt_message = UserPromptMessage(
content=prompt_message.content,
name=prompt_message.name,
)
new_prompt_messages.append(prompt_message)
prompt_messages = new_prompt_messages
return prompt_messages
def _handle_chat_generate_response(
self,
model: str,
credentials: dict,
response: ChatCompletion,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
):
assistant_message = response.choices[0].message
assistant_message_tool_calls = assistant_message.tool_calls
# extract tool calls from response
tool_calls = []
self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=assistant_message_tool_calls)
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls)
# calculate num tokens
if response.usage:
# transform usage
prompt_tokens = response.usage.prompt_tokens
completion_tokens = response.usage.completion_tokens
else:
# calculate num tokens
prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools)
completion_tokens = self._num_tokens_from_messages(credentials, [assistant_prompt_message])
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
# transform response
result = LLMResult(
model=response.model or model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage,
system_fingerprint=response.system_fingerprint,
)
return result
def _handle_chat_generate_stream_response(
self,
model: str,
credentials: dict,
response: Stream[ChatCompletionChunk],
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
):
index = 0
full_assistant_content = ""
real_model = model
system_fingerprint = None
completion = ""
tool_calls = []
for chunk in response:
if len(chunk.choices) == 0:
continue
delta = chunk.choices[0]
# NOTE: For fix https://github.com/langgenius/dify/issues/5790
if delta.delta is None:
continue
# extract tool calls from response
self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=delta.delta.tool_calls)
# Handling exceptions when content filters' streaming mode is set to asynchronous modified filter
if delta.finish_reason is None and not delta.delta.content:
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls)
full_assistant_content += delta.delta.content or ""
real_model = chunk.model
system_fingerprint = chunk.system_fingerprint
completion += delta.delta.content or ""
yield LLMResultChunk(
model=real_model,
prompt_messages=prompt_messages,
system_fingerprint=system_fingerprint,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
),
)
index += 1
# calculate num tokens
prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools)
full_assistant_prompt_message = AssistantPromptMessage(content=completion)
completion_tokens = self._num_tokens_from_messages(credentials, [full_assistant_prompt_message])
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk(
model=real_model,
prompt_messages=prompt_messages,
system_fingerprint=system_fingerprint,
delta=LLMResultChunkDelta(
index=index, message=AssistantPromptMessage(content=""), finish_reason="stop", usage=usage
),
)
@staticmethod
def _update_tool_calls(
tool_calls: list[AssistantPromptMessage.ToolCall],
tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]],
) -> None:
if tool_calls_response:
for response_tool_call in tool_calls_response:
if isinstance(response_tool_call, ChatCompletionMessageToolCall):
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.function.name, arguments=response_tool_call.function.arguments
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call.id, type=response_tool_call.type, function=function
)
tool_calls.append(tool_call)
elif isinstance(response_tool_call, ChoiceDeltaToolCall):
index = response_tool_call.index
if index < len(tool_calls):
tool_calls[index].id = response_tool_call.id or tool_calls[index].id
tool_calls[index].type = response_tool_call.type or tool_calls[index].type
if response_tool_call.function:
tool_calls[index].function.name = (
response_tool_call.function.name or tool_calls[index].function.name
)
tool_calls[index].function.arguments += response_tool_call.function.arguments or ""
else:
assert response_tool_call.id is not None
assert response_tool_call.type is not None
assert response_tool_call.function is not None
assert response_tool_call.function.name is not None
assert response_tool_call.function.arguments is not None
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.function.name, arguments=response_tool_call.function.arguments
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call.id, type=response_tool_call.type, function=function
)
tool_calls.append(tool_call)
@staticmethod
def _convert_prompt_message_to_dict(message: PromptMessage):
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
sub_messages = []
assert message.content is not None
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content)
sub_message_dict = {"type": "text", "text": message_content.data}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
sub_message_dict = {
"type": "image_url",
"image_url": {"url": message_content.data, "detail": message_content.detail.value},
}
sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
# message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls:
message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in message.tool_calls]
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {
"role": "tool",
"name": message.name,
"content": message.content,
"tool_call_id": message.tool_call_id,
}
else:
raise ValueError(f"Got unknown type {message}")
if message.name:
message_dict["name"] = message.name
return message_dict
def _num_tokens_from_string(
self, credentials: dict, text: str, tools: Optional[list[PromptMessageTool]] = None
) -> int:
try:
encoding = tiktoken.encoding_for_model(credentials["base_model_name"])
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = len(encoding.encode(text))
if tools:
num_tokens += self._num_tokens_for_tools(encoding, tools)
return num_tokens
def _num_tokens_from_messages(
self, credentials: dict, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
model = credentials["base_model_name"]
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
encoding = tiktoken.get_encoding(model)
if model.startswith("gpt-35-turbo-0301"):
# every message follows <im_start>{role/name}\n{content}<im_end>\n
tokens_per_message = 4
# if there's a name, the role is omitted
tokens_per_name = -1
elif model.startswith("gpt-35-turbo") or model.startswith("gpt-4") or model.startswith("o1"):
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(
f"get_num_tokens_from_messages() is not presently implemented "
f"for model {model}."
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
"information on how messages are converted to tokens."
)
num_tokens = 0
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():
# Cast str(value) in case the message value is not a string
# This occurs with function messages
# TODO: The current token calculation method for the image type is not implemented,
# which need to download the image and then get the resolution for calculation,
# and will increase the request delay
if isinstance(value, list):
text = ""
for item in value:
if isinstance(item, dict) and item["type"] == "text":
text += item["text"]
value = text
if key == "tool_calls":
for tool_call in value:
assert isinstance(tool_call, dict)
for t_key, t_value in tool_call.items():
num_tokens += len(encoding.encode(t_key))
if t_key == "function":
for f_key, f_value in t_value.items():
num_tokens += len(encoding.encode(f_key))
num_tokens += len(encoding.encode(f_value))
else:
num_tokens += len(encoding.encode(t_key))
num_tokens += len(encoding.encode(t_value))
else:
num_tokens += len(encoding.encode(str(value)))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens += 3
if tools:
num_tokens += self._num_tokens_for_tools(encoding, tools)
return num_tokens
@staticmethod
def _num_tokens_for_tools(encoding: tiktoken.Encoding, tools: list[PromptMessageTool]) -> int:
num_tokens = 0
for tool in tools:
num_tokens += len(encoding.encode("type"))
num_tokens += len(encoding.encode("function"))
# calculate num tokens for function object
num_tokens += len(encoding.encode("name"))
num_tokens += len(encoding.encode(tool.name))
num_tokens += len(encoding.encode("description"))
num_tokens += len(encoding.encode(tool.description))
parameters = tool.parameters
num_tokens += len(encoding.encode("parameters"))
if "title" in parameters:
num_tokens += len(encoding.encode("title"))
num_tokens += len(encoding.encode(parameters["title"]))
num_tokens += len(encoding.encode("type"))
num_tokens += len(encoding.encode(parameters["type"]))
if "properties" in parameters:
num_tokens += len(encoding.encode("properties"))
for key, value in parameters["properties"].items():
num_tokens += len(encoding.encode(key))
for field_key, field_value in value.items():
num_tokens += len(encoding.encode(field_key))
if field_key == "enum":
for enum_field in field_value:
num_tokens += 3
num_tokens += len(encoding.encode(enum_field))
else:
num_tokens += len(encoding.encode(field_key))
num_tokens += len(encoding.encode(str(field_value)))
if "required" in parameters:
num_tokens += len(encoding.encode("required"))
for required_field in parameters["required"]:
num_tokens += 3
num_tokens += len(encoding.encode(required_field))
return num_tokens
@staticmethod
def _get_ai_model_entity(base_model_name: str, model: str):
for ai_model_entity in LLM_BASE_MODELS:
if ai_model_entity.base_model_name == base_model_name:
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
ai_model_entity_copy.entity.model = model
ai_model_entity_copy.entity.label.en_US = model
ai_model_entity_copy.entity.label.zh_Hans = model
return ai_model_entity_copy
def _get_base_model_name(self, credentials: dict) -> str:
base_model_name = credentials.get("base_model_name")
if not base_model_name:
raise ValueError("Base Model Name is required")
return base_model_name

View File

@ -0,0 +1,60 @@
model: anthropic.claude-3-5-sonnet-20241022-v2:0
label:
en_US: Claude 3.5 Sonnet V2
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
- name: response_format
use_template: response_format
pricing:
input: '0.003'
output: '0.015'
unit: '0.001'
currency: USD

View File

@ -0,0 +1,60 @@
model: eu.anthropic.claude-3-5-sonnet-20241022-v2:0
label:
en_US: Claude 3.5 Sonnet V2(EU.Cross Region Inference)
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
- name: response_format
use_template: response_format
pricing:
input: '0.003'
output: '0.015'
unit: '0.001'
currency: USD

View File

@ -0,0 +1,60 @@
model: us.anthropic.claude-3-5-sonnet-20241022-v2:0
label:
en_US: Claude 3.5 Sonnet V2(US.Cross Region Inference)
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
- name: response_format
use_template: response_format
pricing:
input: '0.003'
output: '0.015'
unit: '0.001'
currency: USD

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 9.8 KiB

View File

@ -0,0 +1,3 @@
<svg width="40" height="40" viewBox="0 0 40 40" fill="none" xmlns="http://www.w3.org/2000/svg">
<path fill-rule="evenodd" clip-rule="evenodd" d="M25.132 24.3947C25.497 25.7527 25.8984 27.1413 26.3334 28.5834C26.7302 29.8992 25.5459 30.4167 25.0752 29.1758C24.571 27.8466 24.0885 26.523 23.6347 25.1729C21.065 26.4654 18.5025 27.5424 15.5961 28.7541C16.7581 33.0256 17.8309 36.5984 19.4952 39.9935C19.4953 39.9936 19.4953 39.9937 19.4954 39.9938C19.6631 39.9979 19.8313 40 20 40C31.0457 40 40 31.0457 40 20C40 16.0335 38.8453 12.3366 36.8537 9.22729C31.6585 9.69534 27.0513 10.4562 22.8185 11.406C22.8882 12.252 22.9677 13.0739 23.0555 13.855C23.3824 16.7604 23.9112 19.5281 24.6137 22.3836C27.0581 21.2848 29.084 20.3225 30.6816 19.522C32.2154 18.7535 33.6943 18.7062 31.2018 20.6594C29.0388 22.1602 27.0644 23.3566 25.132 24.3947ZM36.1559 8.20846C33.0001 3.89184 28.1561 0.887462 22.5955 0.166882C22.4257 2.86234 22.4785 6.26344 22.681 9.50447C26.7473 8.88859 31.1721 8.46032 36.1559 8.20846ZM19.9369 9.73661e-05C19.7594 2.92694 19.8384 6.65663 20.19 9.91293C17.3748 10.4109 14.7225 11.0064 12.1592 11.7038C12.0486 10.4257 11.9927 9.25764 11.9927 8.24178C11.9927 7.5054 11.3957 6.90844 10.6593 6.90844C9.92296 6.90844 9.32601 7.5054 9.32601 8.24178C9.32601 9.47868 9.42873 10.898 9.61402 12.438C8.33567 12.8278 7.07397 13.2443 5.81918 13.688C5.12493 13.9336 4.76118 14.6954 5.0067 15.3896C5.25223 16.0839 6.01406 16.4476 6.7083 16.2021C7.7931 15.8185 8.88482 15.4388 9.98927 15.0659C10.5222 18.3344 11.3344 21.9428 12.2703 25.4156C12.4336 26.0218 12.6062 26.6262 12.7863 27.2263C9.34168 28.4135 5.82612 29.3782 2.61128 29.8879C0.949407 26.9716 0 23.5967 0 20C0 8.97534 8.92023 0.0341108 19.9369 9.73661e-05ZM4.19152 32.2527C7.45069 36.4516 12.3458 39.3173 17.9204 39.8932C16.5916 37.455 14.9338 33.717 13.5405 29.5901C10.4404 30.7762 7.25883 31.6027 4.19152 32.2527ZM22.9735 23.1135C22.1479 20.41 21.4462 17.5441 20.9225 14.277C20.746 13.5841 20.5918 12.8035 20.4593 11.9636C17.6508 12.6606 14.9992 13.4372 12.4356 14.2598C12.8479 17.4766 13.5448 21.1334 14.5118 24.7218C14.662 25.2792 14.8081 25.8248 14.9514 26.3594L14.9516 26.3603L14.9524 26.3634L14.9526 26.3639L14.973 26.4401C16.1833 25.9872 17.3746 25.5123 18.53 25.0259C20.1235 24.3552 21.6051 23.7165 22.9735 23.1135Z" fill="#141519"/>
</svg>

After

Width:  |  Height:  |  Size: 2.2 KiB

View File

@ -0,0 +1,47 @@
from dashscope.common.error import (
AuthenticationError,
InvalidParameter,
RequestFailure,
ServiceUnavailableError,
UnsupportedHTTPMethod,
UnsupportedModel,
)
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
class _CommonGiteeAI:
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
RequestFailure,
],
InvokeServerUnavailableError: [
ServiceUnavailableError,
],
InvokeRateLimitError: [],
InvokeAuthorizationError: [
AuthenticationError,
],
InvokeBadRequestError: [
InvalidParameter,
UnsupportedModel,
UnsupportedHTTPMethod,
],
}

View File

@ -0,0 +1,25 @@
import logging
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class GiteeAIProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.LLM)
model_instance.validate_credentials(model="Qwen2-7B-Instruct", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex

View File

@ -0,0 +1,35 @@
provider: gitee_ai
label:
en_US: Gitee AI
zh_Hans: Gitee AI
description:
en_US: 快速体验大模型,领先探索 AI 开源世界
zh_Hans: 快速体验大模型,领先探索 AI 开源世界
icon_small:
en_US: Gitee-AI-Logo.svg
icon_large:
en_US: Gitee-AI-Logo-full.svg
help:
title:
en_US: Get your token from Gitee AI
zh_Hans: 从 Gitee AI 获取 token
url:
en_US: https://ai.gitee.com/dashboard/settings/tokens
supported_model_types:
- llm
- text-embedding
- rerank
- speech2text
- tts
configurate_methods:
- predefined-model
provider_credential_schema:
credential_form_schemas:
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key

View File

@ -0,0 +1,105 @@
model: Qwen2-72B-Instruct
label:
zh_Hans: Qwen2-72B-Instruct
en_US: Qwen2-72B-Instruct
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 6400
parameter_rules:
- name: stream
use_template: boolean
label:
en_US: "Stream"
zh_Hans: "流式"
type: boolean
default: true
required: true
help:
en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process."
zh_Hans: "是否通过流式分批返回结果。如果设置为 true生成过程中实时地向用户推送每一部分生成的文本。"
- name: max_tokens
use_template: max_tokens
label:
en_US: "Max Tokens"
zh_Hans: "最大Token数"
type: int
default: 512
min: 1
required: true
help:
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
- name: temperature
use_template: temperature
label:
en_US: "Temperature"
zh_Hans: "采样温度"
type: float
default: 0.7
min: 0.0
max: 1.0
precision: 1
required: true
help:
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: top_p
use_template: top_p
label:
en_US: "Top P"
zh_Hans: "Top P"
type: float
default: 0.7
min: 0.0
max: 1.0
precision: 1
required: true
help:
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: top_k
use_template: top_k
label:
en_US: "Top K"
zh_Hans: "Top K"
type: int
default: 50
min: 0
max: 100
required: true
help:
en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be."
zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。"
- name: frequency_penalty
use_template: frequency_penalty
label:
en_US: "Frequency Penalty"
zh_Hans: "频率惩罚"
type: float
default: 0
min: -1.0
max: 1.0
precision: 1
required: false
help:
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
- name: user
use_template: text
label:
en_US: "User"
zh_Hans: "用户"
type: string
required: false
help:
en_US: "Used to track and differentiate conversation requests from different users."
zh_Hans: "用于追踪和区分不同用户的对话请求。"

View File

@ -0,0 +1,105 @@
model: Qwen2-7B-Instruct
label:
zh_Hans: Qwen2-7B-Instruct
en_US: Qwen2-7B-Instruct
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 32768
parameter_rules:
- name: stream
use_template: boolean
label:
en_US: "Stream"
zh_Hans: "流式"
type: boolean
default: true
required: true
help:
en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process."
zh_Hans: "是否通过流式分批返回结果。如果设置为 true生成过程中实时地向用户推送每一部分生成的文本。"
- name: max_tokens
use_template: max_tokens
label:
en_US: "Max Tokens"
zh_Hans: "最大Token数"
type: int
default: 512
min: 1
required: true
help:
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
- name: temperature
use_template: temperature
label:
en_US: "Temperature"
zh_Hans: "采样温度"
type: float
default: 0.7
min: 0.0
max: 1.0
precision: 1
required: true
help:
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: top_p
use_template: top_p
label:
en_US: "Top P"
zh_Hans: "Top P"
type: float
default: 0.7
min: 0.0
max: 1.0
precision: 1
required: true
help:
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: top_k
use_template: top_k
label:
en_US: "Top K"
zh_Hans: "Top K"
type: int
default: 50
min: 0
max: 100
required: true
help:
en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be."
zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。"
- name: frequency_penalty
use_template: frequency_penalty
label:
en_US: "Frequency Penalty"
zh_Hans: "频率惩罚"
type: float
default: 0
min: -1.0
max: 1.0
precision: 1
required: false
help:
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
- name: user
use_template: text
label:
en_US: "User"
zh_Hans: "用户"
type: string
required: false
help:
en_US: "Used to track and differentiate conversation requests from different users."
zh_Hans: "用于追踪和区分不同用户的对话请求。"

View File

@ -0,0 +1,105 @@
model: Yi-1.5-34B-Chat
label:
zh_Hans: Yi-1.5-34B-Chat
en_US: Yi-1.5-34B-Chat
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 4096
parameter_rules:
- name: stream
use_template: boolean
label:
en_US: "Stream"
zh_Hans: "流式"
type: boolean
default: true
required: true
help:
en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process."
zh_Hans: "是否通过流式分批返回结果。如果设置为 true生成过程中实时地向用户推送每一部分生成的文本。"
- name: max_tokens
use_template: max_tokens
label:
en_US: "Max Tokens"
zh_Hans: "最大Token数"
type: int
default: 512
min: 1
required: true
help:
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
- name: temperature
use_template: temperature
label:
en_US: "Temperature"
zh_Hans: "采样温度"
type: float
default: 0.7
min: 0.0
max: 1.0
precision: 1
required: true
help:
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: top_p
use_template: top_p
label:
en_US: "Top P"
zh_Hans: "Top P"
type: float
default: 0.7
min: 0.0
max: 1.0
precision: 1
required: true
help:
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: top_k
use_template: top_k
label:
en_US: "Top K"
zh_Hans: "Top K"
type: int
default: 50
min: 0
max: 100
required: true
help:
en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be."
zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。"
- name: frequency_penalty
use_template: frequency_penalty
label:
en_US: "Frequency Penalty"
zh_Hans: "频率惩罚"
type: float
default: 0
min: -1.0
max: 1.0
precision: 1
required: false
help:
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
- name: user
use_template: text
label:
en_US: "User"
zh_Hans: "用户"
type: string
required: false
help:
en_US: "Used to track and differentiate conversation requests from different users."
zh_Hans: "用于追踪和区分不同用户的对话请求。"

View File

@ -0,0 +1,7 @@
- Qwen2-7B-Instruct
- Qwen2-72B-Instruct
- Yi-1.5-34B-Chat
- glm-4-9b-chat
- deepseek-coder-33B-instruct-chat
- deepseek-coder-33B-instruct-completions
- codegeex4-all-9b

View File

@ -0,0 +1,105 @@
model: codegeex4-all-9b
label:
zh_Hans: codegeex4-all-9b
en_US: codegeex4-all-9b
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 40960
parameter_rules:
- name: stream
use_template: boolean
label:
en_US: "Stream"
zh_Hans: "流式"
type: boolean
default: true
required: true
help:
en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process."
zh_Hans: "是否通过流式分批返回结果。如果设置为 true生成过程中实时地向用户推送每一部分生成的文本。"
- name: max_tokens
use_template: max_tokens
label:
en_US: "Max Tokens"
zh_Hans: "最大Token数"
type: int
default: 512
min: 1
required: true
help:
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
- name: temperature
use_template: temperature
label:
en_US: "Temperature"
zh_Hans: "采样温度"
type: float
default: 0.7
min: 0.0
max: 1.0
precision: 1
required: true
help:
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: top_p
use_template: top_p
label:
en_US: "Top P"
zh_Hans: "Top P"
type: float
default: 0.7
min: 0.0
max: 1.0
precision: 1
required: true
help:
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: top_k
use_template: top_k
label:
en_US: "Top K"
zh_Hans: "Top K"
type: int
default: 50
min: 0
max: 100
required: true
help:
en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be."
zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。"
- name: frequency_penalty
use_template: frequency_penalty
label:
en_US: "Frequency Penalty"
zh_Hans: "频率惩罚"
type: float
default: 0
min: -1.0
max: 1.0
precision: 1
required: false
help:
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
- name: user
use_template: text
label:
en_US: "User"
zh_Hans: "用户"
type: string
required: false
help:
en_US: "Used to track and differentiate conversation requests from different users."
zh_Hans: "用于追踪和区分不同用户的对话请求。"

View File

@ -0,0 +1,105 @@
model: deepseek-coder-33B-instruct-chat
label:
zh_Hans: deepseek-coder-33B-instruct-chat
en_US: deepseek-coder-33B-instruct-chat
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 9000
parameter_rules:
- name: stream
use_template: boolean
label:
en_US: "Stream"
zh_Hans: "流式"
type: boolean
default: true
required: true
help:
en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process."
zh_Hans: "是否通过流式分批返回结果。如果设置为 true生成过程中实时地向用户推送每一部分生成的文本。"
- name: max_tokens
use_template: max_tokens
label:
en_US: "Max Tokens"
zh_Hans: "最大Token数"
type: int
default: 512
min: 1
required: true
help:
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
- name: temperature
use_template: temperature
label:
en_US: "Temperature"
zh_Hans: "采样温度"
type: float
default: 0.7
min: 0.0
max: 1.0
precision: 1
required: true
help:
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: top_p
use_template: top_p
label:
en_US: "Top P"
zh_Hans: "Top P"
type: float
default: 0.7
min: 0.0
max: 1.0
precision: 1
required: true
help:
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: top_k
use_template: top_k
label:
en_US: "Top K"
zh_Hans: "Top K"
type: int
default: 50
min: 0
max: 100
required: true
help:
en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be."
zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。"
- name: frequency_penalty
use_template: frequency_penalty
label:
en_US: "Frequency Penalty"
zh_Hans: "频率惩罚"
type: float
default: 0
min: -1.0
max: 1.0
precision: 1
required: false
help:
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
- name: user
use_template: text
label:
en_US: "User"
zh_Hans: "用户"
type: string
required: false
help:
en_US: "Used to track and differentiate conversation requests from different users."
zh_Hans: "用于追踪和区分不同用户的对话请求。"

View File

@ -0,0 +1,91 @@
model: deepseek-coder-33B-instruct-completions
label:
zh_Hans: deepseek-coder-33B-instruct-completions
en_US: deepseek-coder-33B-instruct-completions
model_type: llm
features:
- agent-thought
model_properties:
mode: completion
context_size: 9000
parameter_rules:
- name: stream
use_template: boolean
label:
en_US: "Stream"
zh_Hans: "流式"
type: boolean
default: true
required: true
help:
en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process."
zh_Hans: "是否通过流式分批返回结果。如果设置为 true生成过程中实时地向用户推送每一部分生成的文本。"
- name: max_tokens
use_template: max_tokens
label:
en_US: "Max Tokens"
zh_Hans: "最大Token数"
type: int
default: 512
min: 1
required: true
help:
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
- name: temperature
use_template: temperature
label:
en_US: "Temperature"
zh_Hans: "采样温度"
type: float
default: 0.7
min: 0.0
max: 1.0
precision: 1
required: true
help:
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: top_p
use_template: top_p
label:
en_US: "Top P"
zh_Hans: "Top P"
type: float
default: 0.7
min: 0.0
max: 1.0
precision: 1
required: true
help:
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: frequency_penalty
use_template: frequency_penalty
label:
en_US: "Frequency Penalty"
zh_Hans: "频率惩罚"
type: float
default: 0
min: -1.0
max: 1.0
precision: 1
required: false
help:
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
- name: user
use_template: text
label:
en_US: "User"
zh_Hans: "用户"
type: string
required: false
help:
en_US: "Used to track and differentiate conversation requests from different users."
zh_Hans: "用于追踪和区分不同用户的对话请求。"

View File

@ -0,0 +1,105 @@
model: glm-4-9b-chat
label:
zh_Hans: glm-4-9b-chat
en_US: glm-4-9b-chat
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 32768
parameter_rules:
- name: stream
use_template: boolean
label:
en_US: "Stream"
zh_Hans: "流式"
type: boolean
default: true
required: true
help:
en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process."
zh_Hans: "是否通过流式分批返回结果。如果设置为 true生成过程中实时地向用户推送每一部分生成的文本。"
- name: max_tokens
use_template: max_tokens
label:
en_US: "Max Tokens"
zh_Hans: "最大Token数"
type: int
default: 512
min: 1
required: true
help:
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
- name: temperature
use_template: temperature
label:
en_US: "Temperature"
zh_Hans: "采样温度"
type: float
default: 0.7
min: 0.0
max: 1.0
precision: 1
required: true
help:
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: top_p
use_template: top_p
label:
en_US: "Top P"
zh_Hans: "Top P"
type: float
default: 0.7
min: 0.0
max: 1.0
precision: 1
required: true
help:
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: top_k
use_template: top_k
label:
en_US: "Top K"
zh_Hans: "Top K"
type: int
default: 50
min: 0
max: 100
required: true
help:
en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be."
zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。"
- name: frequency_penalty
use_template: frequency_penalty
label:
en_US: "Frequency Penalty"
zh_Hans: "频率惩罚"
type: float
default: 0
min: -1.0
max: 1.0
precision: 1
required: false
help:
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
- name: user
use_template: text
label:
en_US: "User"
zh_Hans: "用户"
type: string
required: false
help:
en_US: "Used to track and differentiate conversation requests from different users."
zh_Hans: "用于追踪和区分不同用户的对话请求。"

View File

@ -0,0 +1,47 @@
from collections.abc import Generator
from typing import Optional, Union
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
from core.model_runtime.entities.message_entities import (
PromptMessage,
PromptMessageTool,
)
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
class GiteeAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
MODEL_TO_IDENTITY: dict[str, str] = {
"Yi-1.5-34B-Chat": "Yi-34B-Chat",
"deepseek-coder-33B-instruct-completions": "deepseek-coder-33B-instruct",
"deepseek-coder-33B-instruct-chat": "deepseek-coder-33B-instruct",
}
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials, model, model_parameters)
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials, model, None)
super().validate_credentials(model, credentials)
@staticmethod
def _add_custom_parameters(credentials: dict, model: str, model_parameters: dict) -> None:
if model is None:
model = "bge-large-zh-v1.5"
model_identity = GiteeAILargeLanguageModel.MODEL_TO_IDENTITY.get(model, model)
credentials["endpoint_url"] = f"https://ai.gitee.com/api/serverless/{model_identity}/"
if model.endswith("completions"):
credentials["mode"] = LLMMode.COMPLETION.value
else:
credentials["mode"] = LLMMode.CHAT.value

View File

@ -0,0 +1 @@
- bge-reranker-v2-m3

View File

@ -0,0 +1,4 @@
model: bge-reranker-v2-m3
model_type: rerank
model_properties:
context_size: 1024

View File

@ -0,0 +1,128 @@
from typing import Optional
import httpx
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
class GiteeAIRerankModel(RerankModel):
"""
Model class for rerank model.
"""
def _invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n documents to return
:param user: unique user id
:return: rerank result
"""
if len(docs) == 0:
return RerankResult(model=model, docs=[])
base_url = credentials.get("base_url", "https://ai.gitee.com/api/serverless")
base_url = base_url.removesuffix("/")
try:
body = {"model": model, "query": query, "documents": docs}
if top_n is not None:
body["top_n"] = top_n
response = httpx.post(
f"{base_url}/{model}/rerank",
json=body,
headers={"Authorization": f"Bearer {credentials.get('api_key')}"},
)
response.raise_for_status()
results = response.json()
rerank_documents = []
for result in results["results"]:
rerank_document = RerankDocument(
index=result["index"],
text=result["document"]["text"],
score=result["relevance_score"],
)
if score_threshold is None or result["relevance_score"] >= score_threshold:
rerank_documents.append(rerank_document)
return RerankResult(model=model, docs=rerank_documents)
except httpx.HTTPStatusError as e:
raise InvokeServerUnavailableError(str(e))
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
self._invoke(
model=model,
credentials=credentials,
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.01,
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
"""
return {
InvokeConnectionError: [httpx.ConnectError],
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
InvokeRateLimitError: [],
InvokeAuthorizationError: [httpx.HTTPStatusError],
InvokeBadRequestError: [httpx.RequestError],
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
generate custom model entities from credentials
"""
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.RERANK,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
)
return entity

View File

@ -0,0 +1,2 @@
- whisper-base
- whisper-large

View File

@ -0,0 +1,53 @@
import os
from typing import IO, Optional
import requests
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from core.model_runtime.model_providers.gitee_ai._common import _CommonGiteeAI
class GiteeAISpeech2TextModel(_CommonGiteeAI, Speech2TextModel):
"""
Model class for OpenAI Compatible Speech to text model.
"""
def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
"""
Invoke speech2text model
:param model: model name
:param credentials: model credentials
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
# doc: https://ai.gitee.com/docs/openapi/serverless#tag/serverless/POST/{service}/speech-to-text
endpoint_url = f"https://ai.gitee.com/api/serverless/{model}/speech-to-text"
files = [("file", file)]
_, file_ext = os.path.splitext(file.name)
headers = {"Content-Type": f"audio/{file_ext}", "Authorization": f"Bearer {credentials.get('api_key')}"}
response = requests.post(endpoint_url, headers=headers, files=files)
if response.status_code != 200:
raise InvokeBadRequestError(response.text)
response_data = response.json()
return response_data["text"]
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
audio_file_path = self._get_demo_file_path()
with open(audio_file_path, "rb") as audio_file:
self._invoke(model, credentials, audio_file)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

View File

@ -0,0 +1,5 @@
model: whisper-base
model_type: speech2text
model_properties:
file_upload_limit: 1
supported_file_extensions: flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm

View File

@ -0,0 +1,5 @@
model: whisper-large
model_type: speech2text
model_properties:
file_upload_limit: 1
supported_file_extensions: flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm

View File

@ -0,0 +1,3 @@
- bge-large-zh-v1.5
- bge-small-zh-v1.5
- bge-m3

View File

@ -0,0 +1,8 @@
model: bge-large-zh-v1.5
label:
zh_Hans: bge-large-zh-v1.5
en_US: bge-large-zh-v1.5
model_type: text-embedding
model_properties:
context_size: 200000
max_chunks: 20

View File

@ -0,0 +1,8 @@
model: bge-m3
label:
zh_Hans: bge-m3
en_US: bge-m3
model_type: text-embedding
model_properties:
context_size: 200000
max_chunks: 20

View File

@ -0,0 +1,8 @@
model: bge-small-zh-v1.5
label:
zh_Hans: bge-small-zh-v1.5
en_US: bge-small-zh-v1.5
model_type: text-embedding
model_properties:
context_size: 200000
max_chunks: 20

View File

@ -0,0 +1,31 @@
from typing import Optional
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import (
OAICompatEmbeddingModel,
)
class GiteeAIEmbeddingModel(OAICompatEmbeddingModel):
def _invoke(
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> TextEmbeddingResult:
self._add_custom_parameters(credentials, model)
return super()._invoke(model, credentials, texts, user, input_type)
def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials, None)
super().validate_credentials(model, credentials)
@staticmethod
def _add_custom_parameters(credentials: dict, model: str) -> None:
if model is None:
model = "bge-m3"
credentials["endpoint_url"] = f"https://ai.gitee.com/api/serverless/{model}/v1/"

View File

@ -0,0 +1,11 @@
model: ChatTTS
model_type: tts
model_properties:
default_voice: 'default'
voices:
- mode: 'default'
name: 'Default'
language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ]
word_limit: 3500
audio_type: 'mp3'
max_workers: 5

View File

@ -0,0 +1,11 @@
model: FunAudioLLM-CosyVoice-300M
model_type: tts
model_properties:
default_voice: 'default'
voices:
- mode: 'default'
name: 'Default'
language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ]
word_limit: 3500
audio_type: 'mp3'
max_workers: 5

View File

@ -0,0 +1,4 @@
- speecht5_tts
- ChatTTS
- fish-speech-1.2-sft
- FunAudioLLM-CosyVoice-300M

View File

@ -0,0 +1,11 @@
model: fish-speech-1.2-sft
model_type: tts
model_properties:
default_voice: 'default'
voices:
- mode: 'default'
name: 'Default'
language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ]
word_limit: 3500
audio_type: 'mp3'
max_workers: 5

View File

@ -0,0 +1,11 @@
model: speecht5_tts
model_type: tts
model_properties:
default_voice: 'default'
voices:
- mode: 'default'
name: 'Default'
language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ]
word_limit: 3500
audio_type: 'mp3'
max_workers: 5

View File

@ -0,0 +1,79 @@
from typing import Optional
import requests
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.model_runtime.model_providers.gitee_ai._common import _CommonGiteeAI
class GiteeAIText2SpeechModel(_CommonGiteeAI, TTSModel):
"""
Model class for OpenAI Speech to text model.
"""
def _invoke(
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
) -> any:
"""
_invoke text2speech model
:param model: model name
:param tenant_id: user tenant id
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: model timbre
:param user: unique user id
:return: text translated to audio file
"""
return self._tts_invoke_streaming(model=model, credentials=credentials, content_text=content_text, voice=voice)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
validate credentials text2speech model
:param model: model name
:param credentials: model credentials
:return: text translated to audio file
"""
try:
self._tts_invoke_streaming(
model=model,
credentials=credentials,
content_text="Hello Dify!",
voice=self._get_model_default_voice(model, credentials),
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any:
"""
_tts_invoke_streaming text2speech model
:param model: model name
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: model timbre
:return: text translated to audio file
"""
try:
# doc: https://ai.gitee.com/docs/openapi/serverless#tag/serverless/POST/{service}/text-to-speech
endpoint_url = "https://ai.gitee.com/api/serverless/" + model + "/text-to-speech"
headers = {"Content-Type": "application/json"}
api_key = credentials.get("api_key")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
payload = {"inputs": content_text}
response = requests.post(endpoint_url, headers=headers, json=payload)
if response.status_code != 200:
raise InvokeBadRequestError(response.text)
data = response.content
for i in range(0, len(data), 1024):
yield data[i : i + 1024]
except Exception as ex:
raise InvokeBadRequestError(str(ex))

View File

@ -0,0 +1,450 @@
import base64
import io
import json
import logging
from collections.abc import Generator
from typing import Optional, Union, cast
import google.ai.generativelanguage as glm
import google.generativeai as genai
import requests
from google.api_core import exceptions
from google.generativeai.client import _ClientManager
from google.generativeai.types import ContentType, GenerateContentResponse
from google.generativeai.types.content_types import to_part
from PIL import Image
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
logger = logging.getLogger(__name__)
GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
""" # noqa: E501
class GoogleLargeLanguageModel(LargeLanguageModel):
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
# invoke model
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:md = genai.GenerativeModel(model)
"""
prompt = self._convert_messages_to_prompt(prompt_messages)
return self._get_num_tokens_by_gpt2(prompt)
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
"""
Format a list of messages into a full prompt for the Google model
:param messages: List of PromptMessage to combine.
:return: Combined string with necessary human_prompt and ai_prompt tags.
"""
messages = messages.copy() # don't mutate the original list
text = "".join(self._convert_one_message_to_text(message) for message in messages)
return text.rstrip()
def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool:
"""
Convert tool messages to glm tools
:param tools: tool messages
:return: glm tools
"""
function_declarations = []
for tool in tools:
properties = {}
for key, value in tool.parameters.get("properties", {}).items():
properties[key] = {
"type_": glm.Type.STRING,
"description": value.get("description", ""),
"enum": value.get("enum", []),
}
if properties:
parameters = glm.Schema(
type=glm.Type.OBJECT,
properties=properties,
required=tool.parameters.get("required", []),
)
else:
parameters = None
function_declaration = glm.FunctionDeclaration(
name=tool.name,
parameters=parameters,
description=tool.description,
)
function_declarations.append(function_declaration)
return glm.Tool(function_declarations=function_declarations)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
ping_message = SystemPromptMessage(content="ping")
self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: credentials kwargs
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
config_kwargs = model_parameters.copy()
config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None)
if stop:
config_kwargs["stop_sequences"] = stop
google_model = genai.GenerativeModel(model_name=model)
history = []
# hack for gemini-pro-vision, which currently does not support multi-turn chat
if model == "gemini-pro-vision":
last_msg = prompt_messages[-1]
content = self._format_message_to_glm_content(last_msg)
history.append(content)
else:
for msg in prompt_messages: # makes message roles strictly alternating
content = self._format_message_to_glm_content(msg)
if history and history[-1]["role"] == content["role"]:
history[-1]["parts"].extend(content["parts"])
else:
history.append(content)
# Create a new ClientManager with tenant's API key
new_client_manager = _ClientManager()
new_client_manager.configure(api_key=credentials["google_api_key"])
new_custom_client = new_client_manager.make_client("generative")
google_model._client = new_custom_client
response = google_model.generate_content(
contents=history,
generation_config=genai.types.GenerationConfig(**config_kwargs),
stream=stream,
tools=self._convert_tools_to_glm_tool(tools) if tools else None,
request_options={"timeout": 600},
)
if stream:
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_response(
self, model: str, credentials: dict, response: GenerateContentResponse, prompt_messages: list[PromptMessage]
) -> LLMResult:
"""
Handle llm response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:return: llm response
"""
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=response.text)
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
# transform response
result = LLMResult(
model=model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage,
)
return result
def _handle_generate_stream_response(
self, model: str, credentials: dict, response: GenerateContentResponse, prompt_messages: list[PromptMessage]
) -> Generator:
"""
Handle llm stream response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:return: llm response chunk generator result
"""
index = -1
for chunk in response:
for part in chunk.parts:
assistant_prompt_message = AssistantPromptMessage(content="")
if part.text:
assistant_prompt_message.content += part.text
if part.function_call:
assistant_prompt_message.tool_calls = [
AssistantPromptMessage.ToolCall(
id=part.function_call.name,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=part.function_call.name,
arguments=json.dumps(dict(part.function_call.args.items())),
),
)
]
index += 1
if not response._done:
# transform assistant message to prompt message
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message),
)
else:
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
finish_reason=str(chunk.candidates[0].finish_reason),
usage=usage,
),
)
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
"""
Convert a single message to a string.
:param message: PromptMessage to convert.
:return: String representation of the message.
"""
human_prompt = "\n\nuser:"
ai_prompt = "\n\nmodel:"
content = message.content
if isinstance(content, list):
content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE)
if isinstance(message, UserPromptMessage):
message_text = f"{human_prompt} {content}"
elif isinstance(message, AssistantPromptMessage):
message_text = f"{ai_prompt} {content}"
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
message_text = f"{human_prompt} {content}"
else:
raise ValueError(f"Got unknown type {message}")
return message_text
def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType:
"""
Format a single message into glm.Content for Google API
:param message: one PromptMessage
:return: glm Content representation of message
"""
if isinstance(message, UserPromptMessage):
glm_content = {"role": "user", "parts": []}
if isinstance(message.content, str):
glm_content["parts"].append(to_part(message.content))
else:
for c in message.content:
if c.type == PromptMessageContentType.TEXT:
glm_content["parts"].append(to_part(c.data))
elif c.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, c)
if message_content.data.startswith("data:"):
metadata, base64_data = c.data.split(",", 1)
mime_type = metadata.split(";", 1)[0].split(":")[1]
else:
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
with Image.open(io.BytesIO(image_content)) as img:
mime_type = f"image/{img.format.lower()}"
base64_data = base64.b64encode(image_content).decode("utf-8")
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}}
glm_content["parts"].append(blob)
return glm_content
elif isinstance(message, AssistantPromptMessage):
glm_content = {"role": "model", "parts": []}
if message.content:
glm_content["parts"].append(to_part(message.content))
if message.tool_calls:
glm_content["parts"].append(
to_part(
glm.FunctionCall(
name=message.tool_calls[0].function.name,
args=json.loads(message.tool_calls[0].function.arguments),
)
)
)
return glm_content
elif isinstance(message, SystemPromptMessage):
return {"role": "user", "parts": [to_part(message.content)]}
elif isinstance(message, ToolPromptMessage):
return {
"role": "function",
"parts": [
glm.Part(
function_response=glm.FunctionResponse(
name=message.name, response={"response": message.content}
)
)
],
}
else:
raise ValueError(f"Got unknown type {message}")
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller
The value is the md = genai.GenerativeModel(model) error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke emd = genai.GenerativeModel(model) error mapping
"""
return {
InvokeConnectionError: [exceptions.RetryError],
InvokeServerUnavailableError: [
exceptions.ServiceUnavailable,
exceptions.InternalServerError,
exceptions.BadGateway,
exceptions.GatewayTimeout,
exceptions.DeadlineExceeded,
],
InvokeRateLimitError: [exceptions.ResourceExhausted, exceptions.TooManyRequests],
InvokeAuthorizationError: [
exceptions.Unauthenticated,
exceptions.PermissionDenied,
exceptions.Unauthenticated,
exceptions.Forbidden,
],
InvokeBadRequestError: [
exceptions.BadRequest,
exceptions.InvalidArgument,
exceptions.FailedPrecondition,
exceptions.OutOfRange,
exceptions.NotFound,
exceptions.MethodNotAllowed,
exceptions.Conflict,
exceptions.AlreadyExists,
exceptions.Aborted,
exceptions.LengthRequired,
exceptions.PreconditionFailed,
exceptions.RequestRangeNotSatisfiable,
exceptions.Cancelled,
],
}

View File

@ -0,0 +1,330 @@
import json
from collections.abc import Generator
from typing import Optional, Union, cast
import requests
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
ModelFeature,
ModelPropertyKey,
ModelType,
ParameterRule,
ParameterType,
)
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials)
self._add_function_call(model, credentials)
user = user[:32] if user else None
# {"response_format": "json_object"} need convert to {"response_format": {"type": "json_object"}}
if "response_format" in model_parameters:
model_parameters["response_format"] = {"type": model_parameters.get("response_format")}
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
return AIModelEntity(
model=model,
label=I18nObject(en_US=model, zh_Hans=model),
model_type=ModelType.LLM,
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
if credentials.get("function_calling_type") == "tool_call"
else [],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)),
ModelPropertyKey.MODE: LLMMode.CHAT.value,
},
parameter_rules=[
ParameterRule(
name="temperature",
use_template="temperature",
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
type=ParameterType.FLOAT,
),
ParameterRule(
name="max_tokens",
use_template="max_tokens",
default=512,
min=1,
max=int(credentials.get("max_tokens", 4096)),
label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"),
type=ParameterType.INT,
),
ParameterRule(
name="top_p",
use_template="top_p",
label=I18nObject(en_US="Top P", zh_Hans="Top P"),
type=ParameterType.FLOAT,
),
],
)
def _add_custom_parameters(self, credentials: dict) -> None:
credentials["mode"] = "chat"
if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
credentials["endpoint_url"] = "https://api.moonshot.cn/v1"
def _add_function_call(self, model: str, credentials: dict) -> None:
model_schema = self.get_model_schema(model, credentials)
if model_schema and {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}.intersection(
model_schema.features or []
):
credentials["function_calling_type"] = "tool_call"
def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: Optional[dict] = None) -> dict:
"""
Convert PromptMessage to dict for OpenAI API format
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
sub_messages = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(PromptMessageContent, message_content)
sub_message_dict = {"type": "text", "text": message_content.data}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
sub_message_dict = {
"type": "image_url",
"image_url": {"url": message_content.data, "detail": message_content.detail.value},
}
sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls:
message_dict["tool_calls"] = []
for function_call in message.tool_calls:
message_dict["tool_calls"].append(
{
"id": function_call.id,
"type": function_call.type,
"function": {
"name": function_call.function.name,
"arguments": function_call.function.arguments,
},
}
)
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
if message.name:
message_dict["name"] = message.name
return message_dict
def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]:
"""
Extract tool calls from response
:param response_tool_calls: response tool calls
:return: list of tool calls
"""
tool_calls = []
if response_tool_calls:
for response_tool_call in response_tool_calls:
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call["function"]["name"]
if response_tool_call.get("function", {}).get("name")
else "",
arguments=response_tool_call["function"]["arguments"]
if response_tool_call.get("function", {}).get("arguments")
else "",
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call["id"] if response_tool_call.get("id") else "",
type=response_tool_call["type"] if response_tool_call.get("type") else "",
function=function,
)
tool_calls.append(tool_call)
return tool_calls
def _handle_generate_stream_response(
self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
) -> Generator:
"""
Handle llm stream response
:param model: model name
:param credentials: model credentials
:param response: streamed response
:param prompt_messages: prompt messages
:return: llm response chunk generator
"""
full_assistant_content = ""
chunk_index = 0
def create_final_llm_result_chunk(
index: int, message: AssistantPromptMessage, finish_reason: str
) -> LLMResultChunk:
# calculate num tokens
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
return LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage),
)
tools_calls: list[AssistantPromptMessage.ToolCall] = []
finish_reason = "Unknown"
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
def get_tool_call(tool_name: str):
if not tool_name:
return tools_calls[-1]
tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None)
if tool_call is None:
tool_call = AssistantPromptMessage.ToolCall(
id="",
type="",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""),
)
tools_calls.append(tool_call)
return tool_call
for new_tool_call in new_tool_calls:
# get tool call
tool_call = get_tool_call(new_tool_call.function.name)
# update tool call
if new_tool_call.id:
tool_call.id = new_tool_call.id
if new_tool_call.type:
tool_call.type = new_tool_call.type
if new_tool_call.function.name:
tool_call.function.name = new_tool_call.function.name
if new_tool_call.function.arguments:
tool_call.function.arguments += new_tool_call.function.arguments
for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"):
if chunk:
# ignore sse comments
if chunk.startswith(":"):
continue
decoded_chunk = chunk.strip().lstrip("data: ").lstrip()
chunk_json = None
try:
chunk_json = json.loads(decoded_chunk)
# stream ended
except json.JSONDecodeError as e:
yield create_final_llm_result_chunk(
index=chunk_index + 1,
message=AssistantPromptMessage(content=""),
finish_reason="Non-JSON encountered.",
)
break
if not chunk_json or len(chunk_json["choices"]) == 0:
continue
choice = chunk_json["choices"][0]
finish_reason = chunk_json["choices"][0].get("finish_reason")
chunk_index += 1
if "delta" in choice:
delta = choice["delta"]
delta_content = delta.get("content")
assistant_message_tool_calls = delta.get("tool_calls", None)
# assistant_message_function_call = delta.delta.function_call
# extract tool calls from response
if assistant_message_tool_calls:
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
increase_tool_call(tool_calls)
if delta_content is None or delta_content == "":
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta_content, tool_calls=tool_calls if assistant_message_tool_calls else []
)
full_assistant_content += delta_content
elif "text" in choice:
choice_text = choice.get("text", "")
if choice_text == "":
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=choice_text)
full_assistant_content += choice_text
else:
continue
# check payload indicator for completion
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk_index,
message=assistant_prompt_message,
),
)
chunk_index += 1
if tools_calls:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk_index,
message=AssistantPromptMessage(tool_calls=tools_calls, content=""),
),
)
yield create_final_llm_result_chunk(
index=chunk_index, message=AssistantPromptMessage(content=""), finish_reason=finish_reason
)

View File

@ -0,0 +1,847 @@
import json
import logging
from collections.abc import Generator
from decimal import Decimal
from typing import Optional, Union, cast
from urllib.parse import urljoin
import requests
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
PromptMessageFunction,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import (
AIModelEntity,
DefaultParameterName,
FetchFrom,
ModelFeature,
ModelPropertyKey,
ModelType,
ParameterRule,
ParameterType,
PriceConfig,
)
from core.model_runtime.errors.invoke import InvokeError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
from core.model_runtime.utils import helper
logger = logging.getLogger(__name__)
class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
"""
Model class for OpenAI large language model.
"""
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
# text completion model
return self._generate(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
)
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
:param model:
:param credentials:
:param prompt_messages:
:param tools: tools for tool calling
:return:
"""
return self._num_tokens_from_messages(model, prompt_messages, tools, credentials)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials using requests to ensure compatibility with all providers following
OpenAI's API standard.
:param model: model name
:param credentials: model credentials
:return:
"""
try:
headers = {"Content-Type": "application/json"}
api_key = credentials.get("api_key")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials["endpoint_url"]
if not endpoint_url.endswith("/"):
endpoint_url += "/"
# prepare the payload for a simple ping to the model
data = {"model": model, "max_tokens": 5}
completion_type = LLMMode.value_of(credentials["mode"])
if completion_type is LLMMode.CHAT:
data["messages"] = [
{"role": "user", "content": "ping"},
]
endpoint_url = urljoin(endpoint_url, "chat/completions")
elif completion_type is LLMMode.COMPLETION:
data["prompt"] = "ping"
endpoint_url = urljoin(endpoint_url, "completions")
else:
raise ValueError("Unsupported completion type for model configuration.")
# send a post request to validate the credentials
response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300))
if response.status_code != 200:
raise CredentialsValidateFailedError(
f"Credentials validation failed with status code {response.status_code}"
)
try:
json_result = response.json()
except json.JSONDecodeError as e:
raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error")
if completion_type is LLMMode.CHAT and json_result.get("object", "") == "":
json_result["object"] = "chat.completion"
elif completion_type is LLMMode.COMPLETION and json_result.get("object", "") == "":
json_result["object"] = "text_completion"
if completion_type is LLMMode.CHAT and (
"object" not in json_result or json_result["object"] != "chat.completion"
):
raise CredentialsValidateFailedError(
"Credentials validation failed: invalid response object, must be 'chat.completion'"
)
elif completion_type is LLMMode.COMPLETION and (
"object" not in json_result or json_result["object"] != "text_completion"
):
raise CredentialsValidateFailedError(
"Credentials validation failed: invalid response object, must be 'text_completion'"
)
except CredentialsValidateFailedError:
raise
except Exception as ex:
raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {str(ex)}")
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
generate custom model entities from credentials
"""
features = []
function_calling_type = credentials.get("function_calling_type", "no_call")
if function_calling_type == "function_call":
features.append(ModelFeature.TOOL_CALL)
elif function_calling_type == "tool_call":
features.append(ModelFeature.MULTI_TOOL_CALL)
stream_function_calling = credentials.get("stream_function_calling", "supported")
if stream_function_calling == "supported":
features.append(ModelFeature.STREAM_TOOL_CALL)
vision_support = credentials.get("vision_support", "not_support")
if vision_support == "support":
features.append(ModelFeature.VISION)
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
features=features,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "4096")),
ModelPropertyKey.MODE: credentials.get("mode"),
},
parameter_rules=[
ParameterRule(
name=DefaultParameterName.TEMPERATURE.value,
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
help=I18nObject(
en_US="Kernel sampling threshold. Used to determine the randomness of the results."
"The higher the value, the stronger the randomness."
"The higher the possibility of getting different answers to the same question.",
zh_Hans="核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。",
),
type=ParameterType.FLOAT,
default=float(credentials.get("temperature", 0.7)),
min=0,
max=2,
precision=2,
),
ParameterRule(
name=DefaultParameterName.TOP_P.value,
label=I18nObject(en_US="Top P", zh_Hans="Top P"),
help=I18nObject(
en_US="The probability threshold of the nucleus sampling method during the generation process."
"The larger the value is, the higher the randomness of generation will be."
"The smaller the value is, the higher the certainty of generation will be.",
zh_Hans="生成过程中核采样方法概率阈值。取值越大,生成的随机性越高;取值越小,生成的确定性越高。",
),
type=ParameterType.FLOAT,
default=float(credentials.get("top_p", 1)),
min=0,
max=1,
precision=2,
),
ParameterRule(
name=DefaultParameterName.FREQUENCY_PENALTY.value,
label=I18nObject(en_US="Frequency Penalty", zh_Hans="频率惩罚"),
help=I18nObject(
en_US="For controlling the repetition rate of words used by the model."
"Increasing this can reduce the repetition of the same words in the model's output.",
zh_Hans="用于控制模型已使用字词的重复率。 提高此项可以降低模型在输出中重复相同字词的重复度。",
),
type=ParameterType.FLOAT,
default=float(credentials.get("frequency_penalty", 0)),
min=-2,
max=2,
),
ParameterRule(
name=DefaultParameterName.PRESENCE_PENALTY.value,
label=I18nObject(en_US="Presence Penalty", zh_Hans="存在惩罚"),
help=I18nObject(
en_US="Used to control the repetition rate when generating models."
"Increasing this can reduce the repetition rate of model generation.",
zh_Hans="用于控制模型生成时的重复度。提高此项可以降低模型生成的重复度。",
),
type=ParameterType.FLOAT,
default=float(credentials.get("presence_penalty", 0)),
min=-2,
max=2,
),
ParameterRule(
name=DefaultParameterName.MAX_TOKENS.value,
label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"),
help=I18nObject(
en_US="Maximum length of tokens for the model response.", zh_Hans="模型回答的tokens的最大长度。"
),
type=ParameterType.INT,
default=512,
min=1,
max=int(credentials.get("max_tokens_to_sample", 4096)),
),
],
pricing=PriceConfig(
input=Decimal(credentials.get("input_price", 0)),
output=Decimal(credentials.get("output_price", 0)),
unit=Decimal(credentials.get("unit", 0)),
currency=credentials.get("currency", "USD"),
),
)
if credentials["mode"] == "chat":
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value
elif credentials["mode"] == "completion":
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
else:
raise ValueError(f"Unknown completion type {credentials['completion_type']}")
return entity
# validate_credentials method has been rewritten to use the requests library for compatibility with all providers
# following OpenAI's API standard.
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke llm completion model
:param model: model name
:param credentials: credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
headers = {
"Content-Type": "application/json",
"Accept-Charset": "utf-8",
}
extra_headers = credentials.get("extra_headers")
if extra_headers is not None:
headers = {
**headers,
**extra_headers,
}
api_key = credentials.get("api_key")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials["endpoint_url"]
if not endpoint_url.endswith("/"):
endpoint_url += "/"
data = {"model": model, "stream": stream, **model_parameters}
completion_type = LLMMode.value_of(credentials["mode"])
if completion_type is LLMMode.CHAT:
endpoint_url = urljoin(endpoint_url, "chat/completions")
data["messages"] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages]
elif completion_type is LLMMode.COMPLETION:
endpoint_url = urljoin(endpoint_url, "completions")
data["prompt"] = prompt_messages[0].content
else:
raise ValueError("Unsupported completion type for model configuration.")
# annotate tools with names, descriptions, etc.
function_calling_type = credentials.get("function_calling_type", "no_call")
formatted_tools = []
if tools:
if function_calling_type == "function_call":
data["functions"] = [
{"name": tool.name, "description": tool.description, "parameters": tool.parameters}
for tool in tools
]
elif function_calling_type == "tool_call":
data["tool_choice"] = "auto"
for tool in tools:
formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
data["tools"] = formatted_tools
if stop:
data["stop"] = stop
if user:
data["user"] = user
response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream)
if response.encoding is None or response.encoding == "ISO-8859-1":
response.encoding = "utf-8"
if response.status_code != 200:
raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}")
if stream:
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_stream_response(
self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
) -> Generator:
"""
Handle llm stream response
:param model: model name
:param credentials: model credentials
:param response: streamed response
:param prompt_messages: prompt messages
:return: llm response chunk generator
"""
full_assistant_content = ""
chunk_index = 0
def create_final_llm_result_chunk(
id: Optional[str], index: int, message: AssistantPromptMessage, finish_reason: str, usage: dict
) -> LLMResultChunk:
# calculate num tokens
prompt_tokens = usage and usage.get("prompt_tokens")
if prompt_tokens is None:
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
completion_tokens = usage and usage.get("completion_tokens")
if completion_tokens is None:
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
return LLMResultChunk(
id=id,
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage),
)
# delimiter for stream response, need unicode_escape
import codecs
delimiter = credentials.get("stream_mode_delimiter", "\n\n")
delimiter = codecs.decode(delimiter, "unicode_escape")
tools_calls: list[AssistantPromptMessage.ToolCall] = []
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
def get_tool_call(tool_call_id: str):
if not tool_call_id:
return tools_calls[-1]
tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None)
if tool_call is None:
tool_call = AssistantPromptMessage.ToolCall(
id=tool_call_id,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
)
tools_calls.append(tool_call)
return tool_call
for new_tool_call in new_tool_calls:
# get tool call
tool_call = get_tool_call(new_tool_call.function.name)
# update tool call
if new_tool_call.id:
tool_call.id = new_tool_call.id
if new_tool_call.type:
tool_call.type = new_tool_call.type
if new_tool_call.function.name:
tool_call.function.name = new_tool_call.function.name
if new_tool_call.function.arguments:
tool_call.function.arguments += new_tool_call.function.arguments
finish_reason = None # The default value of finish_reason is None
message_id, usage = None, None
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
chunk = chunk.strip()
if chunk:
# ignore sse comments
if chunk.startswith(":"):
continue
decoded_chunk = chunk.strip().lstrip("data: ").lstrip()
if decoded_chunk == "[DONE]": # Some provider returns "data: [DONE]"
continue
try:
chunk_json: dict = json.loads(decoded_chunk)
# stream ended
except json.JSONDecodeError as e:
yield create_final_llm_result_chunk(
id=message_id,
index=chunk_index + 1,
message=AssistantPromptMessage(content=""),
finish_reason="Non-JSON encountered.",
usage=usage,
)
break
if chunk_json:
if u := chunk_json.get("usage"):
usage = u
if not chunk_json or len(chunk_json["choices"]) == 0:
continue
choice = chunk_json["choices"][0]
finish_reason = chunk_json["choices"][0].get("finish_reason")
message_id = chunk_json.get("id")
chunk_index += 1
if "delta" in choice:
delta = choice["delta"]
delta_content = delta.get("content")
assistant_message_tool_calls = None
if "tool_calls" in delta and credentials.get("function_calling_type", "no_call") == "tool_call":
assistant_message_tool_calls = delta.get("tool_calls", None)
elif (
"function_call" in delta
and credentials.get("function_calling_type", "no_call") == "function_call"
):
assistant_message_tool_calls = [
{"id": "tool_call_id", "type": "function", "function": delta.get("function_call", {})}
]
# assistant_message_function_call = delta.delta.function_call
# extract tool calls from response
if assistant_message_tool_calls:
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
increase_tool_call(tool_calls)
if delta_content is None or delta_content == "":
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta_content,
)
# reset tool calls
tool_calls = []
full_assistant_content += delta_content
elif "text" in choice:
choice_text = choice.get("text", "")
if choice_text == "":
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=choice_text)
full_assistant_content += choice_text
else:
continue
yield LLMResultChunk(
id=message_id,
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk_index,
message=assistant_prompt_message,
),
)
chunk_index += 1
if tools_calls:
yield LLMResultChunk(
id=message_id,
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk_index,
message=AssistantPromptMessage(tool_calls=tools_calls, content=""),
),
)
yield create_final_llm_result_chunk(
id=message_id,
index=chunk_index,
message=AssistantPromptMessage(content=""),
finish_reason=finish_reason,
usage=usage,
)
def _handle_generate_response(
self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
) -> LLMResult:
response_json: dict = response.json()
completion_type = LLMMode.value_of(credentials["mode"])
output = response_json["choices"][0]
message_id = response_json.get("id")
response_content = ""
tool_calls = None
function_calling_type = credentials.get("function_calling_type", "no_call")
if completion_type is LLMMode.CHAT:
response_content = output.get("message", {})["content"]
if function_calling_type == "tool_call":
tool_calls = output.get("message", {}).get("tool_calls")
elif function_calling_type == "function_call":
tool_calls = output.get("message", {}).get("function_call")
elif completion_type is LLMMode.COMPLETION:
response_content = output["text"]
assistant_message = AssistantPromptMessage(content=response_content, tool_calls=[])
if tool_calls:
if function_calling_type == "tool_call":
assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls)
elif function_calling_type == "function_call":
assistant_message.tool_calls = [self._extract_response_function_call(tool_calls)]
usage = response_json.get("usage")
if usage:
# transform usage
prompt_tokens = usage["prompt_tokens"]
completion_tokens = usage["completion_tokens"]
else:
# calculate num tokens
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
completion_tokens = self._num_tokens_from_string(model, assistant_message.content)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
# transform response
result = LLMResult(
id=message_id,
model=response_json["model"],
prompt_messages=prompt_messages,
message=assistant_message,
usage=usage,
)
return result
def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: Optional[dict] = None) -> dict:
"""
Convert PromptMessage to dict for OpenAI API format
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
sub_messages = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(PromptMessageContent, message_content)
sub_message_dict = {"type": "text", "text": message_content.data}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
sub_message_dict = {
"type": "image_url",
"image_url": {"url": message_content.data, "detail": message_content.detail.value},
}
sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls:
function_calling_type = credentials.get("function_calling_type", "no_call")
if function_calling_type == "tool_call":
message_dict["tool_calls"] = [tool_call.dict() for tool_call in message.tool_calls]
elif function_calling_type == "function_call":
function_call = message.tool_calls[0]
message_dict["function_call"] = {
"name": function_call.function.name,
"arguments": function_call.function.arguments,
}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
function_calling_type = credentials.get("function_calling_type", "no_call")
if function_calling_type == "tool_call":
message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id}
elif function_calling_type == "function_call":
message_dict = {"role": "function", "content": message.content, "name": message.tool_call_id}
else:
raise ValueError(f"Got unknown type {message}")
if message.name and message_dict.get("role", "") != "tool":
message_dict["name"] = message.name
return message_dict
def _num_tokens_from_string(
self, model: str, text: Union[str, list[PromptMessageContent]], tools: Optional[list[PromptMessageTool]] = None
) -> int:
"""
Approximate num tokens for model with gpt2 tokenizer.
:param model: model name
:param text: prompt text
:param tools: tools for tool calling
:return: number of tokens
"""
if isinstance(text, str):
full_text = text
else:
full_text = ""
for message_content in text:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(PromptMessageContent, message_content)
full_text += message_content.data
num_tokens = self._get_num_tokens_by_gpt2(full_text)
if tools:
num_tokens += self._num_tokens_for_tools(tools)
return num_tokens
def _num_tokens_from_messages(
self,
model: str,
messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
credentials: Optional[dict] = None,
) -> int:
"""
Approximate num tokens with GPT2 tokenizer.
"""
tokens_per_message = 3
tokens_per_name = 1
num_tokens = 0
messages_dict = [self._convert_prompt_message_to_dict(m, credentials) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():
# Cast str(value) in case the message value is not a string
# This occurs with function messages
# TODO: The current token calculation method for the image type is not implemented,
# which need to download the image and then get the resolution for calculation,
# and will increase the request delay
if isinstance(value, list):
text = ""
for item in value:
if isinstance(item, dict) and item["type"] == "text":
text += item["text"]
value = text
if key == "tool_calls":
for tool_call in value:
for t_key, t_value in tool_call.items():
num_tokens += self._get_num_tokens_by_gpt2(t_key)
if t_key == "function":
for f_key, f_value in t_value.items():
num_tokens += self._get_num_tokens_by_gpt2(f_key)
num_tokens += self._get_num_tokens_by_gpt2(f_value)
else:
num_tokens += self._get_num_tokens_by_gpt2(t_key)
num_tokens += self._get_num_tokens_by_gpt2(t_value)
else:
num_tokens += self._get_num_tokens_by_gpt2(str(value))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens += 3
if tools:
num_tokens += self._num_tokens_for_tools(tools)
return num_tokens
def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
"""
Calculate num tokens for tool calling with tiktoken package.
:param tools: tools for tool calling
:return: number of tokens
"""
num_tokens = 0
for tool in tools:
num_tokens += self._get_num_tokens_by_gpt2("type")
num_tokens += self._get_num_tokens_by_gpt2("function")
num_tokens += self._get_num_tokens_by_gpt2("function")
# calculate num tokens for function object
num_tokens += self._get_num_tokens_by_gpt2("name")
num_tokens += self._get_num_tokens_by_gpt2(tool.name)
num_tokens += self._get_num_tokens_by_gpt2("description")
num_tokens += self._get_num_tokens_by_gpt2(tool.description)
parameters = tool.parameters
num_tokens += self._get_num_tokens_by_gpt2("parameters")
if "title" in parameters:
num_tokens += self._get_num_tokens_by_gpt2("title")
num_tokens += self._get_num_tokens_by_gpt2(parameters.get("title"))
num_tokens += self._get_num_tokens_by_gpt2("type")
num_tokens += self._get_num_tokens_by_gpt2(parameters.get("type"))
if "properties" in parameters:
num_tokens += self._get_num_tokens_by_gpt2("properties")
for key, value in parameters.get("properties").items():
num_tokens += self._get_num_tokens_by_gpt2(key)
for field_key, field_value in value.items():
num_tokens += self._get_num_tokens_by_gpt2(field_key)
if field_key == "enum":
for enum_field in field_value:
num_tokens += 3
num_tokens += self._get_num_tokens_by_gpt2(enum_field)
else:
num_tokens += self._get_num_tokens_by_gpt2(field_key)
num_tokens += self._get_num_tokens_by_gpt2(str(field_value))
if "required" in parameters:
num_tokens += self._get_num_tokens_by_gpt2("required")
for required_field in parameters["required"]:
num_tokens += 3
num_tokens += self._get_num_tokens_by_gpt2(required_field)
return num_tokens
def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]:
"""
Extract tool calls from response
:param response_tool_calls: response tool calls
:return: list of tool calls
"""
tool_calls = []
if response_tool_calls:
for response_tool_call in response_tool_calls:
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.get("function", {}).get("name", ""),
arguments=response_tool_call.get("function", {}).get("arguments", ""),
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call.get("id", ""), type=response_tool_call.get("type", ""), function=function
)
tool_calls.append(tool_call)
return tool_calls
def _extract_response_function_call(self, response_function_call) -> AssistantPromptMessage.ToolCall:
"""
Extract function call from response
:param response_function_call: response function call
:return: tool call
"""
tool_call = None
if response_function_call:
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_function_call.get("name", ""), arguments=response_function_call.get("arguments", "")
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_function_call.get("id", ""), type="function", function=function
)
return tool_call

View File

@ -0,0 +1,55 @@
model: claude-3-5-sonnet-v2@20241022
label:
en_US: Claude 3.5 Sonnet v2
model_type: llm
features:
- agent-thought
- vision
model_properties:
mode: chat
context_size: 200000
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.003'
output: '0.015'
unit: '0.001'
currency: USD

View File

@ -0,0 +1,378 @@
import json
import logging
import time
import uuid
from datetime import timedelta
from typing import Any
from couchbase import search
from couchbase.auth import PasswordAuthenticator
from couchbase.cluster import Cluster
from couchbase.management.search import SearchIndex
# needed for options -- cluster, timeout, SQL++ (N1QL) query, etc.
from couchbase.options import ClusterOptions, SearchOptions
from couchbase.vector_search import VectorQuery, VectorSearch
from flask import current_app
from pydantic import BaseModel, model_validator
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class CouchbaseConfig(BaseModel):
connection_string: str
user: str
password: str
bucket_name: str
scope_name: str
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values.get("connection_string"):
raise ValueError("config COUCHBASE_CONNECTION_STRING is required")
if not values.get("user"):
raise ValueError("config COUCHBASE_USER is required")
if not values.get("password"):
raise ValueError("config COUCHBASE_PASSWORD is required")
if not values.get("bucket_name"):
raise ValueError("config COUCHBASE_PASSWORD is required")
if not values.get("scope_name"):
raise ValueError("config COUCHBASE_SCOPE_NAME is required")
return values
class CouchbaseVector(BaseVector):
def __init__(self, collection_name: str, config: CouchbaseConfig):
super().__init__(collection_name)
self._client_config = config
"""Connect to couchbase"""
auth = PasswordAuthenticator(config.user, config.password)
options = ClusterOptions(auth)
self._cluster = Cluster(config.connection_string, options)
self._bucket = self._cluster.bucket(config.bucket_name)
self._scope = self._bucket.scope(config.scope_name)
self._bucket_name = config.bucket_name
self._scope_name = config.scope_name
# Wait until the cluster is ready for use.
self._cluster.wait_until_ready(timedelta(seconds=5))
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
index_id = str(uuid.uuid4()).replace("-", "")
self._create_collection(uuid=index_id, vector_length=len(embeddings[0]))
self.add_texts(texts, embeddings)
def _create_collection(self, vector_length: int, uuid: str):
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
if self._collection_exists(self._collection_name):
return
manager = self._bucket.collections()
manager.create_collection(self._client_config.scope_name, self._collection_name)
index_manager = self._scope.search_indexes()
index_definition = json.loads("""
{
"type": "fulltext-index",
"name": "Embeddings._default.Vector_Search",
"uuid": "26d4db528e78b716",
"sourceType": "gocbcore",
"sourceName": "Embeddings",
"sourceUUID": "2242e4a25b4decd6650c9c7b3afa1dbf",
"planParams": {
"maxPartitionsPerPIndex": 1024,
"indexPartitions": 1
},
"params": {
"doc_config": {
"docid_prefix_delim": "",
"docid_regexp": "",
"mode": "scope.collection.type_field",
"type_field": "type"
},
"mapping": {
"analysis": { },
"default_analyzer": "standard",
"default_datetime_parser": "dateTimeOptional",
"default_field": "_all",
"default_mapping": {
"dynamic": true,
"enabled": true
},
"default_type": "_default",
"docvalues_dynamic": false,
"index_dynamic": true,
"store_dynamic": true,
"type_field": "_type",
"types": {
"collection_name": {
"dynamic": true,
"enabled": true,
"properties": {
"embedding": {
"dynamic": false,
"enabled": true,
"fields": [
{
"dims": 1536,
"index": true,
"name": "embedding",
"similarity": "dot_product",
"type": "vector",
"vector_index_optimized_for": "recall"
}
]
},
"metadata": {
"dynamic": true,
"enabled": true
},
"text": {
"dynamic": false,
"enabled": true,
"fields": [
{
"index": true,
"name": "text",
"store": true,
"type": "text"
}
]
}
}
}
}
},
"store": {
"indexType": "scorch",
"segmentVersion": 16
}
},
"sourceParams": { }
}
""")
index_definition["name"] = self._collection_name + "_search"
index_definition["uuid"] = uuid
index_definition["params"]["mapping"]["types"]["collection_name"]["properties"]["embedding"]["fields"][0][
"dims"
] = vector_length
index_definition["params"]["mapping"]["types"][self._scope_name + "." + self._collection_name] = (
index_definition["params"]["mapping"]["types"].pop("collection_name")
)
time.sleep(2)
index_manager.upsert_index(
SearchIndex(
index_definition["name"],
params=index_definition["params"],
source_name=self._bucket_name,
),
)
time.sleep(1)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _collection_exists(self, name: str):
scope_collection_map: dict[str, Any] = {}
# Get a list of all scopes in the bucket
for scope in self._bucket.collections().get_all_scopes():
scope_collection_map[scope.name] = []
# Get a list of all the collections in the scope
for collection in scope.collections:
scope_collection_map[scope.name].append(collection.name)
# Check if the collection exists in the scope
return self._collection_name in scope_collection_map[self._scope_name]
def get_type(self) -> str:
return VectorType.COUCHBASE
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
doc_ids = []
documents_to_insert = [
{"text": text, "embedding": vector, "metadata": metadata}
for id, text, vector, metadata in zip(uuids, texts, embeddings, metadatas)
]
for doc, id in zip(documents_to_insert, uuids):
result = self._scope.collection(self._collection_name).upsert(id, doc)
doc_ids.extend(uuids)
return doc_ids
def text_exists(self, id: str) -> bool:
# Use a parameterized query for safety and correctness
query = f"""
SELECT COUNT(1) AS count FROM
`{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
WHERE META().id = $doc_id
"""
# Pass the id as a parameter to the query
result = self._cluster.query(query, named_parameters={"doc_id": id}).execute()
for row in result:
return row["count"] > 0
return False # Return False if no rows are returned
def delete_by_ids(self, ids: list[str]) -> None:
query = f"""
DELETE FROM `{self._bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
WHERE META().id IN $doc_ids;
"""
try:
self._cluster.query(query, named_parameters={"doc_ids": ids}).execute()
except Exception as e:
logger.error(e)
def delete_by_document_id(self, document_id: str):
query = f"""
DELETE FROM
`{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
WHERE META().id = $doc_id;
"""
self._cluster.query(query, named_parameters={"doc_id": document_id}).execute()
# def get_ids_by_metadata_field(self, key: str, value: str):
# query = f"""
# SELECT id FROM
# `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
# WHERE `metadata.{key}` = $value;
# """
# result = self._cluster.query(query, named_parameters={'value':value})
# return [row['id'] for row in result.rows()]
def delete_by_metadata_field(self, key: str, value: str) -> None:
query = f"""
DELETE FROM `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
WHERE metadata.{key} = $value;
"""
self._cluster.query(query, named_parameters={"value": value}).execute()
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
score_threshold = kwargs.get("score_threshold") or 0.0
search_req = search.SearchRequest.create(
VectorSearch.from_vector_query(
VectorQuery(
"embedding",
query_vector,
top_k,
)
)
)
try:
search_iter = self._scope.search(
self._collection_name + "_search",
search_req,
SearchOptions(limit=top_k, collections=[self._collection_name], fields=["*"]),
)
docs = []
# Parse the results
for row in search_iter.rows():
text = row.fields.pop("text")
metadata = self._format_metadata(row.fields)
score = row.score
metadata["score"] = score
doc = Document(page_content=text, metadata=metadata)
if score >= score_threshold:
docs.append(doc)
except Exception as e:
raise ValueError(f"Search failed with error: {e}")
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 2)
try:
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
search_iter = self._scope.search(
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
)
docs = []
for row in search_iter.rows():
text = row.fields.pop("text")
metadata = self._format_metadata(row.fields)
score = row.score
metadata["score"] = score
doc = Document(page_content=text, metadata=metadata)
docs.append(doc)
except Exception as e:
raise ValueError(f"Search failed with error: {e}")
return docs
def delete(self):
manager = self._bucket.collections()
scopes = manager.get_all_scopes()
for scope in scopes:
for collection in scope.collections:
if collection.name == self._collection_name:
manager.drop_collection("_default", self._collection_name)
def _format_metadata(self, row_fields: dict[str, Any]) -> dict[str, Any]:
"""Helper method to format the metadata from the Couchbase Search API.
Args:
row_fields (Dict[str, Any]): The fields to format.
Returns:
Dict[str, Any]: The formatted metadata.
"""
metadata = {}
for key, value in row_fields.items():
# Couchbase Search returns the metadata key with a prefix
# `metadata.` We remove it to get the original metadata key
if key.startswith("metadata"):
new_key = key.split("metadata" + ".")[-1]
metadata[new_key] = value
else:
metadata[key] = value
return metadata
class CouchbaseVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> CouchbaseVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.COUCHBASE, collection_name))
config = current_app.config
return CouchbaseVector(
collection_name=collection_name,
config=CouchbaseConfig(
connection_string=config.get("COUCHBASE_CONNECTION_STRING"),
user=config.get("COUCHBASE_USER"),
password=config.get("COUCHBASE_PASSWORD"),
bucket_name=config.get("COUCHBASE_BUCKET_NAME"),
scope_name=config.get("COUCHBASE_SCOPE_NAME"),
),
)

View File

@ -142,7 +142,7 @@ class ElasticSearchVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
query_str = {"match": {Field.CONTENT_KEY.value: query}}
results = self._client.search(index=self._collection_name, query=query_str)
results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4))
docs = []
for hit in results["hits"]["hits"]:
docs.append(

View File

@ -0,0 +1,209 @@
import json
import logging
import math
from typing import Any
from pydantic import BaseModel, model_validator
from pyobvector import VECTOR, ObVecClient
from sqlalchemy import JSON, Column, String, func
from sqlalchemy.dialects.mysql import LONGTEXT
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
DEFAULT_OCEANBASE_HNSW_BUILD_PARAM = {"M": 16, "efConstruction": 256}
DEFAULT_OCEANBASE_HNSW_SEARCH_PARAM = {"efSearch": 64}
OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPE = "HNSW"
DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE = "l2"
class OceanBaseVectorConfig(BaseModel):
host: str
port: int
user: str
password: str
database: str
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config OCEANBASE_VECTOR_HOST is required")
if not values["port"]:
raise ValueError("config OCEANBASE_VECTOR_PORT is required")
if not values["user"]:
raise ValueError("config OCEANBASE_VECTOR_USER is required")
if not values["database"]:
raise ValueError("config OCEANBASE_VECTOR_DATABASE is required")
return values
class OceanBaseVector(BaseVector):
def __init__(self, collection_name: str, config: OceanBaseVectorConfig):
super().__init__(collection_name)
self._config = config
self._hnsw_ef_search = -1
self._client = ObVecClient(
uri=f"{self._config.host}:{self._config.port}",
user=self._config.user,
password=self._config.password,
db_name=self._config.database,
)
def get_type(self) -> str:
return VectorType.OCEANBASE
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self._vec_dim = len(embeddings[0])
self._create_collection()
self.add_texts(texts, embeddings)
def _create_collection(self) -> None:
lock_name = "vector_indexing_lock_" + self._collection_name
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = "vector_indexing_" + self._collection_name
if redis_client.get(collection_exist_cache_key):
return
if self._client.check_table_exists(self._collection_name):
return
self.delete()
cols = [
Column("id", String(36), primary_key=True, autoincrement=False),
Column("vector", VECTOR(self._vec_dim)),
Column("text", LONGTEXT),
Column("metadata", JSON),
]
vidx_params = self._client.prepare_index_params()
vidx_params.add_index(
field_name="vector",
index_type=OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPE,
index_name="vector_index",
metric_type=DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE,
params=DEFAULT_OCEANBASE_HNSW_BUILD_PARAM,
)
self._client.create_table_with_index_params(
table_name=self._collection_name,
columns=cols,
vidxs=vidx_params,
)
vals = []
params = self._client.perform_raw_text_sql("SHOW PARAMETERS LIKE '%ob_vector_memory_limit_percentage%'")
for row in params:
val = int(row[6])
vals.append(val)
if len(vals) == 0:
print("ob_vector_memory_limit_percentage not found in parameters.")
exit(1)
if any(val == 0 for val in vals):
try:
self._client.perform_raw_text_sql("ALTER SYSTEM SET ob_vector_memory_limit_percentage = 30")
except Exception as e:
raise Exception(
"Failed to set ob_vector_memory_limit_percentage. "
+ "Maybe the database user has insufficient privilege.",
e,
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
ids = self._get_uuids(documents)
for id, doc, emb in zip(ids, documents, embeddings):
self._client.insert(
table_name=self._collection_name,
data={
"id": id,
"vector": emb,
"text": doc.page_content,
"metadata": doc.metadata,
},
)
def text_exists(self, id: str) -> bool:
cur = self._client.get(table_name=self._collection_name, id=id)
return cur.rowcount != 0
def delete_by_ids(self, ids: list[str]) -> None:
self._client.delete(table_name=self._collection_name, ids=ids)
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
cur = self._client.get(
table_name=self._collection_name,
where_clause=f"metadata->>'$.{key}' = '{value}'",
output_column_name=["id"],
)
return [row[0] for row in cur]
def delete_by_metadata_field(self, key: str, value: str) -> None:
ids = self.get_ids_by_metadata_field(key, value)
self.delete_by_ids(ids)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return []
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
ef_search = kwargs.get("ef_search", self._hnsw_ef_search)
if ef_search != self._hnsw_ef_search:
self._client.set_ob_hnsw_ef_search(ef_search)
self._hnsw_ef_search = ef_search
topk = kwargs.get("top_k", 10)
cur = self._client.ann_search(
table_name=self._collection_name,
vec_column_name="vector",
vec_data=query_vector,
topk=topk,
distance_func=func.l2_distance,
output_column_names=["text", "metadata"],
with_dist=True,
)
docs = []
for text, metadata, distance in cur:
metadata = json.loads(metadata)
metadata["score"] = 1 - distance / math.sqrt(2)
docs.append(
Document(
page_content=text,
metadata=metadata,
)
)
return docs
def delete(self) -> None:
self._client.drop_table_if_exist(self._collection_name)
class OceanBaseVectorFactory(AbstractVectorFactory):
def init_vector(
self,
dataset: Dataset,
attributes: list,
embeddings: Embeddings,
) -> BaseVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OCEANBASE, collection_name))
return OceanBaseVector(
collection_name,
OceanBaseVectorConfig(
host=dify_config.OCEANBASE_VECTOR_HOST,
port=dify_config.OCEANBASE_VECTOR_PORT,
user=dify_config.OCEANBASE_VECTOR_USER,
password=(dify_config.OCEANBASE_VECTOR_PASSWORD or ""),
database=dify_config.OCEANBASE_VECTOR_DATABASE,
),
)

View File

@ -0,0 +1,17 @@
from typing import Optional
from pydantic import BaseModel
class ClusterEntity(BaseModel):
"""
Model Config Entity.
"""
name: str
cluster_id: str
displayName: str
region: str
spendingLimit: Optional[int] = 1000
version: str
createdBy: str

View File

@ -0,0 +1,526 @@
import json
import os
import uuid
from collections.abc import Generator, Iterable, Sequence
from itertools import islice
from typing import TYPE_CHECKING, Any, Optional, Union, cast
import qdrant_client
import requests
from flask import current_app
from pydantic import BaseModel
from qdrant_client.http import models as rest
from qdrant_client.http.models import (
FilterSelector,
HnswConfigDiff,
PayloadSchemaType,
TextIndexParams,
TextIndexType,
TokenizerType,
)
from qdrant_client.local.qdrant_local import QdrantLocal
from requests.auth import HTTPDigestAuth
from configs import dify_config
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, TidbAuthBinding
if TYPE_CHECKING:
from qdrant_client import grpc # noqa
from qdrant_client.conversions import common_types
from qdrant_client.http import models as rest
DictFilter = dict[str, Union[str, int, bool, dict, list]]
MetadataFilter = Union[DictFilter, common_types.Filter]
class TidbOnQdrantConfig(BaseModel):
endpoint: str
api_key: Optional[str] = None
timeout: float = 20
root_path: Optional[str] = None
grpc_port: int = 6334
prefer_grpc: bool = False
def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith("path:"):
path = self.endpoint.replace("path:", "")
if not os.path.isabs(path):
path = os.path.join(self.root_path, path)
return {"path": path}
else:
return {
"url": self.endpoint,
"api_key": self.api_key,
"timeout": self.timeout,
"verify": False,
"grpc_port": self.grpc_port,
"prefer_grpc": self.prefer_grpc,
}
class TidbConfig(BaseModel):
api_url: str
public_key: str
private_key: str
class TidbOnQdrantVector(BaseVector):
def __init__(self, collection_name: str, group_id: str, config: TidbOnQdrantConfig, distance_func: str = "Cosine"):
super().__init__(collection_name)
self._client_config = config
self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params())
self._distance_func = distance_func.upper()
self._group_id = group_id
def get_type(self) -> str:
return VectorType.TIDB_ON_QDRANT
def to_index_struct(self) -> dict:
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
if texts:
# get embedding vector size
vector_size = len(embeddings[0])
# get collection name
collection_name = self._collection_name
# create collection
self.create_collection(collection_name, vector_size)
self.add_texts(texts, embeddings, **kwargs)
def create_collection(self, collection_name: str, vector_size: int):
lock_name = "vector_indexing_lock_{}".format(collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
collection_name = collection_name or uuid.uuid4().hex
all_collection_name = []
collections_response = self._client.get_collections()
collection_list = collections_response.collections
for collection in collection_list:
all_collection_name.append(collection.name)
if collection_name not in all_collection_name:
from qdrant_client.http import models as rest
vectors_config = rest.VectorParams(
size=vector_size,
distance=rest.Distance[self._distance_func],
)
hnsw_config = HnswConfigDiff(
m=0,
payload_m=16,
ef_construct=100,
full_scan_threshold=10000,
max_indexing_threads=0,
on_disk=False,
)
self._client.recreate_collection(
collection_name=collection_name,
vectors_config=vectors_config,
hnsw_config=hnsw_config,
timeout=int(self._client_config.timeout),
)
# create group_id payload index
self._client.create_payload_index(
collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD
)
# create doc_id payload index
self._client.create_payload_index(
collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD
)
# create full text index
text_index_params = TextIndexParams(
type=TextIndexType.TEXT,
tokenizer=TokenizerType.MULTILINGUAL,
min_token_len=2,
max_token_len=20,
lowercase=True,
)
self._client.create_payload_index(
collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
added_ids = []
for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id):
self._client.upsert(collection_name=self._collection_name, points=points)
added_ids.extend(batch_ids)
return added_ids
def _generate_rest_batches(
self,
texts: Iterable[str],
embeddings: list[list[float]],
metadatas: Optional[list[dict]] = None,
ids: Optional[Sequence[str]] = None,
batch_size: int = 64,
group_id: Optional[str] = None,
) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]:
from qdrant_client.http import models as rest
texts_iterator = iter(texts)
embeddings_iterator = iter(embeddings)
metadatas_iterator = iter(metadatas or [])
ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
while batch_texts := list(islice(texts_iterator, batch_size)):
# Take the corresponding metadata and id for each text in a batch
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
batch_ids = list(islice(ids_iterator, batch_size))
# Generate the embeddings for all the texts in a batch
batch_embeddings = list(islice(embeddings_iterator, batch_size))
points = [
rest.PointStruct(
id=point_id,
vector=vector,
payload=payload,
)
for point_id, vector, payload in zip(
batch_ids,
batch_embeddings,
self._build_payloads(
batch_texts,
batch_metadatas,
Field.CONTENT_KEY.value,
Field.METADATA_KEY.value,
group_id,
Field.GROUP_KEY.value,
),
)
]
yield batch_ids, points
@classmethod
def _build_payloads(
cls,
texts: Iterable[str],
metadatas: Optional[list[dict]],
content_payload_key: str,
metadata_payload_key: str,
group_id: str,
group_payload_key: str,
) -> list[dict]:
payloads = []
for i, text in enumerate(texts):
if text is None:
raise ValueError(
"At least one of the texts is None. Please remove it before "
"calling .from_texts or .add_texts on Qdrant instance."
)
metadata = metadatas[i] if metadatas is not None else None
payloads.append({content_payload_key: text, metadata_payload_key: metadata, group_payload_key: group_id})
return payloads
def delete_by_metadata_field(self, key: str, value: str):
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
try:
filter = models.Filter(
must=[
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
],
)
self._reload_if_needed()
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
def delete(self):
from qdrant_client.http.exceptions import UnexpectedResponse
try:
self._client.delete_collection(collection_name=self._collection_name)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
def delete_by_ids(self, ids: list[str]) -> None:
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
for node_id in ids:
try:
filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchValue(value=node_id),
),
],
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
def text_exists(self, id: str) -> bool:
all_collection_name = []
collections_response = self._client.get_collections()
collection_list = collections_response.collections
for collection in collection_list:
all_collection_name.append(collection.name)
if self._collection_name not in all_collection_name:
return False
response = self._client.retrieve(collection_name=self._collection_name, ids=[id])
return len(response) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from qdrant_client.http import models
filter = models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self._group_id),
),
],
)
results = self._client.search(
collection_name=self._collection_name,
query_vector=query_vector,
query_filter=filter,
limit=kwargs.get("top_k", 4),
with_payload=True,
with_vectors=True,
score_threshold=kwargs.get("score_threshold", 0.0),
)
docs = []
for result in results:
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold
score_threshold = kwargs.get("score_threshold") or 0.0
if result.score > score_threshold:
metadata["score"] = result.score
doc = Document(
page_content=result.payload.get(Field.CONTENT_KEY.value),
metadata=metadata,
)
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
"""Return docs most similar by bm25.
Returns:
List of documents most similar to the query text and distance for each.
"""
from qdrant_client.http import models
scroll_filter = models.Filter(
must=[
models.FieldCondition(
key="page_content",
match=models.MatchText(text=query),
)
]
)
response = self._client.scroll(
collection_name=self._collection_name,
scroll_filter=scroll_filter,
limit=kwargs.get("top_k", 2),
with_payload=True,
with_vectors=True,
)
results = response[0]
documents = []
for result in results:
if result:
document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value)
document.metadata["vector"] = result.vector
documents.append(document)
return documents
def _reload_if_needed(self):
if isinstance(self._client, QdrantLocal):
self._client = cast(QdrantLocal, self._client)
self._client._load()
@classmethod
def _document_from_scored_point(
cls,
scored_point: Any,
content_payload_key: str,
metadata_payload_key: str,
) -> Document:
return Document(
page_content=scored_point.payload.get(content_payload_key),
metadata=scored_point.payload.get(metadata_payload_key) or {},
)
class TidbOnQdrantVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector:
tidb_auth_binding = (
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
)
if not tidb_auth_binding:
idle_tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
.limit(1)
.one_or_none()
)
if idle_tidb_auth_binding:
idle_tidb_auth_binding.active = True
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
db.session.commit()
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
else:
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
.one_or_none()
)
if tidb_auth_binding:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
else:
new_cluster = TidbService.create_tidb_serverless_cluster(
dify_config.TIDB_PROJECT_ID,
dify_config.TIDB_API_URL,
dify_config.TIDB_IAM_API_URL,
dify_config.TIDB_PUBLIC_KEY,
dify_config.TIDB_PRIVATE_KEY,
dify_config.TIDB_REGION,
)
new_tidb_auth_binding = TidbAuthBinding(
cluster_id=new_cluster["cluster_id"],
cluster_name=new_cluster["cluster_name"],
account=new_cluster["account"],
password=new_cluster["password"],
tenant_id=dataset.tenant_id,
active=True,
status="ACTIVE",
)
db.session.add(new_tidb_auth_binding)
db.session.commit()
TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}"
else:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TIDB_ON_QDRANT, collection_name))
config = current_app.config
return TidbOnQdrantVector(
collection_name=collection_name,
group_id=dataset.id,
config=TidbOnQdrantConfig(
endpoint=dify_config.TIDB_ON_QDRANT_URL,
api_key=TIDB_ON_QDRANT_API_KEY,
root_path=config.root_path,
timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT,
grpc_port=dify_config.TIDB_ON_QDRANT_GRPC_PORT,
prefer_grpc=dify_config.TIDB_ON_QDRANT_GRPC_ENABLED,
),
)
def create_tidb_serverless_cluster(self, tidb_config: TidbConfig, display_name: str, region: str):
"""
Creates a new TiDB Serverless cluster.
:param tidb_config: The configuration for the TiDB Cloud API.
:param display_name: The user-friendly display name of the cluster (required).
:param region: The region where the cluster will be created (required).
:return: The response from the API.
"""
region_object = {
"name": region,
}
labels = {
"tidb.cloud/project": "1372813089454548012",
}
cluster_data = {"displayName": display_name, "region": region_object, "labels": labels}
response = requests.post(
f"{tidb_config.api_url}/clusters",
json=cluster_data,
auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key),
)
if response.status_code == 200:
return response.json()
else:
response.raise_for_status()
def change_tidb_serverless_root_password(self, tidb_config: TidbConfig, cluster_id: str, new_password: str):
"""
Changes the root password of a specific TiDB Serverless cluster.
:param tidb_config: The configuration for the TiDB Cloud API.
:param cluster_id: The ID of the cluster for which the password is to be changed (required).
:param new_password: The new password for the root user (required).
:return: The response from the API.
"""
body = {"password": new_password}
response = requests.put(
f"{tidb_config.api_url}/clusters/{cluster_id}/password",
json=body,
auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key),
)
if response.status_code == 200:
return response.json()
else:
response.raise_for_status()

View File

@ -0,0 +1,251 @@
import time
import uuid
import requests
from requests.auth import HTTPDigestAuth
from configs import dify_config
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import TidbAuthBinding
class TidbService:
@staticmethod
def create_tidb_serverless_cluster(
project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str
):
"""
Creates a new TiDB Serverless cluster.
:param project_id: The project ID of the TiDB Cloud project (required).
:param api_url: The URL of the TiDB Cloud API (required).
:param iam_url: The URL of the TiDB Cloud IAM API (required).
:param public_key: The public key for the API (required).
:param private_key: The private key for the API (required).
:param display_name: The user-friendly display name of the cluster (required).
:param region: The region where the cluster will be created (required).
:return: The response from the API.
"""
region_object = {
"name": region,
}
labels = {
"tidb.cloud/project": project_id,
}
spending_limit = {
"monthly": 100,
}
password = str(uuid.uuid4()).replace("-", "")[:16]
display_name = str(uuid.uuid4()).replace("-", "")[:16]
cluster_data = {
"displayName": display_name,
"region": region_object,
"labels": labels,
"spendingLimit": spending_limit,
"rootPassword": password,
}
response = requests.post(f"{api_url}/clusters", json=cluster_data, auth=HTTPDigestAuth(public_key, private_key))
if response.status_code == 200:
response_data = response.json()
cluster_id = response_data["clusterId"]
retry_count = 0
max_retries = 30
while retry_count < max_retries:
cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id)
if cluster_response["state"] == "ACTIVE":
user_prefix = cluster_response["userPrefix"]
return {
"cluster_id": cluster_id,
"cluster_name": display_name,
"account": f"{user_prefix}.root",
"password": password,
}
time.sleep(30) # wait 30 seconds before retrying
retry_count += 1
else:
response.raise_for_status()
@staticmethod
def delete_tidb_serverless_cluster(api_url: str, public_key: str, private_key: str, cluster_id: str):
"""
Deletes a specific TiDB Serverless cluster.
:param api_url: The URL of the TiDB Cloud API (required).
:param public_key: The public key for the API (required).
:param private_key: The private key for the API (required).
:param cluster_id: The ID of the cluster to be deleted (required).
:return: The response from the API.
"""
response = requests.delete(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key))
if response.status_code == 200:
return response.json()
else:
response.raise_for_status()
@staticmethod
def get_tidb_serverless_cluster(api_url: str, public_key: str, private_key: str, cluster_id: str):
"""
Deletes a specific TiDB Serverless cluster.
:param api_url: The URL of the TiDB Cloud API (required).
:param public_key: The public key for the API (required).
:param private_key: The private key for the API (required).
:param cluster_id: The ID of the cluster to be deleted (required).
:return: The response from the API.
"""
response = requests.get(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key))
if response.status_code == 200:
return response.json()
else:
response.raise_for_status()
@staticmethod
def change_tidb_serverless_root_password(
api_url: str, public_key: str, private_key: str, cluster_id: str, account: str, new_password: str
):
"""
Changes the root password of a specific TiDB Serverless cluster.
:param api_url: The URL of the TiDB Cloud API (required).
:param public_key: The public key for the API (required).
:param private_key: The private key for the API (required).
:param cluster_id: The ID of the cluster for which the password is to be changed (required).+
:param account: The account for which the password is to be changed (required).
:param new_password: The new password for the root user (required).
:return: The response from the API.
"""
body = {"password": new_password, "builtinRole": "role_admin", "customRoles": []}
response = requests.patch(
f"{api_url}/clusters/{cluster_id}/sqlUsers/{account}",
json=body,
auth=HTTPDigestAuth(public_key, private_key),
)
if response.status_code == 200:
return response.json()
else:
response.raise_for_status()
@staticmethod
def batch_update_tidb_serverless_cluster_status(
tidb_serverless_list: list[TidbAuthBinding],
project_id: str,
api_url: str,
iam_url: str,
public_key: str,
private_key: str,
) -> list[dict]:
"""
Update the status of a new TiDB Serverless cluster.
:param project_id: The project ID of the TiDB Cloud project (required).
:param api_url: The URL of the TiDB Cloud API (required).
:param iam_url: The URL of the TiDB Cloud IAM API (required).
:param public_key: The public key for the API (required).
:param private_key: The private key for the API (required).
:param display_name: The user-friendly display name of the cluster (required).
:param region: The region where the cluster will be created (required).
:return: The response from the API.
"""
clusters = []
tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list}
cluster_ids = [item.cluster_id for item in tidb_serverless_list]
params = {"clusterIds": cluster_ids, "view": "FULL"}
response = requests.get(
f"{api_url}/clusters:batchGet", params=params, auth=HTTPDigestAuth(public_key, private_key)
)
if response.status_code == 200:
response_data = response.json()
cluster_infos = []
for item in response_data["clusters"]:
state = item["state"]
userPrefix = item["userPrefix"]
if state == "ACTIVE" and len(userPrefix) > 0:
cluster_info = tidb_serverless_list_map[item["clusterId"]]
cluster_info.status = "ACTIVE"
cluster_info.account = f"{userPrefix}.root"
db.session.add(cluster_info)
db.session.commit()
else:
response.raise_for_status()
@staticmethod
def batch_create_tidb_serverless_cluster(
batch_size: int, project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str
) -> list[dict]:
"""
Creates a new TiDB Serverless cluster.
:param project_id: The project ID of the TiDB Cloud project (required).
:param api_url: The URL of the TiDB Cloud API (required).
:param iam_url: The URL of the TiDB Cloud IAM API (required).
:param public_key: The public key for the API (required).
:param private_key: The private key for the API (required).
:param display_name: The user-friendly display name of the cluster (required).
:param region: The region where the cluster will be created (required).
:return: The response from the API.
"""
clusters = []
for _ in range(batch_size):
region_object = {
"name": region,
}
labels = {
"tidb.cloud/project": project_id,
}
spending_limit = {
"monthly": dify_config.TIDB_SPEND_LIMIT,
}
password = str(uuid.uuid4()).replace("-", "")[:16]
display_name = str(uuid.uuid4()).replace("-", "")
cluster_data = {
"cluster": {
"displayName": display_name,
"region": region_object,
"labels": labels,
"spendingLimit": spending_limit,
"rootPassword": password,
}
}
cache_key = f"tidb_serverless_cluster_password:{display_name}"
redis_client.setex(cache_key, 3600, password)
clusters.append(cluster_data)
request_body = {"requests": clusters}
response = requests.post(
f"{api_url}/clusters:batchCreate", json=request_body, auth=HTTPDigestAuth(public_key, private_key)
)
if response.status_code == 200:
response_data = response.json()
cluster_infos = []
for item in response_data["clusters"]:
cache_key = f"tidb_serverless_cluster_password:{item['displayName']}"
password = redis_client.get(cache_key)
if not password:
continue
cluster_info = {
"cluster_id": item["clusterId"],
"cluster_name": item["displayName"],
"account": "root",
"password": password.decode("utf-8"),
}
cluster_infos.append(cluster_info)
return cluster_infos
else:
response.raise_for_status()

View File

@ -0,0 +1,129 @@
import json
from typing import Any
from uuid import uuid4
from pydantic import BaseModel, model_validator
from upstash_vector import Index, Vector
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from models.dataset import Dataset
class UpstashVectorConfig(BaseModel):
url: str
token: str
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["url"]:
raise ValueError("Upstash URL is required")
if not values["token"]:
raise ValueError("Upstash Token is required")
return values
class UpstashVector(BaseVector):
def __init__(self, collection_name: str, config: UpstashVectorConfig):
super().__init__(collection_name)
self._table_name = collection_name
self.index = Index(url=config.url, token=config.token)
def _get_index_dimension(self) -> int:
index_info = self.index.info()
if index_info and index_info.dimension:
return index_info.dimension
else:
return 1536
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
vectors = [
Vector(
id=str(uuid4()),
vector=embedding,
metadata=doc.metadata,
data=doc.page_content,
)
for doc, embedding in zip(documents, embeddings)
]
self.index.upsert(vectors=vectors)
def text_exists(self, id: str) -> bool:
response = self.get_ids_by_metadata_field("doc_id", id)
return len(response) > 0
def delete_by_ids(self, ids: list[str]) -> None:
item_ids = []
for doc_id in ids:
ids = self.get_ids_by_metadata_field("doc_id", doc_id)
if id:
item_ids += ids
self._delete_by_ids(ids=item_ids)
def _delete_by_ids(self, ids: list[str]) -> None:
if ids:
self.index.delete(ids=ids)
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
query_result = self.index.query(
vector=[1.001 * i for i in range(self._get_index_dimension())],
include_metadata=True,
top_k=1000,
filter=f"{key} = '{value}'",
)
return [result.id for result in query_result]
def delete_by_metadata_field(self, key: str, value: str) -> None:
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self._delete_by_ids(ids)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
result = self.index.query(vector=query_vector, top_k=top_k, include_metadata=True, include_data=True)
docs = []
score_threshold = float(kwargs.get("score_threshold") or 0.0)
for record in result:
metadata = record.metadata
text = record.data
score = record.score
metadata["score"] = score
if score > score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return []
def delete(self) -> None:
self.index.reset()
def get_type(self) -> str:
return VectorType.UPSTASH
class UpstashVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> UpstashVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.UPSTASH, collection_name))
return UpstashVector(
collection_name=collection_name,
config=UpstashVectorConfig(
url=dify_config.UPSTASH_VECTOR_URL,
token=dify_config.UPSTASH_VECTOR_TOKEN,
),
)

View File

@ -9,8 +9,9 @@ from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.cached_embedding import CacheEmbedding
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset
from models.dataset import Dataset, Whitelist
class AbstractVectorFactory(ABC):
@ -35,8 +36,18 @@ class Vector:
def _init_vector(self) -> BaseVector:
vector_type = dify_config.VECTOR_STORE
if self._dataset.index_struct_dict:
vector_type = self._dataset.index_struct_dict["type"]
else:
if dify_config.VECTOR_STORE_WHITELIST_ENABLE:
whitelist = (
db.session.query(Whitelist)
.filter(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db")
.one_or_none()
)
if whitelist:
vector_type = VectorType.TIDB_ON_QDRANT
if not vector_type:
raise ValueError("Vector store must be specified.")
@ -103,6 +114,10 @@ class Vector:
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory
return AnalyticdbVectorFactory
case VectorType.COUCHBASE:
from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseVectorFactory
return CouchbaseVectorFactory
case VectorType.BAIDU:
from core.rag.datasource.vdb.baidu.baidu_vector import BaiduVectorFactory
@ -111,6 +126,18 @@ class Vector:
from core.rag.datasource.vdb.vikingdb.vikingdb_vector import VikingDBVectorFactory
return VikingDBVectorFactory
case VectorType.UPSTASH:
from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVectorFactory
return UpstashVectorFactory
case VectorType.TIDB_ON_QDRANT:
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory
return TidbOnQdrantVectorFactory
case VectorType.OCEANBASE:
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory
return OceanBaseVectorFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")

View File

@ -16,5 +16,9 @@ class VectorType(str, Enum):
TENCENT = "tencent"
ORACLE = "oracle"
ELASTICSEARCH = "elasticsearch"
COUCHBASE = "couchbase"
BAIDU = "baidu"
VIKINGDB = "vikingdb"
UPSTASH = "upstash"
TIDB_ON_QDRANT = "tidb_on_qdrant"
OCEANBASE = "oceanbase"

Some files were not shown because too many files have changed in this diff Show More