From c7ee0f2a93e9699c4071de67f1b58877f2eb6179 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 25 Jun 2025 10:18:20 +0800 Subject: [PATCH 01/12] fix(knowledge_base): Unchecked metadata name length (#21454) Signed-off-by: -LAN- --- api/services/metadata_service.py | 8 ++++++++ .../datasets/metadata/hooks/use-check-metadata-name.ts | 6 ++++++ web/i18n/en-US/dataset.ts | 1 + 3 files changed, 15 insertions(+) diff --git a/api/services/metadata_service.py b/api/services/metadata_service.py index 26d6d4ce18..cfcb121153 100644 --- a/api/services/metadata_service.py +++ b/api/services/metadata_service.py @@ -19,6 +19,10 @@ from services.entities.knowledge_entities.knowledge_entities import ( class MetadataService: @staticmethod def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata: + # check if metadata name is too long + if len(metadata_args.name) > 255: + raise ValueError("Metadata name cannot exceed 255 characters.") + # check if metadata name already exists if ( db.session.query(DatasetMetadata) @@ -42,6 +46,10 @@ class MetadataService: @staticmethod def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: # type: ignore + # check if metadata name is too long + if len(name) > 255: + raise ValueError("Metadata name cannot exceed 255 characters.") + lock_key = f"dataset_metadata_lock_{dataset_id}" # check if metadata name already exists if ( diff --git a/web/app/components/datasets/metadata/hooks/use-check-metadata-name.ts b/web/app/components/datasets/metadata/hooks/use-check-metadata-name.ts index 45fe7675f1..aea7efd5c8 100644 --- a/web/app/components/datasets/metadata/hooks/use-check-metadata-name.ts +++ b/web/app/components/datasets/metadata/hooks/use-check-metadata-name.ts @@ -18,6 +18,12 @@ const useCheckMetadataName = () => { } } + if (name.length > 255) { + return { + errorMsg: t(`${i18nPrefix}.tooLong`, { max: 255 }), + } + } + return { errorMsg: '', } diff --git a/web/i18n/en-US/dataset.ts b/web/i18n/en-US/dataset.ts index 3e251d1bf1..e5d3f5ff3e 100644 --- a/web/i18n/en-US/dataset.ts +++ b/web/i18n/en-US/dataset.ts @@ -183,6 +183,7 @@ const translation = { checkName: { empty: 'Metadata name cannot be empty', invalid: 'Metadata name can only contain lowercase letters, numbers, and underscores and must start with a lowercase letter', + tooLong: 'Metadata name cannot exceed {{max}} characters', }, batchEditMetadata: { editMetadata: 'Edit Metadata', From 5b33d086c65389c2bbddfa5c2f7312d9e07e2ac8 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 25 Jun 2025 10:26:21 +0800 Subject: [PATCH 02/12] chore: translate i18n files (#21459) Co-authored-by: laipz8200 <16485841+laipz8200@users.noreply.github.com> --- web/i18n/de-DE/dataset.ts | 1 + web/i18n/es-ES/dataset.ts | 1 + web/i18n/fa-IR/dataset.ts | 1 + web/i18n/fr-FR/dataset.ts | 1 + web/i18n/hi-IN/dataset.ts | 1 + web/i18n/it-IT/dataset.ts | 1 + web/i18n/ja-JP/dataset.ts | 1 + web/i18n/ko-KR/dataset.ts | 1 + web/i18n/pl-PL/dataset.ts | 1 + web/i18n/pt-BR/dataset.ts | 1 + web/i18n/ro-RO/dataset.ts | 1 + web/i18n/ru-RU/dataset.ts | 1 + web/i18n/sl-SI/dataset.ts | 1 + web/i18n/th-TH/dataset.ts | 1 + web/i18n/tr-TR/dataset.ts | 1 + web/i18n/uk-UA/dataset.ts | 1 + web/i18n/vi-VN/dataset.ts | 1 + web/i18n/zh-Hans/dataset.ts | 1 + web/i18n/zh-Hant/dataset.ts | 1 + 19 files changed, 19 insertions(+) diff --git a/web/i18n/de-DE/dataset.ts b/web/i18n/de-DE/dataset.ts index 4d513aab1b..3d2d924bac 100644 --- a/web/i18n/de-DE/dataset.ts +++ b/web/i18n/de-DE/dataset.ts @@ -179,6 +179,7 @@ const translation = { checkName: { empty: 'Der Metadatenname darf nicht leer sein.', invalid: 'Der Metadatenname darf nur Kleinbuchstaben, Zahlen und Unterstriche enthalten und muss mit einem Kleinbuchstaben beginnen.', + tooLong: 'Der Metadatenname darf {{max}} Zeichen nicht überschreiten.', }, batchEditMetadata: { editMetadata: 'Metadaten bearbeiten', diff --git a/web/i18n/es-ES/dataset.ts b/web/i18n/es-ES/dataset.ts index 69c395287d..16745b56d7 100644 --- a/web/i18n/es-ES/dataset.ts +++ b/web/i18n/es-ES/dataset.ts @@ -179,6 +179,7 @@ const translation = { checkName: { empty: 'El nombre de metadatos no puede estar vacío', invalid: 'El nombre de los metadatos solo puede contener letras minúsculas, números y guiones bajos, y debe comenzar con una letra minúscula.', + tooLong: 'El nombre de los metadatos no puede exceder {{max}} caracteres.', }, batchEditMetadata: { multipleValue: 'Valor Múltiple', diff --git a/web/i18n/fa-IR/dataset.ts b/web/i18n/fa-IR/dataset.ts index 8ab2fb9179..9b2069d778 100644 --- a/web/i18n/fa-IR/dataset.ts +++ b/web/i18n/fa-IR/dataset.ts @@ -179,6 +179,7 @@ const translation = { checkName: { invalid: 'نام متاداده فقط می‌تواند شامل حروف کوچک، اعداد و زیرخط‌ها باشد و باید با یک حرف کوچک آغاز شود.', empty: 'نام فراداده نمی‌تواند خالی باشد', + tooLong: 'نام متا داده نمی‌تواند بیشتر از {{max}} کاراکتر باشد', }, batchEditMetadata: { multipleValue: 'چندین ارزش', diff --git a/web/i18n/fr-FR/dataset.ts b/web/i18n/fr-FR/dataset.ts index c9739e20a1..a0858e998b 100644 --- a/web/i18n/fr-FR/dataset.ts +++ b/web/i18n/fr-FR/dataset.ts @@ -179,6 +179,7 @@ const translation = { checkName: { empty: 'Le nom des métadonnées ne peut pas être vide', invalid: 'Le nom des métadonnées ne peut contenir que des lettres minuscules, des chiffres et des tirets bas et doit commencer par une lettre minuscule.', + tooLong: 'Le nom des métadonnées ne peut pas dépasser {{max}} caractères', }, batchEditMetadata: { editMetadata: 'Modifier les métadonnées', diff --git a/web/i18n/hi-IN/dataset.ts b/web/i18n/hi-IN/dataset.ts index 9be333cdec..e94c777718 100644 --- a/web/i18n/hi-IN/dataset.ts +++ b/web/i18n/hi-IN/dataset.ts @@ -186,6 +186,7 @@ const translation = { checkName: { empty: 'मेटाडाटा का नाम खाली नहीं हो सकता', invalid: 'मेटाडेटा नाम में केवल छोटे अक्षर, संख्या और अंडरस्कोर शामिल हो सकते हैं और इसे छोटे अक्षर से शुरू होना चाहिए।', + tooLong: 'मेटाडेटा नाम {{max}} वर्णों से अधिक नहीं हो सकता', }, batchEditMetadata: { editMetadata: 'मेटाडेटा संपादित करें', diff --git a/web/i18n/it-IT/dataset.ts b/web/i18n/it-IT/dataset.ts index c2c4963371..9f66ee8e43 100644 --- a/web/i18n/it-IT/dataset.ts +++ b/web/i18n/it-IT/dataset.ts @@ -186,6 +186,7 @@ const translation = { checkName: { invalid: 'Il nome dei metadati può contenere solo lettere minuscole, numeri e underscore e deve iniziare con una lettera minuscola.', empty: 'Il nome dei metadati non può essere vuoto', + tooLong: 'Il nome dei metadati non può superare {{max}} caratteri.', }, batchEditMetadata: { multipleValue: 'Valore Multiplo', diff --git a/web/i18n/ja-JP/dataset.ts b/web/i18n/ja-JP/dataset.ts index 2bdc4a8d28..1a40f17afc 100644 --- a/web/i18n/ja-JP/dataset.ts +++ b/web/i18n/ja-JP/dataset.ts @@ -183,6 +183,7 @@ const translation = { checkName: { empty: 'メタデータ名を入力してください', invalid: 'メタデータ名は小文字、数字、アンダースコアのみを使用し、小文字で始める必要があります', + tooLong: 'メタデータ名は {{max}} 文字を超えることはできません', }, batchEditMetadata: { editMetadata: 'メタデータを編集', diff --git a/web/i18n/ko-KR/dataset.ts b/web/i18n/ko-KR/dataset.ts index 3eb4634194..ddf4d82450 100644 --- a/web/i18n/ko-KR/dataset.ts +++ b/web/i18n/ko-KR/dataset.ts @@ -178,6 +178,7 @@ const translation = { checkName: { empty: '메타데이터 이름은 비어 있을 수 없습니다.', invalid: '메타데이터 이름은 소문자, 숫자 및 밑줄만 포함할 수 있으며 소문자로 시작해야 합니다.', + tooLong: '메타데이터 이름은 {{max}}자를 초과할 수 없습니다.', }, batchEditMetadata: { multipleValue: '다중 값', diff --git a/web/i18n/pl-PL/dataset.ts b/web/i18n/pl-PL/dataset.ts index 3006c46c97..0a37698750 100644 --- a/web/i18n/pl-PL/dataset.ts +++ b/web/i18n/pl-PL/dataset.ts @@ -185,6 +185,7 @@ const translation = { checkName: { empty: 'Nazwa metadanych nie może być pusta', invalid: 'Nazwa metadanych może zawierać tylko małe litery, cyfry i podkreślenia oraz musi zaczynać się od małej litery', + tooLong: 'Nazwa metadanych nie może przekraczać {{max}} znaków', }, batchEditMetadata: { multipleValue: 'Wielokrotna wartość', diff --git a/web/i18n/pt-BR/dataset.ts b/web/i18n/pt-BR/dataset.ts index 7d5f75aae6..53b13b7567 100644 --- a/web/i18n/pt-BR/dataset.ts +++ b/web/i18n/pt-BR/dataset.ts @@ -179,6 +179,7 @@ const translation = { checkName: { empty: 'O nome dos metadados não pode estar vazio', invalid: 'O nome de metadata só pode conter letras minúsculas, números e sublinhados e deve começar com uma letra minúscula.', + tooLong: 'O nome dos metadados não pode exceder {{max}} caracteres.', }, batchEditMetadata: { editDocumentsNum: 'Editando {{num}} documentos', diff --git a/web/i18n/ro-RO/dataset.ts b/web/i18n/ro-RO/dataset.ts index dd7b2d6900..8305029a97 100644 --- a/web/i18n/ro-RO/dataset.ts +++ b/web/i18n/ro-RO/dataset.ts @@ -179,6 +179,7 @@ const translation = { checkName: { invalid: 'Numele metadatelor poate conține doar litere mici, cifre și underscore și trebuie să înceapă cu o literă mică.', empty: 'Numele metadatelor nu poate fi gol', + tooLong: 'Numele metadatelor nu poate depăși {{max}} caractere', }, batchEditMetadata: { multipleValue: 'Valoare multiplă', diff --git a/web/i18n/ru-RU/dataset.ts b/web/i18n/ru-RU/dataset.ts index 33a4cdf851..fd6839cd33 100644 --- a/web/i18n/ru-RU/dataset.ts +++ b/web/i18n/ru-RU/dataset.ts @@ -179,6 +179,7 @@ const translation = { checkName: { empty: 'Имя метаданных не может быть пустым', invalid: 'Имя метаданных может содержать только строчные буквы, цифры и знаки нижнего подчеркивания и должно начинаться со строчной буквы.', + tooLong: 'Имя метаданных не может превышать {{max}} символов', }, batchEditMetadata: { applyToAllSelectDocumentTip: 'Автоматически создайте все вышеуказанные редактируемые и новые метаданные для всех выбранных документов, иначе редактирование метаданных будет применяться только к документам с ними.', diff --git a/web/i18n/sl-SI/dataset.ts b/web/i18n/sl-SI/dataset.ts index a9f9ccbad7..20403bd722 100644 --- a/web/i18n/sl-SI/dataset.ts +++ b/web/i18n/sl-SI/dataset.ts @@ -179,6 +179,7 @@ const translation = { checkName: { empty: 'Ime metapodatkov ne more biti prazno', invalid: 'Ime metapodatkov lahko vsebuje samo male črke, številke in podčrtaje ter se mora začeti z malo črko.', + tooLong: 'Ime metapodatkov ne sme presegati {{max}} znakov', }, batchEditMetadata: { editMetadata: 'Uredi metapodatke', diff --git a/web/i18n/th-TH/dataset.ts b/web/i18n/th-TH/dataset.ts index 15ef381605..ede1be1609 100644 --- a/web/i18n/th-TH/dataset.ts +++ b/web/i18n/th-TH/dataset.ts @@ -178,6 +178,7 @@ const translation = { checkName: { invalid: 'ชื่อเมตาดาต้าต้องประกอบด้วยตัวอักษรตัวเล็กเท่านั้น เลข และขีดล่าง และต้องเริ่มต้นด้วยตัวอักษรตัวเล็ก', empty: 'ชื่อข้อมูลเมตาไม่สามารถเป็นค่าแEmpty', + tooLong: 'ชื่อเมตาดาต้าไม่สามารถเกิน {{max}} ตัวอักษร', }, batchEditMetadata: { multipleValue: 'หลายค่า', diff --git a/web/i18n/tr-TR/dataset.ts b/web/i18n/tr-TR/dataset.ts index 96f120c13d..1875d5bd40 100644 --- a/web/i18n/tr-TR/dataset.ts +++ b/web/i18n/tr-TR/dataset.ts @@ -179,6 +179,7 @@ const translation = { checkName: { empty: 'Meta veri adı boş olamaz', invalid: 'Meta verisi adı yalnızca küçük harfler, sayılar ve alt çizgiler içerebilir ve küçük bir harfle başlamalıdır.', + tooLong: 'Meta veri adı {{max}} karakteri geçemez', }, batchEditMetadata: { multipleValue: 'Birden Fazla Değer', diff --git a/web/i18n/uk-UA/dataset.ts b/web/i18n/uk-UA/dataset.ts index 84bc8d9ad1..cfa85c950c 100644 --- a/web/i18n/uk-UA/dataset.ts +++ b/web/i18n/uk-UA/dataset.ts @@ -180,6 +180,7 @@ const translation = { checkName: { empty: 'Ім\'я метаданих не може бути порожнім', invalid: 'Ім\'я метаданих може містити лише малі літери, цифри та підкреслення, і повинно починатися з малої літери', + tooLong: 'Назва метаданих не може перевищувати {{max}} символів', }, batchEditMetadata: { editMetadata: 'Редагувати метадані', diff --git a/web/i18n/vi-VN/dataset.ts b/web/i18n/vi-VN/dataset.ts index 4227d712a0..41260c982c 100644 --- a/web/i18n/vi-VN/dataset.ts +++ b/web/i18n/vi-VN/dataset.ts @@ -179,6 +179,7 @@ const translation = { checkName: { invalid: 'Tên siêu dữ liệu chỉ có thể chứa chữ cái thường, số và dấu gạch dưới, và phải bắt đầu bằng một chữ cái thường.', empty: 'Tên siêu dữ liệu không được để trống', + tooLong: 'Tên siêu dữ liệu không được vượt quá {{max}} ký tự', }, batchEditMetadata: { applyToAllSelectDocumentTip: 'Tự động tạo tất cả các siêu dữ liệu đã chỉnh sửa và mới cho tất cả các tài liệu được chọn, nếu không, việc chỉnh sửa siêu dữ liệu sẽ chỉ áp dụng cho các tài liệu có nó.', diff --git a/web/i18n/zh-Hans/dataset.ts b/web/i18n/zh-Hans/dataset.ts index 9b5691ae24..fa1348c527 100644 --- a/web/i18n/zh-Hans/dataset.ts +++ b/web/i18n/zh-Hans/dataset.ts @@ -183,6 +183,7 @@ const translation = { checkName: { empty: '元数据名称不能为空', invalid: '元数据名称只能包含小写字母、数字和下划线,并且必须以小写字母开头', + tooLong: '元数据名称不得超过{{max}}个字符', }, batchEditMetadata: { editMetadata: '编辑元数据', diff --git a/web/i18n/zh-Hant/dataset.ts b/web/i18n/zh-Hant/dataset.ts index 0a6b015155..676c89e185 100644 --- a/web/i18n/zh-Hant/dataset.ts +++ b/web/i18n/zh-Hant/dataset.ts @@ -179,6 +179,7 @@ const translation = { checkName: { empty: '元數據名稱不能為空', invalid: '元數據名稱只能包含小寫字母、數字和底線,並且必須以小寫字母開頭', + tooLong: '元數據名稱不能超過 {{max}} 個字符', }, batchEditMetadata: { applyToAllSelectDocumentTip: '自動為所有選定文檔創建上述所有編輯和新元數據,否則編輯元數據將僅適用於具有該元數據的文檔。', From 27172b08980ab45475f8fb8aea5148bdd5a6005a Mon Sep 17 00:00:00 2001 From: AntzUhl <1325200471@qq.com> Date: Wed, 25 Jun 2025 10:27:36 +0800 Subject: [PATCH 03/12] fix (conf/code): Variables not correctly filled can still be referenced (#21451) Co-authored-by: crazywoola <427733928@qq.com> --- .../config-prompt/advanced-prompt-input.tsx | 2 +- .../config-prompt/simple-prompt-input.tsx | 31 +++++++++++++------ .../config/agent/prompt-editor.tsx | 2 +- 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx index 1eec51992d..437e25fde4 100644 --- a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx @@ -227,7 +227,7 @@ const AdvancedPromptInput: FC = ({ }} variableBlock={{ show: true, - variables: modelConfig.configs.prompt_variables.filter(item => item.type !== 'api').map(item => ({ + variables: modelConfig.configs.prompt_variables.filter(item => item.type !== 'api' && item.key && item.key.trim() && item.name && item.name.trim()).map(item => ({ name: item.name, value: item.key, })), diff --git a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx index 2a9a15296e..eb0f524386 100644 --- a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx @@ -97,20 +97,31 @@ const Prompt: FC = ({ }, }) } - const promptVariablesObj = (() => { - const obj: Record = {} - promptVariables.forEach((item) => { - obj[item.key] = true - }) - return obj - })() const [newPromptVariables, setNewPromptVariables] = React.useState(promptVariables) const [newTemplates, setNewTemplates] = React.useState('') const [isShowConfirmAddVar, { setTrue: showConfirmAddVar, setFalse: hideConfirmAddVar }] = useBoolean(false) const handleChange = (newTemplates: string, keys: string[]) => { - const newPromptVariables = keys.filter(key => !(key in promptVariablesObj) && !externalDataToolsConfig.find(item => item.variable === key)).map(key => getNewVar(key, '')) + // Filter out keys that are not properly defined (either not exist or exist but without valid name) + const newPromptVariables = keys.filter((key) => { + // Check if key exists in external data tools + if (externalDataToolsConfig.find((item: ExternalDataTool) => item.variable === key)) + return false + + // Check if key exists in prompt variables + const existingVar = promptVariables.find((item: PromptVariable) => item.key === key) + if (!existingVar) { + // Variable doesn't exist at all + return true + } + + // Variable exists but check if it has valid name and key + return !existingVar.name || !existingVar.name.trim() || !existingVar.key || !existingVar.key.trim() + + return false + }).map(key => getNewVar(key, '')) + if (newPromptVariables.length > 0) { setNewPromptVariables(newPromptVariables) setNewTemplates(newTemplates) @@ -210,14 +221,14 @@ const Prompt: FC = ({ }} variableBlock={{ show: true, - variables: modelConfig.configs.prompt_variables.filter(item => item.type !== 'api').map(item => ({ + variables: modelConfig.configs.prompt_variables.filter((item: PromptVariable) => item.type !== 'api' && item.key && item.key.trim() && item.name && item.name.trim()).map((item: PromptVariable) => ({ name: item.name, value: item.key, })), }} externalToolBlock={{ show: true, - externalTools: modelConfig.configs.prompt_variables.filter(item => item.type === 'api').map(item => ({ + externalTools: modelConfig.configs.prompt_variables.filter((item: PromptVariable) => item.type === 'api').map((item: PromptVariable) => ({ name: item.name, variableName: item.key, icon: item.icon, diff --git a/web/app/components/app/configuration/config/agent/prompt-editor.tsx b/web/app/components/app/configuration/config/agent/prompt-editor.tsx index 7f7f140eda..579b7c4d64 100644 --- a/web/app/components/app/configuration/config/agent/prompt-editor.tsx +++ b/web/app/components/app/configuration/config/agent/prompt-editor.tsx @@ -107,7 +107,7 @@ const Editor: FC = ({ }} variableBlock={{ show: true, - variables: modelConfig.configs.prompt_variables.map(item => ({ + variables: modelConfig.configs.prompt_variables.filter(item => item.key && item.key.trim() && item.name && item.name.trim()).map(item => ({ name: item.name, value: item.key, })), From 8a20134a5bf2b30378dbda24e5f9c4be3b804840 Mon Sep 17 00:00:00 2001 From: crazywoola <100913391+crazywoola@users.noreply.github.com> Date: Wed, 25 Jun 2025 11:02:27 +0800 Subject: [PATCH 04/12] fix: inputs.items is undefined (#21463) --- .../workflow/nodes/assigner/use-single-run-form-params.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/app/components/workflow/nodes/assigner/use-single-run-form-params.ts b/web/app/components/workflow/nodes/assigner/use-single-run-form-params.ts index c6d483a3ee..7ff31d91c7 100644 --- a/web/app/components/workflow/nodes/assigner/use-single-run-form-params.ts +++ b/web/app/components/workflow/nodes/assigner/use-single-run-form-params.ts @@ -24,7 +24,7 @@ const useSingleRunFormParams = ({ }: Params) => { const { inputs } = useNodeCrud(id, payload) - const vars = inputs.items.filter((item) => { + const vars = (inputs.items ?? []).filter((item) => { return item.operation !== WriteMode.clear && item.operation !== WriteMode.set && item.operation !== WriteMode.removeFirst && item.operation !== WriteMode.removeLast && !writeModeTypesNum.includes(item.operation) From 257809bb3041e6f4f08a321cd0f482e830aafa5e Mon Sep 17 00:00:00 2001 From: thief Date: Wed, 25 Jun 2025 11:14:46 +0800 Subject: [PATCH 05/12] fix: in plugin config panel appselector can not list all apps() (#21457) --- .../app-selector/app-picker.tsx | 115 ++++++++++++++---- .../app-selector/index.tsx | 87 +++++++++++-- 2 files changed, 171 insertions(+), 31 deletions(-) diff --git a/web/app/components/plugins/plugin-detail-panel/app-selector/app-picker.tsx b/web/app/components/plugins/plugin-detail-panel/app-selector/app-picker.tsx index 2f57276559..3c79acb653 100644 --- a/web/app/components/plugins/plugin-detail-panel/app-selector/app-picker.tsx +++ b/web/app/components/plugins/plugin-detail-panel/app-selector/app-picker.tsx @@ -1,7 +1,6 @@ 'use client' import type { FC } from 'react' -import React, { useMemo } from 'react' -import { useState } from 'react' +import React, { useCallback, useEffect, useRef } from 'react' import { PortalToFollowElem, PortalToFollowElemContent, @@ -14,9 +13,9 @@ import type { import Input from '@/app/components/base/input' import AppIcon from '@/app/components/base/app-icon' import type { App } from '@/types/app' +import { useTranslation } from 'react-i18next' type Props = { - appList: App[] scope: string disabled: boolean trigger: React.ReactNode @@ -25,11 +24,16 @@ type Props = { isShow: boolean onShowChange: (isShow: boolean) => void onSelect: (app: App) => void + apps: App[] + isLoading: boolean + hasMore: boolean + onLoadMore: () => void + searchText: string + onSearchChange: (text: string) => void } const AppPicker: FC = ({ scope, - appList, disabled, trigger, placement = 'right-start', @@ -37,19 +41,81 @@ const AppPicker: FC = ({ isShow, onShowChange, onSelect, + apps, + isLoading, + hasMore, + onLoadMore, + searchText, + onSearchChange, }) => { - const [searchText, setSearchText] = useState('') - const filteredAppList = useMemo(() => { - return (appList || []) - .filter(app => app.name.toLowerCase().includes(searchText.toLowerCase())) - .filter(app => (app.mode !== 'advanced-chat' && app.mode !== 'workflow') || !!app.workflow) - .filter(app => scope === 'all' - || (scope === 'completion' && app.mode === 'completion') - || (scope === 'workflow' && app.mode === 'workflow') - || (scope === 'chat' && app.mode === 'advanced-chat') - || (scope === 'chat' && app.mode === 'agent-chat') - || (scope === 'chat' && app.mode === 'chat')) - }, [appList, scope, searchText]) + const { t } = useTranslation() + const observerTarget = useRef(null) + const observerRef = useRef(null) + const loadingRef = useRef(false) + + const handleIntersection = useCallback((entries: IntersectionObserverEntry[]) => { + const target = entries[0] + if (!target.isIntersecting || loadingRef.current || !hasMore || isLoading) return + + loadingRef.current = true + onLoadMore() + // Reset loading state + setTimeout(() => { + loadingRef.current = false + }, 500) + }, [hasMore, isLoading, onLoadMore]) + + useEffect(() => { + if (!isShow) { + if (observerRef.current) { + observerRef.current.disconnect() + observerRef.current = null + } + return + } + + let mutationObserver: MutationObserver | null = null + + const setupIntersectionObserver = () => { + if (!observerTarget.current) return + + // Create new observer + observerRef.current = new IntersectionObserver(handleIntersection, { + root: null, + rootMargin: '100px', + threshold: 0.1, + }) + + observerRef.current.observe(observerTarget.current) + } + + // Set up MutationObserver to watch DOM changes + mutationObserver = new MutationObserver((mutations) => { + if (observerTarget.current) { + setupIntersectionObserver() + mutationObserver?.disconnect() + } + }) + + // Watch body changes since Portal adds content to body + mutationObserver.observe(document.body, { + childList: true, + subtree: true, + }) + + // If element exists, set up IntersectionObserver directly + if (observerTarget.current) + setupIntersectionObserver() + + return () => { + if (observerRef.current) { + observerRef.current.disconnect() + observerRef.current = null + } + mutationObserver?.disconnect() + } + }, [isShow, handleIntersection]) + const getAppType = (app: App) => { switch (app.mode) { case 'advanced-chat': @@ -84,18 +150,18 @@ const AppPicker: FC = ({ -
+
setSearchText(e.target.value)} - onClear={() => setSearchText('')} + onChange={e => onSearchChange(e.target.value)} + onClear={() => onSearchChange('')} />
-
- {filteredAppList.map(app => ( +
+ {apps.map(app => (
= ({
{getAppType(app)}
))} +
+ {isLoading && ( +
+
{t('common.loading')}
+
+ )} +
diff --git a/web/app/components/plugins/plugin-detail-panel/app-selector/index.tsx b/web/app/components/plugins/plugin-detail-panel/app-selector/index.tsx index 39ca957953..9c16c66a70 100644 --- a/web/app/components/plugins/plugin-detail-panel/app-selector/index.tsx +++ b/web/app/components/plugins/plugin-detail-panel/app-selector/index.tsx @@ -1,6 +1,6 @@ 'use client' import type { FC } from 'react' -import React, { useMemo, useState } from 'react' +import React, { useCallback, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' import { PortalToFollowElem, @@ -10,12 +10,36 @@ import { import AppTrigger from '@/app/components/plugins/plugin-detail-panel/app-selector/app-trigger' import AppPicker from '@/app/components/plugins/plugin-detail-panel/app-selector/app-picker' import AppInputsPanel from '@/app/components/plugins/plugin-detail-panel/app-selector/app-inputs-panel' -import { useAppFullList } from '@/service/use-apps' import type { App } from '@/types/app' import type { OffsetOptions, Placement, } from '@floating-ui/react' +import useSWRInfinite from 'swr/infinite' +import { fetchAppList } from '@/service/apps' +import type { AppListResponse } from '@/models/app' + +const PAGE_SIZE = 20 + +const getKey = ( + pageIndex: number, + previousPageData: AppListResponse, + searchText: string, +) => { + if (pageIndex === 0 || (previousPageData && previousPageData.has_more)) { + const params: any = { + url: 'apps', + params: { + page: pageIndex + 1, + limit: PAGE_SIZE, + name: searchText, + }, + } + + return params + } + return null +} type Props = { value?: { @@ -34,6 +58,7 @@ type Props = { }) => void supportAddCustomTool?: boolean } + const AppSelector: FC = ({ value, scope, @@ -44,18 +69,47 @@ const AppSelector: FC = ({ }) => { const { t } = useTranslation() const [isShow, onShowChange] = useState(false) + const [searchText, setSearchText] = useState('') + const [isLoadingMore, setIsLoadingMore] = useState(false) + + const { data, isLoading, setSize } = useSWRInfinite( + (pageIndex: number, previousPageData: AppListResponse) => getKey(pageIndex, previousPageData, searchText), + fetchAppList, + { + revalidateFirstPage: true, + shouldRetryOnError: false, + dedupingInterval: 500, + errorRetryCount: 3, + }, + ) + + const displayedApps = useMemo(() => { + if (!data) return [] + return data.flatMap(({ data: apps }) => apps) + }, [data]) + + const hasMore = data?.at(-1)?.has_more ?? true + + const handleLoadMore = useCallback(async () => { + if (isLoadingMore || !hasMore) return + + setIsLoadingMore(true) + try { + await setSize((size: number) => size + 1) + } + finally { + // Add a small delay to ensure state updates are complete + setTimeout(() => { + setIsLoadingMore(false) + }, 300) + } + }, [isLoadingMore, hasMore, setSize]) + const handleTriggerClick = () => { if (disabled) return onShowChange(true) } - const { data: appList } = useAppFullList() - const currentAppInfo = useMemo(() => { - if (!appList?.data || !value) - return undefined - return appList.data.find(app => app.id === value.app_id) - }, [appList?.data, value]) - const [isShowChooseApp, setIsShowChooseApp] = useState(false) const handleSelectApp = (app: App) => { const clearValue = app.id !== value?.app_id @@ -67,6 +121,7 @@ const AppSelector: FC = ({ onSelect(appValue) setIsShowChooseApp(false) } + const handleFormChange = (inputs: Record) => { const newFiles = inputs['#image#'] delete inputs['#image#'] @@ -88,6 +143,12 @@ const AppSelector: FC = ({ } }, [value]) + const currentAppInfo = useMemo(() => { + if (!displayedApps || !value) + return undefined + return displayedApps.find(app => app.id === value.app_id) + }, [displayedApps, value]) + return ( <> = ({ isShow={isShowChooseApp} onShowChange={setIsShowChooseApp} disabled={false} - appList={appList?.data || []} onSelect={handleSelectApp} scope={scope || 'all'} + apps={displayedApps} + isLoading={isLoading || isLoadingMore} + hasMore={hasMore} + onLoadMore={handleLoadMore} + searchText={searchText} + onSearchChange={setSearchText} />
{/* app inputs config panel */} @@ -140,4 +206,5 @@ const AppSelector: FC = ({ ) } + export default React.memo(AppSelector) From 449e16782e797981d90e1df7f4c374f7fe0e9437 Mon Sep 17 00:00:00 2001 From: Covfefeable <58316364+Covfefeable@users.noreply.github.com> Date: Wed, 25 Jun 2025 11:17:31 +0800 Subject: [PATCH 06/12] fix: update DocumentList to include max height and overflow styles (#21466) --- .../datasets/common/document-picker/document-list.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/app/components/datasets/common/document-picker/document-list.tsx b/web/app/components/datasets/common/document-picker/document-list.tsx index 5ac98b97fe..6b51973ccc 100644 --- a/web/app/components/datasets/common/document-picker/document-list.tsx +++ b/web/app/components/datasets/common/document-picker/document-list.tsx @@ -21,7 +21,7 @@ const DocumentList: FC = ({ }, [onChange]) return ( -
+
{list.map((item) => { const { id, name, extension } = item return ( From 819c02f1f5e3e7f8a1cc60ea2e211c6e0e89122f Mon Sep 17 00:00:00 2001 From: zxhlyh Date: Wed, 25 Jun 2025 11:34:58 +0800 Subject: [PATCH 07/12] fix: prompt editor click to insert variable (#21470) --- .../plugins/component-picker-block/variable-option.tsx | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/web/app/components/base/prompt-editor/plugins/component-picker-block/variable-option.tsx b/web/app/components/base/prompt-editor/plugins/component-picker-block/variable-option.tsx index 20c03768d0..57de34a701 100644 --- a/web/app/components/base/prompt-editor/plugins/component-picker-block/variable-option.tsx +++ b/web/app/components/base/prompt-editor/plugins/component-picker-block/variable-option.tsx @@ -44,6 +44,10 @@ export const VariableMenuItem = memo(({ tabIndex={-1} ref={setRefElement} onMouseEnter={onMouseEnter} + onMouseDown={(e) => { + e.preventDefault() + e.stopPropagation() + }} onClick={onClick}>
{icon} From 94f8e48647877f6fe295b6f4bfca40e291319178 Mon Sep 17 00:00:00 2001 From: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com> Date: Wed, 25 Jun 2025 11:44:35 +0800 Subject: [PATCH 08/12] Refactor update dataset (fix #21401) (#21402) Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- api/services/dataset_service.py | 460 ++++++---- ...t_service_batch_update_document_status.py} | 0 .../test_dataset_service_update_dataset.py | 826 ++++++++++++++++++ 3 files changed, 1133 insertions(+), 153 deletions(-) rename api/tests/unit_tests/services/{test_dataset_service.py => test_dataset_service_batch_update_document_status.py} (100%) create mode 100644 api/tests/unit_tests/services/test_dataset_service_update_dataset.py diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 49ca98624a..af1210c19d 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -280,174 +280,328 @@ class DatasetService: @staticmethod def update_dataset(dataset_id, data, user): + """ + Update dataset configuration and settings. + + Args: + dataset_id: The unique identifier of the dataset to update + data: Dictionary containing the update data + user: The user performing the update operation + + Returns: + Dataset: The updated dataset object + + Raises: + ValueError: If dataset not found or validation fails + NoPermissionError: If user lacks permission to update the dataset + """ + # Retrieve and validate dataset existence dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise ValueError("Dataset not found") + # Verify user has permission to update this dataset DatasetService.check_dataset_permission(dataset, user) + + # Handle external dataset updates if dataset.provider == "external": - external_retrieval_model = data.get("external_retrieval_model", None) - if external_retrieval_model: - dataset.retrieval_model = external_retrieval_model - dataset.name = data.get("name", dataset.name) - dataset.description = data.get("description", "") - permission = data.get("permission") - if permission: - dataset.permission = permission - external_knowledge_id = data.get("external_knowledge_id", None) - db.session.add(dataset) - if not external_knowledge_id: - raise ValueError("External knowledge id is required.") - external_knowledge_api_id = data.get("external_knowledge_api_id", None) - if not external_knowledge_api_id: - raise ValueError("External knowledge api id is required.") - - with Session(db.engine) as session: - external_knowledge_binding = ( - session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first() - ) - - if not external_knowledge_binding: - raise ValueError("External knowledge binding not found.") - - if ( - external_knowledge_binding.external_knowledge_id != external_knowledge_id - or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id - ): - external_knowledge_binding.external_knowledge_id = external_knowledge_id - external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id - db.session.add(external_knowledge_binding) - db.session.commit() + return DatasetService._update_external_dataset(dataset, data, user) else: - data.pop("partial_member_list", None) - data.pop("external_knowledge_api_id", None) - data.pop("external_knowledge_id", None) - data.pop("external_retrieval_model", None) - filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"} - action = None - if dataset.indexing_technique != data["indexing_technique"]: - # if update indexing_technique - if data["indexing_technique"] == "economy": - action = "remove" - filtered_data["embedding_model"] = None - filtered_data["embedding_model_provider"] = None - filtered_data["collection_binding_id"] = None - elif data["indexing_technique"] == "high_quality": - action = "add" - # get embedding model setting - try: - model_manager = ModelManager() - embedding_model = model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, - provider=data["embedding_model_provider"], - model_type=ModelType.TEXT_EMBEDDING, - model=data["embedding_model"], - ) - filtered_data["embedding_model"] = embedding_model.model - filtered_data["embedding_model_provider"] = embedding_model.provider - dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, embedding_model.model - ) - filtered_data["collection_binding_id"] = dataset_collection_binding.id - except LLMBadRequestError: - raise ValueError( - "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider." - ) - except ProviderTokenNotInitError as ex: - raise ValueError(ex.description) - else: - # add default plugin id to both setting sets, to make sure the plugin model provider is consistent - # Skip embedding model checks if not provided in the update request - if ( - "embedding_model_provider" not in data - or "embedding_model" not in data - or not data.get("embedding_model_provider") - or not data.get("embedding_model") - ): - # If the dataset already has embedding model settings, use those - if dataset.embedding_model_provider and dataset.embedding_model: - # Keep existing values - filtered_data["embedding_model_provider"] = dataset.embedding_model_provider - filtered_data["embedding_model"] = dataset.embedding_model - # If collection_binding_id exists, keep it too - if dataset.collection_binding_id: - filtered_data["collection_binding_id"] = dataset.collection_binding_id - # Otherwise, don't try to update embedding model settings at all - # Remove these fields from filtered_data if they exist but are None/empty - if "embedding_model_provider" in filtered_data and not filtered_data["embedding_model_provider"]: - del filtered_data["embedding_model_provider"] - if "embedding_model" in filtered_data and not filtered_data["embedding_model"]: - del filtered_data["embedding_model"] - else: - skip_embedding_update = False - try: - # Handle existing model provider - plugin_model_provider = dataset.embedding_model_provider - plugin_model_provider_str = None - if plugin_model_provider: - plugin_model_provider_str = str(ModelProviderID(plugin_model_provider)) + return DatasetService._update_internal_dataset(dataset, data, user) - # Handle new model provider from request - new_plugin_model_provider = data["embedding_model_provider"] - new_plugin_model_provider_str = None - if new_plugin_model_provider: - new_plugin_model_provider_str = str(ModelProviderID(new_plugin_model_provider)) + @staticmethod + def _update_external_dataset(dataset, data, user): + """ + Update external dataset configuration. - # Only update embedding model if both values are provided and different from current - if ( - plugin_model_provider_str != new_plugin_model_provider_str - or data["embedding_model"] != dataset.embedding_model - ): - action = "update" - model_manager = ModelManager() - try: - embedding_model = model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, - provider=data["embedding_model_provider"], - model_type=ModelType.TEXT_EMBEDDING, - model=data["embedding_model"], - ) - except ProviderTokenNotInitError: - # If we can't get the embedding model, skip updating it - # and keep the existing settings if available - if dataset.embedding_model_provider and dataset.embedding_model: - filtered_data["embedding_model_provider"] = dataset.embedding_model_provider - filtered_data["embedding_model"] = dataset.embedding_model - if dataset.collection_binding_id: - filtered_data["collection_binding_id"] = dataset.collection_binding_id - # Skip the rest of the embedding model update - skip_embedding_update = True - if not skip_embedding_update: - filtered_data["embedding_model"] = embedding_model.model - filtered_data["embedding_model_provider"] = embedding_model.provider - dataset_collection_binding = ( - DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, embedding_model.model - ) - ) - filtered_data["collection_binding_id"] = dataset_collection_binding.id - except LLMBadRequestError: - raise ValueError( - "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider." - ) - except ProviderTokenNotInitError as ex: - raise ValueError(ex.description) + Args: + dataset: The dataset object to update + data: Update data dictionary + user: User performing the update - filtered_data["updated_by"] = user.id - filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + Returns: + Dataset: Updated dataset object + """ + # Update retrieval model if provided + external_retrieval_model = data.get("external_retrieval_model", None) + if external_retrieval_model: + dataset.retrieval_model = external_retrieval_model - # update Retrieval model - filtered_data["retrieval_model"] = data["retrieval_model"] + # Update basic dataset properties + dataset.name = data.get("name", dataset.name) + dataset.description = data.get("description", dataset.description) - db.session.query(Dataset).filter_by(id=dataset_id).update(filtered_data) + # Update permission if provided + permission = data.get("permission") + if permission: + dataset.permission = permission + + # Validate and update external knowledge configuration + external_knowledge_id = data.get("external_knowledge_id", None) + external_knowledge_api_id = data.get("external_knowledge_api_id", None) + + if not external_knowledge_id: + raise ValueError("External knowledge id is required.") + if not external_knowledge_api_id: + raise ValueError("External knowledge api id is required.") + # Update metadata fields + dataset.updated_by = user.id if user else None + dataset.updated_at = datetime.datetime.utcnow() + db.session.add(dataset) + + # Update external knowledge binding + DatasetService._update_external_knowledge_binding(dataset.id, external_knowledge_id, external_knowledge_api_id) + + # Commit changes to database + db.session.commit() - db.session.commit() - if action: - deal_dataset_vector_index_task.delay(dataset_id, action) return dataset + @staticmethod + def _update_external_knowledge_binding(dataset_id, external_knowledge_id, external_knowledge_api_id): + """ + Update external knowledge binding configuration. + + Args: + dataset_id: Dataset identifier + external_knowledge_id: External knowledge identifier + external_knowledge_api_id: External knowledge API identifier + """ + with Session(db.engine) as session: + external_knowledge_binding = ( + session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first() + ) + + if not external_knowledge_binding: + raise ValueError("External knowledge binding not found.") + + # Update binding if values have changed + if ( + external_knowledge_binding.external_knowledge_id != external_knowledge_id + or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id + ): + external_knowledge_binding.external_knowledge_id = external_knowledge_id + external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id + db.session.add(external_knowledge_binding) + + @staticmethod + def _update_internal_dataset(dataset, data, user): + """ + Update internal dataset configuration. + + Args: + dataset: The dataset object to update + data: Update data dictionary + user: User performing the update + + Returns: + Dataset: Updated dataset object + """ + # Remove external-specific fields from update data + data.pop("partial_member_list", None) + data.pop("external_knowledge_api_id", None) + data.pop("external_knowledge_id", None) + data.pop("external_retrieval_model", None) + + # Filter out None values except for description field + filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"} + + # Handle indexing technique changes and embedding model updates + action = DatasetService._handle_indexing_technique_change(dataset, data, filtered_data) + + # Add metadata fields + filtered_data["updated_by"] = user.id + filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + # update Retrieval model + filtered_data["retrieval_model"] = data["retrieval_model"] + + # Update dataset in database + db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data) + db.session.commit() + + # Trigger vector index task if indexing technique changed + if action: + deal_dataset_vector_index_task.delay(dataset.id, action) + + return dataset + + @staticmethod + def _handle_indexing_technique_change(dataset, data, filtered_data): + """ + Handle changes in indexing technique and configure embedding models accordingly. + + Args: + dataset: Current dataset object + data: Update data dictionary + filtered_data: Filtered update data + + Returns: + str: Action to perform ('add', 'remove', 'update', or None) + """ + if dataset.indexing_technique != data["indexing_technique"]: + if data["indexing_technique"] == "economy": + # Remove embedding model configuration for economy mode + filtered_data["embedding_model"] = None + filtered_data["embedding_model_provider"] = None + filtered_data["collection_binding_id"] = None + return "remove" + elif data["indexing_technique"] == "high_quality": + # Configure embedding model for high quality mode + DatasetService._configure_embedding_model_for_high_quality(data, filtered_data) + return "add" + else: + # Handle embedding model updates when indexing technique remains the same + return DatasetService._handle_embedding_model_update_when_technique_unchanged(dataset, data, filtered_data) + return None + + @staticmethod + def _configure_embedding_model_for_high_quality(data, filtered_data): + """ + Configure embedding model settings for high quality indexing. + + Args: + data: Update data dictionary + filtered_data: Filtered update data to modify + """ + try: + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=data["embedding_model_provider"], + model_type=ModelType.TEXT_EMBEDDING, + model=data["embedding_model"], + ) + filtered_data["embedding_model"] = embedding_model.model + filtered_data["embedding_model_provider"] = embedding_model.provider + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_model.provider, embedding_model.model + ) + filtered_data["collection_binding_id"] = dataset_collection_binding.id + except LLMBadRequestError: + raise ValueError( + "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: + raise ValueError(ex.description) + + @staticmethod + def _handle_embedding_model_update_when_technique_unchanged(dataset, data, filtered_data): + """ + Handle embedding model updates when indexing technique remains the same. + + Args: + dataset: Current dataset object + data: Update data dictionary + filtered_data: Filtered update data to modify + + Returns: + str: Action to perform ('update' or None) + """ + # Skip embedding model checks if not provided in the update request + if ( + "embedding_model_provider" not in data + or "embedding_model" not in data + or not data.get("embedding_model_provider") + or not data.get("embedding_model") + ): + DatasetService._preserve_existing_embedding_settings(dataset, filtered_data) + return None + else: + return DatasetService._update_embedding_model_settings(dataset, data, filtered_data) + + @staticmethod + def _preserve_existing_embedding_settings(dataset, filtered_data): + """ + Preserve existing embedding model settings when not provided in update. + + Args: + dataset: Current dataset object + filtered_data: Filtered update data to modify + """ + # If the dataset already has embedding model settings, use those + if dataset.embedding_model_provider and dataset.embedding_model: + filtered_data["embedding_model_provider"] = dataset.embedding_model_provider + filtered_data["embedding_model"] = dataset.embedding_model + # If collection_binding_id exists, keep it too + if dataset.collection_binding_id: + filtered_data["collection_binding_id"] = dataset.collection_binding_id + # Otherwise, don't try to update embedding model settings at all + # Remove these fields from filtered_data if they exist but are None/empty + if "embedding_model_provider" in filtered_data and not filtered_data["embedding_model_provider"]: + del filtered_data["embedding_model_provider"] + if "embedding_model" in filtered_data and not filtered_data["embedding_model"]: + del filtered_data["embedding_model"] + + @staticmethod + def _update_embedding_model_settings(dataset, data, filtered_data): + """ + Update embedding model settings with new values. + + Args: + dataset: Current dataset object + data: Update data dictionary + filtered_data: Filtered update data to modify + + Returns: + str: Action to perform ('update' or None) + """ + try: + # Compare current and new model provider settings + current_provider_str = ( + str(ModelProviderID(dataset.embedding_model_provider)) if dataset.embedding_model_provider else None + ) + new_provider_str = ( + str(ModelProviderID(data["embedding_model_provider"])) if data["embedding_model_provider"] else None + ) + + # Only update if values are different + if current_provider_str != new_provider_str or data["embedding_model"] != dataset.embedding_model: + DatasetService._apply_new_embedding_settings(dataset, data, filtered_data) + return "update" + except LLMBadRequestError: + raise ValueError( + "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: + raise ValueError(ex.description) + return None + + @staticmethod + def _apply_new_embedding_settings(dataset, data, filtered_data): + """ + Apply new embedding model settings to the dataset. + + Args: + dataset: Current dataset object + data: Update data dictionary + filtered_data: Filtered update data to modify + """ + model_manager = ModelManager() + try: + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=data["embedding_model_provider"], + model_type=ModelType.TEXT_EMBEDDING, + model=data["embedding_model"], + ) + except ProviderTokenNotInitError: + # If we can't get the embedding model, preserve existing settings + if dataset.embedding_model_provider and dataset.embedding_model: + filtered_data["embedding_model_provider"] = dataset.embedding_model_provider + filtered_data["embedding_model"] = dataset.embedding_model + if dataset.collection_binding_id: + filtered_data["collection_binding_id"] = dataset.collection_binding_id + # Skip the rest of the embedding model update + return + + # Apply new embedding model settings + filtered_data["embedding_model"] = embedding_model.model + filtered_data["embedding_model_provider"] = embedding_model.provider + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_model.provider, embedding_model.model + ) + filtered_data["collection_binding_id"] = dataset_collection_binding.id + @staticmethod def delete_dataset(dataset_id, user): dataset = DatasetService.get_dataset(dataset_id) diff --git a/api/tests/unit_tests/services/test_dataset_service.py b/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py similarity index 100% rename from api/tests/unit_tests/services/test_dataset_service.py rename to api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py diff --git a/api/tests/unit_tests/services/test_dataset_service_update_dataset.py b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py new file mode 100644 index 0000000000..15e1b7569f --- /dev/null +++ b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py @@ -0,0 +1,826 @@ +import datetime + +# Mock redis_client before importing dataset_service +from unittest.mock import Mock, patch + +import pytest + +from core.model_runtime.entities.model_entities import ModelType +from models.dataset import Dataset, ExternalKnowledgeBindings +from services.dataset_service import DatasetService +from services.errors.account import NoPermissionError +from tests.unit_tests.conftest import redis_mock + + +class TestDatasetServiceUpdateDataset: + """ + Comprehensive unit tests for DatasetService.update_dataset method. + + This test suite covers all supported scenarios including: + - External dataset updates + - Internal dataset updates with different indexing techniques + - Embedding model updates + - Permission checks + - Error conditions and edge cases + """ + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_external_dataset_success(self, mock_datetime, mock_check_permission, mock_get_dataset, mock_db): + """ + Test successful update of external dataset. + + Verifies that: + 1. External dataset attributes are updated correctly + 2. External knowledge binding is updated when values change + 3. Database changes are committed + 4. Permission check is performed + """ + from unittest.mock import Mock, patch + + from extensions.ext_database import db + + with patch.object(db.__class__, "engine", new_callable=Mock): + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "external" + mock_dataset.name = "old_name" + mock_dataset.description = "old_description" + mock_dataset.retrieval_model = "old_model" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create mock external knowledge binding + mock_binding = Mock(spec=ExternalKnowledgeBindings) + mock_binding.external_knowledge_id = "old_knowledge_id" + mock_binding.external_knowledge_api_id = "old_api_id" + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Mock external knowledge binding query + with patch("services.dataset_service.Session") as mock_session: + mock_session_instance = Mock() + mock_session.return_value.__enter__.return_value = mock_session_instance + mock_session_instance.query.return_value.filter_by.return_value.first.return_value = mock_binding + + # Test data + update_data = { + "name": "new_name", + "description": "new_description", + "external_retrieval_model": "new_model", + "permission": "only_me", + "external_knowledge_id": "new_knowledge_id", + "external_knowledge_api_id": "new_api_id", + } + + # Call the method + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify permission check was called + mock_check_permission.assert_called_once_with(mock_dataset, mock_user) + + # Verify dataset attributes were updated + assert mock_dataset.name == "new_name" + assert mock_dataset.description == "new_description" + assert mock_dataset.retrieval_model == "new_model" + + # Verify external knowledge binding was updated + assert mock_binding.external_knowledge_id == "new_knowledge_id" + assert mock_binding.external_knowledge_api_id == "new_api_id" + + # Verify database operations + mock_db.add.assert_any_call(mock_dataset) + mock_db.add.assert_any_call(mock_binding) + mock_db.commit.assert_called_once() + + # Verify return value + assert result == mock_dataset + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + def test_update_external_dataset_missing_knowledge_id_error(self, mock_check_permission, mock_get_dataset, mock_db): + """ + Test error when external knowledge id is missing. + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.provider = "external" + + # Create mock user + mock_user = Mock() + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Test data without external_knowledge_id + update_data = {"name": "new_name", "external_knowledge_api_id": "api_id"} + + # Call the method and expect ValueError + with pytest.raises(ValueError) as context: + DatasetService.update_dataset("dataset-123", update_data, mock_user) + + assert "External knowledge id is required" in str(context.value) + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + def test_update_external_dataset_missing_api_id_error(self, mock_check_permission, mock_get_dataset, mock_db): + """ + Test error when external knowledge api id is missing. + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.provider = "external" + + # Create mock user + mock_user = Mock() + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Test data without external_knowledge_api_id + update_data = {"name": "new_name", "external_knowledge_id": "knowledge_id"} + + # Call the method and expect ValueError + with pytest.raises(ValueError) as context: + DatasetService.update_dataset("dataset-123", update_data, mock_user) + + assert "External knowledge api id is required" in str(context.value) + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.Session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + def test_update_external_dataset_binding_not_found_error( + self, mock_check_permission, mock_get_dataset, mock_session, mock_db + ): + from unittest.mock import Mock, patch + + from extensions.ext_database import db + + with patch.object(db.__class__, "engine", new_callable=Mock): + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.provider = "external" + + # Create mock user + mock_user = Mock() + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Mock external knowledge binding query returning None + mock_session_instance = Mock() + mock_session.return_value.__enter__.return_value = mock_session_instance + mock_session_instance.query.return_value.filter_by.return_value.first.return_value = None + + # Test data + update_data = { + "name": "new_name", + "external_knowledge_id": "knowledge_id", + "external_knowledge_api_id": "api_id", + } + + # Call the method and expect ValueError + with pytest.raises(ValueError) as context: + DatasetService.update_dataset("dataset-123", update_data, mock_user) + + assert "External knowledge binding not found" in str(context.value) + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_internal_dataset_basic_success( + self, mock_datetime, mock_check_permission, mock_get_dataset, mock_db + ): + """ + Test successful update of internal dataset with basic fields. + + Verifies that: + 1. Basic dataset attributes are updated correctly + 2. Filtered data excludes None values except description + 3. Timestamp fields are updated + 4. Database changes are committed + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.name = "old_name" + mock_dataset.description = "old_description" + mock_dataset.indexing_technique = "high_quality" + mock_dataset.retrieval_model = "old_model" + mock_dataset.embedding_model_provider = "openai" + mock_dataset.embedding_model = "text-embedding-ada-002" + mock_dataset.collection_binding_id = "binding-123" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Test data + update_data = { + "name": "new_name", + "description": "new_description", + "indexing_technique": "high_quality", + "retrieval_model": "new_model", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-ada-002", + } + + # Call the method + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify permission check was called + mock_check_permission.assert_called_once_with(mock_dataset, mock_user) + + # Verify database update was called with correct filtered data + expected_filtered_data = { + "name": "new_name", + "description": "new_description", + "indexing_technique": "high_quality", + "retrieval_model": "new_model", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-ada-002", + "updated_by": mock_user.id, + "updated_at": current_time.replace(tzinfo=None), + } + + mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data) + mock_db.commit.assert_called_once() + + # Verify return value + assert result == mock_dataset + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.deal_dataset_vector_index_task") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_internal_dataset_indexing_technique_to_economy( + self, mock_datetime, mock_check_permission, mock_get_dataset, mock_task, mock_db + ): + """ + Test updating internal dataset indexing technique to economy. + + Verifies that: + 1. Embedding model fields are cleared when switching to economy + 2. Vector index task is triggered with 'remove' action + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.indexing_technique = "high_quality" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Test data + update_data = {"indexing_technique": "economy", "retrieval_model": "new_model"} + + # Call the method + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify database update was called with embedding model fields cleared + expected_filtered_data = { + "indexing_technique": "economy", + "embedding_model": None, + "embedding_model_provider": None, + "collection_binding_id": None, + "retrieval_model": "new_model", + "updated_by": mock_user.id, + "updated_at": current_time.replace(tzinfo=None), + } + + mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data) + mock_db.commit.assert_called_once() + + # Verify vector index task was triggered + mock_task.delay.assert_called_once_with("dataset-123", "remove") + + # Verify return value + assert result == mock_dataset + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + def test_update_dataset_not_found_error(self, mock_check_permission, mock_get_dataset, mock_db): + """ + Test error when dataset is not found. + """ + # Create mock user + mock_user = Mock() + + # Mock dataset retrieval returning None + mock_get_dataset.return_value = None + + # Test data + update_data = {"name": "new_name"} + + # Call the method and expect ValueError + with pytest.raises(ValueError) as context: + DatasetService.update_dataset("dataset-123", update_data, mock_user) + + assert "Dataset not found" in str(context.value) + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + def test_update_dataset_permission_error(self, mock_check_permission, mock_get_dataset, mock_db): + """ + Test error when user doesn't have permission. + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + + # Create mock user + mock_user = Mock() + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Mock permission check to raise error + mock_check_permission.side_effect = NoPermissionError("No permission") + + # Test data + update_data = {"name": "new_name"} + + # Call the method and expect NoPermissionError + with pytest.raises(NoPermissionError): + DatasetService.update_dataset("dataset-123", update_data, mock_user) + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_internal_dataset_keep_existing_embedding_model( + self, mock_datetime, mock_check_permission, mock_get_dataset, mock_db + ): + """ + Test updating internal dataset without changing embedding model. + + Verifies that: + 1. Existing embedding model settings are preserved when not provided in update + 2. No vector index task is triggered + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.indexing_technique = "high_quality" + mock_dataset.embedding_model_provider = "openai" + mock_dataset.embedding_model = "text-embedding-ada-002" + mock_dataset.collection_binding_id = "binding-123" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Test data without embedding model fields + update_data = {"name": "new_name", "indexing_technique": "high_quality", "retrieval_model": "new_model"} + + # Call the method + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify database update was called with existing embedding model preserved + expected_filtered_data = { + "name": "new_name", + "indexing_technique": "high_quality", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-ada-002", + "collection_binding_id": "binding-123", + "retrieval_model": "new_model", + "updated_by": mock_user.id, + "updated_at": current_time.replace(tzinfo=None), + } + + mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data) + mock_db.commit.assert_called_once() + + # Verify return value + assert result == mock_dataset + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding") + @patch("services.dataset_service.ModelManager") + @patch("services.dataset_service.deal_dataset_vector_index_task") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_internal_dataset_indexing_technique_to_high_quality( + self, + mock_datetime, + mock_check_permission, + mock_get_dataset, + mock_task, + mock_model_manager, + mock_collection_binding, + mock_db, + ): + """ + Test updating internal dataset indexing technique to high_quality. + + Verifies that: + 1. Embedding model is validated and set + 2. Collection binding is retrieved + 3. Vector index task is triggered with 'add' action + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.indexing_technique = "economy" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Mock embedding model + mock_embedding_model = Mock() + mock_embedding_model.model = "text-embedding-ada-002" + mock_embedding_model.provider = "openai" + + # Mock collection binding + mock_collection_binding_instance = Mock() + mock_collection_binding_instance.id = "binding-456" + + # Mock model manager + mock_model_manager_instance = Mock() + mock_model_manager_instance.get_model_instance.return_value = mock_embedding_model + mock_model_manager.return_value = mock_model_manager_instance + + # Mock collection binding service + mock_collection_binding.return_value = mock_collection_binding_instance + + # Mock current_user + mock_current_user = Mock() + mock_current_user.current_tenant_id = "tenant-123" + + # Test data + update_data = { + "indexing_technique": "high_quality", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-ada-002", + "retrieval_model": "new_model", + } + + # Call the method with current_user mock + with patch("services.dataset_service.current_user", mock_current_user): + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify embedding model was validated + mock_model_manager_instance.get_model_instance.assert_called_once_with( + tenant_id=mock_current_user.current_tenant_id, + provider="openai", + model_type=ModelType.TEXT_EMBEDDING, + model="text-embedding-ada-002", + ) + + # Verify collection binding was retrieved + mock_collection_binding.assert_called_once_with("openai", "text-embedding-ada-002") + + # Verify database update was called with correct data + expected_filtered_data = { + "indexing_technique": "high_quality", + "embedding_model": "text-embedding-ada-002", + "embedding_model_provider": "openai", + "collection_binding_id": "binding-456", + "retrieval_model": "new_model", + "updated_by": mock_user.id, + "updated_at": current_time.replace(tzinfo=None), + } + + mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data) + mock_db.commit.assert_called_once() + + # Verify vector index task was triggered + mock_task.delay.assert_called_once_with("dataset-123", "add") + + # Verify return value + assert result == mock_dataset + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + def test_update_internal_dataset_embedding_model_error(self, mock_check_permission, mock_get_dataset, mock_db): + """ + Test error when embedding model is not available. + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.indexing_technique = "economy" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Mock current_user + mock_current_user = Mock() + mock_current_user.current_tenant_id = "tenant-123" + + # Mock model manager to raise error + with ( + patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.current_user", mock_current_user), + ): + mock_model_manager_instance = Mock() + mock_model_manager_instance.get_model_instance.side_effect = Exception("No Embedding Model available") + mock_model_manager.return_value = mock_model_manager_instance + + # Test data + update_data = { + "indexing_technique": "high_quality", + "embedding_model_provider": "invalid_provider", + "embedding_model": "invalid_model", + "retrieval_model": "new_model", + } + + # Call the method and expect ValueError + with pytest.raises(Exception) as context: + DatasetService.update_dataset("dataset-123", update_data, mock_user) + + assert "No Embedding Model available".lower() in str(context.value).lower() + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_internal_dataset_filter_none_values( + self, mock_datetime, mock_check_permission, mock_get_dataset, mock_db + ): + """ + Test that None values are filtered out except for description field. + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.indexing_technique = "high_quality" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Test data with None values + update_data = { + "name": "new_name", + "description": None, # Should be included + "indexing_technique": "high_quality", + "retrieval_model": "new_model", + "embedding_model_provider": None, # Should be filtered out + "embedding_model": None, # Should be filtered out + } + + # Call the method + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify database update was called with filtered data + expected_filtered_data = { + "name": "new_name", + "description": None, # Description should be included even if None + "indexing_technique": "high_quality", + "retrieval_model": "new_model", + "updated_by": mock_user.id, + "updated_at": mock_db.query.return_value.filter_by.return_value.update.call_args[0][0]["updated_at"], + } + + actual_call_args = mock_db.query.return_value.filter_by.return_value.update.call_args[0][0] + # Remove timestamp for comparison as it's dynamic + del actual_call_args["updated_at"] + del expected_filtered_data["updated_at"] + + del actual_call_args["collection_binding_id"] + del actual_call_args["embedding_model"] + del actual_call_args["embedding_model_provider"] + + assert actual_call_args == expected_filtered_data + + # Verify return value + assert result == mock_dataset + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.deal_dataset_vector_index_task") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_internal_dataset_embedding_model_update( + self, mock_datetime, mock_check_permission, mock_get_dataset, mock_task, mock_db + ): + """ + Test updating internal dataset with new embedding model. + + Verifies that: + 1. Embedding model is updated when different from current + 2. Vector index task is triggered with 'update' action + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.indexing_technique = "high_quality" + mock_dataset.embedding_model_provider = "openai" + mock_dataset.embedding_model = "text-embedding-ada-002" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Mock embedding model + mock_embedding_model = Mock() + mock_embedding_model.model = "text-embedding-3-small" + mock_embedding_model.provider = "openai" + + # Mock collection binding + mock_collection_binding_instance = Mock() + mock_collection_binding_instance.id = "binding-789" + + # Mock current_user + mock_current_user = Mock() + mock_current_user.current_tenant_id = "tenant-123" + + # Mock model manager + with patch("services.dataset_service.ModelManager") as mock_model_manager: + mock_model_manager_instance = Mock() + mock_model_manager_instance.get_model_instance.return_value = mock_embedding_model + mock_model_manager.return_value = mock_model_manager_instance + + # Mock collection binding service + with ( + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" + ) as mock_collection_binding, + patch("services.dataset_service.current_user", mock_current_user), + ): + mock_collection_binding.return_value = mock_collection_binding_instance + + # Test data + update_data = { + "indexing_technique": "high_quality", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-3-small", + "retrieval_model": "new_model", + } + + # Call the method + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify embedding model was validated + mock_model_manager_instance.get_model_instance.assert_called_once_with( + tenant_id=mock_current_user.current_tenant_id, + provider="openai", + model_type=ModelType.TEXT_EMBEDDING, + model="text-embedding-3-small", + ) + + # Verify collection binding was retrieved + mock_collection_binding.assert_called_once_with("openai", "text-embedding-3-small") + + # Verify database update was called with correct data + expected_filtered_data = { + "indexing_technique": "high_quality", + "embedding_model": "text-embedding-3-small", + "embedding_model_provider": "openai", + "collection_binding_id": "binding-789", + "retrieval_model": "new_model", + "updated_by": mock_user.id, + "updated_at": current_time.replace(tzinfo=None), + } + + mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data) + mock_db.commit.assert_called_once() + + # Verify vector index task was triggered + mock_task.delay.assert_called_once_with("dataset-123", "update") + + # Verify return value + assert result == mock_dataset + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_internal_dataset_no_indexing_technique_change( + self, mock_datetime, mock_check_permission, mock_get_dataset, mock_db + ): + """ + Test updating internal dataset without changing indexing technique. + + Verifies that: + 1. No vector index task is triggered when indexing technique doesn't change + 2. Database update is performed normally + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.indexing_technique = "high_quality" + mock_dataset.embedding_model_provider = "openai" + mock_dataset.embedding_model = "text-embedding-ada-002" + mock_dataset.collection_binding_id = "binding-123" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Test data with same indexing technique + update_data = { + "name": "new_name", + "indexing_technique": "high_quality", # Same as current + "retrieval_model": "new_model", + } + + # Call the method + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify database update was called with correct data + expected_filtered_data = { + "name": "new_name", + "indexing_technique": "high_quality", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-ada-002", + "collection_binding_id": "binding-123", + "retrieval_model": "new_model", + "updated_by": mock_user.id, + "updated_at": current_time.replace(tzinfo=None), + } + + mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data) + mock_db.commit.assert_called_once() + + # Verify no vector index task was triggered + mock_db.query.return_value.filter_by.return_value.update.assert_called_once() + + # Verify return value + assert result == mock_dataset From 268da3133282b20f01d3bd5a46e3f69210aab399 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Wed, 25 Jun 2025 12:39:22 +0800 Subject: [PATCH 09/12] fix(api): adding variable to variable pool recursively while loading draft variables. (#21478) This PR fix the issue that `ObjectSegment` are not recursively added to the draft variable pool while loading draft variables from database. It also fixes an issue about loading variables with more than two elements in the its selector. Enhances #19735. Closes #21477. --- .../workflow/graph_engine/graph_engine.py | 17 +- api/core/workflow/utils/variable_utils.py | 28 ++++ api/core/workflow/variable_loader.py | 7 +- .../workflow_draft_variable_service.py | 3 +- .../workflow/utils/test_variable_utils.py | 148 ++++++++++++++++++ 5 files changed, 191 insertions(+), 12 deletions(-) create mode 100644 api/core/workflow/utils/variable_utils.py create mode 100644 api/tests/unit_tests/core/workflow/utils/test_variable_utils.py diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 2809afad02..61a7a26652 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -53,6 +53,7 @@ from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +from core.workflow.utils import variable_utils from libs.flask_utils import preserve_flask_contexts from models.enums import UserFrom from models.workflow import WorkflowType @@ -856,16 +857,12 @@ class GraphEngine: :param variable_value: variable value :return: """ - self.graph_runtime_state.variable_pool.add([node_id] + variable_key_list, variable_value) - - # if variable_value is a dict, then recursively append variables - if isinstance(variable_value, dict): - for key, value in variable_value.items(): - # construct new key list - new_key_list = variable_key_list + [key] - self._append_variables_recursively( - node_id=node_id, variable_key_list=new_key_list, variable_value=value - ) + variable_utils.append_variables_recursively( + self.graph_runtime_state.variable_pool, + node_id, + variable_key_list, + variable_value, + ) def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: """ diff --git a/api/core/workflow/utils/variable_utils.py b/api/core/workflow/utils/variable_utils.py new file mode 100644 index 0000000000..e68d990b60 --- /dev/null +++ b/api/core/workflow/utils/variable_utils.py @@ -0,0 +1,28 @@ +from core.variables.segments import ObjectSegment, Segment +from core.workflow.entities.variable_pool import VariablePool, VariableValue + + +def append_variables_recursively( + pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue | Segment +): + """ + Append variables recursively + :param node_id: node id + :param variable_key_list: variable key list + :param variable_value: variable value + :return: + """ + pool.add([node_id] + variable_key_list, variable_value) + + # if variable_value is a dict, then recursively append variables + if isinstance(variable_value, ObjectSegment): + variable_dict = variable_value.value + elif isinstance(variable_value, dict): + variable_dict = variable_value + else: + return + + for key, value in variable_dict.items(): + # construct new key list + new_key_list = variable_key_list + [key] + append_variables_recursively(pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value) diff --git a/api/core/workflow/variable_loader.py b/api/core/workflow/variable_loader.py index 4842ee00a5..1e13871d0a 100644 --- a/api/core/workflow/variable_loader.py +++ b/api/core/workflow/variable_loader.py @@ -3,7 +3,9 @@ from collections.abc import Mapping, Sequence from typing import Any, Protocol from core.variables import Variable +from core.variables.consts import MIN_SELECTORS_LENGTH from core.workflow.entities.variable_pool import VariablePool +from core.workflow.utils import variable_utils class VariableLoader(Protocol): @@ -76,4 +78,7 @@ def load_into_variable_pool( variables_to_load.append(list(selector)) loaded = variable_loader.load_variables(variables_to_load) for var in loaded: - variable_pool.add(var.selector, var) + assert len(var.selector) >= MIN_SELECTORS_LENGTH, f"Invalid variable {var}" + variable_utils.append_variables_recursively( + variable_pool, node_id=var.selector[0], variable_key_list=list(var.selector[1:]), variable_value=var + ) diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index cd30440b4f..164693c2e1 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -129,7 +129,8 @@ class WorkflowDraftVariableService: ) -> list[WorkflowDraftVariable]: ors = [] for selector in selectors: - node_id, name = selector + assert len(selector) >= MIN_SELECTORS_LENGTH, f"Invalid selector to get: {selector}" + node_id, name = selector[:2] ors.append(and_(WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.name == name)) # NOTE(QuantumGhost): Although the number of `or` expressions may be large, as long as diff --git a/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py b/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py new file mode 100644 index 0000000000..f1cb937bb3 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py @@ -0,0 +1,148 @@ +from typing import Any + +from core.variables.segments import ObjectSegment, StringSegment +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.utils.variable_utils import append_variables_recursively + + +class TestAppendVariablesRecursively: + """Test cases for append_variables_recursively function""" + + def test_append_simple_dict_value(self): + """Test appending a simple dictionary value""" + pool = VariablePool() + node_id = "test_node" + variable_key_list = ["output"] + variable_value = {"name": "John", "age": 30} + + append_variables_recursively(pool, node_id, variable_key_list, variable_value) + + # Check that the main variable is added + main_var = pool.get([node_id] + variable_key_list) + assert main_var is not None + assert main_var.value == variable_value + + # Check that nested variables are added recursively + name_var = pool.get([node_id] + variable_key_list + ["name"]) + assert name_var is not None + assert name_var.value == "John" + + age_var = pool.get([node_id] + variable_key_list + ["age"]) + assert age_var is not None + assert age_var.value == 30 + + def test_append_object_segment_value(self): + """Test appending an ObjectSegment value""" + pool = VariablePool() + node_id = "test_node" + variable_key_list = ["result"] + + # Create an ObjectSegment + obj_data = {"status": "success", "code": 200} + variable_value = ObjectSegment(value=obj_data) + + append_variables_recursively(pool, node_id, variable_key_list, variable_value) + + # Check that the main variable is added + main_var = pool.get([node_id] + variable_key_list) + assert main_var is not None + assert isinstance(main_var, ObjectSegment) + assert main_var.value == obj_data + + # Check that nested variables are added recursively + status_var = pool.get([node_id] + variable_key_list + ["status"]) + assert status_var is not None + assert status_var.value == "success" + + code_var = pool.get([node_id] + variable_key_list + ["code"]) + assert code_var is not None + assert code_var.value == 200 + + def test_append_nested_dict_value(self): + """Test appending a nested dictionary value""" + pool = VariablePool() + node_id = "test_node" + variable_key_list = ["data"] + + variable_value = { + "user": { + "profile": {"name": "Alice", "email": "alice@example.com"}, + "settings": {"theme": "dark", "notifications": True}, + }, + "metadata": {"version": "1.0", "timestamp": 1234567890}, + } + + append_variables_recursively(pool, node_id, variable_key_list, variable_value) + + # Check deeply nested variables + name_var = pool.get([node_id] + variable_key_list + ["user", "profile", "name"]) + assert name_var is not None + assert name_var.value == "Alice" + + email_var = pool.get([node_id] + variable_key_list + ["user", "profile", "email"]) + assert email_var is not None + assert email_var.value == "alice@example.com" + + theme_var = pool.get([node_id] + variable_key_list + ["user", "settings", "theme"]) + assert theme_var is not None + assert theme_var.value == "dark" + + notifications_var = pool.get([node_id] + variable_key_list + ["user", "settings", "notifications"]) + assert notifications_var is not None + assert notifications_var.value == 1 # Boolean True is converted to integer 1 + + version_var = pool.get([node_id] + variable_key_list + ["metadata", "version"]) + assert version_var is not None + assert version_var.value == "1.0" + + def test_append_non_dict_value(self): + """Test appending a non-dictionary value (should not recurse)""" + pool = VariablePool() + node_id = "test_node" + variable_key_list = ["simple"] + variable_value = "simple_string" + + append_variables_recursively(pool, node_id, variable_key_list, variable_value) + + # Check that only the main variable is added + main_var = pool.get([node_id] + variable_key_list) + assert main_var is not None + assert main_var.value == variable_value + + # Ensure no additional variables are created + assert len(pool.variable_dictionary[node_id]) == 1 + + def test_append_segment_non_object_value(self): + """Test appending a Segment that is not ObjectSegment (should not recurse)""" + pool = VariablePool() + node_id = "test_node" + variable_key_list = ["text"] + variable_value = StringSegment(value="Hello World") + + append_variables_recursively(pool, node_id, variable_key_list, variable_value) + + # Check that only the main variable is added + main_var = pool.get([node_id] + variable_key_list) + assert main_var is not None + assert isinstance(main_var, StringSegment) + assert main_var.value == "Hello World" + + # Ensure no additional variables are created + assert len(pool.variable_dictionary[node_id]) == 1 + + def test_append_empty_dict_value(self): + """Test appending an empty dictionary value""" + pool = VariablePool() + node_id = "test_node" + variable_key_list = ["empty"] + variable_value: dict[str, Any] = {} + + append_variables_recursively(pool, node_id, variable_key_list, variable_value) + + # Check that the main variable is added + main_var = pool.get([node_id] + variable_key_list) + assert main_var is not None + assert main_var.value == {} + + # Ensure only the main variable is created (no recursion for empty dict) + assert len(pool.variable_dictionary[node_id]) == 1 From 0f5417ab849024c78ad92a1d316af225d64880ab Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 25 Jun 2025 13:06:07 +0800 Subject: [PATCH 10/12] feat(web): Contains sys.files in the default template. (#21476) --- web/app/components/workflow-app/hooks/use-workflow-template.ts | 2 +- .../workflow/nodes/_base/components/memory-config.tsx | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/web/app/components/workflow-app/hooks/use-workflow-template.ts b/web/app/components/workflow-app/hooks/use-workflow-template.ts index 9f47b981dc..2bab08f205 100644 --- a/web/app/components/workflow-app/hooks/use-workflow-template.ts +++ b/web/app/components/workflow-app/hooks/use-workflow-template.ts @@ -22,7 +22,7 @@ export const useWorkflowTemplate = () => { ...nodesInitialData.llm, memory: { window: { enabled: false, size: 10 }, - query_prompt_template: '{{#sys.query#}}', + query_prompt_template: '{{#sys.query#}}\n\n{{#sys.files#}}', }, selected: true, }, diff --git a/web/app/components/workflow/nodes/_base/components/memory-config.tsx b/web/app/components/workflow/nodes/_base/components/memory-config.tsx index 446fcfa8ae..168ad656b6 100644 --- a/web/app/components/workflow/nodes/_base/components/memory-config.tsx +++ b/web/app/components/workflow/nodes/_base/components/memory-config.tsx @@ -53,7 +53,7 @@ type Props = { const MEMORY_DEFAULT: Memory = { window: { enabled: false, size: WINDOW_SIZE_DEFAULT }, - query_prompt_template: '{{#sys.query#}}', + query_prompt_template: '{{#sys.query#}}\n\n{{#sys.files#}}', } const MemoryConfig: FC = ({ From a098825fcc49e5fcff628d296ca922138ec197d5 Mon Sep 17 00:00:00 2001 From: AuditAIH <145266260+AuditAIH@users.noreply.github.com> Date: Wed, 25 Jun 2025 13:50:35 +0800 Subject: [PATCH 11/12] Update smtp.py (#21335) --- api/libs/smtp.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/api/libs/smtp.py b/api/libs/smtp.py index 35561f071c..b94386660e 100644 --- a/api/libs/smtp.py +++ b/api/libs/smtp.py @@ -22,7 +22,11 @@ class SMTPClient: if self.use_tls: if self.opportunistic_tls: smtp = smtplib.SMTP(self.server, self.port, timeout=10) + # Send EHLO command with the HELO domain name as the server address + smtp.ehlo(self.server) smtp.starttls() + # Resend EHLO command to identify the TLS session + smtp.ehlo(self.server) else: smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10) else: From 164e5481c5dab7a449cf063158d5d1a9cfea9506 Mon Sep 17 00:00:00 2001 From: Maries Date: Wed, 25 Jun 2025 14:14:30 +0800 Subject: [PATCH 12/12] feat(oauth): plugin oauth service (#21480) --- api/core/plugin/impl/oauth.py | 16 +++-- api/services/plugin/oauth_service.py | 62 +++++++++++++++++-- .../test_oauth_convert_request_to_raw_data.py | 57 ++++++++++++++++- 3 files changed, 126 insertions(+), 9 deletions(-) diff --git a/api/core/plugin/impl/oauth.py b/api/core/plugin/impl/oauth.py index 91774984c8..b006bf1d4b 100644 --- a/api/core/plugin/impl/oauth.py +++ b/api/core/plugin/impl/oauth.py @@ -1,3 +1,4 @@ +import binascii from collections.abc import Mapping from typing import Any @@ -16,7 +17,7 @@ class OAuthHandler(BasePluginClient): provider: str, system_credentials: Mapping[str, Any], ) -> PluginOAuthAuthorizationUrlResponse: - return self._request_with_plugin_daemon_response( + response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url", PluginOAuthAuthorizationUrlResponse, @@ -32,6 +33,9 @@ class OAuthHandler(BasePluginClient): "Content-Type": "application/json", }, ) + for resp in response: + return resp + raise ValueError("No response received from plugin daemon for authorization URL request.") def get_credentials( self, @@ -49,7 +53,7 @@ class OAuthHandler(BasePluginClient): # encode request to raw http request raw_request_bytes = self._convert_request_to_raw_data(request) - return self._request_with_plugin_daemon_response( + response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/oauth/get_credentials", PluginOAuthCredentialsResponse, @@ -58,7 +62,8 @@ class OAuthHandler(BasePluginClient): "data": { "provider": provider, "system_credentials": system_credentials, - "raw_request_bytes": raw_request_bytes, + # for json serialization + "raw_http_request": binascii.hexlify(raw_request_bytes).decode(), }, }, headers={ @@ -66,6 +71,9 @@ class OAuthHandler(BasePluginClient): "Content-Type": "application/json", }, ) + for resp in response: + return resp + raise ValueError("No response received from plugin daemon for authorization URL request.") def _convert_request_to_raw_data(self, request: Request) -> bytes: """ @@ -79,7 +87,7 @@ class OAuthHandler(BasePluginClient): """ # Start with the request line method = request.method - path = request.path + path = request.full_path protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1") raw_data = f"{method} {path} {protocol}\r\n".encode() diff --git a/api/services/plugin/oauth_service.py b/api/services/plugin/oauth_service.py index 461247419b..4077ec38d3 100644 --- a/api/services/plugin/oauth_service.py +++ b/api/services/plugin/oauth_service.py @@ -1,7 +1,61 @@ +import json +import uuid + from core.plugin.impl.base import BasePluginClient +from extensions.ext_redis import redis_client -class OAuthService(BasePluginClient): - @classmethod - def get_authorization_url(cls, tenant_id: str, user_id: str, provider_name: str) -> str: - return "1234567890" +class OAuthProxyService(BasePluginClient): + # Default max age for proxy context parameter in seconds + __MAX_AGE__ = 5 * 60 # 5 minutes + + @staticmethod + def create_proxy_context(user_id, tenant_id, plugin_id, provider): + """ + Create a proxy context for an OAuth 2.0 authorization request. + + This parameter is a crucial security measure to prevent Cross-Site Request + Forgery (CSRF) attacks. It works by generating a unique nonce and storing it + in a distributed cache (Redis) along with the user's session context. + + The returned nonce should be included as the 'proxy_context' parameter in the + authorization URL. Upon callback, the `use_proxy_context` method + is used to verify the state, ensuring the request's integrity and authenticity, + and mitigating replay attacks. + """ + seconds, _ = redis_client.time() + context_id = str(uuid.uuid4()) + data = { + "user_id": user_id, + "plugin_id": plugin_id, + "tenant_id": tenant_id, + "provider": provider, + # encode redis time to avoid distribution time skew + "timestamp": seconds, + } + # ignore nonce collision + redis_client.setex( + f"oauth_proxy_context:{context_id}", + OAuthProxyService.__MAX_AGE__, + json.dumps(data), + ) + return context_id + + @staticmethod + def use_proxy_context(context_id, max_age=__MAX_AGE__): + """ + Validate the proxy context parameter. + This checks if the context_id is valid and not expired. + """ + if not context_id: + raise ValueError("context_id is required") + # get data from redis + data = redis_client.getdel(f"oauth_proxy_context:{context_id}") + if not data: + raise ValueError("context_id is invalid") + # check if data is expired + seconds, _ = redis_client.time() + state = json.loads(data) + if state.get("timestamp") < seconds - max_age: + raise ValueError("context_id is expired") + return state diff --git a/api/tests/unit_tests/utils/http_parser/test_oauth_convert_request_to_raw_data.py b/api/tests/unit_tests/utils/http_parser/test_oauth_convert_request_to_raw_data.py index f788a9756b..293ac253f5 100644 --- a/api/tests/unit_tests/utils/http_parser/test_oauth_convert_request_to_raw_data.py +++ b/api/tests/unit_tests/utils/http_parser/test_oauth_convert_request_to_raw_data.py @@ -1,3 +1,5 @@ +import json + from werkzeug import Request from werkzeug.datastructures import Headers from werkzeug.test import EnvironBuilder @@ -15,6 +17,59 @@ def test_oauth_convert_request_to_raw_data(): request = Request(builder.get_environ()) raw_request_bytes = oauth_handler._convert_request_to_raw_data(request) - assert b"GET /test HTTP/1.1" in raw_request_bytes + assert b"GET /test? HTTP/1.1" in raw_request_bytes assert b"Content-Type: application/json" in raw_request_bytes assert b"\r\n\r\n" in raw_request_bytes + + +def test_oauth_convert_request_to_raw_data_with_query_params(): + oauth_handler = OAuthHandler() + builder = EnvironBuilder( + method="GET", + path="/test", + query_string="code=abc123&state=xyz789", + headers=Headers({"Content-Type": "application/json"}), + ) + request = Request(builder.get_environ()) + raw_request_bytes = oauth_handler._convert_request_to_raw_data(request) + + assert b"GET /test?code=abc123&state=xyz789 HTTP/1.1" in raw_request_bytes + assert b"Content-Type: application/json" in raw_request_bytes + assert b"\r\n\r\n" in raw_request_bytes + + +def test_oauth_convert_request_to_raw_data_with_post_body(): + oauth_handler = OAuthHandler() + builder = EnvironBuilder( + method="POST", + path="/test", + data="param1=value1¶m2=value2", + headers=Headers({"Content-Type": "application/x-www-form-urlencoded"}), + ) + request = Request(builder.get_environ()) + raw_request_bytes = oauth_handler._convert_request_to_raw_data(request) + + assert b"POST /test? HTTP/1.1" in raw_request_bytes + assert b"Content-Type: application/x-www-form-urlencoded" in raw_request_bytes + assert b"\r\n\r\n" in raw_request_bytes + assert b"param1=value1¶m2=value2" in raw_request_bytes + + +def test_oauth_convert_request_to_raw_data_with_json_body(): + oauth_handler = OAuthHandler() + json_data = {"code": "abc123", "state": "xyz789", "grant_type": "authorization_code"} + builder = EnvironBuilder( + method="POST", + path="/test", + data=json.dumps(json_data), + headers=Headers({"Content-Type": "application/json"}), + ) + request = Request(builder.get_environ()) + raw_request_bytes = oauth_handler._convert_request_to_raw_data(request) + + assert b"POST /test? HTTP/1.1" in raw_request_bytes + assert b"Content-Type: application/json" in raw_request_bytes + assert b"\r\n\r\n" in raw_request_bytes + assert b'"code": "abc123"' in raw_request_bytes + assert b'"state": "xyz789"' in raw_request_bytes + assert b'"grant_type": "authorization_code"' in raw_request_bytes