From 3ad943a9eb3b4d762ee93f8a025816c7d67d38ce Mon Sep 17 00:00:00 2001
From: Garfield Dai
Date: Fri, 23 Feb 2024 16:12:43 +0800
Subject: [PATCH 01/36] Feat/openai llm trial paid config (#2545)
---
api/.env.example | 2 ++
api/config.py | 4 ++++
api/core/hosting_configuration.py | 36 +++++++++++--------------------
3 files changed, 18 insertions(+), 24 deletions(-)
diff --git a/api/.env.example b/api/.env.example
index 89d550ba5a..fbff8385f8 100644
--- a/api/.env.example
+++ b/api/.env.example
@@ -111,8 +111,10 @@ HOSTED_OPENAI_API_KEY=
HOSTED_OPENAI_API_BASE=
HOSTED_OPENAI_API_ORGANIZATION=
HOSTED_OPENAI_TRIAL_ENABLED=false
+HOSTED_OPENAI_TRIAL_MODELS=gpt-3.5-turbo,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,text-davinci-003
HOSTED_OPENAI_QUOTA_LIMIT=200
HOSTED_OPENAI_PAID_ENABLED=false
+HOSTED_OPENAI_PAID_MODELS=gpt-4,gpt-4-turbo-preview,gpt-4-32k,gpt-4-1106-preview,gpt-4-0125-preview,gpt-3.5-turbo,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-1106,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,gpt-3.5-turbo-instruct,text-davinci-003
HOSTED_AZURE_OPENAI_ENABLED=false
HOSTED_AZURE_OPENAI_API_KEY=
diff --git a/api/config.py b/api/config.py
index 83336e6c45..6bb0496be6 100644
--- a/api/config.py
+++ b/api/config.py
@@ -38,7 +38,9 @@ DEFAULTS = {
'LOG_LEVEL': 'INFO',
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_OPENAI_TRIAL_ENABLED': 'False',
+ 'HOSTED_OPENAI_TRIAL_MODELS': '',
'HOSTED_OPENAI_PAID_ENABLED': 'False',
+ 'HOSTED_OPENAI_PAID_MODELS': '',
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
@@ -261,8 +263,10 @@ class Config:
self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
self.HOSTED_OPENAI_TRIAL_ENABLED = get_bool_env('HOSTED_OPENAI_TRIAL_ENABLED')
+ self.HOSTED_OPENAI_TRIAL_MODELS = get_env('HOSTED_OPENAI_TRIAL_MODELS')
self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT'))
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
+ self.HOSTED_OPENAI_PAID_MODELS = get_env('HOSTED_OPENAI_PAID_MODELS')
self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py
index 58b551f295..880a30cdf4 100644
--- a/api/core/hosting_configuration.py
+++ b/api/core/hosting_configuration.py
@@ -104,37 +104,17 @@ class HostingConfiguration:
if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"):
hosted_quota_limit = int(app_config.get("HOSTED_OPENAI_QUOTA_LIMIT", "200"))
+ trial_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_TRIAL_MODELS")
trial_quota = TrialHostingQuota(
quota_limit=hosted_quota_limit,
- restrict_models=[
- RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM),
- RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
- RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
- RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
- RestrictModel(model="gpt-3.5-turbo-16k-0613", model_type=ModelType.LLM),
- RestrictModel(model="gpt-3.5-turbo-0613", model_type=ModelType.LLM),
- RestrictModel(model="gpt-3.5-turbo-0125", model_type=ModelType.LLM),
- RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
- ]
+ restrict_models=trial_models
)
quotas.append(trial_quota)
if app_config.get("HOSTED_OPENAI_PAID_ENABLED"):
+ paid_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_PAID_MODELS")
paid_quota = PaidHostingQuota(
- restrict_models=[
- RestrictModel(model="gpt-4", model_type=ModelType.LLM),
- RestrictModel(model="gpt-4-turbo-preview", model_type=ModelType.LLM),
- RestrictModel(model="gpt-4-1106-preview", model_type=ModelType.LLM),
- RestrictModel(model="gpt-4-0125-preview", model_type=ModelType.LLM),
- RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM),
- RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
- RestrictModel(model="gpt-3.5-turbo-16k-0613", model_type=ModelType.LLM),
- RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
- RestrictModel(model="gpt-3.5-turbo-0613", model_type=ModelType.LLM),
- RestrictModel(model="gpt-3.5-turbo-0125", model_type=ModelType.LLM),
- RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
- RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
- ]
+ restrict_models=paid_models
)
quotas.append(paid_quota)
@@ -258,3 +238,11 @@ class HostingConfiguration:
return HostedModerationConfig(
enabled=False
)
+
+ @staticmethod
+ def parse_restrict_models_from_env(app_config: Config, env_var: str) -> list[RestrictModel]:
+ models_str = app_config.get(env_var)
+ models_list = models_str.split(",") if models_str else []
+ return [RestrictModel(model=model_name.strip(), model_type=ModelType.LLM) for model_name in models_list if
+ model_name.strip()]
+
From 49da8a23a847b75e6f918d58316b019ff328e99c Mon Sep 17 00:00:00 2001
From: Garfield Dai
Date: Fri, 23 Feb 2024 16:48:58 +0800
Subject: [PATCH 02/36] feat: openai llm get trial or paid models from config.
(#2546)
---
api/.env.example | 2 --
api/config.py | 4 ++--
2 files changed, 2 insertions(+), 4 deletions(-)
diff --git a/api/.env.example b/api/.env.example
index fbff8385f8..89d550ba5a 100644
--- a/api/.env.example
+++ b/api/.env.example
@@ -111,10 +111,8 @@ HOSTED_OPENAI_API_KEY=
HOSTED_OPENAI_API_BASE=
HOSTED_OPENAI_API_ORGANIZATION=
HOSTED_OPENAI_TRIAL_ENABLED=false
-HOSTED_OPENAI_TRIAL_MODELS=gpt-3.5-turbo,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,text-davinci-003
HOSTED_OPENAI_QUOTA_LIMIT=200
HOSTED_OPENAI_PAID_ENABLED=false
-HOSTED_OPENAI_PAID_MODELS=gpt-4,gpt-4-turbo-preview,gpt-4-32k,gpt-4-1106-preview,gpt-4-0125-preview,gpt-3.5-turbo,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-1106,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,gpt-3.5-turbo-instruct,text-davinci-003
HOSTED_AZURE_OPENAI_ENABLED=false
HOSTED_AZURE_OPENAI_API_KEY=
diff --git a/api/config.py b/api/config.py
index 6bb0496be6..8eeede0ff9 100644
--- a/api/config.py
+++ b/api/config.py
@@ -38,9 +38,9 @@ DEFAULTS = {
'LOG_LEVEL': 'INFO',
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_OPENAI_TRIAL_ENABLED': 'False',
- 'HOSTED_OPENAI_TRIAL_MODELS': '',
+ 'HOSTED_OPENAI_TRIAL_MODELS': 'gpt-3.5-turbo,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,text-davinci-003',
'HOSTED_OPENAI_PAID_ENABLED': 'False',
- 'HOSTED_OPENAI_PAID_MODELS': '',
+ 'HOSTED_OPENAI_PAID_MODELS': 'gpt-4,gpt-4-turbo-preview,gpt-4-1106-preview,gpt-4-0125-preview,gpt-3.5-turbo,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-1106,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,gpt-3.5-turbo-instruct,text-davinci-003',
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
From 4be30876427028b6c1000a6723e02c193412e1d6 Mon Sep 17 00:00:00 2001
From: Jyong <76649700+JohnJyong@users.noreply.github.com>
Date: Fri, 23 Feb 2024 16:54:15 +0800
Subject: [PATCH 03/36] Fix/new RAG bugs (#2547)
Co-authored-by: jyong
---
api/core/indexing_runner.py | 2 +-
api/core/rag/datasource/retrieval_service.py | 8 ++++++--
.../dataset_retriever/dataset_multi_retriever_tool.py | 2 +-
.../tool/dataset_retriever/dataset_retriever_tool.py | 2 +-
api/tasks/clean_dataset_task.py | 1 -
5 files changed, 9 insertions(+), 6 deletions(-)
diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py
index c8a2e09443..d2d04c984b 100644
--- a/api/core/indexing_runner.py
+++ b/api/core/indexing_runner.py
@@ -365,7 +365,7 @@ class IndexingRunner:
notion_info={
"notion_workspace_id": data_source_info['notion_workspace_id'],
"notion_obj_id": data_source_info['notion_page_id'],
- "notion_page_type": data_source_info['notion_page_type'],
+ "notion_page_type": data_source_info['type'],
"document": dataset_document
},
document_model=dataset_document.doc_form
diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py
index 79673ffa83..c0205d1aa9 100644
--- a/api/core/rag/datasource/retrieval_service.py
+++ b/api/core/rag/datasource/retrieval_service.py
@@ -2,7 +2,6 @@ import threading
from typing import Optional
from flask import Flask, current_app
-from flask_login import current_user
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
@@ -27,6 +26,11 @@ class RetrievalService:
@classmethod
def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None):
+ dataset = db.session.query(Dataset).filter(
+ Dataset.id == dataset_id
+ ).first()
+ if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
+ return []
all_documents = []
threads = []
# retrieval_model source with keyword
@@ -73,7 +77,7 @@ class RetrievalService:
thread.join()
if retrival_method == 'hybrid_search':
- data_post_processor = DataPostProcessor(str(current_user.current_tenant_id), reranking_model, False)
+ data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
all_documents = data_post_processor.invoke(
query=query,
documents=all_documents,
diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py
index 57b6e090c4..d9934acff9 100644
--- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py
+++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py
@@ -171,7 +171,7 @@ class DatasetMultiRetrieverTool(BaseTool):
if dataset.indexing_technique == "economy":
# use keyword table query
- documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
+ documents = RetrievalService.retrieve(retrival_method='keyword_search',
dataset_id=dataset.id,
query=query,
top_k=self.top_k
diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py
index d3ec0fba69..13331d981b 100644
--- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py
+++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py
@@ -69,7 +69,7 @@ class DatasetRetrieverTool(BaseTool):
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
- documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
+ documents = RetrievalService.retrieve(retrival_method='keyword_search',
dataset_id=dataset.id,
query=query,
top_k=self.top_k
diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py
index 16e4affc91..37e109c847 100644
--- a/api/tasks/clean_dataset_task.py
+++ b/api/tasks/clean_dataset_task.py
@@ -40,7 +40,6 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
indexing_technique=indexing_technique,
index_struct=index_struct,
collection_binding_id=collection_binding_id,
- doc_form=doc_form
)
documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all()
segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
From 952e13fef8d1b411705a85a1e75664fd4b34d73e Mon Sep 17 00:00:00 2001
From: takatost
Date: Fri, 23 Feb 2024 17:38:03 +0800
Subject: [PATCH 04/36] Update README_CN.md (#2550)
---
api/core/model_runtime/README_CN.md | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/api/core/model_runtime/README_CN.md b/api/core/model_runtime/README_CN.md
index 6950cdc0c7..3664fa2ca3 100644
--- a/api/core/model_runtime/README_CN.md
+++ b/api/core/model_runtime/README_CN.md
@@ -20,7 +20,7 @@

- 展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./schema.md)。
+ 展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./docs/zh_Hans/schema.md)。
- 可选择的模型列表展示
@@ -86,4 +86,4 @@ Model Runtime 分三层:

### [接口的具体实现 👈🏻](./docs/zh_Hans/interfaces.md)
-你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。
\ No newline at end of file
+你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。
From ca69af7b975cdaf9b78262840386164b30a61e57 Mon Sep 17 00:00:00 2001
From: Rozstone <42225395+wststone@users.noreply.github.com>
Date: Sat, 24 Feb 2024 09:28:27 +0800
Subject: [PATCH 05/36] feat: change max_question_num to 5 (#2520)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
---
.../features/chat-group/opening-statement/index.tsx | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/web/app/components/app/configuration/features/chat-group/opening-statement/index.tsx b/web/app/components/app/configuration/features/chat-group/opening-statement/index.tsx
index 29ecce5281..6be76210da 100644
--- a/web/app/components/app/configuration/features/chat-group/opening-statement/index.tsx
+++ b/web/app/components/app/configuration/features/chat-group/opening-statement/index.tsx
@@ -18,7 +18,7 @@ import { getNewVar } from '@/utils/var'
import { varHighlightHTML } from '@/app/components/app/configuration/base/var-highlight'
import { Plus, Trash03 } from '@/app/components/base/icons/src/vender/line/general'
-const MAX_QUESTION_NUM = 3
+const MAX_QUESTION_NUM = 5
export type IOpeningStatementProps = {
value: string
From d93288f71112f2f054376138397153097f1d55a8 Mon Sep 17 00:00:00 2001
From: Rozstone <42225395+wststone@users.noreply.github.com>
Date: Mon, 26 Feb 2024 12:52:59 +0800
Subject: [PATCH 06/36] Feat/use searchparams as state (#2554)
Co-authored-by: crazywoola <427733928@qq.com>
---
web/app/(commonLayout)/apps/Apps.tsx | 5 ++-
web/app/(commonLayout)/datasets/Container.tsx | 9 +++--
.../app/annotation/header-opts/index.tsx | 5 +--
.../annotation/header-opts/style.module.css | 4 +--
web/app/components/explore/app-list/index.tsx | 5 ++-
web/app/components/tools/index.tsx | 5 ++-
web/hooks/use-tab-searchparams.ts | 34 +++++++++++++++++++
7 files changed, 58 insertions(+), 9 deletions(-)
create mode 100644 web/hooks/use-tab-searchparams.ts
diff --git a/web/app/(commonLayout)/apps/Apps.tsx b/web/app/(commonLayout)/apps/Apps.tsx
index 106f90810d..b121aa644c 100644
--- a/web/app/(commonLayout)/apps/Apps.tsx
+++ b/web/app/(commonLayout)/apps/Apps.tsx
@@ -11,6 +11,7 @@ import { fetchAppList } from '@/service/apps'
import { useAppContext } from '@/context/app-context'
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
import { CheckModal } from '@/hooks/use-pay'
+import { useTabSearchParams } from '@/hooks/use-tab-searchparams'
import TabSlider from '@/app/components/base/tab-slider'
import { SearchLg } from '@/app/components/base/icons/src/vender/line/general'
import { XCircle } from '@/app/components/base/icons/src/vender/solid/general'
@@ -35,7 +36,9 @@ const getKey = (
const Apps = () => {
const { t } = useTranslation()
const { isCurrentWorkspaceManager } = useAppContext()
- const [activeTab, setActiveTab] = useState('all')
+ const [activeTab, setActiveTab] = useTabSearchParams({
+ defaultTab: 'all',
+ })
const [keywords, setKeywords] = useState('')
const [searchKeywords, setSearchKeywords] = useState('')
diff --git a/web/app/(commonLayout)/datasets/Container.tsx b/web/app/(commonLayout)/datasets/Container.tsx
index c3ebef2ea8..d70ed1cb63 100644
--- a/web/app/(commonLayout)/datasets/Container.tsx
+++ b/web/app/(commonLayout)/datasets/Container.tsx
@@ -1,7 +1,7 @@
'use client'
// Libraries
-import { useRef, useState } from 'react'
+import { useRef } from 'react'
import { useTranslation } from 'react-i18next'
import useSWR from 'swr'
@@ -15,6 +15,9 @@ import TabSlider from '@/app/components/base/tab-slider'
// Services
import { fetchDatasetApiBaseUrl } from '@/service/datasets'
+// Hooks
+import { useTabSearchParams } from '@/hooks/use-tab-searchparams'
+
const Container = () => {
const { t } = useTranslation()
@@ -23,7 +26,9 @@ const Container = () => {
{ value: 'api', text: t('dataset.datasetsApi') },
]
- const [activeTab, setActiveTab] = useState('dataset')
+ const [activeTab, setActiveTab] = useTabSearchParams({
+ defaultTab: 'dataset',
+ })
const containerRef = useRef(null)
const { data } = useSWR(activeTab === 'dataset' ? null : '/datasets/api-base-info', fetchDatasetApiBaseUrl)
diff --git a/web/app/components/app/annotation/header-opts/index.tsx b/web/app/components/app/annotation/header-opts/index.tsx
index 90b1a9672e..aba3b6324c 100644
--- a/web/app/components/app/annotation/header-opts/index.tsx
+++ b/web/app/components/app/annotation/header-opts/index.tsx
@@ -42,6 +42,7 @@ const HeaderOptions: FC = ({
const { locale } = useContext(I18n)
const { CSVDownloader, Type } = useCSVDownloader()
const [list, setList] = useState([])
+ const annotationUnavailable = list.length === 0
const listTransformer = (list: AnnotationItemBasic[]) => list.map(
(item: AnnotationItemBasic) => {
@@ -116,11 +117,11 @@ const HeaderOptions: FC = ({
...list.map(item => [item.question, item.answer]),
]}
>
-
+
+
+ Dify.AI Upcoming Meetup Event [👉 Click to Join the Event Here 👈]
+
+
+ - US EST: 09:00 (9:00 AM)
+ - CET: 15:00 (3:00 PM)
+ - CST: 22:00 (10:00 PM)
+
+
+
Dify.AI Unveils AI Agent: Creating GPTs and Assistants with Various LLMs
From 562ca45e07502aa1d8d66ceff1ca1c66594b720d Mon Sep 17 00:00:00 2001
From: Bowen Liang
Date: Tue, 27 Feb 2024 11:14:35 +0800
Subject: [PATCH 12/36] fix weaviate delete_by_ids (#2565)
---
api/core/rag/datasource/vdb/weaviate/weaviate_vector.py | 9 +++++----
1 file changed, 5 insertions(+), 4 deletions(-)
diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py
index 78033379d6..008e54085d 100644
--- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py
+++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py
@@ -150,10 +150,11 @@ class WeaviateVector(BaseVector):
return True
def delete_by_ids(self, ids: list[str]) -> None:
- self._client.data_object.delete(
- ids,
- class_name=self._collection_name
- )
+ for uuid in ids:
+ self._client.data_object.delete(
+ class_name=self._collection_name,
+ uuid=uuid,
+ )
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""Look up similar documents by embedding vector in Weaviate."""
From 5b953c1ef2ec1081bd4156d4708ec35f7bd395da Mon Sep 17 00:00:00 2001
From: Jyong <76649700+JohnJyong@users.noreply.github.com>
Date: Tue, 27 Feb 2024 11:39:05 +0800
Subject: [PATCH 13/36] Fix some RAG bugs (#2570)
Co-authored-by: jyong
---
.../console/datasets/data_source.py | 6 +-
api/controllers/console/datasets/datasets.py | 3 +-
.../console/datasets/datasets_document.py | 3 +-
api/core/indexing_runner.py | 3 +-
api/core/rag/datasource/retrieval_service.py | 3 +-
.../rag/extractor/entity/extract_setting.py | 1 +
api/core/rag/extractor/extract_processor.py | 3 +-
api/core/rag/extractor/html_extractor.py | 63 ++++---------------
api/core/rag/extractor/notion_extractor.py | 4 +-
api/tasks/document_indexing_sync_task.py | 3 +-
10 files changed, 33 insertions(+), 59 deletions(-)
diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py
index c0c345baea..f3e639c6ac 100644
--- a/api/controllers/console/datasets/data_source.py
+++ b/api/controllers/console/datasets/data_source.py
@@ -178,7 +178,8 @@ class DataSourceNotionApi(Resource):
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type,
- notion_access_token=data_source_binding.access_token
+ notion_access_token=data_source_binding.access_token,
+ tenant_id=current_user.current_tenant_id
)
text_docs = extractor.extract()
@@ -208,7 +209,8 @@ class DataSourceNotionApi(Resource):
notion_info={
"notion_workspace_id": workspace_id,
"notion_obj_id": page['page_id'],
- "notion_page_type": page['type']
+ "notion_page_type": page['type'],
+ "tenant_id": current_user.current_tenant_id
},
document_model=args['doc_form']
)
diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py
index f80b4de48d..e633631c42 100644
--- a/api/controllers/console/datasets/datasets.py
+++ b/api/controllers/console/datasets/datasets.py
@@ -298,7 +298,8 @@ class DatasetIndexingEstimateApi(Resource):
notion_info={
"notion_workspace_id": workspace_id,
"notion_obj_id": page['page_id'],
- "notion_page_type": page['type']
+ "notion_page_type": page['type'],
+ "tenant_id": current_user.current_tenant_id
},
document_model=args['doc_form']
)
diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py
index a990ef96ee..c383cdc762 100644
--- a/api/controllers/console/datasets/datasets_document.py
+++ b/api/controllers/console/datasets/datasets_document.py
@@ -455,7 +455,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
notion_info={
"notion_workspace_id": data_source_info['notion_workspace_id'],
"notion_obj_id": data_source_info['notion_page_id'],
- "notion_page_type": data_source_info['type']
+ "notion_page_type": data_source_info['type'],
+ "tenant_id": current_user.current_tenant_id
},
document_model=document.doc_form
)
diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py
index 68bb294a18..f5ea49bb5e 100644
--- a/api/core/indexing_runner.py
+++ b/api/core/indexing_runner.py
@@ -366,7 +366,8 @@ class IndexingRunner:
"notion_workspace_id": data_source_info['notion_workspace_id'],
"notion_obj_id": data_source_info['notion_page_id'],
"notion_page_type": data_source_info['type'],
- "document": dataset_document
+ "document": dataset_document,
+ "tenant_id": dataset_document.tenant_id
},
document_model=dataset_document.doc_form
)
diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py
index c0205d1aa9..e295e58950 100644
--- a/api/core/rag/datasource/retrieval_service.py
+++ b/api/core/rag/datasource/retrieval_service.py
@@ -39,7 +39,8 @@ class RetrievalService:
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
- 'top_k': top_k
+ 'top_k': top_k,
+ 'all_documents': all_documents
})
threads.append(keyword_thread)
keyword_thread.start()
diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py
index bc5310f7be..49cd4d0c03 100644
--- a/api/core/rag/extractor/entity/extract_setting.py
+++ b/api/core/rag/extractor/entity/extract_setting.py
@@ -12,6 +12,7 @@ class NotionInfo(BaseModel):
notion_obj_id: str
notion_page_type: str
document: Document = None
+ tenant_id: str
class Config:
arbitrary_types_allowed = True
diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py
index 7c7dc5bdae..0de7065335 100644
--- a/api/core/rag/extractor/extract_processor.py
+++ b/api/core/rag/extractor/extract_processor.py
@@ -132,7 +132,8 @@ class ExtractProcessor:
notion_workspace_id=extract_setting.notion_info.notion_workspace_id,
notion_obj_id=extract_setting.notion_info.notion_obj_id,
notion_page_type=extract_setting.notion_info.notion_page_type,
- document_model=extract_setting.notion_info.document
+ document_model=extract_setting.notion_info.document,
+ tenant_id=extract_setting.notion_info.tenant_id,
)
return extractor.extract()
else:
diff --git a/api/core/rag/extractor/html_extractor.py b/api/core/rag/extractor/html_extractor.py
index 557ea42b19..ceb5306255 100644
--- a/api/core/rag/extractor/html_extractor.py
+++ b/api/core/rag/extractor/html_extractor.py
@@ -1,13 +1,14 @@
"""Abstract interface for document loader implementations."""
-from typing import Optional
+from bs4 import BeautifulSoup
from core.rag.extractor.extractor_base import BaseExtractor
-from core.rag.extractor.helpers import detect_file_encodings
from core.rag.models.document import Document
class HtmlExtractor(BaseExtractor):
- """Load html files.
+
+ """
+ Load html files.
Args:
@@ -15,57 +16,19 @@ class HtmlExtractor(BaseExtractor):
"""
def __init__(
- self,
- file_path: str,
- encoding: Optional[str] = None,
- autodetect_encoding: bool = False,
- source_column: Optional[str] = None,
- csv_args: Optional[dict] = None,
+ self,
+ file_path: str
):
"""Initialize with file path."""
self._file_path = file_path
- self._encoding = encoding
- self._autodetect_encoding = autodetect_encoding
- self.source_column = source_column
- self.csv_args = csv_args or {}
def extract(self) -> list[Document]:
- """Load data into document objects."""
- try:
- with open(self._file_path, newline="", encoding=self._encoding) as csvfile:
- docs = self._read_from_file(csvfile)
- except UnicodeDecodeError as e:
- if self._autodetect_encoding:
- detected_encodings = detect_file_encodings(self._file_path)
- for encoding in detected_encodings:
- try:
- with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile:
- docs = self._read_from_file(csvfile)
- break
- except UnicodeDecodeError:
- continue
- else:
- raise RuntimeError(f"Error loading {self._file_path}") from e
+ return [Document(page_content=self._load_as_text())]
- return docs
+ def _load_as_text(self) -> str:
+ with open(self._file_path, "rb") as fp:
+ soup = BeautifulSoup(fp, 'html.parser')
+ text = soup.get_text()
+ text = text.strip() if text else ''
- def _read_from_file(self, csvfile) -> list[Document]:
- docs = []
- csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
- for i, row in enumerate(csv_reader):
- content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
- try:
- source = (
- row[self.source_column]
- if self.source_column is not None
- else ''
- )
- except KeyError:
- raise ValueError(
- f"Source column '{self.source_column}' not found in CSV file."
- )
- metadata = {"source": source, "row": i}
- doc = Document(page_content=content, metadata=metadata)
- docs.append(doc)
-
- return docs
+ return text
\ No newline at end of file
diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py
index f28436ffd9..38dd36361a 100644
--- a/api/core/rag/extractor/notion_extractor.py
+++ b/api/core/rag/extractor/notion_extractor.py
@@ -30,8 +30,10 @@ class NotionExtractor(BaseExtractor):
notion_workspace_id: str,
notion_obj_id: str,
notion_page_type: str,
+ tenant_id: str,
document_model: Optional[DocumentModel] = None,
- notion_access_token: Optional[str] = None
+ notion_access_token: Optional[str] = None,
+
):
self._notion_access_token = None
self._document_model = document_model
diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py
index 84e2029705..a646158dbd 100644
--- a/api/tasks/document_indexing_sync_task.py
+++ b/api/tasks/document_indexing_sync_task.py
@@ -58,7 +58,8 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type,
- notion_access_token=data_source_binding.access_token
+ notion_access_token=data_source_binding.access_token,
+ tenant_id=document.tenant_id
)
last_edited_time = loader.get_notion_last_edited_time()
From 0c0e96c55fde4f8c8623f0d32c682ac31bad5c57 Mon Sep 17 00:00:00 2001
From: zxhlyh
Date: Tue, 27 Feb 2024 11:59:54 +0800
Subject: [PATCH 14/36] fix: notion binding (#2572)
---
web/hooks/use-pay.tsx | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/web/hooks/use-pay.tsx b/web/hooks/use-pay.tsx
index 6ec4795940..cd43e8dd99 100644
--- a/web/hooks/use-pay.tsx
+++ b/web/hooks/use-pay.tsx
@@ -122,7 +122,7 @@ export const useCheckNotion = () => {
const notionCode = searchParams.get('code')
const notionError = searchParams.get('error')
const { data } = useSWR(
- canBinding
+ (canBinding && notionCode)
? `/oauth/data-source/binding/notion?code=${notionCode}`
: null,
fetchDataSourceNotionBinding,
From fc64cdee64e91f9185ec7422cd3dc3ff4215bab1 Mon Sep 17 00:00:00 2001
From: Jyong <76649700+JohnJyong@users.noreply.github.com>
Date: Tue, 27 Feb 2024 12:23:13 +0800
Subject: [PATCH 15/36] fix mivlus delete by ids error (#2573)
Co-authored-by: jyong
---
api/core/rag/datasource/vdb/field.py | 2 +-
api/core/rag/datasource/vdb/milvus/milvus_vector.py | 7 ++++++-
2 files changed, 7 insertions(+), 2 deletions(-)
diff --git a/api/core/rag/datasource/vdb/field.py b/api/core/rag/datasource/vdb/field.py
index 6a594a83ca..dc400dafbb 100644
--- a/api/core/rag/datasource/vdb/field.py
+++ b/api/core/rag/datasource/vdb/field.py
@@ -7,4 +7,4 @@ class Field(Enum):
GROUP_KEY = "group_id"
VECTOR = "vector"
TEXT_KEY = "text"
- PRIMARY_KEY = " id"
+ PRIMARY_KEY = "id"
diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py
index bb12ef1b56..0fc8ed5a26 100644
--- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py
+++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py
@@ -124,7 +124,12 @@ class MilvusVector(BaseVector):
def delete_by_ids(self, doc_ids: list[str]) -> None:
- self._client.delete(collection_name=self._collection_name, pks=doc_ids)
+ result = self._client.query(collection_name=self._collection_name,
+ filter=f'metadata["doc_id"] in {doc_ids}',
+ output_fields=["id"])
+ if result:
+ ids = [item["id"] for item in result]
+ self._client.delete(collection_name=self._collection_name, pks=ids)
def delete(self) -> None:
alias = uuid4().hex
From 07fbeb6cf0656133c4f1dfc4b74dcda5a41ca3aa Mon Sep 17 00:00:00 2001
From: Rozstone <42225395+wststone@users.noreply.github.com>
Date: Tue, 27 Feb 2024 15:58:57 +0800
Subject: [PATCH 16/36] enhancement: improve client-side code (#2568)
---
web/app/components/base/button/index.tsx | 14 +++++++++-----
.../components/billing/billing-page/index.tsx | 18 +++++++-----------
.../header/account-setting/index.tsx | 12 ++++--------
3 files changed, 20 insertions(+), 24 deletions(-)
diff --git a/web/app/components/base/button/index.tsx b/web/app/components/base/button/index.tsx
index e617a5d12d..24d58c6ea5 100644
--- a/web/app/components/base/button/index.tsx
+++ b/web/app/components/base/button/index.tsx
@@ -3,13 +3,16 @@ import React from 'react'
import Spinner from '../spinner'
export type IButtonProps = {
- type?: string
+ /**
+ * The style of the button
+ */
+ type?: 'primary' | 'warning' | (string & {})
className?: string
disabled?: boolean
loading?: boolean
tabIndex?: number
children: React.ReactNode
- onClick?: MouseEventHandler
+ onClick?: MouseEventHandler
}
const Button: FC = ({
@@ -35,15 +38,16 @@ const Button: FC = ({
}
return (
-
{children}
{/* Spinner is hidden when loading is false */}
-
+
)
}
diff --git a/web/app/components/billing/billing-page/index.tsx b/web/app/components/billing/billing-page/index.tsx
index 494851ea5c..843d0995e5 100644
--- a/web/app/components/billing/billing-page/index.tsx
+++ b/web/app/components/billing/billing-page/index.tsx
@@ -1,7 +1,8 @@
'use client'
import type { FC } from 'react'
-import React, { useEffect } from 'react'
+import React from 'react'
import { useTranslation } from 'react-i18next'
+import useSWR from 'swr'
import PlanComp from '../plan'
import { ReceiptList } from '../../base/icons/src/vender/line/financeAndECommerce'
import { LinkExternal01 } from '../../base/icons/src/vender/line/general'
@@ -12,17 +13,11 @@ import { useProviderContext } from '@/context/provider-context'
const Billing: FC = () => {
const { t } = useTranslation()
const { isCurrentWorkspaceManager } = useAppContext()
- const [billingUrl, setBillingUrl] = React.useState('')
const { enableBilling } = useProviderContext()
-
- useEffect(() => {
- if (!enableBilling || !isCurrentWorkspaceManager)
- return
- (async () => {
- const { url } = await fetchBillingUrl()
- setBillingUrl(url)
- })()
- }, [isCurrentWorkspaceManager])
+ const { data: billingUrl } = useSWR(
+ (!enableBilling || !isCurrentWorkspaceManager) ? null : ['/billing/invoices'],
+ () => fetchBillingUrl().then(data => data.url),
+ )
return (
@@ -39,4 +34,5 @@ const Billing: FC = () => {
)
}
+
export default React.memo(Billing)
diff --git a/web/app/components/header/account-setting/index.tsx b/web/app/components/header/account-setting/index.tsx
index a83542ef05..d0f5db243a 100644
--- a/web/app/components/header/account-setting/index.tsx
+++ b/web/app/components/header/account-setting/index.tsx
@@ -138,16 +138,12 @@ export default function AccountSetting({
]
const scrollRef = useRef(null)
const [scrolled, setScrolled] = useState(false)
- const scrollHandle = (e: Event) => {
- if ((e.target as HTMLDivElement).scrollTop > 0)
- setScrolled(true)
-
- else
- setScrolled(false)
- }
useEffect(() => {
const targetElement = scrollRef.current
-
+ const scrollHandle = (e: Event) => {
+ const userScrolled = (e.target as HTMLDivElement).scrollTop > 0
+ setScrolled(userScrolled)
+ }
targetElement?.addEventListener('scroll', scrollHandle)
return () => {
targetElement?.removeEventListener('scroll', scrollHandle)
From ac96d192a65e5e8b537b1eb7d65a4750f6dfed16 Mon Sep 17 00:00:00 2001
From: Yeuoly <45712896+Yeuoly@users.noreply.github.com>
Date: Tue, 27 Feb 2024 15:59:11 +0800
Subject: [PATCH 17/36] fix: parameter type handling in API tool and parser
(#2574)
---
api/core/tools/tool/api_tool.py | 2 +-
api/core/tools/utils/parser.py | 47 +++++++++++++++++++++++++++++----
2 files changed, 43 insertions(+), 6 deletions(-)
diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py
index f6914d3473..2a1ee92e78 100644
--- a/api/core/tools/tool/api_tool.py
+++ b/api/core/tools/tool/api_tool.py
@@ -200,7 +200,7 @@ class ApiTool(Tool):
# replace path parameters
for name, value in path_params.items():
- url = url.replace(f'{{{name}}}', value)
+ url = url.replace(f'{{{name}}}', f'{value}')
# parse http body data if needed, for GET/HEAD/OPTIONS/TRACE, the body is ignored
if 'Content-Type' in headers:
diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py
index 91c18be3f5..889316c235 100644
--- a/api/core/tools/utils/parser.py
+++ b/api/core/tools/utils/parser.py
@@ -1,4 +1,6 @@
+import re
+import uuid
from json import loads as json_loads
from requests import get
@@ -46,7 +48,7 @@ class ApiBasedToolSchemaParser:
parameters = []
if 'parameters' in interface['operation']:
for parameter in interface['operation']['parameters']:
- parameters.append(ToolParameter(
+ tool_parameter = ToolParameter(
name=parameter['name'],
label=I18nObject(
en_US=parameter['name'],
@@ -61,7 +63,14 @@ class ApiBasedToolSchemaParser:
form=ToolParameter.ToolParameterForm.LLM,
llm_description=parameter.get('description'),
default=parameter['schema']['default'] if 'schema' in parameter and 'default' in parameter['schema'] else None,
- ))
+ )
+
+ # check if there is a type
+ typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter)
+ if typ:
+ tool_parameter.type = typ
+
+ parameters.append(tool_parameter)
# create tool bundle
# check if there is a request body
if 'requestBody' in interface['operation']:
@@ -80,13 +89,14 @@ class ApiBasedToolSchemaParser:
root = root[ref]
# overwrite the content
interface['operation']['requestBody']['content'][content_type]['schema'] = root
+
# parse body parameters
if 'schema' in interface['operation']['requestBody']['content'][content_type]:
body_schema = interface['operation']['requestBody']['content'][content_type]['schema']
required = body_schema['required'] if 'required' in body_schema else []
properties = body_schema['properties'] if 'properties' in body_schema else {}
for name, property in properties.items():
- parameters.append(ToolParameter(
+ tool = ToolParameter(
name=name,
label=I18nObject(
en_US=name,
@@ -101,7 +111,14 @@ class ApiBasedToolSchemaParser:
form=ToolParameter.ToolParameterForm.LLM,
llm_description=property['description'] if 'description' in property else '',
default=property['default'] if 'default' in property else None,
- ))
+ )
+
+ # check if there is a type
+ typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
+ if typ:
+ tool.type = typ
+
+ parameters.append(tool)
# check if parameters is duplicated
parameters_count = {}
@@ -119,7 +136,11 @@ class ApiBasedToolSchemaParser:
path = interface['path']
if interface['path'].startswith('/'):
path = interface['path'][1:]
- path = path.replace('/', '_')
+ # remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
+ path = re.sub(r'[^a-zA-Z0-9_-]', '', path)
+ if not path:
+ path = str(uuid.uuid4())
+
interface['operation']['operationId'] = f'{path}_{interface["method"]}'
bundles.append(ApiBasedToolBundle(
@@ -134,7 +155,23 @@ class ApiBasedToolSchemaParser:
))
return bundles
+
+ @staticmethod
+ def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType:
+ parameter = parameter or {}
+ typ = None
+ if 'type' in parameter:
+ typ = parameter['type']
+ elif 'schema' in parameter and 'type' in parameter['schema']:
+ typ = parameter['schema']['type']
+ if typ == 'integer' or typ == 'number':
+ return ToolParameter.ToolParameterType.NUMBER
+ elif typ == 'boolean':
+ return ToolParameter.ToolParameterType.BOOLEAN
+ elif typ == 'string':
+ return ToolParameter.ToolParameterType.STRING
+
@staticmethod
def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]:
"""
From 920b2c2b40a017cb5c25fae9def16b101896d5c3 Mon Sep 17 00:00:00 2001
From: Jyong <76649700+JohnJyong@users.noreply.github.com>
Date: Tue, 27 Feb 2024 17:30:52 +0800
Subject: [PATCH 18/36] Fix/hit test tsne issue (#2581)
Co-authored-by: jyong
---
api/celerybeat-schedule.db | Bin 16384 -> 0 bytes
api/core/features/annotation_reply.py | 2 +-
api/core/rag/datasource/retrieval_service.py | 4 ++--
api/services/hit_testing_service.py | 5 +++--
4 files changed, 6 insertions(+), 5 deletions(-)
delete mode 100644 api/celerybeat-schedule.db
diff --git a/api/celerybeat-schedule.db b/api/celerybeat-schedule.db
deleted file mode 100644
index b8c01de27bfe7ea04f1dd868cec4935ef336f2b5..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001
literal 16384
zcmeI%J#W)M7zglk>f!`QTojc7@c|m78dNo6>rydA=}>~icDmRP*lOxV@fkIhEM?+F
zx~ENo1C0EEPX#K4Zk%h}inp%k`MMgNm7_ax8R_xN|`DeS_kV2srmv)?kd
zVnTMAG0O~jXZ12L`QnGAa$GZ`oyR9}_Q8yK%je{M;jLbjy6|POAOs))0SG_<0uX=z
z1Rwwb2teT62-Mj(_lx`4{p7xP-?*>cL367)Xr7z$Q78l;009U<00Izz00bZafio4D
z*(VRmKSDFTrmp!T5;3R!Aq7DcKjgmfL*h~-ds|{FA(gYPz8m9>b+|(oz
zlF3h^(Edd@H?RIgm^Z6Ln3vLFqkP=;mUX?Y!&dQh8(}+K?Xmv-+WeBQRvRb$J&FTf
zY(P5JdAXSdB=Qjg4Oi6}AWu2CL*wcPbKyxZF2{1Hu(=pg2NW2r$3a74k(-tpwo
znZ77k90Cx400bZa0SG_<0uX=z1Rwx`g$R88js^h;KmY;|fB*y_009U<00Izz!2d4r
E0}denXaE2J
diff --git a/api/core/features/annotation_reply.py b/api/core/features/annotation_reply.py
index e1b64cf73f..fd516e465f 100644
--- a/api/core/features/annotation_reply.py
+++ b/api/core/features/annotation_reply.py
@@ -59,7 +59,7 @@ class AnnotationReplyFeature:
documents = vector.search_by_vector(
query=query,
- k=1,
+ top_k=1,
score_threshold=score_threshold,
filter={
'group_id': [dataset.id]
diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py
index e295e58950..0f9c753056 100644
--- a/api/core/rag/datasource/retrieval_service.py
+++ b/api/core/rag/datasource/retrieval_service.py
@@ -101,7 +101,7 @@ class RetrievalService:
documents = keyword.search(
query,
- k=top_k
+ top_k=top_k
)
all_documents.extend(documents)
@@ -121,7 +121,7 @@ class RetrievalService:
documents = vector.search_by_vector(
query,
search_type='similarity_score_threshold',
- k=top_k,
+ top_k=top_k,
score_threshold=score_threshold,
filter={
'group_id': [dataset.id]
diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py
index 568974b74f..6d5a0537d3 100644
--- a/api/services/hit_testing_service.py
+++ b/api/services/hit_testing_service.py
@@ -133,8 +133,9 @@ class HitTestingService:
if embedding_length <= 1:
return [{'x': 0, 'y': 0}]
- concatenate_data = np.array(embeddings).reshape(embedding_length, -1)
- # concatenate_data = np.concatenate(embeddings)
+ noise = np.random.normal(0, 1e-4, np.array(embeddings).shape)
+ concatenate_data = np.array(embeddings) + noise
+ concatenate_data = concatenate_data.reshape(embedding_length, -1)
perplexity = embedding_length / 2 + 1
if perplexity >= embedding_length:
From 29ab244de601da5c20e10c67b1b15cb57d85169d Mon Sep 17 00:00:00 2001
From: Bowen Liang
Date: Tue, 27 Feb 2024 18:05:48 +0800
Subject: [PATCH 19/36] fix: correct the parent class of CacheEmbedding (#2578)
---
api/core/embedding/cached_embedding.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/api/core/embedding/cached_embedding.py b/api/core/embedding/cached_embedding.py
index a86afd817a..7498a07559 100644
--- a/api/core/embedding/cached_embedding.py
+++ b/api/core/embedding/cached_embedding.py
@@ -3,12 +3,12 @@ import logging
from typing import Optional, cast
import numpy as np
-from langchain.embeddings.base import Embeddings
from sqlalchemy.exc import IntegrityError
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+from core.rag.datasource.entity.embedding import Embeddings
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs import helper
From 3a34370422cda5c76227d7bb92db0d78439fa942 Mon Sep 17 00:00:00 2001
From: Yeuoly <45712896+Yeuoly@users.noreply.github.com>
Date: Tue, 27 Feb 2024 19:15:07 +0800
Subject: [PATCH 20/36] =?UTF-8?q?fix:=20convert=20tool=20messages=20into?=
=?UTF-8?q?=20user=20messages=20in=20react=20mode=20and=20fill=20=E2=80=A6?=
=?UTF-8?q?=20(#2584)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
api/core/features/assistant_base_runner.py | 64 ++++++++++++----------
api/core/features/assistant_cot_runner.py | 13 ++++-
2 files changed, 45 insertions(+), 32 deletions(-)
diff --git a/api/core/features/assistant_base_runner.py b/api/core/features/assistant_base_runner.py
index c4a5767b04..2a4ae7e135 100644
--- a/api/core/features/assistant_base_runner.py
+++ b/api/core/features/assistant_base_runner.py
@@ -606,36 +606,42 @@ class BaseAssistantApplicationRunner(AppRunner):
for message in messages:
result.append(UserPromptMessage(content=message.query))
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
- for agent_thought in agent_thoughts:
- tools = agent_thought.tool
- if tools:
- tools = tools.split(';')
- tool_calls: list[AssistantPromptMessage.ToolCall] = []
- tool_call_response: list[ToolPromptMessage] = []
- tool_inputs = json.loads(agent_thought.tool_input)
- for tool in tools:
- # generate a uuid for tool call
- tool_call_id = str(uuid.uuid4())
- tool_calls.append(AssistantPromptMessage.ToolCall(
- id=tool_call_id,
- type='function',
- function=AssistantPromptMessage.ToolCall.ToolCallFunction(
+ if agent_thoughts:
+ for agent_thought in agent_thoughts:
+ tools = agent_thought.tool
+ if tools:
+ tools = tools.split(';')
+ tool_calls: list[AssistantPromptMessage.ToolCall] = []
+ tool_call_response: list[ToolPromptMessage] = []
+ tool_inputs = json.loads(agent_thought.tool_input)
+ for tool in tools:
+ # generate a uuid for tool call
+ tool_call_id = str(uuid.uuid4())
+ tool_calls.append(AssistantPromptMessage.ToolCall(
+ id=tool_call_id,
+ type='function',
+ function=AssistantPromptMessage.ToolCall.ToolCallFunction(
+ name=tool,
+ arguments=json.dumps(tool_inputs.get(tool, {})),
+ )
+ ))
+ tool_call_response.append(ToolPromptMessage(
+ content=agent_thought.observation,
name=tool,
- arguments=json.dumps(tool_inputs.get(tool, {})),
- )
- ))
- tool_call_response.append(ToolPromptMessage(
- content=agent_thought.observation,
- name=tool,
- tool_call_id=tool_call_id,
- ))
+ tool_call_id=tool_call_id,
+ ))
- result.extend([
- AssistantPromptMessage(
- content=agent_thought.thought,
- tool_calls=tool_calls,
- ),
- *tool_call_response
- ])
+ result.extend([
+ AssistantPromptMessage(
+ content=agent_thought.thought,
+ tool_calls=tool_calls,
+ ),
+ *tool_call_response
+ ])
+ if not tools:
+ result.append(AssistantPromptMessage(content=agent_thought.thought))
+ else:
+ if message.answer:
+ result.append(AssistantPromptMessage(content=message.answer))
return result
\ No newline at end of file
diff --git a/api/core/features/assistant_cot_runner.py b/api/core/features/assistant_cot_runner.py
index aa4a6797cd..809834c8cb 100644
--- a/api/core/features/assistant_cot_runner.py
+++ b/api/core/features/assistant_cot_runner.py
@@ -154,7 +154,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
thought='',
action_str='',
observation='',
- action=None
+ action=None,
)
# publish agent thought if it's first iteration
@@ -469,7 +469,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
thought=message.content,
action_str='',
action=None,
- observation=None
+ observation=None,
)
if message.tool_calls:
try:
@@ -484,7 +484,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
elif isinstance(message, ToolPromptMessage):
if current_scratchpad:
current_scratchpad.observation = message.content
-
+
return agent_scratchpad
def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"],
@@ -607,6 +607,13 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
prompt_message.content = system_message
overridden = True
break
+
+ # convert tool prompt messages to user prompt messages
+ for idx, prompt_message in enumerate(prompt_messages):
+ if isinstance(prompt_message, ToolPromptMessage):
+ prompt_messages[idx] = UserPromptMessage(
+ content=prompt_message.content
+ )
if not overridden:
prompt_messages.insert(0, SystemPromptMessage(
From f1cbd55007c451b02d54c17bca6704a502b26f1b Mon Sep 17 00:00:00 2001
From: Rozstone <42225395+wststone@users.noreply.github.com>
Date: Tue, 27 Feb 2024 19:16:22 +0800
Subject: [PATCH 21/36] =?UTF-8?q?enhancement:=20skip=20fetching=20to=20imp?=
=?UTF-8?q?rove=20user=20experience=20when=20switching=20=E2=80=A6=20(#258?=
=?UTF-8?q?0)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
web/app/components/base/logo/logo-site.tsx | 4 +-
web/app/components/explore/app-list/index.tsx | 82 ++++++++++++-------
web/app/components/header/index.tsx | 2 +-
web/models/explore.ts | 2 +-
web/service/explore.ts | 6 +-
5 files changed, 61 insertions(+), 35 deletions(-)
diff --git a/web/app/components/base/logo/logo-site.tsx b/web/app/components/base/logo/logo-site.tsx
index 9d9bccfaf8..65569c8c99 100644
--- a/web/app/components/base/logo/logo-site.tsx
+++ b/web/app/components/base/logo/logo-site.tsx
@@ -1,15 +1,17 @@
import type { FC } from 'react'
+import classNames from 'classnames'
type LogoSiteProps = {
className?: string
}
+
const LogoSite: FC = ({
className,
}) => {
return (
)
diff --git a/web/app/components/explore/app-list/index.tsx b/web/app/components/explore/app-list/index.tsx
index 6ceabe23bf..d3e90e6212 100644
--- a/web/app/components/explore/app-list/index.tsx
+++ b/web/app/components/explore/app-list/index.tsx
@@ -1,13 +1,14 @@
'use client'
import type { FC } from 'react'
-import React, { useEffect } from 'react'
+import React from 'react'
import { useRouter } from 'next/navigation'
import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector'
+import useSWR from 'swr'
import Toast from '../../base/toast'
import s from './style.module.css'
import ExploreContext from '@/context/explore-context'
-import type { App, AppCategory } from '@/models/explore'
+import type { App } from '@/models/explore'
import Category from '@/app/components/explore/category'
import AppCard from '@/app/components/explore/app-card'
import { fetchAppDetail, fetchAppList } from '@/service/explore'
@@ -25,34 +26,44 @@ const Apps: FC = () => {
const { isCurrentWorkspaceManager } = useAppContext()
const router = useRouter()
const { hasEditPermission } = useContext(ExploreContext)
+ const allCategoriesEn = t('explore.apps.allCategories', { lng: 'en' })
const [currCategory, setCurrCategory] = useTabSearchParams({
- defaultTab: '',
+ defaultTab: allCategoriesEn,
})
- const [allList, setAllList] = React.useState([])
- const [isLoaded, setIsLoaded] = React.useState(false)
+ const {
+ data: { categories, allList },
+ isLoading,
+ } = useSWR(
+ ['/explore/apps'],
+ () =>
+ fetchAppList().then(({ categories, recommended_apps }) => ({
+ categories,
+ allList: recommended_apps.sort((a, b) => a.position - b.position),
+ })),
+ {
+ fallbackData: {
+ categories: [],
+ allList: [],
+ },
+ },
+ )
const currList = (() => {
- if (currCategory === '')
+ if (currCategory === allCategoriesEn)
return allList
return allList.filter(item => item.category === currCategory)
})()
- const [categories, setCategories] = React.useState([])
- useEffect(() => {
- (async () => {
- const { categories, recommended_apps }: any = await fetchAppList()
- const sortedRecommendedApps = [...recommended_apps]
- sortedRecommendedApps.sort((a, b) => a.position - b.position) // position from small to big
- setCategories(categories)
- setAllList(sortedRecommendedApps)
- setIsLoaded(true)
- })()
- }, [])
-
const [currApp, setCurrApp] = React.useState(null)
const [isShowCreateModal, setIsShowCreateModal] = React.useState(false)
- const onCreate: CreateAppModalProps['onConfirm'] = async ({ name, icon, icon_background }) => {
- const { app_model_config: model_config } = await fetchAppDetail(currApp?.app.id as string)
+ const onCreate: CreateAppModalProps['onConfirm'] = async ({
+ name,
+ icon,
+ icon_background,
+ }) => {
+ const { app_model_config: model_config } = await fetchAppDetail(
+ currApp?.app.id as string,
+ )
try {
const app = await createApp({
@@ -68,36 +79,45 @@ const Apps: FC = () => {
message: t('app.newApp.appCreated'),
})
localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1')
- router.push(`/app/${app.id}/${isCurrentWorkspaceManager ? 'configuration' : 'overview'}`)
+ router.push(
+ `/app/${app.id}/${
+ isCurrentWorkspaceManager ? 'configuration' : 'overview'
+ }`,
+ )
}
catch (e) {
Toast.notify({ type: 'error', message: t('app.newApp.appCreateFailed') })
}
}
- if (!isLoaded) {
+ if (!isLoading) {
return (
-
-
+
+
)
}
return (
-
-
-
{t('explore.apps.title')}
-
{t('explore.apps.description')}
+
+
+
+ {t('explore.apps.title')}
+
+
+ {t('explore.apps.description')}
+
-
+
}
{!isMobile && <>
-
+
>}
diff --git a/web/models/explore.ts b/web/models/explore.ts
index 1cb98b2a28..c90d9ba22b 100644
--- a/web/models/explore.ts
+++ b/web/models/explore.ts
@@ -16,7 +16,7 @@ export type App = {
app_id: string
description: string
copyright: string
- privacy_policy: string
+ privacy_policy: string | null
category: AppCategory
position: number
is_listed: boolean
diff --git a/web/service/explore.ts b/web/service/explore.ts
index 60fb8b1128..bb608f7ee5 100644
--- a/web/service/explore.ts
+++ b/web/service/explore.ts
@@ -1,7 +1,11 @@
import { del, get, patch, post } from './base'
+import type { App, AppCategory } from '@/models/explore'
export const fetchAppList = () => {
- return get('/explore/apps')
+ return get<{
+ categories: AppCategory[]
+ recommended_apps: App[]
+ }>('/explore/apps')
}
export const fetchAppDetail = (id: string): Promise
=> {
From 582ba45c009f42d21c013d90ebf67e9cbedd4a9d Mon Sep 17 00:00:00 2001
From: crazywoola <100913391+crazywoola@users.noreply.github.com>
Date: Wed, 28 Feb 2024 11:27:17 +0800
Subject: [PATCH 22/36] Fix 500 error when creating from the template and the
provider is None (#2591)
---
api/controllers/console/app/app.py | 12 +++---------
1 file changed, 3 insertions(+), 9 deletions(-)
diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py
index 87cad07462..59a7535144 100644
--- a/api/controllers/console/app/app.py
+++ b/api/controllers/console/app/app.py
@@ -124,19 +124,13 @@ class AppListApi(Resource):
available_models_names = [f'{model.provider.provider}.{model.model}' for model in available_models]
provider_model = f"{model_config_dict['model']['provider']}.{model_config_dict['model']['name']}"
if provider_model not in available_models_names:
- model_manager = ModelManager()
- model_instance = model_manager.get_default_model_instance(
- tenant_id=current_user.current_tenant_id,
- model_type=ModelType.LLM
- )
-
- if not model_instance:
+ if not default_model_entity:
raise ProviderNotInitializeError(
"No Default System Reasoning Model available. Please configure "
"in the Settings -> Model Provider.")
else:
- model_config_dict["model"]["provider"] = model_instance.provider
- model_config_dict["model"]["name"] = model_instance.model
+ model_config_dict["model"]["provider"] = default_model_entity.provider
+ model_config_dict["model"]["name"] = default_model_entity.model
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,
From 9b1c4f47fbc9498bd91b5a747a46966bab04ba0c Mon Sep 17 00:00:00 2001
From: Joshua <138381132+joshua20231026@users.noreply.github.com>
Date: Wed, 28 Feb 2024 12:22:57 +0800
Subject: [PATCH 23/36] feat:add mistral ai (#2594)
---
.../model_providers/mistralai/__init__.py | 0
.../mistralai/_assets/icon_l_en.png | Bin 0 -> 7064 bytes
.../mistralai/_assets/icon_s_en.png | Bin 0 -> 2359 bytes
.../mistralai/llm/_position.yaml | 5 ++
.../model_providers/mistralai/llm/llm.py | 31 +++++++++++
.../mistralai/llm/mistral-large-latest.yaml | 50 ++++++++++++++++++
.../mistralai/llm/mistral-medium-latest.yaml | 50 ++++++++++++++++++
.../mistralai/llm/mistral-small-latest.yaml | 50 ++++++++++++++++++
.../mistralai/llm/open-mistral-7b.yaml | 50 ++++++++++++++++++
.../mistralai/llm/open-mixtral-8x7b.yaml | 50 ++++++++++++++++++
.../model_providers/mistralai/mistralai.py | 30 +++++++++++
.../model_providers/mistralai/mistralai.yaml | 31 +++++++++++
12 files changed, 347 insertions(+)
create mode 100644 api/core/model_runtime/model_providers/mistralai/__init__.py
create mode 100644 api/core/model_runtime/model_providers/mistralai/_assets/icon_l_en.png
create mode 100644 api/core/model_runtime/model_providers/mistralai/_assets/icon_s_en.png
create mode 100644 api/core/model_runtime/model_providers/mistralai/llm/_position.yaml
create mode 100644 api/core/model_runtime/model_providers/mistralai/llm/llm.py
create mode 100644 api/core/model_runtime/model_providers/mistralai/llm/mistral-large-latest.yaml
create mode 100644 api/core/model_runtime/model_providers/mistralai/llm/mistral-medium-latest.yaml
create mode 100644 api/core/model_runtime/model_providers/mistralai/llm/mistral-small-latest.yaml
create mode 100644 api/core/model_runtime/model_providers/mistralai/llm/open-mistral-7b.yaml
create mode 100644 api/core/model_runtime/model_providers/mistralai/llm/open-mixtral-8x7b.yaml
create mode 100644 api/core/model_runtime/model_providers/mistralai/mistralai.py
create mode 100644 api/core/model_runtime/model_providers/mistralai/mistralai.yaml
diff --git a/api/core/model_runtime/model_providers/mistralai/__init__.py b/api/core/model_runtime/model_providers/mistralai/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/model_runtime/model_providers/mistralai/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/mistralai/_assets/icon_l_en.png
new file mode 100644
index 0000000000000000000000000000000000000000..f019b1edceac20bb838c9be16985b4e8767344d8
GIT binary patch
literal 7064
zcmV;J8)xKFNk&GH8vp=TMM6+kP&il$0000G00030002<{06|PpNIwk#009|-Y#V^i
z-}q;U=zp5>aRdTM`-_hd5vl9Awr!=k7pDpq70?6Yw3}fHavk}NeE$tZOuz(q!~Xx!
z9Qpqq&)2)E5hr7;cgLR{+laHaZQHhO+qN;z*lUe4o-~q9U-f>yZdBEUPIvnF+>3|_
zxNU&{OKE9=1QIm=_QP_b*U;oWZ&om`7xVp?@5Q`N{C|JU`!pYryu#-Jxoh4R^S*f^
zNjz-_X0;r(q*?_i^;EUV@Bp(M4QBnQz^se#=pm^Vxtk=wy$cFKKmODt-R&RGk<*>=H|y9
zm92-+*jGG`JT6;J)x-@~U!O&EvHsq!(57$G<-_)fNE^>JZ&~}3PDfm%G+8P9N$Kby
z5!&|hI~AZ9{lJ1idHM1z0KeO~ix@7VLE8c%O;mm32~P&N-%rhB_;b4fTXxxE9Q8qZ
zUA9>P<{RF6FwNJ0?mZJqokzTY-tXLFn)|<&E^puX8I&O8Xa=66G?&j{u#2fLwSEO-ovNtH6G0rjqC@EB0Je5qpoRd(9l-p}uOkP6
z&|a^<=mlrPSf^6}Y6u|gW?-UI72qMPSh42oKYG%gikjt<7f&nX<>}u@RSmek<6evt
zv_3TqHTu(1H;o!${7?4SSPF8iby-uO~ETc;)qh|ysnPXLkb
z4C3U5>*#MJ@BI{X(Ii
z9Ohnp`W
z?egGvEr59JBLc8|^uD37cfHHvF30@V@*LuWC^rA=P)>p4m-4rfyVYRZ4FFRRY~Dn(
zv47-bSAP(T@r$2#Lity77{8YK9RSnU(X1K-M0MlGKK`=GblxJdswb89-r0MFc=0#Lj!%NQ%%|m;(|}RAo(Nu0_x+2@3xM
zDcE^<0#s=YaM5+w93!nMJBm324I++=_3KCrwED4+h!NIA~WDU(@G@Sn>$ck8HFqe`(OAa?#f9Q6SB
zV~M8$5cbj^o)+0~-|5-ge)pF1%m*Xx@-b(pgAag^z6>b#IvR1_R^52@;c)PM{wC&3
z8wBii>89QHe+d4YlL!V9;~U=djDoBpe0beiuXvyNfL1&kz%6p|!SgMX^6CW{E=$8x
zXTC7<(Ap2#+%ZtGV(RBl_E9docJ~1h+WOohpY`CSc%*Q9?~}h^J_v|ulw_#*(n}TK
zv7RWZpX4q43UgMG`y?O~PpMFVB5i-Lh#z59`3o%@6d}9d$#;9q-S~(?k{*t5@(Ppe)4=8
z|9qkcXV7XKxNbpi0Wkx`j1C`6T@Y=?X=RMQSApG_j}b;OyLYv4Z9
z#BAvVT{;mpr4SH+5}`sk7ajJ#$0RaUNV4*8kG{sd4Tp$sa3TP*{qE=Ms+1pp)W-W;
zmjCUq1mK)Ewm5RJ?o|SUKxw9hW{G5lQd_d5LV(HW)<{6;md79Sj8%Km_|LC>=yr43
z&5}`cW0s5rAie4Af;9B5HEUIb&g!HLgl@m&WGlH}*-=P|?NH2$&b#)UPLr7coOkUx
zowiMp`OQ(eE0iY*LOSnByWjVqy+)`0b=4(zVoqzt=Z=|{O*+v~@wsEBWs{j8v+>Xo
zNAUDvkD>qIG{)oh=t%AaPg$0?htRALTuYb1wKN8|UqD~OO`CpAb8@FW*mOQ!2b<2P
zOI=Ch6*NXea4qeBXS)=$qzS?szy#rdte5Z_IHo4gbT0$KTR;YcH-rfYHvt~Bnt*V?
zv|WT|38!nPHF7WAI(?2zkCExtX>@vY+MVrE%wwZL0}3=#0wQ*_iiY03MXHHjBU-E{
zfl=;Y-7Zj!3`KW^wXm&*sfSW6-f2ipswRT4@z}9ahQR3B(;wa>0#ZKxwIM*!Rp=eR
zP8meJ?ldl&@?;P4Lyg`EJO4x!f7_C2Z5v=t+`jJ!M#nXr)r)PRYSK?
z^)$z7u2Jm%>eBI{@BB99n)xA5tJF={*4Ixd=8(`Fld4tEqyCF^K?4Uon_h>QQ^g$X
ztjbja9kbJI3Lx0laRUj6AT$>?xvDb;evJ4PY02zYplg2Q-TAIIM;G*j^7Cg1qcX6
z7dUulOUsqNARzAami7i1EsF%EBp|eK7YL2p066fzz{cQwtSVN%S?mBNngt+C01*u!
zqaL)X!GpTOPy!HuFzSi^8*5UPxIl4sRb*Bx?+6S!azp=vK89idnKfWr&qe?E_kF~X
z`(JIoqch8z-{W}1Mc2Xf@1E(oT=^kmJ{r+KRt{(5Za*&Awhzf!OP-eu&{;`{yWEeu
zjN^8@NimO-?{)@>oE=b+bSURSdkHGqw^{pO>d9#K
zEKnw9$cq&~NtvQvb*2}R6SJFb$t
zlBC%R09H^qASw<30PsElodGHU0a5@y5d?uigBz^?U?7|Ur!&~`1H3kI@IT)_5Irkf8;;7{>cAq`UC!J^=sW%_mA7ZaDUbR?f=t%U_Yn#
zEdHhcC)5Y<|K@+_f1>-o|D*p&{a4(F_J8vq=6%3Fl>a~VsQ*9gQT-$Rr?5ZqKk0s#
z|MY*v{2PAK|K0yZ@GbnS`j_7)$OoVYsORtw$Pe(o5x&I!J^tJETKhZAnWZ0;{FnIM
z+kC$+AN@c4ALVzT2lqdW9`E-W|E=|t!gJ}~f`8FJ-+z4eUH@0j&$|A_|IdHd{lItw
z^b-AF6TAz=s_Go6Y++o;YDItb4ka$tk>^JL)J+TK(JCfKaBhSjCAh?_!mN_9Px}+O
zFY|&4*0QOW83D3ie}|K=VURrFeOAOKPp--N2Z1jg&(WN!XY|$^2bESH7Hcz=yb?m=R)ZOkbzB&jZ;myMoxr}PTdCz`qMdB0^iaED=~QQ^)12nx%Q&$9(`r0n
z>IemHpaA~=>Z6-e?-J%7>VhQ+IX_oU6z(T5g5sti3w^jZvn2vw2lM9VG*3w%Ev7a0X@f0X70%x`p5m|R)|A3)CyV;YO9Hj<9Z
zWbcRUdgvleJ^#@nYb`39xQ3zR^;CA<{8Y*_p#S+{GzaIiI9+)*zYuZz%tue>{Mp~`
zxARhEW&@u
zxNBE?#kzvFj^xa6W=@;qnBUKD@3p@cL*G@@G=3?^{%JdVar(^uWJx_}FXxkrWA}RN
zWQ53Vp%>LiRvXizJ~GTN?&v7ZlWWur#e8IPI-EdowBSfhKpnSGtyKM*%@zJj|59A~
zC%Bxt2eNPgE2lypqs-))uNIgkUz4SeV%*w)t7;E=L9;#MYzrRnDf9*ERS#xuA{
z(Mv)oh(<3y_-{Os3A5BdUl|?yC);ju?7i5Iua2(y8KZ}qB)jaxnrD0ovHLA
zDzSF)7QM(<3tOKEz=ZLI=Mw+fRBM&8OQm0*HO`5qWPOu=3sG5^{|n5SyZ{rUb2Lha
zeTCTyFQrEuR}ry-NrqSug4nRpIE7d}12*Kb@zt0r$R-=d59l(YU-XlvA0c#EL3+2m
zy6X?d7>UE`$s5?iy%3Y0emqmHHX#ghA3wU%463yGnKvp
zhL?yjKGbHApOIMW>yVwIyn;j|@EBRy3Dlsi!rRtgYiJ;MBU>32XZBpYW<%lKZX&sG
z<`QlL*@ZoPgL~OZ-Up>l{6{1eg2bx3HH9>K`s&^c2x^WhSntG2AwY`QQNa09HNx9wp3X32dfc$l$M-K5N7yR#SB7l#=4?M!k56HHe0>CklLZxv
z+knoC)Y~-s%#JByNtzv^ETZWTIF>)J?I<&R{vIl+`@F&zW8w^a$_1-$HPdy?#66xD
z7X5~q$@A?koJer+j=qGq6X^`WX#l)Ki$gnPIcftK1m%s;X8hdHg&9~rs!Mp?3W@68
zZRfxo4~Q2}ve6;o8(-a=Sm9_3{~TW@_NTv9se$L8RexAH&J~tPeaGu`YX~XL9YV|W
z3HE^UX+g~AjhjH^=6({;T;)0Zy`f3}f%*kS9w*q4%GPlpXmhnPvVRTnIqH|K+2j7g
zUi8l2^{N|V&qv0Sd-$~Qt4e$PbPJ%720FrZQ0>1p4ObUiA=G?KY*4^iE^SAgxsSE2lsi
z0u+7MiLE{iT2?U=j>pzkWqiHz###UNm*GOxrZFV1+}>-K6OXu0@Hbqb@N6AnN?rd-
z)p_oML=rHS$4`Tq)2n(N$Y8g`tu?Gv-4K4tVGlZz2-A!(-miGi?4T
z&!z|X8fz6800T=+Vag0R+a@}$v6Lm|BdPS@&^U&7RO%zcK9o?UiQHVbzH$8LJo8d=
z;hNKC+dX&!9T?xpE(2xSr}z%NDydQYBCoMBhixKlq-@$GHaxtIM;FE|MzKSWr3jF$%eBj;xMXc8}YGJb}^}I
zn_5*pFkjLzmcWJM(T5na_>!VMIhn4U|;zdE@?M?yBt63GI|bm653*#CL!FE7;vWV$$C}`DtbsQeEFvl&d=bRi!4o1TCCL8QD7-9z@pY|U#Qx;wagw1
z)Sa$<{)JvAW9>kS!}kM`PS!)8CYP`12<_ajhaNcB6O)b)^i+|Y;s5QjZ43Y62Kmqb
z%NchqlB!eU9S<(v1ROf8uuCxgYyxZ_H3HvK?5q%rp
z&c@Oa@)TQ$qIP(rl84EsaBSOKG{=m6#S&M**l$R`;udYw^nG~f#FXWg)Z}%8^>lxe{Q0{{3nymu(+y;^)B=Jw4}Sjc*}@AC9$IoqY9-
z#YeFu2J!#nDB^`MYv1GFnrrmj&*pU-dg~ubO^l9?0oJ=pdd=W*q+8leyVi{JH=4Qc
z8$t%z&A8(T@Y=SF`1;+wC8{#qFY(paV!cVO)6Zwr}#su{^LJlblmhr{gC0?VPU{>1Zw)
z+*$hzRuxU-2T$FM9>xAqz`V~tzL^{_u3lg%zu@1-n-fmv1&7dOpwPigf{D=lm+N?w
zaG5xi81MsO?mbIuaYzrW#S&{MF;c8ZB>b;e$aF-nGRSiZzD`5QOoRUw(2|gyC@V$B?tI~xE^500y3vCu(q&q
z&EaJ^!peJ_iR-E$+kaNUiP}v6x!AUvuupbixZ}Wj*_~lZ48x8fhW{SC|NR*5#xhJQ
zW%?h^@HU0{MiIk{Vut^D4FBpGZgn%`Qbm!|-=8)13`$|Iae1s)V%w?O-ek@(X5g
zcy=R=fq|pR)5S5Q;?~=%2f3OYMA|NDntQgnC~2tf{U5w-ookVJ^W%zh!aUReCkHVB
zwbBGUu<#Q+Y%Rb2-?iPQPU5?wl}i?K%Ysqt3y!-#TKX>Z`oG|aec|N(B2TxZMXzML
z+Y6`iY#?JHH&FEW45NH5FlOP6wOcJ*Oqg@vScYWlPT_`;t$hgqj
zcj0an$kM#t3$5?2fovBCiGJk7a1&vZAO?U;0y-3ofKGfYY~^Hb2XZUWv%9x~EWInW
zko)dhknJF^$vzgwa1)wIJ^r9T0r?v2eTY9m;Rp6U$ViA+K;Z`tNRR{GT?6Vtax~aF
zWRo<%3LbuRE$8cLo!wL3E8N}uuUP>SBS5c%V*wsspkM+an4e&ghOP?i9Z+Op(+}4S
zjdwh%ppgqo9k^8CmVs#oMI|uGVPS*A@xUM=qzarM@OvHSUNX
zC`yL1p>roW$#X}(QeHI0-9K4#y;=3b|5xvQ=G?ohbH}t#ugpHZN&@ClbLsnxad+bD!|BtAoMi%JqtBr5#fS0$|phcA|yLvO>#!y^oA{At`|cUC6L4=!MPAm7NMG_
zz(pcf@4~ZJ3Z%S))6;ngEl!ka7rM!m}^Put`
zT#iACMU-;9JR1}=h-wU6Y5`RMMe&&=F&|VQqm=i^^(8zlgKIfZbq8|5-L;@X99+vm
mDk+dCib;GlsrW&LI>^cE&+)o@=v8Dri0kR<=d#Wzp$Pz6^0S-(
literal 0
HcmV?d00001
diff --git a/api/core/model_runtime/model_providers/mistralai/llm/_position.yaml b/api/core/model_runtime/model_providers/mistralai/llm/_position.yaml
new file mode 100644
index 0000000000..5e74dc5dfe
--- /dev/null
+++ b/api/core/model_runtime/model_providers/mistralai/llm/_position.yaml
@@ -0,0 +1,5 @@
+- open-mistral-7b
+- open-mixtral-8x7b
+- mistral-small-latest
+- mistral-medium-latest
+- mistral-large-latest
diff --git a/api/core/model_runtime/model_providers/mistralai/llm/llm.py b/api/core/model_runtime/model_providers/mistralai/llm/llm.py
new file mode 100644
index 0000000000..01ed8010de
--- /dev/null
+++ b/api/core/model_runtime/model_providers/mistralai/llm/llm.py
@@ -0,0 +1,31 @@
+from collections.abc import Generator
+from typing import Optional, Union
+
+from core.model_runtime.entities.llm_entities import LLMResult
+from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
+from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
+
+
+class MistralAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
+ def _invoke(self, model: str, credentials: dict,
+ prompt_messages: list[PromptMessage], model_parameters: dict,
+ tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
+ stream: bool = True, user: Optional[str] = None) \
+ -> Union[LLMResult, Generator]:
+
+ self._add_custom_parameters(credentials)
+
+ # mistral dose not support user/stop arguments
+ stop = []
+ user = None
+
+ return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
+
+ def validate_credentials(self, model: str, credentials: dict) -> None:
+ self._add_custom_parameters(credentials)
+ super().validate_credentials(model, credentials)
+
+ @staticmethod
+ def _add_custom_parameters(credentials: dict) -> None:
+ credentials['mode'] = 'chat'
+ credentials['endpoint_url'] = 'https://api.mistral.ai/v1'
diff --git a/api/core/model_runtime/model_providers/mistralai/llm/mistral-large-latest.yaml b/api/core/model_runtime/model_providers/mistralai/llm/mistral-large-latest.yaml
new file mode 100644
index 0000000000..b729012c40
--- /dev/null
+++ b/api/core/model_runtime/model_providers/mistralai/llm/mistral-large-latest.yaml
@@ -0,0 +1,50 @@
+model: mistral-large-latest
+label:
+ zh_Hans: mistral-large-latest
+ en_US: mistral-large-latest
+model_type: llm
+features:
+ - agent-thought
+model_properties:
+ context_size: 32000
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ default: 0.7
+ min: 0
+ max: 1
+ - name: top_p
+ use_template: top_p
+ default: 1
+ min: 0
+ max: 1
+ - name: max_tokens
+ use_template: max_tokens
+ default: 1024
+ min: 1
+ max: 8000
+ - name: safe_prompt
+ defulat: false
+ type: boolean
+ help:
+ en_US: Whether to inject a safety prompt before all conversations.
+ zh_Hans: 是否开启提示词审查
+ label:
+ en_US: SafePrompt
+ zh_Hans: 提示词审查
+ - name: random_seed
+ type: int
+ help:
+ en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
+ zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
+ label:
+ en_US: RandomSeed
+ zh_Hans: 随机数种子
+ default: 0
+ min: 0
+ max: 2147483647
+pricing:
+ input: '0.008'
+ output: '0.024'
+ unit: '0.001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/mistralai/llm/mistral-medium-latest.yaml b/api/core/model_runtime/model_providers/mistralai/llm/mistral-medium-latest.yaml
new file mode 100644
index 0000000000..6e586b4843
--- /dev/null
+++ b/api/core/model_runtime/model_providers/mistralai/llm/mistral-medium-latest.yaml
@@ -0,0 +1,50 @@
+model: mistral-medium-latest
+label:
+ zh_Hans: mistral-medium-latest
+ en_US: mistral-medium-latest
+model_type: llm
+features:
+ - agent-thought
+model_properties:
+ context_size: 32000
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ default: 0.7
+ min: 0
+ max: 1
+ - name: top_p
+ use_template: top_p
+ default: 1
+ min: 0
+ max: 1
+ - name: max_tokens
+ use_template: max_tokens
+ default: 1024
+ min: 1
+ max: 8000
+ - name: safe_prompt
+ defulat: false
+ type: boolean
+ help:
+ en_US: Whether to inject a safety prompt before all conversations.
+ zh_Hans: 是否开启提示词审查
+ label:
+ en_US: SafePrompt
+ zh_Hans: 提示词审查
+ - name: random_seed
+ type: int
+ help:
+ en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
+ zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
+ label:
+ en_US: RandomSeed
+ zh_Hans: 随机数种子
+ default: 0
+ min: 0
+ max: 2147483647
+pricing:
+ input: '0.0027'
+ output: '0.0081'
+ unit: '0.001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/mistralai/llm/mistral-small-latest.yaml b/api/core/model_runtime/model_providers/mistralai/llm/mistral-small-latest.yaml
new file mode 100644
index 0000000000..4e7e6147f5
--- /dev/null
+++ b/api/core/model_runtime/model_providers/mistralai/llm/mistral-small-latest.yaml
@@ -0,0 +1,50 @@
+model: mistral-small-latest
+label:
+ zh_Hans: mistral-small-latest
+ en_US: mistral-small-latest
+model_type: llm
+features:
+ - agent-thought
+model_properties:
+ context_size: 32000
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ default: 0.7
+ min: 0
+ max: 1
+ - name: top_p
+ use_template: top_p
+ default: 1
+ min: 0
+ max: 1
+ - name: max_tokens
+ use_template: max_tokens
+ default: 1024
+ min: 1
+ max: 8000
+ - name: safe_prompt
+ defulat: false
+ type: boolean
+ help:
+ en_US: Whether to inject a safety prompt before all conversations.
+ zh_Hans: 是否开启提示词审查
+ label:
+ en_US: SafePrompt
+ zh_Hans: 提示词审查
+ - name: random_seed
+ type: int
+ help:
+ en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
+ zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
+ label:
+ en_US: RandomSeed
+ zh_Hans: 随机数种子
+ default: 0
+ min: 0
+ max: 2147483647
+pricing:
+ input: '0.002'
+ output: '0.006'
+ unit: '0.001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/mistralai/llm/open-mistral-7b.yaml b/api/core/model_runtime/model_providers/mistralai/llm/open-mistral-7b.yaml
new file mode 100644
index 0000000000..30454f7df2
--- /dev/null
+++ b/api/core/model_runtime/model_providers/mistralai/llm/open-mistral-7b.yaml
@@ -0,0 +1,50 @@
+model: open-mistral-7b
+label:
+ zh_Hans: open-mistral-7b
+ en_US: open-mistral-7b
+model_type: llm
+features:
+ - agent-thought
+model_properties:
+ context_size: 8000
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ default: 0.7
+ min: 0
+ max: 1
+ - name: top_p
+ use_template: top_p
+ default: 1
+ min: 0
+ max: 1
+ - name: max_tokens
+ use_template: max_tokens
+ default: 1024
+ min: 1
+ max: 2048
+ - name: safe_prompt
+ defulat: false
+ type: boolean
+ help:
+ en_US: Whether to inject a safety prompt before all conversations.
+ zh_Hans: 是否开启提示词审查
+ label:
+ en_US: SafePrompt
+ zh_Hans: 提示词审查
+ - name: random_seed
+ type: int
+ help:
+ en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
+ zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
+ label:
+ en_US: RandomSeed
+ zh_Hans: 随机数种子
+ default: 0
+ min: 0
+ max: 2147483647
+pricing:
+ input: '0.00025'
+ output: '0.00025'
+ unit: '0.001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/mistralai/llm/open-mixtral-8x7b.yaml b/api/core/model_runtime/model_providers/mistralai/llm/open-mixtral-8x7b.yaml
new file mode 100644
index 0000000000..a35cf0a9ae
--- /dev/null
+++ b/api/core/model_runtime/model_providers/mistralai/llm/open-mixtral-8x7b.yaml
@@ -0,0 +1,50 @@
+model: open-mixtral-8x7b
+label:
+ zh_Hans: open-mixtral-8x7b
+ en_US: open-mixtral-8x7b
+model_type: llm
+features:
+ - agent-thought
+model_properties:
+ context_size: 32000
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ default: 0.7
+ min: 0
+ max: 1
+ - name: top_p
+ use_template: top_p
+ default: 1
+ min: 0
+ max: 1
+ - name: max_tokens
+ use_template: max_tokens
+ default: 1024
+ min: 1
+ max: 8000
+ - name: safe_prompt
+ defulat: false
+ type: boolean
+ help:
+ en_US: Whether to inject a safety prompt before all conversations.
+ zh_Hans: 是否开启提示词审查
+ label:
+ en_US: SafePrompt
+ zh_Hans: 提示词审查
+ - name: random_seed
+ type: int
+ help:
+ en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
+ zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
+ label:
+ en_US: RandomSeed
+ zh_Hans: 随机数种子
+ default: 0
+ min: 0
+ max: 2147483647
+pricing:
+ input: '0.0007'
+ output: '0.0007'
+ unit: '0.001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/mistralai/mistralai.py b/api/core/model_runtime/model_providers/mistralai/mistralai.py
new file mode 100644
index 0000000000..f1d825f6c6
--- /dev/null
+++ b/api/core/model_runtime/model_providers/mistralai/mistralai.py
@@ -0,0 +1,30 @@
+import logging
+
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class MistralAIProvider(ModelProvider):
+
+ def validate_provider_credentials(self, credentials: dict) -> None:
+ """
+ Validate provider credentials
+ if validate failed, raise exception
+
+ :param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
+ """
+ try:
+ model_instance = self.get_model_instance(ModelType.LLM)
+
+ model_instance.validate_credentials(
+ model='open-mistral-7b',
+ credentials=credentials
+ )
+ except CredentialsValidateFailedError as ex:
+ raise ex
+ except Exception as ex:
+ logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
+ raise ex
diff --git a/api/core/model_runtime/model_providers/mistralai/mistralai.yaml b/api/core/model_runtime/model_providers/mistralai/mistralai.yaml
new file mode 100644
index 0000000000..c9b4226ea6
--- /dev/null
+++ b/api/core/model_runtime/model_providers/mistralai/mistralai.yaml
@@ -0,0 +1,31 @@
+provider: mistralai
+label:
+ en_US: MistralAI
+description:
+ en_US: Models provided by MistralAI, such as open-mistral-7b and mistral-large-latest.
+ zh_Hans: MistralAI 提供的模型,例如 open-mistral-7b 和 mistral-large-latest。
+icon_small:
+ en_US: icon_s_en.png
+icon_large:
+ en_US: icon_l_en.png
+background: "#FFFFFF"
+help:
+ title:
+ en_US: Get your API Key from MistralAI
+ zh_Hans: 从 MistralAI 获取 API Key
+ url:
+ en_US: https://console.mistral.ai/api-keys/
+supported_model_types:
+ - llm
+configurate_methods:
+ - predefined-model
+provider_credential_schema:
+ credential_form_schemas:
+ - variable: api_key
+ label:
+ en_US: API Key
+ type: secret-input
+ required: true
+ placeholder:
+ zh_Hans: 在此输入您的 API Key
+ en_US: Enter your API Key
From 174ee1b646f81675131261aeb90a4dd5b8945ae6 Mon Sep 17 00:00:00 2001
From: takatost
Date: Wed, 28 Feb 2024 12:23:34 +0800
Subject: [PATCH 24/36] fix: parameter `user` exceeded max length when invoking
moonshot llm (#2596)
---
api/core/model_runtime/model_providers/moonshot/llm/llm.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/api/core/model_runtime/model_providers/moonshot/llm/llm.py b/api/core/model_runtime/model_providers/moonshot/llm/llm.py
index 5db3e2827b..05feee877e 100644
--- a/api/core/model_runtime/model_providers/moonshot/llm/llm.py
+++ b/api/core/model_runtime/model_providers/moonshot/llm/llm.py
@@ -13,6 +13,7 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials)
+ user = user[:32] if user else None
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def validate_credentials(self, model: str, credentials: dict) -> None:
From dc93a292c36fd871d00c219207386b092d697c66 Mon Sep 17 00:00:00 2001
From: Joshua <138381132+joshua20231026@users.noreply.github.com>
Date: Wed, 28 Feb 2024 13:39:55 +0800
Subject: [PATCH 25/36] Feat/provider mistralai (#2598)
---
api/core/model_runtime/model_providers/_position.yaml | 1 +
1 file changed, 1 insertion(+)
diff --git a/api/core/model_runtime/model_providers/_position.yaml b/api/core/model_runtime/model_providers/_position.yaml
index b2c6518395..8c878d67d8 100644
--- a/api/core/model_runtime/model_providers/_position.yaml
+++ b/api/core/model_runtime/model_providers/_position.yaml
@@ -6,6 +6,7 @@
- bedrock
- togetherai
- ollama
+- mistralai
- replicate
- huggingface_hub
- zhipuai
From c4caa7c401698e62ff57603b259e8456cb55e25c Mon Sep 17 00:00:00 2001
From: crazywoola <100913391+crazywoola@users.noreply.github.com>
Date: Wed, 28 Feb 2024 13:40:57 +0800
Subject: [PATCH 26/36] doc: props.appDetail.api_base_url (#2597)
---
web/app/components/develop/template/template.en.mdx | 4 ++--
web/app/components/develop/template/template.zh.mdx | 4 ++--
web/app/components/develop/template/template_chat.en.mdx | 8 ++++----
web/app/components/develop/template/template_chat.zh.mdx | 8 ++++----
4 files changed, 12 insertions(+), 12 deletions(-)
diff --git a/web/app/components/develop/template/template.en.mdx b/web/app/components/develop/template/template.en.mdx
index 9bc994551b..121cf78d18 100644
--- a/web/app/components/develop/template/template.en.mdx
+++ b/web/app/components/develop/template/template.en.mdx
@@ -289,9 +289,9 @@ The text generation application offers non-session support and is ideal for tran
### Request Example
-
+
```bash {{ title: 'cURL' }}
- curl -X POST 'https://cloud.dify.ai/v1/completion-messages/:task_id/stop' \
+ curl -X POST '${props.appDetail.api_base_url}/v1/completion-messages/:task_id/stop' \
-H 'Authorization: Bearer {api_key}' \
-H 'Content-Type: application/json' \
--data-raw '{
diff --git a/web/app/components/develop/template/template.zh.mdx b/web/app/components/develop/template/template.zh.mdx
index 6e2ff881d1..5a86406225 100644
--- a/web/app/components/develop/template/template.zh.mdx
+++ b/web/app/components/develop/template/template.zh.mdx
@@ -266,9 +266,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
- `result` (string) 固定返回 success
-
+
```bash {{ title: 'cURL' }}
- curl -X POST 'https://cloud.dify.ai/v1/completion-messages/:task_id/stop' \
+ curl -X POST '${props.appDetail.api_base_url}/v1/completion-messages/:task_id/stop' \
-H 'Authorization: Bearer {api_key}' \
-H 'Content-Type: application/json' \
--data-raw '{
diff --git a/web/app/components/develop/template/template_chat.en.mdx b/web/app/components/develop/template/template_chat.en.mdx
index 9e8dd69874..33dd6049b3 100644
--- a/web/app/components/develop/template/template_chat.en.mdx
+++ b/web/app/components/develop/template/template_chat.en.mdx
@@ -344,9 +344,9 @@ Chat applications support session persistence, allowing previous chat history to
### Request Example
-
+
```bash {{ title: 'cURL' }}
- curl -X POST 'https://cloud.dify.ai/v1/chat-messages/:task_id/stop' \
+ curl -X POST '${props.appDetail.api_base_url}/v1/chat-messages/:task_id/stop' \
-H 'Authorization: Bearer {api_key}' \
-H 'Content-Type: application/json' \
--data-raw '{
@@ -1025,9 +1025,9 @@ Chat applications support session persistence, allowing previous chat history to
- (string) url of icon
-
+
```bash {{ title: 'cURL' }}
- curl -X GET 'https://cloud.dify.ai/v1/meta?user=abc-123' \
+ curl -X GET '${props.appDetail.api_base_url}/v1/meta?user=abc-123' \
-H 'Authorization: Bearer {api_key}'
```
diff --git a/web/app/components/develop/template/template_chat.zh.mdx b/web/app/components/develop/template/template_chat.zh.mdx
index 47f64466e7..a62d68abba 100644
--- a/web/app/components/develop/template/template_chat.zh.mdx
+++ b/web/app/components/develop/template/template_chat.zh.mdx
@@ -360,9 +360,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
- `result` (string) 固定返回 success
-
+
```bash {{ title: 'cURL' }}
- curl -X POST 'https://cloud.dify.ai/v1/chat-messages/:task_id/stop' \
+ curl -X POST '${props.appDetail.api_base_url}/v1/chat-messages/:task_id/stop' \
-H 'Authorization: Bearer {api_key}' \
-H 'Content-Type: application/json' \
--data-raw '{
@@ -1022,9 +1022,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
- (string) 图标URL
-
+
```bash {{ title: 'cURL' }}
- curl -X GET 'https://cloud.dify.ai/v1/meta?user=abc-123' \
+ curl -X GET '${props.appDetail.api_base_url}/v1/meta?user=abc-123' \
-H 'Authorization: Bearer {api_key}'
```
From 69ce3b3d33cbff1ade7d609e48a85d29bff2bd14 Mon Sep 17 00:00:00 2001
From: cola <45722758+xiangpingjiang@users.noreply.github.com>
Date: Wed, 28 Feb 2024 15:13:38 +0800
Subject: [PATCH 27/36] fix props.appDetail.api_base_url /v1 repeat error
(#2601)
---
web/app/components/develop/template/template.en.mdx | 4 ++--
web/app/components/develop/template/template.zh.mdx | 4 ++--
web/app/components/develop/template/template_chat.en.mdx | 8 ++++----
web/app/components/develop/template/template_chat.zh.mdx | 8 ++++----
4 files changed, 12 insertions(+), 12 deletions(-)
diff --git a/web/app/components/develop/template/template.en.mdx b/web/app/components/develop/template/template.en.mdx
index 121cf78d18..f930cfe1c9 100644
--- a/web/app/components/develop/template/template.en.mdx
+++ b/web/app/components/develop/template/template.en.mdx
@@ -289,9 +289,9 @@ The text generation application offers non-session support and is ideal for tran
### Request Example
-
+
```bash {{ title: 'cURL' }}
- curl -X POST '${props.appDetail.api_base_url}/v1/completion-messages/:task_id/stop' \
+ curl -X POST '${props.appDetail.api_base_url}/completion-messages/:task_id/stop' \
-H 'Authorization: Bearer {api_key}' \
-H 'Content-Type: application/json' \
--data-raw '{
diff --git a/web/app/components/develop/template/template.zh.mdx b/web/app/components/develop/template/template.zh.mdx
index 5a86406225..8153906d0a 100644
--- a/web/app/components/develop/template/template.zh.mdx
+++ b/web/app/components/develop/template/template.zh.mdx
@@ -266,9 +266,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
- `result` (string) 固定返回 success
-
+
```bash {{ title: 'cURL' }}
- curl -X POST '${props.appDetail.api_base_url}/v1/completion-messages/:task_id/stop' \
+ curl -X POST '${props.appDetail.api_base_url}/completion-messages/:task_id/stop' \
-H 'Authorization: Bearer {api_key}' \
-H 'Content-Type: application/json' \
--data-raw '{
diff --git a/web/app/components/develop/template/template_chat.en.mdx b/web/app/components/develop/template/template_chat.en.mdx
index 33dd6049b3..e102108154 100644
--- a/web/app/components/develop/template/template_chat.en.mdx
+++ b/web/app/components/develop/template/template_chat.en.mdx
@@ -344,9 +344,9 @@ Chat applications support session persistence, allowing previous chat history to
### Request Example
-
+
```bash {{ title: 'cURL' }}
- curl -X POST '${props.appDetail.api_base_url}/v1/chat-messages/:task_id/stop' \
+ curl -X POST '${props.appDetail.api_base_url}/chat-messages/:task_id/stop' \
-H 'Authorization: Bearer {api_key}' \
-H 'Content-Type: application/json' \
--data-raw '{
@@ -1025,9 +1025,9 @@ Chat applications support session persistence, allowing previous chat history to
- (string) url of icon
-
+
```bash {{ title: 'cURL' }}
- curl -X GET '${props.appDetail.api_base_url}/v1/meta?user=abc-123' \
+ curl -X GET '${props.appDetail.api_base_url}/meta?user=abc-123' \
-H 'Authorization: Bearer {api_key}'
```
diff --git a/web/app/components/develop/template/template_chat.zh.mdx b/web/app/components/develop/template/template_chat.zh.mdx
index a62d68abba..7bc3cd5337 100644
--- a/web/app/components/develop/template/template_chat.zh.mdx
+++ b/web/app/components/develop/template/template_chat.zh.mdx
@@ -360,9 +360,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
- `result` (string) 固定返回 success
-
+
```bash {{ title: 'cURL' }}
- curl -X POST '${props.appDetail.api_base_url}/v1/chat-messages/:task_id/stop' \
+ curl -X POST '${props.appDetail.api_base_url}/chat-messages/:task_id/stop' \
-H 'Authorization: Bearer {api_key}' \
-H 'Content-Type: application/json' \
--data-raw '{
@@ -1022,9 +1022,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
- (string) 图标URL
-
+
```bash {{ title: 'cURL' }}
- curl -X GET '${props.appDetail.api_base_url}/v1/meta?user=abc-123' \
+ curl -X GET '${props.appDetail.api_base_url}/meta?user=abc-123' \
-H 'Authorization: Bearer {api_key}'
```
From c9257ab4bf315193b13f14eab86c6fa12dcebe4b Mon Sep 17 00:00:00 2001
From: crazywoola <100913391+crazywoola@users.noreply.github.com>
Date: Wed, 28 Feb 2024 15:17:49 +0800
Subject: [PATCH 28/36] Fix/2559 upload powered by brand image not showing up
(#2602)
---
api/services/workspace_service.py | 14 ++++++++++++--
.../custom/custom-web-app-brand/index.tsx | 4 +---
2 files changed, 13 insertions(+), 5 deletions(-)
diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py
index 923e44dd85..778b4e51d3 100644
--- a/api/services/workspace_service.py
+++ b/api/services/workspace_service.py
@@ -1,3 +1,5 @@
+
+from flask import current_app
from flask_login import current_user
from extensions.ext_database import db
@@ -31,7 +33,15 @@ class WorkspaceService:
can_replace_logo = FeatureService.get_features(tenant_info['id']).can_replace_logo
- if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]):
- tenant_info['custom_config'] = tenant.custom_config_dict
+ if can_replace_logo and TenantService.has_roles(tenant,
+ [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]):
+ base_url = current_app.config.get('FILES_URL')
+ replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None
+ remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False)
+
+ tenant_info['custom_config'] = {
+ 'remove_webapp_brand': remove_webapp_brand,
+ 'replace_webapp_logo': replace_webapp_logo,
+ }
return tenant_info
diff --git a/web/app/components/custom/custom-web-app-brand/index.tsx b/web/app/components/custom/custom-web-app-brand/index.tsx
index 4817cfddab..857706bf26 100644
--- a/web/app/components/custom/custom-web-app-brand/index.tsx
+++ b/web/app/components/custom/custom-web-app-brand/index.tsx
@@ -16,8 +16,6 @@ import {
updateCurrentWorkspace,
} from '@/service/common'
import { useAppContext } from '@/context/app-context'
-import { API_PREFIX } from '@/config'
-import { getPurifyHref } from '@/utils'
const ALLOW_FILE_EXTENSIONS = ['svg', 'png']
@@ -123,7 +121,7 @@ const CustomWebAppBrand = () => {
POWERED BY
{
webappLogo
- ?
+ ?
:
}
From 816b707a16e5b205336a5c1eb3eea1c76e907535 Mon Sep 17 00:00:00 2001
From: crazywoola <100913391+crazywoola@users.noreply.github.com>
Date: Wed, 28 Feb 2024 15:43:42 +0800
Subject: [PATCH 29/36] Fix: explore apps is not shown (#2604)
---
web/app/components/explore/app-list/index.tsx | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/web/app/components/explore/app-list/index.tsx b/web/app/components/explore/app-list/index.tsx
index d3e90e6212..9031bb6299 100644
--- a/web/app/components/explore/app-list/index.tsx
+++ b/web/app/components/explore/app-list/index.tsx
@@ -26,13 +26,12 @@ const Apps: FC = () => {
const { isCurrentWorkspaceManager } = useAppContext()
const router = useRouter()
const { hasEditPermission } = useContext(ExploreContext)
- const allCategoriesEn = t('explore.apps.allCategories', { lng: 'en' })
+ const allCategoriesEn = t('explore.apps.allCategories')
const [currCategory, setCurrCategory] = useTabSearchParams({
defaultTab: allCategoriesEn,
})
const {
data: { categories, allList },
- isLoading,
} = useSWR(
['/explore/apps'],
() =>
@@ -90,7 +89,7 @@ const Apps: FC = () => {
}
}
- if (!isLoading) {
+ if (!categories) {
return (
From 0828873b526cfa38b9fd0ae6c1cec4c9882d5c36 Mon Sep 17 00:00:00 2001
From: takatost
Date: Wed, 28 Feb 2024 16:09:56 +0800
Subject: [PATCH 30/36] fix: missing default user for APP service api (#2606)
---
api/controllers/service_api/app/__init__.py | 27 ------
api/controllers/service_api/app/app.py | 14 +--
api/controllers/service_api/app/audio.py | 19 ++--
api/controllers/service_api/app/completion.py | 56 ++++--------
.../service_api/app/conversation.py | 35 +++-----
api/controllers/service_api/app/file.py | 15 ++--
api/controllers/service_api/app/message.py | 48 +++-------
api/controllers/service_api/wraps.py | 90 ++++++++++++++++---
8 files changed, 141 insertions(+), 163 deletions(-)
diff --git a/api/controllers/service_api/app/__init__.py b/api/controllers/service_api/app/__init__.py
index d8018ee385..e69de29bb2 100644
--- a/api/controllers/service_api/app/__init__.py
+++ b/api/controllers/service_api/app/__init__.py
@@ -1,27 +0,0 @@
-from extensions.ext_database import db
-from models.model import EndUser
-
-
-def create_or_update_end_user_for_user_id(app_model, user_id):
- """
- Create or update session terminal based on user ID.
- """
- end_user = db.session.query(EndUser) \
- .filter(
- EndUser.tenant_id == app_model.tenant_id,
- EndUser.session_id == user_id,
- EndUser.type == 'service_api'
- ).first()
-
- if end_user is None:
- end_user = EndUser(
- tenant_id=app_model.tenant_id,
- app_id=app_model.id,
- type='service_api',
- is_anonymous=True,
- session_id=user_id
- )
- db.session.add(end_user)
- db.session.commit()
-
- return end_user
diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py
index 9cd9770c09..a3151fc4a2 100644
--- a/api/controllers/service_api/app/app.py
+++ b/api/controllers/service_api/app/app.py
@@ -1,16 +1,16 @@
import json
from flask import current_app
-from flask_restful import fields, marshal_with
+from flask_restful import fields, marshal_with, Resource
from controllers.service_api import api
-from controllers.service_api.wraps import AppApiResource
+from controllers.service_api.wraps import validate_app_token
from extensions.ext_database import db
from models.model import App, AppModelConfig
from models.tools import ApiToolProvider
-class AppParameterApi(AppApiResource):
+class AppParameterApi(Resource):
"""Resource for app variables."""
variable_fields = {
@@ -42,8 +42,9 @@ class AppParameterApi(AppApiResource):
'system_parameters': fields.Nested(system_parameters_fields)
}
+ @validate_app_token
@marshal_with(parameters_fields)
- def get(self, app_model: App, end_user):
+ def get(self, app_model: App):
"""Retrieve app parameters."""
app_model_config = app_model.app_model_config
@@ -64,8 +65,9 @@ class AppParameterApi(AppApiResource):
}
}
-class AppMetaApi(AppApiResource):
- def get(self, app_model: App, end_user):
+class AppMetaApi(Resource):
+ @validate_app_token
+ def get(self, app_model: App):
"""Get app meta"""
app_model_config: AppModelConfig = app_model.app_model_config
diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py
index d2906b1d6e..58ab56a292 100644
--- a/api/controllers/service_api/app/audio.py
+++ b/api/controllers/service_api/app/audio.py
@@ -1,7 +1,7 @@
import logging
from flask import request
-from flask_restful import reqparse
+from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError
import services
@@ -17,10 +17,10 @@ from controllers.service_api.app.error import (
ProviderQuotaExceededError,
UnsupportedAudioTypeError,
)
-from controllers.service_api.wraps import AppApiResource
+from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
-from models.model import App, AppModelConfig
+from models.model import App, AppModelConfig, EndUser
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
@@ -30,8 +30,9 @@ from services.errors.audio import (
)
-class AudioApi(AppApiResource):
- def post(self, app_model: App, end_user):
+class AudioApi(Resource):
+ @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
+ def post(self, app_model: App, end_user: EndUser):
app_model_config: AppModelConfig = app_model.app_model_config
if not app_model_config.speech_to_text_dict['enabled']:
@@ -73,11 +74,11 @@ class AudioApi(AppApiResource):
raise InternalServerError()
-class TextApi(AppApiResource):
- def post(self, app_model: App, end_user):
+class TextApi(Resource):
+ @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
+ def post(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument('text', type=str, required=True, nullable=False, location='json')
- parser.add_argument('user', type=str, required=True, nullable=False, location='json')
parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json')
args = parser.parse_args()
@@ -85,7 +86,7 @@ class TextApi(AppApiResource):
response = AudioService.transcript_tts(
tenant_id=app_model.tenant_id,
text=args['text'],
- end_user=args['user'],
+ end_user=end_user,
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=args['streaming']
)
diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py
index 5331f796e7..c6cfb24378 100644
--- a/api/controllers/service_api/app/completion.py
+++ b/api/controllers/service_api/app/completion.py
@@ -4,12 +4,11 @@ from collections.abc import Generator
from typing import Union
from flask import Response, stream_with_context
-from flask_restful import reqparse
+from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError, NotFound
import services
from controllers.service_api import api
-from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import (
AppUnavailableError,
CompletionRequestError,
@@ -19,17 +18,19 @@ from controllers.service_api.app.error import (
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
-from controllers.service_api.wraps import AppApiResource
+from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.application_queue_manager import ApplicationQueueManager
from core.entities.application_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value
+from models.model import App, EndUser
from services.completion_service import CompletionService
-class CompletionApi(AppApiResource):
- def post(self, app_model, end_user):
+class CompletionApi(Resource):
+ @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
+ def post(self, app_model: App, end_user: EndUser):
if app_model.mode != 'completion':
raise AppUnavailableError()
@@ -38,16 +39,12 @@ class CompletionApi(AppApiResource):
parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
- parser.add_argument('user', required=True, nullable=False, type=str, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
args = parser.parse_args()
streaming = args['response_mode'] == 'streaming'
- if end_user is None and args['user'] is not None:
- end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
-
args['auto_generate_name'] = False
try:
@@ -82,29 +79,20 @@ class CompletionApi(AppApiResource):
raise InternalServerError()
-class CompletionStopApi(AppApiResource):
- def post(self, app_model, end_user, task_id):
+class CompletionStopApi(Resource):
+ @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
+ def post(self, app_model: App, end_user: EndUser, task_id):
if app_model.mode != 'completion':
raise AppUnavailableError()
- if end_user is None:
- parser = reqparse.RequestParser()
- parser.add_argument('user', required=True, nullable=False, type=str, location='json')
- args = parser.parse_args()
-
- user = args.get('user')
- if user is not None:
- end_user = create_or_update_end_user_for_user_id(app_model, user)
- else:
- raise ValueError("arg user muse be input.")
-
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {'result': 'success'}, 200
-class ChatApi(AppApiResource):
- def post(self, app_model, end_user):
+class ChatApi(Resource):
+ @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
+ def post(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat':
raise NotChatAppError()
@@ -114,7 +102,6 @@ class ChatApi(AppApiResource):
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
- parser.add_argument('user', type=str, required=True, nullable=False, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json')
@@ -122,9 +109,6 @@ class ChatApi(AppApiResource):
streaming = args['response_mode'] == 'streaming'
- if end_user is None and args['user'] is not None:
- end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
-
try:
response = CompletionService.completion(
app_model=app_model,
@@ -157,22 +141,12 @@ class ChatApi(AppApiResource):
raise InternalServerError()
-class ChatStopApi(AppApiResource):
- def post(self, app_model, end_user, task_id):
+class ChatStopApi(Resource):
+ @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
+ def post(self, app_model: App, end_user: EndUser, task_id):
if app_model.mode != 'chat':
raise NotChatAppError()
- if end_user is None:
- parser = reqparse.RequestParser()
- parser.add_argument('user', required=True, nullable=False, type=str, location='json')
- args = parser.parse_args()
-
- user = args.get('user')
- if user is not None:
- end_user = create_or_update_end_user_for_user_id(app_model, user)
- else:
- raise ValueError("arg user muse be input.")
-
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {'result': 'success'}, 200
diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py
index 3c157bed99..4a5fe2f19f 100644
--- a/api/controllers/service_api/app/conversation.py
+++ b/api/controllers/service_api/app/conversation.py
@@ -1,52 +1,44 @@
-from flask import request
-from flask_restful import marshal_with, reqparse
+from flask_restful import Resource, marshal_with, reqparse
from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound
import services
from controllers.service_api import api
-from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import NotChatAppError
-from controllers.service_api.wraps import AppApiResource
+from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value
+from models.model import App, EndUser
from services.conversation_service import ConversationService
-class ConversationApi(AppApiResource):
+class ConversationApi(Resource):
+ @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@marshal_with(conversation_infinite_scroll_pagination_fields)
- def get(self, app_model, end_user):
+ def get(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat':
raise NotChatAppError()
parser = reqparse.RequestParser()
parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
- parser.add_argument('user', type=str, location='args')
args = parser.parse_args()
- if end_user is None and args['user'] is not None:
- end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
-
try:
return ConversationService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit'])
except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
-class ConversationDetailApi(AppApiResource):
+class ConversationDetailApi(Resource):
+ @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(simple_conversation_fields)
- def delete(self, app_model, end_user, c_id):
+ def delete(self, app_model: App, end_user: EndUser, c_id):
if app_model.mode != 'chat':
raise NotChatAppError()
conversation_id = str(c_id)
- user = request.get_json().get('user')
-
- if end_user is None and user is not None:
- end_user = create_or_update_end_user_for_user_id(app_model, user)
-
try:
ConversationService.delete(app_model, conversation_id, end_user)
except services.errors.conversation.ConversationNotExistsError:
@@ -54,10 +46,11 @@ class ConversationDetailApi(AppApiResource):
return {"result": "success"}, 204
-class ConversationRenameApi(AppApiResource):
+class ConversationRenameApi(Resource):
+ @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(simple_conversation_fields)
- def post(self, app_model, end_user, c_id):
+ def post(self, app_model: App, end_user: EndUser, c_id):
if app_model.mode != 'chat':
raise NotChatAppError()
@@ -65,13 +58,9 @@ class ConversationRenameApi(AppApiResource):
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=False, location='json')
- parser.add_argument('user', type=str, location='json')
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
args = parser.parse_args()
- if end_user is None and args['user'] is not None:
- end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
-
try:
return ConversationService.rename(
app_model,
diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py
index a901375ec0..5dbc1b1d1b 100644
--- a/api/controllers/service_api/app/file.py
+++ b/api/controllers/service_api/app/file.py
@@ -1,30 +1,27 @@
from flask import request
-from flask_restful import marshal_with
+from flask_restful import Resource, marshal_with
import services
from controllers.service_api import api
-from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import (
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
UnsupportedFileTypeError,
)
-from controllers.service_api.wraps import AppApiResource
+from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.file_fields import file_fields
+from models.model import App, EndUser
from services.file_service import FileService
-class FileApi(AppApiResource):
+class FileApi(Resource):
+ @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
@marshal_with(file_fields)
- def post(self, app_model, end_user):
+ def post(self, app_model: App, end_user: EndUser):
file = request.files['file']
- user_args = request.form.get('user')
-
- if end_user is None and user_args is not None:
- end_user = create_or_update_end_user_for_user_id(app_model, user_args)
# check file
if 'file' not in request.files:
diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py
index d90f536a42..0050ab1aee 100644
--- a/api/controllers/service_api/app/message.py
+++ b/api/controllers/service_api/app/message.py
@@ -1,20 +1,18 @@
-from flask_restful import fields, marshal_with, reqparse
+from flask_restful import Resource, fields, marshal_with, reqparse
from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound
import services
from controllers.service_api import api
-from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import NotChatAppError
-from controllers.service_api.wraps import AppApiResource
-from extensions.ext_database import db
+from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, uuid_value
-from models.model import EndUser, Message
+from models.model import App, EndUser
from services.message_service import MessageService
-class MessageListApi(AppApiResource):
+class MessageListApi(Resource):
feedback_fields = {
'rating': fields.String
}
@@ -70,8 +68,9 @@ class MessageListApi(AppApiResource):
'data': fields.List(fields.Nested(message_fields))
}
+ @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@marshal_with(message_infinite_scroll_pagination_fields)
- def get(self, app_model, end_user):
+ def get(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat':
raise NotChatAppError()
@@ -79,12 +78,8 @@ class MessageListApi(AppApiResource):
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
parser.add_argument('first_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
- parser.add_argument('user', type=str, location='args')
args = parser.parse_args()
- if end_user is None and args['user'] is not None:
- end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
-
try:
return MessageService.pagination_by_first_id(app_model, end_user,
args['conversation_id'], args['first_id'], args['limit'])
@@ -94,18 +89,15 @@ class MessageListApi(AppApiResource):
raise NotFound("First Message Not Exists.")
-class MessageFeedbackApi(AppApiResource):
- def post(self, app_model, end_user, message_id):
+class MessageFeedbackApi(Resource):
+ @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
+ def post(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id)
parser = reqparse.RequestParser()
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
- parser.add_argument('user', type=str, location='json')
args = parser.parse_args()
- if end_user is None and args['user'] is not None:
- end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
-
try:
MessageService.create_feedback(app_model, message_id, end_user, args['rating'])
except services.errors.message.MessageNotExistsError:
@@ -114,29 +106,17 @@ class MessageFeedbackApi(AppApiResource):
return {'result': 'success'}
-class MessageSuggestedApi(AppApiResource):
- def get(self, app_model, end_user, message_id):
+class MessageSuggestedApi(Resource):
+ @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
+ def get(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id)
if app_model.mode != 'chat':
raise NotChatAppError()
- try:
- message = db.session.query(Message).filter(
- Message.id == message_id,
- Message.app_id == app_model.id,
- ).first()
- if end_user is None and message.from_end_user_id is not None:
- user = db.session.query(EndUser) \
- .filter(
- EndUser.tenant_id == app_model.tenant_id,
- EndUser.id == message.from_end_user_id,
- EndUser.type == 'service_api'
- ).first()
- else:
- user = end_user
+ try:
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model,
- user=user,
+ user=end_user,
message_id=message_id,
check_enabled=False
)
diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py
index a0d89fe62f..169c475af9 100644
--- a/api/controllers/service_api/wraps.py
+++ b/api/controllers/service_api/wraps.py
@@ -1,22 +1,40 @@
+from collections.abc import Callable
from datetime import datetime
+from enum import Enum
from functools import wraps
+from typing import Optional
from flask import current_app, request
from flask_login import user_logged_in
from flask_restful import Resource
+from pydantic import BaseModel
from werkzeug.exceptions import NotFound, Unauthorized
from extensions.ext_database import db
from libs.login import _get_user
from models.account import Account, Tenant, TenantAccountJoin
-from models.model import ApiToken, App
+from models.model import ApiToken, App, EndUser
from services.feature_service import FeatureService
-def validate_app_token(view=None):
- def decorator(view):
- @wraps(view)
- def decorated(*args, **kwargs):
+class WhereisUserArg(Enum):
+ """
+ Enum for whereis_user_arg.
+ """
+ QUERY = 'query'
+ JSON = 'json'
+ FORM = 'form'
+
+
+class FetchUserArg(BaseModel):
+ fetch_from: WhereisUserArg
+ required: bool = False
+
+
+def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None):
+ def decorator(view_func):
+ @wraps(view_func)
+ def decorated_view(*args, **kwargs):
api_token = validate_and_get_api_token('app')
app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
@@ -29,16 +47,35 @@ def validate_app_token(view=None):
if not app_model.enable_api:
raise NotFound()
- return view(app_model, None, *args, **kwargs)
- return decorated
+ kwargs['app_model'] = app_model
- if view:
+ if not fetch_user_arg:
+ # use default-user
+ user_id = None
+ else:
+ if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
+ user_id = request.args.get('user')
+ elif fetch_user_arg.fetch_from == WhereisUserArg.JSON:
+ user_id = request.get_json().get('user')
+ elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
+ user_id = request.form.get('user')
+ else:
+ # use default-user
+ user_id = None
+
+ if not user_id and fetch_user_arg.required:
+ raise ValueError("Arg user must be provided.")
+
+ kwargs['end_user'] = create_or_update_end_user_for_user_id(app_model, user_id)
+
+ return view_func(*args, **kwargs)
+ return decorated_view
+
+ if view is None:
+ return decorator
+ else:
return decorator(view)
- # if view is None, it means that the decorator is used without parentheses
- # use the decorator as a function for method_decorators
- return decorator
-
def cloud_edition_billing_resource_check(resource: str,
api_token_type: str,
@@ -128,8 +165,33 @@ def validate_and_get_api_token(scope=None):
return api_token
-class AppApiResource(Resource):
- method_decorators = [validate_app_token]
+def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] = None) -> EndUser:
+ """
+ Create or update session terminal based on user ID.
+ """
+ if not user_id:
+ user_id = 'DEFAULT-USER'
+
+ end_user = db.session.query(EndUser) \
+ .filter(
+ EndUser.tenant_id == app_model.tenant_id,
+ EndUser.app_id == app_model.id,
+ EndUser.session_id == user_id,
+ EndUser.type == 'service_api'
+ ).first()
+
+ if end_user is None:
+ end_user = EndUser(
+ tenant_id=app_model.tenant_id,
+ app_id=app_model.id,
+ type='service_api',
+ is_anonymous=True if user_id == 'DEFAULT-USER' else False,
+ session_id=user_id
+ )
+ db.session.add(end_user)
+ db.session.commit()
+
+ return end_user
class DatasetApiResource(Resource):
From 90bdc85f8c0ad89fbd505858ae784d8af091d306 Mon Sep 17 00:00:00 2001
From: takatost
Date: Wed, 28 Feb 2024 16:46:50 +0800
Subject: [PATCH 31/36] fix: AppParameterApi.get() got an unexpected keyword
argument 'end_user' (#2607)
---
api/controllers/service_api/wraps.py | 10 +++++-----
api/core/app_runner/generate_task_pipeline.py | 2 +-
2 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py
index 169c475af9..9819c73d37 100644
--- a/api/controllers/service_api/wraps.py
+++ b/api/controllers/service_api/wraps.py
@@ -49,10 +49,7 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
kwargs['app_model'] = app_model
- if not fetch_user_arg:
- # use default-user
- user_id = None
- else:
+ if fetch_user_arg:
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
user_id = request.args.get('user')
elif fetch_user_arg.fetch_from == WhereisUserArg.JSON:
@@ -66,7 +63,10 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
if not user_id and fetch_user_arg.required:
raise ValueError("Arg user must be provided.")
- kwargs['end_user'] = create_or_update_end_user_for_user_id(app_model, user_id)
+ if user_id:
+ user_id = str(user_id)
+
+ kwargs['end_user'] = create_or_update_end_user_for_user_id(app_model, user_id)
return view_func(*args, **kwargs)
return decorated_view
diff --git a/api/core/app_runner/generate_task_pipeline.py b/api/core/app_runner/generate_task_pipeline.py
index 20e4bc7992..5fd635bc3b 100644
--- a/api/core/app_runner/generate_task_pipeline.py
+++ b/api/core/app_runner/generate_task_pipeline.py
@@ -175,7 +175,7 @@ class GenerateTaskPipeline:
'id': self._message.id,
'message_id': self._message.id,
'mode': self._conversation.mode,
- 'answer': event.llm_result.message.content,
+ 'answer': self._task_state.llm_result.message.content,
'metadata': {},
'created_at': int(self._message.created_at.timestamp())
}
From a4d86496e1f21fd4a593e9982cae6401fd9facdd Mon Sep 17 00:00:00 2001
From: takatost
Date: Wed, 28 Feb 2024 17:08:27 +0800
Subject: [PATCH 32/36] =?UTF-8?q?fix:=20notion=20extractor=20raise=20'None?=
=?UTF-8?q?Type'=20object=20has=20no=20attribute=20'curre=E2=80=A6=20(#260?=
=?UTF-8?q?8)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
api/core/rag/extractor/notion_extractor.py | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py
index 38dd36361a..c40064fd1d 100644
--- a/api/core/rag/extractor/notion_extractor.py
+++ b/api/core/rag/extractor/notion_extractor.py
@@ -4,7 +4,6 @@ from typing import Any, Optional
import requests
from flask import current_app
-from flask_login import current_user
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
@@ -43,7 +42,7 @@ class NotionExtractor(BaseExtractor):
if notion_access_token:
self._notion_access_token = notion_access_token
else:
- self._notion_access_token = self._get_access_token(current_user.current_tenant_id,
+ self._notion_access_token = self._get_access_token(tenant_id,
self._notion_workspace_id)
if not self._notion_access_token:
integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN')
From 3cf5c1853debd0210a683767318936fb0fd05941 Mon Sep 17 00:00:00 2001
From: crazywoola <100913391+crazywoola@users.noreply.github.com>
Date: Wed, 28 Feb 2024 17:34:20 +0800
Subject: [PATCH 33/36] Fix: default button behavior (#2609)
---
web/app/components/base/button/index.tsx | 14 +++++---------
1 file changed, 5 insertions(+), 9 deletions(-)
diff --git a/web/app/components/base/button/index.tsx b/web/app/components/base/button/index.tsx
index 24d58c6ea5..e617a5d12d 100644
--- a/web/app/components/base/button/index.tsx
+++ b/web/app/components/base/button/index.tsx
@@ -3,16 +3,13 @@ import React from 'react'
import Spinner from '../spinner'
export type IButtonProps = {
- /**
- * The style of the button
- */
- type?: 'primary' | 'warning' | (string & {})
+ type?: string
className?: string
disabled?: boolean
loading?: boolean
tabIndex?: number
children: React.ReactNode
- onClick?: MouseEventHandler
+ onClick?: MouseEventHandler
}
const Button: FC = ({
@@ -38,16 +35,15 @@ const Button: FC = ({
}
return (
-
{children}
{/* Spinner is hidden when loading is false */}
-
+
)
}
From 5bd3b02be652fb4f5a49bb4f39f4b5122401a9f3 Mon Sep 17 00:00:00 2001
From: takatost
Date: Wed, 28 Feb 2024 18:07:13 +0800
Subject: [PATCH 34/36] version to 0.5.7 (#2610)
---
api/config.py | 2 +-
docker/docker-compose.yaml | 6 +++---
web/package.json | 2 +-
3 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/api/config.py b/api/config.py
index 8eeede0ff9..3f6980bdea 100644
--- a/api/config.py
+++ b/api/config.py
@@ -90,7 +90,7 @@ class Config:
# ------------------------
# General Configurations.
# ------------------------
- self.CURRENT_VERSION = "0.5.6"
+ self.CURRENT_VERSION = "0.5.7"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml
index e3a7bdbbe2..7cd09fd6ea 100644
--- a/docker/docker-compose.yaml
+++ b/docker/docker-compose.yaml
@@ -2,7 +2,7 @@ version: '3.1'
services:
# API service
api:
- image: langgenius/dify-api:0.5.6
+ image: langgenius/dify-api:0.5.7
restart: always
environment:
# Startup mode, 'api' starts the API server.
@@ -135,7 +135,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
- image: langgenius/dify-api:0.5.6
+ image: langgenius/dify-api:0.5.7
restart: always
environment:
# Startup mode, 'worker' starts the Celery worker for processing the queue.
@@ -206,7 +206,7 @@ services:
# Frontend web application.
web:
- image: langgenius/dify-web:0.5.6
+ image: langgenius/dify-web:0.5.7
restart: always
environment:
EDITION: SELF_HOSTED
diff --git a/web/package.json b/web/package.json
index 72cc5bc967..f22c5df595 100644
--- a/web/package.json
+++ b/web/package.json
@@ -1,6 +1,6 @@
{
"name": "dify-web",
- "version": "0.5.6",
+ "version": "0.5.7",
"private": true,
"scripts": {
"dev": "next dev",
From d44b05a9e502b5578bebc8b6d139805b4a030803 Mon Sep 17 00:00:00 2001
From: Yeuoly <45712896+Yeuoly@users.noreply.github.com>
Date: Wed, 28 Feb 2024 23:19:08 +0800
Subject: [PATCH 35/36] feat: support auth type like basic bearer and custom
(#2613)
---
api/core/tools/provider/api_tool_provider.py | 15 +++++
api/core/tools/tool/api_tool.py | 11 ++++
.../config-credentials.tsx | 56 ++++++++++++++++---
.../edit-custom-collection-modal/index.tsx | 5 +-
web/app/components/tools/tool-list/index.tsx | 6 +-
web/app/components/tools/types.ts | 7 +++
web/i18n/en-US/tools.ts | 9 +++
web/i18n/pt-BR/tools.ts | 7 +++
web/i18n/uk-UA/tools.ts | 7 +++
web/i18n/zh-Hans/tools.ts | 9 +++
10 files changed, 122 insertions(+), 10 deletions(-)
diff --git a/api/core/tools/provider/api_tool_provider.py b/api/core/tools/provider/api_tool_provider.py
index 13f4bc2c3d..eb839e9341 100644
--- a/api/core/tools/provider/api_tool_provider.py
+++ b/api/core/tools/provider/api_tool_provider.py
@@ -55,6 +55,21 @@ class ApiBasedToolProviderController(ToolProviderController):
en_US='The api key',
zh_Hans='api key的值'
)
+ ),
+ 'api_key_header_prefix': ToolProviderCredentials(
+ name='api_key_header_prefix',
+ required=False,
+ default='basic',
+ type=ToolProviderCredentials.CredentialsType.SELECT,
+ help=I18nObject(
+ en_US='The prefix of the api key header',
+ zh_Hans='api key header 的前缀'
+ ),
+ options=[
+ ToolCredentialsOption(value='basic', label=I18nObject(en_US='Basic', zh_Hans='Basic')),
+ ToolCredentialsOption(value='bearer', label=I18nObject(en_US='Bearer', zh_Hans='Bearer')),
+ ToolCredentialsOption(value='custom', label=I18nObject(en_US='Custom', zh_Hans='Custom'))
+ ]
)
}
elif auth_type == ApiProviderAuthType.NONE:
diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py
index 2a1ee92e78..781eff13b4 100644
--- a/api/core/tools/tool/api_tool.py
+++ b/api/core/tools/tool/api_tool.py
@@ -62,6 +62,17 @@ class ApiTool(Tool):
if 'api_key_value' not in credentials:
raise ToolProviderCredentialValidationError('Missing api_key_value')
+ elif not isinstance(credentials['api_key_value'], str):
+ raise ToolProviderCredentialValidationError('api_key_value must be a string')
+
+ if 'api_key_header_prefix' in credentials:
+ api_key_header_prefix = credentials['api_key_header_prefix']
+ if api_key_header_prefix == 'basic':
+ credentials['api_key_value'] = f'Basic {credentials["api_key_value"]}'
+ elif api_key_header_prefix == 'bearer':
+ credentials['api_key_value'] = f'Bearer {credentials["api_key_value"]}'
+ elif api_key_header_prefix == 'custom':
+ pass
headers[api_key_header] = credentials['api_key_value']
diff --git a/web/app/components/tools/edit-custom-collection-modal/config-credentials.tsx b/web/app/components/tools/edit-custom-collection-modal/config-credentials.tsx
index 1deef1b531..9da0ff7dcc 100644
--- a/web/app/components/tools/edit-custom-collection-modal/config-credentials.tsx
+++ b/web/app/components/tools/edit-custom-collection-modal/config-credentials.tsx
@@ -3,11 +3,13 @@ import type { FC } from 'react'
import React from 'react'
import { useTranslation } from 'react-i18next'
import cn from 'classnames'
+import Tooltip from '../../base/tooltip'
+import { HelpCircle } from '../../base/icons/src/vender/line/general'
import type { Credential } from '@/app/components/tools/types'
import Drawer from '@/app/components/base/drawer-plus'
import Button from '@/app/components/base/button'
import Radio from '@/app/components/base/radio/ui'
-import { AuthType } from '@/app/components/tools/types'
+import { AuthHeaderPrefix, AuthType } from '@/app/components/tools/types'
type Props = {
credential: Credential
@@ -18,9 +20,9 @@ const keyClassNames = 'py-2 leading-5 text-sm font-medium text-gray-900'
type ItemProps = {
text: string
- value: AuthType
+ value: AuthType | AuthHeaderPrefix
isChecked: boolean
- onClick: (value: AuthType) => void
+ onClick: (value: AuthType | AuthHeaderPrefix) => void
}
const SelectItem: FC = ({ text, value, isChecked, onClick }) => {
@@ -31,7 +33,6 @@ const SelectItem: FC = ({ text, value, isChecked, onClick }) => {
>
{text}
-
)
}
@@ -43,6 +44,7 @@ const ConfigCredential: FC
= ({
}) => {
const { t } = useTranslation()
const [tempCredential, setTempCredential] = React.useState(credential)
+
return (
= ({
text={t('tools.createTool.authMethod.types.none')}
value={AuthType.none}
isChecked={tempCredential.auth_type === AuthType.none}
- onClick={value => setTempCredential({ ...tempCredential, auth_type: value })}
+ onClick={value => setTempCredential({ ...tempCredential, auth_type: value as AuthType })}
/>
setTempCredential({ ...tempCredential, auth_type: value })}
+ onClick={value => setTempCredential({
+ ...tempCredential,
+ auth_type: value as AuthType,
+ api_key_header: tempCredential.api_key_header || 'Authorization',
+ api_key_value: tempCredential.api_key_value || '',
+ api_key_header_prefix: tempCredential.api_key_header_prefix || AuthHeaderPrefix.custom,
+ })}
/>
{tempCredential.auth_type === AuthType.apiKey && (
<>
+
{t('tools.createTool.authHeaderPrefix.title')}
+
+ setTempCredential({ ...tempCredential, api_key_header_prefix: value as AuthHeaderPrefix })}
+ />
+ setTempCredential({ ...tempCredential, api_key_header_prefix: value as AuthHeaderPrefix })}
+ />
+ setTempCredential({ ...tempCredential, api_key_header_prefix: value as AuthHeaderPrefix })}
+ />
+
-
{t('tools.createTool.authMethod.key')}
+
+ {t('tools.createTool.authMethod.key')}
+
+ {t('tools.createTool.authMethod.keyTooltip')}
+
+ }
+ >
+
+
+
setTempCredential({ ...tempCredential, api_key_header: e.target.value })}
@@ -83,7 +124,6 @@ const ConfigCredential: FC
= ({
placeholder={t('tools.createTool.authMethod.types.apiKeyPlaceholder')!}
/>
-
{t('tools.createTool.authMethod.value')}
= ({
const { t } = useTranslation()
const isAdd = !payload
const isEdit = !!payload
+
const [editFirst, setEditFirst] = useState(!isAdd)
const [paramsSchemas, setParamsSchemas] = useState
(payload?.tools || [])
const [customCollection, setCustomCollection, getCustomCollection] = useGetState(isAdd
@@ -44,6 +45,8 @@ const EditCustomCollectionModal: FC = ({
provider: '',
credentials: {
auth_type: AuthType.none,
+ api_key_header: 'Authorization',
+ api_key_header_prefix: AuthHeaderPrefix.basic,
},
icon: {
content: '🕵️',
diff --git a/web/app/components/tools/tool-list/index.tsx b/web/app/components/tools/tool-list/index.tsx
index 58fcf5613b..3bee3292e6 100644
--- a/web/app/components/tools/tool-list/index.tsx
+++ b/web/app/components/tools/tool-list/index.tsx
@@ -3,7 +3,7 @@ import type { FC } from 'react'
import React, { useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next'
import cn from 'classnames'
-import { CollectionType, LOC } from '../types'
+import { AuthHeaderPrefix, AuthType, CollectionType, LOC } from '../types'
import type { Collection, CustomCollectionBackend, Tool } from '../types'
import Loading from '../../base/loading'
import { ArrowNarrowRight } from '../../base/icons/src/vender/line/arrows'
@@ -53,6 +53,10 @@ const ToolList: FC = ({
(async () => {
if (collection.type === CollectionType.custom) {
const res = await fetchCustomCollection(collection.name)
+ if (res.credentials.auth_type === AuthType.apiKey && !res.credentials.api_key_header_prefix) {
+ if (res.credentials.api_key_value)
+ res.credentials.api_key_header_prefix = AuthHeaderPrefix.custom
+ }
setCustomCollection({
...res,
provider: collection.name,
diff --git a/web/app/components/tools/types.ts b/web/app/components/tools/types.ts
index e06e011767..389276e81c 100644
--- a/web/app/components/tools/types.ts
+++ b/web/app/components/tools/types.ts
@@ -9,10 +9,17 @@ export enum AuthType {
apiKey = 'api_key',
}
+export enum AuthHeaderPrefix {
+ basic = 'basic',
+ bearer = 'bearer',
+ custom = 'custom',
+}
+
export type Credential = {
'auth_type': AuthType
'api_key_header'?: string
'api_key_value'?: string
+ 'api_key_header_prefix'?: AuthHeaderPrefix
}
export enum CollectionType {
diff --git a/web/i18n/en-US/tools.ts b/web/i18n/en-US/tools.ts
index 9746fab8bc..30e075210c 100644
--- a/web/i18n/en-US/tools.ts
+++ b/web/i18n/en-US/tools.ts
@@ -51,6 +51,7 @@ const translation = {
authMethod: {
title: 'Authorization method',
type: 'Authorization type',
+ keyTooltip: 'Http Header Key, You can leave it with "Authorization" if you have no idea what it is or set it to a custom value',
types: {
none: 'None',
api_key: 'API Key',
@@ -60,6 +61,14 @@ const translation = {
key: 'Key',
value: 'Value',
},
+ authHeaderPrefix: {
+ title: 'Auth Type',
+ types: {
+ basic: 'Basic',
+ bearer: 'Bearer',
+ custom: 'Custom',
+ },
+ },
privacyPolicy: 'Privacy policy',
privacyPolicyPlaceholder: 'Please enter privacy policy',
},
diff --git a/web/i18n/pt-BR/tools.ts b/web/i18n/pt-BR/tools.ts
index 9e2da08a1a..3434bd15ee 100644
--- a/web/i18n/pt-BR/tools.ts
+++ b/web/i18n/pt-BR/tools.ts
@@ -58,6 +58,13 @@ const translation = {
key: 'Chave',
value: 'Valor',
},
+ authHeaderPrefix: {
+ types: {
+ basic: 'Basic',
+ bearer: 'Bearer',
+ custom: 'Custom',
+ },
+ },
privacyPolicy: 'Política de Privacidade',
privacyPolicyPlaceholder: 'Digite a política de privacidade',
},
diff --git a/web/i18n/uk-UA/tools.ts b/web/i18n/uk-UA/tools.ts
index 56b4371cfb..307149c386 100644
--- a/web/i18n/uk-UA/tools.ts
+++ b/web/i18n/uk-UA/tools.ts
@@ -58,6 +58,13 @@ const translation = {
key: 'Ключ',
value: 'Значення',
},
+ authHeaderPrefix: {
+ types: {
+ basic: 'Basic',
+ bearer: 'Bearer',
+ custom: 'Custom',
+ },
+ },
privacyPolicy: 'Політика конфіденційності',
privacyPolicyPlaceholder: 'Введіть політику конфіденційності',
},
diff --git a/web/i18n/zh-Hans/tools.ts b/web/i18n/zh-Hans/tools.ts
index ff3b5c0fb8..c709d62547 100644
--- a/web/i18n/zh-Hans/tools.ts
+++ b/web/i18n/zh-Hans/tools.ts
@@ -51,6 +51,7 @@ const translation = {
authMethod: {
title: '鉴权方法',
type: '鉴权类型',
+ keyTooltip: 'HTTP 头部名称,如果你不知道是什么,可以将其保留为 Authorization 或设置为自定义值',
types: {
none: '无',
api_key: 'API Key',
@@ -60,6 +61,14 @@ const translation = {
key: '键',
value: '值',
},
+ authHeaderPrefix: {
+ title: '鉴权头部前缀',
+ types: {
+ basic: 'Basic',
+ bearer: 'Bearer',
+ custom: 'Custom',
+ },
+ },
privacyPolicy: '隐私协议',
privacyPolicyPlaceholder: '请输入隐私协议',
},
From dd961985f058387d3d9c5b63a956c803093662d2 Mon Sep 17 00:00:00 2001
From: takatost
Date: Wed, 28 Feb 2024 23:32:47 +0800
Subject: [PATCH 36/36] refactor: remove unused codes, move core/agent module
into dataset retrieval feature (#2614)
---
api/core/agent/agent/calc_token_mixin.py | 49 --
api/core/agent/agent/openai_function_call.py | 361 --------------
api/core/agent/agent/structured_chat.py | 306 ------------
api/core/app_runner/assistant_app_runner.py | 40 +-
api/core/app_runner/basic_app_runner.py | 2 +-
api/core/entities/agent_entities.py | 8 +
api/core/features/agent_runner.py | 199 --------
.../dataset_retrieval}/__init__.py | 0
.../dataset_retrieval/agent}/__init__.py | 0
.../agent/agent_llm_callback.py | 0
.../dataset_retrieval/agent/fake_llm.py} | 0
.../dataset_retrieval/agent}/llm_chain.py | 4 +-
.../agent/multi_dataset_router_agent.py | 2 +-
.../agent/output_parser/__init__.py} | 0
.../agent/output_parser/structured_chat.py | 0
.../structed_multi_dataset_router_agent.py | 2 +-
.../agent_based_dataset_executor.py} | 42 +-
.../dataset_retrieval.py | 3 +-
api/core/third_party/spark/spark_llm.py | 189 --------
api/core/tool/current_datetime_tool.py | 24 -
api/core/tool/provider/base.py | 63 ---
api/core/tool/provider/errors.py | 2 -
api/core/tool/provider/serpapi_provider.py | 77 ---
.../tool/provider/tool_provider_service.py | 43 --
api/core/tool/serpapi_wrapper.py | 51 --
api/core/tool/web_reader_tool.py | 443 ------------------
api/core/tools/tool/dataset_retriever_tool.py | 36 +-
api/core/tools/utils/web_reader_tool.py | 109 -----
api/services/app_model_config_service.py | 2 +-
29 files changed, 41 insertions(+), 2016 deletions(-)
delete mode 100644 api/core/agent/agent/calc_token_mixin.py
delete mode 100644 api/core/agent/agent/openai_function_call.py
delete mode 100644 api/core/agent/agent/structured_chat.py
create mode 100644 api/core/entities/agent_entities.py
delete mode 100644 api/core/features/agent_runner.py
rename api/core/{third_party/langchain/llms => features/dataset_retrieval}/__init__.py (100%)
rename api/core/{third_party/spark => features/dataset_retrieval/agent}/__init__.py (100%)
rename api/core/{agent => features/dataset_retrieval}/agent/agent_llm_callback.py (100%)
rename api/core/{third_party/langchain/llms/fake.py => features/dataset_retrieval/agent/fake_llm.py} (100%)
rename api/core/{chain => features/dataset_retrieval/agent}/llm_chain.py (91%)
rename api/core/{agent => features/dataset_retrieval}/agent/multi_dataset_router_agent.py (98%)
rename api/core/{data_loader/file_extractor.py => features/dataset_retrieval/agent/output_parser/__init__.py} (100%)
rename api/core/{agent => features/dataset_retrieval}/agent/output_parser/structured_chat.py (100%)
rename api/core/{agent => features/dataset_retrieval}/agent/structed_multi_dataset_router_agent.py (99%)
rename api/core/{agent/agent_executor.py => features/dataset_retrieval/agent_based_dataset_executor.py} (69%)
rename api/core/features/{ => dataset_retrieval}/dataset_retrieval.py (97%)
delete mode 100644 api/core/third_party/spark/spark_llm.py
delete mode 100644 api/core/tool/current_datetime_tool.py
delete mode 100644 api/core/tool/provider/base.py
delete mode 100644 api/core/tool/provider/errors.py
delete mode 100644 api/core/tool/provider/serpapi_provider.py
delete mode 100644 api/core/tool/provider/tool_provider_service.py
delete mode 100644 api/core/tool/serpapi_wrapper.py
delete mode 100644 api/core/tool/web_reader_tool.py
diff --git a/api/core/agent/agent/calc_token_mixin.py b/api/core/agent/agent/calc_token_mixin.py
deleted file mode 100644
index 9c0f9c5b36..0000000000
--- a/api/core/agent/agent/calc_token_mixin.py
+++ /dev/null
@@ -1,49 +0,0 @@
-from typing import cast
-
-from core.entities.application_entities import ModelConfigEntity
-from core.model_runtime.entities.message_entities import PromptMessage
-from core.model_runtime.entities.model_entities import ModelPropertyKey
-from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
-
-
-class CalcTokenMixin:
-
- def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: list[PromptMessage], **kwargs) -> int:
- """
- Got the rest tokens available for the model after excluding messages tokens and completion max tokens
-
- :param model_config:
- :param messages:
- :return:
- """
- model_type_instance = model_config.provider_model_bundle.model_type_instance
- model_type_instance = cast(LargeLanguageModel, model_type_instance)
-
- model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
-
- max_tokens = 0
- for parameter_rule in model_config.model_schema.parameter_rules:
- if (parameter_rule.name == 'max_tokens'
- or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
- max_tokens = (model_config.parameters.get(parameter_rule.name)
- or model_config.parameters.get(parameter_rule.use_template)) or 0
-
- if model_context_tokens is None:
- return 0
-
- if max_tokens is None:
- max_tokens = 0
-
- prompt_tokens = model_type_instance.get_num_tokens(
- model_config.model,
- model_config.credentials,
- messages
- )
-
- rest_tokens = model_context_tokens - max_tokens - prompt_tokens
-
- return rest_tokens
-
-
-class ExceededLLMTokensLimitError(Exception):
- pass
diff --git a/api/core/agent/agent/openai_function_call.py b/api/core/agent/agent/openai_function_call.py
deleted file mode 100644
index 1f2d5f24b3..0000000000
--- a/api/core/agent/agent/openai_function_call.py
+++ /dev/null
@@ -1,361 +0,0 @@
-from collections.abc import Sequence
-from typing import Any, Optional, Union
-
-from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
-from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
-from langchain.callbacks.base import BaseCallbackManager
-from langchain.callbacks.manager import Callbacks
-from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken
-from langchain.memory.prompt import SUMMARY_PROMPT
-from langchain.prompts.chat import BaseMessagePromptTemplate
-from langchain.schema import (
- AgentAction,
- AgentFinish,
- AIMessage,
- BaseMessage,
- HumanMessage,
- SystemMessage,
- get_buffer_string,
-)
-from langchain.tools import BaseTool
-from pydantic import root_validator
-
-from core.agent.agent.agent_llm_callback import AgentLLMCallback
-from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
-from core.chain.llm_chain import LLMChain
-from core.entities.application_entities import ModelConfigEntity
-from core.entities.message_entities import lc_messages_to_prompt_messages
-from core.model_manager import ModelInstance
-from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
-from core.third_party.langchain.llms.fake import FakeLLM
-
-
-class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
- moving_summary_buffer: str = ""
- moving_summary_index: int = 0
- summary_model_config: ModelConfigEntity = None
- model_config: ModelConfigEntity
- agent_llm_callback: Optional[AgentLLMCallback] = None
-
- class Config:
- """Configuration for this pydantic object."""
-
- arbitrary_types_allowed = True
-
- @root_validator
- def validate_llm(cls, values: dict) -> dict:
- return values
-
- @classmethod
- def from_llm_and_tools(
- cls,
- model_config: ModelConfigEntity,
- tools: Sequence[BaseTool],
- callback_manager: Optional[BaseCallbackManager] = None,
- extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
- system_message: Optional[SystemMessage] = SystemMessage(
- content="You are a helpful AI assistant."
- ),
- agent_llm_callback: Optional[AgentLLMCallback] = None,
- **kwargs: Any,
- ) -> BaseSingleActionAgent:
- prompt = cls.create_prompt(
- extra_prompt_messages=extra_prompt_messages,
- system_message=system_message,
- )
- return cls(
- model_config=model_config,
- llm=FakeLLM(response=''),
- prompt=prompt,
- tools=tools,
- callback_manager=callback_manager,
- agent_llm_callback=agent_llm_callback,
- **kwargs,
- )
-
- def should_use_agent(self, query: str):
- """
- return should use agent
-
- :param query:
- :return:
- """
- original_max_tokens = 0
- for parameter_rule in self.model_config.model_schema.parameter_rules:
- if (parameter_rule.name == 'max_tokens'
- or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
- original_max_tokens = (self.model_config.parameters.get(parameter_rule.name)
- or self.model_config.parameters.get(parameter_rule.use_template)) or 0
-
- self.model_config.parameters['max_tokens'] = 40
-
- prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
- messages = prompt.to_messages()
-
- try:
- prompt_messages = lc_messages_to_prompt_messages(messages)
- model_instance = ModelInstance(
- provider_model_bundle=self.model_config.provider_model_bundle,
- model=self.model_config.model,
- )
-
- tools = []
- for function in self.functions:
- tool = PromptMessageTool(
- **function
- )
-
- tools.append(tool)
-
- result = model_instance.invoke_llm(
- prompt_messages=prompt_messages,
- tools=tools,
- stream=False,
- model_parameters={
- 'temperature': 0.2,
- 'top_p': 0.3,
- 'max_tokens': 1500
- }
- )
- except Exception as e:
- raise e
-
- self.model_config.parameters['max_tokens'] = original_max_tokens
-
- return True if result.message.tool_calls else False
-
- def plan(
- self,
- intermediate_steps: list[tuple[AgentAction, str]],
- callbacks: Callbacks = None,
- **kwargs: Any,
- ) -> Union[AgentAction, AgentFinish]:
- """Given input, decided what to do.
-
- Args:
- intermediate_steps: Steps the LLM has taken to date, along with observations
- **kwargs: User inputs.
-
- Returns:
- Action specifying what tool to use.
- """
- agent_scratchpad = _format_intermediate_steps(intermediate_steps)
- selected_inputs = {
- k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
- }
- full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
- prompt = self.prompt.format_prompt(**full_inputs)
- messages = prompt.to_messages()
-
- prompt_messages = lc_messages_to_prompt_messages(messages)
-
- # summarize messages if rest_tokens < 0
- try:
- prompt_messages = self.summarize_messages_if_needed(prompt_messages, functions=self.functions)
- except ExceededLLMTokensLimitError as e:
- return AgentFinish(return_values={"output": str(e)}, log=str(e))
-
- model_instance = ModelInstance(
- provider_model_bundle=self.model_config.provider_model_bundle,
- model=self.model_config.model,
- )
-
- tools = []
- for function in self.functions:
- tool = PromptMessageTool(
- **function
- )
-
- tools.append(tool)
-
- result = model_instance.invoke_llm(
- prompt_messages=prompt_messages,
- tools=tools,
- stream=False,
- callbacks=[self.agent_llm_callback] if self.agent_llm_callback else [],
- model_parameters={
- 'temperature': 0.2,
- 'top_p': 0.3,
- 'max_tokens': 1500
- }
- )
-
- ai_message = AIMessage(
- content=result.message.content or "",
- additional_kwargs={
- 'function_call': {
- 'id': result.message.tool_calls[0].id,
- **result.message.tool_calls[0].function.dict()
- } if result.message.tool_calls else None
- }
- )
- agent_decision = _parse_ai_message(ai_message)
-
- if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
- tool_inputs = agent_decision.tool_input
- if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
- tool_inputs['query'] = kwargs['input']
- agent_decision.tool_input = tool_inputs
-
- return agent_decision
-
- @classmethod
- def get_system_message(cls):
- return SystemMessage(content="You are a helpful AI assistant.\n"
- "The current date or current time you know is wrong.\n"
- "Respond directly if appropriate.")
-
- def return_stopped_response(
- self,
- early_stopping_method: str,
- intermediate_steps: list[tuple[AgentAction, str]],
- **kwargs: Any,
- ) -> AgentFinish:
- try:
- return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
- except ValueError:
- return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
-
- def summarize_messages_if_needed(self, messages: list[PromptMessage], **kwargs) -> list[PromptMessage]:
- # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
- rest_tokens = self.get_message_rest_tokens(
- self.model_config,
- messages,
- **kwargs
- )
-
- rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
- if rest_tokens >= 0:
- return messages
-
- system_message = None
- human_message = None
- should_summary_messages = []
- for message in messages:
- if isinstance(message, SystemMessage):
- system_message = message
- elif isinstance(message, HumanMessage):
- human_message = message
- else:
- should_summary_messages.append(message)
-
- if len(should_summary_messages) > 2:
- ai_message = should_summary_messages[-2]
- function_message = should_summary_messages[-1]
- should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
- self.moving_summary_index = len(should_summary_messages)
- else:
- error_msg = "Exceeded LLM tokens limit, stopped."
- raise ExceededLLMTokensLimitError(error_msg)
-
- new_messages = [system_message, human_message]
-
- if self.moving_summary_index == 0:
- should_summary_messages.insert(0, human_message)
-
- self.moving_summary_buffer = self.predict_new_summary(
- messages=should_summary_messages,
- existing_summary=self.moving_summary_buffer
- )
-
- new_messages.append(AIMessage(content=self.moving_summary_buffer))
- new_messages.append(ai_message)
- new_messages.append(function_message)
-
- return new_messages
-
- def predict_new_summary(
- self, messages: list[BaseMessage], existing_summary: str
- ) -> str:
- new_lines = get_buffer_string(
- messages,
- human_prefix="Human",
- ai_prefix="AI",
- )
-
- chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
- return chain.predict(summary=existing_summary, new_lines=new_lines)
-
- def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: list[BaseMessage], **kwargs) -> int:
- """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
-
- Official documentation: https://github.com/openai/openai-cookbook/blob/
- main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
- if model_config.provider == 'azure_openai':
- model = model_config.model
- model = model.replace("gpt-35", "gpt-3.5")
- else:
- model = model_config.credentials.get("base_model_name")
-
- tiktoken_ = _import_tiktoken()
- try:
- encoding = tiktoken_.encoding_for_model(model)
- except KeyError:
- model = "cl100k_base"
- encoding = tiktoken_.get_encoding(model)
-
- if model.startswith("gpt-3.5-turbo"):
- # every message follows {role/name}\n{content}\n
- tokens_per_message = 4
- # if there's a name, the role is omitted
- tokens_per_name = -1
- elif model.startswith("gpt-4"):
- tokens_per_message = 3
- tokens_per_name = 1
- else:
- raise NotImplementedError(
- f"get_num_tokens_from_messages() is not presently implemented "
- f"for model {model}."
- "See https://github.com/openai/openai-python/blob/main/chatml.md for "
- "information on how messages are converted to tokens."
- )
- num_tokens = 0
- for m in messages:
- message = _convert_message_to_dict(m)
- num_tokens += tokens_per_message
- for key, value in message.items():
- if key == "function_call":
- for f_key, f_value in value.items():
- num_tokens += len(encoding.encode(f_key))
- num_tokens += len(encoding.encode(f_value))
- else:
- num_tokens += len(encoding.encode(value))
-
- if key == "name":
- num_tokens += tokens_per_name
- # every reply is primed with assistant
- num_tokens += 3
-
- if kwargs.get('functions'):
- for function in kwargs.get('functions'):
- num_tokens += len(encoding.encode('name'))
- num_tokens += len(encoding.encode(function.get("name")))
- num_tokens += len(encoding.encode('description'))
- num_tokens += len(encoding.encode(function.get("description")))
- parameters = function.get("parameters")
- num_tokens += len(encoding.encode('parameters'))
- if 'title' in parameters:
- num_tokens += len(encoding.encode('title'))
- num_tokens += len(encoding.encode(parameters.get("title")))
- num_tokens += len(encoding.encode('type'))
- num_tokens += len(encoding.encode(parameters.get("type")))
- if 'properties' in parameters:
- num_tokens += len(encoding.encode('properties'))
- for key, value in parameters.get('properties').items():
- num_tokens += len(encoding.encode(key))
- for field_key, field_value in value.items():
- num_tokens += len(encoding.encode(field_key))
- if field_key == 'enum':
- for enum_field in field_value:
- num_tokens += 3
- num_tokens += len(encoding.encode(enum_field))
- else:
- num_tokens += len(encoding.encode(field_key))
- num_tokens += len(encoding.encode(str(field_value)))
- if 'required' in parameters:
- num_tokens += len(encoding.encode('required'))
- for required_field in parameters['required']:
- num_tokens += 3
- num_tokens += len(encoding.encode(required_field))
-
- return num_tokens
diff --git a/api/core/agent/agent/structured_chat.py b/api/core/agent/agent/structured_chat.py
deleted file mode 100644
index e1be624204..0000000000
--- a/api/core/agent/agent/structured_chat.py
+++ /dev/null
@@ -1,306 +0,0 @@
-import re
-from collections.abc import Sequence
-from typing import Any, Optional, Union, cast
-
-from langchain import BasePromptTemplate, PromptTemplate
-from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
-from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
-from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
-from langchain.callbacks.base import BaseCallbackManager
-from langchain.callbacks.manager import Callbacks
-from langchain.memory.prompt import SUMMARY_PROMPT
-from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
-from langchain.schema import (
- AgentAction,
- AgentFinish,
- AIMessage,
- BaseMessage,
- HumanMessage,
- OutputParserException,
- get_buffer_string,
-)
-from langchain.tools import BaseTool
-
-from core.agent.agent.agent_llm_callback import AgentLLMCallback
-from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
-from core.chain.llm_chain import LLMChain
-from core.entities.application_entities import ModelConfigEntity
-from core.entities.message_entities import lc_messages_to_prompt_messages
-
-FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
-The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
-Valid "action" values: "Final Answer" or {tool_names}
-
-Provide only ONE action per $JSON_BLOB, as shown:
-
-```
-{{{{
- "action": $TOOL_NAME,
- "action_input": $INPUT
-}}}}
-```
-
-Follow this format:
-
-Question: input question to answer
-Thought: consider previous and subsequent steps
-Action:
-```
-$JSON_BLOB
-```
-Observation: action result
-... (repeat Thought/Action/Observation N times)
-Thought: I know what to respond
-Action:
-```
-{{{{
- "action": "Final Answer",
- "action_input": "Final response to human"
-}}}}
-```"""
-
-
-class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
- moving_summary_buffer: str = ""
- moving_summary_index: int = 0
- summary_model_config: ModelConfigEntity = None
-
- class Config:
- """Configuration for this pydantic object."""
-
- arbitrary_types_allowed = True
-
- def should_use_agent(self, query: str):
- """
- return should use agent
- Using the ReACT mode to determine whether an agent is needed is costly,
- so it's better to just use an Agent for reasoning, which is cheaper.
-
- :param query:
- :return:
- """
- return True
-
- def plan(
- self,
- intermediate_steps: list[tuple[AgentAction, str]],
- callbacks: Callbacks = None,
- **kwargs: Any,
- ) -> Union[AgentAction, AgentFinish]:
- """Given input, decided what to do.
-
- Args:
- intermediate_steps: Steps the LLM has taken to date,
- along with observatons
- callbacks: Callbacks to run.
- **kwargs: User inputs.
-
- Returns:
- Action specifying what tool to use.
- """
- full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
- prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)])
-
- messages = []
- if prompts:
- messages = prompts[0].to_messages()
-
- prompt_messages = lc_messages_to_prompt_messages(messages)
-
- rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_config, prompt_messages)
- if rest_tokens < 0:
- full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
-
- try:
- full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
- except Exception as e:
- raise e
-
- try:
- agent_decision = self.output_parser.parse(full_output)
- if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
- tool_inputs = agent_decision.tool_input
- if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
- tool_inputs['query'] = kwargs['input']
- agent_decision.tool_input = tool_inputs
- return agent_decision
- except OutputParserException:
- return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
- "I don't know how to respond to that."}, "")
-
- def summarize_messages(self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs):
- if len(intermediate_steps) >= 2 and self.summary_model_config:
- should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
- should_summary_messages = [AIMessage(content=observation)
- for _, observation in should_summary_intermediate_steps]
- if self.moving_summary_index == 0:
- should_summary_messages.insert(0, HumanMessage(content=kwargs.get("input")))
-
- self.moving_summary_index = len(intermediate_steps)
- else:
- error_msg = "Exceeded LLM tokens limit, stopped."
- raise ExceededLLMTokensLimitError(error_msg)
-
- if self.moving_summary_buffer and 'chat_history' in kwargs:
- kwargs["chat_history"].pop()
-
- self.moving_summary_buffer = self.predict_new_summary(
- messages=should_summary_messages,
- existing_summary=self.moving_summary_buffer
- )
-
- if 'chat_history' in kwargs:
- kwargs["chat_history"].append(AIMessage(content=self.moving_summary_buffer))
-
- return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
-
- def predict_new_summary(
- self, messages: list[BaseMessage], existing_summary: str
- ) -> str:
- new_lines = get_buffer_string(
- messages,
- human_prefix="Human",
- ai_prefix="AI",
- )
-
- chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
- return chain.predict(summary=existing_summary, new_lines=new_lines)
-
- @classmethod
- def create_prompt(
- cls,
- tools: Sequence[BaseTool],
- prefix: str = PREFIX,
- suffix: str = SUFFIX,
- human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
- format_instructions: str = FORMAT_INSTRUCTIONS,
- input_variables: Optional[list[str]] = None,
- memory_prompts: Optional[list[BasePromptTemplate]] = None,
- ) -> BasePromptTemplate:
- tool_strings = []
- for tool in tools:
- args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
- tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
- formatted_tools = "\n".join(tool_strings)
- tool_names = ", ".join([('"' + tool.name + '"') for tool in tools])
- format_instructions = format_instructions.format(tool_names=tool_names)
- template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
- if input_variables is None:
- input_variables = ["input", "agent_scratchpad"]
- _memory_prompts = memory_prompts or []
- messages = [
- SystemMessagePromptTemplate.from_template(template),
- *_memory_prompts,
- HumanMessagePromptTemplate.from_template(human_message_template),
- ]
- return ChatPromptTemplate(input_variables=input_variables, messages=messages)
-
- @classmethod
- def create_completion_prompt(
- cls,
- tools: Sequence[BaseTool],
- prefix: str = PREFIX,
- format_instructions: str = FORMAT_INSTRUCTIONS,
- input_variables: Optional[list[str]] = None,
- ) -> PromptTemplate:
- """Create prompt in the style of the zero shot agent.
-
- Args:
- tools: List of tools the agent will have access to, used to format the
- prompt.
- prefix: String to put before the list of tools.
- input_variables: List of input variables the final prompt will expect.
-
- Returns:
- A PromptTemplate with the template assembled from the pieces here.
- """
- suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
-Question: {input}
-Thought: {agent_scratchpad}
-"""
-
- tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
- tool_names = ", ".join([tool.name for tool in tools])
- format_instructions = format_instructions.format(tool_names=tool_names)
- template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
- if input_variables is None:
- input_variables = ["input", "agent_scratchpad"]
- return PromptTemplate(template=template, input_variables=input_variables)
-
- def _construct_scratchpad(
- self, intermediate_steps: list[tuple[AgentAction, str]]
- ) -> str:
- agent_scratchpad = ""
- for action, observation in intermediate_steps:
- agent_scratchpad += action.log
- agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
-
- if not isinstance(agent_scratchpad, str):
- raise ValueError("agent_scratchpad should be of type string.")
- if agent_scratchpad:
- llm_chain = cast(LLMChain, self.llm_chain)
- if llm_chain.model_config.mode == "chat":
- return (
- f"This was your previous work "
- f"(but I haven't seen any of it! I only see what "
- f"you return as final answer):\n{agent_scratchpad}"
- )
- else:
- return agent_scratchpad
- else:
- return agent_scratchpad
-
- @classmethod
- def from_llm_and_tools(
- cls,
- model_config: ModelConfigEntity,
- tools: Sequence[BaseTool],
- callback_manager: Optional[BaseCallbackManager] = None,
- output_parser: Optional[AgentOutputParser] = None,
- prefix: str = PREFIX,
- suffix: str = SUFFIX,
- human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
- format_instructions: str = FORMAT_INSTRUCTIONS,
- input_variables: Optional[list[str]] = None,
- memory_prompts: Optional[list[BasePromptTemplate]] = None,
- agent_llm_callback: Optional[AgentLLMCallback] = None,
- **kwargs: Any,
- ) -> Agent:
- """Construct an agent from an LLM and tools."""
- cls._validate_tools(tools)
- if model_config.mode == "chat":
- prompt = cls.create_prompt(
- tools,
- prefix=prefix,
- suffix=suffix,
- human_message_template=human_message_template,
- format_instructions=format_instructions,
- input_variables=input_variables,
- memory_prompts=memory_prompts,
- )
- else:
- prompt = cls.create_completion_prompt(
- tools,
- prefix=prefix,
- format_instructions=format_instructions,
- input_variables=input_variables,
- )
- llm_chain = LLMChain(
- model_config=model_config,
- prompt=prompt,
- callback_manager=callback_manager,
- agent_llm_callback=agent_llm_callback,
- parameters={
- 'temperature': 0.2,
- 'top_p': 0.3,
- 'max_tokens': 1500
- }
- )
- tool_names = [tool.name for tool in tools]
- _output_parser = output_parser
- return cls(
- llm_chain=llm_chain,
- allowed_tools=tool_names,
- output_parser=_output_parser,
- **kwargs,
- )
diff --git a/api/core/app_runner/assistant_app_runner.py b/api/core/app_runner/assistant_app_runner.py
index a4845d0ff1..d9a3447bda 100644
--- a/api/core/app_runner/assistant_app_runner.py
+++ b/api/core/app_runner/assistant_app_runner.py
@@ -1,4 +1,3 @@
-import json
import logging
from typing import cast
@@ -15,7 +14,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
from core.moderation.base import ModerationException
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
from extensions.ext_database import db
-from models.model import App, Conversation, Message, MessageAgentThought, MessageChain
+from models.model import App, Conversation, Message, MessageAgentThought
from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__)
@@ -173,11 +172,6 @@ class AssistantApplicationRunner(AppRunner):
# convert db variables to tool variables
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
-
- message_chain = self._init_message_chain(
- message=message,
- query=query
- )
# init model instance
model_instance = ModelInstance(
@@ -290,38 +284,6 @@ class AssistantApplicationRunner(AppRunner):
'pool': db_variables.variables
})
- def _init_message_chain(self, message: Message, query: str) -> MessageChain:
- """
- Init MessageChain
- :param message: message
- :param query: query
- :return:
- """
- message_chain = MessageChain(
- message_id=message.id,
- type="AgentExecutor",
- input=json.dumps({
- "input": query
- })
- )
-
- db.session.add(message_chain)
- db.session.commit()
-
- return message_chain
-
- def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
- """
- Save MessageChain
- :param message_chain: message chain
- :param output_text: output text
- :return:
- """
- message_chain.output = json.dumps({
- "output": output_text
- })
- db.session.commit()
-
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
message: Message) -> LLMUsage:
"""
diff --git a/api/core/app_runner/basic_app_runner.py b/api/core/app_runner/basic_app_runner.py
index e1972efb51..99df249ddf 100644
--- a/api/core/app_runner/basic_app_runner.py
+++ b/api/core/app_runner/basic_app_runner.py
@@ -5,7 +5,7 @@ from core.app_runner.app_runner import AppRunner
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity
-from core.features.dataset_retrieval import DatasetRetrievalFeature
+from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.moderation.base import ModerationException
diff --git a/api/core/entities/agent_entities.py b/api/core/entities/agent_entities.py
new file mode 100644
index 0000000000..0cdf8670c4
--- /dev/null
+++ b/api/core/entities/agent_entities.py
@@ -0,0 +1,8 @@
+from enum import Enum
+
+
+class PlanningStrategy(Enum):
+ ROUTER = 'router'
+ REACT_ROUTER = 'react_router'
+ REACT = 'react'
+ FUNCTION_CALL = 'function_call'
diff --git a/api/core/features/agent_runner.py b/api/core/features/agent_runner.py
deleted file mode 100644
index 7412d81281..0000000000
--- a/api/core/features/agent_runner.py
+++ /dev/null
@@ -1,199 +0,0 @@
-import logging
-from typing import Optional, cast
-
-from langchain.tools import BaseTool
-
-from core.agent.agent.agent_llm_callback import AgentLLMCallback
-from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy
-from core.application_queue_manager import ApplicationQueueManager
-from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
-from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
-from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
-from core.entities.application_entities import (
- AgentEntity,
- AppOrchestrationConfigEntity,
- InvokeFrom,
- ModelConfigEntity,
-)
-from core.memory.token_buffer_memory import TokenBufferMemory
-from core.model_runtime.entities.model_entities import ModelFeature, ModelType
-from core.model_runtime.model_providers import model_provider_factory
-from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
-from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
-from extensions.ext_database import db
-from models.dataset import Dataset
-from models.model import Message
-
-logger = logging.getLogger(__name__)
-
-
-class AgentRunnerFeature:
- def __init__(self, tenant_id: str,
- app_orchestration_config: AppOrchestrationConfigEntity,
- model_config: ModelConfigEntity,
- config: AgentEntity,
- queue_manager: ApplicationQueueManager,
- message: Message,
- user_id: str,
- agent_llm_callback: AgentLLMCallback,
- callback: AgentLoopGatherCallbackHandler,
- memory: Optional[TokenBufferMemory] = None,) -> None:
- """
- Agent runner
- :param tenant_id: tenant id
- :param app_orchestration_config: app orchestration config
- :param model_config: model config
- :param config: dataset config
- :param queue_manager: queue manager
- :param message: message
- :param user_id: user id
- :param agent_llm_callback: agent llm callback
- :param callback: callback
- :param memory: memory
- """
- self.tenant_id = tenant_id
- self.app_orchestration_config = app_orchestration_config
- self.model_config = model_config
- self.config = config
- self.queue_manager = queue_manager
- self.message = message
- self.user_id = user_id
- self.agent_llm_callback = agent_llm_callback
- self.callback = callback
- self.memory = memory
-
- def run(self, query: str,
- invoke_from: InvokeFrom) -> Optional[str]:
- """
- Retrieve agent loop result.
- :param query: query
- :param invoke_from: invoke from
- :return:
- """
- provider = self.config.provider
- model = self.config.model
- tool_configs = self.config.tools
-
- # check model is support tool calling
- provider_instance = model_provider_factory.get_provider_instance(provider=provider)
- model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
- model_type_instance = cast(LargeLanguageModel, model_type_instance)
-
- # get model schema
- model_schema = model_type_instance.get_model_schema(
- model=model,
- credentials=self.model_config.credentials
- )
-
- if not model_schema:
- return None
-
- planning_strategy = PlanningStrategy.REACT
- features = model_schema.features
- if features:
- if ModelFeature.TOOL_CALL in features \
- or ModelFeature.MULTI_TOOL_CALL in features:
- planning_strategy = PlanningStrategy.FUNCTION_CALL
-
- tools = self.to_tools(
- tool_configs=tool_configs,
- invoke_from=invoke_from,
- callbacks=[self.callback, DifyStdOutCallbackHandler()],
- )
-
- if len(tools) == 0:
- return None
-
- agent_configuration = AgentConfiguration(
- strategy=planning_strategy,
- model_config=self.model_config,
- tools=tools,
- memory=self.memory,
- max_iterations=10,
- max_execution_time=400.0,
- early_stopping_method="generate",
- agent_llm_callback=self.agent_llm_callback,
- callbacks=[self.callback, DifyStdOutCallbackHandler()]
- )
-
- agent_executor = AgentExecutor(agent_configuration)
-
- try:
- # check if should use agent
- should_use_agent = agent_executor.should_use_agent(query)
- if not should_use_agent:
- return None
-
- result = agent_executor.run(query)
- return result.output
- except Exception as ex:
- logger.exception("agent_executor run failed")
- return None
-
- def to_dataset_retriever_tool(self, tool_config: dict,
- invoke_from: InvokeFrom) \
- -> Optional[BaseTool]:
- """
- A dataset tool is a tool that can be used to retrieve information from a dataset
- :param tool_config: tool config
- :param invoke_from: invoke from
- """
- show_retrieve_source = self.app_orchestration_config.show_retrieve_source
-
- hit_callback = DatasetIndexToolCallbackHandler(
- queue_manager=self.queue_manager,
- app_id=self.message.app_id,
- message_id=self.message.id,
- user_id=self.user_id,
- invoke_from=invoke_from
- )
-
- # get dataset from dataset id
- dataset = db.session.query(Dataset).filter(
- Dataset.tenant_id == self.tenant_id,
- Dataset.id == tool_config.get("id")
- ).first()
-
- # pass if dataset is not available
- if not dataset:
- return None
-
- # pass if dataset is not available
- if (dataset and dataset.available_document_count == 0
- and dataset.available_document_count == 0):
- return None
-
- # get retrieval model config
- default_retrieval_model = {
- 'search_method': 'semantic_search',
- 'reranking_enable': False,
- 'reranking_model': {
- 'reranking_provider_name': '',
- 'reranking_model_name': ''
- },
- 'top_k': 2,
- 'score_threshold_enabled': False
- }
-
- retrieval_model_config = dataset.retrieval_model \
- if dataset.retrieval_model else default_retrieval_model
-
- # get top k
- top_k = retrieval_model_config['top_k']
-
- # get score threshold
- score_threshold = None
- score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
- if score_threshold_enabled:
- score_threshold = retrieval_model_config.get("score_threshold")
-
- tool = DatasetRetrieverTool.from_dataset(
- dataset=dataset,
- top_k=top_k,
- score_threshold=score_threshold,
- hit_callbacks=[hit_callback],
- return_resource=show_retrieve_source,
- retriever_from=invoke_from.to_source()
- )
-
- return tool
\ No newline at end of file
diff --git a/api/core/third_party/langchain/llms/__init__.py b/api/core/features/dataset_retrieval/__init__.py
similarity index 100%
rename from api/core/third_party/langchain/llms/__init__.py
rename to api/core/features/dataset_retrieval/__init__.py
diff --git a/api/core/third_party/spark/__init__.py b/api/core/features/dataset_retrieval/agent/__init__.py
similarity index 100%
rename from api/core/third_party/spark/__init__.py
rename to api/core/features/dataset_retrieval/agent/__init__.py
diff --git a/api/core/agent/agent/agent_llm_callback.py b/api/core/features/dataset_retrieval/agent/agent_llm_callback.py
similarity index 100%
rename from api/core/agent/agent/agent_llm_callback.py
rename to api/core/features/dataset_retrieval/agent/agent_llm_callback.py
diff --git a/api/core/third_party/langchain/llms/fake.py b/api/core/features/dataset_retrieval/agent/fake_llm.py
similarity index 100%
rename from api/core/third_party/langchain/llms/fake.py
rename to api/core/features/dataset_retrieval/agent/fake_llm.py
diff --git a/api/core/chain/llm_chain.py b/api/core/features/dataset_retrieval/agent/llm_chain.py
similarity index 91%
rename from api/core/chain/llm_chain.py
rename to api/core/features/dataset_retrieval/agent/llm_chain.py
index 86fb156292..e5155e15a0 100644
--- a/api/core/chain/llm_chain.py
+++ b/api/core/features/dataset_retrieval/agent/llm_chain.py
@@ -5,11 +5,11 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.schema import Generation, LLMResult
from langchain.schema.language_model import BaseLanguageModel
-from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages
+from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback
+from core.features.dataset_retrieval.agent.fake_llm import FakeLLM
from core.model_manager import ModelInstance
-from core.third_party.langchain.llms.fake import FakeLLM
class LLMChain(LCLLMChain):
diff --git a/api/core/agent/agent/multi_dataset_router_agent.py b/api/core/features/dataset_retrieval/agent/multi_dataset_router_agent.py
similarity index 98%
rename from api/core/agent/agent/multi_dataset_router_agent.py
rename to api/core/features/dataset_retrieval/agent/multi_dataset_router_agent.py
index eb594c3d21..59923202fd 100644
--- a/api/core/agent/agent/multi_dataset_router_agent.py
+++ b/api/core/features/dataset_retrieval/agent/multi_dataset_router_agent.py
@@ -12,9 +12,9 @@ from pydantic import root_validator
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages
+from core.features.dataset_retrieval.agent.fake_llm import FakeLLM
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessageTool
-from core.third_party.langchain.llms.fake import FakeLLM
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
diff --git a/api/core/data_loader/file_extractor.py b/api/core/features/dataset_retrieval/agent/output_parser/__init__.py
similarity index 100%
rename from api/core/data_loader/file_extractor.py
rename to api/core/features/dataset_retrieval/agent/output_parser/__init__.py
diff --git a/api/core/agent/agent/output_parser/structured_chat.py b/api/core/features/dataset_retrieval/agent/output_parser/structured_chat.py
similarity index 100%
rename from api/core/agent/agent/output_parser/structured_chat.py
rename to api/core/features/dataset_retrieval/agent/output_parser/structured_chat.py
diff --git a/api/core/agent/agent/structed_multi_dataset_router_agent.py b/api/core/features/dataset_retrieval/agent/structed_multi_dataset_router_agent.py
similarity index 99%
rename from api/core/agent/agent/structed_multi_dataset_router_agent.py
rename to api/core/features/dataset_retrieval/agent/structed_multi_dataset_router_agent.py
index e104bb01f9..e69302bfd6 100644
--- a/api/core/agent/agent/structed_multi_dataset_router_agent.py
+++ b/api/core/features/dataset_retrieval/agent/structed_multi_dataset_router_agent.py
@@ -12,8 +12,8 @@ from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, Sy
from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.tools import BaseTool
-from core.chain.llm_chain import LLMChain
from core.entities.application_entities import ModelConfigEntity
+from core.features.dataset_retrieval.agent.llm_chain import LLMChain
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
diff --git a/api/core/agent/agent_executor.py b/api/core/features/dataset_retrieval/agent_based_dataset_executor.py
similarity index 69%
rename from api/core/agent/agent_executor.py
rename to api/core/features/dataset_retrieval/agent_based_dataset_executor.py
index 70fe00ee13..588ccc91f5 100644
--- a/api/core/agent/agent_executor.py
+++ b/api/core/features/dataset_retrieval/agent_based_dataset_executor.py
@@ -1,4 +1,3 @@
-import enum
import logging
from typing import Optional, Union
@@ -8,14 +7,13 @@ from langchain.callbacks.manager import Callbacks
from langchain.tools import BaseTool
from pydantic import BaseModel, Extra
-from core.agent.agent.agent_llm_callback import AgentLLMCallback
-from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
-from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
-from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
-from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
-from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
+from core.entities.agent_entities import PlanningStrategy
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import prompt_messages_to_lc_messages
+from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback
+from core.features.dataset_retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
+from core.features.dataset_retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser
+from core.features.dataset_retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.helper import moderation
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.errors.invoke import InvokeError
@@ -23,13 +21,6 @@ from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import Datas
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
-class PlanningStrategy(str, enum.Enum):
- ROUTER = 'router'
- REACT_ROUTER = 'react_router'
- REACT = 'react'
- FUNCTION_CALL = 'function_call'
-
-
class AgentConfiguration(BaseModel):
strategy: PlanningStrategy
model_config: ModelConfigEntity
@@ -62,28 +53,7 @@ class AgentExecutor:
self.agent = self._init_agent()
def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
- if self.configuration.strategy == PlanningStrategy.REACT:
- agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
- model_config=self.configuration.model_config,
- tools=self.configuration.tools,
- output_parser=StructuredChatOutputParser(),
- summary_model_config=self.configuration.summary_model_config
- if self.configuration.summary_model_config else None,
- agent_llm_callback=self.configuration.agent_llm_callback,
- verbose=True
- )
- elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
- agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
- model_config=self.configuration.model_config,
- tools=self.configuration.tools,
- extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
- if self.configuration.memory else None, # used for read chat histories memory
- summary_model_config=self.configuration.summary_model_config
- if self.configuration.summary_model_config else None,
- agent_llm_callback=self.configuration.agent_llm_callback,
- verbose=True
- )
- elif self.configuration.strategy == PlanningStrategy.ROUTER:
+ if self.configuration.strategy == PlanningStrategy.ROUTER:
self.configuration.tools = [t for t in self.configuration.tools
if isinstance(t, DatasetRetrieverTool)
or isinstance(t, DatasetMultiRetrieverTool)]
diff --git a/api/core/features/dataset_retrieval.py b/api/core/features/dataset_retrieval/dataset_retrieval.py
similarity index 97%
rename from api/core/features/dataset_retrieval.py
rename to api/core/features/dataset_retrieval/dataset_retrieval.py
index 488a8ca8d0..3e54d8644d 100644
--- a/api/core/features/dataset_retrieval.py
+++ b/api/core/features/dataset_retrieval/dataset_retrieval.py
@@ -2,9 +2,10 @@ from typing import Optional, cast
from langchain.tools import BaseTool
-from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
+from core.entities.agent_entities import PlanningStrategy
from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity
+from core.features.dataset_retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
diff --git a/api/core/third_party/spark/spark_llm.py b/api/core/third_party/spark/spark_llm.py
deleted file mode 100644
index 5c97bba530..0000000000
--- a/api/core/third_party/spark/spark_llm.py
+++ /dev/null
@@ -1,189 +0,0 @@
-import base64
-import hashlib
-import hmac
-import json
-import queue
-import ssl
-from datetime import datetime
-from time import mktime
-from typing import Optional
-from urllib.parse import urlencode, urlparse
-from wsgiref.handlers import format_date_time
-
-import websocket
-
-
-class SparkLLMClient:
- def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
- domain = 'spark-api.xf-yun.com'
- endpoint = 'chat'
- if api_domain:
- domain = api_domain
- if model_name == 'spark-v3':
- endpoint = 'multimodal'
-
- model_api_configs = {
- 'spark': {
- 'version': 'v1.1',
- 'chat_domain': 'general'
- },
- 'spark-v2': {
- 'version': 'v2.1',
- 'chat_domain': 'generalv2'
- },
- 'spark-v3': {
- 'version': 'v3.1',
- 'chat_domain': 'generalv3'
- },
- 'spark-v3.5': {
- 'version': 'v3.5',
- 'chat_domain': 'generalv3.5'
- }
- }
-
- api_version = model_api_configs[model_name]['version']
-
- self.chat_domain = model_api_configs[model_name]['chat_domain']
- self.api_base = f"wss://{domain}/{api_version}/{endpoint}"
- self.app_id = app_id
- self.ws_url = self.create_url(
- urlparse(self.api_base).netloc,
- urlparse(self.api_base).path,
- self.api_base,
- api_key,
- api_secret
- )
-
- self.queue = queue.Queue()
- self.blocking_message = ''
-
- def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str:
- # generate timestamp by RFC1123
- now = datetime.now()
- date = format_date_time(mktime(now.timetuple()))
-
- signature_origin = "host: " + host + "\n"
- signature_origin += "date: " + date + "\n"
- signature_origin += "GET " + path + " HTTP/1.1"
-
- # encrypt using hmac-sha256
- signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
- digestmod=hashlib.sha256).digest()
-
- signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
-
- authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
-
- authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
-
- v = {
- "authorization": authorization,
- "date": date,
- "host": host
- }
- # generate url
- url = api_base + '?' + urlencode(v)
- return url
-
- def run(self, messages: list, user_id: str,
- model_kwargs: Optional[dict] = None, streaming: bool = False):
- websocket.enableTrace(False)
- ws = websocket.WebSocketApp(
- self.ws_url,
- on_message=self.on_message,
- on_error=self.on_error,
- on_close=self.on_close,
- on_open=self.on_open
- )
- ws.messages = messages
- ws.user_id = user_id
- ws.model_kwargs = model_kwargs
- ws.streaming = streaming
- ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
-
- def on_error(self, ws, error):
- self.queue.put({
- 'status_code': error.status_code,
- 'error': error.resp_body.decode('utf-8')
- })
- ws.close()
-
- def on_close(self, ws, close_status_code, close_reason):
- self.queue.put({'done': True})
-
- def on_open(self, ws):
- self.blocking_message = ''
- data = json.dumps(self.gen_params(
- messages=ws.messages,
- user_id=ws.user_id,
- model_kwargs=ws.model_kwargs
- ))
- ws.send(data)
-
- def on_message(self, ws, message):
- data = json.loads(message)
- code = data['header']['code']
- if code != 0:
- self.queue.put({
- 'status_code': 400,
- 'error': f"Code: {code}, Error: {data['header']['message']}"
- })
- ws.close()
- else:
- choices = data["payload"]["choices"]
- status = choices["status"]
- content = choices["text"][0]["content"]
- if ws.streaming:
- self.queue.put({'data': content})
- else:
- self.blocking_message += content
-
- if status == 2:
- if not ws.streaming:
- self.queue.put({'data': self.blocking_message})
- ws.close()
-
- def gen_params(self, messages: list, user_id: str,
- model_kwargs: Optional[dict] = None) -> dict:
- data = {
- "header": {
- "app_id": self.app_id,
- "uid": user_id
- },
- "parameter": {
- "chat": {
- "domain": self.chat_domain
- }
- },
- "payload": {
- "message": {
- "text": messages
- }
- }
- }
-
- if model_kwargs:
- data['parameter']['chat'].update(model_kwargs)
-
- return data
-
- def subscribe(self):
- while True:
- content = self.queue.get()
- if 'error' in content:
- if content['status_code'] == 401:
- raise SparkError('[Spark] The credentials you provided are incorrect. '
- 'Please double-check and fill them in again.')
- elif content['status_code'] == 403:
- raise SparkError("[Spark] Sorry, the credentials you provided are access denied. "
- "Please try again after obtaining the necessary permissions.")
- else:
- raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}")
-
- if 'data' not in content:
- break
- yield content
-
-
-class SparkError(Exception):
- pass
diff --git a/api/core/tool/current_datetime_tool.py b/api/core/tool/current_datetime_tool.py
deleted file mode 100644
index 208490a5bf..0000000000
--- a/api/core/tool/current_datetime_tool.py
+++ /dev/null
@@ -1,24 +0,0 @@
-from datetime import datetime
-
-from langchain.tools import BaseTool
-from pydantic import BaseModel, Field
-
-
-class DatetimeToolInput(BaseModel):
- type: str = Field(..., description="Type for current time, must be: datetime.")
-
-
-class DatetimeTool(BaseTool):
- """Tool for querying current datetime."""
- name: str = "current_datetime"
- args_schema: type[BaseModel] = DatetimeToolInput
- description: str = "A tool when you want to get the current date, time, week, month or year, " \
- "and the time zone is UTC. Result is \"