From f295c7532c1e0a03027c67b573faccd18436f469 Mon Sep 17 00:00:00 2001 From: Xin Zhang Date: Thu, 16 Oct 2025 10:58:28 +0800 Subject: [PATCH 01/46] fix plugin installation permissions when using a local pkg (#26822) Co-authored-by: zhangx1n --- api/services/plugin/plugin_service.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/api/services/plugin/plugin_service.py b/api/services/plugin/plugin_service.py index 604adeb7b5..525ccc9417 100644 --- a/api/services/plugin/plugin_service.py +++ b/api/services/plugin/plugin_service.py @@ -336,6 +336,8 @@ class PluginService: pkg, verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only, ) + PluginService._check_plugin_installation_scope(response.verification) + return response @staticmethod @@ -358,6 +360,8 @@ class PluginService: pkg, verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only, ) + PluginService._check_plugin_installation_scope(response.verification) + return response @staticmethod @@ -377,6 +381,10 @@ class PluginService: manager = PluginInstaller() + for plugin_unique_identifier in plugin_unique_identifiers: + resp = manager.decode_plugin_from_identifier(tenant_id, plugin_unique_identifier) + PluginService._check_plugin_installation_scope(resp.verification) + return manager.install_from_identifiers( tenant_id, plugin_unique_identifiers, @@ -393,6 +401,9 @@ class PluginService: PluginService._check_marketplace_only_permission() manager = PluginInstaller() + plugin_decode_response = manager.decode_plugin_from_identifier(tenant_id, plugin_unique_identifier) + PluginService._check_plugin_installation_scope(plugin_decode_response.verification) + return manager.install_from_identifiers( tenant_id, [plugin_unique_identifier], From 35011b810d777cadc6aea1a2775c48357546c1ef Mon Sep 17 00:00:00 2001 From: wellCh4n Date: Thu, 16 Oct 2025 11:01:11 +0800 Subject: [PATCH 02/46] feat: run with params from logs (#26787) Co-authored-by: lyzno1 Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com> --- .../components/app/workflow-log/detail.tsx | 27 ++++++- web/app/components/workflow-app/index.tsx | 71 +++++++++++++++++++ .../workflow/store/workflow/form-slice.ts | 4 +- web/i18n/de-DE/app-log.ts | 1 + web/i18n/en-US/app-log.ts | 1 + web/i18n/es-ES/app-log.ts | 1 + web/i18n/fa-IR/app-log.ts | 1 + web/i18n/fr-FR/app-log.ts | 1 + web/i18n/hi-IN/app-log.ts | 1 + web/i18n/id-ID/app-log.ts | 1 + web/i18n/it-IT/app-log.ts | 1 + web/i18n/ja-JP/app-log.ts | 1 + web/i18n/ko-KR/app-log.ts | 1 + web/i18n/pl-PL/app-log.ts | 1 + web/i18n/pt-BR/app-log.ts | 1 + web/i18n/ro-RO/app-log.ts | 1 + web/i18n/ru-RU/app-log.ts | 1 + web/i18n/sl-SI/app-log.ts | 1 + web/i18n/th-TH/app-log.ts | 1 + web/i18n/tr-TR/app-log.ts | 1 + web/i18n/uk-UA/app-log.ts | 1 + web/i18n/vi-VN/app-log.ts | 1 + web/i18n/zh-Hans/app-log.ts | 1 + web/i18n/zh-Hant/app-log.ts | 1 + 24 files changed, 119 insertions(+), 4 deletions(-) diff --git a/web/app/components/app/workflow-log/detail.tsx b/web/app/components/app/workflow-log/detail.tsx index 812438c0ed..7ce701dd68 100644 --- a/web/app/components/app/workflow-log/detail.tsx +++ b/web/app/components/app/workflow-log/detail.tsx @@ -1,9 +1,11 @@ 'use client' import type { FC } from 'react' import { useTranslation } from 'react-i18next' -import { RiCloseLine } from '@remixicon/react' +import { RiCloseLine, RiPlayLargeLine } from '@remixicon/react' import Run from '@/app/components/workflow/run' import { useStore } from '@/app/components/app/store' +import TooltipPlus from '@/app/components/base/tooltip' +import { useRouter } from 'next/navigation' type ILogDetail = { runID: string @@ -13,13 +15,34 @@ type ILogDetail = { const DetailPanel: FC = ({ runID, onClose }) => { const { t } = useTranslation() const appDetail = useStore(state => state.appDetail) + const router = useRouter() + + const handleReplay = () => { + if (!appDetail?.id) return + router.push(`/app/${appDetail.id}/workflow?replayRunId=${runID}`) + } return (
-

{t('appLog.runDetail.workflowTitle')}

+
+

{t('appLog.runDetail.workflowTitle')}

+ + + +
{ const { @@ -47,6 +53,71 @@ const WorkflowAppWithAdditionalContext = () => { return [] }, [data]) + const searchParams = useSearchParams() + const workflowStore = useWorkflowStore() + const { getWorkflowRunAndTraceUrl } = useGetRunAndTraceUrl() + const replayRunId = searchParams.get('replayRunId') + + useEffect(() => { + if (!replayRunId) + return + const { runUrl } = getWorkflowRunAndTraceUrl(replayRunId) + if (!runUrl) + return + fetchRunDetail(runUrl).then((res) => { + const { setInputs, setShowInputsPanel, setShowDebugAndPreviewPanel } = workflowStore.getState() + const rawInputs = res.inputs + let parsedInputs: Record | null = null + + if (typeof rawInputs === 'string') { + try { + const maybeParsed = JSON.parse(rawInputs) as unknown + if (maybeParsed && typeof maybeParsed === 'object' && !Array.isArray(maybeParsed)) + parsedInputs = maybeParsed as Record + } + catch (error) { + console.error('Failed to parse workflow run inputs', error) + } + } + else if (rawInputs && typeof rawInputs === 'object' && !Array.isArray(rawInputs)) { + parsedInputs = rawInputs as Record + } + + if (!parsedInputs) + return + + const userInputs: Record = {} + Object.entries(parsedInputs).forEach(([key, value]) => { + if (key.startsWith('sys.')) + return + + if (value == null) { + userInputs[key] = '' + return + } + + if (typeof value === 'string' || typeof value === 'number' || typeof value === 'boolean') { + userInputs[key] = value + return + } + + try { + userInputs[key] = JSON.stringify(value) + } + catch { + userInputs[key] = String(value) + } + }) + + if (!Object.keys(userInputs).length) + return + + setInputs(userInputs) + setShowInputsPanel(true) + setShowDebugAndPreviewPanel(true) + }) + }, [replayRunId, workflowStore, getWorkflowRunAndTraceUrl]) + if (!data || isLoading || isLoadingCurrentWorkspace || !currentWorkspace.id) { return (
diff --git a/web/app/components/workflow/store/workflow/form-slice.ts b/web/app/components/workflow/store/workflow/form-slice.ts index a6c607d2af..46391eddff 100644 --- a/web/app/components/workflow/store/workflow/form-slice.ts +++ b/web/app/components/workflow/store/workflow/form-slice.ts @@ -4,8 +4,8 @@ import type { } from '@/app/components/workflow/types' export type FormSliceShape = { - inputs: Record - setInputs: (inputs: Record) => void + inputs: Record + setInputs: (inputs: Record) => void files: RunFile[] setFiles: (files: RunFile[]) => void } diff --git a/web/i18n/de-DE/app-log.ts b/web/i18n/de-DE/app-log.ts index 9d63a58259..0fbdcca0bf 100644 --- a/web/i18n/de-DE/app-log.ts +++ b/web/i18n/de-DE/app-log.ts @@ -83,6 +83,7 @@ const translation = { workflowTitle: 'Protokolldetail', fileListLabel: 'Details zur Datei', fileListDetail: 'Detail', + testWithParams: 'Test mit Parametern', }, promptLog: 'Prompt-Protokoll', agentLog: 'Agentenprotokoll', diff --git a/web/i18n/en-US/app-log.ts b/web/i18n/en-US/app-log.ts index 97773e6efd..946d8ffcb7 100644 --- a/web/i18n/en-US/app-log.ts +++ b/web/i18n/en-US/app-log.ts @@ -83,6 +83,7 @@ const translation = { workflowTitle: 'Log Detail', fileListLabel: 'File Details', fileListDetail: 'Detail', + testWithParams: 'Test With Params', }, promptLog: 'Prompt Log', agentLog: 'Agent Log', diff --git a/web/i18n/es-ES/app-log.ts b/web/i18n/es-ES/app-log.ts index 74ebee902e..0044dee709 100644 --- a/web/i18n/es-ES/app-log.ts +++ b/web/i18n/es-ES/app-log.ts @@ -82,6 +82,7 @@ const translation = { workflowTitle: 'Detalle del Registro', fileListLabel: 'Detalles del archivo', fileListDetail: 'Detalle', + testWithParams: 'Prueba con parámetros', }, promptLog: 'Registro de Indicación', agentLog: 'Registro de Agente', diff --git a/web/i18n/fa-IR/app-log.ts b/web/i18n/fa-IR/app-log.ts index 38f7267c6e..526fa01e76 100644 --- a/web/i18n/fa-IR/app-log.ts +++ b/web/i18n/fa-IR/app-log.ts @@ -82,6 +82,7 @@ const translation = { workflowTitle: 'جزئیات لاگ', fileListLabel: 'جزئیات فایل', fileListDetail: 'جزئیات', + testWithParams: 'تست با پارامترها', }, promptLog: 'لاگ درخواست', agentLog: 'لاگ عامل', diff --git a/web/i18n/fr-FR/app-log.ts b/web/i18n/fr-FR/app-log.ts index 50f2ff358d..42e25ba21f 100644 --- a/web/i18n/fr-FR/app-log.ts +++ b/web/i18n/fr-FR/app-log.ts @@ -82,6 +82,7 @@ const translation = { workflowTitle: 'Détail du journal', fileListDetail: 'Détail', fileListLabel: 'Détails du fichier', + testWithParams: 'Test avec paramètres', }, promptLog: 'Journal de consigne', agentLog: 'Journal des agents', diff --git a/web/i18n/hi-IN/app-log.ts b/web/i18n/hi-IN/app-log.ts index bb95b46052..02e062df2e 100644 --- a/web/i18n/hi-IN/app-log.ts +++ b/web/i18n/hi-IN/app-log.ts @@ -84,6 +84,7 @@ const translation = { workflowTitle: 'लॉग विवरण', fileListDetail: 'विस्तार', fileListLabel: 'फ़ाइल विवरण', + testWithParams: 'पैरामीटर्स के साथ परीक्षण', }, promptLog: 'प्रॉम्प्ट लॉग', agentLog: 'एजेंट लॉग', diff --git a/web/i18n/id-ID/app-log.ts b/web/i18n/id-ID/app-log.ts index 1ccf8dec1e..8192e1f40d 100644 --- a/web/i18n/id-ID/app-log.ts +++ b/web/i18n/id-ID/app-log.ts @@ -74,6 +74,7 @@ const translation = { workflowTitle: 'Log Detail', title: 'Log Percakapan', fileListLabel: 'Rincian File', + testWithParams: 'Uji Dengan Param', }, agentLogDetail: { iterations: 'Iterasi', diff --git a/web/i18n/it-IT/app-log.ts b/web/i18n/it-IT/app-log.ts index 8653b765bd..98cb6afd84 100644 --- a/web/i18n/it-IT/app-log.ts +++ b/web/i18n/it-IT/app-log.ts @@ -86,6 +86,7 @@ const translation = { workflowTitle: 'Dettagli Registro', fileListDetail: 'Dettaglio', fileListLabel: 'Dettagli del file', + testWithParams: 'Test con parametri', }, promptLog: 'Registro Prompt', agentLog: 'Registro Agente', diff --git a/web/i18n/ja-JP/app-log.ts b/web/i18n/ja-JP/app-log.ts index db42d317f4..714481c8d1 100644 --- a/web/i18n/ja-JP/app-log.ts +++ b/web/i18n/ja-JP/app-log.ts @@ -83,6 +83,7 @@ const translation = { workflowTitle: 'ログの詳細', fileListLabel: 'ファイルの詳細', fileListDetail: '詳細', + testWithParams: 'パラメータ付きテスト', }, promptLog: 'プロンプトログ', agentLog: 'エージェントログ', diff --git a/web/i18n/ko-KR/app-log.ts b/web/i18n/ko-KR/app-log.ts index 366d2cc1c2..1701d588b0 100644 --- a/web/i18n/ko-KR/app-log.ts +++ b/web/i18n/ko-KR/app-log.ts @@ -83,6 +83,7 @@ const translation = { workflowTitle: '로그 세부 정보', fileListDetail: '세부', fileListLabel: '파일 세부 정보', + testWithParams: '매개변수로 테스트', }, promptLog: '프롬프트 로그', agentLog: '에이전트 로그', diff --git a/web/i18n/pl-PL/app-log.ts b/web/i18n/pl-PL/app-log.ts index 09fdb426fc..90ad14ad0c 100644 --- a/web/i18n/pl-PL/app-log.ts +++ b/web/i18n/pl-PL/app-log.ts @@ -86,6 +86,7 @@ const translation = { workflowTitle: 'Szczegół dziennika', fileListDetail: 'Detal', fileListLabel: 'Szczegóły pliku', + testWithParams: 'Test z parametrami', }, promptLog: 'Dziennik monitów', agentLog: 'Dziennik agenta', diff --git a/web/i18n/pt-BR/app-log.ts b/web/i18n/pt-BR/app-log.ts index 428291d871..9e2ff80759 100644 --- a/web/i18n/pt-BR/app-log.ts +++ b/web/i18n/pt-BR/app-log.ts @@ -82,6 +82,7 @@ const translation = { workflowTitle: 'Detalhes do Registro', fileListLabel: 'Detalhes do arquivo', fileListDetail: 'Detalhe', + testWithParams: 'Teste com parâmetros', }, promptLog: 'Registro de Prompt', agentLog: 'Registro do agente', diff --git a/web/i18n/ro-RO/app-log.ts b/web/i18n/ro-RO/app-log.ts index 3f9609d832..4a6e9bd96e 100644 --- a/web/i18n/ro-RO/app-log.ts +++ b/web/i18n/ro-RO/app-log.ts @@ -82,6 +82,7 @@ const translation = { workflowTitle: 'Detalii jurnal', fileListDetail: 'Amănunt', fileListLabel: 'Detalii fișier', + testWithParams: 'Test cu parametri', }, promptLog: 'Jurnal prompt', agentLog: 'Jurnal agent', diff --git a/web/i18n/ru-RU/app-log.ts b/web/i18n/ru-RU/app-log.ts index 4ca48e723b..f874f5f523 100644 --- a/web/i18n/ru-RU/app-log.ts +++ b/web/i18n/ru-RU/app-log.ts @@ -82,6 +82,7 @@ const translation = { workflowTitle: 'Подробная информация о журнале', fileListLabel: 'Сведения о файле', fileListDetail: 'Подробность', + testWithParams: 'Тест с параметрами', }, promptLog: 'Журнал подсказок', agentLog: 'Журнал агента', diff --git a/web/i18n/sl-SI/app-log.ts b/web/i18n/sl-SI/app-log.ts index 598f4f04c6..7f7cba0fa3 100644 --- a/web/i18n/sl-SI/app-log.ts +++ b/web/i18n/sl-SI/app-log.ts @@ -82,6 +82,7 @@ const translation = { workflowTitle: 'Podrobnosti dnevnika', fileListDetail: 'Podrobnosti', fileListLabel: 'Podrobnosti o datoteki', + testWithParams: 'Preizkus s parametri', }, promptLog: 'Dnevnik PROMPT-ov', agentLog: 'Dnevnik pomočnika', diff --git a/web/i18n/th-TH/app-log.ts b/web/i18n/th-TH/app-log.ts index a4952afeb9..8c47eaafdd 100644 --- a/web/i18n/th-TH/app-log.ts +++ b/web/i18n/th-TH/app-log.ts @@ -82,6 +82,7 @@ const translation = { workflowTitle: 'รายละเอียดบันทึก', fileListDetail: 'รายละเอียด', fileListLabel: 'รายละเอียดไฟล์', + testWithParams: 'ทดสอบด้วยพารามิเตอร์', }, promptLog: 'บันทึกพร้อมท์', agentLog: 'บันทึกตัวแทน', diff --git a/web/i18n/tr-TR/app-log.ts b/web/i18n/tr-TR/app-log.ts index dbc3c5708b..380af8fd59 100644 --- a/web/i18n/tr-TR/app-log.ts +++ b/web/i18n/tr-TR/app-log.ts @@ -82,6 +82,7 @@ const translation = { workflowTitle: 'Günlük Detayı', fileListDetail: 'Ayrıntı', fileListLabel: 'Dosya Detayları', + testWithParams: 'Parametrelerle Test', }, promptLog: 'Prompt Günlüğü', agentLog: 'Agent Günlüğü', diff --git a/web/i18n/uk-UA/app-log.ts b/web/i18n/uk-UA/app-log.ts index 1d361150db..8f8f3db5da 100644 --- a/web/i18n/uk-UA/app-log.ts +++ b/web/i18n/uk-UA/app-log.ts @@ -82,6 +82,7 @@ const translation = { workflowTitle: 'Деталі Журналу', fileListDetail: 'Деталь', fileListLabel: 'Подробиці файлу', + testWithParams: 'Тест з параметрами', }, promptLog: 'Журнал Запитань', agentLog: 'Журнал агента', diff --git a/web/i18n/vi-VN/app-log.ts b/web/i18n/vi-VN/app-log.ts index e53b98b004..167b8747e2 100644 --- a/web/i18n/vi-VN/app-log.ts +++ b/web/i18n/vi-VN/app-log.ts @@ -82,6 +82,7 @@ const translation = { workflowTitle: 'Chi tiết nhật ký', fileListDetail: 'Chi tiết', fileListLabel: 'Chi tiết tệp', + testWithParams: 'Kiểm tra với các tham số', }, promptLog: 'Nhật ký lời nhắc', viewLog: 'Xem nhật ký', diff --git a/web/i18n/zh-Hans/app-log.ts b/web/i18n/zh-Hans/app-log.ts index 51b7ebb1e0..26c5c915c5 100644 --- a/web/i18n/zh-Hans/app-log.ts +++ b/web/i18n/zh-Hans/app-log.ts @@ -83,6 +83,7 @@ const translation = { workflowTitle: '日志详情', fileListLabel: '文件详情', fileListDetail: '详情', + testWithParams: '按此参数测试', }, promptLog: 'Prompt 日志', agentLog: 'Agent 日志', diff --git a/web/i18n/zh-Hant/app-log.ts b/web/i18n/zh-Hant/app-log.ts index 9d577c682f..d24b4a1cce 100644 --- a/web/i18n/zh-Hant/app-log.ts +++ b/web/i18n/zh-Hant/app-log.ts @@ -82,6 +82,7 @@ const translation = { workflowTitle: '日誌詳情', fileListDetail: '細節', fileListLabel: '檔詳細資訊', + testWithParams: '使用參數測試', }, promptLog: 'Prompt 日誌', agentLog: 'Agent 日誌', From bd01af64150dc1887d4be17db22af06b3d6e8c1a Mon Sep 17 00:00:00 2001 From: Xiyuan Chen <52963600+GareArc@users.noreply.github.com> Date: Wed, 15 Oct 2025 21:15:26 -0700 Subject: [PATCH 03/46] =?UTF-8?q?fix:=20update=20load=20balancing=20config?= =?UTF-8?q?urations=20with=20new=20credential=20IDs=20and=E2=80=A6=20(#269?= =?UTF-8?q?00)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/core/entities/provider_configuration.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index b069955836..c4be429219 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -1148,6 +1148,15 @@ class ProviderConfiguration(BaseModel): raise ValueError("Can't add same credential") provider_model_record.credential_id = credential_record.id provider_model_record.updated_at = naive_utc_now() + + # clear cache + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_model_record.id, + cache_type=ProviderCredentialsCacheType.MODEL, + ) + provider_model_credentials_cache.delete() + session.add(provider_model_record) session.commit() @@ -1181,6 +1190,14 @@ class ProviderConfiguration(BaseModel): session.add(provider_model_record) session.commit() + # clear cache + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_model_record.id, + cache_type=ProviderCredentialsCacheType.MODEL, + ) + provider_model_credentials_cache.delete() + def delete_custom_model(self, model_type: ModelType, model: str): """ Delete custom model. From cced33d068aa330d9dcaa6b4261c892c096ded28 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Thu, 16 Oct 2025 15:45:51 +0900 Subject: [PATCH 04/46] use deco to avoid current_user (#26077) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .github/workflows/api-tests.yml | 30 ++-- api/controllers/console/apikey.py | 8 +- api/controllers/console/app/annotation.py | 82 ++-------- api/controllers/console/app/app.py | 63 +++----- api/controllers/console/app/app_import.py | 14 +- api/controllers/console/app/completion.py | 11 +- api/controllers/console/app/conversation.py | 40 +++-- api/controllers/console/app/mcp_server.py | 32 ++-- api/controllers/console/app/message.py | 34 ++-- api/controllers/console/app/site.py | 6 +- api/controllers/console/app/statistic.py | 34 ++-- api/controllers/console/app/workflow.py | 150 ++++-------------- .../console/app/workflow_draft_variable.py | 3 +- api/controllers/console/app/workflow_run.py | 3 +- .../console/app/workflow_statistic.py | 19 ++- api/controllers/console/app/wraps.py | 13 +- api/controllers/console/auth/activate.py | 2 +- .../console/auth/data_source_oauth.py | 4 +- .../console/auth/email_register.py | 2 +- .../console/auth/forgot_password.py | 2 +- api/controllers/console/auth/login.py | 7 +- api/controllers/console/auth/oauth.py | 3 +- api/controllers/console/auth/oauth_server.py | 10 +- api/controllers/console/billing/billing.py | 16 +- api/controllers/console/billing/compliance.py | 9 +- api/controllers/console/datasets/datasets.py | 57 ++++--- .../console/datasets/datasets_document.py | 41 +++-- api/controllers/console/datasets/external.py | 28 ++-- api/controllers/console/datasets/metadata.py | 8 +- .../datasets/rag_pipeline/datasource_auth.py | 41 ++--- .../rag_pipeline_draft_variable.py | 2 +- .../rag_pipeline/rag_pipeline_import.py | 22 +-- .../rag_pipeline/rag_pipeline_workflow.py | 102 ++++++------ api/controllers/console/datasets/wraps.py | 8 +- api/controllers/console/explore/message.py | 15 +- .../console/explore/saved_message.py | 12 +- api/controllers/console/explore/wraps.py | 12 +- api/controllers/console/extension.py | 12 +- api/controllers/console/files.py | 9 +- api/controllers/console/tag/tags.py | 23 +-- api/controllers/console/workspace/__init__.py | 5 +- api/controllers/console/workspace/account.py | 61 +++---- .../console/workspace/agent_providers.py | 16 +- .../workspace/load_balancing_config.py | 14 +- .../console/workspace/workspace.py | 33 ++-- api/controllers/console/wraps.py | 22 ++- api/controllers/inner_api/mail.py | 2 +- api/controllers/inner_api/plugin/plugin.py | 2 +- .../inner_api/workspace/workspace.py | 2 +- api/controllers/service_api/app/annotation.py | 2 +- .../rag_pipeline/rag_pipeline_workflow.py | 2 +- api/controllers/service_api/wraps.py | 2 +- api/controllers/web/forgot_password.py | 2 +- .../advanced_chat/generate_task_pipeline.py | 3 +- api/core/app/apps/chat/app_generator.py | 2 +- .../apps/workflow/generate_task_pipeline.py | 2 +- api/core/ops/ops_trace_manager.py | 2 +- api/core/plugin/backwards_invocation/app.py | 2 +- .../celery_workflow_execution_repository.py | 2 +- api/core/tools/utils/message_transformer.py | 2 +- .../clean_when_dataset_deleted.py | 5 +- api/extensions/ext_login.py | 2 +- api/libs/external_api.py | 4 +- api/libs/helper.py | 6 +- api/libs/login.py | 2 +- .../mail_clean_document_notify_task.py | 2 +- api/services/agent_service.py | 2 +- api/services/app_service.py | 2 +- api/services/billing_service.py | 2 +- api/services/conversation_service.py | 3 +- api/services/dataset_service.py | 2 +- api/services/file_service.py | 2 +- api/services/hit_testing_service.py | 2 +- api/services/message_service.py | 2 +- api/services/metadata_service.py | 15 +- api/services/oauth_server.py | 2 +- .../customized/customized_retrieval.py | 7 +- api/services/rag_pipeline/rag_pipeline.py | 2 +- api/services/saved_message_service.py | 2 +- api/services/web_conversation_service.py | 2 +- api/services/webapp_auth_service.py | 2 +- api/services/workflow/workflow_converter.py | 2 +- .../workflow_draft_variable_service.py | 3 +- api/services/workflow_service.py | 2 +- api/tasks/delete_account_task.py | 2 +- .../priority_rag_pipeline_run_task.py | 2 +- .../rag_pipeline/rag_pipeline_run_task.py | 2 +- api/tasks/retry_document_indexing_task.py | 2 +- .../services/test_account_service.py | 16 +- .../services/test_agent_service.py | 2 +- .../services/test_annotation_service.py | 2 +- .../services/test_app_service.py | 2 +- .../services/test_file_service.py | 2 +- .../services/test_metadata_service.py | 6 +- .../services/test_model_provider_service.py | 2 +- .../services/test_tag_service.py | 2 +- .../services/test_web_conversation_service.py | 2 +- .../services/test_webapp_auth_service.py | 2 +- .../services/test_workspace_service.py | 2 +- .../tools/test_api_tools_manage_service.py | 2 +- .../tools/test_mcp_tools_manage_service.py | 2 +- .../workflow/test_workflow_converter.py | 2 +- .../tasks/test_add_document_to_index_task.py | 2 +- .../tasks/test_batch_clean_document_task.py | 2 +- ...test_batch_create_segment_to_index_task.py | 2 +- .../tasks/test_clean_dataset_task.py | 2 +- .../test_create_segment_to_index_task.py | 2 +- .../test_disable_segment_from_index_task.py | 2 +- .../tasks/test_document_indexing_task.py | 2 +- 109 files changed, 526 insertions(+), 788 deletions(-) diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index 116fc59ee8..37d351627b 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -39,25 +39,11 @@ jobs: - name: Install dependencies run: uv sync --project api --dev - - name: Run Unit tests - run: | - uv run --project api bash dev/pytest/pytest_unit_tests.sh - - name: Run pyrefly check run: | cd api uv add --dev pyrefly uv run pyrefly check || true - - name: Coverage Summary - run: | - set -x - # Extract coverage percentage and create a summary - TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])') - - # Create a detailed coverage summary - echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY - echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY - uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY - name: Run dify config tests run: uv run --project api dev/pytest/pytest_config_tests.py @@ -93,3 +79,19 @@ jobs: - name: Run TestContainers run: uv run --project api bash dev/pytest/pytest_testcontainers.sh + + - name: Run Unit tests + run: | + uv run --project api bash dev/pytest/pytest_unit_tests.sh + + - name: Coverage Summary + run: | + set -x + # Extract coverage percentage and create a summary + TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])') + + # Create a detailed coverage summary + echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY + echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY + uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY + diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 7c39b04464..4f04af7932 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -12,7 +12,7 @@ from models.dataset import Dataset from models.model import ApiToken, App from . import api, console_ns -from .wraps import account_initialization_required, setup_required +from .wraps import account_initialization_required, edit_permission_required, setup_required api_key_fields = { "id": fields.String, @@ -67,14 +67,12 @@ class BaseApiKeyListResource(Resource): return {"items": keys} @marshal_with(api_key_fields) + @edit_permission_required def post(self, resource_id): assert self.resource_id_field is not None, "resource_id_field must be set" resource_id = str(resource_id) - current_user, current_tenant_id = current_account_with_tenant() + _, current_tenant_id = current_account_with_tenant() _get_resource(resource_id, current_tenant_id, self.resource_model) - if not current_user.has_edit_permission: - raise Forbidden() - current_key_count = ( db.session.query(ApiToken) .where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 589955738c..3e549d869e 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -2,13 +2,13 @@ from typing import Literal from flask import request from flask_restx import Resource, fields, marshal, marshal_with, reqparse -from werkzeug.exceptions import Forbidden from controllers.common.errors import NoFileUploadedError, TooManyFilesError from controllers.console import api, console_ns from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_resource_check, + edit_permission_required, setup_required, ) from extensions.ext_redis import redis_client @@ -16,7 +16,7 @@ from fields.annotation_fields import ( annotation_fields, annotation_hit_history_fields, ) -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required from services.annotation_service import AppAnnotationService @@ -41,12 +41,8 @@ class AnnotationReplyActionApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_resource_check("annotation") + @edit_permission_required def post(self, app_id, action: Literal["enable", "disable"]): - current_user, _ = current_account_with_tenant() - - if not current_user.has_edit_permission: - raise Forbidden() - app_id = str(app_id) parser = reqparse.RequestParser() parser.add_argument("score_threshold", required=True, type=float, location="json") @@ -70,12 +66,8 @@ class AppAnnotationSettingDetailApi(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required def get(self, app_id): - current_user, _ = current_account_with_tenant() - - if not current_user.has_edit_permission: - raise Forbidden() - app_id = str(app_id) result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id) return result, 200 @@ -101,12 +93,8 @@ class AppAnnotationSettingUpdateApi(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required def post(self, app_id, annotation_setting_id): - current_user, _ = current_account_with_tenant() - - if not current_user.has_edit_permission: - raise Forbidden() - app_id = str(app_id) annotation_setting_id = str(annotation_setting_id) @@ -129,12 +117,8 @@ class AnnotationReplyActionStatusApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_resource_check("annotation") + @edit_permission_required def get(self, app_id, job_id, action): - current_user, _ = current_account_with_tenant() - - if not current_user.has_edit_permission: - raise Forbidden() - job_id = str(job_id) app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}" cache_result = redis_client.get(app_annotation_job_key) @@ -166,12 +150,8 @@ class AnnotationApi(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required def get(self, app_id): - current_user, _ = current_account_with_tenant() - - if not current_user.has_edit_permission: - raise Forbidden() - page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) keyword = request.args.get("keyword", default="", type=str) @@ -207,12 +187,8 @@ class AnnotationApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("annotation") @marshal_with(annotation_fields) + @edit_permission_required def post(self, app_id): - current_user, _ = current_account_with_tenant() - - if not current_user.has_edit_permission: - raise Forbidden() - app_id = str(app_id) parser = reqparse.RequestParser() parser.add_argument("question", required=True, type=str, location="json") @@ -224,12 +200,8 @@ class AnnotationApi(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required def delete(self, app_id): - current_user, _ = current_account_with_tenant() - - if not current_user.has_edit_permission: - raise Forbidden() - app_id = str(app_id) # Use request.args.getlist to get annotation_ids array directly @@ -262,12 +234,8 @@ class AnnotationExportApi(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required def get(self, app_id): - current_user, _ = current_account_with_tenant() - - if not current_user.has_edit_permission: - raise Forbidden() - app_id = str(app_id) annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id) response = {"data": marshal(annotation_list, annotation_fields)} @@ -286,13 +254,9 @@ class AnnotationUpdateDeleteApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_resource_check("annotation") + @edit_permission_required @marshal_with(annotation_fields) def post(self, app_id, annotation_id): - current_user, _ = current_account_with_tenant() - - if not current_user.has_edit_permission: - raise Forbidden() - app_id = str(app_id) annotation_id = str(annotation_id) parser = reqparse.RequestParser() @@ -305,12 +269,8 @@ class AnnotationUpdateDeleteApi(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required def delete(self, app_id, annotation_id): - current_user, _ = current_account_with_tenant() - - if not current_user.has_edit_permission: - raise Forbidden() - app_id = str(app_id) annotation_id = str(annotation_id) AppAnnotationService.delete_app_annotation(app_id, annotation_id) @@ -329,12 +289,8 @@ class AnnotationBatchImportApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_resource_check("annotation") + @edit_permission_required def post(self, app_id): - current_user, _ = current_account_with_tenant() - - if not current_user.has_edit_permission: - raise Forbidden() - app_id = str(app_id) # check file if "file" not in request.files: @@ -362,12 +318,8 @@ class AnnotationBatchImportStatusApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_resource_check("annotation") + @edit_permission_required def get(self, app_id, job_id): - current_user, _ = current_account_with_tenant() - - if not current_user.has_edit_permission: - raise Forbidden() - job_id = str(job_id) indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}" cache_result = redis_client.get(indexing_cache_key) @@ -399,12 +351,8 @@ class AnnotationHitHistoryListApi(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required def get(self, app_id, annotation_id): - current_user, _ = current_account_with_tenant() - - if not current_user.has_edit_permission: - raise Forbidden() - page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) app_id = str(app_id) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 3927685af3..3900f5a6eb 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,7 +1,5 @@ import uuid -from typing import cast -from flask_login import current_user from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from sqlalchemy import select from sqlalchemy.orm import Session @@ -12,15 +10,16 @@ from controllers.console.app.wraps import get_app_model from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_resource_check, + edit_permission_required, enterprise_license_required, setup_required, ) from core.ops.ops_trace_manager import OpsTraceManager from extensions.ext_database import db from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from libs.validators import validate_description_length -from models import Account, App +from models import App from services.app_dsl_service import AppDslService, ImportMode from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService @@ -56,6 +55,7 @@ class AppListApi(Resource): @enterprise_license_required def get(self): """Get app list""" + current_user, current_tenant_id = current_account_with_tenant() def uuid_list(value): try: @@ -90,7 +90,7 @@ class AppListApi(Resource): # get app list app_service = AppService() - app_pagination = app_service.get_paginate_apps(current_user.id, current_user.current_tenant_id, args) + app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args) if not app_pagination: return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} @@ -129,8 +129,10 @@ class AppListApi(Resource): @account_initialization_required @marshal_with(app_detail_fields) @cloud_edition_billing_resource_check("apps") + @edit_permission_required def post(self): """Create app""" + current_user, current_tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, location="json") parser.add_argument("description", type=validate_description_length, location="json") @@ -140,19 +142,11 @@ class AppListApi(Resource): parser.add_argument("icon_background", type=str, location="json") args = parser.parse_args() - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() - if "mode" not in args or args["mode"] is None: raise BadRequest("mode is required") app_service = AppService() - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account instance") - if current_user.current_tenant_id is None: - raise ValueError("current_user.current_tenant_id cannot be None") - app = app_service.create_app(current_user.current_tenant_id, args, current_user) + app = app_service.create_app(current_tenant_id, args, current_user) return app, 201 @@ -205,13 +199,10 @@ class AppApi(Resource): @login_required @account_initialization_required @get_app_model + @edit_permission_required @marshal_with(app_detail_fields_with_site) def put(self, app_model): """Update app""" - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, nullable=False, location="json") parser.add_argument("description", type=validate_description_length, location="json") @@ -248,12 +239,9 @@ class AppApi(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required def delete(self, app_model): """Delete app""" - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() - app_service = AppService() app_service.delete_app(app_model) @@ -283,12 +271,12 @@ class AppCopyApi(Resource): @login_required @account_initialization_required @get_app_model + @edit_permission_required @marshal_with(app_detail_fields_with_site) def post(self, app_model): """Copy app""" # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() + current_user, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("name", type=str, location="json") @@ -301,9 +289,8 @@ class AppCopyApi(Resource): with Session(db.engine) as session: import_service = AppDslService(session) yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True) - account = cast(Account, current_user) result = import_service.import_app( - account=account, + account=current_user, import_mode=ImportMode.YAML_CONTENT, yaml_content=yaml_content, name=args.get("name"), @@ -340,12 +327,9 @@ class AppExportApi(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required def get(self, app_model): """Export app""" - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() - # Add include_secret params parser = reqparse.RequestParser() parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args") @@ -371,11 +355,8 @@ class AppNameApi(Resource): @account_initialization_required @get_app_model @marshal_with(app_detail_fields) + @edit_permission_required def post(self, app_model): - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, location="json") args = parser.parse_args() @@ -408,11 +389,8 @@ class AppIconApi(Resource): @account_initialization_required @get_app_model @marshal_with(app_detail_fields) + @edit_permission_required def post(self, app_model): - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("icon", type=str, location="json") parser.add_argument("icon_background", type=str, location="json") @@ -441,11 +419,8 @@ class AppSiteStatus(Resource): @account_initialization_required @get_app_model @marshal_with(app_detail_fields) + @edit_permission_required def post(self, app_model): - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("enable_site", type=bool, required=True, location="json") args = parser.parse_args() @@ -475,6 +450,7 @@ class AppApiStatus(Resource): @marshal_with(app_detail_fields) def post(self, app_model): # The role of the current user in the ta table must be admin or owner + current_user, _ = current_account_with_tenant() if not current_user.is_admin_or_owner: raise Forbidden() @@ -520,10 +496,9 @@ class AppTraceApi(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required def post(self, app_id): # add app trace - if not current_user.is_editor: - raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("enabled", type=bool, required=True, location="json") parser.add_argument("tracing_provider", type=str, required=True, location="json") diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index 5751ff1f86..5e7ea6d481 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -1,11 +1,11 @@ from flask_restx import Resource, marshal_with, reqparse from sqlalchemy.orm import Session -from werkzeug.exceptions import Forbidden from controllers.console.app.wraps import get_app_model from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_resource_check, + edit_permission_required, setup_required, ) from extensions.ext_database import db @@ -26,12 +26,10 @@ class AppImportApi(Resource): @account_initialization_required @marshal_with(app_import_fields) @cloud_edition_billing_resource_check("apps") + @edit_permission_required def post(self): # Check user role first current_user, _ = current_account_with_tenant() - if not current_user.has_edit_permission: - raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("mode", type=str, required=True, location="json") parser.add_argument("yaml_content", type=str, location="json") @@ -80,11 +78,10 @@ class AppImportConfirmApi(Resource): @login_required @account_initialization_required @marshal_with(app_import_fields) + @edit_permission_required def post(self, import_id): # Check user role first current_user, _ = current_account_with_tenant() - if not current_user.has_edit_permission: - raise Forbidden() # Create service with session with Session(db.engine) as session: @@ -107,11 +104,8 @@ class AppImportCheckDependenciesApi(Resource): @get_app_model @account_initialization_required @marshal_with(app_import_check_dependencies_fields) + @edit_permission_required def get(self, app_model: App): - current_user, _ = current_account_with_tenant() - if not current_user.has_edit_permission: - raise Forbidden() - with Session(db.engine) as session: import_service = AppDslService(session) result = import_service.check_dependencies(app_model=app_model) diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 2f7b90e7fb..d69f05f23e 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -2,7 +2,7 @@ import logging from flask import request from flask_restx import Resource, fields, reqparse -from werkzeug.exceptions import Forbidden, InternalServerError, NotFound +from werkzeug.exceptions import InternalServerError, NotFound import services from controllers.console import api, console_ns @@ -15,7 +15,7 @@ from controllers.console.app.error import ( ProviderQuotaExceededError, ) from controllers.console.app.wraps import get_app_model -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom @@ -151,13 +151,8 @@ class ChatMessageApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) + @edit_permission_required def post(self, app_model): - if not isinstance(current_user, Account): - raise Forbidden() - - if not current_user.has_edit_permission: - raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, location="json") parser.add_argument("query", type=str, required=True, location="json") diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 3b8dff613b..779be62973 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -1,17 +1,16 @@ from datetime import datetime -import pytz # pip install pytz +import pytz import sqlalchemy as sa -from flask_login import current_user from flask_restx import Resource, marshal_with, reqparse from flask_restx.inputs import int_range from sqlalchemy import func, or_ from sqlalchemy.orm import joinedload -from werkzeug.exceptions import Forbidden, NotFound +from werkzeug.exceptions import NotFound from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from fields.conversation_fields import ( @@ -22,8 +21,8 @@ from fields.conversation_fields import ( ) from libs.datetime_utils import naive_utc_now from libs.helper import DatetimeString -from libs.login import login_required -from models import Account, Conversation, EndUser, Message, MessageAnnotation +from libs.login import current_account_with_tenant, login_required +from models import Conversation, EndUser, Message, MessageAnnotation from models.model import AppMode from services.conversation_service import ConversationService from services.errors.conversation import ConversationNotExistsError @@ -57,9 +56,9 @@ class CompletionConversationApi(Resource): @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) @marshal_with(conversation_pagination_fields) + @edit_permission_required def get(self, app_model): - if not current_user.is_editor: - raise Forbidden() + current_user, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("keyword", type=str, location="args") parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") @@ -84,6 +83,7 @@ class CompletionConversationApi(Resource): ) account = current_user + assert account.timezone is not None timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc @@ -137,9 +137,8 @@ class CompletionConversationDetailApi(Resource): @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) @marshal_with(conversation_message_detail_fields) + @edit_permission_required def get(self, app_model, conversation_id): - if not current_user.is_editor: - raise Forbidden() conversation_id = str(conversation_id) return _get_conversation(app_model, conversation_id) @@ -154,14 +153,12 @@ class CompletionConversationDetailApi(Resource): @login_required @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) + @edit_permission_required def delete(self, app_model, conversation_id): - if not current_user.is_editor: - raise Forbidden() + current_user, _ = current_account_with_tenant() conversation_id = str(conversation_id) try: - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account instance") ConversationService.delete(app_model, conversation_id, current_user) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -206,9 +203,9 @@ class ChatConversationApi(Resource): @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @marshal_with(conversation_with_summary_pagination_fields) + @edit_permission_required def get(self, app_model): - if not current_user.is_editor: - raise Forbidden() + current_user, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("keyword", type=str, location="args") parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") @@ -260,6 +257,7 @@ class ChatConversationApi(Resource): ) account = current_user + assert account.timezone is not None timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc @@ -341,9 +339,8 @@ class ChatConversationDetailApi(Resource): @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @marshal_with(conversation_detail_fields) + @edit_permission_required def get(self, app_model, conversation_id): - if not current_user.is_editor: - raise Forbidden() conversation_id = str(conversation_id) return _get_conversation(app_model, conversation_id) @@ -358,14 +355,12 @@ class ChatConversationDetailApi(Resource): @login_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @account_initialization_required + @edit_permission_required def delete(self, app_model, conversation_id): - if not current_user.is_editor: - raise Forbidden() + current_user, _ = current_account_with_tenant() conversation_id = str(conversation_id) try: - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account instance") ConversationService.delete(app_model, conversation_id, current_user) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -374,6 +369,7 @@ class ChatConversationDetailApi(Resource): def _get_conversation(app_model, conversation_id): + current_user, _ = current_account_with_tenant() conversation = ( db.session.query(Conversation) .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id) diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index b9a383ee61..599f5adb34 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -1,16 +1,15 @@ import json from enum import StrEnum -from flask_login import current_user from flask_restx import Resource, fields, marshal_with, reqparse from werkzeug.exceptions import NotFound from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from extensions.ext_database import db from fields.app_fields import app_server_fields -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from models.model import AppMCPServer @@ -25,9 +24,9 @@ class AppMCPServerController(Resource): @api.doc(description="Get MCP server configuration for an application") @api.doc(params={"app_id": "Application ID"}) @api.response(200, "MCP server configuration retrieved successfully", app_server_fields) - @setup_required @login_required @account_initialization_required + @setup_required @get_app_model @marshal_with(app_server_fields) def get(self, app_model): @@ -48,14 +47,14 @@ class AppMCPServerController(Resource): ) @api.response(201, "MCP server configuration created successfully", app_server_fields) @api.response(403, "Insufficient permissions") - @setup_required - @login_required @account_initialization_required @get_app_model + @login_required + @setup_required @marshal_with(app_server_fields) + @edit_permission_required def post(self, app_model): - if not current_user.is_editor: - raise NotFound() + _, current_tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("description", type=str, required=False, location="json") parser.add_argument("parameters", type=dict, required=True, location="json") @@ -71,7 +70,7 @@ class AppMCPServerController(Resource): parameters=json.dumps(args["parameters"], ensure_ascii=False), status=AppMCPServerStatus.ACTIVE, app_id=app_model.id, - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, server_code=AppMCPServer.generate_server_code(16), ) db.session.add(server) @@ -95,14 +94,13 @@ class AppMCPServerController(Resource): @api.response(200, "MCP server configuration updated successfully", app_server_fields) @api.response(403, "Insufficient permissions") @api.response(404, "Server not found") - @setup_required - @login_required - @account_initialization_required @get_app_model + @login_required + @setup_required + @account_initialization_required @marshal_with(app_server_fields) + @edit_permission_required def put(self, app_model): - if not current_user.is_editor: - raise NotFound() parser = reqparse.RequestParser() parser.add_argument("id", type=str, required=True, location="json") parser.add_argument("description", type=str, required=False, location="json") @@ -142,13 +140,13 @@ class AppMCPServerRefreshController(Resource): @login_required @account_initialization_required @marshal_with(app_server_fields) + @edit_permission_required def get(self, server_id): - if not current_user.is_editor: - raise NotFound() + _, current_tenant_id = current_account_with_tenant() server = ( db.session.query(AppMCPServer) .where(AppMCPServer.id == server_id) - .where(AppMCPServer.tenant_id == current_user.current_tenant_id) + .where(AppMCPServer.tenant_id == current_tenant_id) .first() ) if not server: diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 46523feccc..005cff75fc 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -3,7 +3,7 @@ import logging from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx.inputs import int_range from sqlalchemy import exists, select -from werkzeug.exceptions import Forbidden, InternalServerError, NotFound +from werkzeug.exceptions import InternalServerError, NotFound from controllers.console import api, console_ns from controllers.console.app.error import ( @@ -17,6 +17,7 @@ from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDi from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_resource_check, + edit_permission_required, setup_required, ) from core.app.entities.app_invoke_entities import InvokeFrom @@ -26,8 +27,7 @@ from extensions.ext_database import db from fields.conversation_fields import annotation_fields, message_detail_fields from libs.helper import uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination -from libs.login import current_user, login_required -from models.account import Account +from libs.login import current_account_with_tenant, login_required from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback from services.annotation_service import AppAnnotationService from services.errors.conversation import ConversationNotExistsError @@ -56,15 +56,13 @@ class ChatMessageListApi(Resource): ) @api.response(200, "Success", message_infinite_scroll_pagination_fields) @api.response(404, "Conversation not found") - @setup_required @login_required - @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @account_initialization_required + @setup_required + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @marshal_with(message_infinite_scroll_pagination_fields) + @edit_permission_required def get(self, app_model): - if not isinstance(current_user, Account) or not current_user.has_edit_permission: - raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") parser.add_argument("first_id", type=uuid_value, location="args") @@ -154,8 +152,7 @@ class MessageFeedbackApi(Resource): @login_required @account_initialization_required def post(self, app_model): - if current_user is None: - raise Forbidden() + current_user, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("message_id", required=True, type=uuid_value, location="json") @@ -211,18 +208,14 @@ class MessageAnnotationApi(Resource): ) @api.response(200, "Annotation created successfully", annotation_fields) @api.response(403, "Insufficient permissions") + @marshal_with(annotation_fields) + @get_app_model @setup_required @login_required - @account_initialization_required @cloud_edition_billing_resource_check("annotation") - @get_app_model - @marshal_with(annotation_fields) + @account_initialization_required + @edit_permission_required def post(self, app_model): - if not isinstance(current_user, Account): - raise Forbidden() - if not current_user.has_edit_permission: - raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("message_id", required=False, type=uuid_value, location="json") parser.add_argument("question", required=True, type=str, location="json") @@ -270,6 +263,7 @@ class MessageSuggestedQuestionApi(Resource): @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) def get(self, app_model, message_id): + current_user, _ = current_account_with_tenant() message_id = str(message_id) try: @@ -304,12 +298,12 @@ class MessageApi(Resource): @api.doc(params={"app_id": "Application ID", "message_id": "Message ID"}) @api.response(200, "Message retrieved successfully", message_detail_fields) @api.response(404, "Message not found") + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model @marshal_with(message_detail_fields) - def get(self, app_model, message_id): + def get(self, app_model, message_id: str): message_id = str(message_id) message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first() diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 6537ea5a50..1da704efcc 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -9,7 +9,7 @@ from extensions.ext_database import db from fields.app_fields import app_site_fields from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required -from models import Account, Site +from models import Site def parse_app_site_args(): @@ -107,8 +107,6 @@ class AppSite(Resource): if value is not None: setattr(site, attr_name, value) - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account instance") site.updated_by = current_user.id site.updated_at = naive_utc_now() db.session.commit() @@ -142,8 +140,6 @@ class AppSiteAccessTokenReset(Resource): raise NotFound site.code = Site.generate_code(16) - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account instance") site.updated_by = current_user.id site.updated_at = naive_utc_now() db.session.commit() diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 5974395c6a..cfe5b3ff17 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -4,7 +4,6 @@ from decimal import Decimal import pytz import sqlalchemy as sa from flask import jsonify -from flask_login import current_user from flask_restx import Resource, fields, reqparse from controllers.console import api, console_ns @@ -13,7 +12,7 @@ from controllers.console.wraps import account_initialization_required, setup_req from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from libs.helper import DatetimeString -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from models import AppMode, Message @@ -37,7 +36,7 @@ class DailyMessageStatistic(Resource): @login_required @account_initialization_required def get(self, app_model): - account = current_user + account, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") @@ -53,6 +52,7 @@ WHERE app_id = :app_id AND invoke_from != :invoke_from""" arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} + assert account.timezone is not None timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc @@ -109,13 +109,13 @@ class DailyConversationStatistic(Resource): @login_required @account_initialization_required def get(self, app_model): - account = current_user + account, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - + assert account.timezone is not None timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc @@ -175,7 +175,7 @@ class DailyTerminalsStatistic(Resource): @login_required @account_initialization_required def get(self, app_model): - account = current_user + account, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") @@ -191,7 +191,7 @@ WHERE app_id = :app_id AND invoke_from != :invoke_from""" arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} - + assert account.timezone is not None timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc @@ -247,7 +247,7 @@ class DailyTokenCostStatistic(Resource): @login_required @account_initialization_required def get(self, app_model): - account = current_user + account, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") @@ -264,7 +264,7 @@ WHERE app_id = :app_id AND invoke_from != :invoke_from""" arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} - + assert account.timezone is not None timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc @@ -322,7 +322,7 @@ class AverageSessionInteractionStatistic(Resource): @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) def get(self, app_model): - account = current_user + account, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") @@ -346,7 +346,7 @@ FROM c.app_id = :app_id AND m.invoke_from != :invoke_from""" arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} - + assert account.timezone is not None timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc @@ -413,7 +413,7 @@ class UserSatisfactionRateStatistic(Resource): @login_required @account_initialization_required def get(self, app_model): - account = current_user + account, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") @@ -433,7 +433,7 @@ WHERE m.app_id = :app_id AND m.invoke_from != :invoke_from""" arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} - + assert account.timezone is not None timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc @@ -494,7 +494,7 @@ class AverageResponseTimeStatistic(Resource): @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) def get(self, app_model): - account = current_user + account, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") @@ -510,7 +510,7 @@ WHERE app_id = :app_id AND invoke_from != :invoke_from""" arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} - + assert account.timezone is not None timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc @@ -566,7 +566,7 @@ class TokensPerSecondStatistic(Resource): @login_required @account_initialization_required def get(self, app_model): - account = current_user + account, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") @@ -585,7 +585,7 @@ WHERE app_id = :app_id AND invoke_from != :invoke_from""" arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} - + assert account.timezone is not None timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 578d864b80..172a80736f 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -12,7 +12,7 @@ import services from controllers.console import api, console_ns from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync from controllers.console.app.wraps import get_app_model -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.base_app_queue_manager import AppQueueManager @@ -27,9 +27,8 @@ from fields.workflow_run_fields import workflow_run_node_execution_fields from libs import helper from libs.datetime_utils import naive_utc_now from libs.helper import TimestampField, uuid_value -from libs.login import current_user, login_required +from libs.login import current_account_with_tenant, login_required from models import App -from models.account import Account from models.model import AppMode from models.workflow import Workflow from services.app_generate_service import AppGenerateService @@ -70,15 +69,11 @@ class DraftWorkflowApi(Resource): @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @marshal_with(workflow_fields) + @edit_permission_required def get(self, app_model: App): """ Get draft workflow """ - # The role of the current user in the ta table must be admin, owner, or editor - assert isinstance(current_user, Account) - if not current_user.has_edit_permission: - raise Forbidden() - # fetch draft workflow by app_model workflow_service = WorkflowService() workflow = workflow_service.get_draft_workflow(app_model=app_model) @@ -110,14 +105,12 @@ class DraftWorkflowApi(Resource): @api.response(200, "Draft workflow synced successfully", workflow_fields) @api.response(400, "Invalid workflow configuration") @api.response(403, "Permission denied") + @edit_permission_required def post(self, app_model: App): """ Sync draft workflow """ - # The role of the current user in the ta table must be admin, owner, or editor - assert isinstance(current_user, Account) - if not current_user.has_edit_permission: - raise Forbidden() + current_user, _ = current_account_with_tenant() content_type = request.headers.get("Content-Type", "") @@ -149,10 +142,6 @@ class DraftWorkflowApi(Resource): return {"message": "Invalid JSON data"}, 400 else: abort(415) - - if not isinstance(current_user, Account): - raise Forbidden() - workflow_service = WorkflowService() try: @@ -206,17 +195,12 @@ class AdvancedChatDraftWorkflowRunApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT]) + @edit_permission_required def post(self, app_model: App): """ Run draft workflow """ - # The role of the current user in the ta table must be admin, owner, or editor - assert isinstance(current_user, Account) - if not current_user.has_edit_permission: - raise Forbidden() - - if not isinstance(current_user, Account): - raise Forbidden() + current_user, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, location="json") @@ -271,16 +255,12 @@ class AdvancedChatDraftRunIterationNodeApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT]) + @edit_permission_required def post(self, app_model: App, node_id: str): """ Run draft workflow iteration node """ - if not isinstance(current_user, Account): - raise Forbidden() - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.has_edit_permission: - raise Forbidden() - + current_user, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, location="json") args = parser.parse_args() @@ -323,16 +303,12 @@ class WorkflowDraftRunIterationNodeApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.WORKFLOW]) + @edit_permission_required def post(self, app_model: App, node_id: str): """ Run draft workflow iteration node """ - # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account): - raise Forbidden() - if not current_user.has_edit_permission: - raise Forbidden() - + current_user, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, location="json") args = parser.parse_args() @@ -375,17 +351,12 @@ class AdvancedChatDraftRunLoopNodeApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT]) + @edit_permission_required def post(self, app_model: App, node_id: str): """ Run draft workflow loop node """ - - if not isinstance(current_user, Account): - raise Forbidden() - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.has_edit_permission: - raise Forbidden() - + current_user, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, location="json") args = parser.parse_args() @@ -428,17 +399,12 @@ class WorkflowDraftRunLoopNodeApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.WORKFLOW]) + @edit_permission_required def post(self, app_model: App, node_id: str): """ Run draft workflow loop node """ - - if not isinstance(current_user, Account): - raise Forbidden() - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.has_edit_permission: - raise Forbidden() - + current_user, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, location="json") args = parser.parse_args() @@ -480,17 +446,12 @@ class DraftWorkflowRunApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.WORKFLOW]) + @edit_permission_required def post(self, app_model: App): """ Run draft workflow """ - - if not isinstance(current_user, Account): - raise Forbidden() - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.has_edit_permission: - raise Forbidden() - + current_user, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("files", type=list, required=False, location="json") @@ -526,17 +487,11 @@ class WorkflowTaskStopApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @edit_permission_required def post(self, app_model: App, task_id: str): """ Stop workflow task """ - - if not isinstance(current_user, Account): - raise Forbidden() - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.has_edit_permission: - raise Forbidden() - # Stop using both mechanisms for backward compatibility # Legacy stop flag mechanism (without user check) AppQueueManager.set_stop_flag_no_user_check(task_id) @@ -568,17 +523,12 @@ class DraftWorkflowNodeRunApi(Resource): @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @marshal_with(workflow_run_node_execution_fields) + @edit_permission_required def post(self, app_model: App, node_id: str): """ Run draft workflow node """ - - if not isinstance(current_user, Account): - raise Forbidden() - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.has_edit_permission: - raise Forbidden() - + current_user, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("query", type=str, required=False, location="json", default="") @@ -622,17 +572,11 @@ class PublishedWorkflowApi(Resource): @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @marshal_with(workflow_fields) + @edit_permission_required def get(self, app_model: App): """ Get published workflow """ - - if not isinstance(current_user, Account): - raise Forbidden() - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.has_edit_permission: - raise Forbidden() - # fetch published workflow by app_model workflow_service = WorkflowService() workflow = workflow_service.get_published_workflow(app_model=app_model) @@ -644,16 +588,12 @@ class PublishedWorkflowApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @edit_permission_required def post(self, app_model: App): """ Publish workflow """ - if not isinstance(current_user, Account): - raise Forbidden() - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.has_edit_permission: - raise Forbidden() - + current_user, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("marked_name", type=str, required=False, default="", location="json") parser.add_argument("marked_comment", type=str, required=False, default="", location="json") @@ -702,17 +642,11 @@ class DefaultBlockConfigsApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @edit_permission_required def get(self, app_model: App): """ Get default block config """ - - if not isinstance(current_user, Account): - raise Forbidden() - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.has_edit_permission: - raise Forbidden() - # Get default block configs workflow_service = WorkflowService() return workflow_service.get_default_block_configs() @@ -729,16 +663,11 @@ class DefaultBlockConfigApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @edit_permission_required def get(self, app_model: App, block_type: str): """ Get default block config """ - if not isinstance(current_user, Account): - raise Forbidden() - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.has_edit_permission: - raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("q", type=str, location="args") args = parser.parse_args() @@ -769,17 +698,14 @@ class ConvertToWorkflowApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.COMPLETION]) + @edit_permission_required def post(self, app_model: App): """ Convert basic mode of chatbot app to workflow mode Convert expert mode of chatbot app to workflow mode Convert Completion App to Workflow App """ - if not isinstance(current_user, Account): - raise Forbidden() - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.has_edit_permission: - raise Forbidden() + current_user, _ = current_account_with_tenant() if request.data: parser = reqparse.RequestParser() @@ -812,15 +738,12 @@ class PublishedAllWorkflowApi(Resource): @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @marshal_with(workflow_pagination_fields) + @edit_permission_required def get(self, app_model: App): """ Get published workflows """ - - if not isinstance(current_user, Account): - raise Forbidden() - if not current_user.has_edit_permission: - raise Forbidden() + current_user, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") @@ -879,16 +802,12 @@ class WorkflowByIdApi(Resource): @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @marshal_with(workflow_fields) + @edit_permission_required def patch(self, app_model: App, workflow_id: str): """ Update workflow attributes """ - if not isinstance(current_user, Account): - raise Forbidden() - # Check permission - if not current_user.has_edit_permission: - raise Forbidden() - + current_user, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("marked_name", type=str, required=False, location="json") parser.add_argument("marked_comment", type=str, required=False, location="json") @@ -934,16 +853,11 @@ class WorkflowByIdApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @edit_permission_required def delete(self, app_model: App, workflow_id: str): """ Delete workflow """ - if not isinstance(current_user, Account): - raise Forbidden() - # Check permission - if not current_user.has_edit_permission: - raise Forbidden() - workflow_service = WorkflowService() # Create a session and manage the transaction diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index da6b56d026..5e865dc4c1 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -22,8 +22,7 @@ from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type from libs.login import current_user, login_required -from models import App, AppMode -from models.account import Account +from models import Account, App, AppMode from models.workflow import WorkflowDraftVariable from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService from services.workflow_service import WorkflowService diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 23ba63845c..286ba65a7f 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -1,6 +1,5 @@ from typing import cast -from flask_login import current_user from flask_restx import Resource, marshal_with, reqparse from flask_restx.inputs import int_range @@ -14,7 +13,7 @@ from fields.workflow_run_fields import ( workflow_run_pagination_fields, ) from libs.helper import uuid_value -from libs.login import login_required +from libs.login import current_user, login_required from models import Account, App, AppMode, EndUser from services.workflow_run_service import WorkflowRunService diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index b8904bf3d9..8f7f936c9b 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -4,7 +4,6 @@ from decimal import Decimal import pytz import sqlalchemy as sa from flask import jsonify -from flask_login import current_user from flask_restx import Resource, reqparse from controllers.console import api, console_ns @@ -12,7 +11,7 @@ from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db from libs.helper import DatetimeString -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from models.enums import WorkflowRunTriggeredFrom from models.model import AppMode @@ -29,7 +28,7 @@ class WorkflowDailyRunsStatistic(Resource): @login_required @account_initialization_required def get(self, app_model): - account = current_user + account, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") @@ -49,7 +48,7 @@ WHERE "app_id": app_model.id, "triggered_from": WorkflowRunTriggeredFrom.APP_RUN, } - + assert account.timezone is not None timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc @@ -97,7 +96,7 @@ class WorkflowDailyTerminalsStatistic(Resource): @login_required @account_initialization_required def get(self, app_model): - account = current_user + account, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") @@ -117,7 +116,7 @@ WHERE "app_id": app_model.id, "triggered_from": WorkflowRunTriggeredFrom.APP_RUN, } - + assert account.timezone is not None timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc @@ -165,7 +164,7 @@ class WorkflowDailyTokenCostStatistic(Resource): @login_required @account_initialization_required def get(self, app_model): - account = current_user + account, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") @@ -185,7 +184,7 @@ WHERE "app_id": app_model.id, "triggered_from": WorkflowRunTriggeredFrom.APP_RUN, } - + assert account.timezone is not None timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc @@ -238,7 +237,7 @@ class WorkflowAverageAppInteractionStatistic(Resource): @account_initialization_required @get_app_model(mode=[AppMode.WORKFLOW]) def get(self, app_model): - account = current_user + account, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") @@ -271,7 +270,7 @@ GROUP BY "app_id": app_model.id, "triggered_from": WorkflowRunTriggeredFrom.APP_RUN, } - + assert account.timezone is not None timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index 44aba01820..9bb2718f89 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -4,28 +4,29 @@ from typing import ParamSpec, TypeVar, Union from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db -from libs.login import current_user +from libs.login import current_account_with_tenant from models import App, AppMode -from models.account import Account P = ParamSpec("P") R = TypeVar("R") +P1 = ParamSpec("P1") +R1 = TypeVar("R1") def _load_app_model(app_id: str) -> App | None: - assert isinstance(current_user, Account) + _, current_tenant_id = current_account_with_tenant() app_model = ( db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") .first() ) return app_model def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None): - def decorator(view_func: Callable[P, R]): + def decorator(view_func: Callable[P1, R1]): @wraps(view_func) - def decorated_view(*args: P.args, **kwargs: P.kwargs): + def decorated_view(*args: P1.args, **kwargs: P1.kwargs): if not kwargs.get("app_id"): raise ValueError("missing app_id in path parameters") diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index 76171e3f8a..06d2b936b7 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -7,7 +7,7 @@ from controllers.console.error import AlreadyActivateError from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from libs.helper import StrLen, email, extract_remote_ip, timezone -from models.account import AccountStatus +from models import AccountStatus from services.account_service import AccountService, RegisterService active_check_parser = reqparse.RequestParser() diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 6f1fd2f11a..0fd433d718 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -2,13 +2,12 @@ import logging import httpx from flask import current_app, redirect, request -from flask_login import current_user from flask_restx import Resource, fields from werkzeug.exceptions import Forbidden from configs import dify_config from controllers.console import api, console_ns -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from libs.oauth_data_source import NotionOAuth from ..wraps import account_initialization_required, setup_required @@ -45,6 +44,7 @@ class OAuthDataSource(Resource): @api.response(403, "Admin privileges required") def get(self, provider: str): # The role of the current user in the table must be admin or owner + current_user, _ = current_account_with_tenant() if not current_user.is_admin_or_owner: raise Forbidden() OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py index d3613d9183..cabd118d23 100644 --- a/api/controllers/console/auth/email_register.py +++ b/api/controllers/console/auth/email_register.py @@ -19,7 +19,7 @@ from controllers.console.wraps import email_password_login_enabled, email_regist from extensions.ext_database import db from libs.helper import email, extract_remote_ip from libs.password import valid_password -from models.account import Account +from models import Account from services.account_service import AccountService from services.billing_service import BillingService from services.errors.account import AccountNotFoundError, AccountRegisterError diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 704bcf8fb8..102d33966e 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -20,7 +20,7 @@ from events.tenant_event import tenant_was_created from extensions.ext_database import db from libs.helper import email, extract_remote_ip from libs.password import hash_password, valid_password -from models.account import Account +from models import Account from services.account_service import AccountService, TenantService from services.feature_service import FeatureService diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index ba614aa828..e4bbbf107b 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -1,5 +1,3 @@ -from typing import cast - import flask_login from flask import request from flask_restx import Resource, reqparse @@ -26,7 +24,7 @@ from controllers.console.error import ( from controllers.console.wraps import email_password_login_enabled, setup_required from events.tenant_event import tenant_was_created from libs.helper import email, extract_remote_ip -from models.account import Account +from libs.login import current_account_with_tenant from services.account_service import AccountService, RegisterService, TenantService from services.billing_service import BillingService from services.errors.account import AccountRegisterError @@ -96,7 +94,8 @@ class LoginApi(Resource): class LogoutApi(Resource): @setup_required def get(self): - account = cast(Account, flask_login.current_user) + current_user, _ = current_account_with_tenant() + account = current_user if isinstance(account, flask_login.AnonymousUserMixin): return {"result": "success"} AccountService.logout(account=account) diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 4efeceb676..52459ad5eb 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -14,8 +14,7 @@ from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from libs.helper import extract_remote_ip from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo -from models import Account -from models.account import AccountStatus +from models import Account, AccountStatus from services.account_service import AccountService, RegisterService, TenantService from services.billing_service import BillingService from services.errors.account import AccountNotFoundError, AccountRegisterError diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py index 46281860ae..188ef7f622 100644 --- a/api/controllers/console/auth/oauth_server.py +++ b/api/controllers/console/auth/oauth_server.py @@ -1,16 +1,15 @@ from collections.abc import Callable from functools import wraps -from typing import Concatenate, ParamSpec, TypeVar, cast +from typing import Concatenate, ParamSpec, TypeVar -import flask_login from flask import jsonify, request from flask_restx import Resource, reqparse from werkzeug.exceptions import BadRequest, NotFound from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder -from libs.login import login_required -from models.account import Account +from libs.login import current_account_with_tenant, login_required +from models import Account from models.model import OAuthProviderApp from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService @@ -116,7 +115,8 @@ class OAuthServerUserAuthorizeApi(Resource): @account_initialization_required @oauth_server_client_id_required def post(self, oauth_provider_app: OAuthProviderApp): - account = cast(Account, flask_login.current_user) + current_user, _ = current_account_with_tenant() + account = current_user user_account_id = account.id code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id) diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index fa89f45122..5c89b29057 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -2,8 +2,7 @@ from flask_restx import Resource, reqparse from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required -from libs.login import current_user, login_required -from models.model import Account +from libs.login import current_account_with_tenant, login_required from services.billing_service import BillingService @@ -14,17 +13,13 @@ class Subscription(Resource): @account_initialization_required @only_edition_cloud def get(self): + current_user, current_tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"]) parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"]) args = parser.parse_args() - assert isinstance(current_user, Account) - BillingService.is_tenant_owner_or_admin(current_user) - assert current_user.current_tenant_id is not None - return BillingService.get_subscription( - args["plan"], args["interval"], current_user.email, current_user.current_tenant_id - ) + return BillingService.get_subscription(args["plan"], args["interval"], current_user.email, current_tenant_id) @console_ns.route("/billing/invoices") @@ -34,7 +29,6 @@ class Invoices(Resource): @account_initialization_required @only_edition_cloud def get(self): - assert isinstance(current_user, Account) + current_user, current_tenant_id = current_account_with_tenant() BillingService.is_tenant_owner_or_admin(current_user) - assert current_user.current_tenant_id is not None - return BillingService.get_invoices(current_user.email, current_user.current_tenant_id) + return BillingService.get_invoices(current_user.email, current_tenant_id) diff --git a/api/controllers/console/billing/compliance.py b/api/controllers/console/billing/compliance.py index c0d104e0d4..3b32fe29a1 100644 --- a/api/controllers/console/billing/compliance.py +++ b/api/controllers/console/billing/compliance.py @@ -2,8 +2,7 @@ from flask import request from flask_restx import Resource, reqparse from libs.helper import extract_remote_ip -from libs.login import current_user, login_required -from models.account import Account +from libs.login import current_account_with_tenant, login_required from services.billing_service import BillingService from .. import console_ns @@ -17,19 +16,17 @@ class ComplianceApi(Resource): @account_initialization_required @only_edition_cloud def get(self): - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + current_user, current_tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("doc_name", type=str, required=True, location="args") args = parser.parse_args() ip_address = extract_remote_ip(request) device_info = request.headers.get("User-Agent", "Unknown device") - return BillingService.get_compliance_download_link( doc_name=args.doc_name, account_id=current_user.id, - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, ip=ip_address, device_info=device_info, ) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index f86c5dfc3c..c03767d2e6 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -1,7 +1,6 @@ from typing import Any, cast from flask import request -from flask_login import current_user from flask_restx import Resource, fields, marshal, marshal_with, reqparse from sqlalchemy import select from werkzeug.exceptions import Forbidden, NotFound @@ -30,10 +29,9 @@ from extensions.ext_database import db from fields.app_fields import related_app_list from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields from fields.document_fields import document_status_fields -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from libs.validators import validate_description_length from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile -from models.account import Account from models.dataset import DatasetPermissionEnum from models.provider_ids import ModelProviderID from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService @@ -138,6 +136,7 @@ class DatasetListApi(Resource): @account_initialization_required @enterprise_license_required def get(self): + current_user, current_tenant_id = current_account_with_tenant() page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) ids = request.args.getlist("ids") @@ -146,15 +145,15 @@ class DatasetListApi(Resource): tag_ids = request.args.getlist("tag_ids") include_all = request.args.get("include_all", default="false").lower() == "true" if ids: - datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id) + datasets, total = DatasetService.get_datasets_by_ids(ids, current_tenant_id) else: datasets, total = DatasetService.get_datasets( - page, limit, current_user.current_tenant_id, current_user, search, tag_ids, include_all + page, limit, current_tenant_id, current_user, search, tag_ids, include_all ) # check embedding setting provider_manager = ProviderManager() - configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) + configurations = provider_manager.get_configurations(tenant_id=current_tenant_id) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) @@ -251,6 +250,7 @@ class DatasetListApi(Resource): required=False, ) args = parser.parse_args() + current_user, current_tenant_id = current_account_with_tenant() # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator if not current_user.is_dataset_editor: @@ -258,11 +258,11 @@ class DatasetListApi(Resource): try: dataset = DatasetService.create_empty_dataset( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, name=args["name"], description=args["description"], indexing_technique=args["indexing_technique"], - account=cast(Account, current_user), + account=current_user, permission=DatasetPermissionEnum.ONLY_ME, provider=args["provider"], external_knowledge_api_id=args["external_knowledge_api_id"], @@ -286,6 +286,7 @@ class DatasetApi(Resource): @login_required @account_initialization_required def get(self, dataset_id): + current_user, current_tenant_id = current_account_with_tenant() dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: @@ -305,7 +306,7 @@ class DatasetApi(Resource): # check embedding setting provider_manager = ProviderManager() - configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) + configurations = provider_manager.get_configurations(tenant_id=current_tenant_id) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) @@ -418,6 +419,7 @@ class DatasetApi(Resource): ) args = parser.parse_args() data = request.get_json() + current_user, current_tenant_id = current_account_with_tenant() # check embedding model setting if ( @@ -440,7 +442,7 @@ class DatasetApi(Resource): raise NotFound("Dataset not found.") result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) - tenant_id = current_user.current_tenant_id + tenant_id = current_tenant_id if data.get("partial_member_list") and data.get("permission") == "partial_members": DatasetPermissionService.update_partial_member_list( @@ -464,9 +466,10 @@ class DatasetApi(Resource): @cloud_edition_billing_rate_limit_check("knowledge") def delete(self, dataset_id): dataset_id_str = str(dataset_id) + current_user, _ = current_account_with_tenant() # The role of the current user in the ta table must be admin, owner, or editor - if not (current_user.is_editor or current_user.is_dataset_operator): + if not (current_user.has_edit_permission or current_user.is_dataset_operator): raise Forbidden() try: @@ -505,6 +508,7 @@ class DatasetQueryApi(Resource): @login_required @account_initialization_required def get(self, dataset_id): + current_user, _ = current_account_with_tenant() dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: @@ -556,15 +560,14 @@ class DatasetIndexingEstimateApi(Resource): "doc_language", type=str, default="English", required=False, nullable=False, location="json" ) args = parser.parse_args() + _, current_tenant_id = current_account_with_tenant() # validate args DocumentService.estimate_args_validate(args) extract_settings = [] if args["info_list"]["data_source_type"] == "upload_file": file_ids = args["info_list"]["file_info_list"]["file_ids"] file_details = db.session.scalars( - select(UploadFile).where( - UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids) - ) + select(UploadFile).where(UploadFile.tenant_id == current_tenant_id, UploadFile.id.in_(file_ids)) ).all() if file_details is None: @@ -592,7 +595,7 @@ class DatasetIndexingEstimateApi(Resource): "notion_workspace_id": workspace_id, "notion_obj_id": page["page_id"], "notion_page_type": page["type"], - "tenant_id": current_user.current_tenant_id, + "tenant_id": current_tenant_id, } ), document_model=args["doc_form"], @@ -608,7 +611,7 @@ class DatasetIndexingEstimateApi(Resource): "provider": website_info_list["provider"], "job_id": website_info_list["job_id"], "url": url, - "tenant_id": current_user.current_tenant_id, + "tenant_id": current_tenant_id, "mode": "crawl", "only_main_content": website_info_list["only_main_content"], } @@ -621,7 +624,7 @@ class DatasetIndexingEstimateApi(Resource): indexing_runner = IndexingRunner() try: response = indexing_runner.indexing_estimate( - current_user.current_tenant_id, + current_tenant_id, extract_settings, args["process_rule"], args["doc_form"], @@ -652,6 +655,7 @@ class DatasetRelatedAppListApi(Resource): @account_initialization_required @marshal_with(related_app_list) def get(self, dataset_id): + current_user, _ = current_account_with_tenant() dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: @@ -683,11 +687,10 @@ class DatasetIndexingStatusApi(Resource): @login_required @account_initialization_required def get(self, dataset_id): + _, current_tenant_id = current_account_with_tenant() dataset_id = str(dataset_id) documents = db.session.scalars( - select(Document).where( - Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id - ) + select(Document).where(Document.dataset_id == dataset_id, Document.tenant_id == current_tenant_id) ).all() documents_status = [] for document in documents: @@ -739,10 +742,9 @@ class DatasetApiKeyApi(Resource): @account_initialization_required @marshal_with(api_key_list) def get(self): + _, current_tenant_id = current_account_with_tenant() keys = db.session.scalars( - select(ApiToken).where( - ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id - ) + select(ApiToken).where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id) ).all() return {"items": keys} @@ -752,12 +754,13 @@ class DatasetApiKeyApi(Resource): @marshal_with(api_key_fields) def post(self): # The role of the current user in the ta table must be admin or owner + current_user, current_tenant_id = current_account_with_tenant() if not current_user.is_admin_or_owner: raise Forbidden() current_key_count = ( db.session.query(ApiToken) - .where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) + .where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id) .count() ) @@ -770,7 +773,7 @@ class DatasetApiKeyApi(Resource): key = ApiToken.generate_api_key(self.token_prefix, 24) api_token = ApiToken() - api_token.tenant_id = current_user.current_tenant_id + api_token.tenant_id = current_tenant_id api_token.token = key api_token.type = self.resource_type db.session.add(api_token) @@ -790,6 +793,7 @@ class DatasetApiDeleteApi(Resource): @login_required @account_initialization_required def delete(self, api_key_id): + current_user, current_tenant_id = current_account_with_tenant() api_key_id = str(api_key_id) # The role of the current user in the ta table must be admin or owner @@ -799,7 +803,7 @@ class DatasetApiDeleteApi(Resource): key = ( db.session.query(ApiToken) .where( - ApiToken.tenant_id == current_user.current_tenant_id, + ApiToken.tenant_id == current_tenant_id, ApiToken.type == self.resource_type, ApiToken.id == api_key_id, ) @@ -898,6 +902,7 @@ class DatasetPermissionUserListApi(Resource): @login_required @account_initialization_required def get(self, dataset_id): + current_user, _ = current_account_with_tenant() dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 011dacde76..9c0c54833e 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -6,7 +6,6 @@ from typing import Literal, cast import sqlalchemy as sa from flask import request -from flask_login import current_user from flask_restx import Resource, fields, marshal, marshal_with, reqparse from sqlalchemy import asc, desc, select from werkzeug.exceptions import Forbidden, NotFound @@ -53,9 +52,8 @@ from fields.document_fields import ( document_with_segments_fields, ) from libs.datetime_utils import naive_utc_now -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile -from models.account import Account from models.dataset import DocumentPipelineExecutionLog from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig @@ -65,6 +63,7 @@ logger = logging.getLogger(__name__) class DocumentResource(Resource): def get_document(self, dataset_id: str, document_id: str) -> Document: + current_user, current_tenant_id = current_account_with_tenant() dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise NotFound("Dataset not found.") @@ -79,12 +78,13 @@ class DocumentResource(Resource): if not document: raise NotFound("Document not found.") - if document.tenant_id != current_user.current_tenant_id: + if document.tenant_id != current_tenant_id: raise Forbidden("No permission.") return document def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]: + current_user, _ = current_account_with_tenant() dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise NotFound("Dataset not found.") @@ -112,6 +112,7 @@ class GetProcessRuleApi(Resource): @login_required @account_initialization_required def get(self): + current_user, _ = current_account_with_tenant() req_data = request.args document_id = req_data.get("document_id") @@ -168,6 +169,7 @@ class DatasetDocumentListApi(Resource): @login_required @account_initialization_required def get(self, dataset_id): + current_user, current_tenant_id = current_account_with_tenant() dataset_id = str(dataset_id) page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) @@ -199,7 +201,7 @@ class DatasetDocumentListApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) + query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_tenant_id) if search: search = f"%{search}%" @@ -273,6 +275,7 @@ class DatasetDocumentListApi(Resource): @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") def post(self, dataset_id): + current_user, _ = current_account_with_tenant() dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) @@ -372,6 +375,7 @@ class DatasetInitApi(Resource): @cloud_edition_billing_rate_limit_check("knowledge") def post(self): # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor + current_user, current_tenant_id = current_account_with_tenant() if not current_user.is_dataset_editor: raise Forbidden() @@ -402,7 +406,7 @@ class DatasetInitApi(Resource): try: model_manager = ModelManager() model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider=args["embedding_model_provider"], model_type=ModelType.TEXT_EMBEDDING, model=args["embedding_model"], @@ -419,9 +423,9 @@ class DatasetInitApi(Resource): try: dataset, documents, batch = DocumentService.save_document_without_dataset_id( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, knowledge_config=knowledge_config, - account=cast(Account, current_user), + account=current_user, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -447,6 +451,7 @@ class DocumentIndexingEstimateApi(DocumentResource): @login_required @account_initialization_required def get(self, dataset_id, document_id): + _, current_tenant_id = current_account_with_tenant() dataset_id = str(dataset_id) document_id = str(document_id) document = self.get_document(dataset_id, document_id) @@ -482,7 +487,7 @@ class DocumentIndexingEstimateApi(DocumentResource): try: estimate_response = indexing_runner.indexing_estimate( - current_user.current_tenant_id, + current_tenant_id, [extract_setting], data_process_rule_dict, document.doc_form, @@ -511,6 +516,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): @login_required @account_initialization_required def get(self, dataset_id, batch): + _, current_tenant_id = current_account_with_tenant() dataset_id = str(dataset_id) batch = str(batch) documents = self.get_batch_documents(dataset_id, batch) @@ -530,7 +536,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): file_id = data_source_info["upload_file_id"] file_detail = ( db.session.query(UploadFile) - .where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id) + .where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id) .first() ) @@ -553,7 +559,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): "notion_workspace_id": data_source_info["notion_workspace_id"], "notion_obj_id": data_source_info["notion_page_id"], "notion_page_type": data_source_info["type"], - "tenant_id": current_user.current_tenant_id, + "tenant_id": current_tenant_id, } ), document_model=document.doc_form, @@ -569,7 +575,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): "provider": data_source_info["provider"], "job_id": data_source_info["job_id"], "url": data_source_info["url"], - "tenant_id": current_user.current_tenant_id, + "tenant_id": current_tenant_id, "mode": data_source_info["mode"], "only_main_content": data_source_info["only_main_content"], } @@ -583,7 +589,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): indexing_runner = IndexingRunner() try: response = indexing_runner.indexing_estimate( - current_user.current_tenant_id, + current_tenant_id, extract_settings, data_process_rule_dict, document.doc_form, @@ -834,6 +840,7 @@ class DocumentProcessingApi(DocumentResource): @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]): + current_user, _ = current_account_with_tenant() dataset_id = str(dataset_id) document_id = str(document_id) document = self.get_document(dataset_id, document_id) @@ -884,6 +891,7 @@ class DocumentMetadataApi(DocumentResource): @login_required @account_initialization_required def put(self, dataset_id, document_id): + current_user, _ = current_account_with_tenant() dataset_id = str(dataset_id) document_id = str(document_id) document = self.get_document(dataset_id, document_id) @@ -931,6 +939,7 @@ class DocumentStatusApi(DocumentResource): @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]): + current_user, _ = current_account_with_tenant() dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if dataset is None: @@ -1077,12 +1086,13 @@ class DocumentRenameApi(DocumentResource): @marshal_with(document_fields) def post(self, dataset_id, document_id): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator + current_user, _ = current_account_with_tenant() if not current_user.is_dataset_editor: raise Forbidden() dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise NotFound("Dataset not found.") - DatasetService.check_dataset_operator_permission(cast(Account, current_user), dataset) + DatasetService.check_dataset_operator_permission(current_user, dataset) parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, nullable=False, location="json") args = parser.parse_args() @@ -1102,6 +1112,7 @@ class WebsiteDocumentSyncApi(DocumentResource): @account_initialization_required def get(self, dataset_id, document_id): """sync website document.""" + _, current_tenant_id = current_account_with_tenant() dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: @@ -1110,7 +1121,7 @@ class WebsiteDocumentSyncApi(DocumentResource): document = DocumentService.get_document(dataset.id, document_id) if not document: raise NotFound("Document not found.") - if document.tenant_id != current_user.current_tenant_id: + if document.tenant_id != current_tenant_id: raise Forbidden("No permission.") if document.data_source_type != "website_crawl": raise ValueError("Document is not a website document.") diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index adf9f53523..f590919180 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -1,7 +1,4 @@ -from typing import cast - from flask import request -from flask_login import current_user from flask_restx import Resource, fields, marshal, reqparse from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -10,8 +7,7 @@ from controllers.console import api, console_ns from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.wraps import account_initialization_required, setup_required from fields.dataset_fields import dataset_detail_fields -from libs.login import login_required -from models.account import Account +from libs.login import current_account_with_tenant, login_required from services.dataset_service import DatasetService from services.external_knowledge_service import ExternalDatasetService from services.hit_testing_service import HitTestingService @@ -40,12 +36,13 @@ class ExternalApiTemplateListApi(Resource): @login_required @account_initialization_required def get(self): + _, current_tenant_id = current_account_with_tenant() page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) search = request.args.get("keyword", default=None, type=str) external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis( - page, limit, current_user.current_tenant_id, search + page, limit, current_tenant_id, search ) response = { "data": [item.to_dict() for item in external_knowledge_apis], @@ -60,6 +57,7 @@ class ExternalApiTemplateListApi(Resource): @login_required @account_initialization_required def post(self): + current_user, current_tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument( "name", @@ -85,7 +83,7 @@ class ExternalApiTemplateListApi(Resource): try: external_knowledge_api = ExternalDatasetService.create_external_knowledge_api( - tenant_id=current_user.current_tenant_id, user_id=current_user.id, args=args + tenant_id=current_tenant_id, user_id=current_user.id, args=args ) except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() @@ -115,6 +113,7 @@ class ExternalApiTemplateApi(Resource): @login_required @account_initialization_required def patch(self, external_knowledge_api_id): + current_user, current_tenant_id = current_account_with_tenant() external_knowledge_api_id = str(external_knowledge_api_id) parser = reqparse.RequestParser() @@ -136,7 +135,7 @@ class ExternalApiTemplateApi(Resource): ExternalDatasetService.validate_api_list(args["settings"]) external_knowledge_api = ExternalDatasetService.update_external_knowledge_api( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, user_id=current_user.id, external_knowledge_api_id=external_knowledge_api_id, args=args, @@ -148,13 +147,14 @@ class ExternalApiTemplateApi(Resource): @login_required @account_initialization_required def delete(self, external_knowledge_api_id): + current_user, current_tenant_id = current_account_with_tenant() external_knowledge_api_id = str(external_knowledge_api_id) # The role of the current user in the ta table must be admin, owner, or editor - if not (current_user.is_editor or current_user.is_dataset_operator): + if not (current_user.has_edit_permission or current_user.is_dataset_operator): raise Forbidden() - ExternalDatasetService.delete_external_knowledge_api(current_user.current_tenant_id, external_knowledge_api_id) + ExternalDatasetService.delete_external_knowledge_api(current_tenant_id, external_knowledge_api_id) return {"result": "success"}, 204 @@ -199,7 +199,8 @@ class ExternalDatasetCreateApi(Resource): @account_initialization_required def post(self): # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + current_user, current_tenant_id = current_account_with_tenant() + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -223,7 +224,7 @@ class ExternalDatasetCreateApi(Resource): try: dataset = ExternalDatasetService.create_external_dataset( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, user_id=current_user.id, args=args, ) @@ -255,6 +256,7 @@ class ExternalKnowledgeHitTestingApi(Resource): @login_required @account_initialization_required def post(self, dataset_id): + current_user, _ = current_account_with_tenant() dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: @@ -277,7 +279,7 @@ class ExternalKnowledgeHitTestingApi(Resource): response = HitTestingService.external_retrieve( dataset=dataset, query=args["query"], - account=cast(Account, current_user), + account=current_user, external_retrieval_model=args["external_retrieval_model"], metadata_filtering_conditions=args["metadata_filtering_conditions"], ) diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index 8438458617..673bac1add 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -1,13 +1,12 @@ from typing import Literal -from flask_login import current_user from flask_restx import Resource, marshal_with, reqparse from werkzeug.exceptions import NotFound from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required from fields.dataset_fields import dataset_metadata_fields -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from services.dataset_service import DatasetService from services.entities.knowledge_entities.knowledge_entities import ( MetadataArgs, @@ -24,6 +23,7 @@ class DatasetMetadataCreateApi(Resource): @enterprise_license_required @marshal_with(dataset_metadata_fields) def post(self, dataset_id): + current_user, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("type", type=str, required=True, nullable=False, location="json") parser.add_argument("name", type=str, required=True, nullable=False, location="json") @@ -59,6 +59,7 @@ class DatasetMetadataApi(Resource): @enterprise_license_required @marshal_with(dataset_metadata_fields) def patch(self, dataset_id, metadata_id): + current_user, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, nullable=False, location="json") args = parser.parse_args() @@ -79,6 +80,7 @@ class DatasetMetadataApi(Resource): @account_initialization_required @enterprise_license_required def delete(self, dataset_id, metadata_id): + current_user, _ = current_account_with_tenant() dataset_id_str = str(dataset_id) metadata_id_str = str(metadata_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -108,6 +110,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource): @account_initialization_required @enterprise_license_required def post(self, dataset_id, action: Literal["enable", "disable"]): + current_user, _ = current_account_with_tenant() dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: @@ -128,6 +131,7 @@ class DocumentMetadataEditApi(Resource): @account_initialization_required @enterprise_license_required def post(self, dataset_id): + current_user, _ = current_account_with_tenant() dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index 4c34df7eff..194bd98fa3 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -4,10 +4,7 @@ from werkzeug.exceptions import Forbidden, NotFound from configs import dify_config from controllers.console import console_ns -from controllers.console.wraps import ( - account_initialization_required, - setup_required, -) +from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.impl.oauth import OAuthHandler @@ -23,12 +20,11 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required def get(self, provider_id: str): current_user, current_tenant_id = current_account_with_tenant() tenant_id = current_tenant_id - if not current_user.has_edit_permission: - raise Forbidden() credential_id = request.args.get("credential_id") datasource_provider_id = DatasourceProviderID(provider_id) @@ -130,11 +126,9 @@ class DatasourceAuth(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required def post(self, provider_id: str): - current_user, current_tenant_id = current_account_with_tenant() - - if not current_user.has_edit_permission: - raise Forbidden() + _, current_tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument( @@ -177,14 +171,14 @@ class DatasourceAuthDeleteApi(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required def post(self, provider_id: str): - current_user, current_tenant_id = current_account_with_tenant() + _, current_tenant_id = current_account_with_tenant() datasource_provider_id = DatasourceProviderID(provider_id) plugin_id = datasource_provider_id.plugin_id provider_name = datasource_provider_id.provider_name - if not current_user.has_edit_permission: - raise Forbidden() + parser = reqparse.RequestParser() parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") args = parser.parse_args() @@ -203,8 +197,9 @@ class DatasourceAuthUpdateApi(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required def post(self, provider_id: str): - current_user, current_tenant_id = current_account_with_tenant() + _, current_tenant_id = current_account_with_tenant() datasource_provider_id = DatasourceProviderID(provider_id) parser = reqparse.RequestParser() @@ -212,8 +207,7 @@ class DatasourceAuthUpdateApi(Resource): parser.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json") parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") args = parser.parse_args() - if not current_user.has_edit_permission: - raise Forbidden() + datasource_provider_service = DatasourceProviderService() datasource_provider_service.update_datasource_credentials( tenant_id=current_tenant_id, @@ -257,11 +251,10 @@ class DatasourceAuthOauthCustomClient(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required def post(self, provider_id: str): - current_user, current_tenant_id = current_account_with_tenant() + _, current_tenant_id = current_account_with_tenant() - if not current_user.has_edit_permission: - raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json") parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json") @@ -296,11 +289,10 @@ class DatasourceAuthDefaultApi(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required def post(self, provider_id: str): - current_user, current_tenant_id = current_account_with_tenant() + _, current_tenant_id = current_account_with_tenant() - if not current_user.has_edit_permission: - raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("id", type=str, required=True, nullable=False, location="json") args = parser.parse_args() @@ -319,11 +311,10 @@ class DatasourceUpdateProviderNameApi(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required def post(self, provider_id: str): - current_user, current_tenant_id = current_account_with_tenant() + _, current_tenant_id = current_account_with_tenant() - if not current_user.has_edit_permission: - raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json") parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index bef6bfd13e..2e8cc16dc1 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -23,7 +23,7 @@ from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type from libs.login import current_user, login_required -from models.account import Account +from models import Account from models.dataset import Pipeline from models.workflow import WorkflowDraftVariable from services.rag_pipeline.rag_pipeline import RagPipelineService diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py index a82872ba2b..ca767dbb10 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py @@ -1,6 +1,3 @@ -from typing import cast - -from flask_login import current_user # type: ignore from flask_restx import Resource, marshal_with, reqparse # type: ignore from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden @@ -13,8 +10,7 @@ from controllers.console.wraps import ( ) from extensions.ext_database import db from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields -from libs.login import login_required -from models import Account +from libs.login import current_account_with_tenant, login_required from models.dataset import Pipeline from services.app_dsl_service import ImportStatus from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService @@ -28,7 +24,8 @@ class RagPipelineImportApi(Resource): @marshal_with(pipeline_import_fields) def post(self): # Check user role first - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -47,7 +44,7 @@ class RagPipelineImportApi(Resource): with Session(db.engine) as session: import_service = RagPipelineDslService(session) # Import app - account = cast(Account, current_user) + account = current_user result = import_service.import_rag_pipeline( account=account, import_mode=args["mode"], @@ -74,15 +71,16 @@ class RagPipelineImportConfirmApi(Resource): @account_initialization_required @marshal_with(pipeline_import_fields) def post(self, import_id): + current_user, _ = current_account_with_tenant() # Check user role first - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() # Create service with session with Session(db.engine) as session: import_service = RagPipelineDslService(session) # Confirm import - account = cast(Account, current_user) + account = current_user result = import_service.confirm_import(import_id=import_id, account=account) session.commit() @@ -100,7 +98,8 @@ class RagPipelineImportCheckDependenciesApi(Resource): @account_initialization_required @marshal_with(pipeline_import_check_dependencies_fields) def get(self, pipeline: Pipeline): - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + if not current_user.has_edit_permission: raise Forbidden() with Session(db.engine) as session: @@ -117,7 +116,8 @@ class RagPipelineExportApi(Resource): @get_rag_pipeline @account_initialization_required def get(self, pipeline: Pipeline): - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + if not current_user.has_edit_permission: raise Forbidden() # Add include_secret params diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index a75c121fbe..d4d6da7fe2 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -18,6 +18,7 @@ from controllers.console.app.error import ( from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import ( account_initialization_required, + edit_permission_required, setup_required, ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError @@ -36,8 +37,8 @@ from fields.workflow_run_fields import ( ) from libs import helper from libs.helper import TimestampField, uuid_value -from libs.login import current_user, login_required -from models.account import Account +from libs.login import current_account_with_tenant, current_user, login_required +from models import Account from models.dataset import Pipeline from models.model import EndUser from services.errors.app import WorkflowHashNotEqualError @@ -56,15 +57,12 @@ class DraftRagPipelineApi(Resource): @login_required @account_initialization_required @get_rag_pipeline + @edit_permission_required @marshal_with(workflow_fields) def get(self, pipeline: Pipeline): """ Get draft rag pipeline's workflow """ - # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.has_edit_permission: - raise Forbidden() - # fetch draft workflow by app_model rag_pipeline_service = RagPipelineService() workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline) @@ -79,13 +77,13 @@ class DraftRagPipelineApi(Resource): @login_required @account_initialization_required @get_rag_pipeline + @edit_permission_required def post(self, pipeline: Pipeline): """ Sync draft workflow """ # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.has_edit_permission: - raise Forbidden() + current_user, _ = current_account_with_tenant() content_type = request.headers.get("Content-Type", "") @@ -154,13 +152,13 @@ class RagPipelineDraftRunIterationNodeApi(Resource): @login_required @account_initialization_required @get_rag_pipeline + @edit_permission_required def post(self, pipeline: Pipeline, node_id: str): """ Run draft workflow iteration node """ # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.has_edit_permission: - raise Forbidden() + current_user, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, location="json") @@ -194,7 +192,8 @@ class RagPipelineDraftRunLoopNodeApi(Resource): Run draft workflow loop node """ # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.has_edit_permission: + current_user, _ = current_account_with_tenant() + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -229,7 +228,8 @@ class DraftRagPipelineRunApi(Resource): Run draft workflow """ # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.has_edit_permission: + current_user, _ = current_account_with_tenant() + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -264,7 +264,8 @@ class PublishedRagPipelineRunApi(Resource): Run published workflow """ # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.has_edit_permission: + current_user, _ = current_account_with_tenant() + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -303,7 +304,7 @@ class PublishedRagPipelineRunApi(Resource): # Run rag pipeline datasource # """ # # The role of the current user in the ta table must be admin, owner, or editor -# if not current_user.is_editor: +# if not current_user.has_edit_permission: # raise Forbidden() # # if not isinstance(current_user, Account): @@ -344,7 +345,7 @@ class PublishedRagPipelineRunApi(Resource): # Run rag pipeline datasource # """ # # The role of the current user in the ta table must be admin, owner, or editor -# if not current_user.is_editor: +# if not current_user.has_edit_permission: # raise Forbidden() # # if not isinstance(current_user, Account): @@ -385,7 +386,8 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource): Run rag pipeline datasource """ # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.has_edit_permission: + current_user, _ = current_account_with_tenant() + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -428,7 +430,8 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource): Run rag pipeline datasource """ # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.has_edit_permission: + current_user, _ = current_account_with_tenant() + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -472,7 +475,8 @@ class RagPipelineDraftNodeRunApi(Resource): Run draft workflow node """ # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.has_edit_permission: + current_user, _ = current_account_with_tenant() + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -505,7 +509,8 @@ class RagPipelineTaskStopApi(Resource): Stop workflow task """ # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.has_edit_permission: + current_user, _ = current_account_with_tenant() + if not current_user.has_edit_permission: raise Forbidden() AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) @@ -525,7 +530,8 @@ class PublishedRagPipelineApi(Resource): Get published pipeline """ # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.has_edit_permission: + current_user, _ = current_account_with_tenant() + if not current_user.has_edit_permission: raise Forbidden() if not pipeline.is_published: return None @@ -545,7 +551,8 @@ class PublishedRagPipelineApi(Resource): Publish workflow """ # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.has_edit_permission: + current_user, _ = current_account_with_tenant() + if not current_user.has_edit_permission: raise Forbidden() rag_pipeline_service = RagPipelineService() @@ -580,7 +587,8 @@ class DefaultRagPipelineBlockConfigsApi(Resource): Get default block config """ # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.has_edit_permission: + current_user, _ = current_account_with_tenant() + if not current_user.has_edit_permission: raise Forbidden() # Get default block configs @@ -599,7 +607,8 @@ class DefaultRagPipelineBlockConfigApi(Resource): Get default block config """ # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.has_edit_permission: + current_user, _ = current_account_with_tenant() + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -631,7 +640,8 @@ class PublishedAllRagPipelineApi(Resource): """ Get published workflows """ - if not isinstance(current_user, Account) or not current_user.has_edit_permission: + current_user, _ = current_account_with_tenant() + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -681,7 +691,8 @@ class RagPipelineByIdApi(Resource): Update workflow attributes """ # Check permission - if not isinstance(current_user, Account) or not current_user.has_edit_permission: + current_user, _ = current_account_with_tenant() + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -733,13 +744,11 @@ class PublishedRagPipelineSecondStepApi(Resource): @login_required @account_initialization_required @get_rag_pipeline + @edit_permission_required def get(self, pipeline: Pipeline): """ Get second step parameters of rag pipeline """ - # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.has_edit_permission: - raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("node_id", type=str, required=True, location="args") args = parser.parse_args() @@ -759,13 +768,11 @@ class PublishedRagPipelineFirstStepApi(Resource): @login_required @account_initialization_required @get_rag_pipeline + @edit_permission_required def get(self, pipeline: Pipeline): """ Get first step parameters of rag pipeline """ - # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.has_edit_permission: - raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("node_id", type=str, required=True, location="args") args = parser.parse_args() @@ -785,13 +792,11 @@ class DraftRagPipelineFirstStepApi(Resource): @login_required @account_initialization_required @get_rag_pipeline + @edit_permission_required def get(self, pipeline: Pipeline): """ Get first step parameters of rag pipeline """ - # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.has_edit_permission: - raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("node_id", type=str, required=True, location="args") args = parser.parse_args() @@ -811,13 +816,11 @@ class DraftRagPipelineSecondStepApi(Resource): @login_required @account_initialization_required @get_rag_pipeline + @edit_permission_required def get(self, pipeline: Pipeline): """ Get second step parameters of rag pipeline """ - # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.has_edit_permission: - raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("node_id", type=str, required=True, location="args") args = parser.parse_args() @@ -880,7 +883,7 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource): @account_initialization_required @get_rag_pipeline @marshal_with(workflow_run_node_execution_list_fields) - def get(self, pipeline: Pipeline, run_id): + def get(self, pipeline: Pipeline, run_id: str): """ Get workflow run node execution list """ @@ -903,14 +906,8 @@ class DatasourceListApi(Resource): @login_required @account_initialization_required def get(self): - user = current_user - if not isinstance(user, Account): - raise Forbidden() - tenant_id = user.current_tenant_id - if not tenant_id: - raise Forbidden() - - return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(tenant_id)) + _, current_tenant_id = current_account_with_tenant() + return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(current_tenant_id)) @console_ns.route("/rag/pipelines//workflows/draft/nodes//last-run") @@ -940,11 +937,11 @@ class RagPipelineTransformApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, dataset_id): - if not isinstance(current_user, Account): - raise Forbidden() + @edit_permission_required + def post(self, dataset_id: str): + current_user, _ = current_account_with_tenant() - if not (current_user.has_edit_permission or current_user.is_dataset_operator): + if not current_user.is_dataset_operator: raise Forbidden() dataset_id = str(dataset_id) @@ -959,14 +956,13 @@ class RagPipelineDatasourceVariableApi(Resource): @login_required @account_initialization_required @get_rag_pipeline + @edit_permission_required @marshal_with(workflow_run_node_execution_fields) def post(self, pipeline: Pipeline): """ Set datasource variables """ - if not isinstance(current_user, Account) or not current_user.has_edit_permission: - raise Forbidden() - + current_user, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("datasource_type", type=str, required=True, location="json") parser.add_argument("datasource_info", type=dict, required=True, location="json") diff --git a/api/controllers/console/datasets/wraps.py b/api/controllers/console/datasets/wraps.py index 98abb3ef8d..a8c1298e3e 100644 --- a/api/controllers/console/datasets/wraps.py +++ b/api/controllers/console/datasets/wraps.py @@ -3,8 +3,7 @@ from functools import wraps from controllers.console.datasets.error import PipelineNotFoundError from extensions.ext_database import db -from libs.login import current_user -from models.account import Account +from libs.login import current_account_with_tenant from models.dataset import Pipeline @@ -17,8 +16,7 @@ def get_rag_pipeline( if not kwargs.get("pipeline_id"): raise ValueError("missing pipeline_id in path parameters") - if not isinstance(current_user, Account): - raise ValueError("current_user is not an account") + _, current_tenant_id = current_account_with_tenant() pipeline_id = kwargs.get("pipeline_id") pipeline_id = str(pipeline_id) @@ -27,7 +25,7 @@ def get_rag_pipeline( pipeline = ( db.session.query(Pipeline) - .where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id) + .where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id) .first() ) diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index b045e47846..064e026753 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -23,8 +23,7 @@ from core.model_runtime.errors.invoke import InvokeError from fields.message_fields import message_infinite_scroll_pagination_fields from libs import helper from libs.helper import uuid_value -from libs.login import current_user -from models import Account +from libs.login import current_account_with_tenant from models.model import AppMode from services.app_generate_service import AppGenerateService from services.errors.app import MoreLikeThisDisabledError @@ -48,6 +47,7 @@ logger = logging.getLogger(__name__) class MessageListApi(InstalledAppResource): @marshal_with(message_infinite_scroll_pagination_fields) def get(self, installed_app): + current_user, _ = current_account_with_tenant() app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) @@ -61,8 +61,6 @@ class MessageListApi(InstalledAppResource): args = parser.parse_args() try: - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account instance") return MessageService.pagination_by_first_id( app_model, current_user, args["conversation_id"], args["first_id"], args["limit"] ) @@ -78,6 +76,7 @@ class MessageListApi(InstalledAppResource): ) class MessageFeedbackApi(InstalledAppResource): def post(self, installed_app, message_id): + current_user, _ = current_account_with_tenant() app_model = installed_app.app message_id = str(message_id) @@ -88,8 +87,6 @@ class MessageFeedbackApi(InstalledAppResource): args = parser.parse_args() try: - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account instance") MessageService.create_feedback( app_model=app_model, message_id=message_id, @@ -109,6 +106,7 @@ class MessageFeedbackApi(InstalledAppResource): ) class MessageMoreLikeThisApi(InstalledAppResource): def get(self, installed_app, message_id): + current_user, _ = current_account_with_tenant() app_model = installed_app.app if app_model.mode != "completion": raise NotCompletionAppError() @@ -124,8 +122,6 @@ class MessageMoreLikeThisApi(InstalledAppResource): streaming = args["response_mode"] == "streaming" try: - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account instance") response = AppGenerateService.generate_more_like_this( app_model=app_model, user=current_user, @@ -159,6 +155,7 @@ class MessageMoreLikeThisApi(InstalledAppResource): ) class MessageSuggestedQuestionApi(InstalledAppResource): def get(self, installed_app, message_id): + current_user, _ = current_account_with_tenant() app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: @@ -167,8 +164,6 @@ class MessageSuggestedQuestionApi(InstalledAppResource): message_id = str(message_id) try: - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account instance") questions = MessageService.get_suggested_questions_after_answer( app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE ) diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index 79e4a4339e..830685975b 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -7,8 +7,7 @@ from controllers.console.explore.error import NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource from fields.conversation_fields import message_file_fields from libs.helper import TimestampField, uuid_value -from libs.login import current_user -from models import Account +from libs.login import current_account_with_tenant from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService @@ -35,6 +34,7 @@ class SavedMessageListApi(InstalledAppResource): @marshal_with(saved_message_infinite_scroll_pagination_fields) def get(self, installed_app): + current_user, _ = current_account_with_tenant() app_model = installed_app.app if app_model.mode != "completion": raise NotCompletionAppError() @@ -44,11 +44,10 @@ class SavedMessageListApi(InstalledAppResource): parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account instance") return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"]) def post(self, installed_app): + current_user, _ = current_account_with_tenant() app_model = installed_app.app if app_model.mode != "completion": raise NotCompletionAppError() @@ -58,8 +57,6 @@ class SavedMessageListApi(InstalledAppResource): args = parser.parse_args() try: - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account instance") SavedMessageService.save(app_model, current_user, args["message_id"]) except MessageNotExistsError: raise NotFound("Message Not Exists.") @@ -72,6 +69,7 @@ class SavedMessageListApi(InstalledAppResource): ) class SavedMessageApi(InstalledAppResource): def delete(self, installed_app, message_id): + current_user, _ = current_account_with_tenant() app_model = installed_app.app message_id = str(message_id) @@ -79,8 +77,6 @@ class SavedMessageApi(InstalledAppResource): if app_model.mode != "completion": raise NotCompletionAppError() - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account instance") SavedMessageService.delete(app_model, current_user, message_id) return {"result": "success"}, 204 diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index 5956eb52c4..df4eed18eb 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -8,9 +8,8 @@ from werkzeug.exceptions import NotFound from controllers.console.explore.error import AppAccessDeniedError from controllers.console.wraps import account_initialization_required from extensions.ext_database import db -from libs.login import current_user, login_required +from libs.login import current_account_with_tenant, login_required from models import InstalledApp -from models.account import Account from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService @@ -24,13 +23,10 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non def decorator(view: Callable[Concatenate[InstalledApp, P], R]): @wraps(view) def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs): - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + _, current_tenant_id = current_account_with_tenant() installed_app = ( db.session.query(InstalledApp) - .where( - InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id - ) + .where(InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_tenant_id) .first() ) @@ -56,9 +52,9 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] | def decorator(view: Callable[Concatenate[InstalledApp, P], R]): @wraps(view) def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs): + current_user, _ = current_account_with_tenant() feature = FeatureService.get_system_features() if feature.webapp_auth.enabled: - assert isinstance(current_user, Account) app_id = installed_app.app_id app_code = AppService.get_app_code_by_id(app_id) res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp( diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index dd618307e9..f77996eb6a 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -4,8 +4,7 @@ from constants import HIDDEN_VALUE from controllers.console import api, console_ns from controllers.console.wraps import account_initialization_required, setup_required from fields.api_based_extension_fields import api_based_extension_fields -from libs.login import current_account_with_tenant, current_user, login_required -from models.account import Account +from libs.login import current_account_with_tenant, login_required from models.api_based_extension import APIBasedExtension from services.api_based_extension_service import APIBasedExtensionService from services.code_based_extension_service import CodeBasedExtensionService @@ -68,8 +67,7 @@ class APIBasedExtensionAPI(Resource): @account_initialization_required @marshal_with(api_based_extension_fields) def post(self): - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + _, current_tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, location="json") parser.add_argument("api_endpoint", type=str, required=True, location="json") @@ -98,8 +96,6 @@ class APIBasedExtensionDetailAPI(Resource): @account_initialization_required @marshal_with(api_based_extension_fields) def get(self, id): - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None api_based_extension_id = str(id) _, tenant_id = current_account_with_tenant() @@ -124,8 +120,6 @@ class APIBasedExtensionDetailAPI(Resource): @account_initialization_required @marshal_with(api_based_extension_fields) def post(self, id): - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None api_based_extension_id = str(id) _, current_tenant_id = current_account_with_tenant() @@ -153,8 +147,6 @@ class APIBasedExtensionDetailAPI(Resource): @login_required @account_initialization_required def delete(self, id): - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None api_based_extension_id = str(id) _, current_tenant_id = current_account_with_tenant() diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index 2b63f6febc..1cd193f7ad 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -1,7 +1,6 @@ from typing import Literal from flask import request -from flask_login import current_user from flask_restx import Resource, marshal_with from werkzeug.exceptions import Forbidden @@ -22,8 +21,7 @@ from controllers.console.wraps import ( ) from extensions.ext_database import db from fields.file_fields import file_fields, upload_config_fields -from libs.login import login_required -from models import Account +from libs.login import current_account_with_tenant, login_required from services.file_service import FileService from . import console_ns @@ -53,6 +51,7 @@ class FileApi(Resource): @marshal_with(file_fields) @cloud_edition_billing_resource_check("documents") def post(self): + current_user, _ = current_account_with_tenant() source_str = request.form.get("source") source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None @@ -65,16 +64,12 @@ class FileApi(Resource): if not file.filename: raise FilenameNotExistsError - if source == "datasets" and not current_user.is_dataset_editor: raise Forbidden() if source not in ("datasets", None): source = None - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") - try: upload_file = FileService(db.engine).upload_file( filename=file.filename, diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index b6086c5766..5748ca110d 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -5,8 +5,7 @@ from werkzeug.exceptions import Forbidden from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required from fields.tag_fields import dataset_tag_fields -from libs.login import current_user, login_required -from models.account import Account +from libs.login import current_account_with_tenant, login_required from models.model import Tag from services.tag_service import TagService @@ -24,11 +23,10 @@ class TagListApi(Resource): @account_initialization_required @marshal_with(dataset_tag_fields) def get(self): - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + _, current_tenant_id = current_account_with_tenant() tag_type = request.args.get("type", type=str, default="") keyword = request.args.get("keyword", default=None, type=str) - tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword) + tags = TagService.get_tags(tag_type, current_tenant_id, keyword) return tags, 200 @@ -36,8 +34,7 @@ class TagListApi(Resource): @login_required @account_initialization_required def post(self): - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + current_user, _ = current_account_with_tenant() # The role of the current user in the ta table must be admin, owner, or editor if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() @@ -63,8 +60,7 @@ class TagUpdateDeleteApi(Resource): @login_required @account_initialization_required def patch(self, tag_id): - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + current_user, _ = current_account_with_tenant() tag_id = str(tag_id) # The role of the current user in the ta table must be admin, owner, or editor if not (current_user.has_edit_permission or current_user.is_dataset_editor): @@ -87,8 +83,7 @@ class TagUpdateDeleteApi(Resource): @login_required @account_initialization_required def delete(self, tag_id): - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + current_user, _ = current_account_with_tenant() tag_id = str(tag_id) # The role of the current user in the ta table must be admin, owner, or editor if not current_user.has_edit_permission: @@ -105,8 +100,7 @@ class TagBindingCreateApi(Resource): @login_required @account_initialization_required def post(self): - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + current_user, _ = current_account_with_tenant() # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() @@ -133,8 +127,7 @@ class TagBindingDeleteApi(Resource): @login_required @account_initialization_required def post(self): - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + current_user, _ = current_account_with_tenant() # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() diff --git a/api/controllers/console/workspace/__init__.py b/api/controllers/console/workspace/__init__.py index 4a048f3c5e..876e2301f2 100644 --- a/api/controllers/console/workspace/__init__.py +++ b/api/controllers/console/workspace/__init__.py @@ -2,11 +2,11 @@ from collections.abc import Callable from functools import wraps from typing import ParamSpec, TypeVar -from flask_login import current_user from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden from extensions.ext_database import db +from libs.login import current_account_with_tenant from models.account import TenantPluginPermission P = ParamSpec("P") @@ -20,8 +20,9 @@ def plugin_permission_required( def interceptor(view: Callable[P, R]): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): + current_user, current_tenant_id = current_account_with_tenant() user = current_user - tenant_id = user.current_tenant_id + tenant_id = current_tenant_id with Session(db.engine) as session: permission = ( diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index e2b0e3f84d..a5e6b8f473 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -2,7 +2,6 @@ from datetime import datetime import pytz from flask import request -from flask_login import current_user from flask_restx import Resource, fields, marshal_with, reqparse from sqlalchemy import select from sqlalchemy.orm import Session @@ -37,9 +36,8 @@ from extensions.ext_database import db from fields.member_fields import account_fields from libs.datetime_utils import naive_utc_now from libs.helper import TimestampField, email, extract_remote_ip, timezone -from libs.login import login_required -from models import AccountIntegrate, InvitationCode -from models.account import Account +from libs.login import current_account_with_tenant, login_required +from models import Account, AccountIntegrate, InvitationCode from services.account_service import AccountService from services.billing_service import BillingService from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError @@ -50,9 +48,7 @@ class AccountInitApi(Resource): @setup_required @login_required def post(self): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") - account = current_user + account, _ = current_account_with_tenant() if account.status == "active": raise AccountAlreadyInitedError() @@ -106,8 +102,7 @@ class AccountProfileApi(Resource): @marshal_with(account_fields) @enterprise_license_required def get(self): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, _ = current_account_with_tenant() return current_user @@ -118,8 +113,7 @@ class AccountNameApi(Resource): @account_initialization_required @marshal_with(account_fields) def post(self): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, location="json") args = parser.parse_args() @@ -140,8 +134,7 @@ class AccountAvatarApi(Resource): @account_initialization_required @marshal_with(account_fields) def post(self): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("avatar", type=str, required=True, location="json") args = parser.parse_args() @@ -158,8 +151,7 @@ class AccountInterfaceLanguageApi(Resource): @account_initialization_required @marshal_with(account_fields) def post(self): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("interface_language", type=supported_language, required=True, location="json") args = parser.parse_args() @@ -176,8 +168,7 @@ class AccountInterfaceThemeApi(Resource): @account_initialization_required @marshal_with(account_fields) def post(self): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json") args = parser.parse_args() @@ -194,8 +185,7 @@ class AccountTimezoneApi(Resource): @account_initialization_required @marshal_with(account_fields) def post(self): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("timezone", type=str, required=True, location="json") args = parser.parse_args() @@ -216,8 +206,7 @@ class AccountPasswordApi(Resource): @account_initialization_required @marshal_with(account_fields) def post(self): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("password", type=str, required=False, location="json") parser.add_argument("new_password", type=str, required=True, location="json") @@ -253,9 +242,7 @@ class AccountIntegrateApi(Resource): @account_initialization_required @marshal_with(integrate_list_fields) def get(self): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") - account = current_user + account, _ = current_account_with_tenant() account_integrates = db.session.scalars( select(AccountIntegrate).where(AccountIntegrate.account_id == account.id) @@ -298,9 +285,7 @@ class AccountDeleteVerifyApi(Resource): @login_required @account_initialization_required def get(self): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") - account = current_user + account, _ = current_account_with_tenant() token, code = AccountService.generate_account_deletion_verification_code(account) AccountService.send_account_deletion_verification_email(account, code) @@ -314,9 +299,7 @@ class AccountDeleteApi(Resource): @login_required @account_initialization_required def post(self): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") - account = current_user + account, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("token", type=str, required=True, location="json") @@ -358,9 +341,7 @@ class EducationVerifyApi(Resource): @cloud_edition_billing_enabled @marshal_with(verify_fields) def get(self): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") - account = current_user + account, _ = current_account_with_tenant() return BillingService.EducationIdentity.verify(account.id, account.email) @@ -380,9 +361,7 @@ class EducationApi(Resource): @only_edition_cloud @cloud_edition_billing_enabled def post(self): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") - account = current_user + account, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("token", type=str, required=True, location="json") @@ -399,9 +378,7 @@ class EducationApi(Resource): @cloud_edition_billing_enabled @marshal_with(status_fields) def get(self): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") - account = current_user + account, _ = current_account_with_tenant() res = BillingService.EducationIdentity.status(account.id) # convert expire_at to UTC timestamp from isoformat @@ -441,6 +418,7 @@ class ChangeEmailSendEmailApi(Resource): @login_required @account_initialization_required def post(self): + current_user, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("email", type=email, required=True, location="json") parser.add_argument("language", type=str, required=False, location="json") @@ -467,8 +445,6 @@ class ChangeEmailSendEmailApi(Resource): raise InvalidTokenError() user_email = reset_data.get("email", "") - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") if user_email != current_user.email: raise InvalidEmailError() else: @@ -551,8 +527,7 @@ class ChangeEmailResetApi(Resource): AccountService.revoke_change_email_token(args["token"]) old_email = reset_data.get("old_email", "") - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, _ = current_account_with_tenant() if current_user.email != old_email: raise AccountNotFound() diff --git a/api/controllers/console/workspace/agent_providers.py b/api/controllers/console/workspace/agent_providers.py index e044b2db5b..0a8f49d2e5 100644 --- a/api/controllers/console/workspace/agent_providers.py +++ b/api/controllers/console/workspace/agent_providers.py @@ -3,8 +3,7 @@ from flask_restx import Resource, fields from controllers.console import api, console_ns from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder -from libs.login import current_user, login_required -from models.account import Account +from libs.login import current_account_with_tenant, login_required from services.agent_service import AgentService @@ -21,12 +20,11 @@ class AgentProviderListApi(Resource): @login_required @account_initialization_required def get(self): - assert isinstance(current_user, Account) + current_user, current_tenant_id = current_account_with_tenant() user = current_user - assert user.current_tenant_id is not None user_id = user.id - tenant_id = user.current_tenant_id + tenant_id = current_tenant_id return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id)) @@ -45,9 +43,5 @@ class AgentProviderApi(Resource): @login_required @account_initialization_required def get(self, provider_name: str): - assert isinstance(current_user, Account) - user = current_user - assert user.current_tenant_id is not None - user_id = user.id - tenant_id = user.current_tenant_id - return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name)) + current_user, current_tenant_id = current_account_with_tenant() + return jsonable_encoder(AgentService.get_agent_provider(current_user.id, current_tenant_id, provider_name)) diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index 99a1c1f032..4e6f1fa3a5 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -5,8 +5,8 @@ from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError -from libs.login import current_user, login_required -from models.account import Account, TenantAccountRole +from libs.login import current_account_with_tenant, login_required +from models import TenantAccountRole from services.model_load_balancing_service import ModelLoadBalancingService @@ -18,12 +18,11 @@ class LoadBalancingCredentialsValidateApi(Resource): @login_required @account_initialization_required def post(self, provider: str): - assert isinstance(current_user, Account) + current_user, current_tenant_id = current_account_with_tenant() if not TenantAccountRole.is_privileged_role(current_user.current_role): raise Forbidden() - tenant_id = current_user.current_tenant_id - assert tenant_id is not None + tenant_id = current_tenant_id parser = reqparse.RequestParser() parser.add_argument("model", type=str, required=True, nullable=False, location="json") @@ -72,12 +71,11 @@ class LoadBalancingConfigCredentialsValidateApi(Resource): @login_required @account_initialization_required def post(self, provider: str, config_id: str): - assert isinstance(current_user, Account) + current_user, current_tenant_id = current_account_with_tenant() if not TenantAccountRole.is_privileged_role(current_user.current_role): raise Forbidden() - tenant_id = current_user.current_tenant_id - assert tenant_id is not None + tenant_id = current_tenant_id parser = reqparse.RequestParser() parser.add_argument("model", type=str, required=True, nullable=False, location="json") diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 4a0539785a..5be427e9bb 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -23,8 +23,8 @@ from controllers.console.wraps import ( ) from extensions.ext_database import db from libs.helper import TimestampField -from libs.login import current_user, login_required -from models.account import Account, Tenant, TenantStatus +from libs.login import current_account_with_tenant, login_required +from models.account import Tenant, TenantStatus from services.account_service import TenantService from services.feature_service import FeatureService from services.file_service import FileService @@ -70,8 +70,7 @@ class TenantListApi(Resource): @login_required @account_initialization_required def get(self): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, current_tenant_id = current_account_with_tenant() tenants = TenantService.get_join_tenants(current_user) tenant_dicts = [] @@ -85,7 +84,7 @@ class TenantListApi(Resource): "status": tenant.status, "created_at": tenant.created_at, "plan": features.billing.subscription.plan if features.billing.enabled else "sandbox", - "current": tenant.id == current_user.current_tenant_id if current_user.current_tenant_id else False, + "current": tenant.id == current_tenant_id if current_tenant_id else False, } tenant_dicts.append(tenant_dict) @@ -130,8 +129,7 @@ class TenantApi(Resource): if request.path == "/info": logger.warning("Deprecated URL /info was used.") - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, _ = current_account_with_tenant() tenant = current_user.current_tenant if not tenant: raise ValueError("No current tenant") @@ -155,8 +153,7 @@ class SwitchWorkspaceApi(Resource): @login_required @account_initialization_required def post(self): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, _ = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("tenant_id", type=str, required=True, location="json") args = parser.parse_args() @@ -181,16 +178,12 @@ class CustomConfigWorkspaceApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("workspace_custom") def post(self): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + _, current_tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("remove_webapp_brand", type=bool, location="json") parser.add_argument("replace_webapp_logo", type=str, location="json") args = parser.parse_args() - - if not current_user.current_tenant_id: - raise ValueError("No current tenant") - tenant = db.get_or_404(Tenant, current_user.current_tenant_id) + tenant = db.get_or_404(Tenant, current_tenant_id) custom_config_dict = { "remove_webapp_brand": args["remove_webapp_brand"], @@ -212,8 +205,7 @@ class WebappLogoWorkspaceApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("workspace_custom") def post(self): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, _ = current_account_with_tenant() # check file if "file" not in request.files: raise NoFileUploadedError() @@ -253,15 +245,14 @@ class WorkspaceInfoApi(Resource): @account_initialization_required # Change workspace name def post(self): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + _, current_tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, location="json") args = parser.parse_args() - if not current_user.current_tenant_id: + if not current_tenant_id: raise ValueError("No current tenant") - tenant = db.get_or_404(Tenant, current_user.current_tenant_id) + tenant = db.get_or_404(Tenant, current_tenant_id) tenant.name = args["name"] db.session.commit() diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 4158f0524f..2fa28711c3 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -30,10 +30,7 @@ def account_initialization_required(view: Callable[P, R]): def decorated(*args: P.args, **kwargs: P.kwargs): # check account initialization current_user, _ = current_account_with_tenant() - - account = current_user - - if account.status == AccountStatus.UNINITIALIZED: + if current_user.status == AccountStatus.UNINITIALIZED: raise AccountNotInitializedError() return view(*args, **kwargs) @@ -249,9 +246,9 @@ def email_password_login_enabled(view: Callable[P, R]): return decorated -def email_register_enabled(view): +def email_register_enabled(view: Callable[P, R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): features = FeatureService.get_system_features() if features.is_allow_register: return view(*args, **kwargs) @@ -299,3 +296,16 @@ def knowledge_pipeline_publish_enabled(view: Callable[P, R]): abort(403) return decorated + + +def edit_permission_required(f: Callable[P, R]): + @wraps(f) + def decorated_function(*args: P.args, **kwargs: P.kwargs): + from werkzeug.exceptions import Forbidden + + current_user, _ = current_account_with_tenant() + if not current_user.has_edit_permission: + raise Forbidden() + return f(*args, **kwargs) + + return decorated_function diff --git a/api/controllers/inner_api/mail.py b/api/controllers/inner_api/mail.py index 0b2be03e43..39411a077a 100644 --- a/api/controllers/inner_api/mail.py +++ b/api/controllers/inner_api/mail.py @@ -17,7 +17,7 @@ class BaseMail(Resource): def post(self): args = _mail_parser.parse_args() - send_inner_email_task.delay( + send_inner_email_task.delay( # type: ignore to=args["to"], subject=args["subject"], body=args["body"], diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index deab50076d..e4fe8d44bf 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -31,7 +31,7 @@ from core.plugin.entities.request import ( ) from core.tools.entities.tool_entities import ToolProviderType from libs.helper import length_prefixed_response -from models.account import Account, Tenant +from models import Account, Tenant from models.model import EndUser diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py index 47f0240cd2..861da57708 100644 --- a/api/controllers/inner_api/workspace/workspace.py +++ b/api/controllers/inner_api/workspace/workspace.py @@ -7,7 +7,7 @@ from controllers.inner_api import inner_api_ns from controllers.inner_api.wraps import enterprise_inner_api_only from events.tenant_event import tenant_was_created from extensions.ext_database import db -from models.account import Account +from models import Account from services.account_service import TenantService diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index ad1bdc7334..0521f1537c 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -10,7 +10,7 @@ from controllers.service_api.wraps import validate_app_token from extensions.ext_redis import redis_client from fields.annotation_fields import annotation_fields, build_annotation_model from libs.login import current_user -from models.account import Account +from models import Account from models.model import App from services.annotation_service import AppAnnotationService diff --git a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py index 13ef8abc2d..38891f0180 100644 --- a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py @@ -17,7 +17,7 @@ from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.entities.app_invoke_entities import InvokeFrom from libs import helper from libs.login import current_user -from models.account import Account +from models import Account from models.dataset import Pipeline from models.engine import db from services.errors.file import FileTooLargeError, UnsupportedFileTypeError diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 2c9be4e887..638ab528f3 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -17,7 +17,7 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from libs.login import current_user -from models.account import Account, Tenant, TenantAccountJoin, TenantStatus +from models import Account, Tenant, TenantAccountJoin, TenantStatus from models.dataset import Dataset, RateLimitLog from models.model import ApiToken, App, DefaultEndUserSessionID, EndUser from services.feature_service import FeatureService diff --git a/api/controllers/web/forgot_password.py b/api/controllers/web/forgot_password.py index c743d0f52b..cbafd70e99 100644 --- a/api/controllers/web/forgot_password.py +++ b/api/controllers/web/forgot_password.py @@ -20,7 +20,7 @@ from controllers.web import web_ns from extensions.ext_database import db from libs.helper import email, extract_remote_ip from libs.password import hash_password, valid_password -from models.account import Account +from models import Account from services.account_service import AccountService diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index e021b0aca7..b5af6382e8 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -70,8 +70,7 @@ from core.workflow.system_variable import SystemVariable from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager from extensions.ext_database import db from libs.datetime_utils import naive_utc_now -from models import Conversation, EndUser, Message, MessageFile -from models.account import Account +from models import Account, Conversation, EndUser, Message, MessageFile from models.enums import CreatorUserRole from models.workflow import Workflow diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 8bd956b314..c1251d2feb 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -23,7 +23,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory -from models.account import Account +from models import Account from models.model import App, EndUser from services.conversation_service import ConversationService diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 56b0d91141..ec4dc87643 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -61,7 +61,7 @@ from core.workflow.repositories.workflow_node_execution_repository import Workfl from core.workflow.system_variable import SystemVariable from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager from extensions.ext_database import db -from models.account import Account +from models import Account from models.enums import CreatorUserRole from models.model import EndUser from models.workflow import ( diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 526e78f1d0..7db9b076d2 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -913,4 +913,4 @@ class TraceQueueManager: "file_id": file_id, "app_id": task.app_id, } - process_trace_tasks.delay(file_info) + process_trace_tasks.delay(file_info) # type: ignore diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index 8b08b09eb9..32ac132e1e 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -14,7 +14,7 @@ from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from extensions.ext_database import db -from models.account import Account +from models import Account from models.model import App, AppMode, EndUser diff --git a/api/core/repositories/celery_workflow_execution_repository.py b/api/core/repositories/celery_workflow_execution_repository.py index eda7b54d6a..460bb75722 100644 --- a/api/core/repositories/celery_workflow_execution_repository.py +++ b/api/core/repositories/celery_workflow_execution_repository.py @@ -108,7 +108,7 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository): execution_data = execution.model_dump() # Queue the save operation as a Celery task (fire and forget) - save_workflow_execution_task.delay( + save_workflow_execution_task.delay( # type: ignore execution_data=execution_data, tenant_id=self._tenant_id, app_id=self._app_id or "", diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 0851a54338..ca2aa39861 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -12,7 +12,7 @@ from core.file import File, FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_file_manager import ToolFileManager from libs.login import current_user -from models.account import Account +from models import Account logger = logging.getLogger(__name__) diff --git a/api/events/event_handlers/clean_when_dataset_deleted.py b/api/events/event_handlers/clean_when_dataset_deleted.py index 7caa2d1cc9..0f6aa0e778 100644 --- a/api/events/event_handlers/clean_when_dataset_deleted.py +++ b/api/events/event_handlers/clean_when_dataset_deleted.py @@ -1,10 +1,13 @@ from events.dataset_event import dataset_was_deleted +from models import Dataset from tasks.clean_dataset_task import clean_dataset_task @dataset_was_deleted.connect -def handle(sender, **kwargs): +def handle(sender: Dataset, **kwargs): dataset = sender + assert dataset.doc_form + assert dataset.indexing_technique clean_dataset_task.delay( dataset.id, dataset.tenant_id, diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index 5571c0d9ba..836a5d938c 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -9,7 +9,7 @@ from configs import dify_config from dify_app import DifyApp from extensions.ext_database import db from libs.passport import PassportService -from models.account import Account, Tenant, TenantAccountJoin +from models import Account, Tenant, TenantAccountJoin from models.model import AppMCPServer, EndUser from services.account_service import AccountService diff --git a/api/libs/external_api.py b/api/libs/external_api.py index 25a82f8a96..a59230caaa 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -22,7 +22,7 @@ def register_external_error_handlers(api: Api): got_request_exception.send(current_app, exception=e) # If Werkzeug already prepared a Response, just use it. - if getattr(e, "response", None) is not None: + if e.response is not None: return e.response status_code = getattr(e, "code", 500) or 500 @@ -106,7 +106,7 @@ def register_external_error_handlers(api: Api): # Log stack exc_info: Any = sys.exc_info() if exc_info[1] is None: - exc_info = None + exc_info = (None, None, None) current_app.log_exception(exc_info) return data, status_code diff --git a/api/libs/helper.py b/api/libs/helper.py index 0551470f65..b878141d8e 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -24,7 +24,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_redis import redis_client if TYPE_CHECKING: - from models.account import Account + from models import Account from models.model import EndUser logger = logging.getLogger(__name__) @@ -43,7 +43,7 @@ def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None: Raises: ValueError: If user is neither Account nor EndUser """ - from models.account import Account + from models import Account from models.model import EndUser if isinstance(user, Account): @@ -78,7 +78,7 @@ class AvatarUrlField(fields.Raw): if obj is None: return None - from models.account import Account + from models import Account if isinstance(obj, Account) and obj.avatar is not None: return file_helpers.get_signed_file_url(obj.avatar) diff --git a/api/libs/login.py b/api/libs/login.py index 24b8c4011a..2c75ef9297 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -7,7 +7,7 @@ from flask_login.config import EXEMPT_METHODS # type: ignore from werkzeug.local import LocalProxy from configs import dify_config -from models.account import Account +from models import Account from models.model import EndUser #: A proxy for the current user. If no user is logged in, this will be an diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py index ef6edd6709..b70707b17e 100644 --- a/api/schedule/mail_clean_document_notify_task.py +++ b/api/schedule/mail_clean_document_notify_task.py @@ -10,7 +10,7 @@ from configs import dify_config from extensions.ext_database import db from extensions.ext_mail import mail from libs.email_i18n import EmailType, get_email_i18n_service -from models.account import Account, Tenant, TenantAccountJoin +from models import Account, Tenant, TenantAccountJoin from models.dataset import Dataset, DatasetAutoDisableLog from services.feature_service import FeatureService diff --git a/api/services/agent_service.py b/api/services/agent_service.py index d631ce812f..b2db895a5a 100644 --- a/api/services/agent_service.py +++ b/api/services/agent_service.py @@ -10,7 +10,7 @@ from core.plugin.impl.exc import PluginDaemonClientSideError from core.tools.tool_manager import ToolManager from extensions.ext_database import db from libs.login import current_user -from models.account import Account +from models import Account from models.model import App, Conversation, EndUser, Message, MessageAgentThought diff --git a/api/services/app_service.py b/api/services/app_service.py index 4fc6cf2494..5f8c5089c9 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -18,7 +18,7 @@ from events.app_event import app_was_created from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from libs.login import current_user -from models.account import Account +from models import Account from models.model import App, AppMode, AppModelConfig, Site from models.tools import ApiToolProvider from services.billing_service import BillingService diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 9d6c5b4b31..a6851d2638 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -7,7 +7,7 @@ from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fix from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.helper import RateLimiter -from models.account import Account, TenantAccountJoin, TenantAccountRole +from models import Account, TenantAccountJoin, TenantAccountRole class BillingService: diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index a8e51a426d..39d6c81621 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -14,8 +14,7 @@ from extensions.ext_database import db from factories import variable_factory from libs.datetime_utils import naive_utc_now from libs.infinite_scroll_pagination import InfiniteScrollPagination -from models import ConversationVariable -from models.account import Account +from models import Account, ConversationVariable from models.model import App, Conversation, EndUser, Message from services.errors.conversation import ( ConversationNotExistsError, diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 53216e4fdd..f4047da6b8 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -29,7 +29,7 @@ from extensions.ext_redis import redis_client from libs import helper from libs.datetime_utils import naive_utc_now from libs.login import current_user -from models.account import Account, TenantAccountRole +from models import Account, TenantAccountRole from models.dataset import ( AppDatasetJoin, ChildChunk, diff --git a/api/services/file_service.py b/api/services/file_service.py index f0bb68766d..dd6a829ea2 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -19,7 +19,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from libs.helper import extract_tenant_id -from models.account import Account +from models import Account from models.enums import CreatorUserRole from models.model import EndUser, UploadFile diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index c6ea35076e..7fa82c6d22 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -9,7 +9,7 @@ from core.rag.models.document import Document from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db -from models.account import Account +from models import Account from models.dataset import Dataset, DatasetQuery logger = logging.getLogger(__name__) diff --git a/api/services/message_service.py b/api/services/message_service.py index 5e356bf925..9fdff18622 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -12,7 +12,7 @@ from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination -from models.account import Account +from models import Account from models.model import App, AppMode, AppModelConfig, EndUser, Message, MessageFeedback from services.conversation_service import ConversationService from services.errors.message import ( diff --git a/api/services/metadata_service.py b/api/services/metadata_service.py index 6add830813..5f280c9e57 100644 --- a/api/services/metadata_service.py +++ b/api/services/metadata_service.py @@ -1,12 +1,11 @@ import copy import logging -from flask_login import current_user - from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now +from libs.login import current_account_with_tenant from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding from services.dataset_service import DocumentService from services.entities.knowledge_entities.knowledge_entities import ( @@ -23,11 +22,11 @@ class MetadataService: # check if metadata name is too long if len(metadata_args.name) > 255: raise ValueError("Metadata name cannot exceed 255 characters.") - + current_user, current_tenant_id = current_account_with_tenant() # check if metadata name already exists if ( db.session.query(DatasetMetadata) - .filter_by(tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name) + .filter_by(tenant_id=current_tenant_id, dataset_id=dataset_id, name=metadata_args.name) .first() ): raise ValueError("Metadata name already exists.") @@ -35,7 +34,7 @@ class MetadataService: if field.value == metadata_args.name: raise ValueError("Metadata name already exists in Built-in fields.") metadata = DatasetMetadata( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, dataset_id=dataset_id, type=metadata_args.type, name=metadata_args.name, @@ -53,9 +52,10 @@ class MetadataService: lock_key = f"dataset_metadata_lock_{dataset_id}" # check if metadata name already exists + current_user, current_tenant_id = current_account_with_tenant() if ( db.session.query(DatasetMetadata) - .filter_by(tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name) + .filter_by(tenant_id=current_tenant_id, dataset_id=dataset_id, name=name) .first() ): raise ValueError("Metadata name already exists.") @@ -220,9 +220,10 @@ class MetadataService: db.session.commit() # deal metadata binding db.session.query(DatasetMetadataBinding).filter_by(document_id=operation.document_id).delete() + current_user, current_tenant_id = current_account_with_tenant() for metadata_value in operation.metadata_list: dataset_metadata_binding = DatasetMetadataBinding( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, dataset_id=dataset.id, document_id=operation.document_id, metadata_id=metadata_value.id, diff --git a/api/services/oauth_server.py b/api/services/oauth_server.py index b722dbee22..b05b43d76e 100644 --- a/api/services/oauth_server.py +++ b/api/services/oauth_server.py @@ -7,7 +7,7 @@ from werkzeug.exceptions import BadRequest from extensions.ext_database import db from extensions.ext_redis import redis_client -from models.account import Account +from models import Account from models.model import OAuthProviderApp from services.account_service import AccountService diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py index ca871bcaa1..4ac2e0792b 100644 --- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -1,7 +1,7 @@ import yaml -from flask_login import current_user from extensions.ext_database import db +from libs.login import current_account_with_tenant from models.dataset import PipelineCustomizedTemplate from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType @@ -13,9 +13,8 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ def get_pipeline_templates(self, language: str) -> dict: - result = self.fetch_pipeline_templates_from_customized( - tenant_id=current_user.current_tenant_id, language=language - ) + _, current_tenant_id = current_account_with_tenant() + result = self.fetch_pipeline_templates_from_customized(tenant_id=current_tenant_id, language=language) return result def get_pipeline_template_detail(self, template_id: str): diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 13c0ca7392..d2ba462a37 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -54,7 +54,7 @@ from core.workflow.system_variable import SystemVariable from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination -from models.account import Account +from models import Account from models.dataset import ( # type: ignore Dataset, Document, diff --git a/api/services/saved_message_service.py b/api/services/saved_message_service.py index 67a0106bbd..4dd6c8107b 100644 --- a/api/services/saved_message_service.py +++ b/api/services/saved_message_service.py @@ -2,7 +2,7 @@ from typing import Union from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination -from models.account import Account +from models import Account from models.model import App, EndUser from models.web import SavedMessage from services.message_service import MessageService diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index 0f54e838f3..560aec2330 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination -from models.account import Account +from models import Account from models.model import App, EndUser from models.web import PinnedConversation from services.conversation_service import ConversationService diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index d30e14f7a1..693bfb95b6 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -10,7 +10,7 @@ from extensions.ext_database import db from libs.helper import TokenManager from libs.passport import PassportService from libs.password import compare_password -from models.account import Account, AccountStatus +from models import Account, AccountStatus from models.model import App, EndUser, Site from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 9c09f54bf5..e70b2b5c95 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -22,7 +22,7 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.workflow.nodes import NodeType from events.app_event import app_was_created from extensions.ext_database import db -from models.account import Account +from models import Account from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint from models.model import App, AppMode, AppModelConfig from models.workflow import Workflow, WorkflowType diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 344b7486ee..5e63a83bb1 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -32,8 +32,7 @@ from factories.file_factory import StorageKeyLoader from factories.variable_factory import build_segment, segment_to_variable from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 -from models import App, Conversation -from models.account import Account +from models import Account, App, Conversation from models.enums import DraftVariableType from models.workflow import Workflow, WorkflowDraftVariable, WorkflowDraftVariableFile, is_system_variable_editable from repositories.factory import DifyAPIRepositoryFactory diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index dea6a657a4..f765c229ab 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -30,7 +30,7 @@ from extensions.ext_database import db from extensions.ext_storage import storage from factories.file_factory import build_from_mapping, build_from_mappings from libs.datetime_utils import naive_utc_now -from models.account import Account +from models import Account from models.model import App, AppMode from models.tools import WorkflowToolProvider from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType diff --git a/api/tasks/delete_account_task.py b/api/tasks/delete_account_task.py index 611aef86ad..fb5eb1d691 100644 --- a/api/tasks/delete_account_task.py +++ b/api/tasks/delete_account_task.py @@ -3,7 +3,7 @@ import logging from celery import shared_task from extensions.ext_database import db -from models.account import Account +from models import Account from services.billing_service import BillingService from tasks.mail_account_deletion_task import send_deletion_success_task diff --git a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py index 4171656131..6de95a3b85 100644 --- a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py @@ -16,7 +16,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerat from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity from core.repositories.factory import DifyCoreRepositoryFactory from extensions.ext_database import db -from models.account import Account, Tenant +from models import Account, Tenant from models.dataset import Pipeline from models.enums import WorkflowRunTriggeredFrom from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom diff --git a/api/tasks/rag_pipeline/rag_pipeline_run_task.py b/api/tasks/rag_pipeline/rag_pipeline_run_task.py index 90ebe80daf..f4a092d97e 100644 --- a/api/tasks/rag_pipeline/rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/rag_pipeline_run_task.py @@ -17,7 +17,7 @@ from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEnti from core.repositories.factory import DifyCoreRepositoryFactory from extensions.ext_database import db from extensions.ext_redis import redis_client -from models.account import Account, Tenant +from models import Account, Tenant from models.dataset import Pipeline from models.enums import WorkflowRunTriggeredFrom from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index 9c12696824..9d208647e6 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -10,7 +10,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now -from models.account import Account, Tenant +from models import Account, Tenant from models.dataset import Dataset, Document, DocumentSegment from services.feature_service import FeatureService from services.rag_pipeline.rag_pipeline import RagPipelineService diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index 6eff73a8f3..c59fc50f08 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -8,7 +8,7 @@ from werkzeug.exceptions import Unauthorized from configs import dify_config from controllers.console.error import AccountNotFound, NotAllowedCreateWorkspace -from models.account import AccountStatus, TenantAccountJoin +from models import AccountStatus, TenantAccountJoin from services.account_service import AccountService, RegisterService, TenantService, TokenPair from services.errors.account import ( AccountAlreadyInTenantError, @@ -470,7 +470,7 @@ class TestAccountService: # Verify integration was created from extensions.ext_database import db - from models.account import AccountIntegrate + from models import AccountIntegrate integration = db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider="new-google").first() assert integration is not None @@ -505,7 +505,7 @@ class TestAccountService: # Verify integration was updated from extensions.ext_database import db - from models.account import AccountIntegrate + from models import AccountIntegrate integration = ( db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider="exists-google").first() @@ -2303,7 +2303,7 @@ class TestRegisterService: # Verify account was created from extensions.ext_database import db - from models.account import Account + from models import Account from models.model import DifySetup account = db.session.query(Account).filter_by(email=admin_email).first() @@ -2352,7 +2352,7 @@ class TestRegisterService: # Verify no entities were created (rollback worked) from extensions.ext_database import db - from models.account import Account, Tenant, TenantAccountJoin + from models import Account, Tenant, TenantAccountJoin from models.model import DifySetup account = db.session.query(Account).filter_by(email=admin_email).first() @@ -2446,7 +2446,7 @@ class TestRegisterService: # Verify OAuth integration was created from extensions.ext_database import db - from models.account import AccountIntegrate + from models import AccountIntegrate integration = db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first() assert integration is not None @@ -2472,7 +2472,7 @@ class TestRegisterService: mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False # Execute registration with pending status - from models.account import AccountStatus + from models import AccountStatus account = RegisterService.register( email=email, @@ -2661,7 +2661,7 @@ class TestRegisterService: # Verify new account was created with pending status from extensions.ext_database import db - from models.account import Account, TenantAccountJoin + from models import Account, TenantAccountJoin new_account = db.session.query(Account).filter_by(email=new_member_email).first() assert new_account is not None diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py index c572ddc925..ca513319b2 100644 --- a/api/tests/test_containers_integration_tests/services/test_agent_service.py +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -5,7 +5,7 @@ import pytest from faker import Faker from core.plugin.impl.exc import PluginDaemonClientSideError -from models.account import Account +from models import Account from models.model import AppModelConfig, Conversation, EndUser, Message, MessageAgentThought from services.account_service import AccountService, TenantService from services.agent_service import AgentService diff --git a/api/tests/test_containers_integration_tests/services/test_annotation_service.py b/api/tests/test_containers_integration_tests/services/test_annotation_service.py index 4768b981cc..2b03ec1c26 100644 --- a/api/tests/test_containers_integration_tests/services/test_annotation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_annotation_service.py @@ -4,7 +4,7 @@ import pytest from faker import Faker from werkzeug.exceptions import NotFound -from models.account import Account +from models import Account from models.model import MessageAnnotation from services.annotation_service import AppAnnotationService from services.app_service import AppService diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index cbbbbddb21..e53392bcef 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -4,7 +4,7 @@ import pytest from faker import Faker from constants.model_template import default_app_templates -from models.account import Account +from models import Account from models.model import App, Site from services.account_service import AccountService, TenantService from services.app_service import AppService diff --git a/api/tests/test_containers_integration_tests/services/test_file_service.py b/api/tests/test_containers_integration_tests/services/test_file_service.py index e6bfc157c7..4c94e42f3e 100644 --- a/api/tests/test_containers_integration_tests/services/test_file_service.py +++ b/api/tests/test_containers_integration_tests/services/test_file_service.py @@ -8,7 +8,7 @@ from sqlalchemy import Engine from werkzeug.exceptions import NotFound from configs import dify_config -from models.account import Account, Tenant +from models import Account, Tenant from models.enums import CreatorUserRole from models.model import EndUser, UploadFile from services.errors.file import FileTooLargeError, UnsupportedFileTypeError diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_service.py b/api/tests/test_containers_integration_tests/services/test_metadata_service.py index 253791cc2d..c8ced3f3a5 100644 --- a/api/tests/test_containers_integration_tests/services/test_metadata_service.py +++ b/api/tests/test_containers_integration_tests/services/test_metadata_service.py @@ -4,7 +4,7 @@ import pytest from faker import Faker from core.rag.index_processor.constant.built_in_field import BuiltInField -from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding, Document from services.entities.knowledge_entities.knowledge_entities import MetadataArgs from services.metadata_service import MetadataService @@ -17,9 +17,7 @@ class TestMetadataService: def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( - patch( - "services.metadata_service.current_user", create_autospec(Account, instance=True) - ) as mock_current_user, + patch("libs.login.current_user", create_autospec(Account, instance=True)) as mock_current_user, patch("services.metadata_service.redis_client") as mock_redis_client, patch("services.dataset_service.DocumentService") as mock_document_service, ): diff --git a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py index fb319a4963..8cb3572c47 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py @@ -5,7 +5,7 @@ from faker import Faker from core.entities.model_entities import ModelStatus from core.model_runtime.entities.model_entities import FetchFrom, ModelType -from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.provider import Provider, ProviderModel, ProviderModelSetting, ProviderType from services.model_provider_service import ModelProviderService diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index 3d1226019b..6732b8d558 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -5,7 +5,7 @@ from faker import Faker from sqlalchemy import select from werkzeug.exceptions import NotFound -from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset from models.model import App, Tag, TagBinding from services.tag_service import TagService diff --git a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py index 5db7901cbc..bbbf48ede9 100644 --- a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py @@ -5,7 +5,7 @@ from faker import Faker from sqlalchemy import select from core.app.entities.app_invoke_entities import InvokeFrom -from models.account import Account +from models import Account from models.model import Conversation, EndUser from models.web import PinnedConversation from services.account_service import AccountService, TenantService diff --git a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py index 059767458a..9fc16d9eb7 100644 --- a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py @@ -7,7 +7,7 @@ from faker import Faker from werkzeug.exceptions import NotFound, Unauthorized from libs.password import hash_password -from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole from models.model import App, Site from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError from services.webapp_auth_service import WebAppAuthService, WebAppAuthType diff --git a/api/tests/test_containers_integration_tests/services/test_workspace_service.py b/api/tests/test_containers_integration_tests/services/test_workspace_service.py index 814d1908bd..4249642bc9 100644 --- a/api/tests/test_containers_integration_tests/services/test_workspace_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workspace_service.py @@ -3,7 +3,7 @@ from unittest.mock import patch import pytest from faker import Faker -from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from services.workspace_service import WorkspaceService diff --git a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py index 7366b08439..0871467a05 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py @@ -3,7 +3,7 @@ from unittest.mock import patch import pytest from faker import Faker -from models.account import Account, Tenant +from models import Account, Tenant from models.tools import ApiToolProvider from services.tools.api_tools_manage_service import ApiToolManageService diff --git a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py index f7a4c53318..71d55c3ade 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py @@ -4,7 +4,7 @@ import pytest from faker import Faker from core.tools.entities.tool_entities import ToolProviderType -from models.account import Account, Tenant +from models import Account, Tenant from models.tools import MCPToolProvider from services.tools.mcp_tools_manage_service import UNCHANGED_SERVER_URL_PLACEHOLDER, MCPToolManageService diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py index 88aa0b6e72..2c5e719a58 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py @@ -15,7 +15,7 @@ from core.app.app_config.entities import ( ) from core.model_runtime.entities.llm_entities import LLMMode from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from models.account import Account, Tenant +from models import Account, Tenant from models.api_based_extension import APIBasedExtension from models.model import App, AppMode, AppModelConfig from models.workflow import Workflow diff --git a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py index 96e673d855..68e485107c 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py @@ -6,7 +6,7 @@ from faker import Faker from core.rag.index_processor.constant.index_type import IndexType from extensions.ext_database import db from extensions.ext_redis import redis_client -from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetAutoDisableLog, Document, DocumentSegment from tasks.add_document_to_index_task import add_document_to_index_task diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py index 8628e2af7f..f94c5b19e6 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py @@ -14,7 +14,7 @@ from faker import Faker from extensions.ext_database import db from libs.datetime_utils import naive_utc_now -from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment from models.model import UploadFile from tasks.batch_clean_document_task import batch_clean_document_task diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index a9cfb6ffd4..1b844d6357 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -18,7 +18,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker -from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment from models.enums import CreatorUserRole from models.model import UploadFile diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index 99061d215f..45eb9d4f78 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -17,7 +17,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker -from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( AppDatasetJoin, Dataset, diff --git a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py index 987ebf8aca..8004175b2d 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py @@ -13,7 +13,7 @@ import pytest from faker import Faker from extensions.ext_redis import redis_client -from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment from tasks.create_segment_to_index_task import create_segment_to_index_task diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py index bc3701d098..8785c948d1 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py @@ -16,7 +16,7 @@ from faker import Faker from extensions.ext_database import db from extensions.ext_redis import redis_client -from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment from tasks.disable_segment_from_index_task import disable_segment_from_index_task diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py index a315577b78..448f6da5ec 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py @@ -4,7 +4,7 @@ import pytest from faker import Faker from extensions.ext_database import db -from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document from tasks.document_indexing_task import document_indexing_task From 6432898e7a0b9a79720b43263bae753662c649dd Mon Sep 17 00:00:00 2001 From: GuanMu Date: Thu, 16 Oct 2025 15:51:39 +0800 Subject: [PATCH 05/46] refactor: update TypeScript definitions for custom JSX elements and clean up global declarations in emoji picker (#26985) --- web/app/components/base/emoji-picker/Inner.tsx | 10 ---------- web/global.d.ts | 5 ++++- web/types/jsx.d.ts | 13 +++++++++++++ 3 files changed, 17 insertions(+), 11 deletions(-) create mode 100644 web/types/jsx.d.ts diff --git a/web/app/components/base/emoji-picker/Inner.tsx b/web/app/components/base/emoji-picker/Inner.tsx index 6c747c1583..6299ea7aef 100644 --- a/web/app/components/base/emoji-picker/Inner.tsx +++ b/web/app/components/base/emoji-picker/Inner.tsx @@ -14,16 +14,6 @@ import Divider from '@/app/components/base/divider' import { searchEmoji } from '@/utils/emoji' import cn from '@/utils/classnames' -declare global { - // eslint-disable-next-line ts/no-namespace - namespace JSX { - // eslint-disable-next-line ts/consistent-type-definitions - interface IntrinsicElements { - 'em-emoji': React.DetailedHTMLProps, HTMLElement> - } - } -} - init({ data }) const backgroundColors = [ diff --git a/web/global.d.ts b/web/global.d.ts index eb39fe0c39..20b84a5327 100644 --- a/web/global.d.ts +++ b/web/global.d.ts @@ -1,3 +1,6 @@ +import './types/i18n' +import './types/jsx' + declare module 'lamejs'; declare module 'lamejs/src/js/MPEGMode'; declare module 'lamejs/src/js/Lame'; @@ -9,4 +12,4 @@ declare module '*.mdx' { export default MDXComponent } -import './types/i18n' +export {} diff --git a/web/types/jsx.d.ts b/web/types/jsx.d.ts new file mode 100644 index 0000000000..41f1b63250 --- /dev/null +++ b/web/types/jsx.d.ts @@ -0,0 +1,13 @@ +// TypeScript type definitions for custom JSX elements +// Custom JSX elements for emoji-mart web components + +import 'react' + +declare module 'react' { + namespace JSX { + // eslint-disable-next-line ts/consistent-type-definitions + interface IntrinsicElements { + 'em-emoji': React.DetailedHTMLProps, HTMLElement> + } + } +} From 8b61f5e9c4129f0cf40db8c9c9d55d36fe4eef09 Mon Sep 17 00:00:00 2001 From: Yongtao Huang Date: Thu, 16 Oct 2025 15:53:07 +0800 Subject: [PATCH 06/46] Fix: avoid duplicate response_chunk update in `convert_stream_simple_response` (#26965) Signed-off-by: Yongtao Huang Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> --- api/core/app/apps/completion/generate_response_converter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/app/apps/completion/generate_response_converter.py b/api/core/app/apps/completion/generate_response_converter.py index d7e9ebdf24..a4f574642d 100644 --- a/api/core/app/apps/completion/generate_response_converter.py +++ b/api/core/app/apps/completion/generate_response_converter.py @@ -112,7 +112,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): metadata = {} sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) - if isinstance(sub_stream_response, ErrorStreamResponse): + elif isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: From 06649f6c2171f3212c224a264d59ce4925dc40c0 Mon Sep 17 00:00:00 2001 From: Xiyuan Chen <52963600+GareArc@users.noreply.github.com> Date: Thu, 16 Oct 2025 01:42:22 -0700 Subject: [PATCH 07/46] =?UTF-8?q?Update=20email=20templates=20to=20improve?= =?UTF-8?q?=20clarity=20and=20consistency=20in=20messagin=E2=80=A6=20(#269?= =?UTF-8?q?70)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ...hange_mail_confirm_new_template_en-US.html | 14 ++++++----- ...hange_mail_confirm_new_template_zh-CN.html | 14 ++++++----- ...hange_mail_confirm_old_template_en-US.html | 14 ++++++----- ...hange_mail_confirm_old_template_zh-CN.html | 14 ++++++----- .../invite_member_mail_template_en-US.html | 23 +++++++++++++++---- ...space_new_owner_notify_template_en-US.html | 11 +++++---- ...space_new_owner_notify_template_zh-CN.html | 11 +++++---- ...space_old_owner_notify_template_en-US.html | 21 ++++++++++------- ...space_old_owner_notify_template_zh-CN.html | 14 ++++++----- 9 files changed, 83 insertions(+), 53 deletions(-) diff --git a/api/templates/without-brand/change_mail_confirm_new_template_en-US.html b/api/templates/without-brand/change_mail_confirm_new_template_en-US.html index 69a8978f42..861b1bcdb6 100644 --- a/api/templates/without-brand/change_mail_confirm_new_template_en-US.html +++ b/api/templates/without-brand/change_mail_confirm_new_template_en-US.html @@ -42,7 +42,8 @@ font-family: Inter; font-style: normal; font-weight: 600; - line-height: 120%; /* 28.8px */ + line-height: 120%; + /* 28.8px */ } .description { @@ -51,7 +52,8 @@ font-family: Inter; font-style: normal; font-weight: 400; - line-height: 20px; /* 142.857% */ + line-height: 20px; + /* 142.857% */ letter-spacing: -0.07px; } @@ -96,7 +98,8 @@ font-family: Inter; font-style: normal; font-weight: 400; - line-height: 20px; /* 142.857% */ + line-height: 20px; + /* 142.857% */ letter-spacing: -0.07px; } @@ -107,7 +110,7 @@

Confirm Your New Email Address

-

You’re updating the email address linked to your Dify account.

+

You're updating the email address linked to your account.

To confirm this action, please use the verification code below.

This code will only be valid for the next 5 minutes:

@@ -118,5 +121,4 @@
- - + \ No newline at end of file diff --git a/api/templates/without-brand/change_mail_confirm_new_template_zh-CN.html b/api/templates/without-brand/change_mail_confirm_new_template_zh-CN.html index e3e9e7c45a..e411680e89 100644 --- a/api/templates/without-brand/change_mail_confirm_new_template_zh-CN.html +++ b/api/templates/without-brand/change_mail_confirm_new_template_zh-CN.html @@ -42,7 +42,8 @@ font-family: Inter; font-style: normal; font-weight: 600; - line-height: 120%; /* 28.8px */ + line-height: 120%; + /* 28.8px */ } .description { @@ -51,7 +52,8 @@ font-family: Inter; font-style: normal; font-weight: 400; - line-height: 20px; /* 142.857% */ + line-height: 20px; + /* 142.857% */ letter-spacing: -0.07px; } @@ -96,7 +98,8 @@ font-family: Inter; font-style: normal; font-weight: 400; - line-height: 20px; /* 142.857% */ + line-height: 20px; + /* 142.857% */ letter-spacing: -0.07px; } @@ -107,7 +110,7 @@

确认您的邮箱地址变更

-

您正在更新与您的 Dify 账户关联的邮箱地址。

+

您正在更新与您的账户关联的邮箱地址。

为了确认此操作,请使用以下验证码。

此验证码仅在接下来的5分钟内有效:

@@ -118,5 +121,4 @@
- - + \ No newline at end of file diff --git a/api/templates/without-brand/change_mail_confirm_old_template_en-US.html b/api/templates/without-brand/change_mail_confirm_old_template_en-US.html index 9d79fa7ff9..9fe52255a5 100644 --- a/api/templates/without-brand/change_mail_confirm_old_template_en-US.html +++ b/api/templates/without-brand/change_mail_confirm_old_template_en-US.html @@ -42,7 +42,8 @@ font-family: Inter; font-style: normal; font-weight: 600; - line-height: 120%; /* 28.8px */ + line-height: 120%; + /* 28.8px */ } .description { @@ -51,7 +52,8 @@ font-family: Inter; font-style: normal; font-weight: 400; - line-height: 20px; /* 142.857% */ + line-height: 20px; + /* 142.857% */ letter-spacing: -0.07px; } @@ -96,7 +98,8 @@ font-family: Inter; font-style: normal; font-weight: 400; - line-height: 20px; /* 142.857% */ + line-height: 20px; + /* 142.857% */ letter-spacing: -0.07px; } @@ -107,7 +110,7 @@

Verify Your Request to Change Email

-

We received a request to change the email address associated with your Dify account.

+

We received a request to change the email address associated with your account.

To confirm this action, please use the verification code below.

This code will only be valid for the next 5 minutes:

@@ -118,5 +121,4 @@ - - + \ No newline at end of file diff --git a/api/templates/without-brand/change_mail_confirm_old_template_zh-CN.html b/api/templates/without-brand/change_mail_confirm_old_template_zh-CN.html index 41f0839190..98cbd2f0c6 100644 --- a/api/templates/without-brand/change_mail_confirm_old_template_zh-CN.html +++ b/api/templates/without-brand/change_mail_confirm_old_template_zh-CN.html @@ -42,7 +42,8 @@ font-family: Inter; font-style: normal; font-weight: 600; - line-height: 120%; /* 28.8px */ + line-height: 120%; + /* 28.8px */ } .description { @@ -51,7 +52,8 @@ font-family: Inter; font-style: normal; font-weight: 400; - line-height: 20px; /* 142.857% */ + line-height: 20px; + /* 142.857% */ letter-spacing: -0.07px; } @@ -96,7 +98,8 @@ font-family: Inter; font-style: normal; font-weight: 400; - line-height: 20px; /* 142.857% */ + line-height: 20px; + /* 142.857% */ letter-spacing: -0.07px; } @@ -107,7 +110,7 @@

验证您的邮箱变更请求

-

我们收到了一个变更您 Dify 账户关联邮箱地址的请求。

+

我们收到了一个变更您账户关联邮箱地址的请求。

此验证码仅在接下来的5分钟内有效:

@@ -117,5 +120,4 @@
- - + \ No newline at end of file diff --git a/api/templates/without-brand/invite_member_mail_template_en-US.html b/api/templates/without-brand/invite_member_mail_template_en-US.html index b78a6a0760..f9157284fa 100644 --- a/api/templates/without-brand/invite_member_mail_template_en-US.html +++ b/api/templates/without-brand/invite_member_mail_template_en-US.html @@ -1,5 +1,6 @@ + +

Dear {{ to }},

-

{{ inviter_name }} is pleased to invite you to join our workspace on {{application_title}}, a platform specifically designed for LLM application development. On {{application_title}}, you can explore, create, and collaborate to build and operate AI applications.

+

{{ inviter_name }} is pleased to invite you to join our workspace on {{application_title}}, a + platform specifically designed for LLM application development. On {{application_title}}, you can explore, + create, and collaborate to build and operate AI applications.

Click the button below to log in to {{application_title}} and join the workspace.

-

Login Here

+

Login Here

Best regards,

{{application_title}} Team

- + \ No newline at end of file diff --git a/api/templates/without-brand/transfer_workspace_new_owner_notify_template_en-US.html b/api/templates/without-brand/transfer_workspace_new_owner_notify_template_en-US.html index a5758a2184..659c285324 100644 --- a/api/templates/without-brand/transfer_workspace_new_owner_notify_template_en-US.html +++ b/api/templates/without-brand/transfer_workspace_new_owner_notify_template_en-US.html @@ -42,7 +42,8 @@ font-family: Inter; font-style: normal; font-weight: 600; - line-height: 120%; /* 28.8px */ + line-height: 120%; + /* 28.8px */ } .description { @@ -51,7 +52,8 @@ font-family: Inter; font-style: normal; font-weight: 400; - line-height: 20px; /* 142.857% */ + line-height: 20px; + /* 142.857% */ letter-spacing: -0.07px; } @@ -80,10 +82,9 @@

You have been assigned as the new owner of the workspace "{{WorkspaceName}}".

As the new owner, you now have full administrative privileges for this workspace.

-

If you have any questions, please contact support@dify.ai.

+

If you have any questions, please contact support.

- - + \ No newline at end of file diff --git a/api/templates/without-brand/transfer_workspace_new_owner_notify_template_zh-CN.html b/api/templates/without-brand/transfer_workspace_new_owner_notify_template_zh-CN.html index 53bab92552..f710dbb289 100644 --- a/api/templates/without-brand/transfer_workspace_new_owner_notify_template_zh-CN.html +++ b/api/templates/without-brand/transfer_workspace_new_owner_notify_template_zh-CN.html @@ -42,7 +42,8 @@ font-family: Inter; font-style: normal; font-weight: 600; - line-height: 120%; /* 28.8px */ + line-height: 120%; + /* 28.8px */ } .description { @@ -51,7 +52,8 @@ font-family: Inter; font-style: normal; font-weight: 400; - line-height: 20px; /* 142.857% */ + line-height: 20px; + /* 142.857% */ letter-spacing: -0.07px; } @@ -80,10 +82,9 @@

您已被分配为工作空间“{{WorkspaceName}}”的新所有者。

作为新所有者,您现在对该工作空间拥有完全的管理权限。

-

如果您有任何问题,请联系support@dify.ai。

+

如果您有任何问题,请联系支持团队。

- - + \ No newline at end of file diff --git a/api/templates/without-brand/transfer_workspace_old_owner_notify_template_en-US.html b/api/templates/without-brand/transfer_workspace_old_owner_notify_template_en-US.html index 3e7faeb01e..149ec77aea 100644 --- a/api/templates/without-brand/transfer_workspace_old_owner_notify_template_en-US.html +++ b/api/templates/without-brand/transfer_workspace_old_owner_notify_template_en-US.html @@ -42,7 +42,8 @@ font-family: Inter; font-style: normal; font-weight: 600; - line-height: 120%; /* 28.8px */ + line-height: 120%; + /* 28.8px */ } .description { @@ -51,7 +52,8 @@ font-family: Inter; font-style: normal; font-weight: 400; - line-height: 20px; /* 142.857% */ + line-height: 20px; + /* 142.857% */ letter-spacing: -0.07px; } @@ -97,7 +99,8 @@ font-family: Inter; font-style: normal; font-weight: 400; - line-height: 20px; /* 142.857% */ + line-height: 20px; + /* 142.857% */ letter-spacing: -0.07px; } @@ -108,12 +111,14 @@

Workspace ownership has been transferred

-

You have successfully transferred ownership of the workspace "{{WorkspaceName}}" to {{NewOwnerEmail}}.

-

You no longer have owner privileges for this workspace. Your access level has been changed to Admin.

-

If you did not initiate this transfer or have concerns about this change, please contact support@dify.ai immediately.

+

You have successfully transferred ownership of the workspace "{{WorkspaceName}}" to + {{NewOwnerEmail}}.

+

You no longer have owner privileges for this workspace. Your access level has been changed to + Admin.

+

If you did not initiate this transfer or have concerns about this change, please contact + support immediately.

- - + \ No newline at end of file diff --git a/api/templates/without-brand/transfer_workspace_old_owner_notify_template_zh-CN.html b/api/templates/without-brand/transfer_workspace_old_owner_notify_template_zh-CN.html index 31e3c23140..d7aed40068 100644 --- a/api/templates/without-brand/transfer_workspace_old_owner_notify_template_zh-CN.html +++ b/api/templates/without-brand/transfer_workspace_old_owner_notify_template_zh-CN.html @@ -42,7 +42,8 @@ font-family: Inter; font-style: normal; font-weight: 600; - line-height: 120%; /* 28.8px */ + line-height: 120%; + /* 28.8px */ } .description { @@ -51,7 +52,8 @@ font-family: Inter; font-style: normal; font-weight: 400; - line-height: 20px; /* 142.857% */ + line-height: 20px; + /* 142.857% */ letter-spacing: -0.07px; } @@ -97,7 +99,8 @@ font-family: Inter; font-style: normal; font-weight: 400; - line-height: 20px; /* 142.857% */ + line-height: 20px; + /* 142.857% */ letter-spacing: -0.07px; } @@ -110,10 +113,9 @@

您已成功将工作空间“{{WorkspaceName}}”的所有权转移给{{NewOwnerEmail}}。

您不再拥有此工作空间的拥有者权限。您的访问级别已更改为管理员。

-

如果您没有发起此转移或对此变更有任何疑问,请立即联系support@dify.ai。

+

如果您没有发起此转移或对此变更有任何疑问,请立即联系支持团队。

- - + \ No newline at end of file From 24612adf2c1ad4b359f8a3604b63591e364ff5f6 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 16 Oct 2025 22:15:03 +0800 Subject: [PATCH 08/46] Fix dispatcher idle hang and add pytest timeouts (#26998) --- .../graph_engine/orchestration/dispatcher.py | 2 ++ api/pyproject.toml | 1 + .../core/workflow/graph_engine/test_dispatcher.py | 8 ++++---- .../services/test_metadata_bug_complete.py | 10 ++++++++-- .../services/test_metadata_nullable_bug.py | 15 ++++++++++++--- api/uv.lock | 14 ++++++++++++++ dev/pytest/pytest_artifacts.sh | 4 +++- dev/pytest/pytest_model_runtime.sh | 4 +++- dev/pytest/pytest_testcontainers.sh | 4 +++- dev/pytest/pytest_tools.sh | 4 +++- dev/pytest/pytest_unit_tests.sh | 4 +++- dev/pytest/pytest_vdb.sh | 4 +++- dev/pytest/pytest_workflow.sh | 4 +++- 13 files changed, 62 insertions(+), 16 deletions(-) diff --git a/api/core/workflow/graph_engine/orchestration/dispatcher.py b/api/core/workflow/graph_engine/orchestration/dispatcher.py index 8340c10b49..f3570855ce 100644 --- a/api/core/workflow/graph_engine/orchestration/dispatcher.py +++ b/api/core/workflow/graph_engine/orchestration/dispatcher.py @@ -99,6 +99,8 @@ class Dispatcher: self._execution_coordinator.check_commands() self._event_queue.task_done() except queue.Empty: + # Process commands even when no new events arrive so abort requests are not missed + self._execution_coordinator.check_commands() # Check if execution is complete if self._execution_coordinator.is_execution_complete(): break diff --git a/api/pyproject.toml b/api/pyproject.toml index 62af88a1b2..f2de966a57 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -166,6 +166,7 @@ dev = [ "mypy~=1.17.1", # "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved. "sseclient-py>=1.8.0", + "pytest-timeout>=2.4.0", ] ############################################################ diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher.py b/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher.py index 830fc0884d..0d612e054f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher.py @@ -95,10 +95,10 @@ def _make_succeeded_event() -> NodeRunSucceededEvent: ) -def test_dispatcher_checks_commands_after_node_completion() -> None: - """Dispatcher should only check commands after node completion events.""" +def test_dispatcher_checks_commands_during_idle_and_on_completion() -> None: + """Dispatcher polls commands when idle and re-checks after completion events.""" started_checks = _run_dispatcher_for_event(_make_started_event()) succeeded_checks = _run_dispatcher_for_event(_make_succeeded_event()) - assert started_checks == 0 - assert succeeded_checks == 1 + assert started_checks == 1 + assert succeeded_checks == 2 diff --git a/api/tests/unit_tests/services/test_metadata_bug_complete.py b/api/tests/unit_tests/services/test_metadata_bug_complete.py index 31fe9b2868..ee96305070 100644 --- a/api/tests/unit_tests/services/test_metadata_bug_complete.py +++ b/api/tests/unit_tests/services/test_metadata_bug_complete.py @@ -41,7 +41,10 @@ class TestMetadataBugCompleteValidation: mock_user.current_tenant_id = "tenant-123" mock_user.id = "user-456" - with patch("services.metadata_service.current_user", mock_user): + with patch( + "services.metadata_service.current_account_with_tenant", + return_value=(mock_user, mock_user.current_tenant_id), + ): # Should crash with TypeError with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): MetadataService.create_metadata("dataset-123", mock_metadata_args) @@ -51,7 +54,10 @@ class TestMetadataBugCompleteValidation: mock_user.current_tenant_id = "tenant-123" mock_user.id = "user-456" - with patch("services.metadata_service.current_user", mock_user): + with patch( + "services.metadata_service.current_account_with_tenant", + return_value=(mock_user, mock_user.current_tenant_id), + ): with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): MetadataService.update_metadata_name("dataset-123", "metadata-456", None) diff --git a/api/tests/unit_tests/services/test_metadata_nullable_bug.py b/api/tests/unit_tests/services/test_metadata_nullable_bug.py index c8cd7025c2..3d57737943 100644 --- a/api/tests/unit_tests/services/test_metadata_nullable_bug.py +++ b/api/tests/unit_tests/services/test_metadata_nullable_bug.py @@ -29,7 +29,10 @@ class TestMetadataNullableBug: mock_user.current_tenant_id = "tenant-123" mock_user.id = "user-456" - with patch("services.metadata_service.current_user", mock_user): + with patch( + "services.metadata_service.current_account_with_tenant", + return_value=(mock_user, mock_user.current_tenant_id), + ): # This should crash with TypeError when calling len(None) with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): MetadataService.create_metadata("dataset-123", mock_metadata_args) @@ -40,7 +43,10 @@ class TestMetadataNullableBug: mock_user.current_tenant_id = "tenant-123" mock_user.id = "user-456" - with patch("services.metadata_service.current_user", mock_user): + with patch( + "services.metadata_service.current_account_with_tenant", + return_value=(mock_user, mock_user.current_tenant_id), + ): # This should crash with TypeError when calling len(None) with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): MetadataService.update_metadata_name("dataset-123", "metadata-456", None) @@ -88,7 +94,10 @@ class TestMetadataNullableBug: mock_user.current_tenant_id = "tenant-123" mock_user.id = "user-456" - with patch("services.metadata_service.current_user", mock_user): + with patch( + "services.metadata_service.current_account_with_tenant", + return_value=(mock_user, mock_user.current_tenant_id), + ): # Step 4: Service layer crashes on len(None) with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): MetadataService.create_metadata("dataset-123", mock_metadata_args) diff --git a/api/uv.lock b/api/uv.lock index 96aee8a97b..e7facf8248 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1394,6 +1394,7 @@ dev = [ { name = "pytest-cov" }, { name = "pytest-env" }, { name = "pytest-mock" }, + { name = "pytest-timeout" }, { name = "ruff" }, { name = "scipy-stubs" }, { name = "sseclient-py" }, @@ -1583,6 +1584,7 @@ dev = [ { name = "pytest-cov", specifier = "~=4.1.0" }, { name = "pytest-env", specifier = "~=1.1.3" }, { name = "pytest-mock", specifier = "~=3.14.0" }, + { name = "pytest-timeout", specifier = ">=2.4.0" }, { name = "ruff", specifier = "~=0.14.0" }, { name = "scipy-stubs", specifier = ">=1.15.3.0" }, { name = "sseclient-py", specifier = ">=1.8.0" }, @@ -4979,6 +4981,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b2/05/77b60e520511c53d1c1ca75f1930c7dd8e971d0c4379b7f4b3f9644685ba/pytest_mock-3.14.1-py3-none-any.whl", hash = "sha256:178aefcd11307d874b4cd3100344e7e2d888d9791a6a1d9bfe90fbc1b74fd1d0", size = 9923, upload-time = "2025-05-26T13:58:43.487Z" }, ] +[[package]] +name = "pytest-timeout" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ac/82/4c9ecabab13363e72d880f2fb504c5f750433b2b6f16e99f4ec21ada284c/pytest_timeout-2.4.0.tar.gz", hash = "sha256:7e68e90b01f9eff71332b25001f85c75495fc4e3a836701876183c4bcfd0540a", size = 17973, upload-time = "2025-05-05T19:44:34.99Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl", hash = "sha256:c42667e5cdadb151aeb5b26d114aff6bdf5a907f176a007a30b940d3d865b5c2", size = 14382, upload-time = "2025-05-05T19:44:33.502Z" }, +] + [[package]] name = "python-calamine" version = "0.5.3" diff --git a/dev/pytest/pytest_artifacts.sh b/dev/pytest/pytest_artifacts.sh index 3086ef5cc4..29cacdcc07 100755 --- a/dev/pytest/pytest_artifacts.sh +++ b/dev/pytest/pytest_artifacts.sh @@ -4,4 +4,6 @@ set -x SCRIPT_DIR="$(dirname "$(realpath "$0")")" cd "$SCRIPT_DIR/../.." -pytest api/tests/artifact_tests/ +PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-120}" + +pytest --timeout "${PYTEST_TIMEOUT}" api/tests/artifact_tests/ diff --git a/dev/pytest/pytest_model_runtime.sh b/dev/pytest/pytest_model_runtime.sh index 2cbbbbfd81..fd68dbe697 100755 --- a/dev/pytest/pytest_model_runtime.sh +++ b/dev/pytest/pytest_model_runtime.sh @@ -4,7 +4,9 @@ set -x SCRIPT_DIR="$(dirname "$(realpath "$0")")" cd "$SCRIPT_DIR/../.." -pytest api/tests/integration_tests/model_runtime/anthropic \ +PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-180}" + +pytest --timeout "${PYTEST_TIMEOUT}" api/tests/integration_tests/model_runtime/anthropic \ api/tests/integration_tests/model_runtime/azure_openai \ api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm \ api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference \ diff --git a/dev/pytest/pytest_testcontainers.sh b/dev/pytest/pytest_testcontainers.sh index e55a436138..f92f8821bf 100755 --- a/dev/pytest/pytest_testcontainers.sh +++ b/dev/pytest/pytest_testcontainers.sh @@ -4,4 +4,6 @@ set -x SCRIPT_DIR="$(dirname "$(realpath "$0")")" cd "$SCRIPT_DIR/../.." -pytest api/tests/test_containers_integration_tests +PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-120}" + +pytest --timeout "${PYTEST_TIMEOUT}" api/tests/test_containers_integration_tests diff --git a/dev/pytest/pytest_tools.sh b/dev/pytest/pytest_tools.sh index d10934626f..989784f078 100755 --- a/dev/pytest/pytest_tools.sh +++ b/dev/pytest/pytest_tools.sh @@ -4,4 +4,6 @@ set -x SCRIPT_DIR="$(dirname "$(realpath "$0")")" cd "$SCRIPT_DIR/../.." -pytest api/tests/integration_tests/tools +PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-120}" + +pytest --timeout "${PYTEST_TIMEOUT}" api/tests/integration_tests/tools diff --git a/dev/pytest/pytest_unit_tests.sh b/dev/pytest/pytest_unit_tests.sh index 1a1819ca28..496cb40952 100755 --- a/dev/pytest/pytest_unit_tests.sh +++ b/dev/pytest/pytest_unit_tests.sh @@ -4,5 +4,7 @@ set -x SCRIPT_DIR="$(dirname "$(realpath "$0")")" cd "$SCRIPT_DIR/../.." +PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-20}" + # libs -pytest api/tests/unit_tests +pytest --timeout "${PYTEST_TIMEOUT}" api/tests/unit_tests diff --git a/dev/pytest/pytest_vdb.sh b/dev/pytest/pytest_vdb.sh index 7f617a9c05..3c11a079cc 100755 --- a/dev/pytest/pytest_vdb.sh +++ b/dev/pytest/pytest_vdb.sh @@ -4,7 +4,9 @@ set -x SCRIPT_DIR="$(dirname "$(realpath "$0")")" cd "$SCRIPT_DIR/../.." -pytest api/tests/integration_tests/vdb/chroma \ +PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-180}" + +pytest --timeout "${PYTEST_TIMEOUT}" api/tests/integration_tests/vdb/chroma \ api/tests/integration_tests/vdb/milvus \ api/tests/integration_tests/vdb/pgvecto_rs \ api/tests/integration_tests/vdb/pgvector \ diff --git a/dev/pytest/pytest_workflow.sh b/dev/pytest/pytest_workflow.sh index b63d49069f..941c8d3e7e 100755 --- a/dev/pytest/pytest_workflow.sh +++ b/dev/pytest/pytest_workflow.sh @@ -4,4 +4,6 @@ set -x SCRIPT_DIR="$(dirname "$(realpath "$0")")" cd "$SCRIPT_DIR/../.." -pytest api/tests/integration_tests/workflow +PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-120}" + +pytest --timeout "${PYTEST_TIMEOUT}" api/tests/integration_tests/workflow From 650e38e17f63b58111ff1714998cdb7b84774573 Mon Sep 17 00:00:00 2001 From: GuanMu Date: Thu, 16 Oct 2025 22:16:01 +0800 Subject: [PATCH 09/46] refactor: improve TypeScript types for NodeCardProps and debug configuration context (#27001) --- .../workflow/workflow-preview/components/nodes/base.tsx | 4 +++- web/context/debug-configuration.ts | 2 +- web/service/base.ts | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/web/app/components/workflow/workflow-preview/components/nodes/base.tsx b/web/app/components/workflow/workflow-preview/components/nodes/base.tsx index 55dfac467e..c7483c11bb 100644 --- a/web/app/components/workflow/workflow-preview/components/nodes/base.tsx +++ b/web/app/components/workflow/workflow-preview/components/nodes/base.tsx @@ -23,8 +23,10 @@ import { } from '../node-handle' import ErrorHandleOnNode from '../error-handle-on-node' +type NodeChildElement = ReactElement> + type NodeCardProps = NodeProps & { - children?: ReactElement + children?: NodeChildElement } const BaseCard = ({ diff --git a/web/context/debug-configuration.ts b/web/context/debug-configuration.ts index bbf7be8099..dba2e7a231 100644 --- a/web/context/debug-configuration.ts +++ b/web/context/debug-configuration.ts @@ -242,7 +242,7 @@ const DebugConfigurationContext = createContext({ }, datasetConfigsRef: { current: null, - }, + } as unknown as RefObject, setDatasetConfigs: noop, hasSetContextVar: false, isShowVisionConfig: false, diff --git a/web/service/base.ts b/web/service/base.ts index 358f54183b..1cb99e38d3 100644 --- a/web/service/base.ts +++ b/web/service/base.ts @@ -324,7 +324,7 @@ const baseFetch = base type UploadOptions = { xhr: XMLHttpRequest - method: string + method?: string url?: string headers?: Record data: FormData From a8ad80c405423b03feffed1a157d543acb305fac Mon Sep 17 00:00:00 2001 From: Dhruv Gorasiya <80987415+DhruvGorasiya@users.noreply.github.com> Date: Thu, 16 Oct 2025 10:41:48 -0400 Subject: [PATCH 10/46] Fixed Weaviate no module found issue (issue #26938) (#26964) Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .github/workflows/expose_service_ports.sh | 2 +- .../vdb/weaviate/weaviate_vector.py | 1 - api/pyproject.toml | 3 +- api/uv.lock | 6 +- docker/docker-compose-template.yaml | 15 +- docker/docker-compose.override.yml | 9 - docker/docker-compose.yaml | 15 +- .../WEAVIATE_MIGRATION_GUIDE/README.md | 187 ++++++++++++++++++ 8 files changed, 222 insertions(+), 16 deletions(-) delete mode 100644 docker/docker-compose.override.yml create mode 100644 docs/weaviate/WEAVIATE_MIGRATION_GUIDE/README.md diff --git a/.github/workflows/expose_service_ports.sh b/.github/workflows/expose_service_ports.sh index fa0fd2ee8c..e7d5f60288 100755 --- a/.github/workflows/expose_service_ports.sh +++ b/.github/workflows/expose_service_ports.sh @@ -14,4 +14,4 @@ yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/tidb/docker-compose.ya yq eval '.services.oceanbase.ports += ["2881:2881"]' -i docker/docker-compose.yaml yq eval '.services.opengauss.ports += ["6600:6600"]' -i docker/docker-compose.yaml -echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss" +echo "Ports exposed for sandbox, weaviate (HTTP 8080, gRPC 50051), tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss" diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 4793d2bb50..15207add18 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -250,7 +250,6 @@ class WeaviateVector(BaseVector): ) ) - batch_size = max(1, int(dify_config.WEAVIATE_BATCH_SIZE or 100)) with col.batch.dynamic() as batch: for obj in objs: batch.add_object(properties=obj.properties, uuid=obj.uuid, vector=obj.vector) diff --git a/api/pyproject.toml b/api/pyproject.toml index f2de966a57..74e6782d83 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -86,6 +86,7 @@ dependencies = [ "sendgrid~=6.12.3", "flask-restx~=1.3.0", "packaging~=23.2", + "weaviate-client==4.17.0", ] # Before adding new dependency, consider place it in # alphabet order (a-z) and suitable group. @@ -215,7 +216,7 @@ vdb = [ "tidb-vector==0.0.9", "upstash-vector==0.6.0", "volcengine-compat~=1.0.0", - "weaviate-client>=4.0.0,<5.0.0", + "weaviate-client==4.17.0", "xinference-client~=1.2.2", "mo-vector~=0.1.13", "mysql-connector-python>=9.3.0", diff --git a/api/uv.lock b/api/uv.lock index e7facf8248..8f28fa36a8 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.11, <3.13" resolution-markers = [ "python_full_version >= '3.12.4' and platform_python_implementation != 'PyPy' and sys_platform == 'linux'", @@ -1372,6 +1372,7 @@ dependencies = [ { name = "transformers" }, { name = "unstructured", extra = ["docx", "epub", "md", "ppt", "pptx"] }, { name = "weave" }, + { name = "weaviate-client" }, { name = "webvtt-py" }, { name = "yarl" }, ] @@ -1562,6 +1563,7 @@ requires-dist = [ { name = "transformers", specifier = "~=4.56.1" }, { name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.16.1" }, { name = "weave", specifier = "~=0.51.0" }, + { name = "weaviate-client", specifier = "==4.17.0" }, { name = "webvtt-py", specifier = "~=0.5.1" }, { name = "yarl", specifier = "~=1.18.3" }, ] @@ -1669,7 +1671,7 @@ vdb = [ { name = "tidb-vector", specifier = "==0.0.9" }, { name = "upstash-vector", specifier = "==0.6.0" }, { name = "volcengine-compat", specifier = "~=1.0.0" }, - { name = "weaviate-client", specifier = ">=4.0.0,<5.0.0" }, + { name = "weaviate-client", specifier = "==4.17.0" }, { name = "xinference-client", specifier = "~=1.2.2" }, ] diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 5483e2d554..5a67c080cc 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -24,6 +24,13 @@ services: volumes: # Mount the storage directory to the container, for storing user files. - ./volumes/app/storage:/app/api/storage + # TODO: Remove this entrypoint override when weaviate-client 4.17.0 is included in the next Dify release + entrypoint: + - /bin/bash + - -c + - | + uv pip install --system weaviate-client==4.17.0 + exec /bin/bash /app/api/docker/entrypoint.sh networks: - ssrf_proxy_network - default @@ -51,6 +58,13 @@ services: volumes: # Mount the storage directory to the container, for storing user files. - ./volumes/app/storage:/app/api/storage + # TODO: Remove this entrypoint override when weaviate-client 4.17.0 is included in the next Dify release + entrypoint: + - /bin/bash + - -c + - | + uv pip install --system weaviate-client==4.17.0 + exec /bin/bash /app/api/docker/entrypoint.sh networks: - ssrf_proxy_network - default @@ -331,7 +345,6 @@ services: weaviate: image: semitechnologies/weaviate:1.27.0 profiles: - - "" - weaviate restart: always volumes: diff --git a/docker/docker-compose.override.yml b/docker/docker-compose.override.yml deleted file mode 100644 index 8f2ab1cb43..0000000000 --- a/docker/docker-compose.override.yml +++ /dev/null @@ -1,9 +0,0 @@ -services: - api: - volumes: - - ../api/core/rag/datasource/vdb/weaviate/weaviate_vector.py:/app/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py:ro - command: > - sh -c " - pip install --no-cache-dir 'weaviate>=4.0.0' && - /bin/bash /entrypoint.sh - " diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 46b4a750ea..421b733e2b 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -631,6 +631,13 @@ services: volumes: # Mount the storage directory to the container, for storing user files. - ./volumes/app/storage:/app/api/storage + # TODO: Remove this entrypoint override when weaviate-client 4.17.0 is included in the next Dify release + entrypoint: + - /bin/bash + - -c + - | + uv pip install --system weaviate-client==4.17.0 + exec /bin/bash /app/api/docker/entrypoint.sh networks: - ssrf_proxy_network - default @@ -658,6 +665,13 @@ services: volumes: # Mount the storage directory to the container, for storing user files. - ./volumes/app/storage:/app/api/storage + # TODO: Remove this entrypoint override when weaviate-client 4.17.0 is included in the next Dify release + entrypoint: + - /bin/bash + - -c + - | + uv pip install --system weaviate-client==4.17.0 + exec /bin/bash /app/api/docker/entrypoint.sh networks: - ssrf_proxy_network - default @@ -938,7 +952,6 @@ services: weaviate: image: semitechnologies/weaviate:1.27.0 profiles: - - "" - weaviate restart: always volumes: diff --git a/docs/weaviate/WEAVIATE_MIGRATION_GUIDE/README.md b/docs/weaviate/WEAVIATE_MIGRATION_GUIDE/README.md new file mode 100644 index 0000000000..b2599e8c2e --- /dev/null +++ b/docs/weaviate/WEAVIATE_MIGRATION_GUIDE/README.md @@ -0,0 +1,187 @@ +# Weaviate Migration Guide: v1.19 → v1.27 + +## Overview + +Dify has upgraded from Weaviate v1.19 to v1.27 with the Python client updated from v3.24 to v4.17. + +## What Changed + +### Breaking Changes + +1. **Weaviate Server**: `1.19.0` → `1.27.0` +1. **Python Client**: `weaviate-client~=3.24.0` → `weaviate-client==4.17.0` +1. **gRPC Required**: Weaviate v1.27 requires gRPC port `50051` (in addition to HTTP port `8080`) +1. **Docker Compose**: Added temporary entrypoint overrides for client installation + +### Key Improvements + +- Faster vector operations via gRPC +- Improved batch processing +- Better error handling + +## Migration Steps + +### For Docker Users + +#### Step 1: Backup Your Data + +```bash +cd docker +docker compose down +sudo cp -r ./volumes/weaviate ./volumes/weaviate_backup_$(date +%Y%m%d) +``` + +#### Step 2: Update Dify + +```bash +git pull origin main +docker compose pull +``` + +#### Step 3: Start Services + +```bash +docker compose up -d +sleep 30 +curl http://localhost:8080/v1/meta +``` + +#### Step 4: Verify Migration + +```bash +# Check both ports are accessible +curl http://localhost:8080/v1/meta +netstat -tulpn | grep 50051 + +# Test in Dify UI: +# 1. Go to Knowledge Base +# 2. Test search functionality +# 3. Upload a test document +``` + +### For Source Installation + +#### Step 1: Update Dependencies + +```bash +cd api +uv sync --dev +uv run python -c "import weaviate; print(weaviate.__version__)" +# Should show: 4.17.0 +``` + +#### Step 2: Update Weaviate Server + +```bash +cd docker +docker compose -f docker-compose.middleware.yaml --profile weaviate up -d weaviate +curl http://localhost:8080/v1/meta +netstat -tulpn | grep 50051 +``` + +## Troubleshooting + +### Error: "No module named 'weaviate.classes'" + +**Solution**: + +```bash +cd api +uv sync --reinstall-package weaviate-client +uv run python -c "import weaviate; print(weaviate.__version__)" +# Should show: 4.17.0 +``` + +### Error: "gRPC health check failed" + +**Solution**: + +```bash +# Check Weaviate ports +docker ps | grep weaviate +# Should show: 0.0.0.0:8080->8080/tcp, 0.0.0.0:50051->50051/tcp + +# If missing gRPC port, add to docker-compose: +# ports: +# - "8080:8080" +# - "50051:50051" +``` + +### Error: "Weaviate version 1.19.0 is not supported" + +**Solution**: + +```bash +# Update Weaviate image in docker-compose +# Change: semitechnologies/weaviate:1.19.0 +# To: semitechnologies/weaviate:1.27.0 +docker compose down +docker compose up -d +``` + +### Data Migration Failed + +**Solution**: + +```bash +cd docker +docker compose down +sudo rm -rf ./volumes/weaviate +sudo cp -r ./volumes/weaviate_backup_YYYYMMDD ./volumes/weaviate +docker compose up -d +``` + +## Rollback Instructions + +```bash +# 1. Stop services +docker compose down + +# 2. Restore data backup +sudo rm -rf ./volumes/weaviate +sudo cp -r ./volumes/weaviate_backup_YYYYMMDD ./volumes/weaviate + +# 3. Checkout previous version +git checkout + +# 4. Restart services +docker compose up -d +``` + +## Compatibility + +| Component | Old Version | New Version | Compatible | +|-----------|-------------|-------------|------------| +| Weaviate Server | 1.19.0 | 1.27.0 | ✅ Yes | +| weaviate-client | ~3.24.0 | ==4.17.0 | ✅ Yes | +| Existing Data | v1.19 format | v1.27 format | ✅ Yes | + +## Testing Checklist + +Before deploying to production: + +- [ ] Backup all Weaviate data +- [ ] Test in staging environment +- [ ] Verify existing collections are accessible +- [ ] Test vector search functionality +- [ ] Test document upload and retrieval +- [ ] Monitor gRPC connection stability +- [ ] Check performance metrics + +## Support + +If you encounter issues: + +1. Check GitHub Issues: https://github.com/langgenius/dify/issues +1. Create a bug report with: + - Error messages + - Docker logs: `docker compose logs weaviate` + - Dify version + - Migration steps attempted + +## Important Notes + +- **Data Safety**: Existing vector data remains fully compatible +- **No Re-indexing**: No need to rebuild vector indexes +- **Temporary Workaround**: The entrypoint overrides are temporary until next Dify release +- **Performance**: May see improved performance due to gRPC usage From d19c10016650923c17edb9e8d90ea2637f6700ba Mon Sep 17 00:00:00 2001 From: Dhruv Gorasiya <80987415+DhruvGorasiya@users.noreply.github.com> Date: Thu, 16 Oct 2025 21:06:50 -0400 Subject: [PATCH 11/46] fix: logical error in Weaviate distance calculation (#27019) Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- api/core/rag/datasource/vdb/weaviate/weaviate_vector.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 15207add18..d2d8fcf964 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -347,7 +347,10 @@ class WeaviateVector(BaseVector): for obj in res.objects: properties = dict(obj.properties or {}) text = properties.pop(Field.TEXT_KEY.value, "") - distance = (obj.metadata.distance if obj.metadata else None) or 1.0 + if obj.metadata and obj.metadata.distance is not None: + distance = obj.metadata.distance + else: + distance = 1.0 score = 1.0 - distance if score > score_threshold: From 312974aa20f25e8f9254bfc3b594b68edfe0a7ac Mon Sep 17 00:00:00 2001 From: Yongtao Huang Date: Fri, 17 Oct 2025 09:07:28 +0800 Subject: [PATCH 12/46] Chore: remove unused class-level variables in DatasourceManager (#27011) Signed-off-by: Yongtao Huang Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/core/datasource/datasource_manager.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index 47d297e194..002415a7db 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -1,11 +1,9 @@ import logging from threading import Lock -from typing import Union import contexts from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_provider import DatasourcePluginProviderController -from core.datasource.entities.common_entities import I18nObject from core.datasource.entities.datasource_entities import DatasourceProviderType from core.datasource.errors import DatasourceProviderNotFoundError from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController @@ -18,11 +16,6 @@ logger = logging.getLogger(__name__) class DatasourceManager: - _builtin_provider_lock = Lock() - _hardcoded_providers: dict[str, DatasourcePluginProviderController] = {} - _builtin_providers_loaded = False - _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {} - @classmethod def get_datasource_plugin_provider( cls, provider_id: str, tenant_id: str, datasource_type: DatasourceProviderType From d7f0a31e2464881d49da0c776ba27107f14c6f39 Mon Sep 17 00:00:00 2001 From: quicksand Date: Fri, 17 Oct 2025 09:09:45 +0800 Subject: [PATCH 13/46] =?UTF-8?q?Fix:=20User=20Context=20Loss=20When=20Inv?= =?UTF-8?q?oking=20Workflow=20Tool=20Node=20in=20Knowledge=20=E2=80=A6=20(?= =?UTF-8?q?#26495)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/core/tools/workflow_as_tool/tool.py | 58 ++++++++++++++++++- .../core/tools/workflow_as_tool/test_tool.py | 7 ++- 2 files changed, 61 insertions(+), 4 deletions(-) diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 5adf04611d..50c2327004 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -3,6 +3,7 @@ import logging from collections.abc import Generator from typing import Any +from flask import has_request_context from sqlalchemy import select from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod @@ -18,7 +19,8 @@ from core.tools.errors import ToolInvokeError from extensions.ext_database import db from factories.file_factory import build_from_mapping from libs.login import current_user -from models.model import App +from models import Account, Tenant +from models.model import App, EndUser from models.workflow import Workflow logger = logging.getLogger(__name__) @@ -79,11 +81,16 @@ class WorkflowTool(Tool): generator = WorkflowAppGenerator() assert self.runtime is not None assert self.runtime.invoke_from is not None - assert current_user is not None + + user = self._resolve_user(user_id=user_id) + + if user is None: + raise ToolInvokeError("User not found") + result = generator.generate( app_model=app, workflow=workflow, - user=current_user, + user=user, args={"inputs": tool_parameters, "files": files}, invoke_from=self.runtime.invoke_from, streaming=False, @@ -123,6 +130,51 @@ class WorkflowTool(Tool): label=self.label, ) + def _resolve_user(self, user_id: str) -> Account | EndUser | None: + """ + Resolve user object in both HTTP and worker contexts. + + In HTTP context: dereference the current_user LocalProxy (can return Account or EndUser). + In worker context: load Account from database by user_id (only returns Account, never EndUser). + + Returns: + Account | EndUser | None: The resolved user object, or None if resolution fails. + """ + if has_request_context(): + return self._resolve_user_from_request() + else: + return self._resolve_user_from_database(user_id=user_id) + + def _resolve_user_from_request(self) -> Account | EndUser | None: + """ + Resolve user from Flask request context. + """ + try: + # Note: `current_user` is a LocalProxy. Never compare it with None directly. + return getattr(current_user, "_get_current_object", lambda: current_user)() + except Exception as e: + logger.warning("Failed to resolve user from request context: %s", e) + return None + + def _resolve_user_from_database(self, user_id: str) -> Account | None: + """ + Resolve user from database (worker/Celery context). + """ + + user_stmt = select(Account).where(Account.id == user_id) + user = db.session.scalar(user_stmt) + if not user: + return None + + tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id) + tenant = db.session.scalar(tenant_stmt) + if not tenant: + return None + + user.current_tenant = tenant + + return user + def _get_workflow(self, app_id: str, version: str) -> Workflow: """ get the workflow by app id and version diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index 17e3ebeea0..c68aad0b22 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -34,12 +34,17 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) + # Mock user resolution to avoid database access + from unittest.mock import Mock + + mock_user = Mock() + monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user) + # replace `WorkflowAppGenerator.generate` 's return value. monkeypatch.setattr( "core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", lambda *args, **kwargs: {"data": {"error": "oops"}}, ) - monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None) with pytest.raises(ToolInvokeError) as exc_info: # WorkflowTool always returns a generator, so we need to iterate to From 19cc6ea9930092d3049bde8de8906664a93db86c Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Fri, 17 Oct 2025 10:10:16 +0900 Subject: [PATCH 14/46] fix 27003 (#27005) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- api/controllers/console/explore/workflow.py | 4 ++- api/controllers/console/wraps.py | 7 +++++- api/libs/login.py | 27 +++++++++++++++------ api/models/model.py | 2 +- api/services/datasource_provider_service.py | 21 ++++++++++------ 5 files changed, 43 insertions(+), 18 deletions(-) diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index e32f2814eb..aeea446c6e 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -22,7 +22,7 @@ from core.errors.error import ( from core.model_runtime.errors.invoke import InvokeError from core.workflow.graph_engine.manager import GraphEngineManager from libs import helper -from libs.login import current_user +from libs.login import current_user as current_user_ from models.model import AppMode, InstalledApp from services.app_generate_service import AppGenerateService from services.errors.llm import InvokeRateLimitError @@ -31,6 +31,8 @@ from .. import console_ns logger = logging.getLogger(__name__) +current_user = current_user_._get_current_object() # type: ignore + @console_ns.route("/installed-apps//workflows/run") class InstalledAppWorkflowRunApi(InstalledAppResource): diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 2fa28711c3..8572a6dc9b 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -303,7 +303,12 @@ def edit_permission_required(f: Callable[P, R]): def decorated_function(*args: P.args, **kwargs: P.kwargs): from werkzeug.exceptions import Forbidden - current_user, _ = current_account_with_tenant() + from libs.login import current_user + from models import Account + + user = current_user._get_current_object() # type: ignore + if not isinstance(user, Account): + raise Forbidden() if not current_user.has_edit_permission: raise Forbidden() return f(*args, **kwargs) diff --git a/api/libs/login.py b/api/libs/login.py index 2c75ef9297..d0e81a3441 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -1,6 +1,6 @@ from collections.abc import Callable from functools import wraps -from typing import Union, cast +from typing import Any from flask import current_app, g, has_request_context, request from flask_login.config import EXEMPT_METHODS # type: ignore @@ -10,16 +10,21 @@ from configs import dify_config from models import Account from models.model import EndUser -#: A proxy for the current user. If no user is logged in, this will be an -#: anonymous user -current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user())) - def current_account_with_tenant(): - if not isinstance(current_user, Account): + """ + Resolve the underlying account for the current user proxy and ensure tenant context exists. + Allows tests to supply plain Account mocks without the LocalProxy helper. + """ + user_proxy = current_user + + get_current_object = getattr(user_proxy, "_get_current_object", None) + user = get_current_object() if callable(get_current_object) else user_proxy # type: ignore + + if not isinstance(user, Account): raise ValueError("current_user must be an Account instance") - assert current_user.current_tenant_id is not None, "The tenant information should be loaded." - return current_user, current_user.current_tenant_id + assert user.current_tenant_id is not None, "The tenant information should be loaded." + return user, user.current_tenant_id from typing import ParamSpec, TypeVar @@ -81,3 +86,9 @@ def _get_user() -> EndUser | Account | None: return g._login_user # type: ignore return None + + +#: A proxy for the current user. If no user is logged in, this will be an +#: anonymous user +# NOTE: Any here, but use _get_current_object to check the fields +current_user: Any = LocalProxy(lambda: _get_user()) diff --git a/api/models/model.py b/api/models/model.py index 2373421e7d..af22ab9538 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1479,7 +1479,7 @@ class EndUser(Base, UserMixin): sa.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"), ) - id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id = mapped_column(StringUUID, nullable=True) type: Mapped[str] = mapped_column(String(255), nullable=False) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index fcb6ab1d40..1b690e2266 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -17,7 +17,6 @@ from core.tools.entities.tool_entities import CredentialType from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter from extensions.ext_database import db from extensions.ext_redis import redis_client -from libs.login import current_account_with_tenant from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider from models.provider_ids import DatasourceProviderID from services.plugin.plugin_service import PluginService @@ -25,6 +24,16 @@ from services.plugin.plugin_service import PluginService logger = logging.getLogger(__name__) +def get_current_user(): + from libs.login import current_user + from models.account import Account + from models.model import EndUser + + if not isinstance(current_user._get_current_object(), (Account, EndUser)): # type: ignore + raise TypeError(f"current_user must be Account or EndUser, got {type(current_user).__name__}") + return current_user + + class DatasourceProviderService: """ Model Provider Service @@ -93,8 +102,6 @@ class DatasourceProviderService: """ get credential by id """ - current_user, _ = current_account_with_tenant() - with Session(db.engine) as session: if credential_id: datasource_provider = ( @@ -111,6 +118,7 @@ class DatasourceProviderService: return {} # refresh the credentials if datasource_provider.expires_at != -1 and (datasource_provider.expires_at - 60) < int(time.time()): + current_user = get_current_user() decrypted_credentials = self.decrypt_datasource_provider_credentials( tenant_id=tenant_id, datasource_provider=datasource_provider, @@ -159,8 +167,6 @@ class DatasourceProviderService: """ get all datasource credentials by provider """ - current_user, _ = current_account_with_tenant() - with Session(db.engine) as session: datasource_providers = ( session.query(DatasourceProvider) @@ -170,6 +176,7 @@ class DatasourceProviderService: ) if not datasource_providers: return [] + current_user = get_current_user() # refresh the credentials real_credentials_list = [] for datasource_provider in datasource_providers: @@ -608,7 +615,6 @@ class DatasourceProviderService: """ provider_name = provider_id.provider_name plugin_id = provider_id.plugin_id - current_user, _ = current_account_with_tenant() with Session(db.engine) as session: lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}" @@ -630,6 +636,7 @@ class DatasourceProviderService: raise ValueError("Authorization name is already exists") try: + current_user = get_current_user() self.provider_manager.validate_provider_credentials( tenant_id=tenant_id, user_id=current_user.id, @@ -907,7 +914,6 @@ class DatasourceProviderService: """ update datasource credentials. """ - current_user, _ = current_account_with_tenant() with Session(db.engine) as session: datasource_provider = ( @@ -944,6 +950,7 @@ class DatasourceProviderService: for key, value in credentials.items() } try: + current_user = get_current_user() self.provider_manager.validate_provider_credentials( tenant_id=tenant_id, user_id=current_user.id, From 58524d6d2b93c11c6a9a779624d58268b290e7af Mon Sep 17 00:00:00 2001 From: Guangdong Liu Date: Fri, 17 Oct 2025 09:11:03 +0800 Subject: [PATCH 15/46] fix: remove unnecessary properties from condition draft (#27009) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../nodes/if-else/components/condition-list/condition-item.tsx | 3 +++ .../nodes/loop/components/condition-list/condition-item.tsx | 3 +++ 2 files changed, 6 insertions(+) diff --git a/web/app/components/workflow/nodes/if-else/components/condition-list/condition-item.tsx b/web/app/components/workflow/nodes/if-else/components/condition-list/condition-item.tsx index 1a76622c57..252c9a7d77 100644 --- a/web/app/components/workflow/nodes/if-else/components/condition-list/condition-item.tsx +++ b/web/app/components/workflow/nodes/if-else/components/condition-list/condition-item.tsx @@ -234,6 +234,9 @@ const ConditionItem = ({ draft.varType = resolvedVarType draft.value = resolvedVarType === VarType.boolean ? false : '' draft.comparison_operator = getOperators(resolvedVarType)[0] + delete draft.key + delete draft.sub_variable_condition + delete draft.numberVarType setTimeout(() => setControlPromptEditorRerenderKey(Date.now())) }) doUpdateCondition(newCondition) diff --git a/web/app/components/workflow/nodes/loop/components/condition-list/condition-item.tsx b/web/app/components/workflow/nodes/loop/components/condition-list/condition-item.tsx index 6e573093b7..4de07cc3da 100644 --- a/web/app/components/workflow/nodes/loop/components/condition-list/condition-item.tsx +++ b/web/app/components/workflow/nodes/loop/components/condition-list/condition-item.tsx @@ -196,6 +196,9 @@ const ConditionItem = ({ draft.varType = varItem.type draft.value = '' draft.comparison_operator = getOperators(varItem.type)[0] + delete draft.key + delete draft.sub_variable_condition + delete draft.numberVarType }) doUpdateCondition(newCondition) setOpen(false) From 9d5300440c43d78152a54b7e91570b129bb5bcd4 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 17 Oct 2025 09:11:48 +0800 Subject: [PATCH 16/46] Restore coverage for skipped workflow tests (#27018) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../test_complex_branch_workflow.py | 62 +++++++++---------- .../core/workflow/nodes/llm/test_node.py | 17 +---- .../services/auth/test_auth_integration.py | 11 ++-- 3 files changed, 38 insertions(+), 52 deletions(-) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py index fc38393e75..96926797ec 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py @@ -7,14 +7,11 @@ This test suite validates the behavior of a workflow that: 3. Handles multiple answer nodes with different outputs """ -import pytest - from core.workflow.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, - NodeRunSucceededEvent, ) from .test_mock_config import MockConfigBuilder @@ -29,7 +26,6 @@ class TestComplexBranchWorkflow: self.runner = TableTestRunner() self.fixture_path = "test_complex_branch" - @pytest.mark.skip(reason="output in this workflow can be random") def test_hello_branch_with_llm(self): """ Test when query contains 'hello' - should trigger true branch. @@ -41,42 +37,17 @@ class TestComplexBranchWorkflow: fixture_path=self.fixture_path, query="hello world", expected_outputs={ - "answer": f"{mock_text_1}contains 'hello'", + "answer": f"contains 'hello'{mock_text_1}", }, description="Basic hello case with parallel LLM execution", use_auto_mock=True, mock_config=(MockConfigBuilder().with_node_output("1755502777322", {"text": mock_text_1}).build()), - expected_event_sequence=[ - GraphRunStartedEvent, - # Start - NodeRunStartedEvent, - NodeRunSucceededEvent, - # If/Else (no streaming) - NodeRunStartedEvent, - NodeRunSucceededEvent, - # LLM (with streaming) - NodeRunStartedEvent, - ] - # LLM - + [NodeRunStreamChunkEvent] * (mock_text_1.count(" ") + 2) - + [ - # Answer's text - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - # Answer - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Answer 2 - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], ), WorkflowTestCase( fixture_path=self.fixture_path, query="say hello to everyone", expected_outputs={ - "answer": "Mocked response for greetingcontains 'hello'", + "answer": "contains 'hello'Mocked response for greeting", }, description="Hello in middle of sentence", use_auto_mock=True, @@ -93,6 +64,35 @@ class TestComplexBranchWorkflow: for result in suite_result.results: assert result.success, f"Test '{result.test_case.description}' failed: {result.error}" assert result.actual_outputs + assert any(isinstance(event, GraphRunStartedEvent) for event in result.events) + assert any(isinstance(event, GraphRunSucceededEvent) for event in result.events) + + start_index = next( + idx for idx, event in enumerate(result.events) if isinstance(event, GraphRunStartedEvent) + ) + success_index = max( + idx for idx, event in enumerate(result.events) if isinstance(event, GraphRunSucceededEvent) + ) + assert start_index < success_index + + started_node_ids = {event.node_id for event in result.events if isinstance(event, NodeRunStartedEvent)} + assert {"1755502773326", "1755502777322"}.issubset(started_node_ids), ( + f"Branch or LLM nodes missing in events: {started_node_ids}" + ) + + assert any(isinstance(event, NodeRunStreamChunkEvent) for event in result.events), ( + "Expected streaming chunks from LLM execution" + ) + + llm_start_index = next( + idx + for idx, event in enumerate(result.events) + if isinstance(event, NodeRunStartedEvent) and event.node_id == "1755502777322" + ) + assert any( + idx > llm_start_index and isinstance(event, NodeRunStreamChunkEvent) + for idx, event in enumerate(result.events) + ), "Streaming chunks should follow LLM node start" def test_non_hello_branch_with_llm(self): """ diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 61ce640edd..94c638bb0f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -21,7 +21,6 @@ from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool -from core.workflow.graph import Graph from core.workflow.nodes.llm import llm_utils from core.workflow.nodes.llm.entities import ( ContextConfig, @@ -83,14 +82,6 @@ def graph_init_params() -> GraphInitParams: ) -@pytest.fixture -def graph() -> Graph: - # TODO: This fixture uses old Graph constructor parameters that are incompatible - # with the new queue-based engine. Need to rewrite for new engine architecture. - pytest.skip("Graph fixture incompatible with new queue-based engine - needs rewrite for ResponseStreamCoordinator") - return Graph() - - @pytest.fixture def graph_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( @@ -105,7 +96,7 @@ def graph_runtime_state() -> GraphRuntimeState: @pytest.fixture def llm_node( - llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState + llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState ) -> LLMNode: mock_file_saver = mock.MagicMock(spec=LLMFileSaver) node_config = { @@ -493,9 +484,7 @@ def test_handle_list_messages_basic(llm_node): @pytest.fixture -def llm_node_for_multimodal( - llm_node_data, graph_init_params, graph, graph_runtime_state -) -> tuple[LLMNode, LLMFileSaver]: +def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_state) -> tuple[LLMNode, LLMFileSaver]: mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver) node_config = { "id": "1", @@ -655,7 +644,7 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown: gen = llm_node._save_multimodal_output_and_convert_result_to_markdown( contents=frozenset(["hello world"]), file_saver=mock_file_saver, file_outputs=[] ) - assert list(gen) == ["frozenset({'hello world'})"] + assert list(gen) == ["hello world"] mock_file_saver.save_binary_string.assert_not_called() mock_file_saver.save_remote_url.assert_not_called() diff --git a/api/tests/unit_tests/services/auth/test_auth_integration.py b/api/tests/unit_tests/services/auth/test_auth_integration.py index acfc5cc526..3832a0b8b2 100644 --- a/api/tests/unit_tests/services/auth/test_auth_integration.py +++ b/api/tests/unit_tests/services/auth/test_auth_integration.py @@ -181,14 +181,11 @@ class TestAuthIntegration: ) def test_all_providers_factory_creation(self, provider, credentials): """Test factory creation for all supported providers""" - try: - auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(provider) - assert auth_class is not None + auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(provider) + assert auth_class is not None - factory = ApiKeyAuthFactory(provider, credentials) - assert factory.auth is not None - except ImportError: - pytest.skip(f"Provider {provider} not implemented yet") + factory = ApiKeyAuthFactory(provider, credentials) + assert factory.auth is not None def _create_success_response(self, status_code=200): """Create successful HTTP response mock""" From 8cafc2009856c95265daf4d0656e9a7c30257673 Mon Sep 17 00:00:00 2001 From: GuanMu Date: Fri, 17 Oct 2025 10:46:43 +0800 Subject: [PATCH 17/46] Fix type error (#27024) --- .../workflow/nodes/parameter-extractor/use-config.ts | 2 +- .../workflow/nodes/question-classifier/use-config.ts | 2 +- .../workflow/nodes/template-transform/use-config.ts | 2 +- .../workflow/run/agent-log/agent-log-nav-more.tsx | 6 +++--- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/web/app/components/workflow/nodes/parameter-extractor/use-config.ts b/web/app/components/workflow/nodes/parameter-extractor/use-config.ts index 2caae83f2a..d3699853c2 100644 --- a/web/app/components/workflow/nodes/parameter-extractor/use-config.ts +++ b/web/app/components/workflow/nodes/parameter-extractor/use-config.ts @@ -27,7 +27,7 @@ const useConfig = (id: string, payload: ParameterExtractorNodeType) => { const { handleOutVarRenameChange } = useWorkflow() const isChatMode = useIsChatMode() - const defaultConfig = useStore(s => s.nodesDefaultConfigs)[payload.type] + const defaultConfig = useStore(s => s.nodesDefaultConfigs)?.[payload.type] const [defaultRolePrefix, setDefaultRolePrefix] = useState<{ user: string; assistant: string }>({ user: '', assistant: '' }) const { inputs, setInputs: doSetInputs } = useNodeCrud(id, payload) diff --git a/web/app/components/workflow/nodes/question-classifier/use-config.ts b/web/app/components/workflow/nodes/question-classifier/use-config.ts index b4907641b5..5106f373a8 100644 --- a/web/app/components/workflow/nodes/question-classifier/use-config.ts +++ b/web/app/components/workflow/nodes/question-classifier/use-config.ts @@ -20,7 +20,7 @@ const useConfig = (id: string, payload: QuestionClassifierNodeType) => { const updateNodeInternals = useUpdateNodeInternals() const { nodesReadOnly: readOnly } = useNodesReadOnly() const isChatMode = useIsChatMode() - const defaultConfig = useStore(s => s.nodesDefaultConfigs)[payload.type] + const defaultConfig = useStore(s => s.nodesDefaultConfigs)?.[payload.type] const { getBeforeNodesInSameBranch } = useWorkflow() const startNode = getBeforeNodesInSameBranch(id).find(node => node.data.type === BlockEnum.Start) const startNodeId = startNode?.id diff --git a/web/app/components/workflow/nodes/template-transform/use-config.ts b/web/app/components/workflow/nodes/template-transform/use-config.ts index fa7eb81baf..dc012a3844 100644 --- a/web/app/components/workflow/nodes/template-transform/use-config.ts +++ b/web/app/components/workflow/nodes/template-transform/use-config.ts @@ -13,7 +13,7 @@ import useAvailableVarList from '@/app/components/workflow/nodes/_base/hooks/use const useConfig = (id: string, payload: TemplateTransformNodeType) => { const { nodesReadOnly: readOnly } = useNodesReadOnly() - const defaultConfig = useStore(s => s.nodesDefaultConfigs)[payload.type] + const defaultConfig = useStore(s => s.nodesDefaultConfigs)?.[payload.type] const { inputs, setInputs: doSetInputs } = useNodeCrud(id, payload) const inputsRef = useRef(inputs) diff --git a/web/app/components/workflow/run/agent-log/agent-log-nav-more.tsx b/web/app/components/workflow/run/agent-log/agent-log-nav-more.tsx index 9f14aa1210..6062946ede 100644 --- a/web/app/components/workflow/run/agent-log/agent-log-nav-more.tsx +++ b/web/app/components/workflow/run/agent-log/agent-log-nav-more.tsx @@ -9,7 +9,7 @@ import Button from '@/app/components/base/button' import type { AgentLogItemWithChildren } from '@/types/workflow' type AgentLogNavMoreProps = { - options: { id: string; label: string }[] + options: AgentLogItemWithChildren[] onShowAgentOrToolLog: (detail?: AgentLogItemWithChildren) => void } const AgentLogNavMore = ({ @@ -41,10 +41,10 @@ const AgentLogNavMore = ({ { options.map(option => (
{ - onShowAgentOrToolLog(option as AgentLogItemWithChildren) + onShowAgentOrToolLog(option) setOpen(false) }} > From 91bb8ae4d207b87b3962cf9bd5450ccb91f2bb40 Mon Sep 17 00:00:00 2001 From: Joel Date: Fri, 17 Oct 2025 13:42:56 +0800 Subject: [PATCH 18/46] fix: happy-dom security issues (#27037) --- web/package.json | 2 +- web/pnpm-lock.yaml | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/web/package.json b/web/package.json index 7c2f30aa61..1721e54d73 100644 --- a/web/package.json +++ b/web/package.json @@ -144,7 +144,7 @@ "@babel/core": "^7.28.3", "@chromatic-com/storybook": "^3.1.0", "@eslint-react/eslint-plugin": "^1.15.0", - "@happy-dom/jest-environment": "^20.0.0", + "@happy-dom/jest-environment": "^20.0.2", "@mdx-js/loader": "^3.1.0", "@mdx-js/react": "^3.1.0", "@next/bundle-analyzer": "15.5.4", diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index 28758f1142..4f75b6e93e 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -345,8 +345,8 @@ importers: specifier: ^1.15.0 version: 1.52.3(eslint@9.35.0(jiti@2.6.1))(ts-api-utils@2.1.0(typescript@5.8.3))(typescript@5.8.3) '@happy-dom/jest-environment': - specifier: ^20.0.0 - version: 20.0.0(@jest/environment@29.7.0)(@jest/fake-timers@29.7.0)(@jest/types@29.6.3)(jest-mock@29.7.0)(jest-util@29.7.0) + specifier: ^20.0.2 + version: 20.0.4(@jest/environment@29.7.0)(@jest/fake-timers@29.7.0)(@jest/types@29.6.3)(jest-mock@29.7.0)(jest-util@29.7.0) '@mdx-js/loader': specifier: ^3.1.0 version: 3.1.0(acorn@8.15.0)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)) @@ -1644,8 +1644,8 @@ packages: '@formatjs/intl-localematcher@0.5.10': resolution: {integrity: sha512-af3qATX+m4Rnd9+wHcjJ4w2ijq+rAVP3CCinJQvFv1kgSu1W6jypUmvleJxcewdxmutM8dmIRZFxO/IQBZmP2Q==} - '@happy-dom/jest-environment@20.0.0': - resolution: {integrity: sha512-dUyMDNJzPDFopSDyzKdbeYs8z9B4jLj9kXnru8TjYdGeLsQKf+6r0lq/9T2XVcu04QFxXMykt64A+KjsaJTaNA==} + '@happy-dom/jest-environment@20.0.4': + resolution: {integrity: sha512-75OcYtjO+jqxWiYiXvbwR8JZITX1/8iAjRSRpZ/rNjO6UnYebwX6HdI91Ix09xYZEO1X/xOof6HX1EiZnrgnXA==} engines: {node: '>=20.0.0'} peerDependencies: '@jest/environment': '>=25.0.0' @@ -5575,8 +5575,8 @@ packages: hachure-fill@0.5.2: resolution: {integrity: sha512-3GKBOn+m2LX9iq+JC1064cSFprJY4jL1jCXTcpnfER5HYE2l/4EfWSGzkPa/ZDBmYI0ZOEj5VHV/eKnPGkHuOg==} - happy-dom@20.0.0: - resolution: {integrity: sha512-GkWnwIFxVGCf2raNrxImLo397RdGhLapj5cT3R2PT7FwL62Ze1DROhzmYW7+J3p9105DYMVenEejEbnq5wA37w==} + happy-dom@20.0.4: + resolution: {integrity: sha512-WxFtvnij6G64/MtMimnZhF0nKx3LUQKc20zjATD6tKiqOykUwQkd+2FW/DZBAFNjk4oWh0xdv/HBleGJmSY/Iw==} engines: {node: '>=20.0.0'} has-flag@4.0.0: @@ -10132,12 +10132,12 @@ snapshots: dependencies: tslib: 2.8.1 - '@happy-dom/jest-environment@20.0.0(@jest/environment@29.7.0)(@jest/fake-timers@29.7.0)(@jest/types@29.6.3)(jest-mock@29.7.0)(jest-util@29.7.0)': + '@happy-dom/jest-environment@20.0.4(@jest/environment@29.7.0)(@jest/fake-timers@29.7.0)(@jest/types@29.6.3)(jest-mock@29.7.0)(jest-util@29.7.0)': dependencies: '@jest/environment': 29.7.0 '@jest/fake-timers': 29.7.0 '@jest/types': 29.6.3 - happy-dom: 20.0.0 + happy-dom: 20.0.4 jest-mock: 29.7.0 jest-util: 29.7.0 @@ -14787,7 +14787,7 @@ snapshots: hachure-fill@0.5.2: {} - happy-dom@20.0.0: + happy-dom@20.0.4: dependencies: '@types/node': 20.19.20 '@types/whatwg-mimetype': 3.0.2 From 531a0b755a07bf4df30849120efde4ac4ec1f72f Mon Sep 17 00:00:00 2001 From: NFish Date: Fri, 17 Oct 2025 14:03:34 +0800 Subject: [PATCH 19/46] fix: show 'Invalid email or password' error tip when web app login failed (#27034) --- .../webapp-signin/components/mail-and-password-auth.tsx | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx index 2201b28a2f..2b6bd73df0 100644 --- a/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx @@ -100,7 +100,10 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut }) } } - + catch (e: any) { + if (e.code === 'authentication_failed') + Toast.notify({ type: 'error', message: e.message }) + } finally { setIsLoading(false) } From 6517323addd64507a69bd10568be356bcad032b5 Mon Sep 17 00:00:00 2001 From: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com> Date: Fri, 17 Oct 2025 14:29:56 +0800 Subject: [PATCH 20/46] Feature: add test containers based tests for mail register tasks (#27040) Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com> --- .../tasks/test_mail_register_task.py | 134 ++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 api/tests/test_containers_integration_tests/tasks/test_mail_register_task.py diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_register_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_register_task.py new file mode 100644 index 0000000000..e4db14623d --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_register_task.py @@ -0,0 +1,134 @@ +""" +TestContainers-based integration tests for mail_register_task.py + +This module provides integration tests for email registration tasks +using TestContainers to ensure real database and service interactions. +""" + +from unittest.mock import MagicMock, patch + +import pytest +from faker import Faker + +from libs.email_i18n import EmailType +from tasks.mail_register_task import send_email_register_mail_task, send_email_register_mail_task_when_account_exist + + +class TestMailRegisterTask: + """Integration tests for mail_register_task using testcontainers.""" + + @pytest.fixture + def mock_mail_dependencies(self): + """Mock setup for mail service dependencies.""" + with ( + patch("tasks.mail_register_task.mail") as mock_mail, + patch("tasks.mail_register_task.get_email_i18n_service") as mock_get_email_service, + ): + # Setup mock mail service + mock_mail.is_inited.return_value = True + + # Setup mock email i18n service + mock_email_service = MagicMock() + mock_get_email_service.return_value = mock_email_service + + yield { + "mail": mock_mail, + "email_service": mock_email_service, + "get_email_service": mock_get_email_service, + } + + def test_send_email_register_mail_task_success(self, db_session_with_containers, mock_mail_dependencies): + """Test successful email registration mail sending.""" + fake = Faker() + language = "en-US" + to_email = fake.email() + code = fake.numerify("######") + + send_email_register_mail_task(language=language, to=to_email, code=code) + + mock_mail_dependencies["mail"].is_inited.assert_called_once() + mock_mail_dependencies["email_service"].send_email.assert_called_once_with( + email_type=EmailType.EMAIL_REGISTER, + language_code=language, + to=to_email, + template_context={ + "to": to_email, + "code": code, + }, + ) + + def test_send_email_register_mail_task_mail_not_initialized( + self, db_session_with_containers, mock_mail_dependencies + ): + """Test email registration task when mail service is not initialized.""" + mock_mail_dependencies["mail"].is_inited.return_value = False + + send_email_register_mail_task(language="en-US", to="test@example.com", code="123456") + + mock_mail_dependencies["get_email_service"].assert_not_called() + mock_mail_dependencies["email_service"].send_email.assert_not_called() + + def test_send_email_register_mail_task_exception_handling(self, db_session_with_containers, mock_mail_dependencies): + """Test email registration task exception handling.""" + mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error") + + fake = Faker() + to_email = fake.email() + code = fake.numerify("######") + + with patch("tasks.mail_register_task.logger") as mock_logger: + send_email_register_mail_task(language="en-US", to=to_email, code=code) + mock_logger.exception.assert_called_once_with("Send email register mail to %s failed", to_email) + + def test_send_email_register_mail_task_when_account_exist_success( + self, db_session_with_containers, mock_mail_dependencies + ): + """Test successful email registration mail sending when account exists.""" + fake = Faker() + language = "en-US" + to_email = fake.email() + account_name = fake.name() + + with patch("tasks.mail_register_task.dify_config") as mock_config: + mock_config.CONSOLE_WEB_URL = "https://console.dify.ai" + + send_email_register_mail_task_when_account_exist(language=language, to=to_email, account_name=account_name) + + mock_mail_dependencies["email_service"].send_email.assert_called_once_with( + email_type=EmailType.EMAIL_REGISTER_WHEN_ACCOUNT_EXIST, + language_code=language, + to=to_email, + template_context={ + "to": to_email, + "login_url": "https://console.dify.ai/signin", + "reset_password_url": "https://console.dify.ai/reset-password", + "account_name": account_name, + }, + ) + + def test_send_email_register_mail_task_when_account_exist_mail_not_initialized( + self, db_session_with_containers, mock_mail_dependencies + ): + """Test account exist email task when mail service is not initialized.""" + mock_mail_dependencies["mail"].is_inited.return_value = False + + send_email_register_mail_task_when_account_exist( + language="en-US", to="test@example.com", account_name="Test User" + ) + + mock_mail_dependencies["get_email_service"].assert_not_called() + mock_mail_dependencies["email_service"].send_email.assert_not_called() + + def test_send_email_register_mail_task_when_account_exist_exception_handling( + self, db_session_with_containers, mock_mail_dependencies + ): + """Test account exist email task exception handling.""" + mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error") + + fake = Faker() + to_email = fake.email() + account_name = fake.name() + + with patch("tasks.mail_register_task.logger") as mock_logger: + send_email_register_mail_task_when_account_exist(language="en-US", to=to_email, account_name=account_name) + mock_logger.exception.assert_called_once_with("Send email register mail to %s failed", to_email) From 4f7cb7cd2a1fdec3b994617caac1d47b6b88e182 Mon Sep 17 00:00:00 2001 From: GuanMu Date: Fri, 17 Oct 2025 14:42:58 +0800 Subject: [PATCH 21/46] Fix type error (#27044) --- web/app/components/base/textarea/index.tsx | 4 ++-- .../metadata/condition-list/condition-number.tsx | 2 +- .../metadata/condition-list/condition-string.tsx | 2 +- .../workflow/nodes/knowledge-retrieval/types.ts | 4 ++-- .../json-schema-config-modal/json-importer.tsx | 2 +- .../visual-editor/context.tsx | 2 +- web/app/components/workflow/nodes/llm/use-config.ts | 2 +- web/app/components/workflow/nodes/llm/utils.ts | 2 +- .../extract-parameter/import-from-tool.tsx | 12 +++++++++--- 9 files changed, 19 insertions(+), 13 deletions(-) diff --git a/web/app/components/base/textarea/index.tsx b/web/app/components/base/textarea/index.tsx index 63eae48e31..7813eb7209 100644 --- a/web/app/components/base/textarea/index.tsx +++ b/web/app/components/base/textarea/index.tsx @@ -25,8 +25,8 @@ export type TextareaProps = { destructive?: boolean styleCss?: CSSProperties ref?: React.Ref - onFocus?: () => void - onBlur?: () => void + onFocus?: React.FocusEventHandler + onBlur?: React.FocusEventHandler } & React.TextareaHTMLAttributes & VariantProps const Textarea = React.forwardRef( diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/components/metadata/condition-list/condition-number.tsx b/web/app/components/workflow/nodes/knowledge-retrieval/components/metadata/condition-list/condition-number.tsx index 7016e8bd2a..6421401a2a 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/components/metadata/condition-list/condition-number.tsx +++ b/web/app/components/workflow/nodes/knowledge-retrieval/components/metadata/condition-list/condition-number.tsx @@ -18,7 +18,7 @@ type ConditionNumberProps = { nodesOutputVars: NodeOutPutVar[] availableNodes: Node[] isCommonVariable?: boolean - commonVariables: { name: string, type: string }[] + commonVariables: { name: string; type: string; value: string }[] } & ConditionValueMethodProps const ConditionNumber = ({ value, diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/components/metadata/condition-list/condition-string.tsx b/web/app/components/workflow/nodes/knowledge-retrieval/components/metadata/condition-list/condition-string.tsx index cf85f1259b..d5cb06e690 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/components/metadata/condition-list/condition-string.tsx +++ b/web/app/components/workflow/nodes/knowledge-retrieval/components/metadata/condition-list/condition-string.tsx @@ -18,7 +18,7 @@ type ConditionStringProps = { nodesOutputVars: NodeOutPutVar[] availableNodes: Node[] isCommonVariable?: boolean - commonVariables: { name: string, type: string }[] + commonVariables: { name: string; type: string; value: string }[] } & ConditionValueMethodProps const ConditionString = ({ value, diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/types.ts b/web/app/components/workflow/nodes/knowledge-retrieval/types.ts index 1cae4ecd3b..65f7dc2493 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/types.ts +++ b/web/app/components/workflow/nodes/knowledge-retrieval/types.ts @@ -128,6 +128,6 @@ export type MetadataShape = { availableNumberVars?: NodeOutPutVar[] availableNumberNodesWithParent?: Node[] isCommonVariable?: boolean - availableCommonStringVars?: { name: string; type: string; }[] - availableCommonNumberVars?: { name: string; type: string; }[] + availableCommonStringVars?: { name: string; type: string; value: string }[] + availableCommonNumberVars?: { name: string; type: string; value: string }[] } diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-importer.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-importer.tsx index 6368397b74..463d87d7d1 100644 --- a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-importer.tsx +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-importer.tsx @@ -24,7 +24,7 @@ const JsonImporter: FC = ({ const [open, setOpen] = useState(false) const [json, setJson] = useState('') const [parseError, setParseError] = useState(null) - const importBtnRef = useRef(null) + const importBtnRef = useRef(null) const advancedEditing = useVisualEditorStore(state => state.advancedEditing) const isAddingNewField = useVisualEditorStore(state => state.isAddingNewField) const { emit } = useMittContext() diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/context.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/context.tsx index 5bf4b22f11..268683aec3 100644 --- a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/context.tsx +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/context.tsx @@ -18,7 +18,7 @@ type VisualEditorProviderProps = { export const VisualEditorContext = createContext(null) export const VisualEditorContextProvider = ({ children }: VisualEditorProviderProps) => { - const storeRef = useRef() + const storeRef = useRef(null) if (!storeRef.current) storeRef.current = createVisualEditorStore() diff --git a/web/app/components/workflow/nodes/llm/use-config.ts b/web/app/components/workflow/nodes/llm/use-config.ts index 45635be3f2..d11fb6db28 100644 --- a/web/app/components/workflow/nodes/llm/use-config.ts +++ b/web/app/components/workflow/nodes/llm/use-config.ts @@ -23,7 +23,7 @@ const useConfig = (id: string, payload: LLMNodeType) => { const { nodesReadOnly: readOnly } = useNodesReadOnly() const isChatMode = useIsChatMode() - const defaultConfig = useStore(s => s.nodesDefaultConfigs)[payload.type] + const defaultConfig = useStore(s => s.nodesDefaultConfigs)?.[payload.type] const [defaultRolePrefix, setDefaultRolePrefix] = useState<{ user: string; assistant: string }>({ user: '', assistant: '' }) const { inputs, setInputs: doSetInputs } = useNodeCrud(id, payload) const inputRef = useRef(inputs) diff --git a/web/app/components/workflow/nodes/llm/utils.ts b/web/app/components/workflow/nodes/llm/utils.ts index 29591d76ad..10c287f86b 100644 --- a/web/app/components/workflow/nodes/llm/utils.ts +++ b/web/app/components/workflow/nodes/llm/utils.ts @@ -10,7 +10,7 @@ export const checkNodeValid = (_payload: LLMNodeType) => { export const getFieldType = (field: Field) => { const { type, items } = field - if(field.schemaType === 'file') return 'file' + if(field.schemaType === 'file') return Type.file if (type !== Type.array || !items) return type diff --git a/web/app/components/workflow/nodes/parameter-extractor/components/extract-parameter/import-from-tool.tsx b/web/app/components/workflow/nodes/parameter-extractor/components/extract-parameter/import-from-tool.tsx index bfb664aef1..d93d08a0ac 100644 --- a/web/app/components/workflow/nodes/parameter-extractor/components/extract-parameter/import-from-tool.tsx +++ b/web/app/components/workflow/nodes/parameter-extractor/components/extract-parameter/import-from-tool.tsx @@ -9,7 +9,10 @@ import BlockSelector from '../../../../block-selector' import type { Param, ParamType } from '../../types' import cn from '@/utils/classnames' import { useStore } from '@/app/components/workflow/store' -import type { ToolDefaultValue } from '@/app/components/workflow/block-selector/types' +import type { + DataSourceDefaultValue, + ToolDefaultValue, +} from '@/app/components/workflow/block-selector/types' import type { ToolParameter } from '@/app/components/tools/types' import { CollectionType } from '@/app/components/tools/types' import type { BlockEnum } from '@/app/components/workflow/types' @@ -43,8 +46,11 @@ const ImportFromTool: FC = ({ const customTools = useStore(s => s.customTools) const workflowTools = useStore(s => s.workflowTools) - const handleSelectTool = useCallback((_type: BlockEnum, toolInfo?: ToolDefaultValue) => { - const { provider_id, provider_type, tool_name } = toolInfo! + const handleSelectTool = useCallback((_type: BlockEnum, toolInfo?: ToolDefaultValue | DataSourceDefaultValue) => { + if (!toolInfo || 'datasource_name' in toolInfo) + return + + const { provider_id, provider_type, tool_name } = toolInfo const currentTools = (() => { switch (provider_type) { case CollectionType.builtIn: From bfda4ce7e6f39d43a4420e97e23a18edcfe3e3d3 Mon Sep 17 00:00:00 2001 From: 2h0ng <60600792+superboy-zjc@users.noreply.github.com> Date: Thu, 16 Oct 2025 23:58:15 -0700 Subject: [PATCH 22/46] Merge commit from fork --- web/hooks/use-oauth.ts | 2 ++ web/utils/urlValidation.ts | 24 ++++++++++++++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 web/utils/urlValidation.ts diff --git a/web/hooks/use-oauth.ts b/web/hooks/use-oauth.ts index ae9c1cda66..9f21a476b3 100644 --- a/web/hooks/use-oauth.ts +++ b/web/hooks/use-oauth.ts @@ -1,5 +1,6 @@ 'use client' import { useEffect } from 'react' +import { validateRedirectUrl } from '@/utils/urlValidation' export const useOAuthCallback = () => { useEffect(() => { @@ -18,6 +19,7 @@ export const openOAuthPopup = (url: string, callback: () => void) => { const left = window.screenX + (window.outerWidth - width) / 2 const top = window.screenY + (window.outerHeight - height) / 2 + validateRedirectUrl(url) const popup = window.open( url, 'OAuth', diff --git a/web/utils/urlValidation.ts b/web/utils/urlValidation.ts new file mode 100644 index 0000000000..372dd54cb4 --- /dev/null +++ b/web/utils/urlValidation.ts @@ -0,0 +1,24 @@ +/** + * Validates that a URL is safe for redirection. + * Only allows HTTP and HTTPS protocols to prevent XSS attacks. + * + * @param url - The URL string to validate + * @throws Error if the URL has an unsafe protocol + */ +export function validateRedirectUrl(url: string): void { + try { + const parsedUrl = new URL(url); + if (parsedUrl.protocol !== "http:" && parsedUrl.protocol !== "https:") { + throw new Error("Authorization URL must be HTTP or HTTPS"); + } + } catch (error) { + if ( + error instanceof Error && + error.message === "Authorization URL must be HTTP or HTTPS" + ) { + throw error; + } + // If URL parsing fails, it's also invalid + throw new Error(`Invalid URL: ${url}`); + } +} \ No newline at end of file From 64f55d55a1bb964118b26c248c4530d6cb27e153 Mon Sep 17 00:00:00 2001 From: Wu Tianwei <30284043+WTW0313@users.noreply.github.com> Date: Fri, 17 Oct 2025 14:58:30 +0800 Subject: [PATCH 23/46] fix: update TopK and Score Threshold components to use InputNumber and improve value handling (#27045) --- .../components/base/param-item/top-k-item.tsx | 2 +- .../top-k-and-score-threshold.tsx | 56 +++++++++++++------ 2 files changed, 40 insertions(+), 18 deletions(-) diff --git a/web/app/components/base/param-item/top-k-item.tsx b/web/app/components/base/param-item/top-k-item.tsx index f59c0f273e..0bd552e48d 100644 --- a/web/app/components/base/param-item/top-k-item.tsx +++ b/web/app/components/base/param-item/top-k-item.tsx @@ -32,7 +32,7 @@ const TopKItem: FC = ({ }) => { const { t } = useTranslation() const handleParamChange = (key: string, value: number) => { - let notOutRangeValue = Number.parseFloat(value.toFixed(2)) + let notOutRangeValue = Number.parseInt(value.toFixed(0)) notOutRangeValue = Math.max(VALUE_LIMIT.min, notOutRangeValue) notOutRangeValue = Math.min(VALUE_LIMIT.max, notOutRangeValue) onChange(key, notOutRangeValue) diff --git a/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/top-k-and-score-threshold.tsx b/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/top-k-and-score-threshold.tsx index 9d46037d84..9eb1cb39c9 100644 --- a/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/top-k-and-score-threshold.tsx +++ b/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/top-k-and-score-threshold.tsx @@ -1,8 +1,8 @@ -import { memo } from 'react' +import { memo, useCallback } from 'react' import { useTranslation } from 'react-i18next' import Tooltip from '@/app/components/base/tooltip' -import Input from '@/app/components/base/input' import Switch from '@/app/components/base/switch' +import { InputNumber } from '@/app/components/base/input-number' export type TopKAndScoreThresholdProps = { topK: number @@ -14,6 +14,24 @@ export type TopKAndScoreThresholdProps = { readonly?: boolean hiddenScoreThreshold?: boolean } + +const maxTopK = (() => { + const configValue = Number.parseInt(globalThis.document?.body?.getAttribute('data-public-top-k-max-value') || '', 10) + if (configValue && !isNaN(configValue)) + return configValue + return 10 +})() +const TOP_K_VALUE_LIMIT = { + amount: 1, + min: 1, + max: maxTopK, +} +const SCORE_THRESHOLD_VALUE_LIMIT = { + step: 0.01, + min: 0, + max: 1, +} + const TopKAndScoreThreshold = ({ topK, onTopKChange, @@ -25,18 +43,18 @@ const TopKAndScoreThreshold = ({ hiddenScoreThreshold, }: TopKAndScoreThresholdProps) => { const { t } = useTranslation() - const handleTopKChange = (e: React.ChangeEvent) => { - const value = Number(e.target.value) - if (Number.isNaN(value)) - return - onTopKChange?.(value) - } + const handleTopKChange = useCallback((value: number) => { + let notOutRangeValue = Number.parseInt(value.toFixed(0)) + notOutRangeValue = Math.max(TOP_K_VALUE_LIMIT.min, notOutRangeValue) + notOutRangeValue = Math.min(TOP_K_VALUE_LIMIT.max, notOutRangeValue) + onTopKChange?.(notOutRangeValue) + }, [onTopKChange]) - const handleScoreThresholdChange = (e: React.ChangeEvent) => { - const value = Number(e.target.value) - if (Number.isNaN(value)) - return - onScoreThresholdChange?.(value) + const handleScoreThresholdChange = (value: number) => { + let notOutRangeValue = Number.parseFloat(value.toFixed(2)) + notOutRangeValue = Math.max(SCORE_THRESHOLD_VALUE_LIMIT.min, notOutRangeValue) + notOutRangeValue = Math.min(SCORE_THRESHOLD_VALUE_LIMIT.max, notOutRangeValue) + onScoreThresholdChange?.(notOutRangeValue) } return ( @@ -49,11 +67,13 @@ const TopKAndScoreThreshold = ({ popupContent={t('appDebug.datasetConfig.top_kTip')} />
- { @@ -74,11 +94,13 @@ const TopKAndScoreThreshold = ({ popupContent={t('appDebug.datasetConfig.score_thresholdTip')} /> - ) From fea2ffb3ba583bd1a3d734c3a676ab76c5d313d7 Mon Sep 17 00:00:00 2001 From: GuanMu Date: Fri, 17 Oct 2025 17:46:28 +0800 Subject: [PATCH 24/46] fix: improve URL validation logic in validateRedirectUrl function (#27058) --- web/utils/urlValidation.ts | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/web/utils/urlValidation.ts b/web/utils/urlValidation.ts index 372dd54cb4..abc15a1365 100644 --- a/web/utils/urlValidation.ts +++ b/web/utils/urlValidation.ts @@ -7,18 +7,17 @@ */ export function validateRedirectUrl(url: string): void { try { - const parsedUrl = new URL(url); - if (parsedUrl.protocol !== "http:" && parsedUrl.protocol !== "https:") { - throw new Error("Authorization URL must be HTTP or HTTPS"); - } - } catch (error) { - if ( - error instanceof Error && - error.message === "Authorization URL must be HTTP or HTTPS" - ) { - throw error; - } - // If URL parsing fails, it's also invalid - throw new Error(`Invalid URL: ${url}`); + const parsedUrl = new URL(url) + if (parsedUrl.protocol !== 'http:' && parsedUrl.protocol !== 'https:') + throw new Error('Authorization URL must be HTTP or HTTPS') } -} \ No newline at end of file + catch (error) { + if ( + error instanceof Error + && error.message === 'Authorization URL must be HTTP or HTTPS' + ) + throw error + // If URL parsing fails, it's also invalid + throw new Error(`Invalid URL: ${url}`) + } +} From 35e24d4d1465e3016f7a17a4f83cb8add1279132 Mon Sep 17 00:00:00 2001 From: Yongtao Huang Date: Sat, 18 Oct 2025 09:54:52 +0800 Subject: [PATCH 25/46] Chore: remove redundant tenant lookup in APIBasedExtensionAPI.post (#27067) Signed-off-by: Yongtao Huang --- api/controllers/console/extension.py | 1 - 1 file changed, 1 deletion(-) diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index f77996eb6a..e5b7611c44 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -67,7 +67,6 @@ class APIBasedExtensionAPI(Resource): @account_initialization_required @marshal_with(api_based_extension_fields) def post(self): - _, current_tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, location="json") parser.add_argument("api_endpoint", type=str, required=True, location="json") From 598dd1f816f4661cdceff5292d3a603c46b528b3 Mon Sep 17 00:00:00 2001 From: Guangdong Liu Date: Sat, 18 Oct 2025 11:43:24 +0800 Subject: [PATCH 26/46] fix: allow optional config parameter and conditionally include message file ID (#26960) --- api/factories/file_factory.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 735fff53d1..231e805460 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -21,7 +21,7 @@ def build_from_message_files( *, message_files: Sequence["MessageFile"], tenant_id: str, - config: FileUploadConfig, + config: FileUploadConfig | None = None, ) -> Sequence[File]: results = [ build_from_message_file(message_file=file, tenant_id=tenant_id, config=config) @@ -35,15 +35,18 @@ def build_from_message_file( *, message_file: "MessageFile", tenant_id: str, - config: FileUploadConfig, + config: FileUploadConfig | None, ): mapping = { "transfer_method": message_file.transfer_method, "url": message_file.url, - "id": message_file.id, "type": message_file.type, } + # Only include id if it exists (message_file has been committed to DB) + if message_file.id: + mapping["id"] = message_file.id + # Set the correct ID field based on transfer method if message_file.transfer_method == FileTransferMethod.TOOL_FILE: mapping["tool_file_id"] = message_file.upload_file_id From e4b5b0e5fd5a7d2fd351a31007810186dd25e377 Mon Sep 17 00:00:00 2001 From: Guangdong Liu Date: Sat, 18 Oct 2025 11:44:11 +0800 Subject: [PATCH 27/46] feat: implement strict type validation for remote file uploads (#27010) --- api/factories/file_factory.py | 33 ++++++++++++----- .../factories/test_build_from_mapping.py | 36 +++++++++++++++++++ 2 files changed, 61 insertions(+), 8 deletions(-) diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 231e805460..2316e45179 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -166,7 +166,10 @@ def _build_from_local_file( if strict_type_validation and detected_file_type.value != specified_type: raise ValueError("Detected file type does not match the specified type. Please verify the file.") - file_type = FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type + if specified_type and specified_type != "custom": + file_type = FileType(specified_type) + else: + file_type = detected_file_type return File( id=mapping.get("id"), @@ -214,9 +217,10 @@ def _build_from_remote_url( if strict_type_validation and specified_type and detected_file_type.value != specified_type: raise ValueError("Detected file type does not match the specified type. Please verify the file.") - file_type = ( - FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type - ) + if specified_type and specified_type != "custom": + file_type = FileType(specified_type) + else: + file_type = detected_file_type return File( id=mapping.get("id"), @@ -238,10 +242,17 @@ def _build_from_remote_url( mime_type, filename, file_size = _get_remote_file_info(url) extension = mimetypes.guess_extension(mime_type) or ("." + filename.split(".")[-1] if "." in filename else ".bin") - file_type = _standardize_file_type(extension=extension, mime_type=mime_type) - if file_type.value != mapping.get("type", "custom"): + detected_file_type = _standardize_file_type(extension=extension, mime_type=mime_type) + specified_type = mapping.get("type") + + if strict_type_validation and specified_type and detected_file_type.value != specified_type: raise ValueError("Detected file type does not match the specified type. Please verify the file.") + if specified_type and specified_type != "custom": + file_type = FileType(specified_type) + else: + file_type = detected_file_type + return File( id=mapping.get("id"), filename=filename, @@ -331,7 +342,10 @@ def _build_from_tool_file( if strict_type_validation and specified_type and detected_file_type.value != specified_type: raise ValueError("Detected file type does not match the specified type. Please verify the file.") - file_type = FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type + if specified_type and specified_type != "custom": + file_type = FileType(specified_type) + else: + file_type = detected_file_type return File( id=mapping.get("id"), @@ -376,7 +390,10 @@ def _build_from_datasource_file( if strict_type_validation and specified_type and detected_file_type.value != specified_type: raise ValueError("Detected file type does not match the specified type. Please verify the file.") - file_type = FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type + if specified_type and specified_type != "custom": + file_type = FileType(specified_type) + else: + file_type = detected_file_type return File( id=mapping.get("datasource_file_id"), diff --git a/api/tests/unit_tests/factories/test_build_from_mapping.py b/api/tests/unit_tests/factories/test_build_from_mapping.py index 39280c9267..77c4956c04 100644 --- a/api/tests/unit_tests/factories/test_build_from_mapping.py +++ b/api/tests/unit_tests/factories/test_build_from_mapping.py @@ -150,6 +150,42 @@ def test_build_from_remote_url(mock_http_head): assert file.size == 2048 +@pytest.mark.parametrize( + ("file_type", "should_pass", "expected_error"), + [ + ("image", True, None), + ("document", False, "Detected file type does not match the specified type"), + ("video", False, "Detected file type does not match the specified type"), + ], +) +def test_build_from_remote_url_strict_validation(mock_http_head, file_type, should_pass, expected_error): + """Test strict type validation for remote_url.""" + mapping = { + "transfer_method": "remote_url", + "url": TEST_REMOTE_URL, + "type": file_type, + } + if should_pass: + file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True) + assert file.type == FileType(file_type) + else: + with pytest.raises(ValueError, match=expected_error): + build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True) + + +def test_build_from_remote_url_without_strict_validation(mock_http_head): + """Test that remote_url allows type mismatch when strict_type_validation is False.""" + mapping = { + "transfer_method": "remote_url", + "url": TEST_REMOTE_URL, + "type": "document", + } + file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=False) + assert file.transfer_method == FileTransferMethod.REMOTE_URL + assert file.type == FileType.DOCUMENT + assert file.filename == "remote_test.jpg" + + def test_tool_file_not_found(): """Test ToolFile not found in database.""" with patch("factories.file_factory.db.session.scalar", return_value=None): From 894e38f713f937173bd80b019c2fc65481205fdc Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Sat, 18 Oct 2025 11:47:04 +0800 Subject: [PATCH 28/46] fix: https://github.com/langgenius/dify/issues/27063 (#27074) --- .../edit-annotation-modal/edit-item/index.tsx | 13 +++++-- .../view-annotation-modal/index.tsx | 38 +++++++++++++++---- 2 files changed, 39 insertions(+), 12 deletions(-) diff --git a/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.tsx b/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.tsx index 17cb456558..e808d0b48a 100644 --- a/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.tsx +++ b/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.tsx @@ -1,6 +1,6 @@ 'use client' import type { FC } from 'react' -import React, { useState } from 'react' +import React, { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import { RiDeleteBinLine, RiEditFill, RiEditLine } from '@remixicon/react' import { Robot, User } from '@/app/components/base/icons/src/public/avatar' @@ -16,7 +16,7 @@ type Props = { type: EditItemType content: string readonly?: boolean - onSave: (content: string) => void + onSave: (content: string) => Promise } export const EditTitle: FC<{ className?: string; title: string }> = ({ className, title }) => ( @@ -46,8 +46,13 @@ const EditItem: FC = ({ const placeholder = type === EditItemType.Query ? t('appAnnotation.editModal.queryPlaceholder') : t('appAnnotation.editModal.answerPlaceholder') const [isEdit, setIsEdit] = useState(false) - const handleSave = () => { - onSave(newContent) + // Reset newContent when content prop changes + useEffect(() => { + setNewContent('') + }, [content]) + + const handleSave = async () => { + await onSave(newContent) setIsEdit(false) } diff --git a/web/app/components/app/annotation/view-annotation-modal/index.tsx b/web/app/components/app/annotation/view-annotation-modal/index.tsx index 08904d23d4..8426ab0005 100644 --- a/web/app/components/app/annotation/view-annotation-modal/index.tsx +++ b/web/app/components/app/annotation/view-annotation-modal/index.tsx @@ -21,7 +21,7 @@ type Props = { isShow: boolean onHide: () => void item: AnnotationItem - onSave: (editedQuery: string, editedAnswer: string) => void + onSave: (editedQuery: string, editedAnswer: string) => Promise onRemove: () => void } @@ -46,6 +46,16 @@ const ViewAnnotationModal: FC = ({ const [currPage, setCurrPage] = React.useState(0) const [total, setTotal] = useState(0) const [hitHistoryList, setHitHistoryList] = useState([]) + + // Update local state when item prop changes (e.g., when modal is reopened with updated data) + useEffect(() => { + setNewQuery(question) + setNewAnswer(answer) + setCurrPage(0) + setTotal(0) + setHitHistoryList([]) + }, [question, answer, id]) + const fetchHitHistory = async (page = 1) => { try { const { data, total }: any = await fetchHitHistoryList(appId, id, { @@ -63,6 +73,12 @@ const ViewAnnotationModal: FC = ({ fetchHitHistory(currPage + 1) }, [currPage]) + // Fetch hit history when item changes + useEffect(() => { + if (isShow && id) + fetchHitHistory(1) + }, [id, isShow]) + const tabs = [ { value: TabType.annotation, text: t('appAnnotation.viewModal.annotatedResponse') }, { @@ -82,14 +98,20 @@ const ViewAnnotationModal: FC = ({ }, ] const [activeTab, setActiveTab] = useState(TabType.annotation) - const handleSave = (type: EditItemType, editedContent: string) => { - if (type === EditItemType.Query) { - setNewQuery(editedContent) - onSave(editedContent, newAnswer) + const handleSave = async (type: EditItemType, editedContent: string) => { + try { + if (type === EditItemType.Query) { + await onSave(editedContent, newAnswer) + setNewQuery(editedContent) + } + else { + await onSave(newQuestion, editedContent) + setNewAnswer(editedContent) + } } - else { - setNewAnswer(editedContent) - onSave(newQuestion, editedContent) + catch (error) { + // If save fails, don't update local state + console.error('Failed to save annotation:', error) } } const [showModal, setShowModal] = useState(false) From 5937a66e226fc595d5463fe47363f488302eab10 Mon Sep 17 00:00:00 2001 From: Eric Guo Date: Sat, 18 Oct 2025 11:49:20 +0800 Subject: [PATCH 29/46] Sync same logic for datasets. (#27056) --- api/controllers/console/datasets/datasets.py | 1 - api/controllers/console/datasets/external.py | 1 - .../console/datasets/rag_pipeline/rag_pipeline_workflow.py | 3 +-- 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index c03767d2e6..4a9e0789fb 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -468,7 +468,6 @@ class DatasetApi(Resource): dataset_id_str = str(dataset_id) current_user, _ = current_account_with_tenant() - # The role of the current user in the ta table must be admin, owner, or editor if not (current_user.has_edit_permission or current_user.is_dataset_operator): raise Forbidden() diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index f590919180..1ebd7101e4 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -150,7 +150,6 @@ class ExternalApiTemplateApi(Resource): current_user, current_tenant_id = current_account_with_tenant() external_knowledge_api_id = str(external_knowledge_api_id) - # The role of the current user in the ta table must be admin, owner, or editor if not (current_user.has_edit_permission or current_user.is_dataset_operator): raise Forbidden() diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index d4d6da7fe2..77dcf30a78 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -937,11 +937,10 @@ class RagPipelineTransformApi(Resource): @setup_required @login_required @account_initialization_required - @edit_permission_required def post(self, dataset_id: str): current_user, _ = current_account_with_tenant() - if not current_user.is_dataset_operator: + if not (current_user.has_edit_permission or current_user.is_dataset_operator): raise Forbidden() dataset_id = str(dataset_id) From 830f891a7475a53d4a55bd84613697994c6b890d Mon Sep 17 00:00:00 2001 From: Amy <1530140574@qq.com> Date: Sat, 18 Oct 2025 11:58:40 +0800 Subject: [PATCH 30/46] Fix json in md when use quesion classifier node (#26992) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> --- .../question_classifier_node.py | 3 +++ api/libs/json_in_md_parser.py | 14 +++++++++---- .../unit_tests/libs/test_json_in_md_parser.py | 21 +++++++++++++++++++ 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 592a6566fd..31b1cd4966 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -1,4 +1,5 @@ import json +import re from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any @@ -194,6 +195,8 @@ class QuestionClassifierNode(Node): category_name = node_data.classes[0].name category_id = node_data.classes[0].id + if "" in result_text: + result_text = re.sub(r"]*>[\s\S]*?", "", result_text, flags=re.IGNORECASE) result_text_json = parse_and_check_json_markdown(result_text, []) # result_text_json = json.loads(result_text.strip('```JSON\n')) if "category_name" in result_text_json and "category_id" in result_text_json: diff --git a/api/libs/json_in_md_parser.py b/api/libs/json_in_md_parser.py index 0c642041bf..310e677747 100644 --- a/api/libs/json_in_md_parser.py +++ b/api/libs/json_in_md_parser.py @@ -6,22 +6,22 @@ from core.llm_generator.output_parser.errors import OutputParserError def parse_json_markdown(json_string: str): # Get json from the backticks/braces json_string = json_string.strip() - starts = ["```json", "```", "``", "`", "{"] - ends = ["```", "``", "`", "}"] + starts = ["```json", "```", "``", "`", "{", "["] + ends = ["```", "``", "`", "}", "]"] end_index = -1 start_index = 0 parsed: dict = {} for s in starts: start_index = json_string.find(s) if start_index != -1: - if json_string[start_index] != "{": + if json_string[start_index] not in ("{", "["): start_index += len(s) break if start_index != -1: for e in ends: end_index = json_string.rfind(e, start_index) if end_index != -1: - if json_string[end_index] == "}": + if json_string[end_index] in ("}", "]"): end_index += 1 break if start_index != -1 and end_index != -1 and start_index < end_index: @@ -38,6 +38,12 @@ def parse_and_check_json_markdown(text: str, expected_keys: list[str]): json_obj = parse_json_markdown(text) except json.JSONDecodeError as e: raise OutputParserError(f"got invalid json object. error: {e}") + + if isinstance(json_obj, list): + if len(json_obj) == 1 and isinstance(json_obj[0], dict): + json_obj = json_obj[0] + else: + raise OutputParserError(f"got invalid return object. obj:{json_obj}") for key in expected_keys: if key not in json_obj: raise OutputParserError( diff --git a/api/tests/unit_tests/libs/test_json_in_md_parser.py b/api/tests/unit_tests/libs/test_json_in_md_parser.py index 53fd0bea16..953f203e89 100644 --- a/api/tests/unit_tests/libs/test_json_in_md_parser.py +++ b/api/tests/unit_tests/libs/test_json_in_md_parser.py @@ -86,3 +86,24 @@ def test_parse_and_check_json_markdown_multiple_blocks_fails(): # opening fence to the last closing fence, causing JSON decode failure. with pytest.raises(OutputParserError): parse_and_check_json_markdown(src, []) + + +def test_parse_and_check_json_markdown_handles_think_fenced_and_raw_variants(): + expected = {"keywords": ["2"], "category_id": "2", "category_name": "2"} + cases = [ + """ + ```json + [{"keywords": ["2"], "category_id": "2", "category_name": "2"}] + ```, error: Expecting value: line 1 column 1 (char 0) + """, + """ + ```json + {"keywords": ["2"], "category_id": "2", "category_name": "2"} + ```, error: Extra data: line 2 column 5 (char 66) + """, + '{"keywords": ["2"], "category_id": "2", "category_name": "2"}', + '[{"keywords": ["2"], "category_id": "2", "category_name": "2"}]', + ] + for src in cases: + obj = parse_and_check_json_markdown(src, ["keywords", "category_id", "category_name"]) + assert obj == expected From 1a379897697b2877ca5130e6480f0d4b3b1f83bd Mon Sep 17 00:00:00 2001 From: GuanMu Date: Sat, 18 Oct 2025 12:03:40 +0800 Subject: [PATCH 31/46] Fix type-check error (#27051) --- .../workflow/hooks/use-nodes-interactions.ts | 4 ++-- .../nodes/_base/components/input-var-type-icon.tsx | 1 - .../nodes/_base/components/node-handle.tsx | 6 +++--- .../components/variable/variable-label/hooks.ts | 14 +++++++------- web/app/components/workflow/nodes/_base/node.tsx | 7 ++++++- web/app/components/workflow/nodes/code/types.ts | 5 +++++ .../components/workflow/nodes/http/use-config.ts | 2 +- .../components/condition-list/condition-item.tsx | 6 +++--- 8 files changed, 27 insertions(+), 18 deletions(-) diff --git a/web/app/components/workflow/hooks/use-nodes-interactions.ts b/web/app/components/workflow/hooks/use-nodes-interactions.ts index c721442d86..afd13a73fb 100644 --- a/web/app/components/workflow/hooks/use-nodes-interactions.ts +++ b/web/app/components/workflow/hooks/use-nodes-interactions.ts @@ -16,7 +16,7 @@ import { useReactFlow, useStoreApi, } from 'reactflow' -import type { ToolDefaultValue } from '../block-selector/types' +import type { DataSourceDefaultValue, ToolDefaultValue } from '../block-selector/types' import type { Edge, Node, OnNodeAdd } from '../types' import { BlockEnum } from '../types' import { useWorkflowStore } from '../store' @@ -1286,7 +1286,7 @@ export const useNodesInteractions = () => { currentNodeId: string, nodeType: BlockEnum, sourceHandle: string, - toolDefaultValue?: ToolDefaultValue, + toolDefaultValue?: ToolDefaultValue | DataSourceDefaultValue, ) => { if (getNodesReadOnly()) return diff --git a/web/app/components/workflow/nodes/_base/components/input-var-type-icon.tsx b/web/app/components/workflow/nodes/_base/components/input-var-type-icon.tsx index 566528b5c2..70fd1051b9 100644 --- a/web/app/components/workflow/nodes/_base/components/input-var-type-icon.tsx +++ b/web/app/components/workflow/nodes/_base/components/input-var-type-icon.tsx @@ -28,7 +28,6 @@ const getIcon = (type: InputVarType) => { [InputVarType.jsonObject]: RiBracesLine, [InputVarType.singleFile]: RiFileList2Line, [InputVarType.multiFiles]: RiFileCopy2Line, - [InputVarType.checkbox]: RiCheckboxLine, } as any)[type] || RiTextSnippet } diff --git a/web/app/components/workflow/nodes/_base/components/node-handle.tsx b/web/app/components/workflow/nodes/_base/components/node-handle.tsx index d1d79a0faa..90968a4580 100644 --- a/web/app/components/workflow/nodes/_base/components/node-handle.tsx +++ b/web/app/components/workflow/nodes/_base/components/node-handle.tsx @@ -16,7 +16,7 @@ import { } from '../../../types' import type { Node } from '../../../types' import BlockSelector from '../../../block-selector' -import type { ToolDefaultValue } from '../../../block-selector/types' +import type { DataSourceDefaultValue, ToolDefaultValue } from '../../../block-selector/types' import { useAvailableBlocks, useIsChatMode, @@ -57,7 +57,7 @@ export const NodeTargetHandle = memo(({ if (!connected) setOpen(v => !v) }, [connected]) - const handleSelect = useCallback((type: BlockEnum, toolDefaultValue?: ToolDefaultValue) => { + const handleSelect = useCallback((type: BlockEnum, toolDefaultValue?: ToolDefaultValue | DataSourceDefaultValue) => { handleNodeAdd( { nodeType: type, @@ -140,7 +140,7 @@ export const NodeSourceHandle = memo(({ e.stopPropagation() setOpen(v => !v) }, []) - const handleSelect = useCallback((type: BlockEnum, toolDefaultValue?: ToolDefaultValue) => { + const handleSelect = useCallback((type: BlockEnum, toolDefaultValue?: ToolDefaultValue | DataSourceDefaultValue) => { handleNodeAdd( { nodeType: type, diff --git a/web/app/components/workflow/nodes/_base/components/variable/variable-label/hooks.ts b/web/app/components/workflow/nodes/_base/components/variable/variable-label/hooks.ts index 19690edcba..fef6d8c396 100644 --- a/web/app/components/workflow/nodes/_base/components/variable/variable-label/hooks.ts +++ b/web/app/components/workflow/nodes/_base/components/variable/variable-label/hooks.ts @@ -42,17 +42,17 @@ export const useVarColor = (variables: string[], isExceptionVariable?: boolean, return 'text-util-colors-teal-teal-700' return 'text-text-accent' - }, [variables, isExceptionVariable]) + }, [variables, isExceptionVariable, variableCategory]) } export const useVarName = (variables: string[], notShowFullPath?: boolean) => { - let variableFullPathName = variables.slice(1).join('.') - - if (isRagVariableVar(variables)) - variableFullPathName = variables.slice(2).join('.') - - const variablesLength = variables.length const varName = useMemo(() => { + let variableFullPathName = variables.slice(1).join('.') + + if (isRagVariableVar(variables)) + variableFullPathName = variables.slice(2).join('.') + + const variablesLength = variables.length const isSystem = isSystemVar(variables) const varName = notShowFullPath ? variables[variablesLength - 1] : variableFullPathName return `${isSystem ? 'sys.' : ''}${varName}` diff --git a/web/app/components/workflow/nodes/_base/node.tsx b/web/app/components/workflow/nodes/_base/node.tsx index 9fd9c3ce72..4725f86ad5 100644 --- a/web/app/components/workflow/nodes/_base/node.tsx +++ b/web/app/components/workflow/nodes/_base/node.tsx @@ -48,8 +48,13 @@ import Tooltip from '@/app/components/base/tooltip' import useInspectVarsCrud from '../../hooks/use-inspect-vars-crud' import { ToolTypeEnum } from '../../block-selector/types' +type NodeChildProps = { + id: string + data: NodeProps['data'] +} + type BaseNodeProps = { - children: ReactElement + children: ReactElement> id: NodeProps['id'] data: NodeProps['data'] } diff --git a/web/app/components/workflow/nodes/code/types.ts b/web/app/components/workflow/nodes/code/types.ts index 9c055f3969..265fd9d25d 100644 --- a/web/app/components/workflow/nodes/code/types.ts +++ b/web/app/components/workflow/nodes/code/types.ts @@ -11,6 +11,11 @@ export type OutputVar = Record +export type CodeDependency = { + name: string + version?: string +} + export type CodeNodeType = CommonNodeType & { variables: Variable[] code_language: CodeLanguage diff --git a/web/app/components/workflow/nodes/http/use-config.ts b/web/app/components/workflow/nodes/http/use-config.ts index 63d14794b2..761ce99b26 100644 --- a/web/app/components/workflow/nodes/http/use-config.ts +++ b/web/app/components/workflow/nodes/http/use-config.ts @@ -16,7 +16,7 @@ import { const useConfig = (id: string, payload: HttpNodeType) => { const { nodesReadOnly: readOnly } = useNodesReadOnly() - const defaultConfig = useStore(s => s.nodesDefaultConfigs)[payload.type] + const defaultConfig = useStore(s => s.nodesDefaultConfigs?.[payload.type]) const { inputs, setInputs } = useNodeCrud(id, payload) diff --git a/web/app/components/workflow/nodes/if-else/components/condition-list/condition-item.tsx b/web/app/components/workflow/nodes/if-else/components/condition-list/condition-item.tsx index 252c9a7d77..45973122e8 100644 --- a/web/app/components/workflow/nodes/if-else/components/condition-list/condition-item.tsx +++ b/web/app/components/workflow/nodes/if-else/components/condition-list/condition-item.tsx @@ -209,7 +209,7 @@ const ConditionItem = ({ onRemoveCondition?.(caseId, condition.id) }, [caseId, condition, conditionId, isSubVariableKey, onRemoveCondition, onRemoveSubVariableCondition]) - const { getMatchedSchemaType } = useMatchSchemaType() + const { schemaTypeDefinitions } = useMatchSchemaType() const handleVarChange = useCallback((valueSelector: ValueSelector, _varItem: Var) => { const { conversationVariables, @@ -226,7 +226,7 @@ const ConditionItem = ({ workflowTools, dataSourceList: dataSourceList ?? [], }, - getMatchedSchemaType, + schemaTypeDefinitions, }) const newCondition = produce(condition, (draft) => { @@ -241,7 +241,7 @@ const ConditionItem = ({ }) doUpdateCondition(newCondition) setOpen(false) - }, [condition, doUpdateCondition, availableNodes, isChatMode, setControlPromptEditorRerenderKey]) + }, [condition, doUpdateCondition, availableNodes, isChatMode, setControlPromptEditorRerenderKey, schemaTypeDefinitions]) const showBooleanInput = useMemo(() => { if(condition.varType === VarType.boolean) From ac79691d69536cac157a90a3e3dd4c983b89d0a1 Mon Sep 17 00:00:00 2001 From: Jacky Su Date: Sat, 18 Oct 2025 12:15:29 +0800 Subject: [PATCH 32/46] Feat/add status filter to workflow runs (#26850) Co-authored-by: Jacky Su --- api/controllers/console/app/workflow_run.py | 194 +++++++++++++- api/fields/workflow_run_fields.py | 9 + api/libs/custom_inputs.py | 32 +++ api/libs/time_parser.py | 67 +++++ .../api_workflow_run_repository.py | 39 +++ .../sqlalchemy_api_workflow_run_repository.py | 75 +++++- api/services/workflow_run_service.py | 44 ++- .../unit_tests/libs/test_custom_inputs.py | 68 +++++ api/tests/unit_tests/libs/test_time_parser.py | 91 +++++++ .../test_workflow_run_repository.py | 251 ++++++++++++++++++ 10 files changed, 851 insertions(+), 19 deletions(-) create mode 100644 api/libs/custom_inputs.py create mode 100644 api/libs/time_parser.py create mode 100644 api/tests/unit_tests/libs/test_custom_inputs.py create mode 100644 api/tests/unit_tests/libs/test_time_parser.py create mode 100644 api/tests/unit_tests/repositories/test_workflow_run_repository.py diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 286ba65a7f..311aa81279 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -8,15 +8,81 @@ from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from fields.workflow_run_fields import ( advanced_chat_workflow_run_pagination_fields, + workflow_run_count_fields, workflow_run_detail_fields, workflow_run_node_execution_list_fields, workflow_run_pagination_fields, ) +from libs.custom_inputs import time_duration from libs.helper import uuid_value from libs.login import current_user, login_required -from models import Account, App, AppMode, EndUser +from models import Account, App, AppMode, EndUser, WorkflowRunTriggeredFrom from services.workflow_run_service import WorkflowRunService +# Workflow run status choices for filtering +WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"] + + +def _parse_workflow_run_list_args(): + """ + Parse common arguments for workflow run list endpoints. + + Returns: + Parsed arguments containing last_id, limit, status, and triggered_from filters + """ + parser = reqparse.RequestParser() + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + parser.add_argument( + "status", + type=str, + choices=WORKFLOW_RUN_STATUS_CHOICES, + location="args", + required=False, + ) + parser.add_argument( + "triggered_from", + type=str, + choices=["debugging", "app-run"], + location="args", + required=False, + help="Filter by trigger source: debugging or app-run", + ) + return parser.parse_args() + + +def _parse_workflow_run_count_args(): + """ + Parse common arguments for workflow run count endpoints. + + Returns: + Parsed arguments containing status, time_range, and triggered_from filters + """ + parser = reqparse.RequestParser() + parser.add_argument( + "status", + type=str, + choices=WORKFLOW_RUN_STATUS_CHOICES, + location="args", + required=False, + ) + parser.add_argument( + "time_range", + type=time_duration, + location="args", + required=False, + help="Time range filter (e.g., 7d, 4h, 30m, 30s)", + ) + parser.add_argument( + "triggered_from", + type=str, + choices=["debugging", "app-run"], + location="args", + required=False, + help="Filter by trigger source: debugging or app-run", + ) + return parser.parse_args() + @console_ns.route("/apps//advanced-chat/workflow-runs") class AdvancedChatAppWorkflowRunListApi(Resource): @@ -24,6 +90,8 @@ class AdvancedChatAppWorkflowRunListApi(Resource): @api.doc(description="Get advanced chat workflow run list") @api.doc(params={"app_id": "Application ID"}) @api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"}) + @api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}) + @api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}) @api.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_fields) @setup_required @login_required @@ -34,13 +102,64 @@ class AdvancedChatAppWorkflowRunListApi(Resource): """ Get advanced chat app workflow run list """ - parser = reqparse.RequestParser() - parser.add_argument("last_id", type=uuid_value, location="args") - parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - args = parser.parse_args() + args = _parse_workflow_run_list_args() + + # Default to DEBUGGING if not specified + triggered_from = ( + WorkflowRunTriggeredFrom(args.get("triggered_from")) + if args.get("triggered_from") + else WorkflowRunTriggeredFrom.DEBUGGING + ) workflow_run_service = WorkflowRunService() - result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args=args) + result = workflow_run_service.get_paginate_advanced_chat_workflow_runs( + app_model=app_model, args=args, triggered_from=triggered_from + ) + + return result + + +@console_ns.route("/apps//advanced-chat/workflow-runs/count") +class AdvancedChatAppWorkflowRunCountApi(Resource): + @api.doc("get_advanced_chat_workflow_runs_count") + @api.doc(description="Get advanced chat workflow runs count statistics") + @api.doc(params={"app_id": "Application ID"}) + @api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}) + @api.doc( + params={ + "time_range": ( + "Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), " + "30m (30 minutes), 30s (30 seconds). Filters by created_at field." + ) + } + ) + @api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}) + @api.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields) + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT]) + @marshal_with(workflow_run_count_fields) + def get(self, app_model: App): + """ + Get advanced chat workflow runs count statistics + """ + args = _parse_workflow_run_count_args() + + # Default to DEBUGGING if not specified + triggered_from = ( + WorkflowRunTriggeredFrom(args.get("triggered_from")) + if args.get("triggered_from") + else WorkflowRunTriggeredFrom.DEBUGGING + ) + + workflow_run_service = WorkflowRunService() + result = workflow_run_service.get_workflow_runs_count( + app_model=app_model, + status=args.get("status"), + time_range=args.get("time_range"), + triggered_from=triggered_from, + ) return result @@ -51,6 +170,8 @@ class WorkflowRunListApi(Resource): @api.doc(description="Get workflow run list") @api.doc(params={"app_id": "Application ID"}) @api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"}) + @api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}) + @api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}) @api.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_fields) @setup_required @login_required @@ -61,13 +182,64 @@ class WorkflowRunListApi(Resource): """ Get workflow run list """ - parser = reqparse.RequestParser() - parser.add_argument("last_id", type=uuid_value, location="args") - parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - args = parser.parse_args() + args = _parse_workflow_run_list_args() + + # Default to DEBUGGING for workflow if not specified (backward compatibility) + triggered_from = ( + WorkflowRunTriggeredFrom(args.get("triggered_from")) + if args.get("triggered_from") + else WorkflowRunTriggeredFrom.DEBUGGING + ) workflow_run_service = WorkflowRunService() - result = workflow_run_service.get_paginate_workflow_runs(app_model=app_model, args=args) + result = workflow_run_service.get_paginate_workflow_runs( + app_model=app_model, args=args, triggered_from=triggered_from + ) + + return result + + +@console_ns.route("/apps//workflow-runs/count") +class WorkflowRunCountApi(Resource): + @api.doc("get_workflow_runs_count") + @api.doc(description="Get workflow runs count statistics") + @api.doc(params={"app_id": "Application ID"}) + @api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}) + @api.doc( + params={ + "time_range": ( + "Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), " + "30m (30 minutes), 30s (30 seconds). Filters by created_at field." + ) + } + ) + @api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}) + @api.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields) + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @marshal_with(workflow_run_count_fields) + def get(self, app_model: App): + """ + Get workflow runs count statistics + """ + args = _parse_workflow_run_count_args() + + # Default to DEBUGGING for workflow if not specified (backward compatibility) + triggered_from = ( + WorkflowRunTriggeredFrom(args.get("triggered_from")) + if args.get("triggered_from") + else WorkflowRunTriggeredFrom.DEBUGGING + ) + + workflow_run_service = WorkflowRunService() + result = workflow_run_service.get_workflow_runs_count( + app_model=app_model, + status=args.get("status"), + time_range=args.get("time_range"), + triggered_from=triggered_from, + ) return result diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 649e881848..79594beeed 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -64,6 +64,15 @@ workflow_run_pagination_fields = { "data": fields.List(fields.Nested(workflow_run_for_list_fields), attribute="data"), } +workflow_run_count_fields = { + "total": fields.Integer, + "running": fields.Integer, + "succeeded": fields.Integer, + "failed": fields.Integer, + "stopped": fields.Integer, + "partial_succeeded": fields.Integer(attribute="partial-succeeded"), +} + workflow_run_detail_fields = { "id": fields.String, "version": fields.String, diff --git a/api/libs/custom_inputs.py b/api/libs/custom_inputs.py new file mode 100644 index 0000000000..10d550ed65 --- /dev/null +++ b/api/libs/custom_inputs.py @@ -0,0 +1,32 @@ +"""Custom input types for Flask-RESTX request parsing.""" + +import re + + +def time_duration(value: str) -> str: + """ + Validate and return time duration string. + + Accepts formats: d (days), h (hours), m (minutes), s (seconds) + Examples: 7d, 4h, 30m, 30s + + Args: + value: The time duration string + + Returns: + The validated time duration string + + Raises: + ValueError: If the format is invalid + """ + if not value: + raise ValueError("Time duration cannot be empty") + + pattern = r"^(\d+)([dhms])$" + if not re.match(pattern, value.lower()): + raise ValueError( + "Invalid time duration format. Use: d (days), h (hours), " + "m (minutes), or s (seconds). Examples: 7d, 4h, 30m, 30s" + ) + + return value.lower() diff --git a/api/libs/time_parser.py b/api/libs/time_parser.py new file mode 100644 index 0000000000..1d9dd92a08 --- /dev/null +++ b/api/libs/time_parser.py @@ -0,0 +1,67 @@ +"""Time duration parser utility.""" + +import re +from datetime import UTC, datetime, timedelta + + +def parse_time_duration(duration_str: str) -> timedelta | None: + """ + Parse time duration string to timedelta. + + Supported formats: + - 7d: 7 days + - 4h: 4 hours + - 30m: 30 minutes + - 30s: 30 seconds + + Args: + duration_str: Duration string (e.g., "7d", "4h", "30m", "30s") + + Returns: + timedelta object or None if invalid format + """ + if not duration_str: + return None + + # Pattern: number followed by unit (d, h, m, s) + pattern = r"^(\d+)([dhms])$" + match = re.match(pattern, duration_str.lower()) + + if not match: + return None + + value = int(match.group(1)) + unit = match.group(2) + + if unit == "d": + return timedelta(days=value) + elif unit == "h": + return timedelta(hours=value) + elif unit == "m": + return timedelta(minutes=value) + elif unit == "s": + return timedelta(seconds=value) + + return None + + +def get_time_threshold(duration_str: str | None) -> datetime | None: + """ + Get datetime threshold from duration string. + + Calculates the datetime that is duration_str ago from now. + + Args: + duration_str: Duration string (e.g., "7d", "4h", "30m", "30s") + + Returns: + datetime object representing the threshold time, or None if no duration + """ + if not duration_str: + return None + + duration = parse_time_duration(duration_str) + if duration is None: + return None + + return datetime.now(UTC) - duration diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index 3ac28fad75..72de9fed31 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -59,6 +59,7 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): triggered_from: str, limit: int = 20, last_id: str | None = None, + status: str | None = None, ) -> InfiniteScrollPagination: """ Get paginated workflow runs with filtering. @@ -73,6 +74,7 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): triggered_from: Filter by trigger source (e.g., "debugging", "app-run") limit: Maximum number of records to return (default: 20) last_id: Cursor for pagination - ID of the last record from previous page + status: Optional filter by status (e.g., "running", "succeeded", "failed") Returns: InfiniteScrollPagination object containing: @@ -107,6 +109,43 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): """ ... + def get_workflow_runs_count( + self, + tenant_id: str, + app_id: str, + triggered_from: str, + status: str | None = None, + time_range: str | None = None, + ) -> dict[str, int]: + """ + Get workflow runs count statistics. + + Retrieves total count and count by status for workflow runs + matching the specified filters. + + Args: + tenant_id: Tenant identifier for multi-tenant isolation + app_id: Application identifier + triggered_from: Filter by trigger source (e.g., "debugging", "app-run") + status: Optional filter by specific status + time_range: Optional time range filter (e.g., "7d", "4h", "30m", "30s") + Filters records based on created_at field + + Returns: + Dictionary containing: + - total: Total count of all workflow runs (or filtered by status) + - running: Count of workflow runs with status "running" + - succeeded: Count of workflow runs with status "succeeded" + - failed: Count of workflow runs with status "failed" + - stopped: Count of workflow runs with status "stopped" + - partial_succeeded: Count of workflow runs with status "partial-succeeded" + + Note: If a status is provided, 'total' will be the count for that status, + and the specific status count will also be set to this value, with all + other status counts being 0. + """ + ... + def get_expired_runs_batch( self, tenant_id: str, diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index 6154273f33..68affb59f3 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -24,11 +24,12 @@ from collections.abc import Sequence from datetime import datetime from typing import cast -from sqlalchemy import delete, select +from sqlalchemy import delete, func, select from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, sessionmaker from libs.infinite_scroll_pagination import InfiniteScrollPagination +from libs.time_parser import get_time_threshold from models.workflow import WorkflowRun from repositories.api_workflow_run_repository import APIWorkflowRunRepository @@ -63,6 +64,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): triggered_from: str, limit: int = 20, last_id: str | None = None, + status: str | None = None, ) -> InfiniteScrollPagination: """ Get paginated workflow runs with filtering. @@ -79,6 +81,10 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): WorkflowRun.triggered_from == triggered_from, ) + # Add optional status filter + if status: + base_stmt = base_stmt.where(WorkflowRun.status == status) + if last_id: # Get the last workflow run for cursor-based pagination last_run_stmt = base_stmt.where(WorkflowRun.id == last_id) @@ -120,6 +126,73 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): ) return session.scalar(stmt) + def get_workflow_runs_count( + self, + tenant_id: str, + app_id: str, + triggered_from: str, + status: str | None = None, + time_range: str | None = None, + ) -> dict[str, int]: + """ + Get workflow runs count statistics grouped by status. + """ + _initial_status_counts = { + "running": 0, + "succeeded": 0, + "failed": 0, + "stopped": 0, + "partial-succeeded": 0, + } + + with self._session_maker() as session: + # Build base where conditions + base_conditions = [ + WorkflowRun.tenant_id == tenant_id, + WorkflowRun.app_id == app_id, + WorkflowRun.triggered_from == triggered_from, + ] + + # Add time range filter if provided + if time_range: + time_threshold = get_time_threshold(time_range) + if time_threshold: + base_conditions.append(WorkflowRun.created_at >= time_threshold) + + # If status filter is provided, return simple count + if status: + count_stmt = select(func.count(WorkflowRun.id)).where(*base_conditions, WorkflowRun.status == status) + total = session.scalar(count_stmt) or 0 + + result = {"total": total} | _initial_status_counts + + # Set the count for the filtered status + if status in result: + result[status] = total + + return result + + # No status filter - get counts grouped by status + base_stmt = ( + select(WorkflowRun.status, func.count(WorkflowRun.id).label("count")) + .where(*base_conditions) + .group_by(WorkflowRun.status) + ) + + # Execute query + results = session.execute(base_stmt).all() + + # Build response dictionary + status_counts = _initial_status_counts.copy() + + total = 0 + for status_val, count in results: + total += count + if status_val in status_counts: + status_counts[status_val] = count + + return {"total": total} | status_counts + def get_expired_runs_batch( self, tenant_id: str, diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index 6a2edd912a..5c8719b499 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -26,13 +26,15 @@ class WorkflowRunService: ) self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) - def get_paginate_advanced_chat_workflow_runs(self, app_model: App, args: dict) -> InfiniteScrollPagination: + def get_paginate_advanced_chat_workflow_runs( + self, app_model: App, args: dict, triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING + ) -> InfiniteScrollPagination: """ Get advanced chat app workflow run list - Only return triggered_from == advanced_chat :param app_model: app model :param args: request args + :param triggered_from: workflow run triggered from (default: DEBUGGING for preview runs) """ class WorkflowWithMessage: @@ -45,7 +47,7 @@ class WorkflowRunService: def __getattr__(self, item): return getattr(self._workflow_run, item) - pagination = self.get_paginate_workflow_runs(app_model, args) + pagination = self.get_paginate_workflow_runs(app_model, args, triggered_from) with_message_workflow_runs = [] for workflow_run in pagination.data: @@ -60,23 +62,27 @@ class WorkflowRunService: pagination.data = with_message_workflow_runs return pagination - def get_paginate_workflow_runs(self, app_model: App, args: dict) -> InfiniteScrollPagination: + def get_paginate_workflow_runs( + self, app_model: App, args: dict, triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING + ) -> InfiniteScrollPagination: """ - Get debug workflow run list - Only return triggered_from == debugging + Get workflow run list :param app_model: app model :param args: request args + :param triggered_from: workflow run triggered from (default: DEBUGGING) """ limit = int(args.get("limit", 20)) last_id = args.get("last_id") + status = args.get("status") return self._workflow_run_repo.get_paginated_workflow_runs( tenant_id=app_model.tenant_id, app_id=app_model.id, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + triggered_from=triggered_from, limit=limit, last_id=last_id, + status=status, ) def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun | None: @@ -92,6 +98,30 @@ class WorkflowRunService: run_id=run_id, ) + def get_workflow_runs_count( + self, + app_model: App, + status: str | None = None, + time_range: str | None = None, + triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING, + ) -> dict[str, int]: + """ + Get workflow runs count statistics + + :param app_model: app model + :param status: optional status filter + :param time_range: optional time range filter (e.g., "7d", "4h", "30m", "30s") + :param triggered_from: workflow run triggered from (default: DEBUGGING) + :return: dict with total and status counts + """ + return self._workflow_run_repo.get_workflow_runs_count( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + triggered_from=triggered_from, + status=status, + time_range=time_range, + ) + def get_workflow_run_node_executions( self, app_model: App, diff --git a/api/tests/unit_tests/libs/test_custom_inputs.py b/api/tests/unit_tests/libs/test_custom_inputs.py new file mode 100644 index 0000000000..7e4c3b4ff0 --- /dev/null +++ b/api/tests/unit_tests/libs/test_custom_inputs.py @@ -0,0 +1,68 @@ +"""Unit tests for custom input types.""" + +import pytest + +from libs.custom_inputs import time_duration + + +class TestTimeDuration: + """Test time_duration input validator.""" + + def test_valid_days(self): + """Test valid days format.""" + result = time_duration("7d") + assert result == "7d" + + def test_valid_hours(self): + """Test valid hours format.""" + result = time_duration("4h") + assert result == "4h" + + def test_valid_minutes(self): + """Test valid minutes format.""" + result = time_duration("30m") + assert result == "30m" + + def test_valid_seconds(self): + """Test valid seconds format.""" + result = time_duration("30s") + assert result == "30s" + + def test_uppercase_conversion(self): + """Test uppercase units are converted to lowercase.""" + result = time_duration("7D") + assert result == "7d" + + result = time_duration("4H") + assert result == "4h" + + def test_invalid_format_no_unit(self): + """Test invalid format without unit.""" + with pytest.raises(ValueError, match="Invalid time duration format"): + time_duration("7") + + def test_invalid_format_wrong_unit(self): + """Test invalid format with wrong unit.""" + with pytest.raises(ValueError, match="Invalid time duration format"): + time_duration("7days") + + with pytest.raises(ValueError, match="Invalid time duration format"): + time_duration("7x") + + def test_invalid_format_no_number(self): + """Test invalid format without number.""" + with pytest.raises(ValueError, match="Invalid time duration format"): + time_duration("d") + + with pytest.raises(ValueError, match="Invalid time duration format"): + time_duration("abc") + + def test_empty_string(self): + """Test empty string.""" + with pytest.raises(ValueError, match="Time duration cannot be empty"): + time_duration("") + + def test_none(self): + """Test None value.""" + with pytest.raises(ValueError, match="Time duration cannot be empty"): + time_duration(None) diff --git a/api/tests/unit_tests/libs/test_time_parser.py b/api/tests/unit_tests/libs/test_time_parser.py new file mode 100644 index 0000000000..83ff251272 --- /dev/null +++ b/api/tests/unit_tests/libs/test_time_parser.py @@ -0,0 +1,91 @@ +"""Unit tests for time parser utility.""" + +from datetime import UTC, datetime, timedelta + +from libs.time_parser import get_time_threshold, parse_time_duration + + +class TestParseTimeDuration: + """Test parse_time_duration function.""" + + def test_parse_days(self): + """Test parsing days.""" + result = parse_time_duration("7d") + assert result == timedelta(days=7) + + def test_parse_hours(self): + """Test parsing hours.""" + result = parse_time_duration("4h") + assert result == timedelta(hours=4) + + def test_parse_minutes(self): + """Test parsing minutes.""" + result = parse_time_duration("30m") + assert result == timedelta(minutes=30) + + def test_parse_seconds(self): + """Test parsing seconds.""" + result = parse_time_duration("30s") + assert result == timedelta(seconds=30) + + def test_parse_uppercase(self): + """Test parsing uppercase units.""" + result = parse_time_duration("7D") + assert result == timedelta(days=7) + + def test_parse_invalid_format(self): + """Test parsing invalid format.""" + result = parse_time_duration("7days") + assert result is None + + result = parse_time_duration("abc") + assert result is None + + result = parse_time_duration("7") + assert result is None + + def test_parse_empty_string(self): + """Test parsing empty string.""" + result = parse_time_duration("") + assert result is None + + def test_parse_none(self): + """Test parsing None.""" + result = parse_time_duration(None) + assert result is None + + +class TestGetTimeThreshold: + """Test get_time_threshold function.""" + + def test_get_threshold_days(self): + """Test getting threshold for days.""" + before = datetime.now(UTC) + result = get_time_threshold("7d") + after = datetime.now(UTC) + + assert result is not None + # Result should be approximately 7 days ago + expected = before - timedelta(days=7) + # Allow 1 second tolerance for test execution time + assert abs((result - expected).total_seconds()) < 1 + + def test_get_threshold_hours(self): + """Test getting threshold for hours.""" + before = datetime.now(UTC) + result = get_time_threshold("4h") + after = datetime.now(UTC) + + assert result is not None + expected = before - timedelta(hours=4) + assert abs((result - expected).total_seconds()) < 1 + + def test_get_threshold_invalid(self): + """Test getting threshold with invalid duration.""" + result = get_time_threshold("invalid") + assert result is None + + def test_get_threshold_none(self): + """Test getting threshold with None.""" + result = get_time_threshold(None) + assert result is None diff --git a/api/tests/unit_tests/repositories/test_workflow_run_repository.py b/api/tests/unit_tests/repositories/test_workflow_run_repository.py new file mode 100644 index 0000000000..8f47f0df48 --- /dev/null +++ b/api/tests/unit_tests/repositories/test_workflow_run_repository.py @@ -0,0 +1,251 @@ +"""Unit tests for workflow run repository with status filter.""" + +import uuid +from unittest.mock import MagicMock + +import pytest +from sqlalchemy.orm import sessionmaker + +from models import WorkflowRun, WorkflowRunTriggeredFrom +from repositories.sqlalchemy_api_workflow_run_repository import DifyAPISQLAlchemyWorkflowRunRepository + + +class TestDifyAPISQLAlchemyWorkflowRunRepository: + """Test workflow run repository with status filtering.""" + + @pytest.fixture + def mock_session_maker(self): + """Create a mock session maker.""" + return MagicMock(spec=sessionmaker) + + @pytest.fixture + def repository(self, mock_session_maker): + """Create repository instance with mock session.""" + return DifyAPISQLAlchemyWorkflowRunRepository(mock_session_maker) + + def test_get_paginated_workflow_runs_without_status(self, repository, mock_session_maker): + """Test getting paginated workflow runs without status filter.""" + # Arrange + tenant_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + mock_session = MagicMock() + mock_session_maker.return_value.__enter__.return_value = mock_session + + mock_runs = [MagicMock(spec=WorkflowRun) for _ in range(3)] + mock_session.scalars.return_value.all.return_value = mock_runs + + # Act + result = repository.get_paginated_workflow_runs( + tenant_id=tenant_id, + app_id=app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + limit=20, + last_id=None, + status=None, + ) + + # Assert + assert len(result.data) == 3 + assert result.limit == 20 + assert result.has_more is False + + def test_get_paginated_workflow_runs_with_status_filter(self, repository, mock_session_maker): + """Test getting paginated workflow runs with status filter.""" + # Arrange + tenant_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + mock_session = MagicMock() + mock_session_maker.return_value.__enter__.return_value = mock_session + + mock_runs = [MagicMock(spec=WorkflowRun, status="succeeded") for _ in range(2)] + mock_session.scalars.return_value.all.return_value = mock_runs + + # Act + result = repository.get_paginated_workflow_runs( + tenant_id=tenant_id, + app_id=app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + limit=20, + last_id=None, + status="succeeded", + ) + + # Assert + assert len(result.data) == 2 + assert all(run.status == "succeeded" for run in result.data) + + def test_get_workflow_runs_count_without_status(self, repository, mock_session_maker): + """Test getting workflow runs count without status filter.""" + # Arrange + tenant_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + mock_session = MagicMock() + mock_session_maker.return_value.__enter__.return_value = mock_session + + # Mock the GROUP BY query results + mock_results = [ + ("succeeded", 5), + ("failed", 2), + ("running", 1), + ] + mock_session.execute.return_value.all.return_value = mock_results + + # Act + result = repository.get_workflow_runs_count( + tenant_id=tenant_id, + app_id=app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + status=None, + ) + + # Assert + assert result["total"] == 8 + assert result["succeeded"] == 5 + assert result["failed"] == 2 + assert result["running"] == 1 + assert result["stopped"] == 0 + assert result["partial-succeeded"] == 0 + + def test_get_workflow_runs_count_with_status_filter(self, repository, mock_session_maker): + """Test getting workflow runs count with status filter.""" + # Arrange + tenant_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + mock_session = MagicMock() + mock_session_maker.return_value.__enter__.return_value = mock_session + + # Mock the count query for succeeded status + mock_session.scalar.return_value = 5 + + # Act + result = repository.get_workflow_runs_count( + tenant_id=tenant_id, + app_id=app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + status="succeeded", + ) + + # Assert + assert result["total"] == 5 + assert result["succeeded"] == 5 + assert result["running"] == 0 + assert result["failed"] == 0 + assert result["stopped"] == 0 + assert result["partial-succeeded"] == 0 + + def test_get_workflow_runs_count_with_invalid_status(self, repository, mock_session_maker): + """Test that invalid status is still counted in total but not in any specific status.""" + # Arrange + tenant_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + mock_session = MagicMock() + mock_session_maker.return_value.__enter__.return_value = mock_session + + # Mock count query returning 0 for invalid status + mock_session.scalar.return_value = 0 + + # Act + result = repository.get_workflow_runs_count( + tenant_id=tenant_id, + app_id=app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + status="invalid_status", + ) + + # Assert + assert result["total"] == 0 + assert all(result[status] == 0 for status in ["running", "succeeded", "failed", "stopped", "partial-succeeded"]) + + def test_get_workflow_runs_count_with_time_range(self, repository, mock_session_maker): + """Test getting workflow runs count with time range filter verifies SQL query construction.""" + # Arrange + tenant_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + mock_session = MagicMock() + mock_session_maker.return_value.__enter__.return_value = mock_session + + # Mock the GROUP BY query results + mock_results = [ + ("succeeded", 3), + ("running", 2), + ] + mock_session.execute.return_value.all.return_value = mock_results + + # Act + result = repository.get_workflow_runs_count( + tenant_id=tenant_id, + app_id=app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + status=None, + time_range="1d", + ) + + # Assert results + assert result["total"] == 5 + assert result["succeeded"] == 3 + assert result["running"] == 2 + assert result["failed"] == 0 + + # Verify that execute was called (which means GROUP BY query was used) + assert mock_session.execute.called, "execute should have been called for GROUP BY query" + + # Verify SQL query includes time filter by checking the statement + call_args = mock_session.execute.call_args + assert call_args is not None, "execute should have been called with a statement" + + # The first argument should be the SQL statement + stmt = call_args[0][0] + # Convert to string to inspect the query + query_str = str(stmt.compile(compile_kwargs={"literal_binds": True})) + + # Verify the query includes created_at filter + # The query should have a WHERE clause with created_at comparison + assert "created_at" in query_str.lower() or "workflow_runs.created_at" in query_str.lower(), ( + "Query should include created_at filter for time range" + ) + + def test_get_workflow_runs_count_with_status_and_time_range(self, repository, mock_session_maker): + """Test getting workflow runs count with both status and time range filters verifies SQL query.""" + # Arrange + tenant_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + mock_session = MagicMock() + mock_session_maker.return_value.__enter__.return_value = mock_session + + # Mock the count query for running status within time range + mock_session.scalar.return_value = 2 + + # Act + result = repository.get_workflow_runs_count( + tenant_id=tenant_id, + app_id=app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + status="running", + time_range="1d", + ) + + # Assert results + assert result["total"] == 2 + assert result["running"] == 2 + assert result["succeeded"] == 0 + assert result["failed"] == 0 + + # Verify that scalar was called (which means COUNT query was used) + assert mock_session.scalar.called, "scalar should have been called for count query" + + # Verify SQL query includes both status and time filter + call_args = mock_session.scalar.call_args + assert call_args is not None, "scalar should have been called with a statement" + + # The first argument should be the SQL statement + stmt = call_args[0][0] + # Convert to string to inspect the query + query_str = str(stmt.compile(compile_kwargs={"literal_binds": True})) + + # Verify the query includes both filters + assert "created_at" in query_str.lower() or "workflow_runs.created_at" in query_str.lower(), ( + "Query should include created_at filter for time range" + ) + assert "status" in query_str.lower() or "workflow_runs.status" in query_str.lower(), ( + "Query should include status filter" + ) From cf7ff76165651db81f881fb40a0db8824caf7f9d Mon Sep 17 00:00:00 2001 From: GuanMu Date: Sat, 18 Oct 2025 23:09:00 +0800 Subject: [PATCH 33/46] fix(web): resolve TypeScript type errors in workflow components (#27086) --- web/app/components/workflow/block-selector/tool-picker.tsx | 1 + .../workflow/block-selector/tool/tool-list-tree-view/item.tsx | 1 - web/app/components/workflow/block-selector/types.ts | 1 + .../components/workflow/block-selector/use-sticky-scroll.ts | 4 ++-- .../components/workflow/datasets-detail-store/provider.tsx | 2 +- .../nodes/_base/components/agent-strategy-selector.tsx | 2 +- 6 files changed, 6 insertions(+), 5 deletions(-) diff --git a/web/app/components/workflow/block-selector/tool-picker.tsx b/web/app/components/workflow/block-selector/tool-picker.tsx index ced6d3e88f..ae4b0d4f02 100644 --- a/web/app/components/workflow/block-selector/tool-picker.tsx +++ b/web/app/components/workflow/block-selector/tool-picker.tsx @@ -178,6 +178,7 @@ const ToolPicker: FC = ({ mcpTools={mcpTools || []} selectedTools={selectedTools} canChooseMCPTool={canChooseMCPTool} + onTagsChange={setTags} /> diff --git a/web/app/components/workflow/block-selector/tool/tool-list-tree-view/item.tsx b/web/app/components/workflow/block-selector/tool/tool-list-tree-view/item.tsx index b3f7aab4df..ac0955da0b 100644 --- a/web/app/components/workflow/block-selector/tool/tool-list-tree-view/item.tsx +++ b/web/app/components/workflow/block-selector/tool/tool-list-tree-view/item.tsx @@ -39,7 +39,6 @@ const Item: FC = ({ key={tool.id} payload={tool} viewType={ViewType.tree} - isShowLetterIndex={false} hasSearchText={hasSearchText} onSelect={onSelect} canNotSelectMultiple={canNotSelectMultiple} diff --git a/web/app/components/workflow/block-selector/types.ts b/web/app/components/workflow/block-selector/types.ts index be960b1246..48fbf6a500 100644 --- a/web/app/components/workflow/block-selector/types.ts +++ b/web/app/components/workflow/block-selector/types.ts @@ -37,6 +37,7 @@ export type ToolDefaultValue = { paramSchemas: Record[] credential_id?: string meta?: PluginMeta + output_schema?: Record } export type DataSourceDefaultValue = { diff --git a/web/app/components/workflow/block-selector/use-sticky-scroll.ts b/web/app/components/workflow/block-selector/use-sticky-scroll.ts index c828e9ce92..7933d63b39 100644 --- a/web/app/components/workflow/block-selector/use-sticky-scroll.ts +++ b/web/app/components/workflow/block-selector/use-sticky-scroll.ts @@ -8,8 +8,8 @@ export enum ScrollPosition { } type Params = { - wrapElemRef: React.RefObject - nextToStickyELemRef: React.RefObject + wrapElemRef: React.RefObject + nextToStickyELemRef: React.RefObject } const useStickyScroll = ({ wrapElemRef, diff --git a/web/app/components/workflow/datasets-detail-store/provider.tsx b/web/app/components/workflow/datasets-detail-store/provider.tsx index 1f5749bc3c..a75b7e1d29 100644 --- a/web/app/components/workflow/datasets-detail-store/provider.tsx +++ b/web/app/components/workflow/datasets-detail-store/provider.tsx @@ -21,7 +21,7 @@ const DatasetsDetailProvider: FC = ({ nodes, children, }) => { - const storeRef = useRef() + const storeRef = useRef(undefined) if (!storeRef.current) storeRef.current = createDatasetsDetailStore() diff --git a/web/app/components/workflow/nodes/_base/components/agent-strategy-selector.tsx b/web/app/components/workflow/nodes/_base/components/agent-strategy-selector.tsx index 7635e0faf0..0c24dcfd2c 100644 --- a/web/app/components/workflow/nodes/_base/components/agent-strategy-selector.tsx +++ b/web/app/components/workflow/nodes/_base/components/agent-strategy-selector.tsx @@ -212,7 +212,7 @@ export const AgentStrategySelector = memo((props: AgentStrategySelectorProps) => agent_strategy_name: tool!.tool_name, agent_strategy_provider_name: tool!.provider_name, agent_strategy_label: tool!.tool_label, - agent_output_schema: tool!.output_schema, + agent_output_schema: tool!.output_schema || {}, plugin_unique_identifier: tool!.provider_id, meta: tool!.meta, }) From 59c1fde3519627b81259bb57661a0d47fa1f17c8 Mon Sep 17 00:00:00 2001 From: Bowen Liang Date: Sat, 18 Oct 2025 23:24:35 +0800 Subject: [PATCH 34/46] doc: add Grafana dashboard template link to docs (#27079) Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- README.md | 10 ++++++++++ docs/ar-SA/README.md | 8 ++++++++ docs/bn-BD/README.md | 8 ++++++++ docs/de-DE/README.md | 8 ++++++++ docs/es-ES/README.md | 8 ++++++++ docs/fr-FR/README.md | 8 ++++++++ docs/ja-JP/README.md | 8 ++++++++ docs/ko-KR/README.md | 8 ++++++++ docs/pt-BR/README.md | 8 ++++++++ docs/sl-SI/README.md | 8 ++++++++ docs/tr-TR/README.md | 8 ++++++++ docs/vi-VN/README.md | 8 ++++++++ docs/zh-CN/README.md | 6 ++++++ docs/zh-TW/README.md | 8 ++++++++ 14 files changed, 112 insertions(+) diff --git a/README.md b/README.md index aadced582d..7c194e065a 100644 --- a/README.md +++ b/README.md @@ -129,8 +129,18 @@ Star Dify on GitHub and be instantly notified of new releases. ## Advanced Setup +### Custom configurations + If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments). +### Metrics Monitoring with Grafana + +Import the dashboard to Grafana, using Dify's PostgreSQL database as data source, to monitor metrics in granularity of apps, tenants, messages, and more. + +- [Grafana Dashboard by @bowenliang123](https://github.com/bowenliang123/dify-grafana-dashboard) + +### Deployment with Kubernetes + If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes. - [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) diff --git a/docs/ar-SA/README.md b/docs/ar-SA/README.md index afa494c5d3..30920ed983 100644 --- a/docs/ar-SA/README.md +++ b/docs/ar-SA/README.md @@ -115,6 +115,14 @@ docker compose up -d إذا كنت بحاجة إلى تخصيص الإعدادات، فيرجى الرجوع إلى التعليقات في ملف [.env.example](../../docker/.env.example) وتحديث القيم المقابلة في ملف `.env`. بالإضافة إلى ذلك، قد تحتاج إلى إجراء تعديلات على ملف `docker-compose.yaml` نفسه، مثل تغيير إصدارات الصور أو تعيينات المنافذ أو نقاط تحميل وحدات التخزين، بناءً على بيئة النشر ومتطلباتك الخاصة. بعد إجراء أي تغييرات، يرجى إعادة تشغيل `docker-compose up -d`. يمكنك العثور على قائمة كاملة بمتغيرات البيئة المتاحة [هنا](https://docs.dify.ai/getting-started/install-self-hosted/environments). +### مراقبة المقاييس باستخدام Grafana + +استيراد لوحة التحكم إلى Grafana، باستخدام قاعدة بيانات PostgreSQL الخاصة بـ Dify كمصدر للبيانات، لمراقبة المقاييس بدقة للتطبيقات والمستأجرين والرسائل وغير ذلك. + +- [لوحة تحكم Grafana بواسطة @bowenliang123](https://github.com/bowenliang123/dify-grafana-dashboard) + +### النشر باستخدام Kubernetes + يوجد مجتمع خاص بـ [Helm Charts](https://helm.sh/) وملفات YAML التي تسمح بتنفيذ Dify على Kubernetes للنظام من الإيجابيات العلوية. - [رسم بياني Helm من قبل @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) diff --git a/docs/bn-BD/README.md b/docs/bn-BD/README.md index 318853a8de..5430364ef9 100644 --- a/docs/bn-BD/README.md +++ b/docs/bn-BD/README.md @@ -132,6 +132,14 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন যদি আপনার কনফিগারেশনটি কাস্টমাইজ করার প্রয়োজন হয়, তাহলে অনুগ্রহ করে আমাদের [.env.example](../../docker/.env.example) ফাইল দেখুন এবং আপনার `.env` ফাইলে সংশ্লিষ্ট মানগুলি আপডেট করুন। এছাড়াও, আপনার নির্দিষ্ট এনভায়রনমেন্ট এবং প্রয়োজনীয়তার উপর ভিত্তি করে আপনাকে `docker-compose.yaml` ফাইলে সমন্বয় করতে হতে পারে, যেমন ইমেজ ভার্সন পরিবর্তন করা, পোর্ট ম্যাপিং করা, অথবা ভলিউম মাউন্ট করা। যেকোনো পরিবর্তন করার পর, অনুগ্রহ করে `docker-compose up -d` পুনরায় চালান। ভেরিয়েবলের সম্পূর্ণ তালিকা [এখানে] (https://docs.dify.ai/getting-started/install-self-hosted/environments) খুঁজে পেতে পারেন। +### Grafana দিয়ে মেট্রিক্স মনিটরিং + +Dify-এর PostgreSQL ডাটাবেসকে ডেটা সোর্স হিসাবে ব্যবহার করে, অ্যাপ, টেন্যান্ট, মেসেজ ইত্যাদির গ্র্যানুলারিটিতে মেট্রিক্স মনিটর করার জন্য Grafana-তে ড্যাশবোর্ড ইম্পোর্ট করুন। + +- [@bowenliang123 কর্তৃক Grafana ড্যাশবোর্ড](https://github.com/bowenliang123/dify-grafana-dashboard) + +### Kubernetes এর সাথে ডেপ্লয়মেন্ট + যদি আপনি একটি হাইলি এভেইলেবল সেটআপ কনফিগার করতে চান, তাহলে কমিউনিটি [Helm Charts](https://helm.sh/) এবং YAML ফাইল রয়েছে যা Dify কে Kubernetes-এ ডিপ্লয় করার প্রক্রিয়া বর্ণনা করে। - [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) diff --git a/docs/de-DE/README.md b/docs/de-DE/README.md index 8907d914d3..6c49fbdfc3 100644 --- a/docs/de-DE/README.md +++ b/docs/de-DE/README.md @@ -130,6 +130,14 @@ Star Dify auf GitHub und lassen Sie sich sofort über neue Releases benachrichti Falls Sie die Konfiguration anpassen müssen, lesen Sie bitte die Kommentare in unserer [.env.example](../../docker/.env.example)-Datei und aktualisieren Sie die entsprechenden Werte in Ihrer `.env`-Datei. Zusätzlich müssen Sie eventuell Anpassungen an der `docker-compose.yaml`-Datei vornehmen, wie zum Beispiel das Ändern von Image-Versionen, Portzuordnungen oder Volumen-Mounts, je nach Ihrer spezifischen Einsatzumgebung und Ihren Anforderungen. Nachdem Sie Änderungen vorgenommen haben, starten Sie `docker-compose up -d` erneut. Eine vollständige Liste der verfügbaren Umgebungsvariablen finden Sie [hier](https://docs.dify.ai/getting-started/install-self-hosted/environments). +### Metriküberwachung mit Grafana + +Importieren Sie das Dashboard in Grafana, wobei Sie die PostgreSQL-Datenbank von Dify als Datenquelle verwenden, um Metriken in der Granularität von Apps, Mandanten, Nachrichten und mehr zu überwachen. + +- [Grafana-Dashboard von @bowenliang123](https://github.com/bowenliang123/dify-grafana-dashboard) + +### Bereitstellung mit Kubernetes + Falls Sie eine hochverfügbare Konfiguration einrichten möchten, gibt es von der Community bereitgestellte [Helm Charts](https://helm.sh/) und YAML-Dateien, die es ermöglichen, Dify auf Kubernetes bereitzustellen. - [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) diff --git a/docs/es-ES/README.md b/docs/es-ES/README.md index b005691fea..ae83d416e3 100644 --- a/docs/es-ES/README.md +++ b/docs/es-ES/README.md @@ -128,6 +128,14 @@ Si necesita personalizar la configuración, consulte los comentarios en nuestro . Después de realizar los cambios, ejecuta `docker-compose up -d` nuevamente. Puedes ver la lista completa de variables de entorno [aquí](https://docs.dify.ai/getting-started/install-self-hosted/environments). +### Monitorización de Métricas con Grafana + +Importe el panel a Grafana, utilizando la base de datos PostgreSQL de Dify como fuente de datos, para monitorizar métricas en granularidad de aplicaciones, inquilinos, mensajes y más. + +- [Panel de Grafana por @bowenliang123](https://github.com/bowenliang123/dify-grafana-dashboard) + +### Implementación con Kubernetes + Si desea configurar una configuración de alta disponibilidad, la comunidad proporciona [Gráficos Helm](https://helm.sh/) y archivos YAML, a través de los cuales puede desplegar Dify en Kubernetes. - [Gráfico Helm por @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) diff --git a/docs/fr-FR/README.md b/docs/fr-FR/README.md index 3aca9a9672..b7d006a927 100644 --- a/docs/fr-FR/README.md +++ b/docs/fr-FR/README.md @@ -126,6 +126,14 @@ Après l'exécution, vous pouvez accéder au tableau de bord Dify dans votre nav Si vous devez personnaliser la configuration, veuillez vous référer aux commentaires dans notre fichier [.env.example](../../docker/.env.example) et mettre à jour les valeurs correspondantes dans votre fichier `.env`. De plus, vous devrez peut-être apporter des modifications au fichier `docker-compose.yaml` lui-même, comme changer les versions d'image, les mappages de ports ou les montages de volumes, en fonction de votre environnement de déploiement et de vos exigences spécifiques. Après avoir effectué des modifications, veuillez réexécuter `docker-compose up -d`. Vous pouvez trouver la liste complète des variables d'environnement disponibles [ici](https://docs.dify.ai/getting-started/install-self-hosted/environments). +### Surveillance des Métriques avec Grafana + +Importez le tableau de bord dans Grafana, en utilisant la base de données PostgreSQL de Dify comme source de données, pour surveiller les métriques avec une granularité d'applications, de locataires, de messages et plus. + +- [Tableau de bord Grafana par @bowenliang123](https://github.com/bowenliang123/dify-grafana-dashboard) + +### Déploiement avec Kubernetes + Si vous souhaitez configurer une configuration haute disponibilité, la communauté fournit des [Helm Charts](https://helm.sh/) et des fichiers YAML, à travers lesquels vous pouvez déployer Dify sur Kubernetes. - [Helm Chart par @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) diff --git a/docs/ja-JP/README.md b/docs/ja-JP/README.md index 66831285d6..f9e700d1df 100644 --- a/docs/ja-JP/README.md +++ b/docs/ja-JP/README.md @@ -127,6 +127,14 @@ docker compose up -d 設定をカスタマイズする必要がある場合は、[.env.example](../../docker/.env.example) ファイルのコメントを参照し、`.env` ファイルの対応する値を更新してください。さらに、デプロイ環境や要件に応じて、`docker-compose.yaml` ファイル自体を調整する必要がある場合があります。たとえば、イメージのバージョン、ポートのマッピング、ボリュームのマウントなどを変更します。変更を加えた後は、`docker-compose up -d` を再実行してください。利用可能な環境変数の全一覧は、[こちら](https://docs.dify.ai/getting-started/install-self-hosted/environments)で確認できます。 +### Grafanaを使用したメトリクス監視 + +Grafanaにダッシュボードをインポートし、DifyのPostgreSQLデータベースをデータソースとして使用して、アプリ、テナント、メッセージなどの粒度でメトリクスを監視します。 + +- [@bowenliang123によるGrafanaダッシュボード](https://github.com/bowenliang123/dify-grafana-dashboard) + +### Kubernetesでのデプロイ + 高可用性設定を設定する必要がある場合、コミュニティは[Helm Charts](https://helm.sh/)とYAMLファイルにより、DifyをKubernetesにデプロイすることができます。 - [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) diff --git a/docs/ko-KR/README.md b/docs/ko-KR/README.md index ec67bc90ed..4e4b82e920 100644 --- a/docs/ko-KR/README.md +++ b/docs/ko-KR/README.md @@ -120,6 +120,14 @@ docker compose up -d 구성을 사용자 정의해야 하는 경우 [.env.example](../../docker/.env.example) 파일의 주석을 참조하고 `.env` 파일에서 해당 값을 업데이트하십시오. 또한 특정 배포 환경 및 요구 사항에 따라 `docker-compose.yaml` 파일 자체를 조정해야 할 수도 있습니다. 예를 들어 이미지 버전, 포트 매핑 또는 볼륨 마운트를 변경합니다. 변경 한 후 `docker-compose up -d`를 다시 실행하십시오. 사용 가능한 환경 변수의 전체 목록은 [여기](https://docs.dify.ai/getting-started/install-self-hosted/environments)에서 찾을 수 있습니다. +### Grafana를 사용한 메트릭 모니터링 + +Dify의 PostgreSQL 데이터베이스를 데이터 소스로 사용하여 앱, 테넌트, 메시지 등에 대한 세분화된 메트릭을 모니터링하기 위해 대시보드를 Grafana로 가져옵니다. + +- [@bowenliang123의 Grafana 대시보드](https://github.com/bowenliang123/dify-grafana-dashboard) + +### Kubernetes를 통한 배포 + Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했다는 커뮤니티가 제공하는 [Helm Charts](https://helm.sh/)와 YAML 파일이 존재합니다. - [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) diff --git a/docs/pt-BR/README.md b/docs/pt-BR/README.md index 78383a3c76..f96b18eabb 100644 --- a/docs/pt-BR/README.md +++ b/docs/pt-BR/README.md @@ -126,6 +126,14 @@ Após a execução, você pode acessar o painel do Dify no navegador em [http:// Se precisar personalizar a configuração, consulte os comentários no nosso arquivo [.env.example](../../docker/.env.example) e atualize os valores correspondentes no seu arquivo `.env`. Além disso, talvez seja necessário fazer ajustes no próprio arquivo `docker-compose.yaml`, como alterar versões de imagem, mapeamentos de portas ou montagens de volumes, com base no seu ambiente de implantação específico e nas suas necessidades. Após fazer quaisquer alterações, execute novamente `docker-compose up -d`. Você pode encontrar a lista completa de variáveis de ambiente disponíveis [aqui](https://docs.dify.ai/getting-started/install-self-hosted/environments). +### Monitoramento de Métricas com Grafana + +Importe o dashboard para o Grafana, usando o banco de dados PostgreSQL do Dify como fonte de dados, para monitorar métricas na granularidade de aplicativos, inquilinos, mensagens e muito mais. + +- [Dashboard do Grafana por @bowenliang123](https://github.com/bowenliang123/dify-grafana-dashboard) + +### Implantação com Kubernetes + Se deseja configurar uma instalação de alta disponibilidade, há [Helm Charts](https://helm.sh/) e arquivos YAML contribuídos pela comunidade que permitem a implantação do Dify no Kubernetes. - [Helm Chart de @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) diff --git a/docs/sl-SI/README.md b/docs/sl-SI/README.md index 65aedb7703..04dc3b5dff 100644 --- a/docs/sl-SI/README.md +++ b/docs/sl-SI/README.md @@ -128,6 +128,14 @@ Star Dify on GitHub and be instantly notified of new releases. Če morate prilagoditi konfiguracijo, si oglejte komentarje v naši datoteki .env.example in posodobite ustrezne vrednosti v svoji .env datoteki. Poleg tega boste morda morali prilagoditi docker-compose.yamlsamo datoteko, na primer spremeniti različice slike, preslikave vrat ali namestitve nosilca, glede na vaše specifično okolje in zahteve za uvajanje. Po kakršnih koli spremembah ponovno zaženite docker-compose up -d. Celoten seznam razpoložljivih spremenljivk okolja najdete tukaj . +### Spremljanje metrik z Grafana + +Uvoz nadzorne plošče v Grafana, z uporabo Difyjeve PostgreSQL baze podatkov kot vir podatkov, za spremljanje metrike glede na podrobnost aplikacij, najemnikov, sporočil in drugega. + +- [Nadzorna plošča Grafana avtorja @bowenliang123](https://github.com/bowenliang123/dify-grafana-dashboard) + +### Namestitev s Kubernetes + Če želite konfigurirati visoko razpoložljivo nastavitev, so na voljo Helm Charts in datoteke YAML, ki jih prispeva skupnost, ki omogočajo uvedbo Difyja v Kubernetes. - [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) diff --git a/docs/tr-TR/README.md b/docs/tr-TR/README.md index a044da1f4e..965a1704be 100644 --- a/docs/tr-TR/README.md +++ b/docs/tr-TR/README.md @@ -120,6 +120,14 @@ docker compose up -d Yapılandırmayı özelleştirmeniz gerekiyorsa, lütfen [.env.example](../../docker/.env.example) dosyamızdaki yorumlara bakın ve `.env` dosyanızdaki ilgili değerleri güncelleyin. Ayrıca, spesifik dağıtım ortamınıza ve gereksinimlerinize bağlı olarak `docker-compose.yaml` dosyasının kendisinde de, imaj sürümlerini, port eşlemelerini veya hacim bağlantılarını değiştirmek gibi ayarlamalar yapmanız gerekebilir. Herhangi bir değişiklik yaptıktan sonra, lütfen `docker-compose up -d` komutunu tekrar çalıştırın. Kullanılabilir tüm ortam değişkenlerinin tam listesini [burada](https://docs.dify.ai/getting-started/install-self-hosted/environments) bulabilirsiniz. +### Grafana ile Metrik İzleme + +Uygulamalar, kiracılar, mesajlar ve daha fazlasının granularitesinde metrikleri izlemek için Dify'nin PostgreSQL veritabanını veri kaynağı olarak kullanarak panoyu Grafana'ya aktarın. + +- [@bowenliang123 tarafından Grafana Panosu](%E9%93%BE%E6%8E%A5) + +### Kubernetes ile Dağıtım + Yüksek kullanılabilirliğe sahip bir kurulum yapılandırmak isterseniz, Dify'ın Kubernetes üzerine dağıtılmasına olanak tanıyan topluluk katkılı [Helm Charts](https://helm.sh/) ve YAML dosyaları mevcuttur. - [@LeoQuote tarafından Helm Chart](https://github.com/douban/charts/tree/master/charts/dify) diff --git a/docs/vi-VN/README.md b/docs/vi-VN/README.md index 847641da12..51f7c5d994 100644 --- a/docs/vi-VN/README.md +++ b/docs/vi-VN/README.md @@ -121,6 +121,14 @@ Sau khi chạy, bạn có thể truy cập bảng điều khiển Dify trong tr Nếu bạn cần tùy chỉnh cấu hình, vui lòng tham khảo các nhận xét trong tệp [.env.example](../../docker/.env.example) của chúng tôi và cập nhật các giá trị tương ứng trong tệp `.env` của bạn. Ngoài ra, bạn có thể cần điều chỉnh tệp `docker-compose.yaml`, chẳng hạn như thay đổi phiên bản hình ảnh, ánh xạ cổng hoặc gắn kết khối lượng, dựa trên môi trường triển khai cụ thể và yêu cầu của bạn. Sau khi thực hiện bất kỳ thay đổi nào, vui lòng chạy lại `docker-compose up -d`. Bạn có thể tìm thấy danh sách đầy đủ các biến môi trường có sẵn [tại đây](https://docs.dify.ai/getting-started/install-self-hosted/environments). +### Giám sát Số liệu với Grafana + +Nhập bảng điều khiển vào Grafana, sử dụng cơ sở dữ liệu PostgreSQL của Dify làm nguồn dữ liệu, để giám sát số liệu theo mức độ chi tiết của ứng dụng, người thuê, tin nhắn và hơn thế nữa. + +- [Bảng điều khiển Grafana của @bowenliang123](https://github.com/bowenliang123/dify-grafana-dashboard) + +### Triển khai với Kubernetes + Nếu bạn muốn cấu hình một cài đặt có độ sẵn sàng cao, có các [Helm Charts](https://helm.sh/) và tệp YAML do cộng đồng đóng góp cho phép Dify được triển khai trên Kubernetes. - [Helm Chart bởi @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) diff --git a/docs/zh-CN/README.md b/docs/zh-CN/README.md index 202b99a6b1..888a0d7f12 100644 --- a/docs/zh-CN/README.md +++ b/docs/zh-CN/README.md @@ -127,6 +127,12 @@ docker compose up -d 如果您需要自定义配置,请参考 [.env.example](../../docker/.env.example) 文件中的注释,并更新 `.env` 文件中对应的值。此外,您可能需要根据您的具体部署环境和需求对 `docker-compose.yaml` 文件本身进行调整,例如更改镜像版本、端口映射或卷挂载。完成任何更改后,请重新运行 `docker-compose up -d`。您可以在[此处](https://docs.dify.ai/getting-started/install-self-hosted/environments)找到可用环境变量的完整列表。 +### 使用 Grafana 进行指标监控 + +将仪表板导入 Grafana,使用 Dify 的 PostgreSQL 数据库作为数据源,以监控应用、租户、消息等粒度的指标。 + +- [由 @bowenliang123 提供的 Grafana 仪表板](https://github.com/bowenliang123/dify-grafana-dashboard) + #### 使用 Helm Chart 或 Kubernetes 资源清单(YAML)部署 使用 [Helm Chart](https://helm.sh/) 版本或者 Kubernetes 资源清单(YAML),可以在 Kubernetes 上部署 Dify。 diff --git a/docs/zh-TW/README.md b/docs/zh-TW/README.md index 526e8d9c8c..d8c484a6d4 100644 --- a/docs/zh-TW/README.md +++ b/docs/zh-TW/README.md @@ -130,6 +130,14 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify 如果您需要自定義配置,請參考我們的 [.env.example](../../docker/.env.example) 文件中的註釋,並在您的 `.env` 文件中更新相應的值。此外,根據您特定的部署環境和需求,您可能需要調整 `docker-compose.yaml` 文件本身,例如更改映像版本、端口映射或卷掛載。進行任何更改後,請重新運行 `docker-compose up -d`。您可以在[這裡](https://docs.dify.ai/getting-started/install-self-hosted/environments)找到可用環境變數的完整列表。 +### 使用 Grafana 進行指標監控 + +將儀表板匯入 Grafana,使用 Dify 的 PostgreSQL 資料庫作為資料來源,以監控應用程式、租戶、訊息等顆粒度的指標。 + +- [由 @bowenliang123 提供的 Grafana 儀表板](https://github.com/bowenliang123/dify-grafana-dashboard) + +### 使用 Kubernetes 部署 + 如果您想配置高可用性設置,社區貢獻的 [Helm Charts](https://helm.sh/) 和 Kubernetes 資源清單(YAML)允許在 Kubernetes 上部署 Dify。 - [由 @LeoQuote 提供的 Helm Chart](https://github.com/douban/charts/tree/master/charts/dify) From 4488c090b20e9ae6d1373aa4be40eca615641cf4 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Sun, 19 Oct 2025 12:54:41 +0900 Subject: [PATCH 35/46] fluent api (#27093) Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- api/controllers/console/admin.py | 20 +- .../console/app/advanced_prompt_template.py | 12 +- api/controllers/console/app/agent.py | 8 +- api/controllers/console/app/annotation.py | 29 +- api/controllers/console/app/app.py | 125 ++++---- api/controllers/console/app/app_import.py | 22 +- api/controllers/console/app/audio.py | 15 +- api/controllers/console/app/completion.py | 36 ++- api/controllers/console/app/conversation.py | 62 ++-- .../console/app/conversation_variables.py | 3 +- api/controllers/console/app/generator.py | 51 ++-- api/controllers/console/app/mcp_server.py | 20 +- api/controllers/console/app/message.py | 30 +- api/controllers/console/app/ops_trace.py | 22 +- api/controllers/console/app/site.py | 42 +-- api/controllers/console/app/statistic.py | 64 ++-- api/controllers/console/app/workflow.py | 101 ++++--- .../console/app/workflow_app_log.py | 54 ++-- .../console/app/workflow_draft_variable.py | 29 +- .../console/app/workflow_statistic.py | 32 +- api/controllers/console/auth/activate.py | 30 +- .../console/auth/data_source_bearer_auth.py | 10 +- .../console/auth/email_register.py | 28 +- .../console/auth/forgot_password.py | 28 +- api/controllers/console/auth/login.py | 41 +-- api/controllers/console/auth/oauth_server.py | 20 +- api/controllers/console/billing/billing.py | 8 +- api/controllers/console/billing/compliance.py | 3 +- .../console/datasets/data_source.py | 12 +- api/controllers/console/datasets/datasets.py | 246 ++++++++-------- .../console/datasets/datasets_document.py | 70 ++--- .../console/datasets/datasets_segments.py | 70 +++-- api/controllers/console/datasets/external.py | 112 +++---- .../console/datasets/hit_testing_base.py | 11 +- api/controllers/console/datasets/metadata.py | 16 +- .../datasets/rag_pipeline/datasource_auth.py | 44 +-- .../datasource_content_preview.py | 10 +- .../datasets/rag_pipeline/rag_pipeline.py | 84 +++--- .../rag_pipeline/rag_pipeline_datasets.py | 4 +- .../rag_pipeline_draft_variable.py | 29 +- .../rag_pipeline/rag_pipeline_import.py | 25 +- .../rag_pipeline/rag_pipeline_workflow.py | 144 ++++----- api/controllers/console/datasets/website.py | 25 +- api/controllers/console/explore/audio.py | 12 +- api/controllers/console/explore/completion.py | 30 +- .../console/explore/conversation.py | 18 +- .../console/explore/installed_app.py | 6 +- api/controllers/console/explore/message.py | 21 +- .../console/explore/recommended_app.py | 3 +- .../console/explore/saved_message.py | 11 +- api/controllers/console/explore/workflow.py | 8 +- api/controllers/console/extension.py | 23 +- api/controllers/console/init_validate.py | 3 +- api/controllers/console/remote_files.py | 3 +- api/controllers/console/setup.py | 10 +- api/controllers/console/tag/tags.py | 53 ++-- api/controllers/console/version.py | 3 +- api/controllers/console/workspace/account.py | 104 ++++--- api/controllers/console/workspace/endpoint.py | 47 +-- .../workspace/load_balancing_config.py | 44 +-- api/controllers/console/workspace/members.py | 29 +- .../console/workspace/model_providers.py | 44 +-- api/controllers/console/workspace/models.py | 239 ++++++++------- api/controllers/console/workspace/plugin.py | 130 +++++---- .../console/workspace/tool_providers.py | 251 ++++++++-------- .../console/workspace/workspace.py | 22 +- api/controllers/files/image_preview.py | 12 +- api/controllers/files/tool_files.py | 13 +- api/controllers/files/upload.py | 22 +- api/controllers/inner_api/mail.py | 12 +- api/controllers/inner_api/plugin/wraps.py | 8 +- .../inner_api/workspace/workspace.py | 11 +- api/controllers/mcp/mcp.py | 14 +- api/controllers/service_api/app/annotation.py | 24 +- api/controllers/service_api/app/audio.py | 12 +- api/controllers/service_api/app/completion.py | 54 ++-- .../service_api/app/conversation.py | 87 +++--- .../service_api/app/file_preview.py | 3 +- api/controllers/service_api/app/message.py | 42 ++- api/controllers/service_api/app/workflow.py | 52 ++-- .../service_api/dataset/dataset.py | 274 +++++++++--------- .../service_api/dataset/document.py | 50 ++-- .../service_api/dataset/metadata.py | 16 +- .../rag_pipeline/rag_pipeline_workflow.py | 40 +-- .../service_api/dataset/segment.py | 38 ++- api/controllers/web/app.py | 11 +- api/controllers/web/audio.py | 12 +- api/controllers/web/completion.py | 32 +- api/controllers/web/conversation.py | 32 +- api/controllers/web/forgot_password.py | 28 +- api/controllers/web/login.py | 26 +- api/controllers/web/message.py | 21 +- api/controllers/web/remote_files.py | 3 +- api/controllers/web/saved_message.py | 11 +- api/controllers/web/workflow.py | 8 +- .../services/test_metadata_bug_complete.py | 24 +- .../services/test_metadata_nullable_bug.py | 24 +- 97 files changed, 2179 insertions(+), 1798 deletions(-) diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index 93f242ad28..ef96184678 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -70,15 +70,17 @@ class InsertExploreAppListApi(Resource): @only_edition_cloud @admin_required def post(self): - parser = reqparse.RequestParser() - parser.add_argument("app_id", type=str, required=True, nullable=False, location="json") - parser.add_argument("desc", type=str, location="json") - parser.add_argument("copyright", type=str, location="json") - parser.add_argument("privacy_policy", type=str, location="json") - parser.add_argument("custom_disclaimer", type=str, location="json") - parser.add_argument("language", type=supported_language, required=True, nullable=False, location="json") - parser.add_argument("category", type=str, required=True, nullable=False, location="json") - parser.add_argument("position", type=int, required=True, nullable=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("app_id", type=str, required=True, nullable=False, location="json") + .add_argument("desc", type=str, location="json") + .add_argument("copyright", type=str, location="json") + .add_argument("privacy_policy", type=str, location="json") + .add_argument("custom_disclaimer", type=str, location="json") + .add_argument("language", type=supported_language, required=True, nullable=False, location="json") + .add_argument("category", type=str, required=True, nullable=False, location="json") + .add_argument("position", type=int, required=True, nullable=False, location="json") + ) args = parser.parse_args() app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none() diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py index 315825db79..5885d7b447 100644 --- a/api/controllers/console/app/advanced_prompt_template.py +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -25,11 +25,13 @@ class AdvancedPromptTemplateList(Resource): @login_required @account_initialization_required def get(self): - parser = reqparse.RequestParser() - parser.add_argument("app_mode", type=str, required=True, location="args") - parser.add_argument("model_mode", type=str, required=True, location="args") - parser.add_argument("has_context", type=str, required=False, default="true", location="args") - parser.add_argument("model_name", type=str, required=True, location="args") + parser = ( + reqparse.RequestParser() + .add_argument("app_mode", type=str, required=True, location="args") + .add_argument("model_mode", type=str, required=True, location="args") + .add_argument("has_context", type=str, required=False, default="true", location="args") + .add_argument("model_name", type=str, required=True, location="args") + ) args = parser.parse_args() return AdvancedPromptTemplateService.get_prompt(args) diff --git a/api/controllers/console/app/agent.py b/api/controllers/console/app/agent.py index c063f336c7..717263a74d 100644 --- a/api/controllers/console/app/agent.py +++ b/api/controllers/console/app/agent.py @@ -27,9 +27,11 @@ class AgentLogApi(Resource): @get_app_model(mode=[AppMode.AGENT_CHAT]) def get(self, app_model): """Get agent logs""" - parser = reqparse.RequestParser() - parser.add_argument("message_id", type=uuid_value, required=True, location="args") - parser.add_argument("conversation_id", type=uuid_value, required=True, location="args") + parser = ( + reqparse.RequestParser() + .add_argument("message_id", type=uuid_value, required=True, location="args") + .add_argument("conversation_id", type=uuid_value, required=True, location="args") + ) args = parser.parse_args() diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 3e549d869e..932214058a 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -44,10 +44,12 @@ class AnnotationReplyActionApi(Resource): @edit_permission_required def post(self, app_id, action: Literal["enable", "disable"]): app_id = str(app_id) - parser = reqparse.RequestParser() - parser.add_argument("score_threshold", required=True, type=float, location="json") - parser.add_argument("embedding_provider_name", required=True, type=str, location="json") - parser.add_argument("embedding_model_name", required=True, type=str, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("score_threshold", required=True, type=float, location="json") + .add_argument("embedding_provider_name", required=True, type=str, location="json") + .add_argument("embedding_model_name", required=True, type=str, location="json") + ) args = parser.parse_args() if action == "enable": result = AppAnnotationService.enable_app_annotation(args, app_id) @@ -98,8 +100,7 @@ class AppAnnotationSettingUpdateApi(Resource): app_id = str(app_id) annotation_setting_id = str(annotation_setting_id) - parser = reqparse.RequestParser() - parser.add_argument("score_threshold", required=True, type=float, location="json") + parser = reqparse.RequestParser().add_argument("score_threshold", required=True, type=float, location="json") args = parser.parse_args() result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args) @@ -190,9 +191,11 @@ class AnnotationApi(Resource): @edit_permission_required def post(self, app_id): app_id = str(app_id) - parser = reqparse.RequestParser() - parser.add_argument("question", required=True, type=str, location="json") - parser.add_argument("answer", required=True, type=str, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("question", required=True, type=str, location="json") + .add_argument("answer", required=True, type=str, location="json") + ) args = parser.parse_args() annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id) return annotation @@ -259,9 +262,11 @@ class AnnotationUpdateDeleteApi(Resource): def post(self, app_id, annotation_id): app_id = str(app_id) annotation_id = str(annotation_id) - parser = reqparse.RequestParser() - parser.add_argument("question", required=True, type=str, location="json") - parser.add_argument("answer", required=True, type=str, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("question", required=True, type=str, location="json") + .add_argument("answer", required=True, type=str, location="json") + ) args = parser.parse_args() annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id) return annotation diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 3900f5a6eb..17505d69b2 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -63,28 +63,30 @@ class AppListApi(Resource): except ValueError: abort(400, message="Invalid UUID format in tag_ids.") - parser = reqparse.RequestParser() - parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") - parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") - parser.add_argument( - "mode", - type=str, - choices=[ - "completion", - "chat", - "advanced-chat", - "workflow", - "agent-chat", - "channel", - "all", - ], - default="all", - location="args", - required=False, + parser = ( + reqparse.RequestParser() + .add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") + .add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") + .add_argument( + "mode", + type=str, + choices=[ + "completion", + "chat", + "advanced-chat", + "workflow", + "agent-chat", + "channel", + "all", + ], + default="all", + location="args", + required=False, + ) + .add_argument("name", type=str, location="args", required=False) + .add_argument("tag_ids", type=uuid_list, location="args", required=False) + .add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False) ) - parser.add_argument("name", type=str, location="args", required=False) - parser.add_argument("tag_ids", type=uuid_list, location="args", required=False) - parser.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False) args = parser.parse_args() @@ -133,13 +135,15 @@ class AppListApi(Resource): def post(self): """Create app""" current_user, current_tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=True, location="json") - parser.add_argument("description", type=validate_description_length, location="json") - parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json") - parser.add_argument("icon_type", type=str, location="json") - parser.add_argument("icon", type=str, location="json") - parser.add_argument("icon_background", type=str, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("name", type=str, required=True, location="json") + .add_argument("description", type=validate_description_length, location="json") + .add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json") + .add_argument("icon_type", type=str, location="json") + .add_argument("icon", type=str, location="json") + .add_argument("icon_background", type=str, location="json") + ) args = parser.parse_args() if "mode" not in args or args["mode"] is None: @@ -203,14 +207,16 @@ class AppApi(Resource): @marshal_with(app_detail_fields_with_site) def put(self, app_model): """Update app""" - parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=True, nullable=False, location="json") - parser.add_argument("description", type=validate_description_length, location="json") - parser.add_argument("icon_type", type=str, location="json") - parser.add_argument("icon", type=str, location="json") - parser.add_argument("icon_background", type=str, location="json") - parser.add_argument("use_icon_as_answer_icon", type=bool, location="json") - parser.add_argument("max_active_requests", type=int, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("name", type=str, required=True, nullable=False, location="json") + .add_argument("description", type=validate_description_length, location="json") + .add_argument("icon_type", type=str, location="json") + .add_argument("icon", type=str, location="json") + .add_argument("icon_background", type=str, location="json") + .add_argument("use_icon_as_answer_icon", type=bool, location="json") + .add_argument("max_active_requests", type=int, location="json") + ) args = parser.parse_args() app_service = AppService() @@ -278,12 +284,14 @@ class AppCopyApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("name", type=str, location="json") - parser.add_argument("description", type=validate_description_length, location="json") - parser.add_argument("icon_type", type=str, location="json") - parser.add_argument("icon", type=str, location="json") - parser.add_argument("icon_background", type=str, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("name", type=str, location="json") + .add_argument("description", type=validate_description_length, location="json") + .add_argument("icon_type", type=str, location="json") + .add_argument("icon", type=str, location="json") + .add_argument("icon_background", type=str, location="json") + ) args = parser.parse_args() with Session(db.engine) as session: @@ -331,9 +339,11 @@ class AppExportApi(Resource): def get(self, app_model): """Export app""" # Add include_secret params - parser = reqparse.RequestParser() - parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args") - parser.add_argument("workflow_id", type=str, location="args") + parser = ( + reqparse.RequestParser() + .add_argument("include_secret", type=inputs.boolean, default=False, location="args") + .add_argument("workflow_id", type=str, location="args") + ) args = parser.parse_args() return { @@ -357,8 +367,7 @@ class AppNameApi(Resource): @marshal_with(app_detail_fields) @edit_permission_required def post(self, app_model): - parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=True, location="json") + parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json") args = parser.parse_args() app_service = AppService() @@ -391,9 +400,11 @@ class AppIconApi(Resource): @marshal_with(app_detail_fields) @edit_permission_required def post(self, app_model): - parser = reqparse.RequestParser() - parser.add_argument("icon", type=str, location="json") - parser.add_argument("icon_background", type=str, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("icon", type=str, location="json") + .add_argument("icon_background", type=str, location="json") + ) args = parser.parse_args() app_service = AppService() @@ -421,8 +432,7 @@ class AppSiteStatus(Resource): @marshal_with(app_detail_fields) @edit_permission_required def post(self, app_model): - parser = reqparse.RequestParser() - parser.add_argument("enable_site", type=bool, required=True, location="json") + parser = reqparse.RequestParser().add_argument("enable_site", type=bool, required=True, location="json") args = parser.parse_args() app_service = AppService() @@ -454,8 +464,7 @@ class AppApiStatus(Resource): if not current_user.is_admin_or_owner: raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("enable_api", type=bool, required=True, location="json") + parser = reqparse.RequestParser().add_argument("enable_api", type=bool, required=True, location="json") args = parser.parse_args() app_service = AppService() @@ -499,9 +508,11 @@ class AppTraceApi(Resource): @edit_permission_required def post(self, app_id): # add app trace - parser = reqparse.RequestParser() - parser.add_argument("enabled", type=bool, required=True, location="json") - parser.add_argument("tracing_provider", type=str, required=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("enabled", type=bool, required=True, location="json") + .add_argument("tracing_provider", type=str, required=True, location="json") + ) args = parser.parse_args() OpsTraceManager.update_app_tracing_config( diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index 5e7ea6d481..d902c129ad 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -30,16 +30,18 @@ class AppImportApi(Resource): def post(self): # Check user role first current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("mode", type=str, required=True, location="json") - parser.add_argument("yaml_content", type=str, location="json") - parser.add_argument("yaml_url", type=str, location="json") - parser.add_argument("name", type=str, location="json") - parser.add_argument("description", type=str, location="json") - parser.add_argument("icon_type", type=str, location="json") - parser.add_argument("icon", type=str, location="json") - parser.add_argument("icon_background", type=str, location="json") - parser.add_argument("app_id", type=str, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("mode", type=str, required=True, location="json") + .add_argument("yaml_content", type=str, location="json") + .add_argument("yaml_url", type=str, location="json") + .add_argument("name", type=str, location="json") + .add_argument("description", type=str, location="json") + .add_argument("icon_type", type=str, location="json") + .add_argument("icon", type=str, location="json") + .add_argument("icon_background", type=str, location="json") + .add_argument("app_id", type=str, location="json") + ) args = parser.parse_args() # Create service with session diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 7d659dae0d..8170ba271a 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -111,11 +111,13 @@ class ChatMessageTextApi(Resource): @account_initialization_required def post(self, app_model: App): try: - parser = reqparse.RequestParser() - parser.add_argument("message_id", type=str, location="json") - parser.add_argument("text", type=str, location="json") - parser.add_argument("voice", type=str, location="json") - parser.add_argument("streaming", type=bool, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("message_id", type=str, location="json") + .add_argument("text", type=str, location="json") + .add_argument("voice", type=str, location="json") + .add_argument("streaming", type=bool, location="json") + ) args = parser.parse_args() message_id = args.get("message_id", None) @@ -166,8 +168,7 @@ class TextModesApi(Resource): @account_initialization_required def get(self, app_model): try: - parser = reqparse.RequestParser() - parser.add_argument("language", type=str, required=True, location="args") + parser = reqparse.RequestParser().add_argument("language", type=str, required=True, location="args") args = parser.parse_args() response = AudioService.transcript_tts_voices( diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index d69f05f23e..d7bc3cc20d 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -64,13 +64,15 @@ class CompletionMessageApi(Resource): @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) def post(self, app_model): - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, location="json") - parser.add_argument("query", type=str, location="json", default="") - parser.add_argument("files", type=list, required=False, location="json") - parser.add_argument("model_config", type=dict, required=True, location="json") - parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") + parser = ( + reqparse.RequestParser() + .add_argument("inputs", type=dict, required=True, location="json") + .add_argument("query", type=str, location="json", default="") + .add_argument("files", type=list, required=False, location="json") + .add_argument("model_config", type=dict, required=True, location="json") + .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + .add_argument("retriever_from", type=str, required=False, default="dev", location="json") + ) args = parser.parse_args() streaming = args["response_mode"] != "blocking" @@ -153,15 +155,17 @@ class ChatMessageApi(Resource): @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) @edit_permission_required def post(self, app_model): - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, location="json") - parser.add_argument("query", type=str, required=True, location="json") - parser.add_argument("files", type=list, required=False, location="json") - parser.add_argument("model_config", type=dict, required=True, location="json") - parser.add_argument("conversation_id", type=uuid_value, location="json") - parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json") - parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") + parser = ( + reqparse.RequestParser() + .add_argument("inputs", type=dict, required=True, location="json") + .add_argument("query", type=str, required=True, location="json") + .add_argument("files", type=list, required=False, location="json") + .add_argument("model_config", type=dict, required=True, location="json") + .add_argument("conversation_id", type=uuid_value, location="json") + .add_argument("parent_message_id", type=uuid_value, required=False, location="json") + .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + .add_argument("retriever_from", type=str, required=False, default="dev", location="json") + ) args = parser.parse_args() streaming = args["response_mode"] != "blocking" diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 779be62973..d5fa70d678 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -59,15 +59,21 @@ class CompletionConversationApi(Resource): @edit_permission_required def get(self, app_model): current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("keyword", type=str, location="args") - parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") - parser.add_argument( - "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" + parser = ( + reqparse.RequestParser() + .add_argument("keyword", type=str, location="args") + .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + .add_argument( + "annotation_status", + type=str, + choices=["annotated", "not_annotated", "all"], + default="all", + location="args", + ) + .add_argument("page", type=int_range(1, 99999), default=1, location="args") + .add_argument("limit", type=int_range(1, 100), default=20, location="args") ) - parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") - parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") args = parser.parse_args() query = sa.select(Conversation).where( @@ -206,23 +212,29 @@ class ChatConversationApi(Resource): @edit_permission_required def get(self, app_model): current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("keyword", type=str, location="args") - parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") - parser.add_argument( - "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" - ) - parser.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args") - parser.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args") - parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - parser.add_argument( - "sort_by", - type=str, - choices=["created_at", "-created_at", "updated_at", "-updated_at"], - required=False, - default="-updated_at", - location="args", + parser = ( + reqparse.RequestParser() + .add_argument("keyword", type=str, location="args") + .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + .add_argument( + "annotation_status", + type=str, + choices=["annotated", "not_annotated", "all"], + default="all", + location="args", + ) + .add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args") + .add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args") + .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + .add_argument( + "sort_by", + type=str, + choices=["created_at", "-created_at", "updated_at", "-updated_at"], + required=False, + default="-updated_at", + location="args", + ) ) args = parser.parse_args() diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index 8a65a89963..d4c0b5697f 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -29,8 +29,7 @@ class ConversationVariablesApi(Resource): @get_app_model(mode=AppMode.ADVANCED_CHAT) @marshal_with(paginated_conversation_variable_fields) def get(self, app_model): - parser = reqparse.RequestParser() - parser.add_argument("conversation_id", type=str, location="args") + parser = reqparse.RequestParser().add_argument("conversation_id", type=str, location="args") args = parser.parse_args() stmt = ( diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 4a9b6e7801..b6ca97ab4f 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -42,10 +42,12 @@ class RuleGenerateApi(Resource): @login_required @account_initialization_required def post(self): - parser = reqparse.RequestParser() - parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") - parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") - parser.add_argument("no_variable", type=bool, required=True, default=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("instruction", type=str, required=True, nullable=False, location="json") + .add_argument("model_config", type=dict, required=True, nullable=False, location="json") + .add_argument("no_variable", type=bool, required=True, default=False, location="json") + ) args = parser.parse_args() _, current_tenant_id = current_account_with_tenant() @@ -92,11 +94,13 @@ class RuleCodeGenerateApi(Resource): @login_required @account_initialization_required def post(self): - parser = reqparse.RequestParser() - parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") - parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") - parser.add_argument("no_variable", type=bool, required=True, default=False, location="json") - parser.add_argument("code_language", type=str, required=False, default="javascript", location="json") + parser = ( + reqparse.RequestParser() + .add_argument("instruction", type=str, required=True, nullable=False, location="json") + .add_argument("model_config", type=dict, required=True, nullable=False, location="json") + .add_argument("no_variable", type=bool, required=True, default=False, location="json") + .add_argument("code_language", type=str, required=False, default="javascript", location="json") + ) args = parser.parse_args() _, current_tenant_id = current_account_with_tenant() @@ -139,9 +143,11 @@ class RuleStructuredOutputGenerateApi(Resource): @login_required @account_initialization_required def post(self): - parser = reqparse.RequestParser() - parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") - parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("instruction", type=str, required=True, nullable=False, location="json") + .add_argument("model_config", type=dict, required=True, nullable=False, location="json") + ) args = parser.parse_args() _, current_tenant_id = current_account_with_tenant() @@ -188,14 +194,16 @@ class InstructionGenerateApi(Resource): @login_required @account_initialization_required def post(self): - parser = reqparse.RequestParser() - parser.add_argument("flow_id", type=str, required=True, default="", location="json") - parser.add_argument("node_id", type=str, required=False, default="", location="json") - parser.add_argument("current", type=str, required=False, default="", location="json") - parser.add_argument("language", type=str, required=False, default="javascript", location="json") - parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") - parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") - parser.add_argument("ideal_output", type=str, required=False, default="", location="json") + parser = ( + reqparse.RequestParser() + .add_argument("flow_id", type=str, required=True, default="", location="json") + .add_argument("node_id", type=str, required=False, default="", location="json") + .add_argument("current", type=str, required=False, default="", location="json") + .add_argument("language", type=str, required=False, default="javascript", location="json") + .add_argument("instruction", type=str, required=True, nullable=False, location="json") + .add_argument("model_config", type=dict, required=True, nullable=False, location="json") + .add_argument("ideal_output", type=str, required=False, default="", location="json") + ) args = parser.parse_args() _, current_tenant_id = current_account_with_tenant() code_template = ( @@ -293,8 +301,7 @@ class InstructionGenerationTemplateApi(Resource): @login_required @account_initialization_required def post(self): - parser = reqparse.RequestParser() - parser.add_argument("type", type=str, required=True, default=False, location="json") + parser = reqparse.RequestParser().add_argument("type", type=str, required=True, default=False, location="json") args = parser.parse_args() match args["type"]: case "prompt": diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index 599f5adb34..3700c6b1d0 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -55,9 +55,11 @@ class AppMCPServerController(Resource): @edit_permission_required def post(self, app_model): _, current_tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("description", type=str, required=False, location="json") - parser.add_argument("parameters", type=dict, required=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("description", type=str, required=False, location="json") + .add_argument("parameters", type=dict, required=True, location="json") + ) args = parser.parse_args() description = args.get("description") @@ -101,11 +103,13 @@ class AppMCPServerController(Resource): @marshal_with(app_server_fields) @edit_permission_required def put(self, app_model): - parser = reqparse.RequestParser() - parser.add_argument("id", type=str, required=True, location="json") - parser.add_argument("description", type=str, required=False, location="json") - parser.add_argument("parameters", type=dict, required=True, location="json") - parser.add_argument("status", type=str, required=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("id", type=str, required=True, location="json") + .add_argument("description", type=str, required=False, location="json") + .add_argument("parameters", type=dict, required=True, location="json") + .add_argument("status", type=str, required=False, location="json") + ) args = parser.parse_args() server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first() if not server: diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 005cff75fc..7e0ae370ef 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -63,10 +63,12 @@ class ChatMessageListApi(Resource): @marshal_with(message_infinite_scroll_pagination_fields) @edit_permission_required def get(self, app_model): - parser = reqparse.RequestParser() - parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") - parser.add_argument("first_id", type=uuid_value, location="args") - parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + parser = ( + reqparse.RequestParser() + .add_argument("conversation_id", required=True, type=uuid_value, location="args") + .add_argument("first_id", type=uuid_value, location="args") + .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + ) args = parser.parse_args() conversation = ( @@ -154,9 +156,11 @@ class MessageFeedbackApi(Resource): def post(self, app_model): current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("message_id", required=True, type=uuid_value, location="json") - parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") + parser = ( + reqparse.RequestParser() + .add_argument("message_id", required=True, type=uuid_value, location="json") + .add_argument("rating", type=str, choices=["like", "dislike", None], location="json") + ) args = parser.parse_args() message_id = str(args["message_id"]) @@ -216,11 +220,13 @@ class MessageAnnotationApi(Resource): @account_initialization_required @edit_permission_required def post(self, app_model): - parser = reqparse.RequestParser() - parser.add_argument("message_id", required=False, type=uuid_value, location="json") - parser.add_argument("question", required=True, type=str, location="json") - parser.add_argument("answer", required=True, type=str, location="json") - parser.add_argument("annotation_reply", required=False, type=dict, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("message_id", required=False, type=uuid_value, location="json") + .add_argument("question", required=True, type=str, location="json") + .add_argument("answer", required=True, type=str, location="json") + .add_argument("annotation_reply", required=False, type=dict, location="json") + ) args = parser.parse_args() annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id) diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py index 981974e842..1d80314774 100644 --- a/api/controllers/console/app/ops_trace.py +++ b/api/controllers/console/app/ops_trace.py @@ -30,8 +30,7 @@ class TraceAppConfigApi(Resource): @login_required @account_initialization_required def get(self, app_id): - parser = reqparse.RequestParser() - parser.add_argument("tracing_provider", type=str, required=True, location="args") + parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args") args = parser.parse_args() try: @@ -63,9 +62,11 @@ class TraceAppConfigApi(Resource): @account_initialization_required def post(self, app_id): """Create a new trace app configuration""" - parser = reqparse.RequestParser() - parser.add_argument("tracing_provider", type=str, required=True, location="json") - parser.add_argument("tracing_config", type=dict, required=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("tracing_provider", type=str, required=True, location="json") + .add_argument("tracing_config", type=dict, required=True, location="json") + ) args = parser.parse_args() try: @@ -99,9 +100,11 @@ class TraceAppConfigApi(Resource): @account_initialization_required def patch(self, app_id): """Update an existing trace app configuration""" - parser = reqparse.RequestParser() - parser.add_argument("tracing_provider", type=str, required=True, location="json") - parser.add_argument("tracing_config", type=dict, required=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("tracing_provider", type=str, required=True, location="json") + .add_argument("tracing_config", type=dict, required=True, location="json") + ) args = parser.parse_args() try: @@ -129,8 +132,7 @@ class TraceAppConfigApi(Resource): @account_initialization_required def delete(self, app_id): """Delete an existing trace app configuration""" - parser = reqparse.RequestParser() - parser.add_argument("tracing_provider", type=str, required=True, location="args") + parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args") args = parser.parse_args() try: diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 1da704efcc..c4d640bf0e 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -13,25 +13,31 @@ from models import Site def parse_app_site_args(): - parser = reqparse.RequestParser() - parser.add_argument("title", type=str, required=False, location="json") - parser.add_argument("icon_type", type=str, required=False, location="json") - parser.add_argument("icon", type=str, required=False, location="json") - parser.add_argument("icon_background", type=str, required=False, location="json") - parser.add_argument("description", type=str, required=False, location="json") - parser.add_argument("default_language", type=supported_language, required=False, location="json") - parser.add_argument("chat_color_theme", type=str, required=False, location="json") - parser.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json") - parser.add_argument("customize_domain", type=str, required=False, location="json") - parser.add_argument("copyright", type=str, required=False, location="json") - parser.add_argument("privacy_policy", type=str, required=False, location="json") - parser.add_argument("custom_disclaimer", type=str, required=False, location="json") - parser.add_argument( - "customize_token_strategy", type=str, choices=["must", "allow", "not_allow"], required=False, location="json" + parser = ( + reqparse.RequestParser() + .add_argument("title", type=str, required=False, location="json") + .add_argument("icon_type", type=str, required=False, location="json") + .add_argument("icon", type=str, required=False, location="json") + .add_argument("icon_background", type=str, required=False, location="json") + .add_argument("description", type=str, required=False, location="json") + .add_argument("default_language", type=supported_language, required=False, location="json") + .add_argument("chat_color_theme", type=str, required=False, location="json") + .add_argument("chat_color_theme_inverted", type=bool, required=False, location="json") + .add_argument("customize_domain", type=str, required=False, location="json") + .add_argument("copyright", type=str, required=False, location="json") + .add_argument("privacy_policy", type=str, required=False, location="json") + .add_argument("custom_disclaimer", type=str, required=False, location="json") + .add_argument( + "customize_token_strategy", + type=str, + choices=["must", "allow", "not_allow"], + required=False, + location="json", + ) + .add_argument("prompt_public", type=bool, required=False, location="json") + .add_argument("show_workflow_steps", type=bool, required=False, location="json") + .add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json") ) - parser.add_argument("prompt_public", type=bool, required=False, location="json") - parser.add_argument("show_workflow_steps", type=bool, required=False, location="json") - parser.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json") return parser.parse_args() diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index cfe5b3ff17..0917a6e53c 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -38,9 +38,11 @@ class DailyMessageStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser = ( + reqparse.RequestParser() + .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + ) args = parser.parse_args() sql_query = """SELECT @@ -111,9 +113,11 @@ class DailyConversationStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser = ( + reqparse.RequestParser() + .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + ) args = parser.parse_args() assert account.timezone is not None timezone = pytz.timezone(account.timezone) @@ -177,9 +181,11 @@ class DailyTerminalsStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser = ( + reqparse.RequestParser() + .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + ) args = parser.parse_args() sql_query = """SELECT @@ -249,9 +255,11 @@ class DailyTokenCostStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser = ( + reqparse.RequestParser() + .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + ) args = parser.parse_args() sql_query = """SELECT @@ -324,9 +332,11 @@ class AverageSessionInteractionStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser = ( + reqparse.RequestParser() + .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + ) args = parser.parse_args() sql_query = """SELECT @@ -415,9 +425,11 @@ class UserSatisfactionRateStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser = ( + reqparse.RequestParser() + .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + ) args = parser.parse_args() sql_query = """SELECT @@ -496,9 +508,11 @@ class AverageResponseTimeStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser = ( + reqparse.RequestParser() + .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + ) args = parser.parse_args() sql_query = """SELECT @@ -568,9 +582,11 @@ class TokensPerSecondStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser = ( + reqparse.RequestParser() + .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + ) args = parser.parse_args() sql_query = """SELECT diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 172a80736f..56771ed420 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -115,12 +115,14 @@ class DraftWorkflowApi(Resource): content_type = request.headers.get("Content-Type", "") if "application/json" in content_type: - parser = reqparse.RequestParser() - parser.add_argument("graph", type=dict, required=True, nullable=False, location="json") - parser.add_argument("features", type=dict, required=True, nullable=False, location="json") - parser.add_argument("hash", type=str, required=False, location="json") - parser.add_argument("environment_variables", type=list, required=True, location="json") - parser.add_argument("conversation_variables", type=list, required=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("graph", type=dict, required=True, nullable=False, location="json") + .add_argument("features", type=dict, required=True, nullable=False, location="json") + .add_argument("hash", type=str, required=False, location="json") + .add_argument("environment_variables", type=list, required=True, location="json") + .add_argument("conversation_variables", type=list, required=False, location="json") + ) args = parser.parse_args() elif "text/plain" in content_type: try: @@ -202,12 +204,14 @@ class AdvancedChatDraftWorkflowRunApi(Resource): """ current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, location="json") - parser.add_argument("query", type=str, required=True, location="json", default="") - parser.add_argument("files", type=list, location="json") - parser.add_argument("conversation_id", type=uuid_value, location="json") - parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("inputs", type=dict, location="json") + .add_argument("query", type=str, required=True, location="json", default="") + .add_argument("files", type=list, location="json") + .add_argument("conversation_id", type=uuid_value, location="json") + .add_argument("parent_message_id", type=uuid_value, required=False, location="json") + ) args = parser.parse_args() @@ -261,8 +265,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource): Run draft workflow iteration node """ current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, location="json") + parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json") args = parser.parse_args() try: @@ -309,8 +312,7 @@ class WorkflowDraftRunIterationNodeApi(Resource): Run draft workflow iteration node """ current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, location="json") + parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json") args = parser.parse_args() try: @@ -357,8 +359,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource): Run draft workflow loop node """ current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, location="json") + parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json") args = parser.parse_args() try: @@ -405,8 +406,7 @@ class WorkflowDraftRunLoopNodeApi(Resource): Run draft workflow loop node """ current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, location="json") + parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json") args = parser.parse_args() try: @@ -452,9 +452,11 @@ class DraftWorkflowRunApi(Resource): Run draft workflow """ current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") - parser.add_argument("files", type=list, required=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("inputs", type=dict, required=True, nullable=False, location="json") + .add_argument("files", type=list, required=False, location="json") + ) args = parser.parse_args() external_trace_id = get_external_trace_id(request) @@ -529,10 +531,12 @@ class DraftWorkflowNodeRunApi(Resource): Run draft workflow node """ current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") - parser.add_argument("query", type=str, required=False, location="json", default="") - parser.add_argument("files", type=list, location="json", default=[]) + parser = ( + reqparse.RequestParser() + .add_argument("inputs", type=dict, required=True, nullable=False, location="json") + .add_argument("query", type=str, required=False, location="json", default="") + .add_argument("files", type=list, location="json", default=[]) + ) args = parser.parse_args() user_inputs = args.get("inputs") @@ -594,9 +598,11 @@ class PublishedWorkflowApi(Resource): Publish workflow """ current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("marked_name", type=str, required=False, default="", location="json") - parser.add_argument("marked_comment", type=str, required=False, default="", location="json") + parser = ( + reqparse.RequestParser() + .add_argument("marked_name", type=str, required=False, default="", location="json") + .add_argument("marked_comment", type=str, required=False, default="", location="json") + ) args = parser.parse_args() # Validate name and comment length @@ -668,8 +674,7 @@ class DefaultBlockConfigApi(Resource): """ Get default block config """ - parser = reqparse.RequestParser() - parser.add_argument("q", type=str, location="args") + parser = reqparse.RequestParser().add_argument("q", type=str, location="args") args = parser.parse_args() q = args.get("q") @@ -708,11 +713,13 @@ class ConvertToWorkflowApi(Resource): current_user, _ = current_account_with_tenant() if request.data: - parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=False, nullable=True, location="json") - parser.add_argument("icon_type", type=str, required=False, nullable=True, location="json") - parser.add_argument("icon", type=str, required=False, nullable=True, location="json") - parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("name", type=str, required=False, nullable=True, location="json") + .add_argument("icon_type", type=str, required=False, nullable=True, location="json") + .add_argument("icon", type=str, required=False, nullable=True, location="json") + .add_argument("icon_background", type=str, required=False, nullable=True, location="json") + ) args = parser.parse_args() else: args = {} @@ -745,11 +752,13 @@ class PublishedAllWorkflowApi(Resource): """ current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") - parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") - parser.add_argument("user_id", type=str, required=False, location="args") - parser.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args") + parser = ( + reqparse.RequestParser() + .add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") + .add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") + .add_argument("user_id", type=str, required=False, location="args") + .add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args") + ) args = parser.parse_args() page = int(args.get("page", 1)) limit = int(args.get("limit", 10)) @@ -808,9 +817,11 @@ class WorkflowByIdApi(Resource): Update workflow attributes """ current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("marked_name", type=str, required=False, location="json") - parser.add_argument("marked_comment", type=str, required=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("marked_name", type=str, required=False, location="json") + .add_argument("marked_comment", type=str, required=False, location="json") + ) args = parser.parse_args() # Validate name and comment length diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 8e24be4fa7..cbf4e84ff0 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -42,33 +42,35 @@ class WorkflowAppLogApi(Resource): """ Get workflow app logs """ - parser = reqparse.RequestParser() - parser.add_argument("keyword", type=str, location="args") - parser.add_argument( - "status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args" + parser = ( + reqparse.RequestParser() + .add_argument("keyword", type=str, location="args") + .add_argument( + "status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args" + ) + .add_argument( + "created_at__before", type=str, location="args", help="Filter logs created before this timestamp" + ) + .add_argument( + "created_at__after", type=str, location="args", help="Filter logs created after this timestamp" + ) + .add_argument( + "created_by_end_user_session_id", + type=str, + location="args", + required=False, + default=None, + ) + .add_argument( + "created_by_account", + type=str, + location="args", + required=False, + default=None, + ) + .add_argument("page", type=int_range(1, 99999), default=1, location="args") + .add_argument("limit", type=int_range(1, 100), default=20, location="args") ) - parser.add_argument( - "created_at__before", type=str, location="args", help="Filter logs created before this timestamp" - ) - parser.add_argument( - "created_at__after", type=str, location="args", help="Filter logs created after this timestamp" - ) - parser.add_argument( - "created_by_end_user_session_id", - type=str, - location="args", - required=False, - default=None, - ) - parser.add_argument( - "created_by_account", - type=str, - location="args", - required=False, - default=None, - ) - parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") - parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") args = parser.parse_args() args.status = WorkflowExecutionStatus(args.status) if args.status else None diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 5e865dc4c1..0722eb40d2 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -57,16 +57,18 @@ def _serialize_var_value(variable: WorkflowDraftVariable): def _create_pagination_parser(): - parser = reqparse.RequestParser() - parser.add_argument( - "page", - type=inputs.int_range(1, 100_000), - required=False, - default=1, - location="args", - help="the page of data requested", + parser = ( + reqparse.RequestParser() + .add_argument( + "page", + type=inputs.int_range(1, 100_000), + required=False, + default=1, + location="args", + help="the page of data requested", + ) + .add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") ) - parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") return parser @@ -319,10 +321,11 @@ class VariableApi(Resource): # "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4" # } - parser = reqparse.RequestParser() - parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json") - # Parse 'value' field as-is to maintain its original data structure - parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json") + .add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json") + ) draft_var_srv = WorkflowDraftVariableService( session=db.session(), diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index 8f7f936c9b..bbea04640a 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -30,9 +30,11 @@ class WorkflowDailyRunsStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser = ( + reqparse.RequestParser() + .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + ) args = parser.parse_args() sql_query = """SELECT @@ -98,9 +100,11 @@ class WorkflowDailyTerminalsStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser = ( + reqparse.RequestParser() + .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + ) args = parser.parse_args() sql_query = """SELECT @@ -166,9 +170,11 @@ class WorkflowDailyTokenCostStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser = ( + reqparse.RequestParser() + .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + ) args = parser.parse_args() sql_query = """SELECT @@ -239,9 +245,11 @@ class WorkflowAverageAppInteractionStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser = ( + reqparse.RequestParser() + .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + ) args = parser.parse_args() sql_query = """SELECT diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index 06d2b936b7..2eeef079a1 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -10,15 +10,11 @@ from libs.helper import StrLen, email, extract_remote_ip, timezone from models import AccountStatus from services.account_service import AccountService, RegisterService -active_check_parser = reqparse.RequestParser() -active_check_parser.add_argument( - "workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID" -) -active_check_parser.add_argument( - "email", type=email, required=False, nullable=True, location="args", help="Email address" -) -active_check_parser.add_argument( - "token", type=str, required=True, nullable=False, location="args", help="Activation token" +active_check_parser = ( + reqparse.RequestParser() + .add_argument("workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID") + .add_argument("email", type=email, required=False, nullable=True, location="args", help="Email address") + .add_argument("token", type=str, required=True, nullable=False, location="args", help="Activation token") ) @@ -60,15 +56,15 @@ class ActivateCheckApi(Resource): return {"is_valid": False} -active_parser = reqparse.RequestParser() -active_parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json") -active_parser.add_argument("email", type=email, required=False, nullable=True, location="json") -active_parser.add_argument("token", type=str, required=True, nullable=False, location="json") -active_parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") -active_parser.add_argument( - "interface_language", type=supported_language, required=True, nullable=False, location="json" +active_parser = ( + reqparse.RequestParser() + .add_argument("workspace_id", type=str, required=False, nullable=True, location="json") + .add_argument("email", type=email, required=False, nullable=True, location="json") + .add_argument("token", type=str, required=True, nullable=False, location="json") + .add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") + .add_argument("interface_language", type=supported_language, required=True, nullable=False, location="json") + .add_argument("timezone", type=timezone, required=True, nullable=False, location="json") ) -active_parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json") @console_ns.route("/activate") diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index d9ab7de29b..a06435267b 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -45,10 +45,12 @@ class ApiKeyAuthDataSourceBinding(Resource): if not current_user.is_admin_or_owner: raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("category", type=str, required=True, nullable=False, location="json") - parser.add_argument("provider", type=str, required=True, nullable=False, location="json") - parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("category", type=str, required=True, nullable=False, location="json") + .add_argument("provider", type=str, required=True, nullable=False, location="json") + .add_argument("credentials", type=dict, required=True, nullable=False, location="json") + ) args = parser.parse_args() ApiKeyAuthService.validate_api_key_auth_args(args) try: diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py index cabd118d23..fe2bb54e0b 100644 --- a/api/controllers/console/auth/email_register.py +++ b/api/controllers/console/auth/email_register.py @@ -31,9 +31,11 @@ class EmailRegisterSendEmailApi(Resource): @email_password_login_enabled @email_register_enabled def post(self): - parser = reqparse.RequestParser() - parser.add_argument("email", type=email, required=True, location="json") - parser.add_argument("language", type=str, required=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("email", type=email, required=True, location="json") + .add_argument("language", type=str, required=False, location="json") + ) args = parser.parse_args() ip_address = extract_remote_ip(request) @@ -59,10 +61,12 @@ class EmailRegisterCheckApi(Resource): @email_password_login_enabled @email_register_enabled def post(self): - parser = reqparse.RequestParser() - parser.add_argument("email", type=str, required=True, location="json") - parser.add_argument("code", type=str, required=True, location="json") - parser.add_argument("token", type=str, required=True, nullable=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("email", type=str, required=True, location="json") + .add_argument("code", type=str, required=True, location="json") + .add_argument("token", type=str, required=True, nullable=False, location="json") + ) args = parser.parse_args() user_email = args["email"] @@ -100,10 +104,12 @@ class EmailRegisterResetApi(Resource): @email_password_login_enabled @email_register_enabled def post(self): - parser = reqparse.RequestParser() - parser.add_argument("token", type=str, required=True, nullable=False, location="json") - parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") - parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("token", type=str, required=True, nullable=False, location="json") + .add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") + .add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") + ) args = parser.parse_args() # Validate passwords match diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 102d33966e..6be6ad51fe 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -54,9 +54,11 @@ class ForgotPasswordSendEmailApi(Resource): @setup_required @email_password_login_enabled def post(self): - parser = reqparse.RequestParser() - parser.add_argument("email", type=email, required=True, location="json") - parser.add_argument("language", type=str, required=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("email", type=email, required=True, location="json") + .add_argument("language", type=str, required=False, location="json") + ) args = parser.parse_args() ip_address = extract_remote_ip(request) @@ -111,10 +113,12 @@ class ForgotPasswordCheckApi(Resource): @setup_required @email_password_login_enabled def post(self): - parser = reqparse.RequestParser() - parser.add_argument("email", type=str, required=True, location="json") - parser.add_argument("code", type=str, required=True, location="json") - parser.add_argument("token", type=str, required=True, nullable=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("email", type=str, required=True, location="json") + .add_argument("code", type=str, required=True, location="json") + .add_argument("token", type=str, required=True, nullable=False, location="json") + ) args = parser.parse_args() user_email = args["email"] @@ -169,10 +173,12 @@ class ForgotPasswordResetApi(Resource): @setup_required @email_password_login_enabled def post(self): - parser = reqparse.RequestParser() - parser.add_argument("token", type=str, required=True, nullable=False, location="json") - parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") - parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("token", type=str, required=True, nullable=False, location="json") + .add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") + .add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") + ) args = parser.parse_args() # Validate passwords match diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index e4bbbf107b..3696c88346 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -40,11 +40,13 @@ class LoginApi(Resource): @email_password_login_enabled def post(self): """Authenticate user and login.""" - parser = reqparse.RequestParser() - parser.add_argument("email", type=email, required=True, location="json") - parser.add_argument("password", type=str, required=True, location="json") - parser.add_argument("remember_me", type=bool, required=False, default=False, location="json") - parser.add_argument("invite_token", type=str, required=False, default=None, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("email", type=email, required=True, location="json") + .add_argument("password", type=str, required=True, location="json") + .add_argument("remember_me", type=bool, required=False, default=False, location="json") + .add_argument("invite_token", type=str, required=False, default=None, location="json") + ) args = parser.parse_args() if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]): @@ -108,9 +110,11 @@ class ResetPasswordSendEmailApi(Resource): @setup_required @email_password_login_enabled def post(self): - parser = reqparse.RequestParser() - parser.add_argument("email", type=email, required=True, location="json") - parser.add_argument("language", type=str, required=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("email", type=email, required=True, location="json") + .add_argument("language", type=str, required=False, location="json") + ) args = parser.parse_args() if args["language"] is not None and args["language"] == "zh-Hans": @@ -136,9 +140,11 @@ class ResetPasswordSendEmailApi(Resource): class EmailCodeLoginSendEmailApi(Resource): @setup_required def post(self): - parser = reqparse.RequestParser() - parser.add_argument("email", type=email, required=True, location="json") - parser.add_argument("language", type=str, required=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("email", type=email, required=True, location="json") + .add_argument("language", type=str, required=False, location="json") + ) args = parser.parse_args() ip_address = extract_remote_ip(request) @@ -169,10 +175,12 @@ class EmailCodeLoginSendEmailApi(Resource): class EmailCodeLoginApi(Resource): @setup_required def post(self): - parser = reqparse.RequestParser() - parser.add_argument("email", type=str, required=True, location="json") - parser.add_argument("code", type=str, required=True, location="json") - parser.add_argument("token", type=str, required=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("email", type=str, required=True, location="json") + .add_argument("code", type=str, required=True, location="json") + .add_argument("token", type=str, required=True, location="json") + ) args = parser.parse_args() user_email = args["email"] @@ -225,8 +233,7 @@ class EmailCodeLoginApi(Resource): @console_ns.route("/refresh-token") class RefreshTokenApi(Resource): def post(self): - parser = reqparse.RequestParser() - parser.add_argument("refresh_token", type=str, required=True, location="json") + parser = reqparse.RequestParser().add_argument("refresh_token", type=str, required=True, location="json") args = parser.parse_args() try: diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py index 188ef7f622..5e12aa7d03 100644 --- a/api/controllers/console/auth/oauth_server.py +++ b/api/controllers/console/auth/oauth_server.py @@ -23,8 +23,7 @@ T = TypeVar("T") def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]): @wraps(view) def decorated(self: T, *args: P.args, **kwargs: P.kwargs): - parser = reqparse.RequestParser() - parser.add_argument("client_id", type=str, required=True, location="json") + parser = reqparse.RequestParser().add_argument("client_id", type=str, required=True, location="json") parsed_args = parser.parse_args() client_id = parsed_args.get("client_id") if not client_id: @@ -90,8 +89,7 @@ class OAuthServerAppApi(Resource): @setup_required @oauth_server_client_id_required def post(self, oauth_provider_app: OAuthProviderApp): - parser = reqparse.RequestParser() - parser.add_argument("redirect_uri", type=str, required=True, location="json") + parser = reqparse.RequestParser().add_argument("redirect_uri", type=str, required=True, location="json") parsed_args = parser.parse_args() redirect_uri = parsed_args.get("redirect_uri") @@ -132,12 +130,14 @@ class OAuthServerUserTokenApi(Resource): @setup_required @oauth_server_client_id_required def post(self, oauth_provider_app: OAuthProviderApp): - parser = reqparse.RequestParser() - parser.add_argument("grant_type", type=str, required=True, location="json") - parser.add_argument("code", type=str, required=False, location="json") - parser.add_argument("client_secret", type=str, required=False, location="json") - parser.add_argument("redirect_uri", type=str, required=False, location="json") - parser.add_argument("refresh_token", type=str, required=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("grant_type", type=str, required=True, location="json") + .add_argument("code", type=str, required=False, location="json") + .add_argument("client_secret", type=str, required=False, location="json") + .add_argument("redirect_uri", type=str, required=False, location="json") + .add_argument("refresh_token", type=str, required=False, location="json") + ) parsed_args = parser.parse_args() try: diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 5c89b29057..705f5970dd 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -14,9 +14,11 @@ class Subscription(Resource): @only_edition_cloud def get(self): current_user, current_tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"]) - parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"]) + parser = ( + reqparse.RequestParser() + .add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"]) + .add_argument("interval", type=str, required=True, location="args", choices=["month", "year"]) + ) args = parser.parse_args() BillingService.is_tenant_owner_or_admin(current_user) return BillingService.get_subscription(args["plan"], args["interval"], current_user.email, current_tenant_id) diff --git a/api/controllers/console/billing/compliance.py b/api/controllers/console/billing/compliance.py index 3b32fe29a1..2a6889968c 100644 --- a/api/controllers/console/billing/compliance.py +++ b/api/controllers/console/billing/compliance.py @@ -17,8 +17,7 @@ class ComplianceApi(Resource): @only_edition_cloud def get(self): current_user, current_tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("doc_name", type=str, required=True, location="args") + parser = reqparse.RequestParser().add_argument("doc_name", type=str, required=True, location="args") args = parser.parse_args() ip_address = extract_remote_ip(request) diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 058ef4408a..ef66053075 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -246,12 +246,12 @@ class DataSourceNotionApi(Resource): def post(self): _, current_tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json") - parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") - parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") - parser.add_argument( - "doc_language", type=str, default="English", required=False, nullable=False, location="json" + parser = ( + reqparse.RequestParser() + .add_argument("notion_info_list", type=list, required=True, nullable=True, location="json") + .add_argument("process_rule", type=dict, required=True, nullable=True, location="json") + .add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + .add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json") ) args = parser.parse_args() # validate args diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 4a9e0789fb..50bf48450c 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -206,48 +206,50 @@ class DatasetListApi(Resource): @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") def post(self): - parser = reqparse.RequestParser() - parser.add_argument( - "name", - nullable=False, - required=True, - help="type is required. Name must be between 1 to 40 characters.", - type=_validate_name, - ) - parser.add_argument( - "description", - type=validate_description_length, - nullable=True, - required=False, - default="", - ) - parser.add_argument( - "indexing_technique", - type=str, - location="json", - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, - help="Invalid indexing technique.", - ) - parser.add_argument( - "external_knowledge_api_id", - type=str, - nullable=True, - required=False, - ) - parser.add_argument( - "provider", - type=str, - nullable=True, - choices=Dataset.PROVIDER_LIST, - required=False, - default="vendor", - ) - parser.add_argument( - "external_knowledge_id", - type=str, - nullable=True, - required=False, + parser = ( + reqparse.RequestParser() + .add_argument( + "name", + nullable=False, + required=True, + help="type is required. Name must be between 1 to 40 characters.", + type=_validate_name, + ) + .add_argument( + "description", + type=validate_description_length, + nullable=True, + required=False, + default="", + ) + .add_argument( + "indexing_technique", + type=str, + location="json", + choices=Dataset.INDEXING_TECHNIQUE_LIST, + nullable=True, + help="Invalid indexing technique.", + ) + .add_argument( + "external_knowledge_api_id", + type=str, + nullable=True, + required=False, + ) + .add_argument( + "provider", + type=str, + nullable=True, + choices=Dataset.PROVIDER_LIST, + required=False, + default="vendor", + ) + .add_argument( + "external_knowledge_id", + type=str, + nullable=True, + required=False, + ) ) args = parser.parse_args() current_user, current_tenant_id = current_account_with_tenant() @@ -352,70 +354,72 @@ class DatasetApi(Resource): if dataset is None: raise NotFound("Dataset not found.") - parser = reqparse.RequestParser() - parser.add_argument( - "name", - nullable=False, - help="type is required. Name must be between 1 to 40 characters.", - type=_validate_name, - ) - parser.add_argument("description", location="json", store_missing=False, type=validate_description_length) - parser.add_argument( - "indexing_technique", - type=str, - location="json", - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, - help="Invalid indexing technique.", - ) - parser.add_argument( - "permission", - type=str, - location="json", - choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), - help="Invalid permission.", - ) - parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.") - parser.add_argument( - "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider." - ) - parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.") - parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.") - - parser.add_argument( - "external_retrieval_model", - type=dict, - required=False, - nullable=True, - location="json", - help="Invalid external retrieval model.", - ) - - parser.add_argument( - "external_knowledge_id", - type=str, - required=False, - nullable=True, - location="json", - help="Invalid external knowledge id.", - ) - - parser.add_argument( - "external_knowledge_api_id", - type=str, - required=False, - nullable=True, - location="json", - help="Invalid external knowledge api id.", - ) - - parser.add_argument( - "icon_info", - type=dict, - required=False, - nullable=True, - location="json", - help="Invalid icon info.", + parser = ( + reqparse.RequestParser() + .add_argument( + "name", + nullable=False, + help="type is required. Name must be between 1 to 40 characters.", + type=_validate_name, + ) + .add_argument("description", location="json", store_missing=False, type=validate_description_length) + .add_argument( + "indexing_technique", + type=str, + location="json", + choices=Dataset.INDEXING_TECHNIQUE_LIST, + nullable=True, + help="Invalid indexing technique.", + ) + .add_argument( + "permission", + type=str, + location="json", + choices=( + DatasetPermissionEnum.ONLY_ME, + DatasetPermissionEnum.ALL_TEAM, + DatasetPermissionEnum.PARTIAL_TEAM, + ), + help="Invalid permission.", + ) + .add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.") + .add_argument( + "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider." + ) + .add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.") + .add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.") + .add_argument( + "external_retrieval_model", + type=dict, + required=False, + nullable=True, + location="json", + help="Invalid external retrieval model.", + ) + .add_argument( + "external_knowledge_id", + type=str, + required=False, + nullable=True, + location="json", + help="Invalid external knowledge id.", + ) + .add_argument( + "external_knowledge_api_id", + type=str, + required=False, + nullable=True, + location="json", + help="Invalid external knowledge api id.", + ) + .add_argument( + "icon_info", + type=dict, + required=False, + nullable=True, + location="json", + help="Invalid icon info.", + ) ) args = parser.parse_args() data = request.get_json() @@ -542,21 +546,21 @@ class DatasetIndexingEstimateApi(Resource): @login_required @account_initialization_required def post(self): - parser = reqparse.RequestParser() - parser.add_argument("info_list", type=dict, required=True, nullable=True, location="json") - parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") - parser.add_argument( - "indexing_technique", - type=str, - required=True, - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, - location="json", - ) - parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") - parser.add_argument("dataset_id", type=str, required=False, nullable=False, location="json") - parser.add_argument( - "doc_language", type=str, default="English", required=False, nullable=False, location="json" + parser = ( + reqparse.RequestParser() + .add_argument("info_list", type=dict, required=True, nullable=True, location="json") + .add_argument("process_rule", type=dict, required=True, nullable=True, location="json") + .add_argument( + "indexing_technique", + type=str, + required=True, + choices=Dataset.INDEXING_TECHNIQUE_LIST, + nullable=True, + location="json", + ) + .add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + .add_argument("dataset_id", type=str, required=False, nullable=False, location="json") + .add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json") ) args = parser.parse_args() _, current_tenant_id = current_account_with_tenant() diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 9c0c54833e..85fd0535c7 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -292,20 +292,20 @@ class DatasetDocumentListApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - parser = reqparse.RequestParser() - parser.add_argument( - "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" - ) - parser.add_argument("data_source", type=dict, required=False, location="json") - parser.add_argument("process_rule", type=dict, required=False, location="json") - parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json") - parser.add_argument("original_document_id", type=str, required=False, location="json") - parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") - parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") - parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") - parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") - parser.add_argument( - "doc_language", type=str, default="English", required=False, nullable=False, location="json" + parser = ( + reqparse.RequestParser() + .add_argument( + "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" + ) + .add_argument("data_source", type=dict, required=False, location="json") + .add_argument("process_rule", type=dict, required=False, location="json") + .add_argument("duplicate", type=bool, default=True, nullable=False, location="json") + .add_argument("original_document_id", type=str, required=False, location="json") + .add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + .add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") + .add_argument("embedding_model", type=str, required=False, nullable=True, location="json") + .add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") + .add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json") ) args = parser.parse_args() knowledge_config = KnowledgeConfig.model_validate(args) @@ -379,24 +379,24 @@ class DatasetInitApi(Resource): if not current_user.is_dataset_editor: raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument( - "indexing_technique", - type=str, - choices=Dataset.INDEXING_TECHNIQUE_LIST, - required=True, - nullable=False, - location="json", + parser = ( + reqparse.RequestParser() + .add_argument( + "indexing_technique", + type=str, + choices=Dataset.INDEXING_TECHNIQUE_LIST, + required=True, + nullable=False, + location="json", + ) + .add_argument("data_source", type=dict, required=True, nullable=True, location="json") + .add_argument("process_rule", type=dict, required=True, nullable=True, location="json") + .add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + .add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json") + .add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") + .add_argument("embedding_model", type=str, required=False, nullable=True, location="json") + .add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") ) - parser.add_argument("data_source", type=dict, required=True, nullable=True, location="json") - parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") - parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") - parser.add_argument( - "doc_language", type=str, default="English", required=False, nullable=False, location="json" - ) - parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") - parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") - parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") args = parser.parse_args() knowledge_config = KnowledgeConfig.model_validate(args) @@ -1043,8 +1043,9 @@ class DocumentRetryApi(DocumentResource): def post(self, dataset_id): """retry document.""" - parser = reqparse.RequestParser() - parser.add_argument("document_ids", type=list, required=True, nullable=False, location="json") + parser = reqparse.RequestParser().add_argument( + "document_ids", type=list, required=True, nullable=False, location="json" + ) args = parser.parse_args() dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) @@ -1093,8 +1094,7 @@ class DocumentRenameApi(DocumentResource): if not dataset: raise NotFound("Dataset not found.") DatasetService.check_dataset_operator_permission(current_user, dataset) - parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=True, nullable=False, location="json") + parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json") args = parser.parse_args() try: diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index d4d484a2e2..2fe7d42e46 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -60,13 +60,15 @@ class DatasetDocumentSegmentListApi(Resource): if not document: raise NotFound("Document not found.") - parser = reqparse.RequestParser() - parser.add_argument("limit", type=int, default=20, location="args") - parser.add_argument("status", type=str, action="append", default=[], location="args") - parser.add_argument("hit_count_gte", type=int, default=None, location="args") - parser.add_argument("enabled", type=str, default="all", location="args") - parser.add_argument("keyword", type=str, default=None, location="args") - parser.add_argument("page", type=int, default=1, location="args") + parser = ( + reqparse.RequestParser() + .add_argument("limit", type=int, default=20, location="args") + .add_argument("status", type=str, action="append", default=[], location="args") + .add_argument("hit_count_gte", type=int, default=None, location="args") + .add_argument("enabled", type=str, default="all", location="args") + .add_argument("keyword", type=str, default=None, location="args") + .add_argument("page", type=int, default=1, location="args") + ) args = parser.parse_args() @@ -244,10 +246,12 @@ class DatasetDocumentSegmentAddApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) # validate args - parser = reqparse.RequestParser() - parser.add_argument("content", type=str, required=True, nullable=False, location="json") - parser.add_argument("answer", type=str, required=False, nullable=True, location="json") - parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("content", type=str, required=True, nullable=False, location="json") + .add_argument("answer", type=str, required=False, nullable=True, location="json") + .add_argument("keywords", type=list, required=False, nullable=True, location="json") + ) args = parser.parse_args() SegmentService.segment_create_args_validate(args, document) segment = SegmentService.create_segment(args, document, dataset) @@ -309,12 +313,14 @@ class DatasetDocumentSegmentUpdateApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) # validate args - parser = reqparse.RequestParser() - parser.add_argument("content", type=str, required=True, nullable=False, location="json") - parser.add_argument("answer", type=str, required=False, nullable=True, location="json") - parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") - parser.add_argument( - "regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json" + parser = ( + reqparse.RequestParser() + .add_argument("content", type=str, required=True, nullable=False, location="json") + .add_argument("answer", type=str, required=False, nullable=True, location="json") + .add_argument("keywords", type=list, required=False, nullable=True, location="json") + .add_argument( + "regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json" + ) ) args = parser.parse_args() SegmentService.segment_create_args_validate(args, document) @@ -385,8 +391,9 @@ class DatasetDocumentSegmentBatchImportApi(Resource): if not document: raise NotFound("Document not found.") - parser = reqparse.RequestParser() - parser.add_argument("upload_file_id", type=str, required=True, nullable=False, location="json") + parser = reqparse.RequestParser().add_argument( + "upload_file_id", type=str, required=True, nullable=False, location="json" + ) args = parser.parse_args() upload_file_id = args["upload_file_id"] @@ -484,8 +491,9 @@ class ChildChunkAddApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) # validate args - parser = reqparse.RequestParser() - parser.add_argument("content", type=str, required=True, nullable=False, location="json") + parser = reqparse.RequestParser().add_argument( + "content", type=str, required=True, nullable=False, location="json" + ) args = parser.parse_args() try: content = args["content"] @@ -521,10 +529,12 @@ class ChildChunkAddApi(Resource): ) if not segment: raise NotFound("Segment not found.") - parser = reqparse.RequestParser() - parser.add_argument("limit", type=int, default=20, location="args") - parser.add_argument("keyword", type=str, default=None, location="args") - parser.add_argument("page", type=int, default=1, location="args") + parser = ( + reqparse.RequestParser() + .add_argument("limit", type=int, default=20, location="args") + .add_argument("keyword", type=str, default=None, location="args") + .add_argument("page", type=int, default=1, location="args") + ) args = parser.parse_args() @@ -578,8 +588,9 @@ class ChildChunkAddApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) # validate args - parser = reqparse.RequestParser() - parser.add_argument("chunks", type=list, required=True, nullable=False, location="json") + parser = reqparse.RequestParser().add_argument( + "chunks", type=list, required=True, nullable=False, location="json" + ) args = parser.parse_args() try: chunks_data = args["chunks"] @@ -700,8 +711,9 @@ class ChildChunkUpdateApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) # validate args - parser = reqparse.RequestParser() - parser.add_argument("content", type=str, required=True, nullable=False, location="json") + parser = reqparse.RequestParser().add_argument( + "content", type=str, required=True, nullable=False, location="json" + ) args = parser.parse_args() try: content = args["content"] diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index 1ebd7101e4..4f738db0e5 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -58,20 +58,22 @@ class ExternalApiTemplateListApi(Resource): @account_initialization_required def post(self): current_user, current_tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument( - "name", - nullable=False, - required=True, - help="Name is required. Name must be between 1 to 100 characters.", - type=_validate_name, - ) - parser.add_argument( - "settings", - type=dict, - location="json", - nullable=False, - required=True, + parser = ( + reqparse.RequestParser() + .add_argument( + "name", + nullable=False, + required=True, + help="Name is required. Name must be between 1 to 100 characters.", + type=_validate_name, + ) + .add_argument( + "settings", + type=dict, + location="json", + nullable=False, + required=True, + ) ) args = parser.parse_args() @@ -116,20 +118,22 @@ class ExternalApiTemplateApi(Resource): current_user, current_tenant_id = current_account_with_tenant() external_knowledge_api_id = str(external_knowledge_api_id) - parser = reqparse.RequestParser() - parser.add_argument( - "name", - nullable=False, - required=True, - help="type is required. Name must be between 1 to 100 characters.", - type=_validate_name, - ) - parser.add_argument( - "settings", - type=dict, - location="json", - nullable=False, - required=True, + parser = ( + reqparse.RequestParser() + .add_argument( + "name", + nullable=False, + required=True, + help="type is required. Name must be between 1 to 100 characters.", + type=_validate_name, + ) + .add_argument( + "settings", + type=dict, + location="json", + nullable=False, + required=True, + ) ) args = parser.parse_args() ExternalDatasetService.validate_api_list(args["settings"]) @@ -202,18 +206,20 @@ class ExternalDatasetCreateApi(Resource): if not current_user.has_edit_permission: raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json") - parser.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json") - parser.add_argument( - "name", - nullable=False, - required=True, - help="name is required. Name must be between 1 to 100 characters.", - type=_validate_name, + parser = ( + reqparse.RequestParser() + .add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json") + .add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json") + .add_argument( + "name", + nullable=False, + required=True, + help="name is required. Name must be between 1 to 100 characters.", + type=_validate_name, + ) + .add_argument("description", type=str, required=False, nullable=True, location="json") + .add_argument("external_retrieval_model", type=dict, required=False, location="json") ) - parser.add_argument("description", type=str, required=False, nullable=True, location="json") - parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") args = parser.parse_args() @@ -266,10 +272,12 @@ class ExternalKnowledgeHitTestingApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - parser = reqparse.RequestParser() - parser.add_argument("query", type=str, location="json") - parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") - parser.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("query", type=str, location="json") + .add_argument("external_retrieval_model", type=dict, required=False, location="json") + .add_argument("metadata_filtering_conditions", type=dict, required=False, location="json") + ) args = parser.parse_args() HitTestingService.hit_testing_args_check(args) @@ -305,15 +313,17 @@ class BedrockRetrievalApi(Resource): ) @api.response(200, "Bedrock retrieval test completed") def post(self): - parser = reqparse.RequestParser() - parser.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json") - parser.add_argument( - "query", - nullable=False, - required=True, - type=str, + parser = ( + reqparse.RequestParser() + .add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json") + .add_argument( + "query", + nullable=False, + required=True, + type=str, + ) + .add_argument("knowledge_id", nullable=False, required=True, type=str) ) - parser.add_argument("knowledge_id", nullable=False, required=True, type=str) args = parser.parse_args() # Call the knowledge retrieval service diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index 6113f1fd17..99d4d5a29c 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -48,11 +48,12 @@ class DatasetsHitTestingBase: @staticmethod def parse_args(): - parser = reqparse.RequestParser() - - parser.add_argument("query", type=str, location="json") - parser.add_argument("retrieval_model", type=dict, required=False, location="json") - parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("query", type=str, location="json") + .add_argument("retrieval_model", type=dict, required=False, location="json") + .add_argument("external_retrieval_model", type=dict, required=False, location="json") + ) return parser.parse_args() @staticmethod diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index 673bac1add..72b2ff0ff8 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -24,9 +24,11 @@ class DatasetMetadataCreateApi(Resource): @marshal_with(dataset_metadata_fields) def post(self, dataset_id): current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("type", type=str, required=True, nullable=False, location="json") - parser.add_argument("name", type=str, required=True, nullable=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("type", type=str, required=True, nullable=False, location="json") + .add_argument("name", type=str, required=True, nullable=False, location="json") + ) args = parser.parse_args() metadata_args = MetadataArgs.model_validate(args) @@ -60,8 +62,7 @@ class DatasetMetadataApi(Resource): @marshal_with(dataset_metadata_fields) def patch(self, dataset_id, metadata_id): current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=True, nullable=False, location="json") + parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json") args = parser.parse_args() name = args["name"] @@ -138,8 +139,9 @@ class DocumentMetadataEditApi(Resource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - parser = reqparse.RequestParser() - parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json") + parser = reqparse.RequestParser().add_argument( + "operation_data", type=list, required=True, nullable=False, location="json" + ) args = parser.parse_args() metadata_args = MetadataOperationData.model_validate(args) diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index 194bd98fa3..2111ee2ecf 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -130,11 +130,13 @@ class DatasourceAuth(Resource): def post(self, provider_id: str): _, current_tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument( - "name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None + parser = ( + reqparse.RequestParser() + .add_argument( + "name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None + ) + .add_argument("credentials", type=dict, required=True, nullable=False, location="json") ) - parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_service = DatasourceProviderService() @@ -179,8 +181,9 @@ class DatasourceAuthDeleteApi(Resource): plugin_id = datasource_provider_id.plugin_id provider_name = datasource_provider_id.provider_name - parser = reqparse.RequestParser() - parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") + parser = reqparse.RequestParser().add_argument( + "credential_id", type=str, required=True, nullable=False, location="json" + ) args = parser.parse_args() datasource_provider_service = DatasourceProviderService() datasource_provider_service.remove_datasource_credentials( @@ -202,10 +205,12 @@ class DatasourceAuthUpdateApi(Resource): _, current_tenant_id = current_account_with_tenant() datasource_provider_id = DatasourceProviderID(provider_id) - parser = reqparse.RequestParser() - parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") - parser.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json") - parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("credentials", type=dict, required=False, nullable=True, location="json") + .add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json") + .add_argument("credential_id", type=str, required=True, nullable=False, location="json") + ) args = parser.parse_args() datasource_provider_service = DatasourceProviderService() @@ -255,9 +260,11 @@ class DatasourceAuthOauthCustomClient(Resource): def post(self, provider_id: str): _, current_tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json") - parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("client_params", type=dict, required=False, nullable=True, location="json") + .add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json") + ) args = parser.parse_args() datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_service = DatasourceProviderService() @@ -293,8 +300,7 @@ class DatasourceAuthDefaultApi(Resource): def post(self, provider_id: str): _, current_tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("id", type=str, required=True, nullable=False, location="json") + parser = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json") args = parser.parse_args() datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_service = DatasourceProviderService() @@ -315,9 +321,11 @@ class DatasourceUpdateProviderNameApi(Resource): def post(self, provider_id: str): _, current_tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json") - parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json") + .add_argument("credential_id", type=str, required=True, nullable=False, location="json") + ) args = parser.parse_args() datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_service = DatasourceProviderService() diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py index 6c04cc877a..856e4a1c70 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py @@ -26,10 +26,12 @@ class DataSourceContentPreviewApi(Resource): if not isinstance(current_user, Account): raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") - parser.add_argument("datasource_type", type=str, required=True, location="json") - parser.add_argument("credential_id", type=str, required=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("inputs", type=dict, required=True, nullable=False, location="json") + .add_argument("datasource_type", type=str, required=True, location="json") + .add_argument("credential_id", type=str, required=False, location="json") + ) args = parser.parse_args() inputs = args.get("inputs") diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index e021f95283..f589bba3bf 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -66,26 +66,28 @@ class CustomizedPipelineTemplateApi(Resource): @account_initialization_required @enterprise_license_required def patch(self, template_id: str): - parser = reqparse.RequestParser() - parser.add_argument( - "name", - nullable=False, - required=True, - help="Name must be between 1 to 40 characters.", - type=_validate_name, - ) - parser.add_argument( - "description", - type=_validate_description_length, - nullable=True, - required=False, - default="", - ) - parser.add_argument( - "icon_info", - type=dict, - location="json", - nullable=True, + parser = ( + reqparse.RequestParser() + .add_argument( + "name", + nullable=False, + required=True, + help="Name must be between 1 to 40 characters.", + type=_validate_name, + ) + .add_argument( + "description", + type=_validate_description_length, + nullable=True, + required=False, + default="", + ) + .add_argument( + "icon_info", + type=dict, + location="json", + nullable=True, + ) ) args = parser.parse_args() pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args) @@ -123,26 +125,28 @@ class PublishCustomizedPipelineTemplateApi(Resource): @enterprise_license_required @knowledge_pipeline_publish_enabled def post(self, pipeline_id: str): - parser = reqparse.RequestParser() - parser.add_argument( - "name", - nullable=False, - required=True, - help="Name must be between 1 to 40 characters.", - type=_validate_name, - ) - parser.add_argument( - "description", - type=_validate_description_length, - nullable=True, - required=False, - default="", - ) - parser.add_argument( - "icon_info", - type=dict, - location="json", - nullable=True, + parser = ( + reqparse.RequestParser() + .add_argument( + "name", + nullable=False, + required=True, + help="Name must be between 1 to 40 characters.", + type=_validate_name, + ) + .add_argument( + "description", + type=_validate_description_length, + nullable=True, + required=False, + default="", + ) + .add_argument( + "icon_info", + type=dict, + location="json", + nullable=True, + ) ) args = parser.parse_args() rag_pipeline_service = RagPipelineService() diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py index b394887783..98876e9f5e 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py @@ -26,9 +26,7 @@ class CreateRagPipelineDatasetApi(Resource): @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") def post(self): - parser = reqparse.RequestParser() - - parser.add_argument( + parser = reqparse.RequestParser().add_argument( "yaml_content", type=str, nullable=False, diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index 2e8cc16dc1..858ba94bf8 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -33,16 +33,18 @@ logger = logging.getLogger(__name__) def _create_pagination_parser(): - parser = reqparse.RequestParser() - parser.add_argument( - "page", - type=inputs.int_range(1, 100_000), - required=False, - default=1, - location="args", - help="the page of data requested", + parser = ( + reqparse.RequestParser() + .add_argument( + "page", + type=inputs.int_range(1, 100_000), + required=False, + default=1, + location="args", + help="the page of data requested", + ) + .add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") ) - parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") return parser @@ -206,10 +208,11 @@ class RagPipelineVariableApi(Resource): # "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4" # } - parser = reqparse.RequestParser() - parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json") - # Parse 'value' field as-is to maintain its original data structure - parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json") + .add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json") + ) draft_var_srv = WorkflowDraftVariableService( session=db.session(), diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py index ca767dbb10..2c28120e65 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py @@ -28,16 +28,18 @@ class RagPipelineImportApi(Resource): if not current_user.has_edit_permission: raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("mode", type=str, required=True, location="json") - parser.add_argument("yaml_content", type=str, location="json") - parser.add_argument("yaml_url", type=str, location="json") - parser.add_argument("name", type=str, location="json") - parser.add_argument("description", type=str, location="json") - parser.add_argument("icon_type", type=str, location="json") - parser.add_argument("icon", type=str, location="json") - parser.add_argument("icon_background", type=str, location="json") - parser.add_argument("pipeline_id", type=str, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("mode", type=str, required=True, location="json") + .add_argument("yaml_content", type=str, location="json") + .add_argument("yaml_url", type=str, location="json") + .add_argument("name", type=str, location="json") + .add_argument("description", type=str, location="json") + .add_argument("icon_type", type=str, location="json") + .add_argument("icon", type=str, location="json") + .add_argument("icon_background", type=str, location="json") + .add_argument("pipeline_id", type=str, location="json") + ) args = parser.parse_args() # Create service with session @@ -121,8 +123,7 @@ class RagPipelineExportApi(Resource): raise Forbidden() # Add include_secret params - parser = reqparse.RequestParser() - parser.add_argument("include_secret", type=str, default="false", location="args") + parser = reqparse.RequestParser().add_argument("include_secret", type=str, default="false", location="args") args = parser.parse_args() with Session(db.engine) as session: diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 77dcf30a78..5fe8572dfa 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -88,12 +88,14 @@ class DraftRagPipelineApi(Resource): content_type = request.headers.get("Content-Type", "") if "application/json" in content_type: - parser = reqparse.RequestParser() - parser.add_argument("graph", type=dict, required=True, nullable=False, location="json") - parser.add_argument("hash", type=str, required=False, location="json") - parser.add_argument("environment_variables", type=list, required=False, location="json") - parser.add_argument("conversation_variables", type=list, required=False, location="json") - parser.add_argument("rag_pipeline_variables", type=list, required=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("graph", type=dict, required=True, nullable=False, location="json") + .add_argument("hash", type=str, required=False, location="json") + .add_argument("environment_variables", type=list, required=False, location="json") + .add_argument("conversation_variables", type=list, required=False, location="json") + .add_argument("rag_pipeline_variables", type=list, required=False, location="json") + ) args = parser.parse_args() elif "text/plain" in content_type: try: @@ -160,8 +162,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, location="json") + parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json") args = parser.parse_args() try: @@ -196,8 +197,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource): if not current_user.has_edit_permission: raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, location="json") + parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json") args = parser.parse_args() try: @@ -232,11 +232,13 @@ class DraftRagPipelineRunApi(Resource): if not current_user.has_edit_permission: raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") - parser.add_argument("datasource_type", type=str, required=True, location="json") - parser.add_argument("datasource_info_list", type=list, required=True, location="json") - parser.add_argument("start_node_id", type=str, required=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("inputs", type=dict, required=True, nullable=False, location="json") + .add_argument("datasource_type", type=str, required=True, location="json") + .add_argument("datasource_info_list", type=list, required=True, location="json") + .add_argument("start_node_id", type=str, required=True, location="json") + ) args = parser.parse_args() try: @@ -268,14 +270,16 @@ class PublishedRagPipelineRunApi(Resource): if not current_user.has_edit_permission: raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") - parser.add_argument("datasource_type", type=str, required=True, location="json") - parser.add_argument("datasource_info_list", type=list, required=True, location="json") - parser.add_argument("start_node_id", type=str, required=True, location="json") - parser.add_argument("is_preview", type=bool, required=True, location="json", default=False) - parser.add_argument("response_mode", type=str, required=True, location="json", default="streaming") - parser.add_argument("original_document_id", type=str, required=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("inputs", type=dict, required=True, nullable=False, location="json") + .add_argument("datasource_type", type=str, required=True, location="json") + .add_argument("datasource_info_list", type=list, required=True, location="json") + .add_argument("start_node_id", type=str, required=True, location="json") + .add_argument("is_preview", type=bool, required=True, location="json", default=False) + .add_argument("response_mode", type=str, required=True, location="json", default="streaming") + .add_argument("original_document_id", type=str, required=False, location="json") + ) args = parser.parse_args() streaming = args["response_mode"] == "streaming" @@ -310,9 +314,10 @@ class PublishedRagPipelineRunApi(Resource): # if not isinstance(current_user, Account): # raise Forbidden() # -# parser = reqparse.RequestParser() -# parser.add_argument("job_id", type=str, required=True, nullable=False, location="json") -# parser.add_argument("datasource_type", type=str, required=True, location="json") +# parser = (reqparse.RequestParser() +# .add_argument("job_id", type=str, required=True, nullable=False, location="json") +# .add_argument("datasource_type", type=str, required=True, location="json") +# ) # args = parser.parse_args() # # job_id = args.get("job_id") @@ -351,9 +356,10 @@ class PublishedRagPipelineRunApi(Resource): # if not isinstance(current_user, Account): # raise Forbidden() # -# parser = reqparse.RequestParser() -# parser.add_argument("job_id", type=str, required=True, nullable=False, location="json") -# parser.add_argument("datasource_type", type=str, required=True, location="json") +# parser = (reqparse.RequestParser() +# .add_argument("job_id", type=str, required=True, nullable=False, location="json") +# .add_argument("datasource_type", type=str, required=True, location="json") +# ) # args = parser.parse_args() # # job_id = args.get("job_id") @@ -390,10 +396,12 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource): if not current_user.has_edit_permission: raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") - parser.add_argument("datasource_type", type=str, required=True, location="json") - parser.add_argument("credential_id", type=str, required=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("inputs", type=dict, required=True, nullable=False, location="json") + .add_argument("datasource_type", type=str, required=True, location="json") + .add_argument("credential_id", type=str, required=False, location="json") + ) args = parser.parse_args() inputs = args.get("inputs") @@ -434,10 +442,12 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource): if not current_user.has_edit_permission: raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") - parser.add_argument("datasource_type", type=str, required=True, location="json") - parser.add_argument("credential_id", type=str, required=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("inputs", type=dict, required=True, nullable=False, location="json") + .add_argument("datasource_type", type=str, required=True, location="json") + .add_argument("credential_id", type=str, required=False, location="json") + ) args = parser.parse_args() inputs = args.get("inputs") @@ -479,8 +489,9 @@ class RagPipelineDraftNodeRunApi(Resource): if not current_user.has_edit_permission: raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser = reqparse.RequestParser().add_argument( + "inputs", type=dict, required=True, nullable=False, location="json" + ) args = parser.parse_args() inputs = args.get("inputs") @@ -611,8 +622,7 @@ class DefaultRagPipelineBlockConfigApi(Resource): if not current_user.has_edit_permission: raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("q", type=str, location="args") + parser = reqparse.RequestParser().add_argument("q", type=str, location="args") args = parser.parse_args() q = args.get("q") @@ -644,11 +654,13 @@ class PublishedAllRagPipelineApi(Resource): if not current_user.has_edit_permission: raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") - parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") - parser.add_argument("user_id", type=str, required=False, location="args") - parser.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args") + parser = ( + reqparse.RequestParser() + .add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") + .add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") + .add_argument("user_id", type=str, required=False, location="args") + .add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args") + ) args = parser.parse_args() page = int(args.get("page", 1)) limit = int(args.get("limit", 10)) @@ -695,9 +707,11 @@ class RagPipelineByIdApi(Resource): if not current_user.has_edit_permission: raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("marked_name", type=str, required=False, location="json") - parser.add_argument("marked_comment", type=str, required=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("marked_name", type=str, required=False, location="json") + .add_argument("marked_comment", type=str, required=False, location="json") + ) args = parser.parse_args() # Validate name and comment length @@ -749,8 +763,7 @@ class PublishedRagPipelineSecondStepApi(Resource): """ Get second step parameters of rag pipeline """ - parser = reqparse.RequestParser() - parser.add_argument("node_id", type=str, required=True, location="args") + parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args") args = parser.parse_args() node_id = args.get("node_id") if not node_id: @@ -773,8 +786,7 @@ class PublishedRagPipelineFirstStepApi(Resource): """ Get first step parameters of rag pipeline """ - parser = reqparse.RequestParser() - parser.add_argument("node_id", type=str, required=True, location="args") + parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args") args = parser.parse_args() node_id = args.get("node_id") if not node_id: @@ -797,8 +809,7 @@ class DraftRagPipelineFirstStepApi(Resource): """ Get first step parameters of rag pipeline """ - parser = reqparse.RequestParser() - parser.add_argument("node_id", type=str, required=True, location="args") + parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args") args = parser.parse_args() node_id = args.get("node_id") if not node_id: @@ -821,8 +832,7 @@ class DraftRagPipelineSecondStepApi(Resource): """ Get second step parameters of rag pipeline """ - parser = reqparse.RequestParser() - parser.add_argument("node_id", type=str, required=True, location="args") + parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args") args = parser.parse_args() node_id = args.get("node_id") if not node_id: @@ -846,9 +856,11 @@ class RagPipelineWorkflowRunListApi(Resource): """ Get workflow run list """ - parser = reqparse.RequestParser() - parser.add_argument("last_id", type=uuid_value, location="args") - parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + parser = ( + reqparse.RequestParser() + .add_argument("last_id", type=uuid_value, location="args") + .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + ) args = parser.parse_args() rag_pipeline_service = RagPipelineService() @@ -962,11 +974,13 @@ class RagPipelineDatasourceVariableApi(Resource): Set datasource variables """ current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("datasource_type", type=str, required=True, location="json") - parser.add_argument("datasource_info", type=dict, required=True, location="json") - parser.add_argument("start_node_id", type=str, required=True, location="json") - parser.add_argument("start_node_title", type=str, required=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("datasource_type", type=str, required=True, location="json") + .add_argument("datasource_info", type=dict, required=True, location="json") + .add_argument("start_node_id", type=str, required=True, location="json") + .add_argument("start_node_title", type=str, required=True, location="json") + ) args = parser.parse_args() rag_pipeline_service = RagPipelineService() diff --git a/api/controllers/console/datasets/website.py b/api/controllers/console/datasets/website.py index b9c1f65bfd..fe6eaaa0de 100644 --- a/api/controllers/console/datasets/website.py +++ b/api/controllers/console/datasets/website.py @@ -31,17 +31,19 @@ class WebsiteCrawlApi(Resource): @login_required @account_initialization_required def post(self): - parser = reqparse.RequestParser() - parser.add_argument( - "provider", - type=str, - choices=["firecrawl", "watercrawl", "jinareader"], - required=True, - nullable=True, - location="json", + parser = ( + reqparse.RequestParser() + .add_argument( + "provider", + type=str, + choices=["firecrawl", "watercrawl", "jinareader"], + required=True, + nullable=True, + location="json", + ) + .add_argument("url", type=str, required=True, nullable=True, location="json") + .add_argument("options", type=dict, required=True, nullable=True, location="json") ) - parser.add_argument("url", type=str, required=True, nullable=True, location="json") - parser.add_argument("options", type=dict, required=True, nullable=True, location="json") args = parser.parse_args() # Create typed request and validate @@ -70,8 +72,7 @@ class WebsiteCrawlStatusApi(Resource): @login_required @account_initialization_required def get(self, job_id: str): - parser = reqparse.RequestParser() - parser.add_argument( + parser = reqparse.RequestParser().add_argument( "provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args" ) args = parser.parse_args() diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index 7c20fb49d8..2a248cf20d 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -81,11 +81,13 @@ class ChatTextApi(InstalledAppResource): app_model = installed_app.app try: - parser = reqparse.RequestParser() - parser.add_argument("message_id", type=str, required=False, location="json") - parser.add_argument("voice", type=str, location="json") - parser.add_argument("text", type=str, location="json") - parser.add_argument("streaming", type=bool, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("message_id", type=str, required=False, location="json") + .add_argument("voice", type=str, location="json") + .add_argument("text", type=str, location="json") + .add_argument("streaming", type=bool, location="json") + ) args = parser.parse_args() message_id = args.get("message_id", None) diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 1102b815eb..9386ecebae 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -49,12 +49,14 @@ class CompletionApi(InstalledAppResource): if app_model.mode != "completion": raise NotCompletionAppError() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, location="json") - parser.add_argument("query", type=str, location="json", default="") - parser.add_argument("files", type=list, required=False, location="json") - parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") + parser = ( + reqparse.RequestParser() + .add_argument("inputs", type=dict, required=True, location="json") + .add_argument("query", type=str, location="json", default="") + .add_argument("files", type=list, required=False, location="json") + .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + .add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") + ) args = parser.parse_args() streaming = args["response_mode"] == "streaming" @@ -121,13 +123,15 @@ class ChatApi(InstalledAppResource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, location="json") - parser.add_argument("query", type=str, required=True, location="json") - parser.add_argument("files", type=list, required=False, location="json") - parser.add_argument("conversation_id", type=uuid_value, location="json") - parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json") - parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") + parser = ( + reqparse.RequestParser() + .add_argument("inputs", type=dict, required=True, location="json") + .add_argument("query", type=str, required=True, location="json") + .add_argument("files", type=list, required=False, location="json") + .add_argument("conversation_id", type=uuid_value, location="json") + .add_argument("parent_message_id", type=uuid_value, required=False, location="json") + .add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") + ) args = parser.parse_args() args["auto_generate_name"] = False diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index feabea2524..5a39363cc2 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -31,10 +31,12 @@ class ConversationListApi(InstalledAppResource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = reqparse.RequestParser() - parser.add_argument("last_id", type=uuid_value, location="args") - parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - parser.add_argument("pinned", type=str, choices=["true", "false", None], location="args") + parser = ( + reqparse.RequestParser() + .add_argument("last_id", type=uuid_value, location="args") + .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + .add_argument("pinned", type=str, choices=["true", "false", None], location="args") + ) args = parser.parse_args() pinned = None @@ -94,9 +96,11 @@ class ConversationRenameApi(InstalledAppResource): conversation_id = str(c_id) - parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=False, location="json") - parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("name", type=str, required=False, location="json") + .add_argument("auto_generate", type=bool, required=False, default=False, location="json") + ) args = parser.parse_args() try: diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 7ead93a1b6..dec84b68f4 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -111,8 +111,7 @@ class InstalledAppsListApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("apps") def post(self): - parser = reqparse.RequestParser() - parser.add_argument("app_id", type=str, required=True, help="Invalid app_id") + parser = reqparse.RequestParser().add_argument("app_id", type=str, required=True, help="Invalid app_id") args = parser.parse_args() recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first() @@ -170,8 +169,7 @@ class InstalledAppApi(InstalledAppResource): return {"result": "success", "message": "App uninstalled successfully"}, 204 def patch(self, installed_app): - parser = reqparse.RequestParser() - parser.add_argument("is_pinned", type=inputs.boolean) + parser = reqparse.RequestParser().add_argument("is_pinned", type=inputs.boolean) args = parser.parse_args() commit_args = False diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 064e026753..db854e09bb 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -54,10 +54,12 @@ class MessageListApi(InstalledAppResource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = reqparse.RequestParser() - parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") - parser.add_argument("first_id", type=uuid_value, location="args") - parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + parser = ( + reqparse.RequestParser() + .add_argument("conversation_id", required=True, type=uuid_value, location="args") + .add_argument("first_id", type=uuid_value, location="args") + .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + ) args = parser.parse_args() try: @@ -81,9 +83,11 @@ class MessageFeedbackApi(InstalledAppResource): message_id = str(message_id) - parser = reqparse.RequestParser() - parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") - parser.add_argument("content", type=str, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("rating", type=str, choices=["like", "dislike", None], location="json") + .add_argument("content", type=str, location="json") + ) args = parser.parse_args() try: @@ -113,8 +117,7 @@ class MessageMoreLikeThisApi(InstalledAppResource): message_id = str(message_id) - parser = reqparse.RequestParser() - parser.add_argument( + parser = reqparse.RequestParser().add_argument( "response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args" ) args = parser.parse_args() diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 6d627a929a..751012757a 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -42,8 +42,7 @@ class RecommendedAppListApi(Resource): @marshal_with(recommended_app_list_fields) def get(self): # language args - parser = reqparse.RequestParser() - parser.add_argument("language", type=str, location="args") + parser = reqparse.RequestParser().add_argument("language", type=str, location="args") args = parser.parse_args() language = args.get("language") diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index 830685975b..9775c951f7 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -39,9 +39,11 @@ class SavedMessageListApi(InstalledAppResource): if app_model.mode != "completion": raise NotCompletionAppError() - parser = reqparse.RequestParser() - parser.add_argument("last_id", type=uuid_value, location="args") - parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + parser = ( + reqparse.RequestParser() + .add_argument("last_id", type=uuid_value, location="args") + .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + ) args = parser.parse_args() return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"]) @@ -52,8 +54,7 @@ class SavedMessageListApi(InstalledAppResource): if app_model.mode != "completion": raise NotCompletionAppError() - parser = reqparse.RequestParser() - parser.add_argument("message_id", type=uuid_value, required=True, location="json") + parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json") args = parser.parse_args() try: diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index aeea446c6e..3022d937b9 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -47,9 +47,11 @@ class InstalledAppWorkflowRunApi(InstalledAppResource): if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") - parser.add_argument("files", type=list, required=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("inputs", type=dict, required=True, nullable=False, location="json") + .add_argument("files", type=list, required=False, location="json") + ) args = parser.parse_args() assert current_user is not None try: diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index e5b7611c44..4e1a8aeb3e 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -29,8 +29,7 @@ class CodeBasedExtensionAPI(Resource): @login_required @account_initialization_required def get(self): - parser = reqparse.RequestParser() - parser.add_argument("module", type=str, required=True, location="args") + parser = reqparse.RequestParser().add_argument("module", type=str, required=True, location="args") args = parser.parse_args() return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])} @@ -67,10 +66,12 @@ class APIBasedExtensionAPI(Resource): @account_initialization_required @marshal_with(api_based_extension_fields) def post(self): - parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=True, location="json") - parser.add_argument("api_endpoint", type=str, required=True, location="json") - parser.add_argument("api_key", type=str, required=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("name", type=str, required=True, location="json") + .add_argument("api_endpoint", type=str, required=True, location="json") + .add_argument("api_key", type=str, required=True, location="json") + ) args = parser.parse_args() _, current_tenant_id = current_account_with_tenant() @@ -124,10 +125,12 @@ class APIBasedExtensionDetailAPI(Resource): extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id) - parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=True, location="json") - parser.add_argument("api_endpoint", type=str, required=True, location="json") - parser.add_argument("api_key", type=str, required=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("name", type=str, required=True, location="json") + .add_argument("api_endpoint", type=str, required=True, location="json") + .add_argument("api_key", type=str, required=True, location="json") + ) args = parser.parse_args() extension_data_from_db.name = args["name"] diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index 30b53458b2..f219425d07 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -57,8 +57,7 @@ class InitValidateAPI(Resource): if tenant_count > 0: raise AlreadySetupError() - parser = reqparse.RequestParser() - parser.add_argument("password", type=StrLen(30), required=True, location="json") + parser = reqparse.RequestParser().add_argument("password", type=StrLen(30), required=True, location="json") input_password = parser.parse_args()["password"] if input_password != os.environ.get("INIT_PASSWORD"): diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index b053f222df..96c86dc0db 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -40,8 +40,7 @@ class RemoteFileInfoApi(Resource): class RemoteFileUploadApi(Resource): @marshal_with(file_fields_with_signed_url) def post(self): - parser = reqparse.RequestParser() - parser.add_argument("url", type=str, required=True, help="URL is required") + parser = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required") args = parser.parse_args() url = args["url"] diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index bff5fc1651..6d2b22bde3 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -69,10 +69,12 @@ class SetupApi(Resource): if not get_init_validate_status(): raise NotInitValidateError() - parser = reqparse.RequestParser() - parser.add_argument("email", type=email, required=True, location="json") - parser.add_argument("name", type=StrLen(30), required=True, location="json") - parser.add_argument("password", type=valid_password, required=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("email", type=email, required=True, location="json") + .add_argument("name", type=StrLen(30), required=True, location="json") + .add_argument("password", type=valid_password, required=True, location="json") + ) args = parser.parse_args() # setup diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index 5748ca110d..40ae7fb4d0 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -39,12 +39,18 @@ class TagListApi(Resource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument( - "name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name - ) - parser.add_argument( - "type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type." + parser = ( + reqparse.RequestParser() + .add_argument( + "name", + nullable=False, + required=True, + help="Name must be between 1 to 50 characters.", + type=_validate_name, + ) + .add_argument( + "type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type." + ) ) args = parser.parse_args() tag = TagService.save_tags(args) @@ -66,8 +72,7 @@ class TagUpdateDeleteApi(Resource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument( + parser = reqparse.RequestParser().add_argument( "name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name ) args = parser.parse_args() @@ -105,15 +110,17 @@ class TagBindingCreateApi(Resource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument( - "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required." - ) - parser.add_argument( - "target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required." - ) - parser.add_argument( - "type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type." + parser = ( + reqparse.RequestParser() + .add_argument( + "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required." + ) + .add_argument( + "target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required." + ) + .add_argument( + "type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type." + ) ) args = parser.parse_args() TagService.save_tag_binding(args) @@ -132,11 +139,13 @@ class TagBindingDeleteApi(Resource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") - parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") - parser.add_argument( - "type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type." + parser = ( + reqparse.RequestParser() + .add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") + .add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") + .add_argument( + "type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type." + ) ) args = parser.parse_args() TagService.delete_tag_binding(args) diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index 965a520f70..417486f59e 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -37,8 +37,7 @@ class VersionApi(Resource): ) def get(self): """Check for application version updates""" - parser = reqparse.RequestParser() - parser.add_argument("current_version", type=str, required=True, location="args") + parser = reqparse.RequestParser().add_argument("current_version", type=str, required=True, location="args") args = parser.parse_args() check_update_url = dify_config.CHECK_UPDATE_URL diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index a5e6b8f473..499a52370f 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -57,9 +57,9 @@ class AccountInitApi(Resource): if dify_config.EDITION == "CLOUD": parser.add_argument("invitation_code", type=str, location="json") - - parser.add_argument("interface_language", type=supported_language, required=True, location="json") - parser.add_argument("timezone", type=timezone, required=True, location="json") + parser.add_argument("interface_language", type=supported_language, required=True, location="json").add_argument( + "timezone", type=timezone, required=True, location="json" + ) args = parser.parse_args() if dify_config.EDITION == "CLOUD": @@ -114,8 +114,7 @@ class AccountNameApi(Resource): @marshal_with(account_fields) def post(self): current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=True, location="json") + parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json") args = parser.parse_args() # Validate account name length @@ -135,8 +134,7 @@ class AccountAvatarApi(Resource): @marshal_with(account_fields) def post(self): current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("avatar", type=str, required=True, location="json") + parser = reqparse.RequestParser().add_argument("avatar", type=str, required=True, location="json") args = parser.parse_args() updated_account = AccountService.update_account(current_user, avatar=args["avatar"]) @@ -152,8 +150,9 @@ class AccountInterfaceLanguageApi(Resource): @marshal_with(account_fields) def post(self): current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("interface_language", type=supported_language, required=True, location="json") + parser = reqparse.RequestParser().add_argument( + "interface_language", type=supported_language, required=True, location="json" + ) args = parser.parse_args() updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"]) @@ -169,8 +168,9 @@ class AccountInterfaceThemeApi(Resource): @marshal_with(account_fields) def post(self): current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json") + parser = reqparse.RequestParser().add_argument( + "interface_theme", type=str, choices=["light", "dark"], required=True, location="json" + ) args = parser.parse_args() updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"]) @@ -186,8 +186,7 @@ class AccountTimezoneApi(Resource): @marshal_with(account_fields) def post(self): current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("timezone", type=str, required=True, location="json") + parser = reqparse.RequestParser().add_argument("timezone", type=str, required=True, location="json") args = parser.parse_args() # Validate timezone string, e.g. America/New_York, Asia/Shanghai @@ -207,10 +206,12 @@ class AccountPasswordApi(Resource): @marshal_with(account_fields) def post(self): current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("password", type=str, required=False, location="json") - parser.add_argument("new_password", type=str, required=True, location="json") - parser.add_argument("repeat_new_password", type=str, required=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("password", type=str, required=False, location="json") + .add_argument("new_password", type=str, required=True, location="json") + .add_argument("repeat_new_password", type=str, required=True, location="json") + ) args = parser.parse_args() if args["new_password"] != args["repeat_new_password"]: @@ -301,9 +302,11 @@ class AccountDeleteApi(Resource): def post(self): account, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("token", type=str, required=True, location="json") - parser.add_argument("code", type=str, required=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("token", type=str, required=True, location="json") + .add_argument("code", type=str, required=True, location="json") + ) args = parser.parse_args() if not AccountService.verify_account_deletion_code(args["token"], args["code"]): @@ -318,9 +321,11 @@ class AccountDeleteApi(Resource): class AccountDeleteUpdateFeedbackApi(Resource): @setup_required def post(self): - parser = reqparse.RequestParser() - parser.add_argument("email", type=str, required=True, location="json") - parser.add_argument("feedback", type=str, required=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("email", type=str, required=True, location="json") + .add_argument("feedback", type=str, required=True, location="json") + ) args = parser.parse_args() BillingService.update_account_deletion_feedback(args["email"], args["feedback"]) @@ -363,10 +368,12 @@ class EducationApi(Resource): def post(self): account, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("token", type=str, required=True, location="json") - parser.add_argument("institution", type=str, required=True, location="json") - parser.add_argument("role", type=str, required=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("token", type=str, required=True, location="json") + .add_argument("institution", type=str, required=True, location="json") + .add_argument("role", type=str, required=True, location="json") + ) args = parser.parse_args() return BillingService.EducationIdentity.activate(account, args["token"], args["institution"], args["role"]) @@ -402,10 +409,12 @@ class EducationAutoCompleteApi(Resource): @cloud_edition_billing_enabled @marshal_with(data_fields) def get(self): - parser = reqparse.RequestParser() - parser.add_argument("keywords", type=str, required=True, location="args") - parser.add_argument("page", type=int, required=False, location="args", default=0) - parser.add_argument("limit", type=int, required=False, location="args", default=20) + parser = ( + reqparse.RequestParser() + .add_argument("keywords", type=str, required=True, location="args") + .add_argument("page", type=int, required=False, location="args", default=0) + .add_argument("limit", type=int, required=False, location="args", default=20) + ) args = parser.parse_args() return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"]) @@ -419,11 +428,13 @@ class ChangeEmailSendEmailApi(Resource): @account_initialization_required def post(self): current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("email", type=email, required=True, location="json") - parser.add_argument("language", type=str, required=False, location="json") - parser.add_argument("phase", type=str, required=False, location="json") - parser.add_argument("token", type=str, required=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("email", type=email, required=True, location="json") + .add_argument("language", type=str, required=False, location="json") + .add_argument("phase", type=str, required=False, location="json") + .add_argument("token", type=str, required=False, location="json") + ) args = parser.parse_args() ip_address = extract_remote_ip(request) @@ -466,10 +477,12 @@ class ChangeEmailCheckApi(Resource): @login_required @account_initialization_required def post(self): - parser = reqparse.RequestParser() - parser.add_argument("email", type=email, required=True, location="json") - parser.add_argument("code", type=str, required=True, location="json") - parser.add_argument("token", type=str, required=True, nullable=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("email", type=email, required=True, location="json") + .add_argument("code", type=str, required=True, location="json") + .add_argument("token", type=str, required=True, nullable=False, location="json") + ) args = parser.parse_args() user_email = args["email"] @@ -509,9 +522,11 @@ class ChangeEmailResetApi(Resource): @account_initialization_required @marshal_with(account_fields) def post(self): - parser = reqparse.RequestParser() - parser.add_argument("new_email", type=email, required=True, location="json") - parser.add_argument("token", type=str, required=True, nullable=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("new_email", type=email, required=True, location="json") + .add_argument("token", type=str, required=True, nullable=False, location="json") + ) args = parser.parse_args() if AccountService.is_account_in_freeze(args["new_email"]): @@ -544,8 +559,7 @@ class ChangeEmailResetApi(Resource): class CheckEmailUnique(Resource): @setup_required def post(self): - parser = reqparse.RequestParser() - parser.add_argument("email", type=email, required=True, location="json") + parser = reqparse.RequestParser().add_argument("email", type=email, required=True, location="json") args = parser.parse_args() if AccountService.is_account_in_freeze(args["email"]): raise AccountInFreezeError() diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index b31011b4a3..d115f62d73 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -37,10 +37,12 @@ class EndpointCreateApi(Resource): if not user.is_admin_or_owner: raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("plugin_unique_identifier", type=str, required=True) - parser.add_argument("settings", type=dict, required=True) - parser.add_argument("name", type=str, required=True) + parser = ( + reqparse.RequestParser() + .add_argument("plugin_unique_identifier", type=str, required=True) + .add_argument("settings", type=dict, required=True) + .add_argument("name", type=str, required=True) + ) args = parser.parse_args() plugin_unique_identifier = args["plugin_unique_identifier"] @@ -81,9 +83,11 @@ class EndpointListApi(Resource): def get(self): user, tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("page", type=int, required=True, location="args") - parser.add_argument("page_size", type=int, required=True, location="args") + parser = ( + reqparse.RequestParser() + .add_argument("page", type=int, required=True, location="args") + .add_argument("page_size", type=int, required=True, location="args") + ) args = parser.parse_args() page = args["page"] @@ -124,10 +128,12 @@ class EndpointListForSinglePluginApi(Resource): def get(self): user, tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("page", type=int, required=True, location="args") - parser.add_argument("page_size", type=int, required=True, location="args") - parser.add_argument("plugin_id", type=str, required=True, location="args") + parser = ( + reqparse.RequestParser() + .add_argument("page", type=int, required=True, location="args") + .add_argument("page_size", type=int, required=True, location="args") + .add_argument("plugin_id", type=str, required=True, location="args") + ) args = parser.parse_args() page = args["page"] @@ -166,8 +172,7 @@ class EndpointDeleteApi(Resource): def post(self): user, tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("endpoint_id", type=str, required=True) + parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True) args = parser.parse_args() if not user.is_admin_or_owner: @@ -206,10 +211,12 @@ class EndpointUpdateApi(Resource): def post(self): user, tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("endpoint_id", type=str, required=True) - parser.add_argument("settings", type=dict, required=True) - parser.add_argument("name", type=str, required=True) + parser = ( + reqparse.RequestParser() + .add_argument("endpoint_id", type=str, required=True) + .add_argument("settings", type=dict, required=True) + .add_argument("name", type=str, required=True) + ) args = parser.parse_args() endpoint_id = args["endpoint_id"] @@ -249,8 +256,7 @@ class EndpointEnableApi(Resource): def post(self): user, tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("endpoint_id", type=str, required=True) + parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True) args = parser.parse_args() endpoint_id = args["endpoint_id"] @@ -282,8 +288,7 @@ class EndpointDisableApi(Resource): def post(self): user, tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("endpoint_id", type=str, required=True) + parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True) args = parser.parse_args() endpoint_id = args["endpoint_id"] diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index 4e6f1fa3a5..9bf393ea2e 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -24,17 +24,19 @@ class LoadBalancingCredentialsValidateApi(Resource): tenant_id = current_tenant_id - parser = reqparse.RequestParser() - parser.add_argument("model", type=str, required=True, nullable=False, location="json") - parser.add_argument( - "model_type", - type=str, - required=True, - nullable=False, - choices=[mt.value for mt in ModelType], - location="json", + parser = ( + reqparse.RequestParser() + .add_argument("model", type=str, required=True, nullable=False, location="json") + .add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + .add_argument("credentials", type=dict, required=True, nullable=False, location="json") ) - parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() # validate model load balancing credentials @@ -77,17 +79,19 @@ class LoadBalancingConfigCredentialsValidateApi(Resource): tenant_id = current_tenant_id - parser = reqparse.RequestParser() - parser.add_argument("model", type=str, required=True, nullable=False, location="json") - parser.add_argument( - "model_type", - type=str, - required=True, - nullable=False, - choices=[mt.value for mt in ModelType], - location="json", + parser = ( + reqparse.RequestParser() + .add_argument("model", type=str, required=True, nullable=False, location="json") + .add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + .add_argument("credentials", type=dict, required=True, nullable=False, location="json") ) - parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() # validate model load balancing config credentials diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 4f080708cc..d66f861799 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -57,10 +57,12 @@ class MemberInviteEmailApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("members") def post(self): - parser = reqparse.RequestParser() - parser.add_argument("emails", type=list, required=True, location="json") - parser.add_argument("role", type=str, required=True, default="admin", location="json") - parser.add_argument("language", type=str, required=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("emails", type=list, required=True, location="json") + .add_argument("role", type=str, required=True, default="admin", location="json") + .add_argument("language", type=str, required=False, location="json") + ) args = parser.parse_args() invitee_emails = args["emails"] @@ -149,8 +151,7 @@ class MemberUpdateRoleApi(Resource): @login_required @account_initialization_required def put(self, member_id): - parser = reqparse.RequestParser() - parser.add_argument("role", type=str, required=True, location="json") + parser = reqparse.RequestParser().add_argument("role", type=str, required=True, location="json") args = parser.parse_args() new_role = args["role"] @@ -199,8 +200,7 @@ class SendOwnerTransferEmailApi(Resource): @account_initialization_required @is_allow_transfer_owner def post(self): - parser = reqparse.RequestParser() - parser.add_argument("language", type=str, required=False, location="json") + parser = reqparse.RequestParser().add_argument("language", type=str, required=False, location="json") args = parser.parse_args() ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): @@ -236,9 +236,11 @@ class OwnerTransferCheckApi(Resource): @account_initialization_required @is_allow_transfer_owner def post(self): - parser = reqparse.RequestParser() - parser.add_argument("code", type=str, required=True, location="json") - parser.add_argument("token", type=str, required=True, nullable=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("code", type=str, required=True, location="json") + .add_argument("token", type=str, required=True, nullable=False, location="json") + ) args = parser.parse_args() # check if the current user is the owner of the workspace current_user, _ = current_account_with_tenant() @@ -281,8 +283,9 @@ class OwnerTransfer(Resource): @account_initialization_required @is_allow_transfer_owner def post(self, member_id): - parser = reqparse.RequestParser() - parser.add_argument("token", type=str, required=True, nullable=False, location="json") + parser = reqparse.RequestParser().add_argument( + "token", type=str, required=True, nullable=False, location="json" + ) args = parser.parse_args() # check if the current user is the owner of the workspace diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index acdd467b30..04db975fc2 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -24,8 +24,7 @@ class ModelProviderListApi(Resource): _, current_tenant_id = current_account_with_tenant() tenant_id = current_tenant_id - parser = reqparse.RequestParser() - parser.add_argument( + parser = reqparse.RequestParser().add_argument( "model_type", type=str, required=False, @@ -50,8 +49,9 @@ class ModelProviderCredentialApi(Resource): _, current_tenant_id = current_account_with_tenant() tenant_id = current_tenant_id # if credential_id is not provided, return current used credential - parser = reqparse.RequestParser() - parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args") + parser = reqparse.RequestParser().add_argument( + "credential_id", type=uuid_value, required=False, nullable=True, location="args" + ) args = parser.parse_args() model_provider_service = ModelProviderService() @@ -69,9 +69,11 @@ class ModelProviderCredentialApi(Resource): if not current_user.is_admin_or_owner: raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") - parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("credentials", type=dict, required=True, nullable=False, location="json") + .add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") + ) args = parser.parse_args() model_provider_service = ModelProviderService() @@ -96,10 +98,12 @@ class ModelProviderCredentialApi(Resource): if not current_user.is_admin_or_owner: raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") - parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") - parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") + .add_argument("credentials", type=dict, required=True, nullable=False, location="json") + .add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") + ) args = parser.parse_args() model_provider_service = ModelProviderService() @@ -124,8 +128,9 @@ class ModelProviderCredentialApi(Resource): current_user, current_tenant_id = current_account_with_tenant() if not current_user.is_admin_or_owner: raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") + parser = reqparse.RequestParser().add_argument( + "credential_id", type=uuid_value, required=True, nullable=False, location="json" + ) args = parser.parse_args() model_provider_service = ModelProviderService() @@ -145,8 +150,9 @@ class ModelProviderCredentialSwitchApi(Resource): current_user, current_tenant_id = current_account_with_tenant() if not current_user.is_admin_or_owner: raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") + parser = reqparse.RequestParser().add_argument( + "credential_id", type=str, required=True, nullable=False, location="json" + ) args = parser.parse_args() service = ModelProviderService() @@ -165,8 +171,9 @@ class ModelProviderValidateApi(Resource): @account_initialization_required def post(self, provider: str): _, current_tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser = reqparse.RequestParser().add_argument( + "credentials", type=dict, required=True, nullable=False, location="json" + ) args = parser.parse_args() tenant_id = current_tenant_id @@ -223,8 +230,7 @@ class PreferredProviderTypeUpdateApi(Resource): tenant_id = current_tenant_id - parser = reqparse.RequestParser() - parser.add_argument( + parser = reqparse.RequestParser().add_argument( "preferred_provider_type", type=str, required=True, diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index d5d1aed00e..5ab958d585 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -24,8 +24,7 @@ class DefaultModelApi(Resource): def get(self): _, tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument( + parser = reqparse.RequestParser().add_argument( "model_type", type=str, required=True, @@ -51,8 +50,9 @@ class DefaultModelApi(Resource): if not current_user.is_admin_or_owner: raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json") + parser = reqparse.RequestParser().add_argument( + "model_settings", type=list, required=True, nullable=False, location="json" + ) args = parser.parse_args() model_provider_service = ModelProviderService() model_settings = args["model_settings"] @@ -107,19 +107,21 @@ class ModelProviderModelApi(Resource): if not current_user.is_admin_or_owner: raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("model", type=str, required=True, nullable=False, location="json") - parser.add_argument( - "model_type", - type=str, - required=True, - nullable=False, - choices=[mt.value for mt in ModelType], - location="json", + parser = ( + reqparse.RequestParser() + .add_argument("model", type=str, required=True, nullable=False, location="json") + .add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + .add_argument("load_balancing", type=dict, required=False, nullable=True, location="json") + .add_argument("config_from", type=str, required=False, nullable=True, location="json") + .add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json") ) - parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json") - parser.add_argument("config_from", type=str, required=False, nullable=True, location="json") - parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json") args = parser.parse_args() if args.get("config_from", "") == "custom-model": @@ -167,15 +169,17 @@ class ModelProviderModelApi(Resource): if not current_user.is_admin_or_owner: raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("model", type=str, required=True, nullable=False, location="json") - parser.add_argument( - "model_type", - type=str, - required=True, - nullable=False, - choices=[mt.value for mt in ModelType], - location="json", + parser = ( + reqparse.RequestParser() + .add_argument("model", type=str, required=True, nullable=False, location="json") + .add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) ) args = parser.parse_args() @@ -195,18 +199,20 @@ class ModelProviderModelCredentialApi(Resource): def get(self, provider: str): _, tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("model", type=str, required=True, nullable=False, location="args") - parser.add_argument( - "model_type", - type=str, - required=True, - nullable=False, - choices=[mt.value for mt in ModelType], - location="args", + parser = ( + reqparse.RequestParser() + .add_argument("model", type=str, required=True, nullable=False, location="args") + .add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="args", + ) + .add_argument("config_from", type=str, required=False, nullable=True, location="args") + .add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args") ) - parser.add_argument("config_from", type=str, required=False, nullable=True, location="args") - parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args") args = parser.parse_args() model_provider_service = ModelProviderService() @@ -260,18 +266,20 @@ class ModelProviderModelCredentialApi(Resource): if not current_user.is_admin_or_owner: raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("model", type=str, required=True, nullable=False, location="json") - parser.add_argument( - "model_type", - type=str, - required=True, - nullable=False, - choices=[mt.value for mt in ModelType], - location="json", + parser = ( + reqparse.RequestParser() + .add_argument("model", type=str, required=True, nullable=False, location="json") + .add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + .add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") + .add_argument("credentials", type=dict, required=True, nullable=False, location="json") ) - parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") - parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() model_provider_service = ModelProviderService() @@ -305,19 +313,21 @@ class ModelProviderModelCredentialApi(Resource): if not current_user.is_admin_or_owner: raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("model", type=str, required=True, nullable=False, location="json") - parser.add_argument( - "model_type", - type=str, - required=True, - nullable=False, - choices=[mt.value for mt in ModelType], - location="json", + parser = ( + reqparse.RequestParser() + .add_argument("model", type=str, required=True, nullable=False, location="json") + .add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + .add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") + .add_argument("credentials", type=dict, required=True, nullable=False, location="json") + .add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") ) - parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") - parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") - parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") args = parser.parse_args() model_provider_service = ModelProviderService() @@ -345,17 +355,19 @@ class ModelProviderModelCredentialApi(Resource): if not current_user.is_admin_or_owner: raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("model", type=str, required=True, nullable=False, location="json") - parser.add_argument( - "model_type", - type=str, - required=True, - nullable=False, - choices=[mt.value for mt in ModelType], - location="json", + parser = ( + reqparse.RequestParser() + .add_argument("model", type=str, required=True, nullable=False, location="json") + .add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + .add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") ) - parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") args = parser.parse_args() model_provider_service = ModelProviderService() @@ -380,17 +392,19 @@ class ModelProviderModelCredentialSwitchApi(Resource): if not current_user.is_admin_or_owner: raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("model", type=str, required=True, nullable=False, location="json") - parser.add_argument( - "model_type", - type=str, - required=True, - nullable=False, - choices=[mt.value for mt in ModelType], - location="json", + parser = ( + reqparse.RequestParser() + .add_argument("model", type=str, required=True, nullable=False, location="json") + .add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + .add_argument("credential_id", type=str, required=True, nullable=False, location="json") ) - parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") args = parser.parse_args() service = ModelProviderService() @@ -414,15 +428,17 @@ class ModelProviderModelEnableApi(Resource): def patch(self, provider: str): _, tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("model", type=str, required=True, nullable=False, location="json") - parser.add_argument( - "model_type", - type=str, - required=True, - nullable=False, - choices=[mt.value for mt in ModelType], - location="json", + parser = ( + reqparse.RequestParser() + .add_argument("model", type=str, required=True, nullable=False, location="json") + .add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) ) args = parser.parse_args() @@ -444,15 +460,17 @@ class ModelProviderModelDisableApi(Resource): def patch(self, provider: str): _, tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("model", type=str, required=True, nullable=False, location="json") - parser.add_argument( - "model_type", - type=str, - required=True, - nullable=False, - choices=[mt.value for mt in ModelType], - location="json", + parser = ( + reqparse.RequestParser() + .add_argument("model", type=str, required=True, nullable=False, location="json") + .add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) ) args = parser.parse_args() @@ -472,17 +490,19 @@ class ModelProviderModelValidateApi(Resource): def post(self, provider: str): _, tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("model", type=str, required=True, nullable=False, location="json") - parser.add_argument( - "model_type", - type=str, - required=True, - nullable=False, - choices=[mt.value for mt in ModelType], - location="json", + parser = ( + reqparse.RequestParser() + .add_argument("model", type=str, required=True, nullable=False, location="json") + .add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + .add_argument("credentials", type=dict, required=True, nullable=False, location="json") ) - parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() model_provider_service = ModelProviderService() @@ -516,8 +536,9 @@ class ModelProviderModelParameterRuleApi(Resource): @login_required @account_initialization_required def get(self, provider: str): - parser = reqparse.RequestParser() - parser.add_argument("model", type=str, required=True, nullable=False, location="args") + parser = reqparse.RequestParser().add_argument( + "model", type=str, required=True, nullable=False, location="args" + ) args = parser.parse_args() _, tenant_id = current_account_with_tenant() diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index ed5426376f..e8bc312caf 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -44,9 +44,11 @@ class PluginListApi(Resource): @account_initialization_required def get(self): _, tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("page", type=int, required=False, location="args", default=1) - parser.add_argument("page_size", type=int, required=False, location="args", default=256) + parser = ( + reqparse.RequestParser() + .add_argument("page", type=int, required=False, location="args", default=1) + .add_argument("page_size", type=int, required=False, location="args", default=256) + ) args = parser.parse_args() try: plugins_with_total = PluginService.list_with_total(tenant_id, args["page"], args["page_size"]) @@ -62,8 +64,7 @@ class PluginListLatestVersionsApi(Resource): @login_required @account_initialization_required def post(self): - req = reqparse.RequestParser() - req.add_argument("plugin_ids", type=list, required=True, location="json") + req = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json") args = req.parse_args() try: @@ -82,8 +83,7 @@ class PluginListInstallationsFromIdsApi(Resource): def post(self): _, tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("plugin_ids", type=list, required=True, location="json") + parser = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json") args = parser.parse_args() try: @@ -98,9 +98,11 @@ class PluginListInstallationsFromIdsApi(Resource): class PluginIconApi(Resource): @setup_required def get(self): - req = reqparse.RequestParser() - req.add_argument("tenant_id", type=str, required=True, location="args") - req.add_argument("filename", type=str, required=True, location="args") + req = ( + reqparse.RequestParser() + .add_argument("tenant_id", type=str, required=True, location="args") + .add_argument("filename", type=str, required=True, location="args") + ) args = req.parse_args() try: @@ -145,10 +147,12 @@ class PluginUploadFromGithubApi(Resource): def post(self): _, tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("repo", type=str, required=True, location="json") - parser.add_argument("version", type=str, required=True, location="json") - parser.add_argument("package", type=str, required=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("repo", type=str, required=True, location="json") + .add_argument("version", type=str, required=True, location="json") + .add_argument("package", type=str, required=True, location="json") + ) args = parser.parse_args() try: @@ -192,8 +196,9 @@ class PluginInstallFromPkgApi(Resource): def post(self): _, tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json") + parser = reqparse.RequestParser().add_argument( + "plugin_unique_identifiers", type=list, required=True, location="json" + ) args = parser.parse_args() # check if all plugin_unique_identifiers are valid string @@ -218,11 +223,13 @@ class PluginInstallFromGithubApi(Resource): def post(self): _, tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("repo", type=str, required=True, location="json") - parser.add_argument("version", type=str, required=True, location="json") - parser.add_argument("package", type=str, required=True, location="json") - parser.add_argument("plugin_unique_identifier", type=str, required=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("repo", type=str, required=True, location="json") + .add_argument("version", type=str, required=True, location="json") + .add_argument("package", type=str, required=True, location="json") + .add_argument("plugin_unique_identifier", type=str, required=True, location="json") + ) args = parser.parse_args() try: @@ -248,8 +255,9 @@ class PluginInstallFromMarketplaceApi(Resource): def post(self): _, tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json") + parser = reqparse.RequestParser().add_argument( + "plugin_unique_identifiers", type=list, required=True, location="json" + ) args = parser.parse_args() # check if all plugin_unique_identifiers are valid string @@ -274,8 +282,9 @@ class PluginFetchMarketplacePkgApi(Resource): def get(self): _, tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args") + parser = reqparse.RequestParser().add_argument( + "plugin_unique_identifier", type=str, required=True, location="args" + ) args = parser.parse_args() try: @@ -300,8 +309,9 @@ class PluginFetchManifestApi(Resource): def get(self): _, tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args") + parser = reqparse.RequestParser().add_argument( + "plugin_unique_identifier", type=str, required=True, location="args" + ) args = parser.parse_args() try: @@ -325,9 +335,11 @@ class PluginFetchInstallTasksApi(Resource): def get(self): _, tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("page", type=int, required=True, location="args") - parser.add_argument("page_size", type=int, required=True, location="args") + parser = ( + reqparse.RequestParser() + .add_argument("page", type=int, required=True, location="args") + .add_argument("page_size", type=int, required=True, location="args") + ) args = parser.parse_args() try: @@ -407,9 +419,11 @@ class PluginUpgradeFromMarketplaceApi(Resource): def post(self): _, tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json") - parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("original_plugin_unique_identifier", type=str, required=True, location="json") + .add_argument("new_plugin_unique_identifier", type=str, required=True, location="json") + ) args = parser.parse_args() try: @@ -431,12 +445,14 @@ class PluginUpgradeFromGithubApi(Resource): def post(self): _, tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json") - parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json") - parser.add_argument("repo", type=str, required=True, location="json") - parser.add_argument("version", type=str, required=True, location="json") - parser.add_argument("package", type=str, required=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("original_plugin_unique_identifier", type=str, required=True, location="json") + .add_argument("new_plugin_unique_identifier", type=str, required=True, location="json") + .add_argument("repo", type=str, required=True, location="json") + .add_argument("version", type=str, required=True, location="json") + .add_argument("package", type=str, required=True, location="json") + ) args = parser.parse_args() try: @@ -461,8 +477,7 @@ class PluginUninstallApi(Resource): @account_initialization_required @plugin_permission_required(install_required=True) def post(self): - req = reqparse.RequestParser() - req.add_argument("plugin_installation_id", type=str, required=True, location="json") + req = reqparse.RequestParser().add_argument("plugin_installation_id", type=str, required=True, location="json") args = req.parse_args() _, tenant_id = current_account_with_tenant() @@ -484,9 +499,11 @@ class PluginChangePermissionApi(Resource): if not user.is_admin_or_owner: raise Forbidden() - req = reqparse.RequestParser() - req.add_argument("install_permission", type=str, required=True, location="json") - req.add_argument("debug_permission", type=str, required=True, location="json") + req = ( + reqparse.RequestParser() + .add_argument("install_permission", type=str, required=True, location="json") + .add_argument("debug_permission", type=str, required=True, location="json") + ) args = req.parse_args() install_permission = TenantPluginPermission.InstallPermission(args["install_permission"]) @@ -535,12 +552,14 @@ class PluginFetchDynamicSelectOptionsApi(Resource): user_id = current_user.id - parser = reqparse.RequestParser() - parser.add_argument("plugin_id", type=str, required=True, location="args") - parser.add_argument("provider", type=str, required=True, location="args") - parser.add_argument("action", type=str, required=True, location="args") - parser.add_argument("parameter", type=str, required=True, location="args") - parser.add_argument("provider_type", type=str, required=True, location="args") + parser = ( + reqparse.RequestParser() + .add_argument("plugin_id", type=str, required=True, location="args") + .add_argument("provider", type=str, required=True, location="args") + .add_argument("action", type=str, required=True, location="args") + .add_argument("parameter", type=str, required=True, location="args") + .add_argument("provider_type", type=str, required=True, location="args") + ) args = parser.parse_args() try: @@ -569,9 +588,11 @@ class PluginChangePreferencesApi(Resource): if not user.is_admin_or_owner: raise Forbidden() - req = reqparse.RequestParser() - req.add_argument("permission", type=dict, required=True, location="json") - req.add_argument("auto_upgrade", type=dict, required=True, location="json") + req = ( + reqparse.RequestParser() + .add_argument("permission", type=dict, required=True, location="json") + .add_argument("auto_upgrade", type=dict, required=True, location="json") + ) args = req.parse_args() permission = args["permission"] @@ -661,8 +682,7 @@ class PluginAutoUpgradeExcludePluginApi(Resource): # exclude one single plugin _, tenant_id = current_account_with_tenant() - req = reqparse.RequestParser() - req.add_argument("plugin_id", type=str, required=True, location="json") + req = reqparse.RequestParser().add_argument("plugin_id", type=str, required=True, location="json") args = req.parse_args() return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])}) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 17a935ade7..cc50131f0a 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -56,8 +56,7 @@ class ToolProviderListApi(Resource): user_id = user.id - req = reqparse.RequestParser() - req.add_argument( + req = reqparse.RequestParser().add_argument( "type", type=str, choices=["builtin", "model", "api", "workflow", "mcp"], @@ -107,8 +106,9 @@ class ToolBuiltinProviderDeleteApi(Resource): if not user.is_admin_or_owner: raise Forbidden() - req = reqparse.RequestParser() - req.add_argument("credential_id", type=str, required=True, nullable=False, location="json") + req = reqparse.RequestParser().add_argument( + "credential_id", type=str, required=True, nullable=False, location="json" + ) args = req.parse_args() return BuiltinToolManageService.delete_builtin_tool_provider( @@ -128,10 +128,12 @@ class ToolBuiltinProviderAddApi(Resource): user_id = user.id - parser = reqparse.RequestParser() - parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") - parser.add_argument("name", type=StrLen(30), required=False, nullable=False, location="json") - parser.add_argument("type", type=str, required=True, nullable=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("credentials", type=dict, required=True, nullable=False, location="json") + .add_argument("name", type=StrLen(30), required=False, nullable=False, location="json") + .add_argument("type", type=str, required=True, nullable=False, location="json") + ) args = parser.parse_args() if args["type"] not in CredentialType.values(): @@ -160,10 +162,12 @@ class ToolBuiltinProviderUpdateApi(Resource): user_id = user.id - parser = reqparse.RequestParser() - parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") - parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") - parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("credential_id", type=str, required=True, nullable=False, location="json") + .add_argument("credentials", type=dict, required=False, nullable=True, location="json") + .add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") + ) args = parser.parse_args() @@ -216,15 +220,17 @@ class ToolApiProviderAddApi(Resource): user_id = user.id - parser = reqparse.RequestParser() - parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") - parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json") - parser.add_argument("schema", type=str, required=True, nullable=False, location="json") - parser.add_argument("provider", type=str, required=True, nullable=False, location="json") - parser.add_argument("icon", type=dict, required=True, nullable=False, location="json") - parser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json") - parser.add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[]) - parser.add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("credentials", type=dict, required=True, nullable=False, location="json") + .add_argument("schema_type", type=str, required=True, nullable=False, location="json") + .add_argument("schema", type=str, required=True, nullable=False, location="json") + .add_argument("provider", type=str, required=True, nullable=False, location="json") + .add_argument("icon", type=dict, required=True, nullable=False, location="json") + .add_argument("privacy_policy", type=str, required=False, nullable=True, location="json") + .add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[]) + .add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json") + ) args = parser.parse_args() @@ -252,9 +258,7 @@ class ToolApiProviderGetRemoteSchemaApi(Resource): user_id = user.id - parser = reqparse.RequestParser() - - parser.add_argument("url", type=str, required=True, nullable=False, location="args") + parser = reqparse.RequestParser().add_argument("url", type=str, required=True, nullable=False, location="args") args = parser.parse_args() @@ -275,9 +279,9 @@ class ToolApiProviderListToolsApi(Resource): user_id = user.id - parser = reqparse.RequestParser() - - parser.add_argument("provider", type=str, required=True, nullable=False, location="args") + parser = reqparse.RequestParser().add_argument( + "provider", type=str, required=True, nullable=False, location="args" + ) args = parser.parse_args() @@ -303,16 +307,18 @@ class ToolApiProviderUpdateApi(Resource): user_id = user.id - parser = reqparse.RequestParser() - parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") - parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json") - parser.add_argument("schema", type=str, required=True, nullable=False, location="json") - parser.add_argument("provider", type=str, required=True, nullable=False, location="json") - parser.add_argument("original_provider", type=str, required=True, nullable=False, location="json") - parser.add_argument("icon", type=dict, required=True, nullable=False, location="json") - parser.add_argument("privacy_policy", type=str, required=True, nullable=True, location="json") - parser.add_argument("labels", type=list[str], required=False, nullable=True, location="json") - parser.add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("credentials", type=dict, required=True, nullable=False, location="json") + .add_argument("schema_type", type=str, required=True, nullable=False, location="json") + .add_argument("schema", type=str, required=True, nullable=False, location="json") + .add_argument("provider", type=str, required=True, nullable=False, location="json") + .add_argument("original_provider", type=str, required=True, nullable=False, location="json") + .add_argument("icon", type=dict, required=True, nullable=False, location="json") + .add_argument("privacy_policy", type=str, required=True, nullable=True, location="json") + .add_argument("labels", type=list[str], required=False, nullable=True, location="json") + .add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json") + ) args = parser.parse_args() @@ -344,9 +350,9 @@ class ToolApiProviderDeleteApi(Resource): user_id = user.id - parser = reqparse.RequestParser() - - parser.add_argument("provider", type=str, required=True, nullable=False, location="json") + parser = reqparse.RequestParser().add_argument( + "provider", type=str, required=True, nullable=False, location="json" + ) args = parser.parse_args() @@ -367,9 +373,9 @@ class ToolApiProviderGetApi(Resource): user_id = user.id - parser = reqparse.RequestParser() - - parser.add_argument("provider", type=str, required=True, nullable=False, location="args") + parser = reqparse.RequestParser().add_argument( + "provider", type=str, required=True, nullable=False, location="args" + ) args = parser.parse_args() @@ -401,9 +407,9 @@ class ToolApiProviderSchemaApi(Resource): @login_required @account_initialization_required def post(self): - parser = reqparse.RequestParser() - - parser.add_argument("schema", type=str, required=True, nullable=False, location="json") + parser = reqparse.RequestParser().add_argument( + "schema", type=str, required=True, nullable=False, location="json" + ) args = parser.parse_args() @@ -418,14 +424,15 @@ class ToolApiProviderPreviousTestApi(Resource): @login_required @account_initialization_required def post(self): - parser = reqparse.RequestParser() - - parser.add_argument("tool_name", type=str, required=True, nullable=False, location="json") - parser.add_argument("provider_name", type=str, required=False, nullable=False, location="json") - parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") - parser.add_argument("parameters", type=dict, required=True, nullable=False, location="json") - parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json") - parser.add_argument("schema", type=str, required=True, nullable=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("tool_name", type=str, required=True, nullable=False, location="json") + .add_argument("provider_name", type=str, required=False, nullable=False, location="json") + .add_argument("credentials", type=dict, required=True, nullable=False, location="json") + .add_argument("parameters", type=dict, required=True, nullable=False, location="json") + .add_argument("schema_type", type=str, required=True, nullable=False, location="json") + .add_argument("schema", type=str, required=True, nullable=False, location="json") + ) args = parser.parse_args() _, current_tenant_id = current_account_with_tenant() @@ -453,15 +460,17 @@ class ToolWorkflowProviderCreateApi(Resource): user_id = user.id - reqparser = reqparse.RequestParser() - reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json") - reqparser.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") - reqparser.add_argument("label", type=str, required=True, nullable=False, location="json") - reqparser.add_argument("description", type=str, required=True, nullable=False, location="json") - reqparser.add_argument("icon", type=dict, required=True, nullable=False, location="json") - reqparser.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") - reqparser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") - reqparser.add_argument("labels", type=list[str], required=False, nullable=True, location="json") + reqparser = ( + reqparse.RequestParser() + .add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json") + .add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") + .add_argument("label", type=str, required=True, nullable=False, location="json") + .add_argument("description", type=str, required=True, nullable=False, location="json") + .add_argument("icon", type=dict, required=True, nullable=False, location="json") + .add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") + .add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") + .add_argument("labels", type=list[str], required=False, nullable=True, location="json") + ) args = reqparser.parse_args() @@ -492,15 +501,17 @@ class ToolWorkflowProviderUpdateApi(Resource): user_id = user.id - reqparser = reqparse.RequestParser() - reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") - reqparser.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") - reqparser.add_argument("label", type=str, required=True, nullable=False, location="json") - reqparser.add_argument("description", type=str, required=True, nullable=False, location="json") - reqparser.add_argument("icon", type=dict, required=True, nullable=False, location="json") - reqparser.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") - reqparser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") - reqparser.add_argument("labels", type=list[str], required=False, nullable=True, location="json") + reqparser = ( + reqparse.RequestParser() + .add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") + .add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") + .add_argument("label", type=str, required=True, nullable=False, location="json") + .add_argument("description", type=str, required=True, nullable=False, location="json") + .add_argument("icon", type=dict, required=True, nullable=False, location="json") + .add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") + .add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") + .add_argument("labels", type=list[str], required=False, nullable=True, location="json") + ) args = reqparser.parse_args() @@ -534,8 +545,9 @@ class ToolWorkflowProviderDeleteApi(Resource): user_id = user.id - reqparser = reqparse.RequestParser() - reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") + reqparser = reqparse.RequestParser().add_argument( + "workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json" + ) args = reqparser.parse_args() @@ -556,9 +568,11 @@ class ToolWorkflowProviderGetApi(Resource): user_id = user.id - parser = reqparse.RequestParser() - parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args") - parser.add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args") + parser = ( + reqparse.RequestParser() + .add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args") + .add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args") + ) args = parser.parse_args() @@ -590,8 +604,9 @@ class ToolWorkflowProviderListToolApi(Resource): user_id = user.id - parser = reqparse.RequestParser() - parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args") + parser = reqparse.RequestParser().add_argument( + "workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args" + ) args = parser.parse_args() @@ -776,8 +791,7 @@ class ToolBuiltinProviderSetDefaultApi(Resource): @account_initialization_required def post(self, provider): current_user, current_tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("id", type=str, required=True, nullable=False, location="json") + parser = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json") args = parser.parse_args() return BuiltinToolManageService.set_default_provider( tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"] @@ -790,9 +804,11 @@ class ToolOAuthCustomClient(Resource): @login_required @account_initialization_required def post(self, provider): - parser = reqparse.RequestParser() - parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json") - parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("client_params", type=dict, required=False, nullable=True, location="json") + .add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json") + ) args = parser.parse_args() user, tenant_id = current_account_with_tenant() @@ -862,18 +878,18 @@ class ToolProviderMCPApi(Resource): @login_required @account_initialization_required def post(self): - parser = reqparse.RequestParser() - parser.add_argument("server_url", type=str, required=True, nullable=False, location="json") - parser.add_argument("name", type=str, required=True, nullable=False, location="json") - parser.add_argument("icon", type=str, required=True, nullable=False, location="json") - parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json") - parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="") - parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json") - parser.add_argument("timeout", type=float, required=False, nullable=False, location="json", default=30) - parser.add_argument( - "sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300 + parser = ( + reqparse.RequestParser() + .add_argument("server_url", type=str, required=True, nullable=False, location="json") + .add_argument("name", type=str, required=True, nullable=False, location="json") + .add_argument("icon", type=str, required=True, nullable=False, location="json") + .add_argument("icon_type", type=str, required=True, nullable=False, location="json") + .add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="") + .add_argument("server_identifier", type=str, required=True, nullable=False, location="json") + .add_argument("timeout", type=float, required=False, nullable=False, location="json", default=30) + .add_argument("sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300) + .add_argument("headers", type=dict, required=False, nullable=True, location="json", default={}) ) - parser.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={}) args = parser.parse_args() user, tenant_id = current_account_with_tenant() if not is_valid_url(args["server_url"]): @@ -898,17 +914,19 @@ class ToolProviderMCPApi(Resource): @login_required @account_initialization_required def put(self): - parser = reqparse.RequestParser() - parser.add_argument("server_url", type=str, required=True, nullable=False, location="json") - parser.add_argument("name", type=str, required=True, nullable=False, location="json") - parser.add_argument("icon", type=str, required=True, nullable=False, location="json") - parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json") - parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json") - parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") - parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json") - parser.add_argument("timeout", type=float, required=False, nullable=True, location="json") - parser.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json") - parser.add_argument("headers", type=dict, required=False, nullable=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("server_url", type=str, required=True, nullable=False, location="json") + .add_argument("name", type=str, required=True, nullable=False, location="json") + .add_argument("icon", type=str, required=True, nullable=False, location="json") + .add_argument("icon_type", type=str, required=True, nullable=False, location="json") + .add_argument("icon_background", type=str, required=False, nullable=True, location="json") + .add_argument("provider_id", type=str, required=True, nullable=False, location="json") + .add_argument("server_identifier", type=str, required=True, nullable=False, location="json") + .add_argument("timeout", type=float, required=False, nullable=True, location="json") + .add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json") + .add_argument("headers", type=dict, required=False, nullable=True, location="json") + ) args = parser.parse_args() if not is_valid_url(args["server_url"]): if "[__HIDDEN__]" in args["server_url"]: @@ -935,8 +953,9 @@ class ToolProviderMCPApi(Resource): @login_required @account_initialization_required def delete(self): - parser = reqparse.RequestParser() - parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") + parser = reqparse.RequestParser().add_argument( + "provider_id", type=str, required=True, nullable=False, location="json" + ) args = parser.parse_args() _, current_tenant_id = current_account_with_tenant() MCPToolManageService.delete_mcp_tool(tenant_id=current_tenant_id, provider_id=args["provider_id"]) @@ -949,9 +968,11 @@ class ToolMCPAuthApi(Resource): @login_required @account_initialization_required def post(self): - parser = reqparse.RequestParser() - parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") - parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("provider_id", type=str, required=True, nullable=False, location="json") + .add_argument("authorization_code", type=str, required=False, nullable=True, location="json") + ) args = parser.parse_args() provider_id = args["provider_id"] _, tenant_id = current_account_with_tenant() @@ -1030,9 +1051,11 @@ class ToolMCPUpdateApi(Resource): @console_ns.route("/mcp/oauth/callback") class ToolMCPCallbackApi(Resource): def get(self): - parser = reqparse.RequestParser() - parser.add_argument("code", type=str, required=True, nullable=False, location="args") - parser.add_argument("state", type=str, required=True, nullable=False, location="args") + parser = ( + reqparse.RequestParser() + .add_argument("code", type=str, required=True, nullable=False, location="args") + .add_argument("state", type=str, required=True, nullable=False, location="args") + ) args = parser.parse_args() state_key = args["state"] authorization_code = args["code"] diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 5be427e9bb..f9856df9ea 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -97,9 +97,11 @@ class WorkspaceListApi(Resource): @setup_required @admin_required def get(self): - parser = reqparse.RequestParser() - parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") - parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") + parser = ( + reqparse.RequestParser() + .add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") + .add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") + ) args = parser.parse_args() stmt = select(Tenant).order_by(Tenant.created_at.desc()) @@ -154,8 +156,7 @@ class SwitchWorkspaceApi(Resource): @account_initialization_required def post(self): current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("tenant_id", type=str, required=True, location="json") + parser = reqparse.RequestParser().add_argument("tenant_id", type=str, required=True, location="json") args = parser.parse_args() # check if tenant_id is valid, 403 if not @@ -179,9 +180,11 @@ class CustomConfigWorkspaceApi(Resource): @cloud_edition_billing_resource_check("workspace_custom") def post(self): _, current_tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("remove_webapp_brand", type=bool, location="json") - parser.add_argument("replace_webapp_logo", type=str, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("remove_webapp_brand", type=bool, location="json") + .add_argument("replace_webapp_logo", type=str, location="json") + ) args = parser.parse_args() tenant = db.get_or_404(Tenant, current_tenant_id) @@ -246,8 +249,7 @@ class WorkspaceInfoApi(Resource): # Change workspace name def post(self): _, current_tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=True, location="json") + parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json") args = parser.parse_args() if not current_tenant_id: diff --git a/api/controllers/files/image_preview.py b/api/controllers/files/image_preview.py index 0efee0c377..3db82456d5 100644 --- a/api/controllers/files/image_preview.py +++ b/api/controllers/files/image_preview.py @@ -46,11 +46,13 @@ class FilePreviewApi(Resource): def get(self, file_id): file_id = str(file_id) - parser = reqparse.RequestParser() - parser.add_argument("timestamp", type=str, required=True, location="args") - parser.add_argument("nonce", type=str, required=True, location="args") - parser.add_argument("sign", type=str, required=True, location="args") - parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args") + parser = ( + reqparse.RequestParser() + .add_argument("timestamp", type=str, required=True, location="args") + .add_argument("nonce", type=str, required=True, location="args") + .add_argument("sign", type=str, required=True, location="args") + .add_argument("as_attachment", type=bool, required=False, default=False, location="args") + ) args = parser.parse_args() diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index 42207b878c..dec5a4a1b2 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -16,12 +16,13 @@ class ToolFileApi(Resource): def get(self, file_id, extension): file_id = str(file_id) - parser = reqparse.RequestParser() - - parser.add_argument("timestamp", type=str, required=True, location="args") - parser.add_argument("nonce", type=str, required=True, location="args") - parser.add_argument("sign", type=str, required=True, location="args") - parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args") + parser = ( + reqparse.RequestParser() + .add_argument("timestamp", type=str, required=True, location="args") + .add_argument("nonce", type=str, required=True, location="args") + .add_argument("sign", type=str, required=True, location="args") + .add_argument("as_attachment", type=bool, required=False, default=False, location="args") + ) args = parser.parse_args() if not verify_tool_file_signature( diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py index 206a5d1cc2..a09e24e2d9 100644 --- a/api/controllers/files/upload.py +++ b/api/controllers/files/upload.py @@ -18,19 +18,17 @@ from core.tools.tool_file_manager import ToolFileManager from fields.file_fields import build_file_model # Define parser for both documentation and validation -upload_parser = reqparse.RequestParser() -upload_parser.add_argument("file", location="files", type=FileStorage, required=True, help="File to upload") -upload_parser.add_argument( - "timestamp", type=str, required=True, location="args", help="Unix timestamp for signature verification" +upload_parser = ( + reqparse.RequestParser() + .add_argument("file", location="files", type=FileStorage, required=True, help="File to upload") + .add_argument( + "timestamp", type=str, required=True, location="args", help="Unix timestamp for signature verification" + ) + .add_argument("nonce", type=str, required=True, location="args", help="Random string for signature verification") + .add_argument("sign", type=str, required=True, location="args", help="HMAC signature for request validation") + .add_argument("tenant_id", type=str, required=True, location="args", help="Tenant identifier") + .add_argument("user_id", type=str, required=False, location="args", help="User identifier") ) -upload_parser.add_argument( - "nonce", type=str, required=True, location="args", help="Random string for signature verification" -) -upload_parser.add_argument( - "sign", type=str, required=True, location="args", help="HMAC signature for request validation" -) -upload_parser.add_argument("tenant_id", type=str, required=True, location="args", help="Tenant identifier") -upload_parser.add_argument("user_id", type=str, required=False, location="args", help="User identifier") @files_ns.route("/upload/for-plugin") diff --git a/api/controllers/inner_api/mail.py b/api/controllers/inner_api/mail.py index 39411a077a..7e40d81706 100644 --- a/api/controllers/inner_api/mail.py +++ b/api/controllers/inner_api/mail.py @@ -5,11 +5,13 @@ from controllers.inner_api import inner_api_ns from controllers.inner_api.wraps import billing_inner_api_only, enterprise_inner_api_only from tasks.mail_inner_task import send_inner_email_task -_mail_parser = reqparse.RequestParser() -_mail_parser.add_argument("to", type=str, action="append", required=True) -_mail_parser.add_argument("subject", type=str, required=True) -_mail_parser.add_argument("body", type=str, required=True) -_mail_parser.add_argument("substitutions", type=dict, required=False) +_mail_parser = ( + reqparse.RequestParser() + .add_argument("to", type=str, action="append", required=True) + .add_argument("subject", type=str, required=True) + .add_argument("body", type=str, required=True) + .add_argument("substitutions", type=dict, required=False) +) class BaseMail(Resource): diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index 1f588bedce..2a57bb745b 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -72,9 +72,11 @@ def get_user_tenant(view: Callable[P, R] | None = None): @wraps(view_func) def decorated_view(*args: P.args, **kwargs: P.kwargs): # fetch json body - parser = reqparse.RequestParser() - parser.add_argument("tenant_id", type=str, required=True, location="json") - parser.add_argument("user_id", type=str, required=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("tenant_id", type=str, required=True, location="json") + .add_argument("user_id", type=str, required=True, location="json") + ) p = parser.parse_args() diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py index 861da57708..8391a15919 100644 --- a/api/controllers/inner_api/workspace/workspace.py +++ b/api/controllers/inner_api/workspace/workspace.py @@ -25,9 +25,11 @@ class EnterpriseWorkspace(Resource): } ) def post(self): - parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=True, location="json") - parser.add_argument("owner_email", type=str, required=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("name", type=str, required=True, location="json") + .add_argument("owner_email", type=str, required=True, location="json") + ) args = parser.parse_args() account = db.session.query(Account).filter_by(email=args["owner_email"]).first() @@ -68,8 +70,7 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource): } ) def post(self): - parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=True, location="json") + parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json") args = parser.parse_args() tenant = TenantService.create_tenant(args["name"], is_from_dashboard=True) diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index a8629dca20..85b7df229f 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -33,14 +33,12 @@ def int_or_str(value): # Define parser for both documentation and validation -mcp_request_parser = reqparse.RequestParser() -mcp_request_parser.add_argument( - "jsonrpc", type=str, required=True, location="json", help="JSON-RPC version (should be '2.0')" -) -mcp_request_parser.add_argument("method", type=str, required=True, location="json", help="The method to invoke") -mcp_request_parser.add_argument("params", type=dict, required=False, location="json", help="Parameters for the method") -mcp_request_parser.add_argument( - "id", type=int_or_str, required=False, location="json", help="Request ID for tracking responses" +mcp_request_parser = ( + reqparse.RequestParser() + .add_argument("jsonrpc", type=str, required=True, location="json", help="JSON-RPC version (should be '2.0')") + .add_argument("method", type=str, required=True, location="json", help="The method to invoke") + .add_argument("params", type=dict, required=False, location="json", help="Parameters for the method") + .add_argument("id", type=int_or_str, required=False, location="json", help="Request ID for tracking responses") ) diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index 0521f1537c..ed013b1674 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -15,19 +15,19 @@ from models.model import App from services.annotation_service import AppAnnotationService # Define parsers for annotation API -annotation_create_parser = reqparse.RequestParser() -annotation_create_parser.add_argument("question", required=True, type=str, location="json", help="Annotation question") -annotation_create_parser.add_argument("answer", required=True, type=str, location="json", help="Annotation answer") +annotation_create_parser = ( + reqparse.RequestParser() + .add_argument("question", required=True, type=str, location="json", help="Annotation question") + .add_argument("answer", required=True, type=str, location="json", help="Annotation answer") +) -annotation_reply_action_parser = reqparse.RequestParser() -annotation_reply_action_parser.add_argument( - "score_threshold", required=True, type=float, location="json", help="Score threshold for annotation matching" -) -annotation_reply_action_parser.add_argument( - "embedding_provider_name", required=True, type=str, location="json", help="Embedding provider name" -) -annotation_reply_action_parser.add_argument( - "embedding_model_name", required=True, type=str, location="json", help="Embedding model name" +annotation_reply_action_parser = ( + reqparse.RequestParser() + .add_argument( + "score_threshold", required=True, type=float, location="json", help="Score threshold for annotation matching" + ) + .add_argument("embedding_provider_name", required=True, type=str, location="json", help="Embedding provider name") + .add_argument("embedding_model_name", required=True, type=str, location="json", help="Embedding model name") ) diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 33035123d7..c069a7ddfb 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -85,11 +85,13 @@ class AudioApi(Resource): # Define parser for text-to-audio API -text_to_audio_parser = reqparse.RequestParser() -text_to_audio_parser.add_argument("message_id", type=str, required=False, location="json", help="Message ID") -text_to_audio_parser.add_argument("voice", type=str, location="json", help="Voice to use for TTS") -text_to_audio_parser.add_argument("text", type=str, location="json", help="Text to convert to audio") -text_to_audio_parser.add_argument("streaming", type=bool, location="json", help="Enable streaming response") +text_to_audio_parser = ( + reqparse.RequestParser() + .add_argument("message_id", type=str, required=False, location="json", help="Message ID") + .add_argument("voice", type=str, location="json", help="Voice to use for TTS") + .add_argument("text", type=str, location="json", help="Text to convert to audio") + .add_argument("streaming", type=bool, location="json", help="Enable streaming response") +) @service_api_ns.route("/text-to-audio") diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 22428ee0ab..915e7e9416 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -37,40 +37,34 @@ logger = logging.getLogger(__name__) # Define parser for completion API -completion_parser = reqparse.RequestParser() -completion_parser.add_argument( - "inputs", type=dict, required=True, location="json", help="Input parameters for completion" -) -completion_parser.add_argument("query", type=str, location="json", default="", help="The query string") -completion_parser.add_argument("files", type=list, required=False, location="json", help="List of file attachments") -completion_parser.add_argument( - "response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode" -) -completion_parser.add_argument( - "retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source" +completion_parser = ( + reqparse.RequestParser() + .add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for completion") + .add_argument("query", type=str, location="json", default="", help="The query string") + .add_argument("files", type=list, required=False, location="json", help="List of file attachments") + .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode") + .add_argument("retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source") ) # Define parser for chat API -chat_parser = reqparse.RequestParser() -chat_parser.add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for chat") -chat_parser.add_argument("query", type=str, required=True, location="json", help="The chat query") -chat_parser.add_argument("files", type=list, required=False, location="json", help="List of file attachments") -chat_parser.add_argument( - "response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode" +chat_parser = ( + reqparse.RequestParser() + .add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for chat") + .add_argument("query", type=str, required=True, location="json", help="The chat query") + .add_argument("files", type=list, required=False, location="json", help="List of file attachments") + .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode") + .add_argument("conversation_id", type=uuid_value, location="json", help="Existing conversation ID") + .add_argument("retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source") + .add_argument( + "auto_generate_name", + type=bool, + required=False, + default=True, + location="json", + help="Auto generate conversation name", + ) + .add_argument("workflow_id", type=str, required=False, location="json", help="Workflow ID for advanced chat") ) -chat_parser.add_argument("conversation_id", type=uuid_value, location="json", help="Existing conversation ID") -chat_parser.add_argument( - "retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source" -) -chat_parser.add_argument( - "auto_generate_name", - type=bool, - required=False, - default=True, - location="json", - help="Auto generate conversation name", -) -chat_parser.add_argument("workflow_id", type=str, required=False, location="json", help="Workflow ID for advanced chat") @service_api_ns.route("/completion-messages") diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 711dd5704c..c4e23dd2e7 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -24,48 +24,63 @@ from models.model import App, AppMode, EndUser from services.conversation_service import ConversationService # Define parsers for conversation APIs -conversation_list_parser = reqparse.RequestParser() -conversation_list_parser.add_argument( - "last_id", type=uuid_value, location="args", help="Last conversation ID for pagination" -) -conversation_list_parser.add_argument( - "limit", - type=int_range(1, 100), - required=False, - default=20, - location="args", - help="Number of conversations to return", -) -conversation_list_parser.add_argument( - "sort_by", - type=str, - choices=["created_at", "-created_at", "updated_at", "-updated_at"], - required=False, - default="-updated_at", - location="args", - help="Sort order for conversations", +conversation_list_parser = ( + reqparse.RequestParser() + .add_argument("last_id", type=uuid_value, location="args", help="Last conversation ID for pagination") + .add_argument( + "limit", + type=int_range(1, 100), + required=False, + default=20, + location="args", + help="Number of conversations to return", + ) + .add_argument( + "sort_by", + type=str, + choices=["created_at", "-created_at", "updated_at", "-updated_at"], + required=False, + default="-updated_at", + location="args", + help="Sort order for conversations", + ) ) -conversation_rename_parser = reqparse.RequestParser() -conversation_rename_parser.add_argument("name", type=str, required=False, location="json", help="New conversation name") -conversation_rename_parser.add_argument( - "auto_generate", type=bool, required=False, default=False, location="json", help="Auto-generate conversation name" +conversation_rename_parser = ( + reqparse.RequestParser() + .add_argument("name", type=str, required=False, location="json", help="New conversation name") + .add_argument( + "auto_generate", + type=bool, + required=False, + default=False, + location="json", + help="Auto-generate conversation name", + ) ) -conversation_variables_parser = reqparse.RequestParser() -conversation_variables_parser.add_argument( - "last_id", type=uuid_value, location="args", help="Last variable ID for pagination" -) -conversation_variables_parser.add_argument( - "limit", type=int_range(1, 100), required=False, default=20, location="args", help="Number of variables to return" +conversation_variables_parser = ( + reqparse.RequestParser() + .add_argument("last_id", type=uuid_value, location="args", help="Last variable ID for pagination") + .add_argument( + "limit", + type=int_range(1, 100), + required=False, + default=20, + location="args", + help="Number of variables to return", + ) ) -conversation_variable_update_parser = reqparse.RequestParser() -# using lambda is for passing the already-typed value without modification -# if no lambda, it will be converted to string -# the string cannot be converted using json.loads -conversation_variable_update_parser.add_argument( - "value", required=True, location="json", type=lambda x: x, help="New value for the conversation variable" +conversation_variable_update_parser = reqparse.RequestParser().add_argument( + # using lambda is for passing the already-typed value without modification + # if no lambda, it will be converted to string + # the string cannot be converted using json.loads + "value", + required=True, + location="json", + type=lambda x: x, + help="New value for the conversation variable", ) diff --git a/api/controllers/service_api/app/file_preview.py b/api/controllers/service_api/app/file_preview.py index 63b46f49f2..b8e91f0657 100644 --- a/api/controllers/service_api/app/file_preview.py +++ b/api/controllers/service_api/app/file_preview.py @@ -18,8 +18,7 @@ logger = logging.getLogger(__name__) # Define parser for file preview API -file_preview_parser = reqparse.RequestParser() -file_preview_parser.add_argument( +file_preview_parser = reqparse.RequestParser().add_argument( "as_attachment", type=bool, required=False, default=False, location="args", help="Download as attachment" ) diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index fc506ef723..b8e5ed28e4 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -26,25 +26,37 @@ logger = logging.getLogger(__name__) # Define parsers for message APIs -message_list_parser = reqparse.RequestParser() -message_list_parser.add_argument( - "conversation_id", required=True, type=uuid_value, location="args", help="Conversation ID" -) -message_list_parser.add_argument("first_id", type=uuid_value, location="args", help="First message ID for pagination") -message_list_parser.add_argument( - "limit", type=int_range(1, 100), required=False, default=20, location="args", help="Number of messages to return" +message_list_parser = ( + reqparse.RequestParser() + .add_argument("conversation_id", required=True, type=uuid_value, location="args", help="Conversation ID") + .add_argument("first_id", type=uuid_value, location="args", help="First message ID for pagination") + .add_argument( + "limit", + type=int_range(1, 100), + required=False, + default=20, + location="args", + help="Number of messages to return", + ) ) -message_feedback_parser = reqparse.RequestParser() -message_feedback_parser.add_argument( - "rating", type=str, choices=["like", "dislike", None], location="json", help="Feedback rating" +message_feedback_parser = ( + reqparse.RequestParser() + .add_argument("rating", type=str, choices=["like", "dislike", None], location="json", help="Feedback rating") + .add_argument("content", type=str, location="json", help="Feedback content") ) -message_feedback_parser.add_argument("content", type=str, location="json", help="Feedback content") -feedback_list_parser = reqparse.RequestParser() -feedback_list_parser.add_argument("page", type=int, default=1, location="args", help="Page number") -feedback_list_parser.add_argument( - "limit", type=int_range(1, 101), required=False, default=20, location="args", help="Number of feedbacks per page" +feedback_list_parser = ( + reqparse.RequestParser() + .add_argument("page", type=int, default=1, location="args", help="Page number") + .add_argument( + "limit", + type=int_range(1, 101), + required=False, + default=20, + location="args", + help="Number of feedbacks per page", + ) ) diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index e912563bc6..af5eae463d 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -42,32 +42,36 @@ from services.workflow_app_service import WorkflowAppService logger = logging.getLogger(__name__) # Define parsers for workflow APIs -workflow_run_parser = reqparse.RequestParser() -workflow_run_parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") -workflow_run_parser.add_argument("files", type=list, required=False, location="json") -workflow_run_parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") +workflow_run_parser = ( + reqparse.RequestParser() + .add_argument("inputs", type=dict, required=True, nullable=False, location="json") + .add_argument("files", type=list, required=False, location="json") + .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") +) -workflow_log_parser = reqparse.RequestParser() -workflow_log_parser.add_argument("keyword", type=str, location="args") -workflow_log_parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args") -workflow_log_parser.add_argument("created_at__before", type=str, location="args") -workflow_log_parser.add_argument("created_at__after", type=str, location="args") -workflow_log_parser.add_argument( - "created_by_end_user_session_id", - type=str, - location="args", - required=False, - default=None, +workflow_log_parser = ( + reqparse.RequestParser() + .add_argument("keyword", type=str, location="args") + .add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args") + .add_argument("created_at__before", type=str, location="args") + .add_argument("created_at__after", type=str, location="args") + .add_argument( + "created_by_end_user_session_id", + type=str, + location="args", + required=False, + default=None, + ) + .add_argument( + "created_by_account", + type=str, + location="args", + required=False, + default=None, + ) + .add_argument("page", type=int_range(1, 99999), default=1, location="args") + .add_argument("limit", type=int_range(1, 100), default=20, location="args") ) -workflow_log_parser.add_argument( - "created_by_account", - type=str, - location="args", - required=False, - default=None, -) -workflow_log_parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") -workflow_log_parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") workflow_run_fields = { "id": fields.String, diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 92bbb76f0f..9d5566919b 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -33,119 +33,118 @@ def _validate_name(name): # Define parsers for dataset operations -dataset_create_parser = reqparse.RequestParser() -dataset_create_parser.add_argument( - "name", - nullable=False, - required=True, - help="type is required. Name must be between 1 to 40 characters.", - type=_validate_name, -) -dataset_create_parser.add_argument( - "description", - type=validate_description_length, - nullable=True, - required=False, - default="", -) -dataset_create_parser.add_argument( - "indexing_technique", - type=str, - location="json", - choices=Dataset.INDEXING_TECHNIQUE_LIST, - help="Invalid indexing technique.", -) -dataset_create_parser.add_argument( - "permission", - type=str, - location="json", - choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), - help="Invalid permission.", - required=False, - nullable=False, -) -dataset_create_parser.add_argument( - "external_knowledge_api_id", - type=str, - nullable=True, - required=False, - default="_validate_name", -) -dataset_create_parser.add_argument( - "provider", - type=str, - nullable=True, - required=False, - default="vendor", -) -dataset_create_parser.add_argument( - "external_knowledge_id", - type=str, - nullable=True, - required=False, -) -dataset_create_parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json") -dataset_create_parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") -dataset_create_parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") - -dataset_update_parser = reqparse.RequestParser() -dataset_update_parser.add_argument( - "name", - nullable=False, - help="type is required. Name must be between 1 to 40 characters.", - type=_validate_name, -) -dataset_update_parser.add_argument( - "description", location="json", store_missing=False, type=validate_description_length -) -dataset_update_parser.add_argument( - "indexing_technique", - type=str, - location="json", - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, - help="Invalid indexing technique.", -) -dataset_update_parser.add_argument( - "permission", - type=str, - location="json", - choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), - help="Invalid permission.", -) -dataset_update_parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.") -dataset_update_parser.add_argument( - "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider." -) -dataset_update_parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.") -dataset_update_parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.") -dataset_update_parser.add_argument( - "external_retrieval_model", - type=dict, - required=False, - nullable=True, - location="json", - help="Invalid external retrieval model.", -) -dataset_update_parser.add_argument( - "external_knowledge_id", - type=str, - required=False, - nullable=True, - location="json", - help="Invalid external knowledge id.", -) -dataset_update_parser.add_argument( - "external_knowledge_api_id", - type=str, - required=False, - nullable=True, - location="json", - help="Invalid external knowledge api id.", +dataset_create_parser = ( + reqparse.RequestParser() + .add_argument( + "name", + nullable=False, + required=True, + help="type is required. Name must be between 1 to 40 characters.", + type=_validate_name, + ) + .add_argument( + "description", + type=validate_description_length, + nullable=True, + required=False, + default="", + ) + .add_argument( + "indexing_technique", + type=str, + location="json", + choices=Dataset.INDEXING_TECHNIQUE_LIST, + help="Invalid indexing technique.", + ) + .add_argument( + "permission", + type=str, + location="json", + choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), + help="Invalid permission.", + required=False, + nullable=False, + ) + .add_argument( + "external_knowledge_api_id", + type=str, + nullable=True, + required=False, + default="_validate_name", + ) + .add_argument( + "provider", + type=str, + nullable=True, + required=False, + default="vendor", + ) + .add_argument( + "external_knowledge_id", + type=str, + nullable=True, + required=False, + ) + .add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json") + .add_argument("embedding_model", type=str, required=False, nullable=True, location="json") + .add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") ) -tag_create_parser = reqparse.RequestParser() -tag_create_parser.add_argument( +dataset_update_parser = ( + reqparse.RequestParser() + .add_argument( + "name", + nullable=False, + help="type is required. Name must be between 1 to 40 characters.", + type=_validate_name, + ) + .add_argument("description", location="json", store_missing=False, type=validate_description_length) + .add_argument( + "indexing_technique", + type=str, + location="json", + choices=Dataset.INDEXING_TECHNIQUE_LIST, + nullable=True, + help="Invalid indexing technique.", + ) + .add_argument( + "permission", + type=str, + location="json", + choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), + help="Invalid permission.", + ) + .add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.") + .add_argument("embedding_model_provider", type=str, location="json", help="Invalid embedding model provider.") + .add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.") + .add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.") + .add_argument( + "external_retrieval_model", + type=dict, + required=False, + nullable=True, + location="json", + help="Invalid external retrieval model.", + ) + .add_argument( + "external_knowledge_id", + type=str, + required=False, + nullable=True, + location="json", + help="Invalid external knowledge id.", + ) + .add_argument( + "external_knowledge_api_id", + type=str, + required=False, + nullable=True, + location="json", + help="Invalid external knowledge api id.", + ) +) + +tag_create_parser = reqparse.RequestParser().add_argument( "name", nullable=False, required=True, @@ -155,32 +154,37 @@ tag_create_parser.add_argument( else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")), ) -tag_update_parser = reqparse.RequestParser() -tag_update_parser.add_argument( - "name", - nullable=False, - required=True, - help="Name must be between 1 to 50 characters.", - type=lambda x: x - if x and 1 <= len(x) <= 50 - else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")), -) -tag_update_parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str) - -tag_delete_parser = reqparse.RequestParser() -tag_delete_parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str) - -tag_binding_parser = reqparse.RequestParser() -tag_binding_parser.add_argument( - "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required." -) -tag_binding_parser.add_argument( - "target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required." +tag_update_parser = ( + reqparse.RequestParser() + .add_argument( + "name", + nullable=False, + required=True, + help="Name must be between 1 to 50 characters.", + type=lambda x: x + if x and 1 <= len(x) <= 50 + else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")), + ) + .add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str) ) -tag_unbinding_parser = reqparse.RequestParser() -tag_unbinding_parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") -tag_unbinding_parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") +tag_delete_parser = reqparse.RequestParser().add_argument( + "tag_id", nullable=False, required=True, help="Id of a tag.", type=str +) + +tag_binding_parser = ( + reqparse.RequestParser() + .add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.") + .add_argument( + "target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required." + ) +) + +tag_unbinding_parser = ( + reqparse.RequestParser() + .add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") + .add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") +) @service_api_ns.route("/datasets") diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 961a338bc5..893cd7c923 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -35,37 +35,31 @@ from services.entities.knowledge_entities.knowledge_entities import KnowledgeCon from services.file_service import FileService # Define parsers for document operations -document_text_create_parser = reqparse.RequestParser() -document_text_create_parser.add_argument("name", type=str, required=True, nullable=False, location="json") -document_text_create_parser.add_argument("text", type=str, required=True, nullable=False, location="json") -document_text_create_parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") -document_text_create_parser.add_argument("original_document_id", type=str, required=False, location="json") -document_text_create_parser.add_argument( - "doc_form", type=str, default="text_model", required=False, nullable=False, location="json" -) -document_text_create_parser.add_argument( - "doc_language", type=str, default="English", required=False, nullable=False, location="json" -) -document_text_create_parser.add_argument( - "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" -) -document_text_create_parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json") -document_text_create_parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") -document_text_create_parser.add_argument( - "embedding_model_provider", type=str, required=False, nullable=True, location="json" +document_text_create_parser = ( + reqparse.RequestParser() + .add_argument("name", type=str, required=True, nullable=False, location="json") + .add_argument("text", type=str, required=True, nullable=False, location="json") + .add_argument("process_rule", type=dict, required=False, nullable=True, location="json") + .add_argument("original_document_id", type=str, required=False, location="json") + .add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + .add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json") + .add_argument( + "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" + ) + .add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json") + .add_argument("embedding_model", type=str, required=False, nullable=True, location="json") + .add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") ) -document_text_update_parser = reqparse.RequestParser() -document_text_update_parser.add_argument("name", type=str, required=False, nullable=True, location="json") -document_text_update_parser.add_argument("text", type=str, required=False, nullable=True, location="json") -document_text_update_parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") -document_text_update_parser.add_argument( - "doc_form", type=str, default="text_model", required=False, nullable=False, location="json" +document_text_update_parser = ( + reqparse.RequestParser() + .add_argument("name", type=str, required=False, nullable=True, location="json") + .add_argument("text", type=str, required=False, nullable=True, location="json") + .add_argument("process_rule", type=dict, required=False, nullable=True, location="json") + .add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + .add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json") + .add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") ) -document_text_update_parser.add_argument( - "doc_language", type=str, default="English", required=False, nullable=False, location="json" -) -document_text_update_parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") @service_api_ns.route( diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index 51420fdd5f..f646f1f4fa 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -15,21 +15,17 @@ from services.entities.knowledge_entities.knowledge_entities import ( from services.metadata_service import MetadataService # Define parsers for metadata APIs -metadata_create_parser = reqparse.RequestParser() -metadata_create_parser.add_argument( - "type", type=str, required=True, nullable=False, location="json", help="Metadata type" -) -metadata_create_parser.add_argument( - "name", type=str, required=True, nullable=False, location="json", help="Metadata name" +metadata_create_parser = ( + reqparse.RequestParser() + .add_argument("type", type=str, required=True, nullable=False, location="json", help="Metadata type") + .add_argument("name", type=str, required=True, nullable=False, location="json", help="Metadata name") ) -metadata_update_parser = reqparse.RequestParser() -metadata_update_parser.add_argument( +metadata_update_parser = reqparse.RequestParser().add_argument( "name", type=str, required=True, nullable=False, location="json", help="New metadata name" ) -document_metadata_parser = reqparse.RequestParser() -document_metadata_parser.add_argument( +document_metadata_parser = reqparse.RequestParser().add_argument( "operation_data", type=list, required=True, nullable=False, location="json", help="Metadata operation data" ) diff --git a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py index 38891f0180..c177e9180a 100644 --- a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py @@ -91,11 +91,13 @@ class DatasourceNodeRunApi(DatasetApiResource): def post(self, tenant_id: str, dataset_id: str, node_id: str): """Resource for getting datasource plugins.""" # Get query parameter to determine published or draft - parser: RequestParser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") - parser.add_argument("datasource_type", type=str, required=True, location="json") - parser.add_argument("credential_id", type=str, required=False, location="json") - parser.add_argument("is_published", type=bool, required=True, location="json") + parser: RequestParser = ( + reqparse.RequestParser() + .add_argument("inputs", type=dict, required=True, nullable=False, location="json") + .add_argument("datasource_type", type=str, required=True, location="json") + .add_argument("credential_id", type=str, required=False, location="json") + .add_argument("is_published", type=bool, required=True, location="json") + ) args: ParseResult = parser.parse_args() datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(args) @@ -147,19 +149,21 @@ class PipelineRunApi(DatasetApiResource): ) def post(self, tenant_id: str, dataset_id: str): """Resource for running a rag pipeline.""" - parser: RequestParser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") - parser.add_argument("datasource_type", type=str, required=True, location="json") - parser.add_argument("datasource_info_list", type=list, required=True, location="json") - parser.add_argument("start_node_id", type=str, required=True, location="json") - parser.add_argument("is_published", type=bool, required=True, default=True, location="json") - parser.add_argument( - "response_mode", - type=str, - required=True, - choices=["streaming", "blocking"], - default="blocking", - location="json", + parser: RequestParser = ( + reqparse.RequestParser() + .add_argument("inputs", type=dict, required=True, nullable=False, location="json") + .add_argument("datasource_type", type=str, required=True, location="json") + .add_argument("datasource_info_list", type=list, required=True, location="json") + .add_argument("start_node_id", type=str, required=True, location="json") + .add_argument("is_published", type=bool, required=True, default=True, location="json") + .add_argument( + "response_mode", + type=str, + required=True, + choices=["streaming", "blocking"], + default="blocking", + location="json", + ) ) args: ParseResult = parser.parse_args() diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index acbbf4531b..81abd19fed 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -24,26 +24,34 @@ from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDelete from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError # Define parsers for segment operations -segment_create_parser = reqparse.RequestParser() -segment_create_parser.add_argument("segments", type=list, required=False, nullable=True, location="json") +segment_create_parser = reqparse.RequestParser().add_argument( + "segments", type=list, required=False, nullable=True, location="json" +) -segment_list_parser = reqparse.RequestParser() -segment_list_parser.add_argument("status", type=str, action="append", default=[], location="args") -segment_list_parser.add_argument("keyword", type=str, default=None, location="args") +segment_list_parser = ( + reqparse.RequestParser() + .add_argument("status", type=str, action="append", default=[], location="args") + .add_argument("keyword", type=str, default=None, location="args") +) -segment_update_parser = reqparse.RequestParser() -segment_update_parser.add_argument("segment", type=dict, required=False, nullable=True, location="json") +segment_update_parser = reqparse.RequestParser().add_argument( + "segment", type=dict, required=False, nullable=True, location="json" +) -child_chunk_create_parser = reqparse.RequestParser() -child_chunk_create_parser.add_argument("content", type=str, required=True, nullable=False, location="json") +child_chunk_create_parser = reqparse.RequestParser().add_argument( + "content", type=str, required=True, nullable=False, location="json" +) -child_chunk_list_parser = reqparse.RequestParser() -child_chunk_list_parser.add_argument("limit", type=int, default=20, location="args") -child_chunk_list_parser.add_argument("keyword", type=str, default=None, location="args") -child_chunk_list_parser.add_argument("page", type=int, default=1, location="args") +child_chunk_list_parser = ( + reqparse.RequestParser() + .add_argument("limit", type=int, default=20, location="args") + .add_argument("keyword", type=str, default=None, location="args") + .add_argument("page", type=int, default=1, location="args") +) -child_chunk_update_parser = reqparse.RequestParser() -child_chunk_update_parser.add_argument("content", type=str, required=True, nullable=False, location="json") +child_chunk_update_parser = reqparse.RequestParser().add_argument( + "content", type=str, required=True, nullable=False, location="json" +) @service_api_ns.route("/datasets//documents//segments") diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 2bc068ec75..d7facdbbb3 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -94,9 +94,11 @@ class AppAccessMode(Resource): } ) def get(self): - parser = reqparse.RequestParser() - parser.add_argument("appId", type=str, required=False, location="args") - parser.add_argument("appCode", type=str, required=False, location="args") + parser = ( + reqparse.RequestParser() + .add_argument("appId", type=str, required=False, location="args") + .add_argument("appCode", type=str, required=False, location="args") + ) args = parser.parse_args() features = FeatureService.get_system_features() @@ -155,8 +157,7 @@ class AppWebAuthPermission(Resource): if not features.webapp_auth.enabled: return {"result": True} - parser = reqparse.RequestParser() - parser.add_argument("appId", type=str, required=True, location="args") + parser = reqparse.RequestParser().add_argument("appId", type=str, required=True, location="args") args = parser.parse_args() app_id = args["appId"] diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index c1c46891b6..3103851088 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -108,11 +108,13 @@ class TextApi(WebApiResource): def post(self, app_model: App, end_user): """Convert text to audio""" try: - parser = reqparse.RequestParser() - parser.add_argument("message_id", type=str, required=False, location="json") - parser.add_argument("voice", type=str, location="json") - parser.add_argument("text", type=str, location="json") - parser.add_argument("streaming", type=bool, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("message_id", type=str, required=False, location="json") + .add_argument("voice", type=str, location="json") + .add_argument("text", type=str, location="json") + .add_argument("streaming", type=bool, location="json") + ) args = parser.parse_args() message_id = args.get("message_id", None) diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 67ae970388..5e45beffc0 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -67,12 +67,14 @@ class CompletionApi(WebApiResource): if app_model.mode != "completion": raise NotCompletionAppError() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, location="json") - parser.add_argument("query", type=str, location="json", default="") - parser.add_argument("files", type=list, required=False, location="json") - parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - parser.add_argument("retriever_from", type=str, required=False, default="web_app", location="json") + parser = ( + reqparse.RequestParser() + .add_argument("inputs", type=dict, required=True, location="json") + .add_argument("query", type=str, location="json", default="") + .add_argument("files", type=list, required=False, location="json") + .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + .add_argument("retriever_from", type=str, required=False, default="web_app", location="json") + ) args = parser.parse_args() @@ -166,14 +168,16 @@ class ChatApi(WebApiResource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, location="json") - parser.add_argument("query", type=str, required=True, location="json") - parser.add_argument("files", type=list, required=False, location="json") - parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - parser.add_argument("conversation_id", type=uuid_value, location="json") - parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json") - parser.add_argument("retriever_from", type=str, required=False, default="web_app", location="json") + parser = ( + reqparse.RequestParser() + .add_argument("inputs", type=dict, required=True, location="json") + .add_argument("query", type=str, required=True, location="json") + .add_argument("files", type=list, required=False, location="json") + .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + .add_argument("conversation_id", type=uuid_value, location="json") + .add_argument("parent_message_id", type=uuid_value, required=False, location="json") + .add_argument("retriever_from", type=str, required=False, default="web_app", location="json") + ) args = parser.parse_args() diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index 03dd986aed..86e19423e5 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -60,17 +60,19 @@ class ConversationListApi(WebApiResource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = reqparse.RequestParser() - parser.add_argument("last_id", type=uuid_value, location="args") - parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - parser.add_argument("pinned", type=str, choices=["true", "false", None], location="args") - parser.add_argument( - "sort_by", - type=str, - choices=["created_at", "-created_at", "updated_at", "-updated_at"], - required=False, - default="-updated_at", - location="args", + parser = ( + reqparse.RequestParser() + .add_argument("last_id", type=uuid_value, location="args") + .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + .add_argument("pinned", type=str, choices=["true", "false", None], location="args") + .add_argument( + "sort_by", + type=str, + choices=["created_at", "-created_at", "updated_at", "-updated_at"], + required=False, + default="-updated_at", + location="args", + ) ) args = parser.parse_args() @@ -161,9 +163,11 @@ class ConversationRenameApi(WebApiResource): conversation_id = str(c_id) - parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=False, location="json") - parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("name", type=str, required=False, location="json") + .add_argument("auto_generate", type=bool, required=False, default=False, location="json") + ) args = parser.parse_args() try: diff --git a/api/controllers/web/forgot_password.py b/api/controllers/web/forgot_password.py index cbafd70e99..b9e391e049 100644 --- a/api/controllers/web/forgot_password.py +++ b/api/controllers/web/forgot_password.py @@ -40,9 +40,11 @@ class ForgotPasswordSendEmailApi(Resource): } ) def post(self): - parser = reqparse.RequestParser() - parser.add_argument("email", type=email, required=True, location="json") - parser.add_argument("language", type=str, required=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("email", type=email, required=True, location="json") + .add_argument("language", type=str, required=False, location="json") + ) args = parser.parse_args() ip_address = extract_remote_ip(request) @@ -76,10 +78,12 @@ class ForgotPasswordCheckApi(Resource): responses={200: "Token is valid", 400: "Bad request - invalid token format", 401: "Invalid or expired token"} ) def post(self): - parser = reqparse.RequestParser() - parser.add_argument("email", type=str, required=True, location="json") - parser.add_argument("code", type=str, required=True, location="json") - parser.add_argument("token", type=str, required=True, nullable=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("email", type=str, required=True, location="json") + .add_argument("code", type=str, required=True, location="json") + .add_argument("token", type=str, required=True, nullable=False, location="json") + ) args = parser.parse_args() user_email = args["email"] @@ -127,10 +131,12 @@ class ForgotPasswordResetApi(Resource): } ) def post(self): - parser = reqparse.RequestParser() - parser.add_argument("token", type=str, required=True, nullable=False, location="json") - parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") - parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("token", type=str, required=True, nullable=False, location="json") + .add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") + .add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") + ) args = parser.parse_args() # Validate passwords match diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py index a489101cc9..351f245f4a 100644 --- a/api/controllers/web/login.py +++ b/api/controllers/web/login.py @@ -35,9 +35,11 @@ class LoginApi(Resource): ) def post(self): """Authenticate user and login.""" - parser = reqparse.RequestParser() - parser.add_argument("email", type=email, required=True, location="json") - parser.add_argument("password", type=valid_password, required=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("email", type=email, required=True, location="json") + .add_argument("password", type=valid_password, required=True, location="json") + ) args = parser.parse_args() try: @@ -77,9 +79,11 @@ class EmailCodeLoginSendEmailApi(Resource): } ) def post(self): - parser = reqparse.RequestParser() - parser.add_argument("email", type=email, required=True, location="json") - parser.add_argument("language", type=str, required=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("email", type=email, required=True, location="json") + .add_argument("language", type=str, required=False, location="json") + ) args = parser.parse_args() if args["language"] is not None and args["language"] == "zh-Hans": @@ -111,10 +115,12 @@ class EmailCodeLoginApi(Resource): } ) def post(self): - parser = reqparse.RequestParser() - parser.add_argument("email", type=str, required=True, location="json") - parser.add_argument("code", type=str, required=True, location="json") - parser.add_argument("token", type=str, required=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("email", type=str, required=True, location="json") + .add_argument("code", type=str, required=True, location="json") + .add_argument("token", type=str, required=True, location="json") + ) args = parser.parse_args() user_email = args["email"] diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index a52cccac13..9f9aa4838c 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -93,10 +93,12 @@ class MessageListApi(WebApiResource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = reqparse.RequestParser() - parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") - parser.add_argument("first_id", type=uuid_value, location="args") - parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + parser = ( + reqparse.RequestParser() + .add_argument("conversation_id", required=True, type=uuid_value, location="args") + .add_argument("first_id", type=uuid_value, location="args") + .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + ) args = parser.parse_args() try: @@ -143,9 +145,11 @@ class MessageFeedbackApi(WebApiResource): def post(self, app_model, end_user, message_id): message_id = str(message_id) - parser = reqparse.RequestParser() - parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") - parser.add_argument("content", type=str, location="json", default=None) + parser = ( + reqparse.RequestParser() + .add_argument("rating", type=str, choices=["like", "dislike", None], location="json") + .add_argument("content", type=str, location="json", default=None) + ) args = parser.parse_args() try: @@ -193,8 +197,7 @@ class MessageMoreLikeThisApi(WebApiResource): message_id = str(message_id) - parser = reqparse.RequestParser() - parser.add_argument( + parser = reqparse.RequestParser().add_argument( "response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args" ) args = parser.parse_args() diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index 0983e30b9d..dac4b3da94 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -97,8 +97,7 @@ class RemoteFileUploadApi(WebApiResource): FileTooLargeError: File exceeds size limit UnsupportedFileTypeError: File type not supported """ - parser = reqparse.RequestParser() - parser.add_argument("url", type=str, required=True, help="URL is required") + parser = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required") args = parser.parse_args() url = args["url"] diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index 96f09c8d3c..865f3610a7 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -63,9 +63,11 @@ class SavedMessageListApi(WebApiResource): if app_model.mode != "completion": raise NotCompletionAppError() - parser = reqparse.RequestParser() - parser.add_argument("last_id", type=uuid_value, location="args") - parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + parser = ( + reqparse.RequestParser() + .add_argument("last_id", type=uuid_value, location="args") + .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + ) args = parser.parse_args() return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"]) @@ -92,8 +94,7 @@ class SavedMessageListApi(WebApiResource): if app_model.mode != "completion": raise NotCompletionAppError() - parser = reqparse.RequestParser() - parser.add_argument("message_id", type=uuid_value, required=True, location="json") + parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json") args = parser.parse_args() try: diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index 9a980148d9..3cbb07a296 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -58,9 +58,11 @@ class WorkflowRunApi(WebApiResource): if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") - parser.add_argument("files", type=list, required=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("inputs", type=dict, required=True, nullable=False, location="json") + .add_argument("files", type=list, required=False, location="json") + ) args = parser.parse_args() try: diff --git a/api/tests/unit_tests/services/test_metadata_bug_complete.py b/api/tests/unit_tests/services/test_metadata_bug_complete.py index ee96305070..bbfa9da15e 100644 --- a/api/tests/unit_tests/services/test_metadata_bug_complete.py +++ b/api/tests/unit_tests/services/test_metadata_bug_complete.py @@ -80,9 +80,11 @@ class TestMetadataBugCompleteValidation: def test_4_fixed_api_layer_rejects_null(self, app): """Test Layer 4: Fixed API configuration properly rejects null values.""" # Test Console API create endpoint (fixed) - parser = reqparse.RequestParser() - parser.add_argument("type", type=str, required=True, nullable=False, location="json") - parser.add_argument("name", type=str, required=True, nullable=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("type", type=str, required=True, nullable=False, location="json") + .add_argument("name", type=str, required=True, nullable=False, location="json") + ) with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"): with pytest.raises(BadRequest): @@ -100,9 +102,11 @@ class TestMetadataBugCompleteValidation: def test_5_fixed_api_accepts_valid_values(self, app): """Test that fixed API still accepts valid non-null values.""" - parser = reqparse.RequestParser() - parser.add_argument("type", type=str, required=True, nullable=False, location="json") - parser.add_argument("name", type=str, required=True, nullable=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("type", type=str, required=True, nullable=False, location="json") + .add_argument("name", type=str, required=True, nullable=False, location="json") + ) with app.test_request_context(json={"type": "string", "name": "valid_name"}, content_type="application/json"): args = parser.parse_args() @@ -112,9 +116,11 @@ class TestMetadataBugCompleteValidation: def test_6_simulated_buggy_behavior(self, app): """Test simulating the original buggy behavior with nullable=True.""" # Simulate the old buggy configuration - buggy_parser = reqparse.RequestParser() - buggy_parser.add_argument("type", type=str, required=True, nullable=True, location="json") - buggy_parser.add_argument("name", type=str, required=True, nullable=True, location="json") + buggy_parser = ( + reqparse.RequestParser() + .add_argument("type", type=str, required=True, nullable=True, location="json") + .add_argument("name", type=str, required=True, nullable=True, location="json") + ) with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"): # This would pass in the buggy version diff --git a/api/tests/unit_tests/services/test_metadata_nullable_bug.py b/api/tests/unit_tests/services/test_metadata_nullable_bug.py index 3d57737943..c8a1a70422 100644 --- a/api/tests/unit_tests/services/test_metadata_nullable_bug.py +++ b/api/tests/unit_tests/services/test_metadata_nullable_bug.py @@ -54,9 +54,11 @@ class TestMetadataNullableBug: def test_api_parser_accepts_null_values(self, app): """Test that API parser configuration incorrectly accepts null values.""" # Simulate the current API parser configuration - parser = reqparse.RequestParser() - parser.add_argument("type", type=str, required=True, nullable=True, location="json") - parser.add_argument("name", type=str, required=True, nullable=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("type", type=str, required=True, nullable=True, location="json") + .add_argument("name", type=str, required=True, nullable=True, location="json") + ) # Simulate request data with null values with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"): @@ -72,9 +74,11 @@ class TestMetadataNullableBug: def test_integration_bug_scenario(self, app): """Test the complete bug scenario from API to service layer.""" # Step 1: API parser accepts null values (current buggy behavior) - parser = reqparse.RequestParser() - parser.add_argument("type", type=str, required=True, nullable=True, location="json") - parser.add_argument("name", type=str, required=True, nullable=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("type", type=str, required=True, nullable=True, location="json") + .add_argument("name", type=str, required=True, nullable=True, location="json") + ) with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"): args = parser.parse_args() @@ -105,9 +109,11 @@ class TestMetadataNullableBug: def test_correct_nullable_false_configuration_works(self, app): """Test that the correct nullable=False configuration works as expected.""" # This tests the FIXED configuration - parser = reqparse.RequestParser() - parser.add_argument("type", type=str, required=True, nullable=False, location="json") - parser.add_argument("name", type=str, required=True, nullable=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("type", type=str, required=True, nullable=False, location="json") + .add_argument("name", type=str, required=True, nullable=False, location="json") + ) with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"): # This should fail with BadRequest due to nullable=False From 141ca8904a3fec5f4b59f6e1e46f65880c395fa6 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Sun, 19 Oct 2025 18:56:02 +0800 Subject: [PATCH 36/46] fix(api): ensure JSON responses are properly serialized in ApiTool (#27097) Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- api/core/tools/custom_tool/tool.py | 10 +- api/tests/unit_tests/tools/test_api_tool.py | 249 ++++++++++++++++++++ 2 files changed, 255 insertions(+), 4 deletions(-) create mode 100644 api/tests/unit_tests/tools/test_api_tool.py diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index f18f638f2d..54c266ffcc 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -395,11 +395,13 @@ class ApiTool(Tool): parsed_response = self.validate_and_parse_response(response) # assemble invoke message based on response type - if parsed_response.is_json and isinstance(parsed_response.content, dict): - yield self.create_json_message(parsed_response.content) + if parsed_response.is_json: + if isinstance(parsed_response.content, dict): + yield self.create_json_message(parsed_response.content) - # FIXES: https://github.com/langgenius/dify/pull/23456#issuecomment-3182413088 - # We need never break the original flows + # The yield below must be preserved to keep backward compatibility. + # + # ref: https://github.com/langgenius/dify/pull/23456#issuecomment-3182413088 yield self.create_text_message(response.text) else: # Convert to string if needed and create text message diff --git a/api/tests/unit_tests/tools/test_api_tool.py b/api/tests/unit_tests/tools/test_api_tool.py new file mode 100644 index 0000000000..4d5683dcbd --- /dev/null +++ b/api/tests/unit_tests/tools/test_api_tool.py @@ -0,0 +1,249 @@ +import json +import operator +from typing import TypeVar +from unittest.mock import Mock, patch + +import httpx +import pytest + +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.custom_tool.tool import ApiTool +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_bundle import ApiToolBundle +from core.tools.entities.tool_entities import ( + ToolEntity, + ToolIdentity, + ToolInvokeMessage, +) + +_T = TypeVar("_T") + + +def _get_message_by_type(msgs: list[ToolInvokeMessage], msg_type: type[_T]) -> ToolInvokeMessage | None: + return next((i for i in msgs if isinstance(i.message, msg_type)), None) + + +class TestApiToolInvoke: + """Test suite for ApiTool._invoke method to ensure JSON responses are properly serialized.""" + + def setup_method(self): + """Setup test fixtures.""" + # Create a mock tool entity + self.mock_tool_identity = ToolIdentity( + author="test", + name="test_api_tool", + label=I18nObject(en_US="Test API Tool", zh_Hans="测试API工具"), + provider="test_provider", + ) + self.mock_tool_entity = ToolEntity(identity=self.mock_tool_identity) + + # Create a mock API bundle + self.mock_api_bundle = ApiToolBundle( + server_url="https://api.example.com/test", + method="GET", + openapi={}, + operation_id="test_operation", + parameters=[], + author="test_author", + ) + + # Create a mock runtime + self.mock_runtime = Mock(spec=ToolRuntime) + self.mock_runtime.credentials = {"auth_type": "none"} + + # Create the ApiTool instance + self.api_tool = ApiTool( + entity=self.mock_tool_entity, + api_bundle=self.mock_api_bundle, + runtime=self.mock_runtime, + provider_id="test_provider", + ) + + @patch("core.tools.custom_tool.tool.ssrf_proxy.get") + def test_invoke_with_json_response_creates_text_message_with_serialized_json(self, mock_get: Mock) -> None: + """Test that when upstream returns JSON, the output Text message contains JSON-serialized string.""" + # Setup mock response with JSON content + json_response_data = { + "key": "value", + "number": 123, + "nested": {"inner": "data"}, + } + mock_response = Mock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.content = json.dumps(json_response_data).encode("utf-8") + mock_response.json.return_value = json_response_data + mock_response.text = json.dumps(json_response_data) + mock_response.headers = {"content-type": "application/json"} + mock_get.return_value = mock_response + + # Invoke the tool + result_generator = self.api_tool._invoke(user_id="test_user", tool_parameters={}) + + # Get the result from the generator + result = list(result_generator) + assert len(result) == 2 + + # Verify _invoke yields text message + text_message = _get_message_by_type(result, ToolInvokeMessage.TextMessage) + assert text_message is not None, "_invoke should yield a text message" + assert isinstance(text_message, ToolInvokeMessage) + assert text_message.type == ToolInvokeMessage.MessageType.TEXT + assert text_message.message is not None + # Verify the text contains the JSON-serialized string + # Check if message is a TextMessage + assert isinstance(text_message.message, ToolInvokeMessage.TextMessage) + # Verify it's a valid JSON string and equals to the mock response + parsed_back = json.loads(text_message.message.text) + assert parsed_back == json_response_data + + # Verify _invoke yields json message + json_message = _get_message_by_type(result, ToolInvokeMessage.JsonMessage) + assert json_message is not None, "_invoke should yield a JSON message" + assert isinstance(json_message, ToolInvokeMessage) + assert json_message.type == ToolInvokeMessage.MessageType.JSON + assert json_message.message is not None + + assert isinstance(json_message.message, ToolInvokeMessage.JsonMessage) + assert json_message.message.json_object == json_response_data + + @patch("core.tools.custom_tool.tool.ssrf_proxy.get") + @pytest.mark.parametrize( + "test_case", + [ + ( + "array", + [ + {"id": 1, "name": "Item 1", "active": True}, + {"id": 2, "name": "Item 2", "active": False}, + {"id": 3, "name": "项目 3", "active": True}, + ], + ), + ( + "string", + "string", + ), + ( + "number", + 123.456, + ), + ( + "boolean", + True, + ), + ( + "null", + None, + ), + ], + ids=operator.itemgetter(0), + ) + def test_invoke_with_non_dict_json_response_creates_text_message_with_serialized_json( + self, mock_get: Mock, test_case + ) -> None: + """Test that when upstream returns a non-dict JSON, the output Text message contains JSON-serialized string.""" + # Setup mock response with non-dict JSON content + _, json_value = test_case + mock_response = Mock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.content = json.dumps(json_value).encode("utf-8") + mock_response.json.return_value = json_value + mock_response.text = json.dumps(json_value) + mock_response.headers = {"content-type": "application/json"} + mock_get.return_value = mock_response + + # Invoke the tool + result_generator = self.api_tool._invoke(user_id="test_user", tool_parameters={}) + + # Get the result from the generator + result = list(result_generator) + assert len(result) == 1 + + # Verify _invoke yields a text message + text_message = _get_message_by_type(result, ToolInvokeMessage.TextMessage) + assert text_message is not None, "_invoke should yield a text message containing the serialized JSON." + assert isinstance(text_message, ToolInvokeMessage) + assert text_message.type == ToolInvokeMessage.MessageType.TEXT + assert text_message.message is not None + # Verify the text contains the JSON-serialized string + # Check if message is a TextMessage + assert isinstance(text_message.message, ToolInvokeMessage.TextMessage) + # Verify it's a valid JSON string + parsed_back = json.loads(text_message.message.text) + assert parsed_back == json_value + + # Verify _invoke yields json message + json_message = _get_message_by_type(result, ToolInvokeMessage.JsonMessage) + assert json_message is None, "_invoke should not yield a JSON message for JSON array response" + + @patch("core.tools.custom_tool.tool.ssrf_proxy.get") + def test_invoke_with_text_response_creates_text_message_with_original_text(self, mock_get: Mock) -> None: + """Test that when upstream returns plain text, the output Text message contains the original text.""" + # Setup mock response with plain text content + text_response_data = "This is a plain text response" + mock_response = Mock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.content = text_response_data.encode("utf-8") + mock_response.json.side_effect = json.JSONDecodeError("Expecting value", "doc", 0) + mock_response.text = text_response_data + mock_response.headers = {"content-type": "text/plain"} + mock_get.return_value = mock_response + + # Invoke the tool + result_generator = self.api_tool._invoke(user_id="test_user", tool_parameters={}) + + # Get the result from the generator + result = list(result_generator) + assert len(result) == 1 + + # Verify it's a text message with the original text + message = result[0] + assert isinstance(message, ToolInvokeMessage) + assert message.type == ToolInvokeMessage.MessageType.TEXT + assert message.message is not None + # Check if message is a TextMessage + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + assert message.message.text == text_response_data + + @patch("core.tools.custom_tool.tool.ssrf_proxy.get") + def test_invoke_with_empty_response(self, mock_get: Mock) -> None: + """Test that empty responses are handled correctly.""" + # Setup mock response with empty content + mock_response = Mock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.content = b"" + mock_response.headers = {"content-type": "application/json"} + mock_get.return_value = mock_response + + # Invoke the tool + result_generator = self.api_tool._invoke(user_id="test_user", tool_parameters={}) + + # Get the result from the generator + result = list(result_generator) + assert len(result) == 1 + + # Verify it's a text message with the empty response message + message = result[0] + assert isinstance(message, ToolInvokeMessage) + assert message.type == ToolInvokeMessage.MessageType.TEXT + assert message.message is not None + # Check if message is a TextMessage + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + assert "Empty response from the tool" in message.message.text + + @patch("core.tools.custom_tool.tool.ssrf_proxy.get") + def test_invoke_with_error_response(self, mock_get: Mock) -> None: + """Test that error responses are handled correctly.""" + # Setup mock response with error status code + mock_response = Mock(spec=httpx.Response) + mock_response.status_code = 404 + mock_response.text = "Not Found" + mock_get.return_value = mock_response + + result_generator = self.api_tool._invoke(user_id="test_user", tool_parameters={}) + + # Invoke the tool and expect an error + with pytest.raises(Exception) as exc_info: + list(result_generator) # Consume the generator to trigger the error + + # Verify the error message + assert "Request failed with status code 404" in str(exc_info.value) From 9a5f21462361c5154a5e785c906bb22d1b3c6931 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sun, 19 Oct 2025 21:29:04 +0800 Subject: [PATCH 37/46] refactor: replace localStorage with HTTP-only cookies for auth tokens (#24365) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com> Signed-off-by: lyzno1 Signed-off-by: kenwoodjw Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Yunlu Wen Co-authored-by: Joel Co-authored-by: GareArc Co-authored-by: NFish Co-authored-by: Davide Delbianco Co-authored-by: minglu7 <1347866672@qq.com> Co-authored-by: Ponder Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: heyszt <270985384@qq.com> Co-authored-by: Asuka Minato Co-authored-by: Guangdong Liu Co-authored-by: Eric Guo Co-authored-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com> Co-authored-by: XlKsyt Co-authored-by: Dhruv Gorasiya <80987415+DhruvGorasiya@users.noreply.github.com> Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com> Co-authored-by: hj24 Co-authored-by: GuanMu Co-authored-by: 非法操作 Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Tonlo <123lzs123@gmail.com> Co-authored-by: Yusuke Yamada Co-authored-by: Novice Co-authored-by: kenwoodjw Co-authored-by: Ademílson Tonato Co-authored-by: znn Co-authored-by: yangzheli <43645580+yangzheli@users.noreply.github.com> --- api/constants/__init__.py | 9 + api/controllers/console/admin.py | 15 +- api/controllers/console/auth/login.py | 81 +++++-- api/controllers/console/auth/oauth.py | 14 +- .../console/explore/installed_app.py | 12 +- api/controllers/console/explore/wraps.py | 4 +- api/controllers/web/app.py | 33 ++- api/controllers/web/login.py | 90 +++++++- api/controllers/web/passport.py | 42 ++-- api/controllers/web/wraps.py | 32 ++- api/extensions/ext_blueprints.py | 9 +- api/extensions/ext_login.py | 15 +- api/libs/external_api.py | 15 ++ api/libs/login.py | 4 + api/libs/token.py | 208 ++++++++++++++++++ api/services/account_service.py | 8 +- api/services/enterprise/enterprise_service.py | 10 +- api/services/webapp_auth_service.py | 3 +- .../services/test_webapp_auth_service.py | 9 +- .../controllers/console/auth/test_oauth.py | 20 +- .../unit_tests/libs/test_external_api.py | 65 ++++++ api/tests/unit_tests/libs/test_login.py | 11 + api/tests/unit_tests/libs/test_token.py | 23 ++ .../components/authenticated-layout.tsx | 9 +- web/app/(shareLayout)/components/splash.tsx | 88 +++++--- .../webapp-signin/check-code/page.tsx | 8 +- .../components/mail-and-password-auth.tsx | 23 +- web/app/(shareLayout)/webapp-signin/page.tsx | 9 +- .../account-page/email-change-modal.tsx | 11 +- web/app/account/(commonLayout)/avatar.tsx | 11 +- .../delete-account/components/feed-back.tsx | 11 +- web/app/account/oauth/authorize/layout.tsx | 19 +- web/app/account/oauth/authorize/page.tsx | 15 +- .../access-control-dialog.tsx | 4 +- .../add-member-or-group-pop.tsx | 2 +- .../base/chat/chat-with-history/index.tsx | 33 --- .../header/account-dropdown/index.tsx | 11 +- .../hooks/use-nodes-sync-draft.ts | 2 +- .../share/text-generation/menu-dropdown.tsx | 9 +- web/app/components/share/utils.ts | 56 ----- web/app/components/swr-initializer.tsx | 28 +-- .../hooks/use-nodes-sync-draft.ts | 2 +- web/app/education-apply/user-info.tsx | 11 +- web/app/install/installForm.tsx | 2 - web/app/signin/check-code/page.tsx | 2 - .../components/mail-and-password-auth.tsx | 3 +- web/app/signin/invite-settings/page.tsx | 3 +- web/app/signin/normal-form.tsx | 18 +- web/app/signup/set-password/page.tsx | 4 +- web/config/index.ts | 11 + web/context/web-app-context.tsx | 17 +- web/models/app.ts | 57 ----- web/service/base.ts | 43 ++-- web/service/common.ts | 10 +- web/service/fetch.ts | 49 ++--- web/service/refresh-token.ts | 8 +- web/service/share.ts | 14 +- web/service/use-common.ts | 22 +- web/service/use-share.ts | 2 + web/service/webapp-auth.ts | 53 +++++ 60 files changed, 879 insertions(+), 533 deletions(-) create mode 100644 api/libs/token.py create mode 100644 api/tests/unit_tests/libs/test_token.py create mode 100644 web/service/webapp-auth.ts diff --git a/api/constants/__init__.py b/api/constants/__init__.py index 9141fbea95..248cdfc09f 100644 --- a/api/constants/__init__.py +++ b/api/constants/__init__.py @@ -55,3 +55,12 @@ else: "properties", } DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions) + +COOKIE_NAME_ACCESS_TOKEN = "access_token" +COOKIE_NAME_REFRESH_TOKEN = "refresh_token" +COOKIE_NAME_PASSPORT = "passport" +COOKIE_NAME_CSRF_TOKEN = "csrf_token" + +HEADER_NAME_CSRF_TOKEN = "X-CSRF-Token" +HEADER_NAME_APP_CODE = "X-App-Code" +HEADER_NAME_PASSPORT = "X-App-Passport" diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index ef96184678..2c4d8709eb 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -15,6 +15,7 @@ from constants.languages import supported_language from controllers.console import api, console_ns from controllers.console.wraps import only_edition_cloud from extensions.ext_database import db +from libs.token import extract_access_token from models.model import App, InstalledApp, RecommendedApp @@ -24,19 +25,9 @@ def admin_required(view: Callable[P, R]): if not dify_config.ADMIN_API_KEY: raise Unauthorized("API key is invalid.") - auth_header = request.headers.get("Authorization") - if auth_header is None: + auth_token = extract_access_token(request) + if not auth_token: raise Unauthorized("Authorization header is missing.") - - if " " not in auth_header: - raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - - auth_scheme, auth_token = auth_header.split(None, 1) - auth_scheme = auth_scheme.lower() - - if auth_scheme != "bearer": - raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - if auth_token != dify_config.ADMIN_API_KEY: raise Unauthorized("API key is invalid.") diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 3696c88346..277f9a60a8 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -1,5 +1,5 @@ import flask_login -from flask import request +from flask import make_response, request from flask_restx import Resource, reqparse import services @@ -25,6 +25,16 @@ from controllers.console.wraps import email_password_login_enabled, setup_requir from events.tenant_event import tenant_was_created from libs.helper import email, extract_remote_ip from libs.login import current_account_with_tenant +from libs.token import ( + clear_access_token_from_cookie, + clear_csrf_token_from_cookie, + clear_refresh_token_from_cookie, + extract_access_token, + extract_csrf_token, + set_access_token_to_cookie, + set_csrf_token_to_cookie, + set_refresh_token_to_cookie, +) from services.account_service import AccountService, RegisterService, TenantService from services.billing_service import BillingService from services.errors.account import AccountRegisterError @@ -89,20 +99,36 @@ class LoginApi(Resource): token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) AccountService.reset_login_error_rate_limit(args["email"]) - return {"result": "success", "data": token_pair.model_dump()} + + # Create response with cookies instead of returning tokens in body + response = make_response({"result": "success"}) + + set_access_token_to_cookie(request, response, token_pair.access_token) + set_refresh_token_to_cookie(request, response, token_pair.refresh_token) + set_csrf_token_to_cookie(request, response, token_pair.csrf_token) + + return response @console_ns.route("/logout") class LogoutApi(Resource): @setup_required - def get(self): + def post(self): current_user, _ = current_account_with_tenant() account = current_user if isinstance(account, flask_login.AnonymousUserMixin): - return {"result": "success"} - AccountService.logout(account=account) - flask_login.logout_user() - return {"result": "success"} + response = make_response({"result": "success"}) + else: + AccountService.logout(account=account) + flask_login.logout_user() + response = make_response({"result": "success"}) + + # Clear cookies on logout + clear_access_token_from_cookie(response) + clear_refresh_token_from_cookie(response) + clear_csrf_token_from_cookie(response) + + return response @console_ns.route("/reset-password") @@ -227,17 +253,46 @@ class EmailCodeLoginApi(Resource): raise WorkspacesLimitExceeded() token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) AccountService.reset_login_error_rate_limit(args["email"]) - return {"result": "success", "data": token_pair.model_dump()} + + # Create response with cookies instead of returning tokens in body + response = make_response({"result": "success"}) + + set_csrf_token_to_cookie(request, response, token_pair.csrf_token) + # Set HTTP-only secure cookies for tokens + set_access_token_to_cookie(request, response, token_pair.access_token) + set_refresh_token_to_cookie(request, response, token_pair.refresh_token) + return response @console_ns.route("/refresh-token") class RefreshTokenApi(Resource): def post(self): - parser = reqparse.RequestParser().add_argument("refresh_token", type=str, required=True, location="json") - args = parser.parse_args() + # Get refresh token from cookie instead of request body + refresh_token = request.cookies.get("refresh_token") + + if not refresh_token: + return {"result": "fail", "message": "No refresh token provided"}, 401 try: - new_token_pair = AccountService.refresh_token(args["refresh_token"]) - return {"result": "success", "data": new_token_pair.model_dump()} + new_token_pair = AccountService.refresh_token(refresh_token) + + # Create response with new cookies + response = make_response({"result": "success"}) + + # Update cookies with new tokens + set_csrf_token_to_cookie(request, response, new_token_pair.csrf_token) + set_access_token_to_cookie(request, response, new_token_pair.access_token) + set_refresh_token_to_cookie(request, response, new_token_pair.refresh_token) + return response except Exception as e: - return {"result": "fail", "data": str(e)}, 401 + return {"result": "fail", "message": str(e)}, 401 + + +# this api helps frontend to check whether user is authenticated +# TODO: remove in the future. frontend should redirect to login page by catching 401 status +@console_ns.route("/login/status") +class LoginStatus(Resource): + def get(self): + token = extract_access_token(request) + csrf_token = extract_csrf_token(request) + return {"logged_in": bool(token) and bool(csrf_token)} diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 52459ad5eb..29653b32ec 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -14,6 +14,11 @@ from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from libs.helper import extract_remote_ip from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo +from libs.token import ( + set_access_token_to_cookie, + set_csrf_token_to_cookie, + set_refresh_token_to_cookie, +) from models import Account, AccountStatus from services.account_service import AccountService, RegisterService, TenantService from services.billing_service import BillingService @@ -152,9 +157,12 @@ class OAuthCallback(Resource): ip_address=extract_remote_ip(request), ) - return redirect( - f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}" - ) + response = redirect(f"{dify_config.CONSOLE_WEB_URL}") + + set_access_token_to_cookie(request, response, token_pair.access_token) + set_refresh_token_to_cookie(request, response, token_pair.refresh_token) + set_csrf_token_to_cookie(request, response, token_pair.csrf_token) + return response def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Account | None: diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index dec84b68f4..3c95779475 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -15,7 +15,6 @@ from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required from models import App, InstalledApp, RecommendedApp from services.account_service import TenantService -from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService @@ -67,31 +66,26 @@ class InstalledAppsListApi(Resource): # Pre-filter out apps without setting or with sso_verified filtered_installed_apps = [] - app_id_to_app_code = {} for installed_app in installed_app_list: app_id = installed_app["app"].id webapp_setting = webapp_settings.get(app_id) if not webapp_setting or webapp_setting.access_mode == "sso_verified": continue - app_code = AppService.get_app_code_by_id(str(app_id)) - app_id_to_app_code[app_id] = app_code filtered_installed_apps.append(installed_app) - app_codes = list(app_id_to_app_code.values()) - # Batch permission check + app_ids = [installed_app["app"].id for installed_app in filtered_installed_apps] permissions = EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps( user_id=user_id, - app_codes=app_codes, + app_ids=app_ids, ) # Keep only allowed apps res = [] for installed_app in filtered_installed_apps: app_id = installed_app["app"].id - app_code = app_id_to_app_code[app_id] - if permissions.get(app_code): + if permissions.get(app_id): res.append(installed_app) installed_app_list = res diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index df4eed18eb..2a97d312aa 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -10,7 +10,6 @@ from controllers.console.wraps import account_initialization_required from extensions.ext_database import db from libs.login import current_account_with_tenant, login_required from models import InstalledApp -from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService @@ -56,10 +55,9 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] | feature = FeatureService.get_system_features() if feature.webapp_auth.enabled: app_id = installed_app.app_id - app_code = AppService.get_app_code_by_id(app_id) res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp( user_id=str(current_user.id), - app_code=app_code, + app_id=app_id, ) if not res: raise AppAccessDeniedError() diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index d7facdbbb3..60193f5f15 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -4,12 +4,14 @@ from flask import request from flask_restx import Resource, marshal_with, reqparse from werkzeug.exceptions import Unauthorized +from constants import HEADER_NAME_APP_CODE from controllers.common import fields from controllers.web import web_ns from controllers.web.error import AppUnavailableError from controllers.web.wraps import WebApiResource from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict from libs.passport import PassportService +from libs.token import extract_webapp_passport from models.model import App, AppMode from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService @@ -133,18 +135,19 @@ class AppWebAuthPermission(Resource): ) def get(self): user_id = "visitor" + app_code = request.headers.get(HEADER_NAME_APP_CODE) + app_id = request.args.get("appId") + if not app_id or not app_code: + raise ValueError("appId must be provided") + + require_permission_check = WebAppAuthService.is_app_require_permission_check(app_id=app_id) + if not require_permission_check: + return {"result": True} + try: - auth_header = request.headers.get("Authorization") - if auth_header is None: - raise Unauthorized("Authorization header is missing.") - if " " not in auth_header: - raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - - auth_scheme, tk = auth_header.split(None, 1) - auth_scheme = auth_scheme.lower() - if auth_scheme != "bearer": - raise Unauthorized("Authorization scheme must be 'Bearer'") - + tk = extract_webapp_passport(app_code, request) + if not tk: + raise Unauthorized("Access token is missing.") decoded = PassportService().verify(tk) user_id = decoded.get("user_id", "visitor") except Unauthorized: @@ -157,13 +160,7 @@ class AppWebAuthPermission(Resource): if not features.webapp_auth.enabled: return {"result": True} - parser = reqparse.RequestParser().add_argument("appId", type=str, required=True, location="args") - args = parser.parse_args() - - app_id = args["appId"] - app_code = AppService.get_app_code_by_id(app_id) - res = True if WebAppAuthService.is_app_require_permission_check(app_id=app_id): - res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_code) + res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_id) return {"result": res} diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py index 351f245f4a..f213fd8c90 100644 --- a/api/controllers/web/login.py +++ b/api/controllers/web/login.py @@ -1,7 +1,9 @@ +from flask import make_response, request from flask_restx import Resource, reqparse from jwt import InvalidTokenError import services +from configs import dify_config from controllers.console.auth.error import ( AuthenticationFailedError, EmailCodeError, @@ -10,9 +12,16 @@ from controllers.console.auth.error import ( from controllers.console.error import AccountBannedError from controllers.console.wraps import only_edition_enterprise, setup_required from controllers.web import web_ns +from controllers.web.wraps import decode_jwt_token from libs.helper import email +from libs.passport import PassportService from libs.password import valid_password +from libs.token import ( + clear_access_token_from_cookie, + extract_access_token, +) from services.account_service import AccountService +from services.app_service import AppService from services.webapp_auth_service import WebAppAuthService @@ -52,17 +61,75 @@ class LoginApi(Resource): raise AuthenticationFailedError() token = WebAppAuthService.login(account=account) - return {"result": "success", "data": {"access_token": token}} + response = make_response({"result": "success", "data": {"access_token": token}}) + # set_access_token_to_cookie(request, response, token, samesite="None", httponly=False) + return response -# class LogoutApi(Resource): -# @setup_required -# def get(self): -# account = cast(Account, flask_login.current_user) -# if isinstance(account, flask_login.AnonymousUserMixin): -# return {"result": "success"} -# flask_login.logout_user() -# return {"result": "success"} +# this api helps frontend to check whether user is authenticated +# TODO: remove in the future. frontend should redirect to login page by catching 401 status +@web_ns.route("/login/status") +class LoginStatusApi(Resource): + @setup_required + @web_ns.doc("web_app_login_status") + @web_ns.doc(description="Check login status") + @web_ns.doc( + responses={ + 200: "Login status", + 401: "Login status", + } + ) + def get(self): + app_code = request.args.get("app_code") + token = extract_access_token(request) + if not app_code: + return { + "logged_in": bool(token), + "app_logged_in": False, + } + app_id = AppService.get_app_id_by_code(app_code) + is_public = not dify_config.ENTERPRISE_ENABLED or not WebAppAuthService.is_app_require_permission_check( + app_id=app_id + ) + user_logged_in = False + + if is_public: + user_logged_in = True + else: + try: + PassportService().verify(token=token) + user_logged_in = True + except Exception: + user_logged_in = False + + try: + _ = decode_jwt_token(app_code=app_code) + app_logged_in = True + except Exception: + app_logged_in = False + + return { + "logged_in": user_logged_in, + "app_logged_in": app_logged_in, + } + + +@web_ns.route("/logout") +class LogoutApi(Resource): + @setup_required + @web_ns.doc("web_app_logout") + @web_ns.doc(description="Logout user from web application") + @web_ns.doc( + responses={ + 200: "Logout successful", + } + ) + def post(self): + response = make_response({"result": "success"}) + # enterprise SSO sets same site to None in https deployment + # so we need to logout by calling api + clear_access_token_from_cookie(response, samesite="None") + return response @web_ns.route("/email-code-login") @@ -96,7 +163,6 @@ class EmailCodeLoginSendEmailApi(Resource): raise AuthenticationFailedError() else: token = WebAppAuthService.send_email_code_login_email(account=account, language=language) - return {"result": "success", "data": token} @@ -142,4 +208,6 @@ class EmailCodeLoginApi(Resource): token = WebAppAuthService.login(account=account) AccountService.reset_login_error_rate_limit(args["email"]) - return {"result": "success", "data": {"access_token": token}} + response = make_response({"result": "success", "data": {"access_token": token}}) + # set_access_token_to_cookie(request, response, token, samesite="None", httponly=False) + return response diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index 7190f06426..776b743e92 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -1,17 +1,20 @@ import uuid from datetime import UTC, datetime, timedelta -from flask import request +from flask import make_response, request from flask_restx import Resource from sqlalchemy import func, select from werkzeug.exceptions import NotFound, Unauthorized from configs import dify_config +from constants import HEADER_NAME_APP_CODE from controllers.web import web_ns from controllers.web.error import WebAppAuthRequiredError from extensions.ext_database import db from libs.passport import PassportService +from libs.token import extract_access_token from models.model import App, EndUser, Site +from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService from services.webapp_auth_service import WebAppAuthService, WebAppAuthType @@ -32,15 +35,15 @@ class PassportResource(Resource): ) def get(self): system_features = FeatureService.get_system_features() - app_code = request.headers.get("X-App-Code") + app_code = request.headers.get(HEADER_NAME_APP_CODE) user_id = request.args.get("user_id") - web_app_access_token = request.args.get("web_app_access_token") + access_token = extract_access_token(request) if app_code is None: raise Unauthorized("X-App-Code header is missing.") - + app_id = AppService.get_app_id_by_code(app_code) # exchange token for enterprise logined web user - enterprise_user_decoded = decode_enterprise_webapp_user_id(web_app_access_token) + enterprise_user_decoded = decode_enterprise_webapp_user_id(access_token) if enterprise_user_decoded: # a web user has already logged in, exchange a token for this app without redirecting to the login page return exchange_token_for_existing_web_user( @@ -48,7 +51,7 @@ class PassportResource(Resource): ) if system_features.webapp_auth.enabled: - app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code) + app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id) if not app_settings or not app_settings.access_mode == "public": raise WebAppAuthRequiredError() @@ -99,9 +102,12 @@ class PassportResource(Resource): tk = PassportService().issue(payload) - return { - "access_token": tk, - } + response = make_response( + { + "access_token": tk, + } + ) + return response def decode_enterprise_webapp_user_id(jwt_token: str | None): @@ -189,9 +195,12 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: "exp": exp, } token: str = PassportService().issue(payload) - return { - "access_token": token, - } + resp = make_response( + { + "access_token": token, + } + ) + return resp def _exchange_for_public_app_token(app_model, site, token_decoded): @@ -224,9 +233,12 @@ def _exchange_for_public_app_token(app_model, site, token_decoded): tk = PassportService().issue(payload) - return { - "access_token": tk, - } + resp = make_response( + { + "access_token": tk, + } + ) + return resp def generate_session_id(): diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index ba03c4eae4..9efd9f25d1 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -9,10 +9,13 @@ from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, NotFound, Unauthorized +from constants import HEADER_NAME_APP_CODE from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError from extensions.ext_database import db from libs.passport import PassportService +from libs.token import extract_webapp_passport from models.model import App, EndUser, Site +from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService, WebAppSettings from services.feature_service import FeatureService from services.webapp_auth_service import WebAppAuthService @@ -35,22 +38,14 @@ def validate_jwt_token(view: Callable[Concatenate[App, EndUser, P], R] | None = return decorator -def decode_jwt_token(): +def decode_jwt_token(app_code: str | None = None): system_features = FeatureService.get_system_features() - app_code = str(request.headers.get("X-App-Code")) + if not app_code: + app_code = str(request.headers.get(HEADER_NAME_APP_CODE)) try: - auth_header = request.headers.get("Authorization") - if auth_header is None: - raise Unauthorized("Authorization header is missing.") - - if " " not in auth_header: - raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - - auth_scheme, tk = auth_header.split(None, 1) - auth_scheme = auth_scheme.lower() - - if auth_scheme != "bearer": - raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") + tk = extract_webapp_passport(app_code, request) + if not tk: + raise Unauthorized("App token is missing.") decoded = PassportService().verify(tk) app_code = decoded.get("app_code") app_id = decoded.get("app_id") @@ -72,7 +67,8 @@ def decode_jwt_token(): app_web_auth_enabled = False webapp_settings = None if system_features.webapp_auth.enabled: - webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code) + app_id = AppService.get_app_id_by_code(app_code) + webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id) if not webapp_settings: raise NotFound("Web app settings not found.") app_web_auth_enabled = webapp_settings.access_mode != "public" @@ -87,8 +83,9 @@ def decode_jwt_token(): if system_features.webapp_auth.enabled: if not app_code: raise Unauthorized("Please re-login to access the web app.") + app_id = AppService.get_app_id_by_code(app_code) app_web_auth_enabled = ( - EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=str(app_code)).access_mode != "public" + EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id).access_mode != "public" ) if app_web_auth_enabled: raise WebAppAuthRequiredError() @@ -129,7 +126,8 @@ def _validate_user_accessibility( raise WebAppAuthRequiredError("Web app settings not found.") if WebAppAuthService.is_app_require_permission_check(access_mode=webapp_settings.access_mode): - if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code): + app_id = AppService.get_app_id_by_code(app_code) + if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_id): raise WebAppAuthAccessDeniedError() auth_type = decoded.get("auth_type") diff --git a/api/extensions/ext_blueprints.py b/api/extensions/ext_blueprints.py index 9c08a08c45..52fef4929f 100644 --- a/api/extensions/ext_blueprints.py +++ b/api/extensions/ext_blueprints.py @@ -1,4 +1,5 @@ from configs import dify_config +from constants import HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN from dify_app import DifyApp @@ -16,7 +17,7 @@ def init_app(app: DifyApp): CORS( service_api_bp, - allow_headers=["Content-Type", "Authorization", "X-App-Code"], + allow_headers=["Content-Type", "Authorization", HEADER_NAME_APP_CODE], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], ) app.register_blueprint(service_api_bp) @@ -25,7 +26,7 @@ def init_app(app: DifyApp): web_bp, resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}}, supports_credentials=True, - allow_headers=["Content-Type", "Authorization", "X-App-Code"], + allow_headers=["Content-Type", "Authorization", HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], expose_headers=["X-Version", "X-Env"], ) @@ -35,7 +36,7 @@ def init_app(app: DifyApp): console_app_bp, resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}}, supports_credentials=True, - allow_headers=["Content-Type", "Authorization"], + allow_headers=["Content-Type", "Authorization", HEADER_NAME_CSRF_TOKEN], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], expose_headers=["X-Version", "X-Env"], ) @@ -43,7 +44,7 @@ def init_app(app: DifyApp): CORS( files_bp, - allow_headers=["Content-Type"], + allow_headers=["Content-Type", HEADER_NAME_CSRF_TOKEN], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], ) app.register_blueprint(files_bp) diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index 836a5d938c..e7816a2e88 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -9,6 +9,7 @@ from configs import dify_config from dify_app import DifyApp from extensions.ext_database import db from libs.passport import PassportService +from libs.token import extract_access_token from models import Account, Tenant, TenantAccountJoin from models.model import AppMCPServer, EndUser from services.account_service import AccountService @@ -24,20 +25,10 @@ def load_user_from_request(request_from_flask_login): if dify_config.SWAGGER_UI_ENABLED and request.path.endswith((dify_config.SWAGGER_UI_PATH, "/swagger.json")): return None - auth_header = request.headers.get("Authorization", "") - auth_token: str | None = None - if auth_header: - if " " not in auth_header: - raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - auth_scheme, auth_token = auth_header.split(maxsplit=1) - auth_scheme = auth_scheme.lower() - if auth_scheme != "bearer": - raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - else: - auth_token = request.args.get("_token") + auth_token = extract_access_token(request) # Check for admin API key authentication first - if dify_config.ADMIN_API_KEY_ENABLE and auth_header: + if dify_config.ADMIN_API_KEY_ENABLE and auth_token: admin_api_key = dify_config.ADMIN_API_KEY if admin_api_key and admin_api_key == auth_token: workspace_id = request.headers.get("X-WORKSPACE-ID") diff --git a/api/libs/external_api.py b/api/libs/external_api.py index a59230caaa..f3ebcc4306 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -9,7 +9,9 @@ from werkzeug.exceptions import HTTPException from werkzeug.http import HTTP_STATUS_CODES from configs import dify_config +from constants import COOKIE_NAME_ACCESS_TOKEN, COOKIE_NAME_CSRF_TOKEN, COOKIE_NAME_REFRESH_TOKEN from core.errors.error import AppInvokeQuotaExceededError +from libs.token import is_secure def http_status_message(code): @@ -67,6 +69,19 @@ def register_external_error_handlers(api: Api): # If you need WWW-Authenticate for 401, add it to headers if status_code == 401: headers["WWW-Authenticate"] = 'Bearer realm="api"' + # Check if this is a forced logout error - clear cookies + error_code = getattr(e, "error_code", None) + if error_code == "unauthorized_and_force_logout": + # Add Set-Cookie headers to clear auth cookies + + secure = is_secure() + # response is not accessible, so we need to do it ugly + common_part = "Path=/; Expires=Thu, 01 Jan 1970 00:00:00 GMT; HttpOnly" + headers["Set-Cookie"] = [ + f'{COOKIE_NAME_ACCESS_TOKEN}=""; {common_part}{"; Secure" if secure else ""}; SameSite=Lax', + f'{COOKIE_NAME_CSRF_TOKEN}=""; {common_part}{"; Secure" if secure else ""}; SameSite=Lax', + f'{COOKIE_NAME_REFRESH_TOKEN}=""; {common_part}{"; Secure" if secure else ""}; SameSite=Lax', + ] return data, status_code, headers _ = handle_http_exception diff --git a/api/libs/login.py b/api/libs/login.py index d0e81a3441..5ed4bfae8f 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -7,6 +7,7 @@ from flask_login.config import EXEMPT_METHODS # type: ignore from werkzeug.local import LocalProxy from configs import dify_config +from libs.token import check_csrf_token from models import Account from models.model import EndUser @@ -73,6 +74,9 @@ def login_required(func: Callable[P, R]): pass elif current_user is not None and not current_user.is_authenticated: return current_app.login_manager.unauthorized() # type: ignore + # we put csrf validation here for less conflicts + # TODO: maybe find a better place for it. + check_csrf_token(request, current_user.id) return current_app.ensure_sync(func)(*args, **kwargs) return decorated_view diff --git a/api/libs/token.py b/api/libs/token.py new file mode 100644 index 0000000000..4be25696e7 --- /dev/null +++ b/api/libs/token.py @@ -0,0 +1,208 @@ +import logging +import re +from datetime import UTC, datetime, timedelta + +from flask import Request +from werkzeug.exceptions import Unauthorized +from werkzeug.wrappers import Response + +from configs import dify_config +from constants import ( + COOKIE_NAME_ACCESS_TOKEN, + COOKIE_NAME_CSRF_TOKEN, + COOKIE_NAME_PASSPORT, + COOKIE_NAME_REFRESH_TOKEN, + HEADER_NAME_CSRF_TOKEN, + HEADER_NAME_PASSPORT, +) +from libs.passport import PassportService + +logger = logging.getLogger(__name__) + +CSRF_WHITE_LIST = [ + re.compile(r"/console/api/apps/[a-f0-9-]+/workflows/draft"), +] + + +# server is behind a reverse proxy, so we need to check the url +def is_secure() -> bool: + return dify_config.CONSOLE_WEB_URL.startswith("https") and dify_config.CONSOLE_API_URL.startswith("https") + + +def _real_cookie_name(cookie_name: str) -> str: + if is_secure(): + return "__Host-" + cookie_name + else: + return cookie_name + + +def _try_extract_from_header(request: Request) -> str | None: + """ + Try to extract access token from header + """ + auth_header = request.headers.get("Authorization") + if auth_header: + if " " not in auth_header: + return None + else: + auth_scheme, auth_token = auth_header.split(None, 1) + auth_scheme = auth_scheme.lower() + if auth_scheme != "bearer": + return None + else: + return auth_token + return None + + +def extract_csrf_token(request: Request) -> str | None: + """ + Try to extract CSRF token from header or cookie. + """ + return request.headers.get(HEADER_NAME_CSRF_TOKEN) + + +def extract_csrf_token_from_cookie(request: Request) -> str | None: + """ + Try to extract CSRF token from cookie. + """ + return request.cookies.get(_real_cookie_name(COOKIE_NAME_CSRF_TOKEN)) + + +def extract_access_token(request: Request) -> str | None: + """ + Try to extract access token from cookie, header or params. + + Access token is either for console session or webapp passport exchange. + """ + + def _try_extract_from_cookie(request: Request) -> str | None: + return request.cookies.get(_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN)) + + return _try_extract_from_cookie(request) or _try_extract_from_header(request) + + +def extract_webapp_passport(app_code: str, request: Request) -> str | None: + """ + Try to extract app token from header or params. + + Webapp access token (part of passport) is only used for webapp session. + """ + + def _try_extract_passport_token_from_cookie(request: Request) -> str | None: + return request.cookies.get(_real_cookie_name(COOKIE_NAME_PASSPORT + "-" + app_code)) + + def _try_extract_passport_token_from_header(request: Request) -> str | None: + return request.headers.get(HEADER_NAME_PASSPORT) + + ret = _try_extract_passport_token_from_cookie(request) or _try_extract_passport_token_from_header(request) + return ret + + +def set_access_token_to_cookie(request: Request, response: Response, token: str, samesite: str = "Lax"): + response.set_cookie( + _real_cookie_name(COOKIE_NAME_ACCESS_TOKEN), + value=token, + httponly=True, + secure=is_secure(), + samesite=samesite, + max_age=int(dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 60), + path="/", + ) + + +def set_refresh_token_to_cookie(request: Request, response: Response, token: str): + response.set_cookie( + _real_cookie_name(COOKIE_NAME_REFRESH_TOKEN), + value=token, + httponly=True, + secure=is_secure(), + samesite="Lax", + max_age=int(60 * 60 * 24 * dify_config.REFRESH_TOKEN_EXPIRE_DAYS), + path="/", + ) + + +def set_csrf_token_to_cookie(request: Request, response: Response, token: str): + response.set_cookie( + _real_cookie_name(COOKIE_NAME_CSRF_TOKEN), + value=token, + httponly=False, + secure=is_secure(), + samesite="Lax", + max_age=int(60 * dify_config.ACCESS_TOKEN_EXPIRE_MINUTES), + path="/", + ) + + +def _clear_cookie( + response: Response, + cookie_name: str, + samesite: str = "Lax", + http_only: bool = True, +): + response.set_cookie( + _real_cookie_name(cookie_name), + "", + expires=0, + path="/", + secure=is_secure(), + httponly=http_only, + samesite=samesite, + ) + + +def clear_access_token_from_cookie(response: Response, samesite: str = "Lax"): + _clear_cookie(response, COOKIE_NAME_ACCESS_TOKEN, samesite) + + +def clear_refresh_token_from_cookie(response: Response): + _clear_cookie(response, COOKIE_NAME_REFRESH_TOKEN) + + +def clear_csrf_token_from_cookie(response: Response): + _clear_cookie(response, COOKIE_NAME_CSRF_TOKEN, http_only=False) + + +def check_csrf_token(request: Request, user_id: str): + # some apis are sent by beacon, so we need to bypass csrf token check + # since these APIs are post, they are already protected by SameSite: Lax, so csrf is not required. + def _unauthorized(): + raise Unauthorized("CSRF token is missing or invalid.") + + for pattern in CSRF_WHITE_LIST: + if pattern.match(request.path): + return + + csrf_token = extract_csrf_token(request) + csrf_token_from_cookie = extract_csrf_token_from_cookie(request) + + if csrf_token != csrf_token_from_cookie: + _unauthorized() + + if not csrf_token: + _unauthorized() + verified = {} + try: + verified = PassportService().verify(csrf_token) + except: + _unauthorized() + + if verified.get("sub") != user_id: + _unauthorized() + + exp: int | None = verified.get("exp") + if not exp: + _unauthorized() + else: + time_now = int(datetime.now().timestamp()) + if exp < time_now: + _unauthorized() + + +def generate_csrf_token(user_id: str) -> str: + exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES) + payload = { + "exp": int(exp_dt.timestamp()), + "sub": user_id, + } + return PassportService().issue(payload) diff --git a/api/services/account_service.py b/api/services/account_service.py index 106bc0e77e..cb0eb7a9dd 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -22,6 +22,7 @@ from libs.helper import RateLimiter, TokenManager from libs.passport import PassportService from libs.password import compare_password, hash_password, valid_password from libs.rsa import generate_key_pair +from libs.token import generate_csrf_token from models.account import ( Account, AccountIntegrate, @@ -76,6 +77,7 @@ logger = logging.getLogger(__name__) class TokenPair(BaseModel): access_token: str refresh_token: str + csrf_token: str REFRESH_TOKEN_PREFIX = "refresh_token:" @@ -403,10 +405,11 @@ class AccountService: access_token = AccountService.get_account_jwt_token(account=account) refresh_token = _generate_refresh_token() + csrf_token = generate_csrf_token(account.id) AccountService._store_refresh_token(refresh_token, account.id) - return TokenPair(access_token=access_token, refresh_token=refresh_token) + return TokenPair(access_token=access_token, refresh_token=refresh_token, csrf_token=csrf_token) @staticmethod def logout(*, account: Account): @@ -431,8 +434,9 @@ class AccountService: AccountService._delete_refresh_token(refresh_token, account.id) AccountService._store_refresh_token(new_refresh_token, account.id) + csrf_token = generate_csrf_token(account.id) - return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token) + return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token, csrf_token=csrf_token) @staticmethod def load_logged_in_account(*, account_id: str): diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 4fbf33fd6f..974aa849db 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -46,17 +46,17 @@ class EnterpriseService: class WebAppAuth: @classmethod - def is_user_allowed_to_access_webapp(cls, user_id: str, app_code: str): - params = {"userId": user_id, "appCode": app_code} + def is_user_allowed_to_access_webapp(cls, user_id: str, app_id: str): + params = {"userId": user_id, "appId": app_id} data = EnterpriseRequest.send_request("GET", "/webapp/permission", params=params) return data.get("result", False) @classmethod - def batch_is_user_allowed_to_access_webapps(cls, user_id: str, app_codes: list[str]): - if not app_codes: + def batch_is_user_allowed_to_access_webapps(cls, user_id: str, app_ids: list[str]): + if not app_ids: return {} - body = {"userId": user_id, "appCodes": app_codes} + body = {"userId": user_id, "appIds": app_ids} data = EnterpriseRequest.send_request("POST", "/webapp/permission/batch", json=body) if not data: raise ValueError("No data found.") diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index 693bfb95b6..9bd797a45f 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -172,7 +172,8 @@ class WebAppAuthService: return WebAppAuthType.EXTERNAL if app_code: - webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code) + app_id = AppService.get_app_id_by_code(app_code) + webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id) return cls.get_app_auth_type(access_mode=webapp_settings.access_mode) raise ValueError("Could not determine app authentication type.") diff --git a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py index 9fc16d9eb7..73e622b061 100644 --- a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py @@ -863,13 +863,14 @@ class TestWebAppAuthService: - Mock service integration """ # Arrange: Setup mock for enterprise service - mock_webapp_auth = type("MockWebAppAuth", (), {"access_mode": "sso_verified"})() + mock_external_service_dependencies["app_service"].get_app_id_by_code.return_value = "mock_app_id" + setting = type("MockWebAppAuth", (), {"access_mode": "sso_verified"})() mock_external_service_dependencies[ "enterprise_service" - ].WebAppAuth.get_app_access_mode_by_code.return_value = mock_webapp_auth + ].WebAppAuth.get_app_access_mode_by_id.return_value = setting # Act: Execute authentication type determination - result = WebAppAuthService.get_app_auth_type(app_code="mock_app_code") + result: WebAppAuthType = WebAppAuthService.get_app_auth_type(app_code="mock_app_code") # Assert: Verify correct result assert result == WebAppAuthType.EXTERNAL @@ -877,7 +878,7 @@ class TestWebAppAuthService: # Verify mock service was called correctly mock_external_service_dependencies[ "enterprise_service" - ].WebAppAuth.get_app_access_mode_by_code.assert_called_once_with("mock_app_code") + ].WebAppAuth.get_app_access_mode_by_id.assert_called_once_with(app_id="mock_app_id") def test_get_app_auth_type_no_parameters(self, db_session_with_containers, mock_external_service_dependencies): """ diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth.py b/api/tests/unit_tests/controllers/console/auth/test_oauth.py index 67f4b85413..399caf8c4d 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/unit_tests/controllers/console/auth/test_oauth.py @@ -179,9 +179,7 @@ class TestOAuthCallback: oauth_setup["provider"].get_access_token.assert_called_once_with("test_code") oauth_setup["provider"].get_user_info.assert_called_once_with("access_token") - mock_redirect.assert_called_once_with( - "http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token" - ) + mock_redirect.assert_called_once_with("http://localhost:3000") @pytest.mark.parametrize( ("exception", "expected_error"), @@ -224,8 +222,8 @@ class TestOAuthCallback: # CLOSED status: Currently NOT handled, will proceed to login (security issue) # This documents actual behavior. See test_defensive_check_for_closed_account_status for details ( - AccountStatus.CLOSED, - "http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token", + AccountStatus.CLOSED.value, + "http://localhost:3000", ), ], ) @@ -268,6 +266,7 @@ class TestOAuthCallback: mock_token_pair = MagicMock() mock_token_pair.access_token = "jwt_access_token" mock_token_pair.refresh_token = "jwt_refresh_token" + mock_token_pair.csrf_token = "csrf_token" mock_account_service.login.return_value = mock_token_pair with app.test_request_context("/auth/oauth/github/callback?code=test_code"): @@ -299,6 +298,12 @@ class TestOAuthCallback: mock_account.status = AccountStatus.PENDING mock_generate_account.return_value = mock_account + mock_token_pair = MagicMock() + mock_token_pair.access_token = "jwt_access_token" + mock_token_pair.refresh_token = "jwt_refresh_token" + mock_token_pair.csrf_token = "csrf_token" + mock_account_service.login.return_value = mock_token_pair + with app.test_request_context("/auth/oauth/github/callback?code=test_code"): resource.get("github") @@ -361,6 +366,7 @@ class TestOAuthCallback: mock_token_pair = MagicMock() mock_token_pair.access_token = "jwt_access_token" mock_token_pair.refresh_token = "jwt_refresh_token" + mock_token_pair.csrf_token = "csrf_token" mock_account_service.login.return_value = mock_token_pair # Execute OAuth callback @@ -368,9 +374,7 @@ class TestOAuthCallback: resource.get("github") # Verify current behavior: login succeeds (this is NOT ideal) - mock_redirect.assert_called_once_with( - "http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token" - ) + mock_redirect.assert_called_once_with("http://localhost:3000") mock_account_service.login.assert_called_once() # Document expected behavior in comments: diff --git a/api/tests/unit_tests/libs/test_external_api.py b/api/tests/unit_tests/libs/test_external_api.py index a9edb913ea..c4c376a070 100644 --- a/api/tests/unit_tests/libs/test_external_api.py +++ b/api/tests/unit_tests/libs/test_external_api.py @@ -2,7 +2,9 @@ from flask import Blueprint, Flask from flask_restx import Resource from werkzeug.exceptions import BadRequest, Unauthorized +from constants import COOKIE_NAME_ACCESS_TOKEN, COOKIE_NAME_CSRF_TOKEN, COOKIE_NAME_REFRESH_TOKEN from core.errors.error import AppInvokeQuotaExceededError +from libs.exception import BaseHTTPException from libs.external_api import ExternalApi @@ -120,3 +122,66 @@ def test_external_api_param_mapping_and_quota_and_exc_info_none(): assert res.status_code in (400, 429) finally: ext.sys.exc_info = orig_exc_info # type: ignore[assignment] + + +def test_unauthorized_and_force_logout_clears_cookies(): + """Test that UnauthorizedAndForceLogout error clears auth cookies""" + + class UnauthorizedAndForceLogout(BaseHTTPException): + error_code = "unauthorized_and_force_logout" + description = "Unauthorized and force logout." + code = 401 + + app = Flask(__name__) + bp = Blueprint("test", __name__) + api = ExternalApi(bp) + + @api.route("/force-logout") + class ForceLogout(Resource): # type: ignore + def get(self): # type: ignore + raise UnauthorizedAndForceLogout() + + app.register_blueprint(bp, url_prefix="/api") + client = app.test_client() + + # Set cookies first + client.set_cookie(COOKIE_NAME_ACCESS_TOKEN, "test_access_token") + client.set_cookie(COOKIE_NAME_CSRF_TOKEN, "test_csrf_token") + client.set_cookie(COOKIE_NAME_REFRESH_TOKEN, "test_refresh_token") + + # Make request that should trigger cookie clearing + res = client.get("/api/force-logout") + + # Verify response + assert res.status_code == 401 + data = res.get_json() + assert data["code"] == "unauthorized_and_force_logout" + assert data["status"] == 401 + assert "WWW-Authenticate" in res.headers + + # Verify Set-Cookie headers are present to clear cookies + set_cookie_headers = res.headers.getlist("Set-Cookie") + assert len(set_cookie_headers) == 3, f"Expected 3 Set-Cookie headers, got {len(set_cookie_headers)}" + + # Verify each cookie is being cleared (empty value and expired) + cookie_names_found = set() + for cookie_header in set_cookie_headers: + # Check for cookie names + if COOKIE_NAME_ACCESS_TOKEN in cookie_header: + cookie_names_found.add(COOKIE_NAME_ACCESS_TOKEN) + assert '""' in cookie_header or "=" in cookie_header # Empty value + assert "Expires=Thu, 01 Jan 1970" in cookie_header # Expired + elif COOKIE_NAME_CSRF_TOKEN in cookie_header: + cookie_names_found.add(COOKIE_NAME_CSRF_TOKEN) + assert '""' in cookie_header or "=" in cookie_header + assert "Expires=Thu, 01 Jan 1970" in cookie_header + elif COOKIE_NAME_REFRESH_TOKEN in cookie_header: + cookie_names_found.add(COOKIE_NAME_REFRESH_TOKEN) + assert '""' in cookie_header or "=" in cookie_header + assert "Expires=Thu, 01 Jan 1970" in cookie_header + + # Verify all three cookies are present + assert len(cookie_names_found) == 3 + assert COOKIE_NAME_ACCESS_TOKEN in cookie_names_found + assert COOKIE_NAME_CSRF_TOKEN in cookie_names_found + assert COOKIE_NAME_REFRESH_TOKEN in cookie_names_found diff --git a/api/tests/unit_tests/libs/test_login.py b/api/tests/unit_tests/libs/test_login.py index 39671077d4..35155b4931 100644 --- a/api/tests/unit_tests/libs/test_login.py +++ b/api/tests/unit_tests/libs/test_login.py @@ -19,10 +19,15 @@ class MockUser(UserMixin): return self._is_authenticated +def mock_csrf_check(*args, **kwargs): + return + + class TestLoginRequired: """Test cases for login_required decorator.""" @pytest.fixture + @patch("libs.login.check_csrf_token", mock_csrf_check) def setup_app(self, app: Flask): """Set up Flask app with login manager.""" # Initialize login manager @@ -39,6 +44,7 @@ class TestLoginRequired: return app + @patch("libs.login.check_csrf_token", mock_csrf_check) def test_authenticated_user_can_access_protected_view(self, setup_app: Flask): """Test that authenticated users can access protected views.""" @@ -53,6 +59,7 @@ class TestLoginRequired: result = protected_view() assert result == "Protected content" + @patch("libs.login.check_csrf_token", mock_csrf_check) def test_unauthenticated_user_cannot_access_protected_view(self, setup_app: Flask): """Test that unauthenticated users are redirected.""" @@ -68,6 +75,7 @@ class TestLoginRequired: assert result == "Unauthorized" setup_app.login_manager.unauthorized.assert_called_once() + @patch("libs.login.check_csrf_token", mock_csrf_check) def test_login_disabled_allows_unauthenticated_access(self, setup_app: Flask): """Test that LOGIN_DISABLED config bypasses authentication.""" @@ -87,6 +95,7 @@ class TestLoginRequired: # Ensure unauthorized was not called setup_app.login_manager.unauthorized.assert_not_called() + @patch("libs.login.check_csrf_token", mock_csrf_check) def test_options_request_bypasses_authentication(self, setup_app: Flask): """Test that OPTIONS requests are exempt from authentication.""" @@ -103,6 +112,7 @@ class TestLoginRequired: # Ensure unauthorized was not called setup_app.login_manager.unauthorized.assert_not_called() + @patch("libs.login.check_csrf_token", mock_csrf_check) def test_flask_2_compatibility(self, setup_app: Flask): """Test Flask 2.x compatibility with ensure_sync.""" @@ -120,6 +130,7 @@ class TestLoginRequired: assert result == "Synced content" setup_app.ensure_sync.assert_called_once() + @patch("libs.login.check_csrf_token", mock_csrf_check) def test_flask_1_compatibility(self, setup_app: Flask): """Test Flask 1.x compatibility without ensure_sync.""" diff --git a/api/tests/unit_tests/libs/test_token.py b/api/tests/unit_tests/libs/test_token.py new file mode 100644 index 0000000000..22790fa4a6 --- /dev/null +++ b/api/tests/unit_tests/libs/test_token.py @@ -0,0 +1,23 @@ +from constants import COOKIE_NAME_ACCESS_TOKEN +from libs.token import extract_access_token + + +class MockRequest: + def __init__(self, headers: dict[str, str], cookies: dict[str, str], args: dict[str, str]): + self.headers: dict[str, str] = headers + self.cookies: dict[str, str] = cookies + self.args: dict[str, str] = args + + +def test_extract_access_token(): + def _mock_request(headers: dict[str, str], cookies: dict[str, str], args: dict[str, str]): + return MockRequest(headers, cookies, args) + + test_cases = [ + (_mock_request({"Authorization": "Bearer 123"}, {}, {}), "123"), + (_mock_request({}, {COOKIE_NAME_ACCESS_TOKEN: "123"}, {}), "123"), + (_mock_request({}, {}, {}), None), + (_mock_request({"Authorization": "Bearer_aaa 123"}, {}, {}), None), + ] + for request, expected in test_cases: + assert extract_access_token(request) == expected # pyright: ignore[reportArgumentType] diff --git a/web/app/(shareLayout)/components/authenticated-layout.tsx b/web/app/(shareLayout)/components/authenticated-layout.tsx index e3cfc8e6a8..2185606a6d 100644 --- a/web/app/(shareLayout)/components/authenticated-layout.tsx +++ b/web/app/(shareLayout)/components/authenticated-layout.tsx @@ -2,16 +2,17 @@ import AppUnavailable from '@/app/components/base/app-unavailable' import Loading from '@/app/components/base/loading' -import { removeAccessToken } from '@/app/components/share/utils' import { useWebAppStore } from '@/context/web-app-context' import { useGetUserCanAccessApp } from '@/service/access-control' import { useGetWebAppInfo, useGetWebAppMeta, useGetWebAppParams } from '@/service/use-share' +import { webAppLogout } from '@/service/webapp-auth' import { usePathname, useRouter, useSearchParams } from 'next/navigation' import React, { useCallback, useEffect } from 'react' import { useTranslation } from 'react-i18next' const AuthenticatedLayout = ({ children }: { children: React.ReactNode }) => { const { t } = useTranslation() + const shareCode = useWebAppStore(s => s.shareCode) const updateAppInfo = useWebAppStore(s => s.updateAppInfo) const updateAppParams = useWebAppStore(s => s.updateAppParams) const updateWebAppMeta = useWebAppStore(s => s.updateWebAppMeta) @@ -41,11 +42,11 @@ const AuthenticatedLayout = ({ children }: { children: React.ReactNode }) => { return `/webapp-signin?${params.toString()}` }, [searchParams, pathname]) - const backToHome = useCallback(() => { - removeAccessToken() + const backToHome = useCallback(async () => { + await webAppLogout(shareCode!) const url = getSigninUrl() router.replace(url) - }, [getSigninUrl, router]) + }, [getSigninUrl, router, webAppLogout, shareCode]) if (appInfoError) { return
diff --git a/web/app/(shareLayout)/components/splash.tsx b/web/app/(shareLayout)/components/splash.tsx index 4fe9efe4dd..c26ea7e045 100644 --- a/web/app/(shareLayout)/components/splash.tsx +++ b/web/app/(shareLayout)/components/splash.tsx @@ -1,15 +1,16 @@ 'use client' import type { FC, PropsWithChildren } from 'react' -import { useEffect } from 'react' +import { useEffect, useState } from 'react' import { useCallback } from 'react' import { useWebAppStore } from '@/context/web-app-context' import { useRouter, useSearchParams } from 'next/navigation' import AppUnavailable from '@/app/components/base/app-unavailable' -import { checkOrSetAccessToken, removeAccessToken, setAccessToken } from '@/app/components/share/utils' import { useTranslation } from 'react-i18next' +import { AccessMode } from '@/models/access-control' +import { webAppLoginStatus, webAppLogout } from '@/service/webapp-auth' import { fetchAccessToken } from '@/service/share' import Loading from '@/app/components/base/loading' -import { AccessMode } from '@/models/access-control' +import { setWebAppAccessToken, setWebAppPassport } from '@/service/webapp-auth' const Splash: FC = ({ children }) => { const { t } = useTranslation() @@ -18,9 +19,9 @@ const Splash: FC = ({ children }) => { const searchParams = useSearchParams() const router = useRouter() const redirectUrl = searchParams.get('redirect_url') - const tokenFromUrl = searchParams.get('web_sso_token') const message = searchParams.get('message') const code = searchParams.get('code') + const tokenFromUrl = searchParams.get('web_sso_token') const getSigninUrl = useCallback(() => { const params = new URLSearchParams(searchParams) params.delete('message') @@ -28,35 +29,66 @@ const Splash: FC = ({ children }) => { return `/webapp-signin?${params.toString()}` }, [searchParams]) - const backToHome = useCallback(() => { - removeAccessToken() + const backToHome = useCallback(async () => { + await webAppLogout(shareCode!) const url = getSigninUrl() router.replace(url) - }, [getSigninUrl, router]) + }, [getSigninUrl, router, webAppLogout, shareCode]) + const needCheckIsLogin = webAppAccessMode !== AccessMode.PUBLIC + const [isLoading, setIsLoading] = useState(true) useEffect(() => { + if (message) { + setIsLoading(false) + return + } + + if(tokenFromUrl) + setWebAppAccessToken(tokenFromUrl) + + const redirectOrFinish = () => { + if (redirectUrl) + router.replace(decodeURIComponent(redirectUrl)) + else + setIsLoading(false) + } + + const proceedToAuth = () => { + setIsLoading(false) + } + (async () => { - if (message) - return - if (shareCode && tokenFromUrl && redirectUrl) { - localStorage.setItem('webapp_access_token', tokenFromUrl) - const tokenResp = await fetchAccessToken({ appCode: shareCode, webAppAccessToken: tokenFromUrl }) - await setAccessToken(shareCode, tokenResp.access_token) - router.replace(decodeURIComponent(redirectUrl)) - return + const { userLoggedIn, appLoggedIn } = await webAppLoginStatus(needCheckIsLogin, shareCode!) + + if (userLoggedIn && appLoggedIn) { + redirectOrFinish() } - if (shareCode && redirectUrl && localStorage.getItem('webapp_access_token')) { - const tokenResp = await fetchAccessToken({ appCode: shareCode, webAppAccessToken: localStorage.getItem('webapp_access_token') }) - await setAccessToken(shareCode, tokenResp.access_token) - router.replace(decodeURIComponent(redirectUrl)) - return + else if (!userLoggedIn && !appLoggedIn) { + proceedToAuth() } - if (webAppAccessMode === AccessMode.PUBLIC && redirectUrl) { - await checkOrSetAccessToken(shareCode) - router.replace(decodeURIComponent(redirectUrl)) + else if (!userLoggedIn && appLoggedIn) { + redirectOrFinish() + } + else if (userLoggedIn && !appLoggedIn) { + try { + const { access_token } = await fetchAccessToken({ appCode: shareCode! }) + setWebAppPassport(shareCode!, access_token) + redirectOrFinish() + } + catch (error) { + await webAppLogout(shareCode!) + proceedToAuth() + } } })() - }, [shareCode, redirectUrl, router, tokenFromUrl, message, webAppAccessMode]) + }, [ + shareCode, + redirectUrl, + router, + message, + webAppAccessMode, + needCheckIsLogin, + tokenFromUrl]) if (message) { return
@@ -64,12 +96,8 @@ const Splash: FC = ({ children }) => { {code === '403' ? t('common.userProfile.logout') : t('share.login.backToHome')}
} - if (tokenFromUrl) { - return
- -
- } - if (webAppAccessMode === AccessMode.PUBLIC && redirectUrl) { + + if (isLoading) { return
diff --git a/web/app/(shareLayout)/webapp-signin/check-code/page.tsx b/web/app/(shareLayout)/webapp-signin/check-code/page.tsx index 3fc32fec71..4a1326fedf 100644 --- a/web/app/(shareLayout)/webapp-signin/check-code/page.tsx +++ b/web/app/(shareLayout)/webapp-signin/check-code/page.tsx @@ -10,7 +10,7 @@ import Input from '@/app/components/base/input' import Toast from '@/app/components/base/toast' import { sendWebAppEMailLoginCode, webAppEmailLoginWithCode } from '@/service/common' import I18NContext from '@/context/i18n' -import { setAccessToken } from '@/app/components/share/utils' +import { setWebAppAccessToken, setWebAppPassport } from '@/service/webapp-auth' import { fetchAccessToken } from '@/service/share' export default function CheckCode() { @@ -62,9 +62,9 @@ export default function CheckCode() { setIsLoading(true) const ret = await webAppEmailLoginWithCode({ email, code, token }) if (ret.result === 'success') { - localStorage.setItem('webapp_access_token', ret.data.access_token) - const tokenResp = await fetchAccessToken({ appCode, webAppAccessToken: ret.data.access_token }) - await setAccessToken(appCode, tokenResp.access_token) + setWebAppAccessToken(ret.data.access_token) + const { access_token } = await fetchAccessToken({ appCode: appCode! }) + setWebAppPassport(appCode!, access_token) router.replace(decodeURIComponent(redirectUrl)) } } diff --git a/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx index 2b6bd73df0..ce220b103e 100644 --- a/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx @@ -11,15 +11,13 @@ import { webAppLogin } from '@/service/common' import Input from '@/app/components/base/input' import I18NContext from '@/context/i18n' import { noop } from 'lodash-es' -import { setAccessToken } from '@/app/components/share/utils' import { fetchAccessToken } from '@/service/share' +import { setWebAppAccessToken, setWebAppPassport } from '@/service/webapp-auth' type MailAndPasswordAuthProps = { isEmailSetup: boolean } -const passwordRegex = /^(?=.*[a-zA-Z])(?=.*\d).{8,}$/ - export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAuthProps) { const { t } = useTranslation() const { locale } = useContext(I18NContext) @@ -43,8 +41,8 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut return appCode }, [redirectUrl]) + const appCode = getAppCodeFromRedirectUrl() const handleEmailPasswordLogin = async () => { - const appCode = getAppCodeFromRedirectUrl() if (!email) { Toast.notify({ type: 'error', message: t('login.error.emailEmpty') }) return @@ -60,13 +58,7 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut Toast.notify({ type: 'error', message: t('login.error.passwordEmpty') }) return } - if (!passwordRegex.test(password)) { - Toast.notify({ - type: 'error', - message: t('login.error.passwordInvalid'), - }) - return - } + if (!redirectUrl || !appCode) { Toast.notify({ type: 'error', @@ -88,9 +80,10 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut body: loginData, }) if (res.result === 'success') { - localStorage.setItem('webapp_access_token', res.data.access_token) - const tokenResp = await fetchAccessToken({ appCode, webAppAccessToken: res.data.access_token }) - await setAccessToken(appCode, tokenResp.access_token) + setWebAppAccessToken(res.data.access_token) + + const { access_token } = await fetchAccessToken({ appCode: appCode! }) + setWebAppPassport(appCode!, access_token) router.replace(decodeURIComponent(redirectUrl)) } else { @@ -141,9 +134,9 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut
setPassword(e.target.value)} + id="password" onKeyDown={(e) => { if (e.key === 'Enter') handleEmailPasswordLogin() diff --git a/web/app/(shareLayout)/webapp-signin/page.tsx b/web/app/(shareLayout)/webapp-signin/page.tsx index 1c6209b902..2ffa19c0c9 100644 --- a/web/app/(shareLayout)/webapp-signin/page.tsx +++ b/web/app/(shareLayout)/webapp-signin/page.tsx @@ -3,13 +3,13 @@ import { useRouter, useSearchParams } from 'next/navigation' import type { FC } from 'react' import React, { useCallback } from 'react' import { useTranslation } from 'react-i18next' -import { removeAccessToken } from '@/app/components/share/utils' import { useGlobalPublicStore } from '@/context/global-public-context' import AppUnavailable from '@/app/components/base/app-unavailable' import NormalForm from './normalForm' import { AccessMode } from '@/models/access-control' import ExternalMemberSsoAuth from './components/external-member-sso-auth' import { useWebAppStore } from '@/context/web-app-context' +import { webAppLogout } from '@/service/webapp-auth' const WebSSOForm: FC = () => { const { t } = useTranslation() @@ -26,11 +26,12 @@ const WebSSOForm: FC = () => { return `/webapp-signin?${params.toString()}` }, [redirectUrl]) - const backToHome = useCallback(() => { - removeAccessToken() + const shareCode = useWebAppStore(s => s.shareCode) + const backToHome = useCallback(async () => { + await webAppLogout(shareCode!) const url = getSigninUrl() router.replace(url) - }, [getSigninUrl, router]) + }, [getSigninUrl, router, webAppLogout, shareCode]) if (!redirectUrl) { return
diff --git a/web/app/account/(commonLayout)/account-page/email-change-modal.tsx b/web/app/account/(commonLayout)/account-page/email-change-modal.tsx index bd00f27ac5..d04cd18557 100644 --- a/web/app/account/(commonLayout)/account-page/email-change-modal.tsx +++ b/web/app/account/(commonLayout)/account-page/email-change-modal.tsx @@ -9,7 +9,6 @@ import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import { checkEmailExisted, - logout, resetEmail, sendVerifyCode, verifyEmail, @@ -17,6 +16,7 @@ import { import { noop } from 'lodash-es' import { asyncRunSafe } from '@/utils' import type { ResponseError } from '@/service/fetch' +import { useLogout } from '@/service/use-common' type Props = { show: boolean @@ -167,15 +167,12 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => { setStep(STEP.verifyNew) } + const { mutateAsync: logout } = useLogout() const handleLogout = async () => { - await logout({ - url: '/logout', - params: {}, - }) + await logout() localStorage.removeItem('setup_status') - localStorage.removeItem('console_token') - localStorage.removeItem('refresh_token') + // Tokens are now stored in cookies and cleared by backend router.push('/signin') } diff --git a/web/app/account/(commonLayout)/avatar.tsx b/web/app/account/(commonLayout)/avatar.tsx index ea897e639f..d8943b7879 100644 --- a/web/app/account/(commonLayout)/avatar.tsx +++ b/web/app/account/(commonLayout)/avatar.tsx @@ -7,11 +7,11 @@ import { } from '@remixicon/react' import { Menu, MenuButton, MenuItem, MenuItems, Transition } from '@headlessui/react' import Avatar from '@/app/components/base/avatar' -import { logout } from '@/service/common' import { useAppContext } from '@/context/app-context' import { useProviderContext } from '@/context/provider-context' import { LogOut01 } from '@/app/components/base/icons/src/vender/line/general' import PremiumBadge from '@/app/components/base/premium-badge' +import { useLogout } from '@/service/use-common' export type IAppSelector = { isMobile: boolean @@ -23,15 +23,12 @@ export default function AppSelector() { const { userProfile } = useAppContext() const { isEducationAccount } = useProviderContext() + const { mutateAsync: logout } = useLogout() const handleLogout = async () => { - await logout({ - url: '/logout', - params: {}, - }) + await logout() localStorage.removeItem('setup_status') - localStorage.removeItem('console_token') - localStorage.removeItem('refresh_token') + // Tokens are now stored in cookies and cleared by backend router.push('/signin') } diff --git a/web/app/account/(commonLayout)/delete-account/components/feed-back.tsx b/web/app/account/(commonLayout)/delete-account/components/feed-back.tsx index 2cd30bc3f2..64a378d2fe 100644 --- a/web/app/account/(commonLayout)/delete-account/components/feed-back.tsx +++ b/web/app/account/(commonLayout)/delete-account/components/feed-back.tsx @@ -8,7 +8,7 @@ import Button from '@/app/components/base/button' import CustomDialog from '@/app/components/base/dialog' import Textarea from '@/app/components/base/textarea' import Toast from '@/app/components/base/toast' -import { logout } from '@/service/common' +import { useLogout } from '@/service/use-common' type DeleteAccountProps = { onCancel: () => void @@ -22,14 +22,11 @@ export default function FeedBack(props: DeleteAccountProps) { const [userFeedback, setUserFeedback] = useState('') const { isPending, mutateAsync: sendFeedback } = useDeleteAccountFeedback() + const { mutateAsync: logout } = useLogout() const handleSuccess = useCallback(async () => { try { - await logout({ - url: '/logout', - params: {}, - }) - localStorage.removeItem('refresh_token') - localStorage.removeItem('console_token') + await logout() + // Tokens are now stored in cookies and cleared by backend router.push('/signin') Toast.notify({ type: 'info', message: t('common.account.deleteSuccessTip') }) } diff --git a/web/app/account/oauth/authorize/layout.tsx b/web/app/account/oauth/authorize/layout.tsx index 078d23114a..2ab676d6b6 100644 --- a/web/app/account/oauth/authorize/layout.tsx +++ b/web/app/account/oauth/authorize/layout.tsx @@ -5,17 +5,22 @@ import cn from '@/utils/classnames' import { useGlobalPublicStore } from '@/context/global-public-context' import useDocumentTitle from '@/hooks/use-document-title' import { AppContextProvider } from '@/context/app-context' -import { useMemo } from 'react' +import { useIsLogin } from '@/service/use-common' +import Loading from '@/app/components/base/loading' export default function SignInLayout({ children }: any) { const { systemFeatures } = useGlobalPublicStore() useDocumentTitle('') - const isLoggedIn = useMemo(() => { - try { - return Boolean(localStorage.getItem('console_token') && localStorage.getItem('refresh_token')) - } - catch { return false } - }, []) + const { isLoading, data: loginData } = useIsLogin() + const isLoggedIn = loginData?.logged_in + + if(isLoading) { + return ( +
+ +
+ ) + } return <>
diff --git a/web/app/account/oauth/authorize/page.tsx b/web/app/account/oauth/authorize/page.tsx index 6ad63996ae..4aa5fa0b8e 100644 --- a/web/app/account/oauth/authorize/page.tsx +++ b/web/app/account/oauth/authorize/page.tsx @@ -1,6 +1,6 @@ 'use client' -import React, { useEffect, useMemo, useRef } from 'react' +import React, { useEffect, useRef } from 'react' import { useTranslation } from 'react-i18next' import { useRouter, useSearchParams } from 'next/navigation' import Button from '@/app/components/base/button' @@ -18,6 +18,7 @@ import { RiTranslate2, } from '@remixicon/react' import dayjs from 'dayjs' +import { useIsLogin } from '@/service/use-common' export const OAUTH_AUTHORIZE_PENDING_KEY = 'oauth_authorize_pending' export const REDIRECT_URL_KEY = 'oauth_redirect_url' @@ -74,17 +75,13 @@ export default function OAuthAuthorize() { const client_id = decodeURIComponent(searchParams.get('client_id') || '') const redirect_uri = decodeURIComponent(searchParams.get('redirect_uri') || '') const { userProfile } = useAppContext() - const { data: authAppInfo, isLoading, isError } = useOAuthAppInfo(client_id, redirect_uri) + const { data: authAppInfo, isLoading: isOAuthLoading, isError } = useOAuthAppInfo(client_id, redirect_uri) const { mutateAsync: authorize, isPending: authorizing } = useAuthorizeOAuthApp() const hasNotifiedRef = useRef(false) - const isLoggedIn = useMemo(() => { - try { - return Boolean(localStorage.getItem('console_token') && localStorage.getItem('refresh_token')) - } - catch { return false } - }, []) - + const { isLoading: isIsLoginLoading, data: loginData } = useIsLogin() + const isLoggedIn = loginData?.logged_in + const isLoading = isOAuthLoading || isIsLoginLoading const onLoginSwitchClick = () => { try { const returnUrl = buildReturnUrl('/account/oauth/authorize', `?client_id=${encodeURIComponent(client_id)}&redirect_uri=${encodeURIComponent(redirect_uri)}`) diff --git a/web/app/components/app/app-access-control/access-control-dialog.tsx b/web/app/components/app/app-access-control/access-control-dialog.tsx index 479eedc9cf..ee3fa9650b 100644 --- a/web/app/components/app/app-access-control/access-control-dialog.tsx +++ b/web/app/components/app/app-access-control/access-control-dialog.tsx @@ -22,7 +22,7 @@ const AccessControlDialog = ({ }, [onClose]) return ( - null}> + null}> -
+
diff --git a/web/app/components/app/app-access-control/add-member-or-group-pop.tsx b/web/app/components/app/app-access-control/add-member-or-group-pop.tsx index 0fad6cc740..e9519aeedf 100644 --- a/web/app/components/app/app-access-control/add-member-or-group-pop.tsx +++ b/web/app/components/app/app-access-control/add-member-or-group-pop.tsx @@ -52,7 +52,7 @@ export default function AddMemberOrGroupDialog() { {open && } - +
diff --git a/web/app/components/base/chat/chat-with-history/index.tsx b/web/app/components/base/chat/chat-with-history/index.tsx index 464e30a821..6953be4b3c 100644 --- a/web/app/components/base/chat/chat-with-history/index.tsx +++ b/web/app/components/base/chat/chat-with-history/index.tsx @@ -4,7 +4,6 @@ import { useEffect, useState, } from 'react' -import { useAsyncEffect } from 'ahooks' import { useThemeContext } from '../embedded-chatbot/theme/theme-context' import { ChatWithHistoryContext, @@ -18,8 +17,6 @@ import ChatWrapper from './chat-wrapper' import type { InstalledApp } from '@/models/explore' import Loading from '@/app/components/base/loading' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' -import { checkOrSetAccessToken } from '@/app/components/share/utils' -import AppUnavailable from '@/app/components/base/app-unavailable' import cn from '@/utils/classnames' import useDocumentTitle from '@/hooks/use-document-title' @@ -201,36 +198,6 @@ const ChatWithHistoryWrapWithCheckToken: FC = ({ installedAppInfo, className, }) => { - const [initialized, setInitialized] = useState(false) - const [appUnavailable, setAppUnavailable] = useState(false) - const [isUnknownReason, setIsUnknownReason] = useState(false) - - useAsyncEffect(async () => { - if (!initialized) { - if (!installedAppInfo) { - try { - await checkOrSetAccessToken() - } - catch (e: any) { - if (e.status === 404) { - setAppUnavailable(true) - } - else { - setIsUnknownReason(true) - setAppUnavailable(true) - } - } - } - setInitialized(true) - } - }, []) - - if (!initialized) - return null - - if (appUnavailable) - return - return ( { - await logout({ - url: '/logout', - params: {}, - }) + await logout() localStorage.removeItem('setup_status') - localStorage.removeItem('console_token') - localStorage.removeItem('refresh_token') + // Tokens are now stored in cookies and cleared by backend // To avoid use other account's education notice info localStorage.removeItem('education-reverify-prev-expire-at') diff --git a/web/app/components/rag-pipeline/hooks/use-nodes-sync-draft.ts b/web/app/components/rag-pipeline/hooks/use-nodes-sync-draft.ts index ad757f36a7..51782f3cbf 100644 --- a/web/app/components/rag-pipeline/hooks/use-nodes-sync-draft.ts +++ b/web/app/components/rag-pipeline/hooks/use-nodes-sync-draft.ts @@ -77,7 +77,7 @@ export const useNodesSyncDraft = () => { if (postParams) { navigator.sendBeacon( - `${API_PREFIX}${postParams.url}?_token=${localStorage.getItem('console_token')}`, + `${API_PREFIX}${postParams.url}`, JSON.stringify(postParams.params), ) } diff --git a/web/app/components/share/text-generation/menu-dropdown.tsx b/web/app/components/share/text-generation/menu-dropdown.tsx index 373e3b8699..e3b12b3d84 100644 --- a/web/app/components/share/text-generation/menu-dropdown.tsx +++ b/web/app/components/share/text-generation/menu-dropdown.tsx @@ -20,6 +20,7 @@ import type { SiteInfo } from '@/models/share' import cn from '@/utils/classnames' import { AccessMode } from '@/models/access-control' import { useWebAppStore } from '@/context/web-app-context' +import { webAppLogout } from '@/service/webapp-auth' type Props = { data?: SiteInfo @@ -49,11 +50,11 @@ const MenuDropdown: FC = ({ setOpen(!openRef.current) }, [setOpen]) - const handleLogout = useCallback(() => { - localStorage.removeItem('token') - localStorage.removeItem('webapp_access_token') + const shareCode = useWebAppStore(s => s.shareCode) + const handleLogout = useCallback(async () => { + await webAppLogout(shareCode!) router.replace(`/webapp-signin?redirect_url=${pathname}`) - }, [router, pathname]) + }, [router, pathname, webAppLogout, shareCode]) const [show, setShow] = useState(false) diff --git a/web/app/components/share/utils.ts b/web/app/components/share/utils.ts index 3f5303dfcc..491433322d 100644 --- a/web/app/components/share/utils.ts +++ b/web/app/components/share/utils.ts @@ -1,7 +1,3 @@ -import { CONVERSATION_ID_INFO } from '../base/chat/constants' -import { fetchAccessToken } from '@/service/share' -import { getProcessedSystemVariablesFromUrlParams } from '../base/chat/utils' - export const isTokenV1 = (token: Record) => { return !token.version } @@ -9,55 +5,3 @@ export const isTokenV1 = (token: Record) => { export const getInitialTokenV2 = (): Record => ({ version: 2, }) - -export const checkOrSetAccessToken = async (appCode?: string | null) => { - const sharedToken = appCode || globalThis.location.pathname.split('/').slice(-1)[0] - const userId = (await getProcessedSystemVariablesFromUrlParams()).user_id - const accessToken = localStorage.getItem('token') || JSON.stringify(getInitialTokenV2()) - let accessTokenJson = getInitialTokenV2() - try { - accessTokenJson = JSON.parse(accessToken) - if (isTokenV1(accessTokenJson)) - accessTokenJson = getInitialTokenV2() - } - catch { - - } - - if (!accessTokenJson[sharedToken]?.[userId || 'DEFAULT']) { - const webAppAccessToken = localStorage.getItem('webapp_access_token') - const res = await fetchAccessToken({ appCode: sharedToken, userId, webAppAccessToken }) - accessTokenJson[sharedToken] = { - ...accessTokenJson[sharedToken], - [userId || 'DEFAULT']: res.access_token, - } - localStorage.setItem('token', JSON.stringify(accessTokenJson)) - localStorage.removeItem(CONVERSATION_ID_INFO) - } -} - -export const setAccessToken = (sharedToken: string, token: string, user_id?: string) => { - const accessToken = localStorage.getItem('token') || JSON.stringify(getInitialTokenV2()) - let accessTokenJson = getInitialTokenV2() - try { - accessTokenJson = JSON.parse(accessToken) - if (isTokenV1(accessTokenJson)) - accessTokenJson = getInitialTokenV2() - } - catch { - - } - - localStorage.removeItem(CONVERSATION_ID_INFO) - - accessTokenJson[sharedToken] = { - ...accessTokenJson[sharedToken], - [user_id || 'DEFAULT']: token, - } - localStorage.setItem('token', JSON.stringify(accessTokenJson)) -} - -export const removeAccessToken = () => { - localStorage.removeItem('token') - localStorage.removeItem('webapp_access_token') -} diff --git a/web/app/components/swr-initializer.tsx b/web/app/components/swr-initializer.tsx index fd9432fdd8..1ab1567659 100644 --- a/web/app/components/swr-initializer.tsx +++ b/web/app/components/swr-initializer.tsx @@ -19,10 +19,7 @@ const SwrInitializer = ({ }: SwrInitializerProps) => { const router = useRouter() const searchParams = useSearchParams() - const consoleToken = decodeURIComponent(searchParams.get('access_token') || '') - const refreshToken = decodeURIComponent(searchParams.get('refresh_token') || '') - const consoleTokenFromLocalStorage = localStorage?.getItem('console_token') - const refreshTokenFromLocalStorage = localStorage?.getItem('refresh_token') + // Tokens are now stored in cookies, no need to check localStorage const pathname = usePathname() const [init, setInit] = useState(false) @@ -57,21 +54,12 @@ const SwrInitializer = ({ router.replace('/install') return } - if (!((consoleToken && refreshToken) || (consoleTokenFromLocalStorage && refreshTokenFromLocalStorage))) { - router.replace('/signin') - return - } - if (searchParams.has('access_token') || searchParams.has('refresh_token')) { - if (consoleToken) - localStorage.setItem('console_token', consoleToken) - if (refreshToken) - localStorage.setItem('refresh_token', refreshToken) - const redirectUrl = resolvePostLoginRedirect(searchParams) - if (redirectUrl) - location.replace(redirectUrl) - else - router.replace(pathname) - } + + const redirectUrl = resolvePostLoginRedirect(searchParams) + if (redirectUrl) + location.replace(redirectUrl) + else + router.replace(pathname) setInit(true) } @@ -79,7 +67,7 @@ const SwrInitializer = ({ router.replace('/signin') } })() - }, [isSetupFinished, router, pathname, searchParams, consoleToken, refreshToken, consoleTokenFromLocalStorage, refreshTokenFromLocalStorage]) + }, [isSetupFinished, router, pathname, searchParams]) return init ? ( diff --git a/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts b/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts index 5705deb0c0..d33bfcc8b8 100644 --- a/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts +++ b/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts @@ -97,7 +97,7 @@ export const useNodesSyncDraft = () => { if (postParams) { navigator.sendBeacon( - `${API_PREFIX}/apps/${params.appId}/workflows/draft?_token=${localStorage.getItem('console_token')}`, + `${API_PREFIX}/apps/${params.appId}/workflows/draft`, JSON.stringify(postParams.params), ) } diff --git a/web/app/education-apply/user-info.tsx b/web/app/education-apply/user-info.tsx index e1d60a5e94..96ff1aaae6 100644 --- a/web/app/education-apply/user-info.tsx +++ b/web/app/education-apply/user-info.tsx @@ -2,24 +2,21 @@ import { useTranslation } from 'react-i18next' import { useRouter } from 'next/navigation' import Button from '@/app/components/base/button' import { useAppContext } from '@/context/app-context' -import { logout } from '@/service/common' import Avatar from '@/app/components/base/avatar' import { Triangle } from '@/app/components/base/icons/src/public/education' +import { useLogout } from '@/service/use-common' const UserInfo = () => { const router = useRouter() const { t } = useTranslation() const { userProfile } = useAppContext() + const { mutateAsync: logout } = useLogout() const handleLogout = async () => { - await logout({ - url: '/logout', - params: {}, - }) + await logout() localStorage.removeItem('setup_status') - localStorage.removeItem('console_token') - localStorage.removeItem('refresh_token') + // Tokens are now stored in cookies and cleared by backend router.push('/signin') } diff --git a/web/app/install/installForm.tsx b/web/app/install/installForm.tsx index 65d1998fcc..0a534b72fe 100644 --- a/web/app/install/installForm.tsx +++ b/web/app/install/installForm.tsx @@ -72,8 +72,6 @@ const InstallForm = () => { // Store tokens and redirect to apps if login successful if (loginRes.result === 'success') { - localStorage.setItem('console_token', loginRes.data.access_token) - localStorage.setItem('refresh_token', loginRes.data.refresh_token) router.replace('/apps') } else { diff --git a/web/app/signin/check-code/page.tsx b/web/app/signin/check-code/page.tsx index 8f12d807db..da6bd426af 100644 --- a/web/app/signin/check-code/page.tsx +++ b/web/app/signin/check-code/page.tsx @@ -42,8 +42,6 @@ export default function CheckCode() { setIsLoading(true) const ret = await emailLoginWithCode({ email, code, token }) if (ret.result === 'success') { - localStorage.setItem('console_token', ret.data.access_token) - localStorage.setItem('refresh_token', ret.data.refresh_token) if (invite_token) { router.replace(`/signin/invite-settings?${searchParams.toString()}`) } diff --git a/web/app/signin/components/mail-and-password-auth.tsx b/web/app/signin/components/mail-and-password-auth.tsx index 5214b73ee0..2740a82782 100644 --- a/web/app/signin/components/mail-and-password-auth.tsx +++ b/web/app/signin/components/mail-and-password-auth.tsx @@ -30,6 +30,7 @@ export default function MailAndPasswordAuth({ isInvite, isEmailSetup, allowRegis const [password, setPassword] = useState('') const [isLoading, setIsLoading] = useState(false) + const handleEmailPasswordLogin = async () => { if (!email) { Toast.notify({ type: 'error', message: t('login.error.emailEmpty') }) @@ -66,8 +67,6 @@ export default function MailAndPasswordAuth({ isInvite, isEmailSetup, allowRegis router.replace(`/signin/invite-settings?${searchParams.toString()}`) } else { - localStorage.setItem('console_token', res.data.access_token) - localStorage.setItem('refresh_token', res.data.refresh_token) const redirectUrl = resolvePostLoginRedirect(searchParams) router.replace(redirectUrl || '/apps') } diff --git a/web/app/signin/invite-settings/page.tsx b/web/app/signin/invite-settings/page.tsx index cec51a70ef..cbd37f51f6 100644 --- a/web/app/signin/invite-settings/page.tsx +++ b/web/app/signin/invite-settings/page.tsx @@ -58,8 +58,7 @@ export default function InviteSettingsPage() { }, }) if (res.result === 'success') { - localStorage.setItem('console_token', res.data.access_token) - localStorage.setItem('refresh_token', res.data.refresh_token) + // Tokens are now stored in cookies by the backend await setLocaleOnClient(language, false) const redirectUrl = resolvePostLoginRedirect(searchParams) router.replace(redirectUrl || '/apps') diff --git a/web/app/signin/normal-form.tsx b/web/app/signin/normal-form.tsx index a5a30a0cdd..920a992b4f 100644 --- a/web/app/signin/normal-form.tsx +++ b/web/app/signin/normal-form.tsx @@ -16,16 +16,18 @@ import { IS_CE_EDITION } from '@/config' import { useGlobalPublicStore } from '@/context/global-public-context' import { resolvePostLoginRedirect } from './utils/post-login-redirect' import Split from './split' +import { useIsLogin } from '@/service/use-common' const NormalForm = () => { const { t } = useTranslation() const router = useRouter() const searchParams = useSearchParams() - const consoleToken = decodeURIComponent(searchParams.get('access_token') || '') - const refreshToken = decodeURIComponent(searchParams.get('refresh_token') || '') + const { isLoading: isCheckLoading, data: loginData } = useIsLogin() + const isLoggedIn = loginData?.logged_in const message = decodeURIComponent(searchParams.get('message') || '') const invite_token = decodeURIComponent(searchParams.get('invite_token') || '') - const [isLoading, setIsLoading] = useState(true) + const [isInitCheckLoading, setInitCheckLoading] = useState(true) + const isLoading = isCheckLoading || loginData?.logged_in || isInitCheckLoading const { systemFeatures } = useGlobalPublicStore() const [authType, updateAuthType] = useState<'code' | 'password'>('password') const [showORLine, setShowORLine] = useState(false) @@ -36,9 +38,7 @@ const NormalForm = () => { const init = useCallback(async () => { try { - if (consoleToken && refreshToken) { - localStorage.setItem('console_token', consoleToken) - localStorage.setItem('refresh_token', refreshToken) + if (isLoggedIn) { const redirectUrl = resolvePostLoginRedirect(searchParams) router.replace(redirectUrl || '/apps') return @@ -67,12 +67,12 @@ const NormalForm = () => { console.error(error) setAllMethodsAreDisabled(true) } - finally { setIsLoading(false) } - }, [consoleToken, refreshToken, message, router, invite_token, isInviteLink, systemFeatures]) + finally { setInitCheckLoading(false) } + }, [isLoggedIn, message, router, invite_token, isInviteLink, systemFeatures]) useEffect(() => { init() }, [init]) - if (isLoading || consoleToken) { + if (isLoading) { return
{ new_password: password, password_confirm: confirmPassword, }) - const { result, data } = res as MailRegisterResponse + const { result } = res as MailRegisterResponse if (result === 'success') { Toast.notify({ type: 'success', message: t('common.api.actionSuccess'), }) - localStorage.setItem('console_token', data.access_token) - localStorage.setItem('refresh_token', data.refresh_token) router.replace('/apps') } } diff --git a/web/config/index.ts b/web/config/index.ts index f818a1c0af..0e876b800e 100644 --- a/web/config/index.ts +++ b/web/config/index.ts @@ -144,6 +144,17 @@ export const getMaxToken = (modelId: string) => { export const LOCALE_COOKIE_NAME = 'locale' +export const CSRF_COOKIE_NAME = () => { + const isSecure = API_PREFIX.startsWith('https://') + return isSecure ? '__Host-csrf_token' : 'csrf_token' +} +export const CSRF_HEADER_NAME = 'X-CSRF-Token' +export const ACCESS_TOKEN_LOCAL_STORAGE_NAME = 'access_token' +export const PASSPORT_LOCAL_STORAGE_NAME = (appCode: string) => `passport-${appCode}` +export const PASSPORT_HEADER_NAME = 'X-App-Passport' + +export const WEB_APP_SHARE_CODE_HEADER_NAME = 'X-App-Code' + export const DEFAULT_VALUE_MAX_LEN = 48 export const DEFAULT_PARAGRAPH_VALUE_MAX_LEN = 1000 diff --git a/web/context/web-app-context.tsx b/web/context/web-app-context.tsx index 0fe1b56b0a..48de01f2df 100644 --- a/web/context/web-app-context.tsx +++ b/web/context/web-app-context.tsx @@ -2,14 +2,12 @@ import type { ChatConfig } from '@/app/components/base/chat/types' import Loading from '@/app/components/base/loading' -import { checkOrSetAccessToken } from '@/app/components/share/utils' import { AccessMode } from '@/models/access-control' import type { AppData, AppMeta } from '@/models/share' import { useGetWebAppAccessModeByCode } from '@/service/use-share' import { usePathname, useSearchParams } from 'next/navigation' import type { FC, PropsWithChildren } from 'react' import { useEffect } from 'react' -import { useState } from 'react' import { create } from 'zustand' import { useGlobalPublicStore } from './global-public-context' @@ -71,24 +69,13 @@ const WebAppStoreProvider: FC = ({ children }) => { }, [shareCode, updateShareCode]) const { isFetching, data: accessModeResult } = useGetWebAppAccessModeByCode(shareCode) - const [isFetchingAccessToken, setIsFetchingAccessToken] = useState(true) useEffect(() => { - if (accessModeResult?.accessMode) { + if (accessModeResult?.accessMode) updateWebAppAccessMode(accessModeResult.accessMode) - if (accessModeResult.accessMode === AccessMode.PUBLIC) { - setIsFetchingAccessToken(true) - checkOrSetAccessToken(shareCode).finally(() => { - setIsFetchingAccessToken(false) - }) - } - else { - setIsFetchingAccessToken(false) - } - } }, [accessModeResult, updateWebAppAccessMode, shareCode]) - if (isGlobalPending || isFetching || isFetchingAccessToken) { + if (isGlobalPending || isFetching) { return
diff --git a/web/models/app.ts b/web/models/app.ts index 630dba9c19..26e6cba85b 100644 --- a/web/models/app.ts +++ b/web/models/app.ts @@ -2,63 +2,6 @@ import type { AliyunConfig, ArizeConfig, LangFuseConfig, LangSmithConfig, OpikCo import type { App, AppMode, AppTemplate, SiteConfig } from '@/types/app' import type { Dependency } from '@/app/components/plugins/types' -/* export type App = { - id: string - name: string - description: string - mode: AppMode - enable_site: boolean - enable_api: boolean - api_rpm: number - api_rph: number - is_demo: boolean - model_config: AppModelConfig - providers: Array<{ provider: string; token_is_set: boolean }> - site: SiteConfig - created_at: string -} - -export type AppModelConfig = { - provider: string - model_id: string - configs: { - prompt_template: string - prompt_variables: Array - completion_params: CompletionParam - } -} - -export type PromptVariable = { - key: string - name: string - description: string - type: string | number - default: string - options: string[] -} - -export type CompletionParam = { - max_tokens: number - temperature: number - top_p: number - echo: boolean - stop: string[] - presence_penalty: number - frequency_penalty: number -} - -export type SiteConfig = { - access_token: string - title: string - author: string - support_email: string - default_language: string - customize_domain: string - theme: string - customize_token_strategy: 'must' | 'allow' | 'not_allow' - prompt_public: boolean -} */ - export enum DSLImportMode { YAML_CONTENT = 'yaml-content', YAML_URL = 'yaml-url', diff --git a/web/service/base.ts b/web/service/base.ts index 1cb99e38d3..6e54e228e1 100644 --- a/web/service/base.ts +++ b/web/service/base.ts @@ -1,4 +1,4 @@ -import { API_PREFIX, IS_CE_EDITION, PUBLIC_API_PREFIX } from '@/config' +import { API_PREFIX, CSRF_COOKIE_NAME, CSRF_HEADER_NAME, IS_CE_EDITION, PASSPORT_HEADER_NAME, PUBLIC_API_PREFIX, WEB_APP_SHARE_CODE_HEADER_NAME } from '@/config' import { refreshAccessTokenOrRelogin } from './refresh-token' import Toast from '@/app/components/base/toast' import { basePath } from '@/utils/var' @@ -21,15 +21,16 @@ import type { WorkflowFinishedResponse, WorkflowStartedResponse, } from '@/types/workflow' -import { removeAccessToken } from '@/app/components/share/utils' import type { FetchOptionType, ResponseError } from './fetch' -import { ContentType, base, getAccessToken, getBaseOptions } from './fetch' +import { ContentType, base, getBaseOptions } from './fetch' import { asyncRunSafe } from '@/utils' import type { DataSourceNodeCompletedResponse, DataSourceNodeErrorResponse, DataSourceNodeProcessingResponse, } from '@/types/pipeline' +import Cookies from 'js-cookie' +import { getWebAppPassport } from './webapp-auth' const TIME_OUT = 100000 export type IOnDataMoreInfo = { @@ -122,14 +123,19 @@ function unicodeToChar(text: string) { }) } +const WBB_APP_LOGIN_PATH = '/webapp-signin' function requiredWebSSOLogin(message?: string, code?: number) { const params = new URLSearchParams() + // prevent redirect loop + if(globalThis.location.pathname === WBB_APP_LOGIN_PATH) + return + params.append('redirect_url', encodeURIComponent(`${globalThis.location.pathname}${globalThis.location.search}`)) if (message) params.append('message', message) if (code) params.append('code', String(code)) - globalThis.location.href = `${globalThis.location.origin}${basePath}/webapp-signin?${params.toString()}` + globalThis.location.href = `${globalThis.location.origin}${basePath}/${WBB_APP_LOGIN_PATH}?${params.toString()}` } export function format(text: string) { @@ -338,12 +344,14 @@ type UploadResponse = { export const upload = async (options: UploadOptions, isPublicAPI?: boolean, url?: string, searchParams?: string): Promise => { const urlPrefix = isPublicAPI ? PUBLIC_API_PREFIX : API_PREFIX - const token = await getAccessToken(isPublicAPI) + const shareCode = globalThis.location.pathname.split('/').slice(-1)[0] const defaultOptions = { method: 'POST', url: (url ? `${urlPrefix}${url}` : `${urlPrefix}/files/upload`) + (searchParams || ''), headers: { - Authorization: `Bearer ${token}`, + [CSRF_HEADER_NAME]: Cookies.get(CSRF_COOKIE_NAME()) || '', + [PASSPORT_HEADER_NAME]: getWebAppPassport(shareCode), + [WEB_APP_SHARE_CODE_HEADER_NAME]: shareCode, }, } const mergedOptions = { @@ -413,14 +421,17 @@ export const ssePost = async ( } = otherOptions const abortController = new AbortController() - const token = localStorage.getItem('console_token') + // No need to get token from localStorage, cookies will be sent automatically const baseOptions = getBaseOptions() + const shareCode = globalThis.location.pathname.split('/').slice(-1)[0] const options = Object.assign({}, baseOptions, { method: 'POST', signal: abortController.signal, headers: new Headers({ - Authorization: `Bearer ${token}`, + [CSRF_HEADER_NAME]: Cookies.get(CSRF_COOKIE_NAME()) || '', + [WEB_APP_SHARE_CODE_HEADER_NAME]: shareCode, + [PASSPORT_HEADER_NAME]: getWebAppPassport(shareCode), }), } as RequestInit, fetchOptions) @@ -439,9 +450,6 @@ export const ssePost = async ( if (body) options.body = JSON.stringify(body) - const accessToken = await getAccessToken(isPublicAPI) - ; (options.headers as Headers).set('Authorization', `Bearer ${accessToken}`) - globalThis.fetch(urlWithPrefix, options as RequestInit) .then((res) => { if (!/^[23]\d{2}$/.test(String(res.status))) { @@ -452,15 +460,11 @@ export const ssePost = async ( if (data.code === 'web_app_access_denied') requiredWebSSOLogin(data.message, 403) - if (data.code === 'web_sso_auth_required') { - removeAccessToken() + if (data.code === 'web_sso_auth_required') requiredWebSSOLogin() - } - if (data.code === 'unauthorized') { - removeAccessToken() + if (data.code === 'unauthorized') requiredWebSSOLogin() - } } }) } @@ -551,13 +555,11 @@ export const request = async(url: string, options = {}, otherOptions?: IOther return Promise.reject(err) } if (code === 'web_sso_auth_required') { - removeAccessToken() requiredWebSSOLogin() return Promise.reject(err) } if (code === 'unauthorized_and_force_logout') { - localStorage.removeItem('console_token') - localStorage.removeItem('refresh_token') + // Cookies will be cleared by the backend globalThis.location.reload() return Promise.reject(err) } @@ -566,7 +568,6 @@ export const request = async(url: string, options = {}, otherOptions?: IOther silent, } = otherOptionsForBaseFetch if (isPublicAPI && code === 'unauthorized') { - removeAccessToken() requiredWebSSOLogin() return Promise.reject(err) } diff --git a/web/service/common.ts b/web/service/common.ts index d70315f5c6..8f2adc329e 100644 --- a/web/service/common.ts +++ b/web/service/common.ts @@ -40,7 +40,7 @@ import type { SystemFeatures } from '@/types/feature' type LoginSuccess = { result: 'success' - data: { access_token: string; refresh_token: string } + data: { access_token: string } } type LoginFail = { result: 'fail' @@ -56,10 +56,6 @@ export const webAppLogin: Fetcher } -export const fetchNewToken: Fetcher }> = ({ body }) => { - return post('/refresh-token', { body }) as Promise -} - export const setup: Fetcher }> = ({ body }) => { return post('/setup', { body }) } @@ -84,10 +80,6 @@ export const updateUserProfile: Fetcher(url, { body }) } -export const logout: Fetcher }> = ({ url, params }) => { - return get(url, params) -} - export const fetchLangGeniusVersion: Fetcher }> = ({ url, params }) => { return get(url, { params }) } diff --git a/web/service/fetch.ts b/web/service/fetch.ts index 4e76843ba2..541b1246d4 100644 --- a/web/service/fetch.ts +++ b/web/service/fetch.ts @@ -2,9 +2,9 @@ import type { AfterResponseHook, BeforeErrorHook, BeforeRequestHook, Hooks } fro import ky from 'ky' import type { IOtherOptions } from './base' import Toast from '@/app/components/base/toast' -import { API_PREFIX, APP_VERSION, MARKETPLACE_API_PREFIX, PUBLIC_API_PREFIX } from '@/config' -import { getInitialTokenV2, isTokenV1 } from '@/app/components/share/utils' -import { getProcessedSystemVariablesFromUrlParams } from '@/app/components/base/chat/utils' +import { API_PREFIX, APP_VERSION, CSRF_COOKIE_NAME, CSRF_HEADER_NAME, MARKETPLACE_API_PREFIX, PASSPORT_HEADER_NAME, PUBLIC_API_PREFIX, WEB_APP_SHARE_CODE_HEADER_NAME } from '@/config' +import Cookies from 'js-cookie' +import { getWebAppAccessToken, getWebAppPassport } from './webapp-auth' const TIME_OUT = 100000 @@ -69,35 +69,15 @@ const beforeErrorToast = (otherOptions: IOtherOptions): BeforeErrorHook => { } } -export async function getAccessToken(isPublicAPI?: boolean) { - if (isPublicAPI) { - const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0] - const userId = (await getProcessedSystemVariablesFromUrlParams()).user_id - const accessToken = localStorage.getItem('token') || JSON.stringify({ version: 2 }) - let accessTokenJson: Record = { version: 2 } - try { - accessTokenJson = JSON.parse(accessToken) - if (isTokenV1(accessTokenJson)) - accessTokenJson = getInitialTokenV2() - } - catch { - - } - return accessTokenJson[sharedToken]?.[userId || 'DEFAULT'] - } - else { - return localStorage.getItem('console_token') || '' - } -} - -const beforeRequestPublicAuthorization: BeforeRequestHook = async (request) => { - const token = await getAccessToken(true) - request.headers.set('Authorization', `Bearer ${token}`) -} - -const beforeRequestAuthorization: BeforeRequestHook = async (request) => { - const accessToken = await getAccessToken() - request.headers.set('Authorization', `Bearer ${accessToken}`) +const beforeRequestPublicWithCode = (request: Request) => { + request.headers.set('Authorization', `Bearer ${getWebAppAccessToken()}`) + const shareCode = globalThis.location.pathname.split('/').filter(Boolean).pop() || '' + // some pages does not end with share code, so we need to check it + // TODO: maybe find a better way to access app code? + if (shareCode === 'webapp-signin' || shareCode === 'check-code') + return + request.headers.set(WEB_APP_SHARE_CODE_HEADER_NAME, shareCode) + request.headers.set(PASSPORT_HEADER_NAME, getWebAppPassport(shareCode)) } const baseHooks: Hooks = { @@ -148,6 +128,8 @@ async function base(url: string, options: FetchOptionType = {}, otherOptions: } const fetchPathname = base + (url.startsWith('/') ? url : `/${url}`) + if (!isMarketplaceAPI) + (headers as any).set(CSRF_HEADER_NAME, Cookies.get(CSRF_COOKIE_NAME()) || '') if (deleteContentType) (headers as any).delete('Content-Type') @@ -165,8 +147,7 @@ async function base(url: string, options: FetchOptionType = {}, otherOptions: ], beforeRequest: [ ...baseHooks.beforeRequest || [], - isPublicAPI && beforeRequestPublicAuthorization, - !isPublicAPI && !isMarketplaceAPI && beforeRequestAuthorization, + isPublicAPI && beforeRequestPublicWithCode, ].filter((h): h is BeforeRequestHook => Boolean(h)), afterResponse: [ ...baseHooks.afterResponse || [], diff --git a/web/service/refresh-token.ts b/web/service/refresh-token.ts index 7eff08b52f..3f63f628a1 100644 --- a/web/service/refresh-token.ts +++ b/web/service/refresh-token.ts @@ -39,7 +39,6 @@ async function getNewAccessToken(timeout: number): Promise { globalThis.localStorage.setItem(LOCAL_STORAGE_KEY, '1') globalThis.localStorage.setItem('last_refresh_time', new Date().getTime().toString()) globalThis.addEventListener('beforeunload', releaseRefreshLock) - const refresh_token = globalThis.localStorage.getItem('refresh_token') // Do not use baseFetch to refresh tokens. // If a 401 response occurs and baseFetch itself attempts to refresh the token, @@ -48,10 +47,11 @@ async function getNewAccessToken(timeout: number): Promise { // that does not call baseFetch and uses a single retry mechanism. const [error, ret] = await fetchWithRetry(globalThis.fetch(`${API_PREFIX}/refresh-token`, { method: 'POST', + credentials: 'include', // Important: include cookies in the request headers: { 'Content-Type': 'application/json;utf-8', }, - body: JSON.stringify({ refresh_token }), + // No body needed - refresh token is in cookie })) if (error) { return Promise.reject(error) @@ -59,10 +59,6 @@ async function getNewAccessToken(timeout: number): Promise { else { if (ret.status === 401) return Promise.reject(ret) - - const { data } = await ret.json() - globalThis.localStorage.setItem('console_token', data.access_token) - globalThis.localStorage.setItem('refresh_token', data.refresh_token) } } } diff --git a/web/service/share.ts b/web/service/share.ts index ab8e0deb4a..ce03f508d1 100644 --- a/web/service/share.ts +++ b/web/service/share.ts @@ -34,6 +34,8 @@ import type { } from '@/models/share' import type { ChatConfig } from '@/app/components/base/chat/types' import type { AccessMode } from '@/models/access-control' +import { WEB_APP_SHARE_CODE_HEADER_NAME } from '@/config' +import { getWebAppAccessToken } from './webapp-auth' function getAction(action: 'get' | 'post' | 'del' | 'patch', isInstalledApp: boolean) { switch (action) { @@ -286,16 +288,14 @@ export const textToAudioStream = (url: string, isPublicAPI: boolean, header: { c return (getAction('post', !isPublicAPI))(url, { body, header }, { needAllResponseContent: true }) } -export const fetchAccessToken = async ({ appCode, userId, webAppAccessToken }: { appCode: string, userId?: string, webAppAccessToken?: string | null }) => { +export const fetchAccessToken = async ({ userId, appCode }: { userId?: string, appCode: string }) => { const headers = new Headers() - headers.append('X-App-Code', appCode) + headers.append(WEB_APP_SHARE_CODE_HEADER_NAME, appCode) + headers.append('Authorization', `Bearer ${getWebAppAccessToken()}`) const params = new URLSearchParams() - if (webAppAccessToken) - params.append('web_app_access_token', webAppAccessToken) - if (userId) - params.append('user_id', userId) + userId && params.append('user_id', userId) const url = `/passport?${params.toString()}` - return get(url, { headers }) as Promise<{ access_token: string }> + return get<{ access_token: string }>(url, { headers }) as Promise<{ access_token: string }> } export const getUserCanAccess = (appId: string, isInstalledApp: boolean) => { diff --git a/web/service/use-common.ts b/web/service/use-common.ts index 330ee674b0..3e01b721e8 100644 --- a/web/service/use-common.ts +++ b/web/service/use-common.ts @@ -50,7 +50,7 @@ export const useMailValidity = () => { }) } -export type MailRegisterResponse = { result: string, data: { access_token: string, refresh_token: string } } +export type MailRegisterResponse = { result: string, data: {} } export const useMailRegister = () => { return useMutation({ @@ -106,3 +106,23 @@ export const useSchemaTypeDefinitions = () => { queryFn: () => get('/spec/schema-definitions'), }) } + +type isLogin = { + logged_in: boolean +} + +export const useIsLogin = () => { + return useQuery({ + queryKey: [NAME_SPACE, 'is-login'], + staleTime: 0, + gcTime: 0, + queryFn: () => get('/login/status'), + }) +} + +export const useLogout = () => { + return useMutation({ + mutationKey: [NAME_SPACE, 'logout'], + mutationFn: () => post('/logout'), + }) +} diff --git a/web/service/use-share.ts b/web/service/use-share.ts index 267975fd38..a5e0a11100 100644 --- a/web/service/use-share.ts +++ b/web/service/use-share.ts @@ -8,6 +8,8 @@ export const useGetWebAppAccessModeByCode = (code: string | null) => { queryKey: [NAME_SPACE, 'appAccessMode', code], queryFn: () => getAppAccessModeByAppCode(code!), enabled: !!code, + staleTime: 0, // backend change the access mode may cause the logic error. Because /permission API is no cached. + gcTime: 0, }) } diff --git a/web/service/webapp-auth.ts b/web/service/webapp-auth.ts new file mode 100644 index 0000000000..a7ce7153bf --- /dev/null +++ b/web/service/webapp-auth.ts @@ -0,0 +1,53 @@ +import { ACCESS_TOKEN_LOCAL_STORAGE_NAME, PASSPORT_LOCAL_STORAGE_NAME } from '@/config' +import { getPublic, postPublic } from './base' + +export function setWebAppAccessToken(token: string) { + localStorage.setItem(ACCESS_TOKEN_LOCAL_STORAGE_NAME, token) +} + +export function setWebAppPassport(shareCode: string, token: string) { + localStorage.setItem(PASSPORT_LOCAL_STORAGE_NAME(shareCode), token) +} + +export function getWebAppAccessToken() { + return localStorage.getItem(ACCESS_TOKEN_LOCAL_STORAGE_NAME) || '' +} + +export function getWebAppPassport(shareCode: string) { + return localStorage.getItem(PASSPORT_LOCAL_STORAGE_NAME(shareCode)) || '' +} + +export function clearWebAppAccessToken() { + localStorage.removeItem(ACCESS_TOKEN_LOCAL_STORAGE_NAME) +} + +export function clearWebAppPassport(shareCode: string) { + localStorage.removeItem(PASSPORT_LOCAL_STORAGE_NAME(shareCode)) +} + +type isWebAppLogin = { + logged_in: boolean + app_logged_in: boolean +} + +export async function webAppLoginStatus(enabled: boolean, shareCode: string) { + if (!enabled) { + return { + userLoggedIn: true, + appLoggedIn: true, + } + } + + // check remotely, the access token could be in cookie (enterprise SSO redirected with https) + const { logged_in, app_logged_in } = await getPublic(`/login/status?app_code=${shareCode}`) + return { + userLoggedIn: logged_in, + appLoggedIn: app_logged_in, + } +} + +export async function webAppLogout(shareCode: string) { + clearWebAppAccessToken() + clearWebAppPassport(shareCode) + await postPublic('/logout') +} From 578247ffbcfef8e69f90ff6942c515f86fd51eda Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sun, 19 Oct 2025 21:33:41 +0800 Subject: [PATCH 38/46] feat(graph_engine): Support pausing workflow graph executions (#26585) Signed-off-by: -LAN- --- .../app/apps/advanced_chat/app_generator.py | 13 +- api/core/app/apps/advanced_chat/app_runner.py | 50 +- .../advanced_chat/generate_task_pipeline.py | 246 ++++----- api/core/app/apps/base_app_queue_manager.py | 12 + .../common/graph_runtime_state_support.py | 55 ++ .../common/workflow_response_converter.py | 379 +++++++++----- .../app/apps/pipeline/pipeline_generator.py | 13 +- api/core/app/apps/pipeline/pipeline_runner.py | 28 +- api/core/app/apps/workflow/app_generator.py | 12 +- api/core/app/apps/workflow/app_runner.py | 49 +- .../apps/workflow/generate_task_pipeline.py | 247 ++++----- api/core/app/apps/workflow_app_runner.py | 11 +- api/core/app/entities/queue_entities.py | 61 +-- api/core/app/entities/task_entities.py | 32 -- api/core/prompt/advanced_prompt_transform.py | 2 +- ...hemy_workflow_node_execution_repository.py | 16 +- api/core/tools/tool_manager.py | 2 +- api/core/workflow/entities/__init__.py | 7 - .../workflow/entities/graph_runtime_state.py | 160 ------ api/core/workflow/entities/run_condition.py | 21 - api/core/workflow/enums.py | 2 + api/core/workflow/graph/__init__.py | 9 +- api/core/workflow/graph/graph.py | 99 ++++ .../command_channels/redis_channel.py | 10 +- .../command_processing/__init__.py | 3 +- .../command_processing/command_handlers.py | 24 +- .../graph_engine/domain/graph_execution.py | 26 +- .../graph_engine/entities/commands.py | 8 +- .../event_management/event_handlers.py | 17 +- .../event_management/event_manager.py | 4 + .../workflow/graph_engine/graph_engine.py | 143 +++--- api/core/workflow/graph_engine/layers/base.py | 2 +- .../graph_engine/layers/persistence.py | 410 +++++++++++++++ api/core/workflow/graph_engine/manager.py | 26 +- .../graph_engine/orchestration/dispatcher.py | 56 ++- .../orchestration/execution_coordinator.py | 43 +- .../response_coordinator/coordinator.py | 2 +- api/core/workflow/graph_events/__init__.py | 4 + api/core/workflow/graph_events/graph.py | 29 +- api/core/workflow/graph_events/node.py | 4 + api/core/workflow/node_events/__init__.py | 2 + api/core/workflow/node_events/node.py | 4 + api/core/workflow/nodes/agent/agent_node.py | 2 +- api/core/workflow/nodes/base/node.py | 15 +- .../nodes/datasource/datasource_node.py | 2 +- .../workflow/nodes/http_request/executor.py | 2 +- .../workflow/nodes/human_input/__init__.py | 3 + .../workflow/nodes/human_input/entities.py | 10 + .../nodes/human_input/human_input_node.py | 132 +++++ .../workflow/nodes/if_else/if_else_node.py | 2 +- .../nodes/iteration/iteration_node.py | 5 +- .../knowledge_index/knowledge_index_node.py | 2 +- .../knowledge_retrieval_node.py | 2 +- api/core/workflow/nodes/llm/llm_utils.py | 2 +- api/core/workflow/nodes/llm/node.py | 5 +- api/core/workflow/nodes/loop/loop_node.py | 3 +- api/core/workflow/nodes/node_factory.py | 3 +- api/core/workflow/nodes/node_mapping.py | 5 + .../parameter_extractor_node.py | 2 +- .../question_classifier_node.py | 2 +- api/core/workflow/nodes/tool/tool_node.py | 2 +- .../nodes/variable_assigner/v1/node.py | 2 +- api/core/workflow/runtime/__init__.py | 14 + .../workflow/runtime/graph_runtime_state.py | 393 +++++++++++++++ .../graph_runtime_state_protocol.py | 18 + .../read_only_wrappers.py} | 53 +- .../{entities => runtime}/variable_pool.py | 15 + .../workflow/utils/condition/processor.py | 2 +- api/core/workflow/variable_loader.py | 2 +- api/core/workflow/workflow_cycle_manager.py | 459 ----------------- api/core/workflow/workflow_entry.py | 3 +- api/services/rag_pipeline/rag_pipeline.py | 2 +- api/services/workflow_service.py | 3 +- .../workflow/nodes/test_code.py | 3 +- .../workflow/nodes/test_http.py | 5 +- .../workflow/nodes/test_llm.py | 3 +- .../nodes/test_parameter_extractor.py | 3 +- .../workflow/nodes/test_template_transform.py | 3 +- .../workflow/nodes/test_tool.py | 3 +- .../test_app_runner_conversation_variables.py | 6 + .../test_graph_runtime_state_support.py | 63 +++ ...orkflow_response_converter_process_data.py | 420 ++++++---------- .../unit_tests/core/variables/test_segment.py | 2 +- .../entities/test_graph_runtime_state.py | 144 +++++- .../workflow/entities/test_variable_pool.py | 2 +- .../core/workflow/graph/test_graph_builder.py | 59 +++ .../core/workflow/graph_engine/README.md | 46 -- .../event_management/test_event_handlers.py | 2 +- .../graph_engine/test_command_system.py | 63 ++- .../workflow/graph_engine/test_dispatcher.py | 7 +- .../test_execution_coordinator.py | 62 +++ .../test_human_input_pause_multi_branch.py | 341 +++++++++++++ .../test_human_input_pause_single_branch.py | 297 +++++++++++ .../graph_engine/test_if_else_streaming.py | 321 ++++++++++++ .../graph_engine/test_mock_factory.py | 3 +- .../test_mock_iteration_simple.py | 6 +- .../workflow/graph_engine/test_mock_nodes.py | 9 +- .../test_mock_nodes_template_code.py | 40 +- .../test_parallel_streaming_workflow.py | 3 +- .../test_redis_stop_integration.py | 68 ++- .../graph_engine/test_table_runner.py | 183 +++---- .../core/workflow/nodes/answer/test_answer.py | 3 +- .../test_http_request_executor.py | 2 +- .../core/workflow/nodes/llm/test_node.py | 3 +- .../core/workflow/nodes/test_if_else.py | 3 +- .../v1/test_variable_assigner_v1.py | 3 +- .../v2/test_variable_assigner_v2.py | 3 +- .../core/workflow/test_variable_pool.py | 2 +- .../workflow/test_workflow_cycle_manager.py | 476 ------------------ .../core/workflow/test_workflow_entry.py | 2 +- .../test_workflow_entry_redis_channel.py | 2 +- dev/start-worker | 5 +- 112 files changed, 3766 insertions(+), 2415 deletions(-) create mode 100644 api/core/app/apps/common/graph_runtime_state_support.py delete mode 100644 api/core/workflow/entities/graph_runtime_state.py delete mode 100644 api/core/workflow/entities/run_condition.py create mode 100644 api/core/workflow/graph_engine/layers/persistence.py create mode 100644 api/core/workflow/nodes/human_input/__init__.py create mode 100644 api/core/workflow/nodes/human_input/entities.py create mode 100644 api/core/workflow/nodes/human_input/human_input_node.py create mode 100644 api/core/workflow/runtime/__init__.py create mode 100644 api/core/workflow/runtime/graph_runtime_state.py rename api/core/workflow/{graph => runtime}/graph_runtime_state_protocol.py (76%) rename api/core/workflow/{graph/read_only_state_wrapper.py => runtime/read_only_wrappers.py} (54%) rename api/core/workflow/{entities => runtime}/variable_pool.py (95%) delete mode 100644 api/core/workflow/workflow_cycle_manager.py create mode 100644 api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py create mode 100644 api/tests/unit_tests/core/workflow/graph/test_graph_builder.py create mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py create mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py create mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py create mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py delete mode 100644 api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index b6234491c5..feb0d3358c 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -447,6 +447,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): "message_id": message.id, "context": context, "variable_loader": variable_loader, + "workflow_execution_repository": workflow_execution_repository, + "workflow_node_execution_repository": workflow_node_execution_repository, }, ) @@ -466,8 +468,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation=conversation, message=message, user=user, - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, stream=stream, draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user), ) @@ -483,6 +483,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): message_id: str, context: contextvars.Context, variable_loader: VariableLoader, + workflow_execution_repository: WorkflowExecutionRepository, + workflow_node_execution_repository: WorkflowNodeExecutionRepository, ): """ Generate worker in a new thread. @@ -538,6 +540,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): workflow=workflow, system_user_id=system_user_id, app=app, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, ) try: @@ -570,8 +574,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation: Conversation, message: Message, user: Union[Account, EndUser], - workflow_execution_repository: WorkflowExecutionRepository, - workflow_node_execution_repository: WorkflowNodeExecutionRepository, draft_var_saver_factory: DraftVariableSaverFactory, stream: bool = False, ) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: @@ -584,7 +586,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): :param message: message :param user: account or end user :param stream: is stream - :param workflow_node_execution_repository: optional repository for workflow node execution :return: """ # init generate task pipeline @@ -596,8 +597,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): message=message, user=user, dialogue_count=self._dialogue_count, - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, stream=stream, draft_var_saver_factory=draft_var_saver_factory, ) diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 919b135ec9..587c663482 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -23,8 +23,12 @@ from core.app.features.annotation_reply.annotation_reply import AnnotationReplyF from core.moderation.base import ModerationError from core.moderation.input_moderation import InputModeration from core.variables.variables import VariableUnion -from core.workflow.entities import GraphRuntimeState, VariablePool +from core.workflow.enums import WorkflowType from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel +from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import VariableLoader from core.workflow.workflow_entry import WorkflowEntry @@ -55,6 +59,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): workflow: Workflow, system_user_id: str, app: App, + workflow_execution_repository: WorkflowExecutionRepository, + workflow_node_execution_repository: WorkflowNodeExecutionRepository, ): super().__init__( queue_manager=queue_manager, @@ -68,11 +74,24 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): self._workflow = workflow self.system_user_id = system_user_id self._app = app + self._workflow_execution_repository = workflow_execution_repository + self._workflow_node_execution_repository = workflow_node_execution_repository def run(self): app_config = self.application_generate_entity.app_config app_config = cast(AdvancedChatAppConfig, app_config) + system_inputs = SystemVariable( + query=self.application_generate_entity.query, + files=self.application_generate_entity.files, + conversation_id=self.conversation.id, + user_id=self.system_user_id, + dialogue_count=self._dialogue_count, + app_id=app_config.app_id, + workflow_id=app_config.workflow_id, + workflow_execution_id=self.application_generate_entity.workflow_run_id, + ) + with Session(db.engine, expire_on_commit=False) as session: app_record = session.scalar(select(App).where(App.id == app_config.app_id)) @@ -89,7 +108,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): else: inputs = self.application_generate_entity.inputs query = self.application_generate_entity.query - files = self.application_generate_entity.files # moderation if self.handle_input_moderation( @@ -114,17 +132,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): conversation_variables = self._initialize_conversation_variables() # Create a variable pool. - system_inputs = SystemVariable( - query=query, - files=files, - conversation_id=self.conversation.id, - user_id=self.system_user_id, - dialogue_count=self._dialogue_count, - app_id=app_config.app_id, - workflow_id=app_config.workflow_id, - workflow_execution_id=self.application_generate_entity.workflow_run_id, - ) - # init variable pool variable_pool = VariablePool( system_variables=system_inputs, @@ -172,6 +179,23 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): command_channel=command_channel, ) + self._queue_manager.graph_runtime_state = graph_runtime_state + + persistence_layer = WorkflowPersistenceLayer( + application_generate_entity=self.application_generate_entity, + workflow_info=PersistenceWorkflowInfo( + workflow_id=self._workflow.id, + workflow_type=WorkflowType(self._workflow.type), + version=self._workflow.version, + graph_data=self._workflow.graph_dict, + ), + workflow_execution_repository=self._workflow_execution_repository, + workflow_node_execution_repository=self._workflow_node_execution_repository, + trace_manager=self.application_generate_entity.trace_manager, + ) + + workflow_entry.graph_engine.layer(persistence_layer) + generator = workflow_entry.run() for event in generator: diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index b5af6382e8..8c0102d9bd 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -11,6 +11,7 @@ from sqlalchemy.orm import Session from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, @@ -60,14 +61,11 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.model_runtime.entities.llm_entities import LLMUsage from core.ops.ops_trace_manager import TraceQueueManager -from core.workflow.entities import GraphRuntimeState -from core.workflow.enums import WorkflowExecutionStatus, WorkflowType +from core.workflow.enums import WorkflowExecutionStatus from core.workflow.nodes import NodeType from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory -from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository -from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.runtime import GraphRuntimeState from core.workflow.system_variable import SystemVariable -from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models import Account, Conversation, EndUser, Message, MessageFile @@ -77,7 +75,7 @@ from models.workflow import Workflow logger = logging.getLogger(__name__) -class AdvancedChatAppGenerateTaskPipeline: +class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): """ AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ @@ -92,8 +90,6 @@ class AdvancedChatAppGenerateTaskPipeline: user: Union[Account, EndUser], stream: bool, dialogue_count: int, - workflow_execution_repository: WorkflowExecutionRepository, - workflow_node_execution_repository: WorkflowNodeExecutionRepository, draft_var_saver_factory: DraftVariableSaverFactory, ): self._base_task_pipeline = BasedGenerateTaskPipeline( @@ -113,31 +109,20 @@ class AdvancedChatAppGenerateTaskPipeline: else: raise NotImplementedError(f"User type not supported: {type(user)}") - self._workflow_cycle_manager = WorkflowCycleManager( - application_generate_entity=application_generate_entity, - workflow_system_variables=SystemVariable( - query=message.query, - files=application_generate_entity.files, - conversation_id=conversation.id, - user_id=user_session_id, - dialogue_count=dialogue_count, - app_id=application_generate_entity.app_config.app_id, - workflow_id=workflow.id, - workflow_execution_id=application_generate_entity.workflow_run_id, - ), - workflow_info=CycleManagerWorkflowInfo( - workflow_id=workflow.id, - workflow_type=WorkflowType(workflow.type), - version=workflow.version, - graph_data=workflow.graph_dict, - ), - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, + self._workflow_system_variables = SystemVariable( + query=message.query, + files=application_generate_entity.files, + conversation_id=conversation.id, + user_id=user_session_id, + dialogue_count=dialogue_count, + app_id=application_generate_entity.app_config.app_id, + workflow_id=workflow.id, + workflow_execution_id=application_generate_entity.workflow_run_id, ) - self._workflow_response_converter = WorkflowResponseConverter( application_generate_entity=application_generate_entity, user=user, + system_variables=self._workflow_system_variables, ) self._task_state = WorkflowTaskState() @@ -156,6 +141,8 @@ class AdvancedChatAppGenerateTaskPipeline: self._recorded_files: list[Mapping[str, Any]] = [] self._workflow_run_id: str = "" self._draft_var_saver_factory = draft_var_saver_factory + self._graph_runtime_state: GraphRuntimeState | None = None + self._seed_graph_runtime_state_from_queue_manager() def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: """ @@ -288,12 +275,6 @@ class AdvancedChatAppGenerateTaskPipeline: if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - def _ensure_graph_runtime_initialized(self, graph_runtime_state: GraphRuntimeState | None) -> GraphRuntimeState: - """Fluent validation for graph runtime state.""" - if not graph_runtime_state: - raise ValueError("graph runtime state not initialized.") - return graph_runtime_state - def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]: """Handle ping events.""" yield self._base_task_pipeline.ping_stream_response() @@ -304,21 +285,28 @@ class AdvancedChatAppGenerateTaskPipeline: err = self._base_task_pipeline.handle_error(event=event, session=session, message_id=self._message_id) yield self._base_task_pipeline.error_to_stream_response(err) - def _handle_workflow_started_event(self, *args, **kwargs) -> Generator[StreamResponse, None, None]: + def _handle_workflow_started_event( + self, + event: QueueWorkflowStartedEvent, + **kwargs, + ) -> Generator[StreamResponse, None, None]: """Handle workflow started events.""" - with self._database_session() as session: - workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start() - self._workflow_run_id = workflow_execution.id_ + runtime_state = self._resolve_graph_runtime_state() + run_id = self._extract_workflow_run_id(runtime_state) + self._workflow_run_id = run_id + with self._database_session() as session: message = self._get_message(session=session) if not message: raise ValueError(f"Message not found: {self._message_id}") - message.workflow_run_id = workflow_execution.id_ - workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution=workflow_execution, - ) + message.workflow_run_id = run_id + + workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run_id=run_id, + workflow_id=self._workflow_id, + ) yield workflow_start_resp @@ -326,13 +314,9 @@ class AdvancedChatAppGenerateTaskPipeline: """Handle node retry events.""" self._ensure_workflow_initialized() - workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried( - workflow_execution_id=self._workflow_run_id, event=event - ) node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, ) if node_retry_resp: @@ -344,14 +328,9 @@ class AdvancedChatAppGenerateTaskPipeline: """Handle node started events.""" self._ensure_workflow_initialized() - workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start( - workflow_execution_id=self._workflow_run_id, event=event - ) - node_start_resp = self._workflow_response_converter.workflow_node_start_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, ) if node_start_resp: @@ -367,14 +346,12 @@ class AdvancedChatAppGenerateTaskPipeline: self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {}) ) - workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event) node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, ) - self._save_output_for_event(event, workflow_node_execution.id) + self._save_output_for_event(event, event.node_execution_id) if node_finish_resp: yield node_finish_resp @@ -385,16 +362,13 @@ class AdvancedChatAppGenerateTaskPipeline: **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle various node failure events.""" - workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(event=event) - node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, ) if isinstance(event, QueueNodeExceptionEvent): - self._save_output_for_event(event, workflow_node_execution.id) + self._save_output_for_event(event, event.node_execution_id) if node_finish_resp: yield node_finish_resp @@ -504,29 +478,19 @@ class AdvancedChatAppGenerateTaskPipeline: self, event: QueueWorkflowSucceededEvent, *, - graph_runtime_state: GraphRuntimeState | None = None, trace_manager: TraceQueueManager | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle workflow succeeded events.""" + _ = trace_manager self._ensure_workflow_initialized() - validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state) - - with self._database_session() as session: - workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success( - workflow_run_id=self._workflow_run_id, - total_tokens=validated_state.total_tokens, - total_steps=validated_state.node_run_steps, - outputs=event.outputs, - conversation_id=self._conversation_id, - trace_manager=trace_manager, - external_trace_id=self._application_generate_entity.extras.get("external_trace_id"), - ) - workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_execution=workflow_execution, - ) + validated_state = self._ensure_graph_runtime_initialized() + workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_id=self._workflow_id, + status=WorkflowExecutionStatus.SUCCEEDED, + graph_runtime_state=validated_state, + ) yield workflow_finish_resp self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) @@ -535,30 +499,20 @@ class AdvancedChatAppGenerateTaskPipeline: self, event: QueueWorkflowPartialSuccessEvent, *, - graph_runtime_state: GraphRuntimeState | None = None, trace_manager: TraceQueueManager | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle workflow partial success events.""" + _ = trace_manager self._ensure_workflow_initialized() - validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state) - - with self._database_session() as session: - workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success( - workflow_run_id=self._workflow_run_id, - total_tokens=validated_state.total_tokens, - total_steps=validated_state.node_run_steps, - outputs=event.outputs, - exceptions_count=event.exceptions_count, - conversation_id=self._conversation_id, - trace_manager=trace_manager, - external_trace_id=self._application_generate_entity.extras.get("external_trace_id"), - ) - workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_execution=workflow_execution, - ) + validated_state = self._ensure_graph_runtime_initialized() + workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_id=self._workflow_id, + status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED, + graph_runtime_state=validated_state, + exceptions_count=event.exceptions_count, + ) yield workflow_finish_resp self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) @@ -567,32 +521,25 @@ class AdvancedChatAppGenerateTaskPipeline: self, event: QueueWorkflowFailedEvent, *, - graph_runtime_state: GraphRuntimeState | None = None, trace_manager: TraceQueueManager | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle workflow failed events.""" + _ = trace_manager self._ensure_workflow_initialized() - validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state) + validated_state = self._ensure_graph_runtime_initialized() + + workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_id=self._workflow_id, + status=WorkflowExecutionStatus.FAILED, + graph_runtime_state=validated_state, + error=event.error, + exceptions_count=event.exceptions_count, + ) with self._database_session() as session: - workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed( - workflow_run_id=self._workflow_run_id, - total_tokens=validated_state.total_tokens, - total_steps=validated_state.node_run_steps, - status=WorkflowExecutionStatus.FAILED, - error_message=event.error, - conversation_id=self._conversation_id, - trace_manager=trace_manager, - exceptions_count=event.exceptions_count, - external_trace_id=self._application_generate_entity.extras.get("external_trace_id"), - ) - workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_execution=workflow_execution, - ) - err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}")) + err_event = QueueErrorEvent(error=ValueError(f"Run failed: {event.error}")) err = self._base_task_pipeline.handle_error(event=err_event, session=session, message_id=self._message_id) yield workflow_finish_resp @@ -607,25 +554,23 @@ class AdvancedChatAppGenerateTaskPipeline: **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle stop events.""" - if self._workflow_run_id and graph_runtime_state: + _ = trace_manager + resolved_state = None + if self._workflow_run_id: + resolved_state = self._resolve_graph_runtime_state(graph_runtime_state) + + if self._workflow_run_id and resolved_state: + workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_id=self._workflow_id, + status=WorkflowExecutionStatus.STOPPED, + graph_runtime_state=resolved_state, + error=event.get_stop_reason(), + ) + with self._database_session() as session: - workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed( - workflow_run_id=self._workflow_run_id, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - status=WorkflowExecutionStatus.STOPPED, - error_message=event.get_stop_reason(), - conversation_id=self._conversation_id, - trace_manager=trace_manager, - external_trace_id=self._application_generate_entity.extras.get("external_trace_id"), - ) - workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_execution=workflow_execution, - ) # Save message - self._save_message(session=session, graph_runtime_state=graph_runtime_state) + self._save_message(session=session, graph_runtime_state=resolved_state) yield workflow_finish_resp elif event.stopped_by in ( @@ -647,7 +592,7 @@ class AdvancedChatAppGenerateTaskPipeline: **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle advanced chat message end events.""" - self._ensure_graph_runtime_initialized(graph_runtime_state) + resolved_state = self._ensure_graph_runtime_initialized(graph_runtime_state) output_moderation_answer = self._base_task_pipeline.handle_output_moderation_when_task_finished( self._task_state.answer @@ -661,7 +606,7 @@ class AdvancedChatAppGenerateTaskPipeline: # Save message with self._database_session() as session: - self._save_message(session=session, graph_runtime_state=graph_runtime_state) + self._save_message(session=session, graph_runtime_state=resolved_state) yield self._message_end_to_stream_response() @@ -670,10 +615,6 @@ class AdvancedChatAppGenerateTaskPipeline: ) -> Generator[StreamResponse, None, None]: """Handle retriever resources events.""" self._message_cycle_manager.handle_retriever_resources(event) - - with self._database_session() as session: - message = self._get_message(session=session) - message.message_metadata = self._task_state.metadata.model_dump_json() return yield # Make this a generator @@ -682,10 +623,6 @@ class AdvancedChatAppGenerateTaskPipeline: ) -> Generator[StreamResponse, None, None]: """Handle annotation reply events.""" self._message_cycle_manager.handle_annotation_reply(event) - - with self._database_session() as session: - message = self._get_message(session=session) - message.message_metadata = self._task_state.metadata.model_dump_json() return yield # Make this a generator @@ -739,7 +676,6 @@ class AdvancedChatAppGenerateTaskPipeline: self, event: Any, *, - graph_runtime_state: GraphRuntimeState | None = None, tts_publisher: AppGeneratorTTSPublisher | None = None, trace_manager: TraceQueueManager | None = None, queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None, @@ -752,7 +688,6 @@ class AdvancedChatAppGenerateTaskPipeline: if handler := handlers.get(event_type): yield from handler( event, - graph_runtime_state=graph_runtime_state, tts_publisher=tts_publisher, trace_manager=trace_manager, queue_message=queue_message, @@ -769,7 +704,6 @@ class AdvancedChatAppGenerateTaskPipeline: ): yield from self._handle_node_failed_events( event, - graph_runtime_state=graph_runtime_state, tts_publisher=tts_publisher, trace_manager=trace_manager, queue_message=queue_message, @@ -788,15 +722,12 @@ class AdvancedChatAppGenerateTaskPipeline: Process stream response using elegant Fluent Python patterns. Maintains exact same functionality as original 57-if-statement version. """ - # Initialize graph runtime state - graph_runtime_state: GraphRuntimeState | None = None - for queue_message in self._base_task_pipeline.queue_manager.listen(): event = queue_message.event match event: case QueueWorkflowStartedEvent(): - graph_runtime_state = event.graph_runtime_state + self._resolve_graph_runtime_state() yield from self._handle_workflow_started_event(event) case QueueErrorEvent(): @@ -804,15 +735,11 @@ class AdvancedChatAppGenerateTaskPipeline: break case QueueWorkflowFailedEvent(): - yield from self._handle_workflow_failed_event( - event, graph_runtime_state=graph_runtime_state, trace_manager=trace_manager - ) + yield from self._handle_workflow_failed_event(event, trace_manager=trace_manager) break case QueueStopEvent(): - yield from self._handle_stop_event( - event, graph_runtime_state=graph_runtime_state, trace_manager=trace_manager - ) + yield from self._handle_stop_event(event, graph_runtime_state=None, trace_manager=trace_manager) break # Handle all other events through elegant dispatch @@ -820,7 +747,6 @@ class AdvancedChatAppGenerateTaskPipeline: if responses := list( self._dispatch_event( event, - graph_runtime_state=graph_runtime_state, tts_publisher=tts_publisher, trace_manager=trace_manager, queue_message=queue_message, @@ -878,6 +804,12 @@ class AdvancedChatAppGenerateTaskPipeline: else: self._task_state.metadata.usage = LLMUsage.empty_usage() + def _seed_graph_runtime_state_from_queue_manager(self) -> None: + """Bootstrap the cached runtime state from the queue manager when present.""" + candidate = self._base_task_pipeline.queue_manager.graph_runtime_state + if candidate is not None: + self._graph_runtime_state = candidate + def _message_end_to_stream_response(self) -> MessageEndStreamResponse: """ Message end to stream response. diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 074555e31b..698eee9894 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -20,6 +20,7 @@ from core.app.entities.queue_entities import ( QueueStopEvent, WorkflowQueueMessage, ) +from core.workflow.runtime import GraphRuntimeState from extensions.ext_redis import redis_client logger = logging.getLogger(__name__) @@ -47,6 +48,7 @@ class AppQueueManager: q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue() self._q = q + self._graph_runtime_state: GraphRuntimeState | None = None self._stopped_cache: TTLCache[tuple, bool] = TTLCache(maxsize=1, ttl=1) self._cache_lock = threading.Lock() @@ -109,6 +111,16 @@ class AppQueueManager: """ self.publish(QueueErrorEvent(error=e), pub_from) + @property + def graph_runtime_state(self) -> GraphRuntimeState | None: + """Retrieve the attached graph runtime state, if available.""" + return self._graph_runtime_state + + @graph_runtime_state.setter + def graph_runtime_state(self, graph_runtime_state: GraphRuntimeState | None) -> None: + """Attach the live graph runtime state reference for downstream consumers.""" + self._graph_runtime_state = graph_runtime_state + def publish(self, event: AppQueueEvent, pub_from: PublishFrom): """ Publish event to queue diff --git a/api/core/app/apps/common/graph_runtime_state_support.py b/api/core/app/apps/common/graph_runtime_state_support.py new file mode 100644 index 0000000000..0b03149665 --- /dev/null +++ b/api/core/app/apps/common/graph_runtime_state_support.py @@ -0,0 +1,55 @@ +"""Shared helpers for managing GraphRuntimeState across task pipelines.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from core.workflow.runtime import GraphRuntimeState + +if TYPE_CHECKING: + from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline + + +class GraphRuntimeStateSupport: + """ + Mixin that centralises common GraphRuntimeState access patterns used by task pipelines. + + Subclasses are expected to provide: + * `_base_task_pipeline` – exposing the queue manager with an optional cached runtime state. + * `_graph_runtime_state` attribute used as the local cache for the runtime state. + """ + + _base_task_pipeline: BasedGenerateTaskPipeline + _graph_runtime_state: GraphRuntimeState | None = None + + def _ensure_graph_runtime_initialized( + self, + graph_runtime_state: GraphRuntimeState | None = None, + ) -> GraphRuntimeState: + """Validate and return the active graph runtime state.""" + return self._resolve_graph_runtime_state(graph_runtime_state) + + def _extract_workflow_run_id(self, graph_runtime_state: GraphRuntimeState) -> str: + system_variables = graph_runtime_state.variable_pool.system_variables + if not system_variables or not system_variables.workflow_execution_id: + raise ValueError("workflow_execution_id missing from runtime state") + return str(system_variables.workflow_execution_id) + + def _resolve_graph_runtime_state( + self, + graph_runtime_state: GraphRuntimeState | None = None, + ) -> GraphRuntimeState: + """Return the cached runtime state or bootstrap it from the queue manager.""" + if graph_runtime_state is not None: + self._graph_runtime_state = graph_runtime_state + return graph_runtime_state + + if self._graph_runtime_state is None: + candidate = self._base_task_pipeline.queue_manager.graph_runtime_state + if candidate is not None: + self._graph_runtime_state = candidate + + if self._graph_runtime_state is None: + raise ValueError("graph runtime state not initialized.") + + return self._graph_runtime_state diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 7c7a4fd6ac..2c9ce5b56d 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -1,9 +1,8 @@ import time from collections.abc import Mapping, Sequence -from datetime import UTC, datetime -from typing import Any, Union - -from sqlalchemy.orm import Session +from dataclasses import dataclass +from datetime import datetime +from typing import Any, NewType, Union from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( @@ -39,16 +38,36 @@ from core.plugin.impl.datasource import PluginDatasourceManager from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager from core.variables.segments import ArrayFileSegment, FileSegment, Segment -from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus +from core.workflow.enums import ( + NodeType, + SystemVariableKey, + WorkflowExecutionStatus, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from core.workflow.runtime import GraphRuntimeState +from core.workflow.system_variable import SystemVariable +from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.datetime_utils import naive_utc_now -from models import ( - Account, - EndUser, -) +from models import Account, EndUser from services.variable_truncator import VariableTruncator +NodeExecutionId = NewType("NodeExecutionId", str) + + +@dataclass(slots=True) +class _NodeSnapshot: + """In-memory cache for node metadata between start and completion events.""" + + title: str + index: int + start_at: datetime + iteration_id: str = "" + """Empty string means the node is not executing inside an iteration.""" + loop_id: str = "" + """Empty string means the node is not executing inside a loop.""" + class WorkflowResponseConverter: def __init__( @@ -56,37 +75,151 @@ class WorkflowResponseConverter: *, application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], user: Union[Account, EndUser], + system_variables: SystemVariable, ): self._application_generate_entity = application_generate_entity self._user = user + self._system_variables = system_variables + self._workflow_inputs = self._prepare_workflow_inputs() self._truncator = VariableTruncator.default() + self._node_snapshots: dict[NodeExecutionId, _NodeSnapshot] = {} + self._workflow_execution_id: str | None = None + self._workflow_started_at: datetime | None = None + + # ------------------------------------------------------------------ + # Workflow lifecycle helpers + # ------------------------------------------------------------------ + def _prepare_workflow_inputs(self) -> Mapping[str, Any]: + inputs = dict(self._application_generate_entity.inputs) + for field_name, value in self._system_variables.to_dict().items(): + # TODO(@future-refactor): store system variables separately from user inputs so we don't + # need to flatten `sys.*` entries into the input payload just for rerun/export tooling. + if field_name == SystemVariableKey.CONVERSATION_ID: + # Conversation IDs are session-scoped; omitting them keeps workflow inputs + # reusable without pinning new runs to a prior conversation. + continue + inputs[f"sys.{field_name}"] = value + handled = WorkflowEntry.handle_special_values(inputs) + return dict(handled or {}) + + def _ensure_workflow_run_id(self, workflow_run_id: str | None = None) -> str: + """Return the memoized workflow run id, optionally seeding it during start events.""" + if workflow_run_id is not None: + self._workflow_execution_id = workflow_run_id + if not self._workflow_execution_id: + raise ValueError("workflow_run_id missing before streaming workflow events") + return self._workflow_execution_id + + # ------------------------------------------------------------------ + # Node snapshot helpers + # ------------------------------------------------------------------ + def _store_snapshot(self, event: QueueNodeStartedEvent) -> _NodeSnapshot: + snapshot = _NodeSnapshot( + title=event.node_title, + index=event.node_run_index, + start_at=event.start_at, + iteration_id=event.in_iteration_id or "", + loop_id=event.in_loop_id or "", + ) + node_execution_id = NodeExecutionId(event.node_execution_id) + self._node_snapshots[node_execution_id] = snapshot + return snapshot + + def _get_snapshot(self, node_execution_id: str) -> _NodeSnapshot | None: + return self._node_snapshots.get(NodeExecutionId(node_execution_id)) + + def _pop_snapshot(self, node_execution_id: str) -> _NodeSnapshot | None: + return self._node_snapshots.pop(NodeExecutionId(node_execution_id), None) + + @staticmethod + def _merge_metadata( + base_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None, + snapshot: _NodeSnapshot | None, + ) -> Mapping[WorkflowNodeExecutionMetadataKey, Any] | None: + if not base_metadata and not snapshot: + return base_metadata + + merged: dict[WorkflowNodeExecutionMetadataKey, Any] = {} + if base_metadata: + merged.update(base_metadata) + + if snapshot: + if snapshot.iteration_id: + merged[WorkflowNodeExecutionMetadataKey.ITERATION_ID] = snapshot.iteration_id + if snapshot.loop_id: + merged[WorkflowNodeExecutionMetadataKey.LOOP_ID] = snapshot.loop_id + + return merged or None + + def _truncate_mapping( + self, + mapping: Mapping[str, Any] | None, + ) -> tuple[Mapping[str, Any] | None, bool]: + if mapping is None: + return None, False + if not mapping: + return {}, False + + normalized = WorkflowEntry.handle_special_values(dict(mapping)) + if normalized is None: + return None, False + + truncated, is_truncated = self._truncator.truncate_variable_mapping(dict(normalized)) + return truncated, is_truncated + + @staticmethod + def _encode_outputs(outputs: Mapping[str, Any] | None) -> Mapping[str, Any] | None: + if outputs is None: + return None + converter = WorkflowRuntimeTypeConverter() + return converter.to_json_encodable(outputs) def workflow_start_to_stream_response( self, *, task_id: str, - workflow_execution: WorkflowExecution, + workflow_run_id: str, + workflow_id: str, ) -> WorkflowStartStreamResponse: + run_id = self._ensure_workflow_run_id(workflow_run_id) + started_at = naive_utc_now() + self._workflow_started_at = started_at + return WorkflowStartStreamResponse( task_id=task_id, - workflow_run_id=workflow_execution.id_, + workflow_run_id=run_id, data=WorkflowStartStreamResponse.Data( - id=workflow_execution.id_, - workflow_id=workflow_execution.workflow_id, - inputs=workflow_execution.inputs, - created_at=int(workflow_execution.started_at.timestamp()), + id=run_id, + workflow_id=workflow_id, + inputs=self._workflow_inputs, + created_at=int(started_at.timestamp()), ), ) def workflow_finish_to_stream_response( self, *, - session: Session, task_id: str, - workflow_execution: WorkflowExecution, + workflow_id: str, + status: WorkflowExecutionStatus, + graph_runtime_state: GraphRuntimeState, + error: str | None = None, + exceptions_count: int = 0, ) -> WorkflowFinishStreamResponse: - created_by = None + run_id = self._ensure_workflow_run_id() + started_at = self._workflow_started_at + if started_at is None: + raise ValueError( + "workflow_finish_to_stream_response called before workflow_start_to_stream_response", + ) + finished_at = naive_utc_now() + elapsed_time = (finished_at - started_at).total_seconds() + + outputs_mapping = graph_runtime_state.outputs or {} + encoded_outputs = WorkflowRuntimeTypeConverter().to_json_encodable(outputs_mapping) + + created_by: Mapping[str, object] | None user = self._user if isinstance(user, Account): created_by = { @@ -94,38 +227,29 @@ class WorkflowResponseConverter: "name": user.name, "email": user.email, } - elif isinstance(user, EndUser): + else: created_by = { "id": user.id, "user": user.session_id, } - else: - raise NotImplementedError(f"User type not supported: {type(user)}") - - # Handle the case where finished_at is None by using current time as default - finished_at_timestamp = ( - int(workflow_execution.finished_at.timestamp()) - if workflow_execution.finished_at - else int(datetime.now(UTC).timestamp()) - ) return WorkflowFinishStreamResponse( task_id=task_id, - workflow_run_id=workflow_execution.id_, + workflow_run_id=run_id, data=WorkflowFinishStreamResponse.Data( - id=workflow_execution.id_, - workflow_id=workflow_execution.workflow_id, - status=workflow_execution.status, - outputs=WorkflowRuntimeTypeConverter().to_json_encodable(workflow_execution.outputs), - error=workflow_execution.error_message, - elapsed_time=workflow_execution.elapsed_time, - total_tokens=workflow_execution.total_tokens, - total_steps=workflow_execution.total_steps, + id=run_id, + workflow_id=workflow_id, + status=status.value, + outputs=encoded_outputs, + error=error, + elapsed_time=elapsed_time, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, created_by=created_by, - created_at=int(workflow_execution.started_at.timestamp()), - finished_at=finished_at_timestamp, - files=self.fetch_files_from_node_outputs(workflow_execution.outputs), - exceptions_count=workflow_execution.exceptions_count, + created_at=int(started_at.timestamp()), + finished_at=int(finished_at.timestamp()), + files=self.fetch_files_from_node_outputs(outputs_mapping), + exceptions_count=exceptions_count, ), ) @@ -134,38 +258,28 @@ class WorkflowResponseConverter: *, event: QueueNodeStartedEvent, task_id: str, - workflow_node_execution: WorkflowNodeExecution, ) -> NodeStartStreamResponse | None: - if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: - return None - if not workflow_node_execution.workflow_execution_id: + if event.node_type in {NodeType.ITERATION, NodeType.LOOP}: return None + run_id = self._ensure_workflow_run_id() + snapshot = self._store_snapshot(event) response = NodeStartStreamResponse( task_id=task_id, - workflow_run_id=workflow_node_execution.workflow_execution_id, + workflow_run_id=run_id, data=NodeStartStreamResponse.Data( - id=workflow_node_execution.id, - node_id=workflow_node_execution.node_id, - node_type=workflow_node_execution.node_type, - title=workflow_node_execution.title, - index=workflow_node_execution.index, - predecessor_node_id=workflow_node_execution.predecessor_node_id, - inputs=workflow_node_execution.get_response_inputs(), - inputs_truncated=workflow_node_execution.inputs_truncated, - created_at=int(workflow_node_execution.created_at.timestamp()), - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id, - parent_parallel_id=event.parent_parallel_id, - parent_parallel_start_node_id=event.parent_parallel_start_node_id, + id=event.node_execution_id, + node_id=event.node_id, + node_type=event.node_type, + title=snapshot.title, + index=snapshot.index, + created_at=int(snapshot.start_at.timestamp()), iteration_id=event.in_iteration_id, loop_id=event.in_loop_id, - parallel_run_id=event.parallel_mode_run_id, agent_strategy=event.agent_strategy, ), ) - # extras logic if event.node_type == NodeType.TOOL: response.data.extras["icon"] = ToolManager.get_tool_icon( tenant_id=self._application_generate_entity.app_config.tenant_id, @@ -189,41 +303,54 @@ class WorkflowResponseConverter: *, event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent, task_id: str, - workflow_node_execution: WorkflowNodeExecution, ) -> NodeFinishStreamResponse | None: - if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: - return None - if not workflow_node_execution.workflow_execution_id: - return None - if not workflow_node_execution.finished_at: + if event.node_type in {NodeType.ITERATION, NodeType.LOOP}: return None + run_id = self._ensure_workflow_run_id() + snapshot = self._pop_snapshot(event.node_execution_id) - json_converter = WorkflowRuntimeTypeConverter() + start_at = snapshot.start_at if snapshot else event.start_at + finished_at = naive_utc_now() + elapsed_time = (finished_at - start_at).total_seconds() + + inputs, inputs_truncated = self._truncate_mapping(event.inputs) + process_data, process_data_truncated = self._truncate_mapping(event.process_data) + encoded_outputs = self._encode_outputs(event.outputs) + outputs, outputs_truncated = self._truncate_mapping(encoded_outputs) + metadata = self._merge_metadata(event.execution_metadata, snapshot) + + if isinstance(event, QueueNodeSucceededEvent): + status = WorkflowNodeExecutionStatus.SUCCEEDED.value + error_message = event.error + elif isinstance(event, QueueNodeFailedEvent): + status = WorkflowNodeExecutionStatus.FAILED.value + error_message = event.error + else: + status = WorkflowNodeExecutionStatus.EXCEPTION.value + error_message = event.error return NodeFinishStreamResponse( task_id=task_id, - workflow_run_id=workflow_node_execution.workflow_execution_id, + workflow_run_id=run_id, data=NodeFinishStreamResponse.Data( - id=workflow_node_execution.id, - node_id=workflow_node_execution.node_id, - node_type=workflow_node_execution.node_type, - index=workflow_node_execution.index, - title=workflow_node_execution.title, - predecessor_node_id=workflow_node_execution.predecessor_node_id, - inputs=workflow_node_execution.get_response_inputs(), - inputs_truncated=workflow_node_execution.inputs_truncated, - process_data=workflow_node_execution.get_response_process_data(), - process_data_truncated=workflow_node_execution.process_data_truncated, - outputs=json_converter.to_json_encodable(workflow_node_execution.get_response_outputs()), - outputs_truncated=workflow_node_execution.outputs_truncated, - status=workflow_node_execution.status, - error=workflow_node_execution.error, - elapsed_time=workflow_node_execution.elapsed_time, - execution_metadata=workflow_node_execution.metadata, - created_at=int(workflow_node_execution.created_at.timestamp()), - finished_at=int(workflow_node_execution.finished_at.timestamp()), - files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}), - parallel_id=event.parallel_id, + id=event.node_execution_id, + node_id=event.node_id, + node_type=event.node_type, + index=snapshot.index if snapshot else 0, + title=snapshot.title if snapshot else "", + inputs=inputs, + inputs_truncated=inputs_truncated, + process_data=process_data, + process_data_truncated=process_data_truncated, + outputs=outputs, + outputs_truncated=outputs_truncated, + status=status, + error=error_message, + elapsed_time=elapsed_time, + execution_metadata=metadata, + created_at=int(start_at.timestamp()), + finished_at=int(finished_at.timestamp()), + files=self.fetch_files_from_node_outputs(event.outputs or {}), iteration_id=event.in_iteration_id, loop_id=event.in_loop_id, ), @@ -234,44 +361,45 @@ class WorkflowResponseConverter: *, event: QueueNodeRetryEvent, task_id: str, - workflow_node_execution: WorkflowNodeExecution, - ) -> Union[NodeRetryStreamResponse, NodeFinishStreamResponse] | None: - if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: - return None - if not workflow_node_execution.workflow_execution_id: - return None - if not workflow_node_execution.finished_at: + ) -> NodeRetryStreamResponse | None: + if event.node_type in {NodeType.ITERATION, NodeType.LOOP}: return None + run_id = self._ensure_workflow_run_id() - json_converter = WorkflowRuntimeTypeConverter() + snapshot = self._get_snapshot(event.node_execution_id) + if snapshot is None: + raise AssertionError("node retry event arrived without a stored snapshot") + finished_at = naive_utc_now() + elapsed_time = (finished_at - event.start_at).total_seconds() + + inputs, inputs_truncated = self._truncate_mapping(event.inputs) + process_data, process_data_truncated = self._truncate_mapping(event.process_data) + encoded_outputs = self._encode_outputs(event.outputs) + outputs, outputs_truncated = self._truncate_mapping(encoded_outputs) + metadata = self._merge_metadata(event.execution_metadata, snapshot) return NodeRetryStreamResponse( task_id=task_id, - workflow_run_id=workflow_node_execution.workflow_execution_id, + workflow_run_id=run_id, data=NodeRetryStreamResponse.Data( - id=workflow_node_execution.id, - node_id=workflow_node_execution.node_id, - node_type=workflow_node_execution.node_type, - index=workflow_node_execution.index, - title=workflow_node_execution.title, - predecessor_node_id=workflow_node_execution.predecessor_node_id, - inputs=workflow_node_execution.get_response_inputs(), - inputs_truncated=workflow_node_execution.inputs_truncated, - process_data=workflow_node_execution.get_response_process_data(), - process_data_truncated=workflow_node_execution.process_data_truncated, - outputs=json_converter.to_json_encodable(workflow_node_execution.get_response_outputs()), - outputs_truncated=workflow_node_execution.outputs_truncated, - status=workflow_node_execution.status, - error=workflow_node_execution.error, - elapsed_time=workflow_node_execution.elapsed_time, - execution_metadata=workflow_node_execution.metadata, - created_at=int(workflow_node_execution.created_at.timestamp()), - finished_at=int(workflow_node_execution.finished_at.timestamp()), - files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}), - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id, - parent_parallel_id=event.parent_parallel_id, - parent_parallel_start_node_id=event.parent_parallel_start_node_id, + id=event.node_execution_id, + node_id=event.node_id, + node_type=event.node_type, + index=snapshot.index, + title=snapshot.title, + inputs=inputs, + inputs_truncated=inputs_truncated, + process_data=process_data, + process_data_truncated=process_data_truncated, + outputs=outputs, + outputs_truncated=outputs_truncated, + status=WorkflowNodeExecutionStatus.RETRY.value, + error=event.error, + elapsed_time=elapsed_time, + execution_metadata=metadata, + created_at=int(snapshot.start_at.timestamp()), + finished_at=int(finished_at.timestamp()), + files=self.fetch_files_from_node_outputs(event.outputs or {}), iteration_id=event.in_iteration_id, loop_id=event.in_loop_id, retry_index=event.retry_index, @@ -379,8 +507,6 @@ class WorkflowResponseConverter: inputs=new_inputs, inputs_truncated=truncated, metadata=event.metadata or {}, - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id, ), ) @@ -405,9 +531,6 @@ class WorkflowResponseConverter: pre_loop_output={}, created_at=int(time.time()), extras={}, - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id, - parallel_mode_run_id=event.parallel_mode_run_id, ), ) @@ -446,8 +569,6 @@ class WorkflowResponseConverter: execution_metadata=event.metadata, finished_at=int(time.time()), steps=event.steps, - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id, ), ) diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index bd077c4cb8..1fb076b685 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -352,6 +352,8 @@ class PipelineGenerator(BaseAppGenerator): "application_generate_entity": application_generate_entity, "workflow_thread_pool_id": workflow_thread_pool_id, "variable_loader": variable_loader, + "workflow_execution_repository": workflow_execution_repository, + "workflow_node_execution_repository": workflow_node_execution_repository, }, ) @@ -367,8 +369,6 @@ class PipelineGenerator(BaseAppGenerator): workflow=workflow, queue_manager=queue_manager, user=user, - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, stream=streaming, draft_var_saver_factory=draft_var_saver_factory, ) @@ -573,6 +573,8 @@ class PipelineGenerator(BaseAppGenerator): queue_manager: AppQueueManager, context: contextvars.Context, variable_loader: VariableLoader, + workflow_execution_repository: WorkflowExecutionRepository, + workflow_node_execution_repository: WorkflowNodeExecutionRepository, workflow_thread_pool_id: str | None = None, ) -> None: """ @@ -620,6 +622,8 @@ class PipelineGenerator(BaseAppGenerator): variable_loader=variable_loader, workflow=workflow, system_user_id=system_user_id, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, ) runner.run() @@ -648,8 +652,6 @@ class PipelineGenerator(BaseAppGenerator): workflow: Workflow, queue_manager: AppQueueManager, user: Union[Account, EndUser], - workflow_execution_repository: WorkflowExecutionRepository, - workflow_node_execution_repository: WorkflowNodeExecutionRepository, draft_var_saver_factory: DraftVariableSaverFactory, stream: bool = False, ) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: @@ -660,7 +662,6 @@ class PipelineGenerator(BaseAppGenerator): :param queue_manager: queue manager :param user: account or end user :param stream: is stream - :param workflow_node_execution_repository: optional repository for workflow node execution :return: """ # init generate task pipeline @@ -670,8 +671,6 @@ class PipelineGenerator(BaseAppGenerator): queue_manager=queue_manager, user=user, stream=stream, - workflow_node_execution_repository=workflow_node_execution_repository, - workflow_execution_repository=workflow_execution_repository, draft_var_saver_factory=draft_var_saver_factory, ) diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index a8a7dde2b4..4be9e01fbf 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -11,11 +11,14 @@ from core.app.entities.app_invoke_entities import ( ) from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput from core.workflow.entities.graph_init_params import GraphInitParams -from core.workflow.entities.graph_runtime_state import GraphRuntimeState -from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import WorkflowType from core.workflow.graph import Graph +from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent from core.workflow.nodes.node_factory import DifyNodeFactory +from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import VariableLoader from core.workflow.workflow_entry import WorkflowEntry @@ -40,6 +43,8 @@ class PipelineRunner(WorkflowBasedAppRunner): variable_loader: VariableLoader, workflow: Workflow, system_user_id: str, + workflow_execution_repository: WorkflowExecutionRepository, + workflow_node_execution_repository: WorkflowNodeExecutionRepository, workflow_thread_pool_id: str | None = None, ) -> None: """ @@ -56,6 +61,8 @@ class PipelineRunner(WorkflowBasedAppRunner): self.workflow_thread_pool_id = workflow_thread_pool_id self._workflow = workflow self._sys_user_id = system_user_id + self._workflow_execution_repository = workflow_execution_repository + self._workflow_node_execution_repository = workflow_node_execution_repository def _get_app_id(self) -> str: return self.application_generate_entity.app_config.app_id @@ -163,6 +170,23 @@ class PipelineRunner(WorkflowBasedAppRunner): variable_pool=variable_pool, ) + self._queue_manager.graph_runtime_state = graph_runtime_state + + persistence_layer = WorkflowPersistenceLayer( + application_generate_entity=self.application_generate_entity, + workflow_info=PersistenceWorkflowInfo( + workflow_id=workflow.id, + workflow_type=WorkflowType(workflow.type), + version=workflow.version, + graph_data=workflow.graph_dict, + ), + workflow_execution_repository=self._workflow_execution_repository, + workflow_node_execution_repository=self._workflow_node_execution_repository, + trace_manager=self.application_generate_entity.trace_manager, + ) + + workflow_entry.graph_engine.layer(persistence_layer) + generator = workflow_entry.run() for event in generator: diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 45d047434b..f22ef5431e 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -231,6 +231,8 @@ class WorkflowAppGenerator(BaseAppGenerator): "queue_manager": queue_manager, "context": context, "variable_loader": variable_loader, + "workflow_execution_repository": workflow_execution_repository, + "workflow_node_execution_repository": workflow_node_execution_repository, }, ) @@ -244,8 +246,6 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow=workflow, queue_manager=queue_manager, user=user, - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, draft_var_saver_factory=draft_var_saver_factory, stream=streaming, ) @@ -424,6 +424,8 @@ class WorkflowAppGenerator(BaseAppGenerator): queue_manager: AppQueueManager, context: contextvars.Context, variable_loader: VariableLoader, + workflow_execution_repository: WorkflowExecutionRepository, + workflow_node_execution_repository: WorkflowNodeExecutionRepository, ) -> None: """ Generate worker in a new thread. @@ -465,6 +467,8 @@ class WorkflowAppGenerator(BaseAppGenerator): variable_loader=variable_loader, workflow=workflow, system_user_id=system_user_id, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, ) try: @@ -493,8 +497,6 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow: Workflow, queue_manager: AppQueueManager, user: Union[Account, EndUser], - workflow_execution_repository: WorkflowExecutionRepository, - workflow_node_execution_repository: WorkflowNodeExecutionRepository, draft_var_saver_factory: DraftVariableSaverFactory, stream: bool = False, ) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: @@ -514,8 +516,6 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow=workflow, queue_manager=queue_manager, user=user, - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, draft_var_saver_factory=draft_var_saver_factory, stream=stream, ) diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 943ae8ab4e..3c9bf176b5 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -5,12 +5,13 @@ from typing import cast from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfig from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner -from core.app.entities.app_invoke_entities import ( - InvokeFrom, - WorkflowAppGenerateEntity, -) -from core.workflow.entities import GraphRuntimeState, VariablePool +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.workflow.enums import WorkflowType from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel +from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import VariableLoader from core.workflow.workflow_entry import WorkflowEntry @@ -34,6 +35,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): variable_loader: VariableLoader, workflow: Workflow, system_user_id: str, + workflow_execution_repository: WorkflowExecutionRepository, + workflow_node_execution_repository: WorkflowNodeExecutionRepository, ): super().__init__( queue_manager=queue_manager, @@ -43,6 +46,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): self.application_generate_entity = application_generate_entity self._workflow = workflow self._sys_user_id = system_user_id + self._workflow_execution_repository = workflow_execution_repository + self._workflow_node_execution_repository = workflow_node_execution_repository def run(self): """ @@ -51,6 +56,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): app_config = self.application_generate_entity.app_config app_config = cast(WorkflowAppConfig, app_config) + system_inputs = SystemVariable( + files=self.application_generate_entity.files, + user_id=self._sys_user_id, + app_id=app_config.app_id, + workflow_id=app_config.workflow_id, + workflow_execution_id=self.application_generate_entity.workflow_execution_id, + ) + # if only single iteration or single loop run is requested if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution( @@ -60,18 +73,9 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): ) else: inputs = self.application_generate_entity.inputs - files = self.application_generate_entity.files # Create a variable pool. - system_inputs = SystemVariable( - files=files, - user_id=self._sys_user_id, - app_id=app_config.app_id, - workflow_id=app_config.workflow_id, - workflow_execution_id=self.application_generate_entity.workflow_execution_id, - ) - variable_pool = VariablePool( system_variables=system_inputs, user_inputs=inputs, @@ -96,6 +100,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): channel_key = f"workflow:{task_id}:commands" command_channel = RedisChannel(redis_client, channel_key) + self._queue_manager.graph_runtime_state = graph_runtime_state + workflow_entry = WorkflowEntry( tenant_id=self._workflow.tenant_id, app_id=self._workflow.app_id, @@ -115,6 +121,21 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): command_channel=command_channel, ) + persistence_layer = WorkflowPersistenceLayer( + application_generate_entity=self.application_generate_entity, + workflow_info=PersistenceWorkflowInfo( + workflow_id=self._workflow.id, + workflow_type=WorkflowType(self._workflow.type), + version=self._workflow.version, + graph_data=self._workflow.graph_dict, + ), + workflow_execution_repository=self._workflow_execution_repository, + workflow_node_execution_repository=self._workflow_node_execution_repository, + trace_manager=self.application_generate_entity.trace_manager, + ) + + workflow_entry.graph_engine.layer(persistence_layer) + generator = workflow_entry.run() for event in generator: diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index ec4dc87643..08e2fce48c 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -8,11 +8,9 @@ from sqlalchemy.orm import Session from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter -from core.app.entities.app_invoke_entities import ( - InvokeFrom, - WorkflowAppGenerateEntity, -) +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( AppQueueEvent, MessageQueueMessage, @@ -53,27 +51,20 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager -from core.workflow.entities import GraphRuntimeState, WorkflowExecution -from core.workflow.enums import WorkflowExecutionStatus, WorkflowType +from core.workflow.enums import WorkflowExecutionStatus from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory -from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository -from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.runtime import GraphRuntimeState from core.workflow.system_variable import SystemVariable -from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager from extensions.ext_database import db from models import Account from models.enums import CreatorUserRole from models.model import EndUser -from models.workflow import ( - Workflow, - WorkflowAppLog, - WorkflowAppLogCreatedFrom, -) +from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom logger = logging.getLogger(__name__) -class WorkflowAppGenerateTaskPipeline: +class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): """ WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ @@ -85,8 +76,6 @@ class WorkflowAppGenerateTaskPipeline: queue_manager: AppQueueManager, user: Union[Account, EndUser], stream: bool, - workflow_execution_repository: WorkflowExecutionRepository, - workflow_node_execution_repository: WorkflowNodeExecutionRepository, draft_var_saver_factory: DraftVariableSaverFactory, ): self._base_task_pipeline = BasedGenerateTaskPipeline( @@ -99,42 +88,30 @@ class WorkflowAppGenerateTaskPipeline: self._user_id = user.id user_session_id = user.session_id self._created_by_role = CreatorUserRole.END_USER - elif isinstance(user, Account): + else: self._user_id = user.id user_session_id = user.id self._created_by_role = CreatorUserRole.ACCOUNT - else: - raise ValueError(f"Invalid user type: {type(user)}") - - self._workflow_cycle_manager = WorkflowCycleManager( - application_generate_entity=application_generate_entity, - workflow_system_variables=SystemVariable( - files=application_generate_entity.files, - user_id=user_session_id, - app_id=application_generate_entity.app_config.app_id, - workflow_id=workflow.id, - workflow_execution_id=application_generate_entity.workflow_execution_id, - ), - workflow_info=CycleManagerWorkflowInfo( - workflow_id=workflow.id, - workflow_type=WorkflowType(workflow.type), - version=workflow.version, - graph_data=workflow.graph_dict, - ), - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, - ) - - self._workflow_response_converter = WorkflowResponseConverter( - application_generate_entity=application_generate_entity, - user=user, - ) self._application_generate_entity = application_generate_entity self._workflow_features_dict = workflow.features_dict - self._workflow_run_id = "" + self._workflow_execution_id = "" self._invoke_from = queue_manager.invoke_from self._draft_var_saver_factory = draft_var_saver_factory + self._workflow = workflow + self._workflow_system_variables = SystemVariable( + files=application_generate_entity.files, + user_id=user_session_id, + app_id=application_generate_entity.app_config.app_id, + workflow_id=workflow.id, + workflow_execution_id=application_generate_entity.workflow_execution_id, + ) + self._workflow_response_converter = WorkflowResponseConverter( + application_generate_entity=application_generate_entity, + user=user, + system_variables=self._workflow_system_variables, + ) + self._graph_runtime_state: GraphRuntimeState | None = self._base_task_pipeline.queue_manager.graph_runtime_state def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: """ @@ -261,15 +238,9 @@ class WorkflowAppGenerateTaskPipeline: def _ensure_workflow_initialized(self): """Fluent validation for workflow state.""" - if not self._workflow_run_id: + if not self._workflow_execution_id: raise ValueError("workflow run not initialized.") - def _ensure_graph_runtime_initialized(self, graph_runtime_state: GraphRuntimeState | None) -> GraphRuntimeState: - """Fluent validation for graph runtime state.""" - if not graph_runtime_state: - raise ValueError("graph runtime state not initialized.") - return graph_runtime_state - def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]: """Handle ping events.""" yield self._base_task_pipeline.ping_stream_response() @@ -283,12 +254,14 @@ class WorkflowAppGenerateTaskPipeline: self, event: QueueWorkflowStartedEvent, **kwargs ) -> Generator[StreamResponse, None, None]: """Handle workflow started events.""" - # init workflow run - workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start() - self._workflow_run_id = workflow_execution.id_ + runtime_state = self._resolve_graph_runtime_state() + + run_id = self._extract_workflow_run_id(runtime_state) + self._workflow_execution_id = run_id start_resp = self._workflow_response_converter.workflow_start_to_stream_response( task_id=self._application_generate_entity.task_id, - workflow_execution=workflow_execution, + workflow_run_id=run_id, + workflow_id=self._workflow.id, ) yield start_resp @@ -296,14 +269,9 @@ class WorkflowAppGenerateTaskPipeline: """Handle node retry events.""" self._ensure_workflow_initialized() - workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried( - workflow_execution_id=self._workflow_run_id, - event=event, - ) response = self._workflow_response_converter.workflow_node_retry_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, ) if response: @@ -315,13 +283,9 @@ class WorkflowAppGenerateTaskPipeline: """Handle node started events.""" self._ensure_workflow_initialized() - workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start( - workflow_execution_id=self._workflow_run_id, event=event - ) node_start_response = self._workflow_response_converter.workflow_node_start_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, ) if node_start_response: @@ -331,14 +295,12 @@ class WorkflowAppGenerateTaskPipeline: self, event: QueueNodeSucceededEvent, **kwargs ) -> Generator[StreamResponse, None, None]: """Handle node succeeded events.""" - workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event) node_success_response = self._workflow_response_converter.workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, ) - self._save_output_for_event(event, workflow_node_execution.id) + self._save_output_for_event(event, event.node_execution_id) if node_success_response: yield node_success_response @@ -349,17 +311,13 @@ class WorkflowAppGenerateTaskPipeline: **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle various node failure events.""" - workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed( - event=event, - ) node_failed_response = self._workflow_response_converter.workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, ) if isinstance(event, QueueNodeExceptionEvent): - self._save_output_for_event(event, workflow_node_execution.id) + self._save_output_for_event(event, event.node_execution_id) if node_failed_response: yield node_failed_response @@ -372,7 +330,7 @@ class WorkflowAppGenerateTaskPipeline: iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response( task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, + workflow_execution_id=self._workflow_execution_id, event=event, ) yield iter_start_resp @@ -385,7 +343,7 @@ class WorkflowAppGenerateTaskPipeline: iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response( task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, + workflow_execution_id=self._workflow_execution_id, event=event, ) yield iter_next_resp @@ -398,7 +356,7 @@ class WorkflowAppGenerateTaskPipeline: iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response( task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, + workflow_execution_id=self._workflow_execution_id, event=event, ) yield iter_finish_resp @@ -409,7 +367,7 @@ class WorkflowAppGenerateTaskPipeline: loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response( task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, + workflow_execution_id=self._workflow_execution_id, event=event, ) yield loop_start_resp @@ -420,7 +378,7 @@ class WorkflowAppGenerateTaskPipeline: loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response( task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, + workflow_execution_id=self._workflow_execution_id, event=event, ) yield loop_next_resp @@ -433,7 +391,7 @@ class WorkflowAppGenerateTaskPipeline: loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response( task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, + workflow_execution_id=self._workflow_execution_id, event=event, ) yield loop_finish_resp @@ -442,33 +400,22 @@ class WorkflowAppGenerateTaskPipeline: self, event: QueueWorkflowSucceededEvent, *, - graph_runtime_state: GraphRuntimeState | None = None, trace_manager: TraceQueueManager | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle workflow succeeded events.""" + _ = trace_manager self._ensure_workflow_initialized() - validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state) + validated_state = self._ensure_graph_runtime_initialized() + workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_id=self._workflow.id, + status=WorkflowExecutionStatus.SUCCEEDED, + graph_runtime_state=validated_state, + ) with self._database_session() as session: - workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success( - workflow_run_id=self._workflow_run_id, - total_tokens=validated_state.total_tokens, - total_steps=validated_state.node_run_steps, - outputs=event.outputs, - conversation_id=None, - trace_manager=trace_manager, - external_trace_id=self._application_generate_entity.extras.get("external_trace_id"), - ) - - # save workflow app log - self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) - - workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_execution=workflow_execution, - ) + self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id) yield workflow_finish_resp @@ -476,34 +423,23 @@ class WorkflowAppGenerateTaskPipeline: self, event: QueueWorkflowPartialSuccessEvent, *, - graph_runtime_state: GraphRuntimeState | None = None, trace_manager: TraceQueueManager | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle workflow partial success events.""" + _ = trace_manager self._ensure_workflow_initialized() - validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state) + validated_state = self._ensure_graph_runtime_initialized() + workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_id=self._workflow.id, + status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED, + graph_runtime_state=validated_state, + exceptions_count=event.exceptions_count, + ) with self._database_session() as session: - workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success( - workflow_run_id=self._workflow_run_id, - total_tokens=validated_state.total_tokens, - total_steps=validated_state.node_run_steps, - outputs=event.outputs, - exceptions_count=event.exceptions_count, - conversation_id=None, - trace_manager=trace_manager, - external_trace_id=self._application_generate_entity.extras.get("external_trace_id"), - ) - - # save workflow app log - self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) - - workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_execution=workflow_execution, - ) + self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id) yield workflow_finish_resp @@ -511,37 +447,33 @@ class WorkflowAppGenerateTaskPipeline: self, event: Union[QueueWorkflowFailedEvent, QueueStopEvent], *, - graph_runtime_state: GraphRuntimeState | None = None, trace_manager: TraceQueueManager | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle workflow failed and stop events.""" + _ = trace_manager self._ensure_workflow_initialized() - validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state) + validated_state = self._ensure_graph_runtime_initialized() + + if isinstance(event, QueueWorkflowFailedEvent): + status = WorkflowExecutionStatus.FAILED + error = event.error + exceptions_count = event.exceptions_count + else: + status = WorkflowExecutionStatus.STOPPED + error = event.get_stop_reason() + exceptions_count = 0 + workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_id=self._workflow.id, + status=status, + graph_runtime_state=validated_state, + error=error, + exceptions_count=exceptions_count, + ) with self._database_session() as session: - workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed( - workflow_run_id=self._workflow_run_id, - total_tokens=validated_state.total_tokens, - total_steps=validated_state.node_run_steps, - status=WorkflowExecutionStatus.FAILED - if isinstance(event, QueueWorkflowFailedEvent) - else WorkflowExecutionStatus.STOPPED, - error_message=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), - conversation_id=None, - trace_manager=trace_manager, - exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0, - external_trace_id=self._application_generate_entity.extras.get("external_trace_id"), - ) - - # save workflow app log - self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) - - workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_execution=workflow_execution, - ) + self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id) yield workflow_finish_resp @@ -601,7 +533,6 @@ class WorkflowAppGenerateTaskPipeline: self, event: AppQueueEvent, *, - graph_runtime_state: GraphRuntimeState | None = None, tts_publisher: AppGeneratorTTSPublisher | None = None, trace_manager: TraceQueueManager | None = None, queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None, @@ -614,7 +545,6 @@ class WorkflowAppGenerateTaskPipeline: if handler := handlers.get(event_type): yield from handler( event, - graph_runtime_state=graph_runtime_state, tts_publisher=tts_publisher, trace_manager=trace_manager, queue_message=queue_message, @@ -631,7 +561,6 @@ class WorkflowAppGenerateTaskPipeline: ): yield from self._handle_node_failed_events( event, - graph_runtime_state=graph_runtime_state, tts_publisher=tts_publisher, trace_manager=trace_manager, queue_message=queue_message, @@ -642,7 +571,6 @@ class WorkflowAppGenerateTaskPipeline: if isinstance(event, (QueueWorkflowFailedEvent, QueueStopEvent)): yield from self._handle_workflow_failed_and_stop_events( event, - graph_runtime_state=graph_runtime_state, tts_publisher=tts_publisher, trace_manager=trace_manager, queue_message=queue_message, @@ -661,15 +589,12 @@ class WorkflowAppGenerateTaskPipeline: Process stream response using elegant Fluent Python patterns. Maintains exact same functionality as original 44-if-statement version. """ - # Initialize graph runtime state - graph_runtime_state = None - for queue_message in self._base_task_pipeline.queue_manager.listen(): event = queue_message.event match event: case QueueWorkflowStartedEvent(): - graph_runtime_state = event.graph_runtime_state + self._resolve_graph_runtime_state() yield from self._handle_workflow_started_event(event) case QueueTextChunkEvent(): @@ -681,12 +606,19 @@ class WorkflowAppGenerateTaskPipeline: yield from self._handle_error_event(event) break + case QueueWorkflowFailedEvent(): + yield from self._handle_workflow_failed_and_stop_events(event) + break + + case QueueStopEvent(): + yield from self._handle_workflow_failed_and_stop_events(event) + break + # Handle all other events through elegant dispatch case _: if responses := list( self._dispatch_event( event, - graph_runtime_state=graph_runtime_state, tts_publisher=tts_publisher, trace_manager=trace_manager, queue_message=queue_message, @@ -697,7 +629,7 @@ class WorkflowAppGenerateTaskPipeline: if tts_publisher: tts_publisher.publish(None) - def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution): + def _save_workflow_app_log(self, *, session: Session, workflow_run_id: str | None): invoke_from = self._application_generate_entity.invoke_from if invoke_from == InvokeFrom.SERVICE_API: created_from = WorkflowAppLogCreatedFrom.SERVICE_API @@ -709,11 +641,14 @@ class WorkflowAppGenerateTaskPipeline: # not save log for debugging return + if not workflow_run_id: + return + workflow_app_log = WorkflowAppLog() workflow_app_log.tenant_id = self._application_generate_entity.app_config.tenant_id workflow_app_log.app_id = self._application_generate_entity.app_config.app_id - workflow_app_log.workflow_id = workflow_execution.workflow_id - workflow_app_log.workflow_run_id = workflow_execution.id_ + workflow_app_log.workflow_id = self._workflow.id + workflow_app_log.workflow_run_id = workflow_run_id workflow_app_log.created_from = created_from.value workflow_app_log.created_by_role = self._created_by_role workflow_app_log.created_by = self._user_id diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 68eb455d26..5e2bd17f8c 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -25,7 +25,7 @@ from core.app.entities.queue_entities import ( QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) -from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool +from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph from core.workflow.graph_events import ( GraphEngineEvent, @@ -54,6 +54,7 @@ from core.workflow.graph_events.graph import GraphRunAbortedEvent from core.workflow.nodes import NodeType from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from core.workflow.workflow_entry import WorkflowEntry @@ -346,9 +347,7 @@ class WorkflowBasedAppRunner: :param event: event """ if isinstance(event, GraphRunStartedEvent): - self._publish_event( - QueueWorkflowStartedEvent(graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state) - ) + self._publish_event(QueueWorkflowStartedEvent()) elif isinstance(event, GraphRunSucceededEvent): self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs)) elif isinstance(event, GraphRunPartialSucceededEvent): @@ -372,7 +371,6 @@ class WorkflowBasedAppRunner: node_title=event.node_title, node_type=event.node_type, start_at=event.start_at, - predecessor_node_id=event.predecessor_node_id, in_iteration_id=event.in_iteration_id, in_loop_id=event.in_loop_id, inputs=inputs, @@ -393,7 +391,6 @@ class WorkflowBasedAppRunner: node_title=event.node_title, node_type=event.node_type, start_at=event.start_at, - predecessor_node_id=event.predecessor_node_id, in_iteration_id=event.in_iteration_id, in_loop_id=event.in_loop_id, agent_strategy=event.agent_strategy, @@ -494,7 +491,6 @@ class WorkflowBasedAppRunner: start_at=event.start_at, node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, inputs=event.inputs, - predecessor_node_id=event.predecessor_node_id, metadata=event.metadata, ) ) @@ -536,7 +532,6 @@ class WorkflowBasedAppRunner: start_at=event.start_at, node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, inputs=event.inputs, - predecessor_node_id=event.predecessor_node_id, metadata=event.metadata, ) ) diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 76d22d8ac3..77d6bf03b4 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -3,11 +3,11 @@ from datetime import datetime from enum import StrEnum, auto from typing import Any -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from core.workflow.entities import AgentNodeStrategyInit, GraphRuntimeState +from core.workflow.entities import AgentNodeStrategyInit from core.workflow.enums import WorkflowNodeExecutionMetadataKey from core.workflow.nodes import NodeType @@ -54,6 +54,7 @@ class AppQueueEvent(BaseModel): """ event: QueueEvent + model_config = ConfigDict(arbitrary_types_allowed=True) class QueueLLMChunkEvent(AppQueueEvent): @@ -80,7 +81,6 @@ class QueueIterationStartEvent(AppQueueEvent): node_run_index: int inputs: Mapping[str, object] = Field(default_factory=dict) - predecessor_node_id: str | None = None metadata: Mapping[str, object] = Field(default_factory=dict) @@ -132,19 +132,10 @@ class QueueLoopStartEvent(AppQueueEvent): node_id: str node_type: NodeType node_title: str - parallel_id: str | None = None - """parallel id if node is in parallel""" - parallel_start_node_id: str | None = None - """parallel start node id if node is in parallel""" - parent_parallel_id: str | None = None - """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: str | None = None - """parent parallel start node id if node is in parallel""" start_at: datetime node_run_index: int inputs: Mapping[str, object] = Field(default_factory=dict) - predecessor_node_id: str | None = None metadata: Mapping[str, object] = Field(default_factory=dict) @@ -160,16 +151,6 @@ class QueueLoopNextEvent(AppQueueEvent): node_id: str node_type: NodeType node_title: str - parallel_id: str | None = None - """parallel id if node is in parallel""" - parallel_start_node_id: str | None = None - """parallel start node id if node is in parallel""" - parent_parallel_id: str | None = None - """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: str | None = None - """parent parallel start node id if node is in parallel""" - parallel_mode_run_id: str | None = None - """iteration run in parallel mode run id""" node_run_index: int output: Any = None # output for the current loop @@ -185,14 +166,6 @@ class QueueLoopCompletedEvent(AppQueueEvent): node_id: str node_type: NodeType node_title: str - parallel_id: str | None = None - """parallel id if node is in parallel""" - parallel_start_node_id: str | None = None - """parallel start node id if node is in parallel""" - parent_parallel_id: str | None = None - """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: str | None = None - """parent parallel start node id if node is in parallel""" start_at: datetime node_run_index: int @@ -285,12 +258,9 @@ class QueueAdvancedChatMessageEndEvent(AppQueueEvent): class QueueWorkflowStartedEvent(AppQueueEvent): - """ - QueueWorkflowStartedEvent entity - """ + """QueueWorkflowStartedEvent entity.""" event: QueueEvent = QueueEvent.WORKFLOW_STARTED - graph_runtime_state: GraphRuntimeState class QueueWorkflowSucceededEvent(AppQueueEvent): @@ -334,15 +304,9 @@ class QueueNodeStartedEvent(AppQueueEvent): node_title: str node_type: NodeType node_run_index: int = 1 # FIXME(-LAN-): may not used - predecessor_node_id: str | None = None - parallel_id: str | None = None - parallel_start_node_id: str | None = None - parent_parallel_id: str | None = None - parent_parallel_start_node_id: str | None = None in_iteration_id: str | None = None in_loop_id: str | None = None start_at: datetime - parallel_mode_run_id: str | None = None agent_strategy: AgentNodeStrategyInit | None = None # FIXME(-LAN-): only for ToolNode, need to refactor @@ -360,14 +324,6 @@ class QueueNodeSucceededEvent(AppQueueEvent): node_execution_id: str node_id: str node_type: NodeType - parallel_id: str | None = None - """parallel id if node is in parallel""" - parallel_start_node_id: str | None = None - """parallel start node id if node is in parallel""" - parent_parallel_id: str | None = None - """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: str | None = None - """parent parallel start node id if node is in parallel""" in_iteration_id: str | None = None """iteration id if node is in iteration""" in_loop_id: str | None = None @@ -423,14 +379,6 @@ class QueueNodeExceptionEvent(AppQueueEvent): node_execution_id: str node_id: str node_type: NodeType - parallel_id: str | None = None - """parallel id if node is in parallel""" - parallel_start_node_id: str | None = None - """parallel start node id if node is in parallel""" - parent_parallel_id: str | None = None - """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: str | None = None - """parent parallel start node id if node is in parallel""" in_iteration_id: str | None = None """iteration id if node is in iteration""" in_loop_id: str | None = None @@ -455,7 +403,6 @@ class QueueNodeFailedEvent(AppQueueEvent): node_execution_id: str node_id: str node_type: NodeType - parallel_id: str | None = None in_iteration_id: str | None = None """iteration id if node is in iteration""" in_loop_id: str | None = None diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 31dc1eea89..72a92add04 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -257,13 +257,8 @@ class NodeStartStreamResponse(StreamResponse): inputs_truncated: bool = False created_at: int extras: dict[str, object] = Field(default_factory=dict) - parallel_id: str | None = None - parallel_start_node_id: str | None = None - parent_parallel_id: str | None = None - parent_parallel_start_node_id: str | None = None iteration_id: str | None = None loop_id: str | None = None - parallel_run_id: str | None = None agent_strategy: AgentNodeStrategyInit | None = None event: StreamEvent = StreamEvent.NODE_STARTED @@ -285,10 +280,6 @@ class NodeStartStreamResponse(StreamResponse): "inputs": None, "created_at": self.data.created_at, "extras": {}, - "parallel_id": self.data.parallel_id, - "parallel_start_node_id": self.data.parallel_start_node_id, - "parent_parallel_id": self.data.parent_parallel_id, - "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id, "iteration_id": self.data.iteration_id, "loop_id": self.data.loop_id, }, @@ -324,10 +315,6 @@ class NodeFinishStreamResponse(StreamResponse): created_at: int finished_at: int files: Sequence[Mapping[str, Any]] | None = [] - parallel_id: str | None = None - parallel_start_node_id: str | None = None - parent_parallel_id: str | None = None - parent_parallel_start_node_id: str | None = None iteration_id: str | None = None loop_id: str | None = None @@ -357,10 +344,6 @@ class NodeFinishStreamResponse(StreamResponse): "created_at": self.data.created_at, "finished_at": self.data.finished_at, "files": [], - "parallel_id": self.data.parallel_id, - "parallel_start_node_id": self.data.parallel_start_node_id, - "parent_parallel_id": self.data.parent_parallel_id, - "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id, "iteration_id": self.data.iteration_id, "loop_id": self.data.loop_id, }, @@ -396,10 +379,6 @@ class NodeRetryStreamResponse(StreamResponse): created_at: int finished_at: int files: Sequence[Mapping[str, Any]] | None = [] - parallel_id: str | None = None - parallel_start_node_id: str | None = None - parent_parallel_id: str | None = None - parent_parallel_start_node_id: str | None = None iteration_id: str | None = None loop_id: str | None = None retry_index: int = 0 @@ -430,10 +409,6 @@ class NodeRetryStreamResponse(StreamResponse): "created_at": self.data.created_at, "finished_at": self.data.finished_at, "files": [], - "parallel_id": self.data.parallel_id, - "parallel_start_node_id": self.data.parallel_start_node_id, - "parent_parallel_id": self.data.parent_parallel_id, - "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id, "iteration_id": self.data.iteration_id, "loop_id": self.data.loop_id, "retry_index": self.data.retry_index, @@ -541,8 +516,6 @@ class LoopNodeStartStreamResponse(StreamResponse): metadata: Mapping = {} inputs: Mapping = {} inputs_truncated: bool = False - parallel_id: str | None = None - parallel_start_node_id: str | None = None event: StreamEvent = StreamEvent.LOOP_STARTED workflow_run_id: str @@ -567,9 +540,6 @@ class LoopNodeNextStreamResponse(StreamResponse): created_at: int pre_loop_output: Any = None extras: Mapping[str, object] = Field(default_factory=dict) - parallel_id: str | None = None - parallel_start_node_id: str | None = None - parallel_mode_run_id: str | None = None event: StreamEvent = StreamEvent.LOOP_NEXT workflow_run_id: str @@ -603,8 +573,6 @@ class LoopNodeCompletedStreamResponse(StreamResponse): execution_metadata: Mapping[str, object] = Field(default_factory=dict) finished_at: int steps: int - parallel_id: str | None = None - parallel_start_node_id: str | None = None event: StreamEvent = StreamEvent.LOOP_COMPLETED workflow_run_id: str diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 5f2ffefd94..d74b2bddf5 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -18,7 +18,7 @@ from core.model_runtime.entities.message_entities import ImagePromptMessageConte from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.prompt_transform import PromptTransform from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.workflow.entities.variable_pool import VariablePool +from core.workflow.runtime import VariablePool class AdvancedPromptTransform(PromptTransform): diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 4399ec01cc..4436773d25 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -104,7 +104,6 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER # Initialize in-memory cache for node executions - # Key: node_execution_id, Value: WorkflowNodeExecution (DB model) self._node_execution_cache: dict[str, WorkflowNodeExecutionModel] = {} # Initialize FileService for handling offloaded data @@ -332,17 +331,10 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) Args: execution: The NodeExecution domain entity to persist """ - # NOTE: As per the implementation of `WorkflowCycleManager`, - # the `save` method is invoked multiple times during the node's execution lifecycle, including: - # - # - When the node starts execution - # - When the node retries execution - # - When the node completes execution (either successfully or with failure) - # - # Only the final invocation will have `inputs` and `outputs` populated. - # - # This simplifies the logic for saving offloaded variables but introduces a tight coupling - # between this module and `WorkflowCycleManager`. + # NOTE: The workflow engine triggers `save` multiple times for a single node execution: + # when the node starts, any time it retries, and once more when it reaches a terminal state. + # Only the final call contains the complete inputs and outputs payloads, so earlier invocations + # must tolerate missing data without attempting to offload variables. # Convert domain model to database model using tenant context and other attributes db_model = self._to_db_model(execution) diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index af68971ca7..006cf856d5 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -58,8 +58,8 @@ from services.tools.mcp_tools_manage_service import MCPToolManageService from services.tools.tools_transform_service import ToolTransformService if TYPE_CHECKING: - from core.workflow.entities import VariablePool from core.workflow.nodes.tool.entities import ToolEntity + from core.workflow.runtime import VariablePool logger = logging.getLogger(__name__) diff --git a/api/core/workflow/entities/__init__.py b/api/core/workflow/entities/__init__.py index 007bf42aa6..be70e467a0 100644 --- a/api/core/workflow/entities/__init__.py +++ b/api/core/workflow/entities/__init__.py @@ -1,18 +1,11 @@ from .agent import AgentNodeStrategyInit from .graph_init_params import GraphInitParams -from .graph_runtime_state import GraphRuntimeState -from .run_condition import RunCondition -from .variable_pool import VariablePool, VariableValue from .workflow_execution import WorkflowExecution from .workflow_node_execution import WorkflowNodeExecution __all__ = [ "AgentNodeStrategyInit", "GraphInitParams", - "GraphRuntimeState", - "RunCondition", - "VariablePool", - "VariableValue", "WorkflowExecution", "WorkflowNodeExecution", ] diff --git a/api/core/workflow/entities/graph_runtime_state.py b/api/core/workflow/entities/graph_runtime_state.py deleted file mode 100644 index 6362f291ea..0000000000 --- a/api/core/workflow/entities/graph_runtime_state.py +++ /dev/null @@ -1,160 +0,0 @@ -from copy import deepcopy - -from pydantic import BaseModel, PrivateAttr - -from core.model_runtime.entities.llm_entities import LLMUsage - -from .variable_pool import VariablePool - - -class GraphRuntimeState(BaseModel): - # Private attributes to prevent direct modification - _variable_pool: VariablePool = PrivateAttr() - _start_at: float = PrivateAttr() - _total_tokens: int = PrivateAttr(default=0) - _llm_usage: LLMUsage = PrivateAttr(default_factory=LLMUsage.empty_usage) - _outputs: dict[str, object] = PrivateAttr(default_factory=dict[str, object]) - _node_run_steps: int = PrivateAttr(default=0) - _ready_queue_json: str = PrivateAttr() - _graph_execution_json: str = PrivateAttr() - _response_coordinator_json: str = PrivateAttr() - - def __init__( - self, - *, - variable_pool: VariablePool, - start_at: float, - total_tokens: int = 0, - llm_usage: LLMUsage | None = None, - outputs: dict[str, object] | None = None, - node_run_steps: int = 0, - ready_queue_json: str = "", - graph_execution_json: str = "", - response_coordinator_json: str = "", - **kwargs: object, - ): - """Initialize the GraphRuntimeState with validation.""" - super().__init__(**kwargs) - - # Initialize private attributes with validation - self._variable_pool = variable_pool - - self._start_at = start_at - - if total_tokens < 0: - raise ValueError("total_tokens must be non-negative") - self._total_tokens = total_tokens - - if llm_usage is None: - llm_usage = LLMUsage.empty_usage() - self._llm_usage = llm_usage - - if outputs is None: - outputs = {} - self._outputs = deepcopy(outputs) - - if node_run_steps < 0: - raise ValueError("node_run_steps must be non-negative") - self._node_run_steps = node_run_steps - - self._ready_queue_json = ready_queue_json - self._graph_execution_json = graph_execution_json - self._response_coordinator_json = response_coordinator_json - - @property - def variable_pool(self) -> VariablePool: - """Get the variable pool.""" - return self._variable_pool - - @property - def start_at(self) -> float: - """Get the start time.""" - return self._start_at - - @start_at.setter - def start_at(self, value: float) -> None: - """Set the start time.""" - self._start_at = value - - @property - def total_tokens(self) -> int: - """Get the total tokens count.""" - return self._total_tokens - - @total_tokens.setter - def total_tokens(self, value: int): - """Set the total tokens count.""" - if value < 0: - raise ValueError("total_tokens must be non-negative") - self._total_tokens = value - - @property - def llm_usage(self) -> LLMUsage: - """Get the LLM usage info.""" - # Return a copy to prevent external modification - return self._llm_usage.model_copy() - - @llm_usage.setter - def llm_usage(self, value: LLMUsage): - """Set the LLM usage info.""" - self._llm_usage = value.model_copy() - - @property - def outputs(self) -> dict[str, object]: - """Get a copy of the outputs dictionary.""" - return deepcopy(self._outputs) - - @outputs.setter - def outputs(self, value: dict[str, object]) -> None: - """Set the outputs dictionary.""" - self._outputs = deepcopy(value) - - def set_output(self, key: str, value: object) -> None: - """Set a single output value.""" - self._outputs[key] = deepcopy(value) - - def get_output(self, key: str, default: object = None) -> object: - """Get a single output value.""" - return deepcopy(self._outputs.get(key, default)) - - def update_outputs(self, updates: dict[str, object]) -> None: - """Update multiple output values.""" - for key, value in updates.items(): - self._outputs[key] = deepcopy(value) - - @property - def node_run_steps(self) -> int: - """Get the node run steps count.""" - return self._node_run_steps - - @node_run_steps.setter - def node_run_steps(self, value: int) -> None: - """Set the node run steps count.""" - if value < 0: - raise ValueError("node_run_steps must be non-negative") - self._node_run_steps = value - - def increment_node_run_steps(self) -> None: - """Increment the node run steps by 1.""" - self._node_run_steps += 1 - - def add_tokens(self, tokens: int) -> None: - """Add tokens to the total count.""" - if tokens < 0: - raise ValueError("tokens must be non-negative") - self._total_tokens += tokens - - @property - def ready_queue_json(self) -> str: - """Get a copy of the ready queue state.""" - return self._ready_queue_json - - @property - def graph_execution_json(self) -> str: - """Get a copy of the serialized graph execution state.""" - return self._graph_execution_json - - @property - def response_coordinator_json(self) -> str: - """Get a copy of the serialized response coordinator state.""" - return self._response_coordinator_json diff --git a/api/core/workflow/entities/run_condition.py b/api/core/workflow/entities/run_condition.py deleted file mode 100644 index 7b9a379215..0000000000 --- a/api/core/workflow/entities/run_condition.py +++ /dev/null @@ -1,21 +0,0 @@ -import hashlib -from typing import Literal - -from pydantic import BaseModel - -from core.workflow.utils.condition.entities import Condition - - -class RunCondition(BaseModel): - type: Literal["branch_identify", "condition"] - """condition type""" - - branch_identify: str | None = None - """branch identify like: sourceHandle, required when type is branch_identify""" - - conditions: list[Condition] | None = None - """conditions to run the node, required when type is condition""" - - @property - def hash(self) -> str: - return hashlib.sha256(self.model_dump_json().encode()).hexdigest() diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py index eb88bb67ee..83b9281e51 100644 --- a/api/core/workflow/enums.py +++ b/api/core/workflow/enums.py @@ -58,6 +58,7 @@ class NodeType(StrEnum): DOCUMENT_EXTRACTOR = "document-extractor" LIST_OPERATOR = "list-operator" AGENT = "agent" + HUMAN_INPUT = "human-input" class NodeExecutionType(StrEnum): @@ -96,6 +97,7 @@ class WorkflowExecutionStatus(StrEnum): FAILED = "failed" STOPPED = "stopped" PARTIAL_SUCCEEDED = "partial-succeeded" + PAUSED = "paused" class WorkflowNodeExecutionMetadataKey(StrEnum): diff --git a/api/core/workflow/graph/__init__.py b/api/core/workflow/graph/__init__.py index 31a81d494e..4830ea83d3 100644 --- a/api/core/workflow/graph/__init__.py +++ b/api/core/workflow/graph/__init__.py @@ -1,16 +1,11 @@ from .edge import Edge -from .graph import Graph, NodeFactory -from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool +from .graph import Graph, GraphBuilder, NodeFactory from .graph_template import GraphTemplate -from .read_only_state_wrapper import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper __all__ = [ "Edge", "Graph", + "GraphBuilder", "GraphTemplate", "NodeFactory", - "ReadOnlyGraphRuntimeState", - "ReadOnlyGraphRuntimeStateWrapper", - "ReadOnlyVariablePool", - "ReadOnlyVariablePoolWrapper", ] diff --git a/api/core/workflow/graph/graph.py b/api/core/workflow/graph/graph.py index 330e14de81..20b5193875 100644 --- a/api/core/workflow/graph/graph.py +++ b/api/core/workflow/graph/graph.py @@ -195,6 +195,12 @@ class Graph: return nodes + @classmethod + def new(cls) -> "GraphBuilder": + """Create a fluent builder for assembling a graph programmatically.""" + + return GraphBuilder(graph_cls=cls) + @classmethod def _mark_inactive_root_branches( cls, @@ -344,3 +350,96 @@ class Graph: """ edge_ids = self.in_edges.get(node_id, []) return [self.edges[eid] for eid in edge_ids if eid in self.edges] + + +@final +class GraphBuilder: + """Fluent helper for constructing simple graphs, primarily for tests.""" + + def __init__(self, *, graph_cls: type[Graph]): + self._graph_cls = graph_cls + self._nodes: list[Node] = [] + self._nodes_by_id: dict[str, Node] = {} + self._edges: list[Edge] = [] + self._edge_counter = 0 + + def add_root(self, node: Node) -> "GraphBuilder": + """Register the root node. Must be called exactly once.""" + + if self._nodes: + raise ValueError("Root node has already been added") + self._register_node(node) + self._nodes.append(node) + return self + + def add_node( + self, + node: Node, + *, + from_node_id: str | None = None, + source_handle: str = "source", + ) -> "GraphBuilder": + """Append a node and connect it from the specified predecessor.""" + + if not self._nodes: + raise ValueError("Root node must be added before adding other nodes") + + predecessor_id = from_node_id or self._nodes[-1].id + if predecessor_id not in self._nodes_by_id: + raise ValueError(f"Predecessor node '{predecessor_id}' not found") + + predecessor = self._nodes_by_id[predecessor_id] + self._register_node(node) + self._nodes.append(node) + + edge_id = f"edge_{self._edge_counter}" + self._edge_counter += 1 + edge = Edge(id=edge_id, tail=predecessor.id, head=node.id, source_handle=source_handle) + self._edges.append(edge) + + return self + + def connect(self, *, tail: str, head: str, source_handle: str = "source") -> "GraphBuilder": + """Connect two existing nodes without adding a new node.""" + + if tail not in self._nodes_by_id: + raise ValueError(f"Tail node '{tail}' not found") + if head not in self._nodes_by_id: + raise ValueError(f"Head node '{head}' not found") + + edge_id = f"edge_{self._edge_counter}" + self._edge_counter += 1 + edge = Edge(id=edge_id, tail=tail, head=head, source_handle=source_handle) + self._edges.append(edge) + + return self + + def build(self) -> Graph: + """Materialize the graph instance from the accumulated nodes and edges.""" + + if not self._nodes: + raise ValueError("Cannot build an empty graph") + + nodes = {node.id: node for node in self._nodes} + edges = {edge.id: edge for edge in self._edges} + in_edges: dict[str, list[str]] = defaultdict(list) + out_edges: dict[str, list[str]] = defaultdict(list) + + for edge in self._edges: + out_edges[edge.tail].append(edge.id) + in_edges[edge.head].append(edge.id) + + return self._graph_cls( + nodes=nodes, + edges=edges, + in_edges=dict(in_edges), + out_edges=dict(out_edges), + root_node=self._nodes[0], + ) + + def _register_node(self, node: Node) -> None: + if not node.id: + raise ValueError("Node must have a non-empty id") + if node.id in self._nodes_by_id: + raise ValueError(f"Duplicate node id detected: {node.id}") + self._nodes_by_id[node.id] = node diff --git a/api/core/workflow/graph_engine/command_channels/redis_channel.py b/api/core/workflow/graph_engine/command_channels/redis_channel.py index 527647ae3b..4be3adb8f8 100644 --- a/api/core/workflow/graph_engine/command_channels/redis_channel.py +++ b/api/core/workflow/graph_engine/command_channels/redis_channel.py @@ -9,7 +9,7 @@ Each instance uses a unique key for its command queue. import json from typing import TYPE_CHECKING, Any, final -from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand +from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand if TYPE_CHECKING: from extensions.ext_redis import RedisClientWrapper @@ -111,9 +111,11 @@ class RedisChannel: if command_type == CommandType.ABORT: return AbortCommand.model_validate(data) - else: - # For other command types, use base class - return GraphEngineCommand.model_validate(data) + if command_type == CommandType.PAUSE: + return PauseCommand.model_validate(data) + + # For other command types, use base class + return GraphEngineCommand.model_validate(data) except (ValueError, TypeError): return None diff --git a/api/core/workflow/graph_engine/command_processing/__init__.py b/api/core/workflow/graph_engine/command_processing/__init__.py index 3460b52226..837f5e55fd 100644 --- a/api/core/workflow/graph_engine/command_processing/__init__.py +++ b/api/core/workflow/graph_engine/command_processing/__init__.py @@ -5,10 +5,11 @@ This package handles external commands sent to the engine during execution. """ -from .command_handlers import AbortCommandHandler +from .command_handlers import AbortCommandHandler, PauseCommandHandler from .command_processor import CommandProcessor __all__ = [ "AbortCommandHandler", "CommandProcessor", + "PauseCommandHandler", ] diff --git a/api/core/workflow/graph_engine/command_processing/command_handlers.py b/api/core/workflow/graph_engine/command_processing/command_handlers.py index 3c51de99f3..c26c98c496 100644 --- a/api/core/workflow/graph_engine/command_processing/command_handlers.py +++ b/api/core/workflow/graph_engine/command_processing/command_handlers.py @@ -1,14 +1,10 @@ -""" -Command handler implementations. -""" - import logging from typing import final from typing_extensions import override from ..domain.graph_execution import GraphExecution -from ..entities.commands import AbortCommand, GraphEngineCommand +from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand from .command_processor import CommandHandler logger = logging.getLogger(__name__) @@ -16,17 +12,17 @@ logger = logging.getLogger(__name__) @final class AbortCommandHandler(CommandHandler): - """Handles abort commands.""" - @override def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: - """ - Handle an abort command. - - Args: - command: The abort command - execution: Graph execution to abort - """ assert isinstance(command, AbortCommand) logger.debug("Aborting workflow %s: %s", execution.workflow_id, command.reason) execution.abort(command.reason or "User requested abort") + + +@final +class PauseCommandHandler(CommandHandler): + @override + def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: + assert isinstance(command, PauseCommand) + logger.debug("Pausing workflow %s: %s", execution.workflow_id, command.reason) + execution.pause(command.reason) diff --git a/api/core/workflow/graph_engine/domain/graph_execution.py b/api/core/workflow/graph_engine/domain/graph_execution.py index b273ee9969..6482c927d6 100644 --- a/api/core/workflow/graph_engine/domain/graph_execution.py +++ b/api/core/workflow/graph_engine/domain/graph_execution.py @@ -40,6 +40,8 @@ class GraphExecutionState(BaseModel): started: bool = Field(default=False) completed: bool = Field(default=False) aborted: bool = Field(default=False) + paused: bool = Field(default=False) + pause_reason: str | None = Field(default=None) error: GraphExecutionErrorState | None = Field(default=None) exceptions_count: int = Field(default=0) node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState]) @@ -103,6 +105,8 @@ class GraphExecution: started: bool = False completed: bool = False aborted: bool = False + paused: bool = False + pause_reason: str | None = None error: Exception | None = None node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution]) exceptions_count: int = 0 @@ -126,6 +130,17 @@ class GraphExecution: self.aborted = True self.error = RuntimeError(f"Aborted: {reason}") + def pause(self, reason: str | None = None) -> None: + """Pause the graph execution without marking it complete.""" + if self.completed: + raise RuntimeError("Cannot pause execution that has completed") + if self.aborted: + raise RuntimeError("Cannot pause execution that has been aborted") + if self.paused: + return + self.paused = True + self.pause_reason = reason + def fail(self, error: Exception) -> None: """Mark the graph execution as failed.""" self.error = error @@ -140,7 +155,12 @@ class GraphExecution: @property def is_running(self) -> bool: """Check if the execution is currently running.""" - return self.started and not self.completed and not self.aborted + return self.started and not self.completed and not self.aborted and not self.paused + + @property + def is_paused(self) -> bool: + """Check if the execution is currently paused.""" + return self.paused @property def has_error(self) -> bool: @@ -173,6 +193,8 @@ class GraphExecution: started=self.started, completed=self.completed, aborted=self.aborted, + paused=self.paused, + pause_reason=self.pause_reason, error=_serialize_error(self.error), exceptions_count=self.exceptions_count, node_executions=node_states, @@ -197,6 +219,8 @@ class GraphExecution: self.started = state.started self.completed = state.completed self.aborted = state.aborted + self.paused = state.paused + self.pause_reason = state.pause_reason self.error = _deserialize_error(state.error) self.exceptions_count = state.exceptions_count self.node_executions = { diff --git a/api/core/workflow/graph_engine/entities/commands.py b/api/core/workflow/graph_engine/entities/commands.py index 123ef3d449..6070ed8812 100644 --- a/api/core/workflow/graph_engine/entities/commands.py +++ b/api/core/workflow/graph_engine/entities/commands.py @@ -16,7 +16,6 @@ class CommandType(StrEnum): ABORT = "abort" PAUSE = "pause" - RESUME = "resume" class GraphEngineCommand(BaseModel): @@ -31,3 +30,10 @@ class AbortCommand(GraphEngineCommand): command_type: CommandType = Field(default=CommandType.ABORT, description="Type of command") reason: str | None = Field(default=None, description="Optional reason for abort") + + +class PauseCommand(GraphEngineCommand): + """Command to pause a running workflow execution.""" + + command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command") + reason: str | None = Field(default=None, description="Optional reason for pause") diff --git a/api/core/workflow/graph_engine/event_management/event_handlers.py b/api/core/workflow/graph_engine/event_management/event_handlers.py index 1cb5851ab1..fe99d3ad50 100644 --- a/api/core/workflow/graph_engine/event_management/event_handlers.py +++ b/api/core/workflow/graph_engine/event_management/event_handlers.py @@ -8,8 +8,7 @@ from functools import singledispatchmethod from typing import TYPE_CHECKING, final from core.model_runtime.entities.llm_entities import LLMUsage -from core.workflow.entities import GraphRuntimeState -from core.workflow.enums import ErrorStrategy, NodeExecutionType +from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState from core.workflow.graph import Graph from core.workflow.graph_events import ( GraphNodeEventBase, @@ -24,11 +23,13 @@ from core.workflow.graph_events import ( NodeRunLoopNextEvent, NodeRunLoopStartedEvent, NodeRunLoopSucceededEvent, + NodeRunPauseRequestedEvent, NodeRunRetryEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) +from core.workflow.runtime import GraphRuntimeState from ..domain.graph_execution import GraphExecution from ..response_coordinator import ResponseStreamCoordinator @@ -203,6 +204,18 @@ class EventHandler: # Collect the event self._event_collector.collect(event) + @_dispatch.register + def _(self, event: NodeRunPauseRequestedEvent) -> None: + """Handle pause requests emitted by nodes.""" + + pause_reason = event.reason or "Awaiting human input" + self._graph_execution.pause(pause_reason) + self._state_manager.finish_execution(event.node_id) + if event.node_id in self._graph.nodes: + self._graph.nodes[event.node_id].state = NodeState.UNKNOWN + self._graph_runtime_state.register_paused_node(event.node_id) + self._event_collector.collect(event) + @_dispatch.register def _(self, event: NodeRunFailedEvent) -> None: """ diff --git a/api/core/workflow/graph_engine/event_management/event_manager.py b/api/core/workflow/graph_engine/event_management/event_manager.py index 751a2a4352..689cf53cf0 100644 --- a/api/core/workflow/graph_engine/event_management/event_manager.py +++ b/api/core/workflow/graph_engine/event_management/event_manager.py @@ -97,6 +97,10 @@ class EventManager: """ self._layers = layers + def notify_layers(self, event: GraphEngineEvent) -> None: + """Notify registered layers about an event without buffering it.""" + self._notify_layers(event) + def collect(self, event: GraphEngineEvent) -> None: """ Thread-safe method to collect an event. diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index a21fb7c022..dd2ca3f93b 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -9,28 +9,29 @@ import contextvars import logging import queue from collections.abc import Generator -from typing import final +from typing import TYPE_CHECKING, cast, final from flask import Flask, current_app -from core.workflow.entities import GraphRuntimeState from core.workflow.enums import NodeExecutionType from core.workflow.graph import Graph -from core.workflow.graph.read_only_state_wrapper import ReadOnlyGraphRuntimeStateWrapper -from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue from core.workflow.graph_events import ( GraphEngineEvent, GraphNodeEventBase, GraphRunAbortedEvent, GraphRunFailedEvent, GraphRunPartialSucceededEvent, + GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, ) +from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper -from .command_processing import AbortCommandHandler, CommandProcessor -from .domain import GraphExecution -from .entities.commands import AbortCommand +if TYPE_CHECKING: # pragma: no cover - used only for static analysis + from core.workflow.runtime.graph_runtime_state import GraphProtocol + +from .command_processing import AbortCommandHandler, CommandProcessor, PauseCommandHandler +from .entities.commands import AbortCommand, PauseCommand from .error_handler import ErrorHandler from .event_management import EventHandler, EventManager from .graph_state_manager import GraphStateManager @@ -38,10 +39,13 @@ from .graph_traversal import EdgeProcessor, SkipPropagator from .layers.base import GraphEngineLayer from .orchestration import Dispatcher, ExecutionCoordinator from .protocols.command_channel import CommandChannel -from .ready_queue import ReadyQueue, ReadyQueueState, create_ready_queue_from_state -from .response_coordinator import ResponseStreamCoordinator +from .ready_queue import ReadyQueue from .worker_management import WorkerPool +if TYPE_CHECKING: + from core.workflow.graph_engine.domain.graph_execution import GraphExecution + from core.workflow.graph_engine.response_coordinator import ResponseStreamCoordinator + logger = logging.getLogger(__name__) @@ -67,17 +71,16 @@ class GraphEngine: ) -> None: """Initialize the graph engine with all subsystems and dependencies.""" - # Graph execution tracks the overall execution state - self._graph_execution = GraphExecution(workflow_id=workflow_id) - if graph_runtime_state.graph_execution_json != "": - self._graph_execution.loads(graph_runtime_state.graph_execution_json) - - # === Core Dependencies === - # Graph structure and configuration + # Bind runtime state to current workflow context self._graph = graph self._graph_runtime_state = graph_runtime_state + self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph)) self._command_channel = command_channel + # Graph execution tracks the overall execution state + self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution) + self._graph_execution.workflow_id = workflow_id + # === Worker Management Parameters === # Parameters for dynamic worker pool scaling self._min_workers = min_workers @@ -86,13 +89,7 @@ class GraphEngine: self._scale_down_idle_time = scale_down_idle_time # === Execution Queues === - # Create ready queue from saved state or initialize new one - self._ready_queue: ReadyQueue - if self._graph_runtime_state.ready_queue_json == "": - self._ready_queue = InMemoryReadyQueue() - else: - ready_queue_state = ReadyQueueState.model_validate_json(self._graph_runtime_state.ready_queue_json) - self._ready_queue = create_ready_queue_from_state(ready_queue_state) + self._ready_queue = cast(ReadyQueue, self._graph_runtime_state.ready_queue) # Queue for events generated during execution self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue() @@ -103,11 +100,7 @@ class GraphEngine: # === Response Coordination === # Coordinates response streaming from response nodes - self._response_coordinator = ResponseStreamCoordinator( - variable_pool=self._graph_runtime_state.variable_pool, graph=self._graph - ) - if graph_runtime_state.response_coordinator_json != "": - self._response_coordinator.loads(graph_runtime_state.response_coordinator_json) + self._response_coordinator = cast("ResponseStreamCoordinator", self._graph_runtime_state.response_coordinator) # === Event Management === # Event manager handles both collection and emission of events @@ -133,19 +126,6 @@ class GraphEngine: skip_propagator=self._skip_propagator, ) - # === Event Handler Registry === - # Central registry for handling all node execution events - self._event_handler_registry = EventHandler( - graph=self._graph, - graph_runtime_state=self._graph_runtime_state, - graph_execution=self._graph_execution, - response_coordinator=self._response_coordinator, - event_collector=self._event_manager, - edge_processor=self._edge_processor, - state_manager=self._state_manager, - error_handler=self._error_handler, - ) - # === Command Processing === # Processes external commands (e.g., abort requests) self._command_processor = CommandProcessor( @@ -153,12 +133,12 @@ class GraphEngine: graph_execution=self._graph_execution, ) - # Register abort command handler + # Register command handlers abort_handler = AbortCommandHandler() - self._command_processor.register_handler( - AbortCommand, - abort_handler, - ) + self._command_processor.register_handler(AbortCommand, abort_handler) + + pause_handler = PauseCommandHandler() + self._command_processor.register_handler(PauseCommand, pause_handler) # === Worker Pool Setup === # Capture Flask app context for worker threads @@ -191,12 +171,23 @@ class GraphEngine: self._execution_coordinator = ExecutionCoordinator( graph_execution=self._graph_execution, state_manager=self._state_manager, - event_handler=self._event_handler_registry, - event_collector=self._event_manager, command_processor=self._command_processor, worker_pool=self._worker_pool, ) + # === Event Handler Registry === + # Central registry for handling all node execution events + self._event_handler_registry = EventHandler( + graph=self._graph, + graph_runtime_state=self._graph_runtime_state, + graph_execution=self._graph_execution, + response_coordinator=self._response_coordinator, + event_collector=self._event_manager, + edge_processor=self._edge_processor, + state_manager=self._state_manager, + error_handler=self._error_handler, + ) + # Dispatches events and manages execution flow self._dispatcher = Dispatcher( event_queue=self._event_queue, @@ -237,26 +228,41 @@ class GraphEngine: # Initialize layers self._initialize_layers() - # Start execution - self._graph_execution.start() + is_resume = self._graph_execution.started + if not is_resume: + self._graph_execution.start() + else: + self._graph_execution.paused = False + self._graph_execution.pause_reason = None + start_event = GraphRunStartedEvent() + self._event_manager.notify_layers(start_event) yield start_event # Start subsystems - self._start_execution() + self._start_execution(resume=is_resume) # Yield events as they occur yield from self._event_manager.emit_events() # Handle completion - if self._graph_execution.aborted: + if self._graph_execution.is_paused: + paused_event = GraphRunPausedEvent( + reason=self._graph_execution.pause_reason, + outputs=self._graph_runtime_state.outputs, + ) + self._event_manager.notify_layers(paused_event) + yield paused_event + elif self._graph_execution.aborted: abort_reason = "Workflow execution aborted by user command" if self._graph_execution.error: abort_reason = str(self._graph_execution.error) - yield GraphRunAbortedEvent( + aborted_event = GraphRunAbortedEvent( reason=abort_reason, outputs=self._graph_runtime_state.outputs, ) + self._event_manager.notify_layers(aborted_event) + yield aborted_event elif self._graph_execution.has_error: if self._graph_execution.error: raise self._graph_execution.error @@ -264,20 +270,26 @@ class GraphEngine: outputs = self._graph_runtime_state.outputs exceptions_count = self._graph_execution.exceptions_count if exceptions_count > 0: - yield GraphRunPartialSucceededEvent( + partial_event = GraphRunPartialSucceededEvent( exceptions_count=exceptions_count, outputs=outputs, ) + self._event_manager.notify_layers(partial_event) + yield partial_event else: - yield GraphRunSucceededEvent( + succeeded_event = GraphRunSucceededEvent( outputs=outputs, ) + self._event_manager.notify_layers(succeeded_event) + yield succeeded_event except Exception as e: - yield GraphRunFailedEvent( + failed_event = GraphRunFailedEvent( error=str(e), exceptions_count=self._graph_execution.exceptions_count, ) + self._event_manager.notify_layers(failed_event) + yield failed_event raise finally: @@ -299,8 +311,12 @@ class GraphEngine: except Exception as e: logger.warning("Layer %s failed on_graph_start: %s", layer.__class__.__name__, e) - def _start_execution(self) -> None: + def _start_execution(self, *, resume: bool = False) -> None: """Start execution subsystems.""" + paused_nodes: list[str] = [] + if resume: + paused_nodes = self._graph_runtime_state.consume_paused_nodes() + # Start worker pool (it calculates initial workers internally) self._worker_pool.start() @@ -309,10 +325,15 @@ class GraphEngine: if node.execution_type == NodeExecutionType.RESPONSE: self._response_coordinator.register(node.id) - # Enqueue root node - root_node = self._graph.root_node - self._state_manager.enqueue_node(root_node.id) - self._state_manager.start_execution(root_node.id) + if not resume: + # Enqueue root node + root_node = self._graph.root_node + self._state_manager.enqueue_node(root_node.id) + self._state_manager.start_execution(root_node.id) + else: + for node_id in paused_nodes: + self._state_manager.enqueue_node(node_id) + self._state_manager.start_execution(node_id) # Start dispatcher self._dispatcher.start() diff --git a/api/core/workflow/graph_engine/layers/base.py b/api/core/workflow/graph_engine/layers/base.py index dfac49e11a..24c12c2934 100644 --- a/api/core/workflow/graph_engine/layers/base.py +++ b/api/core/workflow/graph_engine/layers/base.py @@ -7,9 +7,9 @@ intercept and respond to GraphEngine events. from abc import ABC, abstractmethod -from core.workflow.graph.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState from core.workflow.graph_engine.protocols.command_channel import CommandChannel from core.workflow.graph_events import GraphEngineEvent +from core.workflow.runtime import ReadOnlyGraphRuntimeState class GraphEngineLayer(ABC): diff --git a/api/core/workflow/graph_engine/layers/persistence.py b/api/core/workflow/graph_engine/layers/persistence.py new file mode 100644 index 0000000000..ecd8e12ca5 --- /dev/null +++ b/api/core/workflow/graph_engine/layers/persistence.py @@ -0,0 +1,410 @@ +"""Workflow persistence layer for GraphEngine. + +This layer mirrors the former ``WorkflowCycleManager`` responsibilities by +listening to ``GraphEngineEvent`` instances directly and persisting workflow +and node execution state via the injected repositories. + +The design keeps domain persistence concerns inside the engine thread, while +allowing presentation layers to remain read-only observers of repository +state. +""" + +from collections.abc import Mapping +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Union + +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID +from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution +from core.workflow.enums import ( + SystemVariableKey, + WorkflowExecutionStatus, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, + WorkflowType, +) +from core.workflow.graph_engine.layers.base import GraphEngineLayer +from core.workflow.graph_events import ( + GraphEngineEvent, + GraphRunAbortedEvent, + GraphRunFailedEvent, + GraphRunPartialSucceededEvent, + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunExceptionEvent, + NodeRunFailedEvent, + NodeRunPauseRequestedEvent, + NodeRunRetryEvent, + NodeRunStartedEvent, + NodeRunSucceededEvent, +) +from core.workflow.node_events import NodeRunResult +from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.workflow_entry import WorkflowEntry +from libs.datetime_utils import naive_utc_now + + +@dataclass(slots=True) +class PersistenceWorkflowInfo: + """Static workflow metadata required for persistence.""" + + workflow_id: str + workflow_type: WorkflowType + version: str + graph_data: Mapping[str, Any] + + +@dataclass(slots=True) +class _NodeRuntimeSnapshot: + """Lightweight cache to keep node metadata across event phases.""" + + node_id: str + title: str + predecessor_node_id: str | None + iteration_id: str | None + loop_id: str | None + created_at: datetime + + +class WorkflowPersistenceLayer(GraphEngineLayer): + """GraphEngine layer that persists workflow and node execution state.""" + + def __init__( + self, + *, + application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], + workflow_info: PersistenceWorkflowInfo, + workflow_execution_repository: WorkflowExecutionRepository, + workflow_node_execution_repository: WorkflowNodeExecutionRepository, + trace_manager: TraceQueueManager | None = None, + ) -> None: + super().__init__() + self._application_generate_entity = application_generate_entity + self._workflow_info = workflow_info + self._workflow_execution_repository = workflow_execution_repository + self._workflow_node_execution_repository = workflow_node_execution_repository + self._trace_manager = trace_manager + + self._workflow_execution: WorkflowExecution | None = None + self._node_execution_cache: dict[str, WorkflowNodeExecution] = {} + self._node_snapshots: dict[str, _NodeRuntimeSnapshot] = {} + self._node_sequence: int = 0 + + # ------------------------------------------------------------------ + # GraphEngineLayer lifecycle + # ------------------------------------------------------------------ + def on_graph_start(self) -> None: + self._workflow_execution = None + self._node_execution_cache.clear() + self._node_snapshots.clear() + self._node_sequence = 0 + + def on_event(self, event: GraphEngineEvent) -> None: + if isinstance(event, GraphRunStartedEvent): + self._handle_graph_run_started() + return + + if isinstance(event, GraphRunSucceededEvent): + self._handle_graph_run_succeeded(event) + return + + if isinstance(event, GraphRunPartialSucceededEvent): + self._handle_graph_run_partial_succeeded(event) + return + + if isinstance(event, GraphRunFailedEvent): + self._handle_graph_run_failed(event) + return + + if isinstance(event, GraphRunAbortedEvent): + self._handle_graph_run_aborted(event) + return + + if isinstance(event, GraphRunPausedEvent): + self._handle_graph_run_paused(event) + return + + if isinstance(event, NodeRunStartedEvent): + self._handle_node_started(event) + return + + if isinstance(event, NodeRunRetryEvent): + self._handle_node_retry(event) + return + + if isinstance(event, NodeRunSucceededEvent): + self._handle_node_succeeded(event) + return + + if isinstance(event, NodeRunFailedEvent): + self._handle_node_failed(event) + return + + if isinstance(event, NodeRunExceptionEvent): + self._handle_node_exception(event) + return + + if isinstance(event, NodeRunPauseRequestedEvent): + self._handle_node_pause_requested(event) + + def on_graph_end(self, error: Exception | None) -> None: + return + + # ------------------------------------------------------------------ + # Graph-level handlers + # ------------------------------------------------------------------ + def _handle_graph_run_started(self) -> None: + execution_id = self._get_execution_id() + workflow_execution = WorkflowExecution.new( + id_=execution_id, + workflow_id=self._workflow_info.workflow_id, + workflow_type=self._workflow_info.workflow_type, + workflow_version=self._workflow_info.version, + graph=self._workflow_info.graph_data, + inputs=self._prepare_workflow_inputs(), + started_at=naive_utc_now(), + ) + + self._workflow_execution_repository.save(workflow_execution) + self._workflow_execution = workflow_execution + + def _handle_graph_run_succeeded(self, event: GraphRunSucceededEvent) -> None: + execution = self._get_workflow_execution() + execution.outputs = event.outputs + execution.status = WorkflowExecutionStatus.SUCCEEDED + self._populate_completion_statistics(execution) + + self._workflow_execution_repository.save(execution) + self._enqueue_trace_task(execution) + + def _handle_graph_run_partial_succeeded(self, event: GraphRunPartialSucceededEvent) -> None: + execution = self._get_workflow_execution() + execution.outputs = event.outputs + execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED + execution.exceptions_count = event.exceptions_count + self._populate_completion_statistics(execution) + + self._workflow_execution_repository.save(execution) + self._enqueue_trace_task(execution) + + def _handle_graph_run_failed(self, event: GraphRunFailedEvent) -> None: + execution = self._get_workflow_execution() + execution.status = WorkflowExecutionStatus.FAILED + execution.error_message = event.error + execution.exceptions_count = event.exceptions_count + self._populate_completion_statistics(execution) + + self._fail_running_node_executions(error_message=event.error) + self._workflow_execution_repository.save(execution) + self._enqueue_trace_task(execution) + + def _handle_graph_run_aborted(self, event: GraphRunAbortedEvent) -> None: + execution = self._get_workflow_execution() + execution.status = WorkflowExecutionStatus.STOPPED + execution.error_message = event.reason or "Workflow execution aborted" + self._populate_completion_statistics(execution) + + self._fail_running_node_executions(error_message=execution.error_message or "") + self._workflow_execution_repository.save(execution) + self._enqueue_trace_task(execution) + + def _handle_graph_run_paused(self, event: GraphRunPausedEvent) -> None: + execution = self._get_workflow_execution() + execution.status = WorkflowExecutionStatus.PAUSED + execution.error_message = event.reason or "Workflow execution paused" + execution.outputs = event.outputs + self._populate_completion_statistics(execution, update_finished=False) + + self._workflow_execution_repository.save(execution) + + # ------------------------------------------------------------------ + # Node-level handlers + # ------------------------------------------------------------------ + def _handle_node_started(self, event: NodeRunStartedEvent) -> None: + execution = self._get_workflow_execution() + + metadata = { + WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id, + WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id, + } + + domain_execution = WorkflowNodeExecution( + id=event.id, + node_execution_id=event.id, + workflow_id=execution.workflow_id, + workflow_execution_id=execution.id_, + predecessor_node_id=event.predecessor_node_id, + index=self._next_node_sequence(), + node_id=event.node_id, + node_type=event.node_type, + title=event.node_title, + status=WorkflowNodeExecutionStatus.RUNNING, + metadata=metadata, + created_at=event.start_at, + ) + + self._node_execution_cache[event.id] = domain_execution + self._workflow_node_execution_repository.save(domain_execution) + + snapshot = _NodeRuntimeSnapshot( + node_id=event.node_id, + title=event.node_title, + predecessor_node_id=event.predecessor_node_id, + iteration_id=event.in_iteration_id, + loop_id=event.in_loop_id, + created_at=event.start_at, + ) + self._node_snapshots[event.id] = snapshot + + def _handle_node_retry(self, event: NodeRunRetryEvent) -> None: + domain_execution = self._get_node_execution(event.id) + domain_execution.status = WorkflowNodeExecutionStatus.RETRY + domain_execution.error = event.error + self._workflow_node_execution_repository.save(domain_execution) + self._workflow_node_execution_repository.save_execution_data(domain_execution) + + def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None: + domain_execution = self._get_node_execution(event.id) + self._update_node_execution(domain_execution, event.node_run_result, WorkflowNodeExecutionStatus.SUCCEEDED) + + def _handle_node_failed(self, event: NodeRunFailedEvent) -> None: + domain_execution = self._get_node_execution(event.id) + self._update_node_execution( + domain_execution, + event.node_run_result, + WorkflowNodeExecutionStatus.FAILED, + error=event.error, + ) + + def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None: + domain_execution = self._get_node_execution(event.id) + self._update_node_execution( + domain_execution, + event.node_run_result, + WorkflowNodeExecutionStatus.EXCEPTION, + error=event.error, + ) + + def _handle_node_pause_requested(self, event: NodeRunPauseRequestedEvent) -> None: + domain_execution = self._get_node_execution(event.id) + self._update_node_execution( + domain_execution, + event.node_run_result, + WorkflowNodeExecutionStatus.PAUSED, + error=event.reason, + update_outputs=False, + ) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + def _get_execution_id(self) -> str: + workflow_execution_id = self._system_variables().get(SystemVariableKey.WORKFLOW_EXECUTION_ID) + if not workflow_execution_id: + raise ValueError("workflow_execution_id must be provided in system variables for pause/resume flows") + return str(workflow_execution_id) + + def _prepare_workflow_inputs(self) -> Mapping[str, Any]: + inputs = {**self._application_generate_entity.inputs} + for field_name, value in self._system_variables().items(): + if field_name == SystemVariableKey.CONVERSATION_ID.value: + # Conversation IDs are tied to the current session; omit them so persisted + # workflow inputs stay reusable without binding future runs to this conversation. + continue + inputs[f"sys.{field_name}"] = value + handled = WorkflowEntry.handle_special_values(inputs) + return handled or {} + + def _get_workflow_execution(self) -> WorkflowExecution: + if self._workflow_execution is None: + raise ValueError("workflow execution not initialized") + return self._workflow_execution + + def _get_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution: + if node_execution_id not in self._node_execution_cache: + raise ValueError(f"Node execution not found for id={node_execution_id}") + return self._node_execution_cache[node_execution_id] + + def _next_node_sequence(self) -> int: + self._node_sequence += 1 + return self._node_sequence + + def _populate_completion_statistics(self, execution: WorkflowExecution, *, update_finished: bool = True) -> None: + if update_finished: + execution.finished_at = naive_utc_now() + runtime_state = self.graph_runtime_state + if runtime_state is None: + return + execution.total_tokens = runtime_state.total_tokens + execution.total_steps = runtime_state.node_run_steps + execution.outputs = execution.outputs or runtime_state.outputs + execution.exceptions_count = runtime_state.exceptions_count + + def _update_node_execution( + self, + domain_execution: WorkflowNodeExecution, + node_result: NodeRunResult, + status: WorkflowNodeExecutionStatus, + *, + error: str | None = None, + update_outputs: bool = True, + ) -> None: + finished_at = naive_utc_now() + snapshot = self._node_snapshots.get(domain_execution.id) + start_at = snapshot.created_at if snapshot else domain_execution.created_at + domain_execution.status = status + domain_execution.finished_at = finished_at + domain_execution.elapsed_time = max((finished_at - start_at).total_seconds(), 0.0) + + if error: + domain_execution.error = error + + if update_outputs: + domain_execution.update_from_mapping( + inputs=node_result.inputs, + process_data=node_result.process_data, + outputs=node_result.outputs, + metadata=node_result.metadata, + ) + + self._workflow_node_execution_repository.save(domain_execution) + self._workflow_node_execution_repository.save_execution_data(domain_execution) + + def _fail_running_node_executions(self, *, error_message: str) -> None: + now = naive_utc_now() + for execution in self._node_execution_cache.values(): + if execution.status == WorkflowNodeExecutionStatus.RUNNING: + execution.status = WorkflowNodeExecutionStatus.FAILED + execution.error = error_message + execution.finished_at = now + execution.elapsed_time = max((now - execution.created_at).total_seconds(), 0.0) + self._workflow_node_execution_repository.save(execution) + + def _enqueue_trace_task(self, execution: WorkflowExecution) -> None: + if not self._trace_manager: + return + + conversation_id = self._system_variables().get(SystemVariableKey.CONVERSATION_ID.value) + external_trace_id = None + if isinstance(self._application_generate_entity, (WorkflowAppGenerateEntity, AdvancedChatAppGenerateEntity)): + external_trace_id = self._application_generate_entity.extras.get("external_trace_id") + + trace_task = TraceTask( + TraceTaskName.WORKFLOW_TRACE, + workflow_execution=execution, + conversation_id=conversation_id, + user_id=self._trace_manager.user_id, + external_trace_id=external_trace_id, + ) + self._trace_manager.add_trace_task(trace_task) + + def _system_variables(self) -> Mapping[str, Any]: + runtime_state = self.graph_runtime_state + if runtime_state is None: + return {} + return runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID) diff --git a/api/core/workflow/graph_engine/manager.py b/api/core/workflow/graph_engine/manager.py index ed62209acb..f05d43d8ad 100644 --- a/api/core/workflow/graph_engine/manager.py +++ b/api/core/workflow/graph_engine/manager.py @@ -9,7 +9,7 @@ Supports stop, pause, and resume operations. from typing import final from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel -from core.workflow.graph_engine.entities.commands import AbortCommand +from core.workflow.graph_engine.entities.commands import AbortCommand, GraphEngineCommand, PauseCommand from extensions.ext_redis import redis_client @@ -20,7 +20,7 @@ class GraphEngineManager: This class provides a simple interface for controlling workflow executions by sending commands through Redis channels, without user validation. - Supports stop, pause, and resume operations. + Supports stop and pause operations. """ @staticmethod @@ -32,19 +32,29 @@ class GraphEngineManager: task_id: The task ID of the workflow to stop reason: Optional reason for stopping (defaults to "User requested stop") """ + abort_command = AbortCommand(reason=reason or "User requested stop") + GraphEngineManager._send_command(task_id, abort_command) + + @staticmethod + def send_pause_command(task_id: str, reason: str | None = None) -> None: + """Send a pause command to a running workflow.""" + + pause_command = PauseCommand(reason=reason or "User requested pause") + GraphEngineManager._send_command(task_id, pause_command) + + @staticmethod + def _send_command(task_id: str, command: GraphEngineCommand) -> None: + """Send a command to the workflow-specific Redis channel.""" + if not task_id: return - # Create Redis channel for this task channel_key = f"workflow:{task_id}:commands" channel = RedisChannel(redis_client, channel_key) - # Create and send abort command - abort_command = AbortCommand(reason=reason or "User requested stop") - try: - channel.send_command(abort_command) + channel.send_command(command) except Exception: # Silently fail if Redis is unavailable - # The legacy stop flag mechanism will still work + # The legacy control mechanisms will still work pass diff --git a/api/core/workflow/graph_engine/orchestration/dispatcher.py b/api/core/workflow/graph_engine/orchestration/dispatcher.py index f3570855ce..4097cead9c 100644 --- a/api/core/workflow/graph_engine/orchestration/dispatcher.py +++ b/api/core/workflow/graph_engine/orchestration/dispatcher.py @@ -33,6 +33,12 @@ class Dispatcher: with timeout and completion detection. """ + _COMMAND_TRIGGER_EVENTS = ( + NodeRunSucceededEvent, + NodeRunFailedEvent, + NodeRunExceptionEvent, + ) + def __init__( self, event_queue: queue.Queue[GraphNodeEventBase], @@ -77,33 +83,41 @@ class Dispatcher: if self._thread and self._thread.is_alive(): self._thread.join(timeout=10.0) - _COMMAND_TRIGGER_EVENTS = ( - NodeRunSucceededEvent, - NodeRunFailedEvent, - NodeRunExceptionEvent, - ) - def _dispatcher_loop(self) -> None: """Main dispatcher loop.""" try: while not self._stop_event.is_set(): - # Check for scaling - self._execution_coordinator.check_scaling() + commands_checked = False + should_check_commands = False + should_break = False - # Process events - try: - event = self._event_queue.get(timeout=0.1) - # Route to the event handler - self._event_handler.dispatch(event) - if self._should_check_commands(event): - self._execution_coordinator.check_commands() - self._event_queue.task_done() - except queue.Empty: - # Process commands even when no new events arrive so abort requests are not missed + if self._execution_coordinator.is_execution_complete(): + should_check_commands = True + should_break = True + else: + # Check for scaling + self._execution_coordinator.check_scaling() + + # Process events + try: + event = self._event_queue.get(timeout=0.1) + # Route to the event handler + self._event_handler.dispatch(event) + should_check_commands = self._should_check_commands(event) + self._event_queue.task_done() + except queue.Empty: + # Process commands even when no new events arrive so abort requests are not missed + should_check_commands = True + time.sleep(0.1) + + if should_check_commands and not commands_checked: self._execution_coordinator.check_commands() - # Check if execution is complete - if self._execution_coordinator.is_execution_complete(): - break + commands_checked = True + + if should_break: + if not commands_checked: + self._execution_coordinator.check_commands() + break except Exception as e: logger.exception("Dispatcher error") diff --git a/api/core/workflow/graph_engine/orchestration/execution_coordinator.py b/api/core/workflow/graph_engine/orchestration/execution_coordinator.py index b35e8bb6d8..a3162de244 100644 --- a/api/core/workflow/graph_engine/orchestration/execution_coordinator.py +++ b/api/core/workflow/graph_engine/orchestration/execution_coordinator.py @@ -2,17 +2,13 @@ Execution coordinator for managing overall workflow execution. """ -from typing import TYPE_CHECKING, final +from typing import final from ..command_processing import CommandProcessor from ..domain import GraphExecution -from ..event_management import EventManager from ..graph_state_manager import GraphStateManager from ..worker_management import WorkerPool -if TYPE_CHECKING: - from ..event_management import EventHandler - @final class ExecutionCoordinator: @@ -27,8 +23,6 @@ class ExecutionCoordinator: self, graph_execution: GraphExecution, state_manager: GraphStateManager, - event_handler: "EventHandler", - event_collector: EventManager, command_processor: CommandProcessor, worker_pool: WorkerPool, ) -> None: @@ -38,15 +32,11 @@ class ExecutionCoordinator: Args: graph_execution: Graph execution aggregate state_manager: Unified state manager - event_handler: Event handler registry for processing events - event_collector: Event manager for collecting events command_processor: Processor for commands worker_pool: Pool of workers """ self._graph_execution = graph_execution self._state_manager = state_manager - self._event_handler = event_handler - self._event_collector = event_collector self._command_processor = command_processor self._worker_pool = worker_pool @@ -65,15 +55,24 @@ class ExecutionCoordinator: Returns: True if execution is complete """ - # Check if aborted or failed + # Treat paused, aborted, or failed executions as terminal states + if self._graph_execution.is_paused: + return True + if self._graph_execution.aborted or self._graph_execution.has_error: return True - # Complete if no work remains return self._state_manager.is_execution_complete() + @property + def is_paused(self) -> bool: + """Expose whether the underlying graph execution is paused.""" + return self._graph_execution.is_paused + def mark_complete(self) -> None: """Mark execution as complete.""" + if self._graph_execution.is_paused: + return if not self._graph_execution.completed: self._graph_execution.complete() @@ -85,3 +84,21 @@ class ExecutionCoordinator: error: The error that caused failure """ self._graph_execution.fail(error) + + def handle_pause_if_needed(self) -> None: + """If the execution has been paused, stop workers immediately.""" + + if not self._graph_execution.is_paused: + return + + self._worker_pool.stop() + self._state_manager.clear_executing() + + def handle_abort_if_needed(self) -> None: + """If the execution has been aborted, stop workers immediately.""" + + if not self._graph_execution.aborted: + return + + self._worker_pool.stop() + self._state_manager.clear_executing() diff --git a/api/core/workflow/graph_engine/response_coordinator/coordinator.py b/api/core/workflow/graph_engine/response_coordinator/coordinator.py index 3db40c545e..98e0ea91ef 100644 --- a/api/core/workflow/graph_engine/response_coordinator/coordinator.py +++ b/api/core/workflow/graph_engine/response_coordinator/coordinator.py @@ -14,11 +14,11 @@ from uuid import uuid4 from pydantic import BaseModel, Field -from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import NodeExecutionType, NodeState from core.workflow.graph import Graph from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent from core.workflow.nodes.base.template import TextSegment, VariableSegment +from core.workflow.runtime import VariablePool from .path import Path from .session import ResponseSession diff --git a/api/core/workflow/graph_events/__init__.py b/api/core/workflow/graph_events/__init__.py index 42a376d4ad..7a5edbb331 100644 --- a/api/core/workflow/graph_events/__init__.py +++ b/api/core/workflow/graph_events/__init__.py @@ -13,6 +13,7 @@ from .graph import ( GraphRunAbortedEvent, GraphRunFailedEvent, GraphRunPartialSucceededEvent, + GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, ) @@ -37,6 +38,7 @@ from .loop import ( from .node import ( NodeRunExceptionEvent, NodeRunFailedEvent, + NodeRunPauseRequestedEvent, NodeRunRetrieverResourceEvent, NodeRunRetryEvent, NodeRunStartedEvent, @@ -51,6 +53,7 @@ __all__ = [ "GraphRunAbortedEvent", "GraphRunFailedEvent", "GraphRunPartialSucceededEvent", + "GraphRunPausedEvent", "GraphRunStartedEvent", "GraphRunSucceededEvent", "NodeRunAgentLogEvent", @@ -64,6 +67,7 @@ __all__ = [ "NodeRunLoopNextEvent", "NodeRunLoopStartedEvent", "NodeRunLoopSucceededEvent", + "NodeRunPauseRequestedEvent", "NodeRunRetrieverResourceEvent", "NodeRunRetryEvent", "NodeRunStartedEvent", diff --git a/api/core/workflow/graph_events/graph.py b/api/core/workflow/graph_events/graph.py index 5d13833faa..0da962aa1c 100644 --- a/api/core/workflow/graph_events/graph.py +++ b/api/core/workflow/graph_events/graph.py @@ -8,7 +8,12 @@ class GraphRunStartedEvent(BaseGraphEvent): class GraphRunSucceededEvent(BaseGraphEvent): - outputs: dict[str, object] = Field(default_factory=dict) + """Event emitted when a run completes successfully with final outputs.""" + + outputs: dict[str, object] = Field( + default_factory=dict, + description="Final workflow outputs keyed by output selector.", + ) class GraphRunFailedEvent(BaseGraphEvent): @@ -17,12 +22,30 @@ class GraphRunFailedEvent(BaseGraphEvent): class GraphRunPartialSucceededEvent(BaseGraphEvent): + """Event emitted when a run finishes with partial success and failures.""" + exceptions_count: int = Field(..., description="exception count") - outputs: dict[str, object] = Field(default_factory=dict) + outputs: dict[str, object] = Field( + default_factory=dict, + description="Outputs that were materialised before failures occurred.", + ) class GraphRunAbortedEvent(BaseGraphEvent): """Event emitted when a graph run is aborted by user command.""" reason: str | None = Field(default=None, description="reason for abort") - outputs: dict[str, object] = Field(default_factory=dict, description="partial outputs if any") + outputs: dict[str, object] = Field( + default_factory=dict, + description="Outputs produced before the abort was requested.", + ) + + +class GraphRunPausedEvent(BaseGraphEvent): + """Event emitted when a graph run is paused by user command.""" + + reason: str | None = Field(default=None, description="reason for pause") + outputs: dict[str, object] = Field( + default_factory=dict, + description="Outputs available to the client while the run is paused.", + ) diff --git a/api/core/workflow/graph_events/node.py b/api/core/workflow/graph_events/node.py index 1d35a69c4a..b880df60d1 100644 --- a/api/core/workflow/graph_events/node.py +++ b/api/core/workflow/graph_events/node.py @@ -51,3 +51,7 @@ class NodeRunExceptionEvent(GraphNodeEventBase): class NodeRunRetryEvent(NodeRunStartedEvent): error: str = Field(..., description="error") retry_index: int = Field(..., description="which retry attempt is about to be performed") + + +class NodeRunPauseRequestedEvent(GraphNodeEventBase): + reason: str | None = Field(default=None, description="Optional pause reason") diff --git a/api/core/workflow/node_events/__init__.py b/api/core/workflow/node_events/__init__.py index c3bcda0483..f14a594c85 100644 --- a/api/core/workflow/node_events/__init__.py +++ b/api/core/workflow/node_events/__init__.py @@ -14,6 +14,7 @@ from .loop import ( ) from .node import ( ModelInvokeCompletedEvent, + PauseRequestedEvent, RunRetrieverResourceEvent, RunRetryEvent, StreamChunkEvent, @@ -33,6 +34,7 @@ __all__ = [ "ModelInvokeCompletedEvent", "NodeEventBase", "NodeRunResult", + "PauseRequestedEvent", "RunRetrieverResourceEvent", "RunRetryEvent", "StreamChunkEvent", diff --git a/api/core/workflow/node_events/node.py b/api/core/workflow/node_events/node.py index 93dfefb679..4fd5684436 100644 --- a/api/core/workflow/node_events/node.py +++ b/api/core/workflow/node_events/node.py @@ -40,3 +40,7 @@ class StreamChunkEvent(NodeEventBase): class StreamCompletedEvent(NodeEventBase): node_run_result: NodeRunResult = Field(..., description="run result") + + +class PauseRequestedEvent(NodeEventBase): + reason: str | None = Field(default=None, description="Optional pause reason") diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 4a24b18465..626ef1df7b 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -25,7 +25,6 @@ from core.tools.entities.tool_entities import ( from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.variables.segments import ArrayFileSegment, StringSegment -from core.workflow.entities import VariablePool from core.workflow.enums import ( ErrorStrategy, NodeType, @@ -44,6 +43,7 @@ from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionMod from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser +from core.workflow.runtime import VariablePool from extensions.ext_database import db from factories import file_factory from factories.agent_factory import get_plugin_agent_strategy diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 41212abb0e..7f8c1eddff 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -6,7 +6,7 @@ from typing import Any, ClassVar from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities import AgentNodeStrategyInit, GraphInitParams, GraphRuntimeState +from core.workflow.entities import AgentNodeStrategyInit, GraphInitParams from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus from core.workflow.graph_events import ( GraphNodeEventBase, @@ -20,6 +20,7 @@ from core.workflow.graph_events import ( NodeRunLoopNextEvent, NodeRunLoopStartedEvent, NodeRunLoopSucceededEvent, + NodeRunPauseRequestedEvent, NodeRunRetrieverResourceEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, @@ -37,10 +38,12 @@ from core.workflow.node_events import ( LoopSucceededEvent, NodeEventBase, NodeRunResult, + PauseRequestedEvent, RunRetrieverResourceEvent, StreamChunkEvent, StreamCompletedEvent, ) +from core.workflow.runtime import GraphRuntimeState from libs.datetime_utils import naive_utc_now from models.enums import UserFrom @@ -385,6 +388,16 @@ class Node: f"Node {self._node_id} does not support status {event.node_run_result.status}" ) + @_dispatch.register + def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent: + return NodeRunPauseRequestedEvent( + id=self._node_execution_id, + node_id=self._node_id, + node_type=self.node_type, + node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.PAUSED), + reason=event.reason, + ) + @_dispatch.register def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent: return NodeRunAgentLogEvent( diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index e392cb5f5c..34c1db9468 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -19,7 +19,6 @@ from core.file.enums import FileTransferMethod, FileType from core.plugin.impl.exc import PluginDaemonClientSideError from core.variables.segments import ArrayAnySegment from core.variables.variables import ArrayAnyVariable -from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent @@ -27,6 +26,7 @@ from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser from core.workflow.nodes.tool.exc import ToolFileError +from core.workflow.runtime import VariablePool from extensions.ext_database import db from factories import file_factory from models.model import UploadFile diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index d3d3571b44..7b5b9c9e86 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -15,7 +15,7 @@ from core.file import file_manager from core.file.enums import FileTransferMethod from core.helper import ssrf_proxy from core.variables.segments import ArrayFileSegment, FileSegment -from core.workflow.entities import VariablePool +from core.workflow.runtime import VariablePool from .entities import ( HttpRequestNodeAuthorization, diff --git a/api/core/workflow/nodes/human_input/__init__.py b/api/core/workflow/nodes/human_input/__init__.py new file mode 100644 index 0000000000..379440557c --- /dev/null +++ b/api/core/workflow/nodes/human_input/__init__.py @@ -0,0 +1,3 @@ +from .human_input_node import HumanInputNode + +__all__ = ["HumanInputNode"] diff --git a/api/core/workflow/nodes/human_input/entities.py b/api/core/workflow/nodes/human_input/entities.py new file mode 100644 index 0000000000..02913d93c3 --- /dev/null +++ b/api/core/workflow/nodes/human_input/entities.py @@ -0,0 +1,10 @@ +from pydantic import Field + +from core.workflow.nodes.base import BaseNodeData + + +class HumanInputNodeData(BaseNodeData): + """Configuration schema for the HumanInput node.""" + + required_variables: list[str] = Field(default_factory=list) + pause_reason: str | None = Field(default=None) diff --git a/api/core/workflow/nodes/human_input/human_input_node.py b/api/core/workflow/nodes/human_input/human_input_node.py new file mode 100644 index 0000000000..e49f9a8c81 --- /dev/null +++ b/api/core/workflow/nodes/human_input/human_input_node.py @@ -0,0 +1,132 @@ +from collections.abc import Mapping +from typing import Any + +from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus +from core.workflow.node_events import NodeRunResult, PauseRequestedEvent +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.base.node import Node + +from .entities import HumanInputNodeData + + +class HumanInputNode(Node): + node_type = NodeType.HUMAN_INPUT + execution_type = NodeExecutionType.BRANCH + + _BRANCH_SELECTION_KEYS: tuple[str, ...] = ( + "edge_source_handle", + "edgeSourceHandle", + "source_handle", + "selected_branch", + "selectedBranch", + "branch", + "branch_id", + "branchId", + "handle", + ) + + _node_data: HumanInputNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = HumanInputNodeData(**data) + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + + @classmethod + def version(cls) -> str: + return "1" + + def _get_error_strategy(self) -> ErrorStrategy | None: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> str | None: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def _run(self): # type: ignore[override] + if self._is_completion_ready(): + branch_handle = self._resolve_branch_selection() + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={}, + edge_source_handle=branch_handle or "source", + ) + + return self._pause_generator() + + def _pause_generator(self): + yield PauseRequestedEvent(reason=self._node_data.pause_reason) + + def _is_completion_ready(self) -> bool: + """Determine whether all required inputs are satisfied.""" + + if not self._node_data.required_variables: + return False + + variable_pool = self.graph_runtime_state.variable_pool + + for selector_str in self._node_data.required_variables: + parts = selector_str.split(".") + if len(parts) != 2: + return False + segment = variable_pool.get(parts) + if segment is None: + return False + + return True + + def _resolve_branch_selection(self) -> str | None: + """Determine the branch handle selected by human input if available.""" + + variable_pool = self.graph_runtime_state.variable_pool + + for key in self._BRANCH_SELECTION_KEYS: + handle = self._extract_branch_handle(variable_pool.get((self.id, key))) + if handle: + return handle + + default_values = self._node_data.default_value_dict + for key in self._BRANCH_SELECTION_KEYS: + handle = self._normalize_branch_value(default_values.get(key)) + if handle: + return handle + + return None + + @staticmethod + def _extract_branch_handle(segment: Any) -> str | None: + if segment is None: + return None + + candidate = getattr(segment, "to_object", None) + raw_value = candidate() if callable(candidate) else getattr(segment, "value", None) + if raw_value is None: + return None + + return HumanInputNode._normalize_branch_value(raw_value) + + @staticmethod + def _normalize_branch_value(value: Any) -> str | None: + if value is None: + return None + + if isinstance(value, str): + stripped = value.strip() + return stripped or None + + if isinstance(value, Mapping): + for key in ("handle", "edge_source_handle", "edgeSourceHandle", "branch", "id", "value"): + candidate = value.get(key) + if isinstance(candidate, str) and candidate: + return candidate + + return None diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index 7e3b6ecc1a..165e529714 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -3,12 +3,12 @@ from typing import Any, Literal from typing_extensions import deprecated -from core.workflow.entities import VariablePool from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.if_else.entities import IfElseNodeData +from core.workflow.runtime import VariablePool from core.workflow.utils.condition.entities import Condition from core.workflow.utils.condition.processor import ConditionProcessor diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index c089a68bd4..41060bd569 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -12,7 +12,6 @@ from core.variables import IntegerVariable, NoneSegment from core.variables.segments import ArrayAnySegment, ArraySegment from core.variables.variables import VariableUnion from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID -from core.workflow.entities import VariablePool from core.workflow.enums import ( ErrorStrategy, NodeExecutionType, @@ -38,6 +37,7 @@ from core.workflow.node_events import ( from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData +from core.workflow.runtime import VariablePool from libs.datetime_utils import naive_utc_now from libs.flask_utils import preserve_flask_contexts @@ -557,11 +557,12 @@ class IterationNode(Node): def _create_graph_engine(self, index: int, item: object): # Import dependencies - from core.workflow.entities import GraphInitParams, GraphRuntimeState + from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.nodes.node_factory import DifyNodeFactory + from core.workflow.runtime import GraphRuntimeState # Create GraphInitParams from node attributes graph_init_params = GraphInitParams( diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 2751f24048..2ba1e5e1c5 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -9,13 +9,13 @@ from sqlalchemy import func, select from core.app.entities.app_invoke_entities import InvokeFrom from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.retrieval.retrieval_methods import RetrievalMethod -from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.template import Template +from core.workflow.runtime import VariablePool from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 7091b62463..2dc3cb9320 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -67,7 +67,7 @@ from .exc import ( if TYPE_CHECKING: from core.file.models import File - from core.workflow.entities import GraphRuntimeState + from core.workflow.runtime import GraphRuntimeState logger = logging.getLogger(__name__) diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py index aff84433b2..0c545469bc 100644 --- a/api/core/workflow/nodes/llm/llm_utils.py +++ b/api/core/workflow/nodes/llm/llm_utils.py @@ -15,9 +15,9 @@ from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment -from core.workflow.entities import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.nodes.llm.entities import ModelConfig +from core.workflow.runtime import VariablePool from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.model import Conversation diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 13f6d904e6..e4637e6e95 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -52,7 +52,7 @@ from core.variables import ( StringSegment, ) from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID -from core.workflow.entities import GraphInitParams, VariablePool +from core.workflow.entities import GraphInitParams from core.workflow.enums import ( ErrorStrategy, NodeType, @@ -71,6 +71,7 @@ from core.workflow.node_events import ( from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser +from core.workflow.runtime import VariablePool from . import llm_utils from .entities import ( @@ -93,7 +94,7 @@ from .file_saver import FileSaverImpl, LLMFileSaver if TYPE_CHECKING: from core.file.models import File - from core.workflow.entities import GraphRuntimeState + from core.workflow.runtime import GraphRuntimeState logger = logging.getLogger(__name__) diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 790975d556..b51790c0a2 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -406,11 +406,12 @@ class LoopNode(Node): def _create_graph_engine(self, start_at: datetime, root_node_id: str): # Import dependencies - from core.workflow.entities import GraphInitParams, GraphRuntimeState + from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.nodes.node_factory import DifyNodeFactory + from core.workflow.runtime import GraphRuntimeState # Create GraphInitParams from node attributes graph_init_params = GraphInitParams( diff --git a/api/core/workflow/nodes/node_factory.py b/api/core/workflow/nodes/node_factory.py index df1d685909..87d1b8c435 100644 --- a/api/core/workflow/nodes/node_factory.py +++ b/api/core/workflow/nodes/node_factory.py @@ -10,7 +10,8 @@ from libs.typing import is_str, is_str_dict from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING if TYPE_CHECKING: - from core.workflow.entities import GraphInitParams, GraphRuntimeState + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState @final diff --git a/api/core/workflow/nodes/node_mapping.py b/api/core/workflow/nodes/node_mapping.py index 3d3a1bec98..3ee28802f1 100644 --- a/api/core/workflow/nodes/node_mapping.py +++ b/api/core/workflow/nodes/node_mapping.py @@ -9,6 +9,7 @@ from core.workflow.nodes.datasource.datasource_node import DatasourceNode from core.workflow.nodes.document_extractor import DocumentExtractorNode from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.http_request import HttpRequestNode +from core.workflow.nodes.human_input import HumanInputNode from core.workflow.nodes.if_else import IfElseNode from core.workflow.nodes.iteration import IterationNode, IterationStartNode from core.workflow.nodes.knowledge_index import KnowledgeIndexNode @@ -134,6 +135,10 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = { "2": AgentNode, "1": AgentNode, }, + NodeType.HUMAN_INPUT: { + LATEST_VERSION: HumanInputNode, + "1": HumanInputNode, + }, NodeType.DATASOURCE: { LATEST_VERSION: DatasourceNode, "1": DatasourceNode, diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 875a0598e0..2b65cc30b6 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -27,13 +27,13 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.variables.types import ArrayValidation, SegmentType -from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base import variable_template_parser from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.llm import ModelConfig, llm_utils +from core.workflow.runtime import VariablePool from factories.variable_factory import build_segment_with_type from .entities import ParameterExtractorNodeData diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 31b1cd4966..3f37fc481b 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -41,7 +41,7 @@ from .template_prompts import ( if TYPE_CHECKING: from core.file.models import File - from core.workflow.entities import GraphRuntimeState + from core.workflow.runtime import GraphRuntimeState class QuestionClassifierNode(Node): diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index cd0094f531..2e2c32ac93 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -36,7 +36,7 @@ from .exc import ( ) if TYPE_CHECKING: - from core.workflow.entities import VariablePool + from core.workflow.runtime import VariablePool class ToolNode(Node): diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index c2a9ecd7fb..8cd267c4a7 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -18,7 +18,7 @@ from ..common.impl import conversation_variable_updater_factory from .node_data import VariableAssignerData, WriteMode if TYPE_CHECKING: - from core.workflow.entities import GraphRuntimeState + from core.workflow.runtime import GraphRuntimeState _CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater] diff --git a/api/core/workflow/runtime/__init__.py b/api/core/workflow/runtime/__init__.py new file mode 100644 index 0000000000..10014c7182 --- /dev/null +++ b/api/core/workflow/runtime/__init__.py @@ -0,0 +1,14 @@ +from .graph_runtime_state import GraphRuntimeState +from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool +from .read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper +from .variable_pool import VariablePool, VariableValue + +__all__ = [ + "GraphRuntimeState", + "ReadOnlyGraphRuntimeState", + "ReadOnlyGraphRuntimeStateWrapper", + "ReadOnlyVariablePool", + "ReadOnlyVariablePoolWrapper", + "VariablePool", + "VariableValue", +] diff --git a/api/core/workflow/runtime/graph_runtime_state.py b/api/core/workflow/runtime/graph_runtime_state.py new file mode 100644 index 0000000000..486718dc62 --- /dev/null +++ b/api/core/workflow/runtime/graph_runtime_state.py @@ -0,0 +1,393 @@ +from __future__ import annotations + +import importlib +import json +from collections.abc import Mapping, Sequence +from collections.abc import Mapping as TypingMapping +from copy import deepcopy +from typing import Any, Protocol + +from pydantic.json import pydantic_encoder + +from core.model_runtime.entities.llm_entities import LLMUsage +from core.workflow.runtime.variable_pool import VariablePool + + +class ReadyQueueProtocol(Protocol): + """Structural interface required from ready queue implementations.""" + + def put(self, item: str) -> None: + """Enqueue the identifier of a node that is ready to run.""" + ... + + def get(self, timeout: float | None = None) -> str: + """Return the next node identifier, blocking until available or timeout expires.""" + ... + + def task_done(self) -> None: + """Signal that the most recently dequeued node has completed processing.""" + ... + + def empty(self) -> bool: + """Return True when the queue contains no pending nodes.""" + ... + + def qsize(self) -> int: + """Approximate the number of pending nodes awaiting execution.""" + ... + + def dumps(self) -> str: + """Serialize the queue contents for persistence.""" + ... + + def loads(self, data: str) -> None: + """Restore the queue contents from a serialized payload.""" + ... + + +class GraphExecutionProtocol(Protocol): + """Structural interface for graph execution aggregate.""" + + workflow_id: str + started: bool + completed: bool + aborted: bool + error: Exception | None + exceptions_count: int + + def start(self) -> None: + """Transition execution into the running state.""" + ... + + def complete(self) -> None: + """Mark execution as successfully completed.""" + ... + + def abort(self, reason: str) -> None: + """Abort execution in response to an external stop request.""" + ... + + def fail(self, error: Exception) -> None: + """Record an unrecoverable error and end execution.""" + ... + + def dumps(self) -> str: + """Serialize execution state into a JSON payload.""" + ... + + def loads(self, data: str) -> None: + """Restore execution state from a previously serialized payload.""" + ... + + +class ResponseStreamCoordinatorProtocol(Protocol): + """Structural interface for response stream coordinator.""" + + def register(self, response_node_id: str) -> None: + """Register a response node so its outputs can be streamed.""" + ... + + def loads(self, data: str) -> None: + """Restore coordinator state from a serialized payload.""" + ... + + def dumps(self) -> str: + """Serialize coordinator state for persistence.""" + ... + + +class GraphProtocol(Protocol): + """Structural interface required from graph instances attached to the runtime state.""" + + nodes: TypingMapping[str, object] + edges: TypingMapping[str, object] + root_node: object + + def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ... + + +class GraphRuntimeState: + """Mutable runtime state shared across graph execution components.""" + + def __init__( + self, + *, + variable_pool: VariablePool, + start_at: float, + total_tokens: int = 0, + llm_usage: LLMUsage | None = None, + outputs: dict[str, object] | None = None, + node_run_steps: int = 0, + ready_queue: ReadyQueueProtocol | None = None, + graph_execution: GraphExecutionProtocol | None = None, + response_coordinator: ResponseStreamCoordinatorProtocol | None = None, + graph: GraphProtocol | None = None, + ) -> None: + self._variable_pool = variable_pool + self._start_at = start_at + + if total_tokens < 0: + raise ValueError("total_tokens must be non-negative") + self._total_tokens = total_tokens + + self._llm_usage = (llm_usage or LLMUsage.empty_usage()).model_copy() + self._outputs = deepcopy(outputs) if outputs is not None else {} + + if node_run_steps < 0: + raise ValueError("node_run_steps must be non-negative") + self._node_run_steps = node_run_steps + + self._graph: GraphProtocol | None = None + + self._ready_queue = ready_queue + self._graph_execution = graph_execution + self._response_coordinator = response_coordinator + self._pending_response_coordinator_dump: str | None = None + self._pending_graph_execution_workflow_id: str | None = None + self._paused_nodes: set[str] = set() + + if graph is not None: + self.attach_graph(graph) + + # ------------------------------------------------------------------ + # Context binding helpers + # ------------------------------------------------------------------ + def attach_graph(self, graph: GraphProtocol) -> None: + """Attach the materialized graph to the runtime state.""" + if self._graph is not None and self._graph is not graph: + raise ValueError("GraphRuntimeState already attached to a different graph instance") + + self._graph = graph + + if self._response_coordinator is None: + self._response_coordinator = self._build_response_coordinator(graph) + + if self._pending_response_coordinator_dump is not None and self._response_coordinator is not None: + self._response_coordinator.loads(self._pending_response_coordinator_dump) + self._pending_response_coordinator_dump = None + + def configure(self, *, graph: GraphProtocol | None = None) -> None: + """Ensure core collaborators are initialized with the provided context.""" + if graph is not None: + self.attach_graph(graph) + + # Ensure collaborators are instantiated + _ = self.ready_queue + _ = self.graph_execution + if self._graph is not None: + _ = self.response_coordinator + + # ------------------------------------------------------------------ + # Primary collaborators + # ------------------------------------------------------------------ + @property + def variable_pool(self) -> VariablePool: + return self._variable_pool + + @property + def ready_queue(self) -> ReadyQueueProtocol: + if self._ready_queue is None: + self._ready_queue = self._build_ready_queue() + return self._ready_queue + + @property + def graph_execution(self) -> GraphExecutionProtocol: + if self._graph_execution is None: + self._graph_execution = self._build_graph_execution() + return self._graph_execution + + @property + def response_coordinator(self) -> ResponseStreamCoordinatorProtocol: + if self._response_coordinator is None: + if self._graph is None: + raise ValueError("Graph must be attached before accessing response coordinator") + self._response_coordinator = self._build_response_coordinator(self._graph) + return self._response_coordinator + + # ------------------------------------------------------------------ + # Scalar state + # ------------------------------------------------------------------ + @property + def start_at(self) -> float: + return self._start_at + + @start_at.setter + def start_at(self, value: float) -> None: + self._start_at = value + + @property + def total_tokens(self) -> int: + return self._total_tokens + + @total_tokens.setter + def total_tokens(self, value: int) -> None: + if value < 0: + raise ValueError("total_tokens must be non-negative") + self._total_tokens = value + + @property + def llm_usage(self) -> LLMUsage: + return self._llm_usage.model_copy() + + @llm_usage.setter + def llm_usage(self, value: LLMUsage) -> None: + self._llm_usage = value.model_copy() + + @property + def outputs(self) -> dict[str, Any]: + return deepcopy(self._outputs) + + @outputs.setter + def outputs(self, value: dict[str, Any]) -> None: + self._outputs = deepcopy(value) + + def set_output(self, key: str, value: object) -> None: + self._outputs[key] = deepcopy(value) + + def get_output(self, key: str, default: object = None) -> object: + return deepcopy(self._outputs.get(key, default)) + + def update_outputs(self, updates: dict[str, object]) -> None: + for key, value in updates.items(): + self._outputs[key] = deepcopy(value) + + @property + def node_run_steps(self) -> int: + return self._node_run_steps + + @node_run_steps.setter + def node_run_steps(self, value: int) -> None: + if value < 0: + raise ValueError("node_run_steps must be non-negative") + self._node_run_steps = value + + def increment_node_run_steps(self) -> None: + self._node_run_steps += 1 + + def add_tokens(self, tokens: int) -> None: + if tokens < 0: + raise ValueError("tokens must be non-negative") + self._total_tokens += tokens + + # ------------------------------------------------------------------ + # Serialization + # ------------------------------------------------------------------ + def dumps(self) -> str: + """Serialize runtime state into a JSON string.""" + + snapshot: dict[str, Any] = { + "version": "1.0", + "start_at": self._start_at, + "total_tokens": self._total_tokens, + "node_run_steps": self._node_run_steps, + "llm_usage": self._llm_usage.model_dump(mode="json"), + "outputs": self.outputs, + "variable_pool": self.variable_pool.model_dump(mode="json"), + "ready_queue": self.ready_queue.dumps(), + "graph_execution": self.graph_execution.dumps(), + "paused_nodes": list(self._paused_nodes), + } + + if self._response_coordinator is not None and self._graph is not None: + snapshot["response_coordinator"] = self._response_coordinator.dumps() + + return json.dumps(snapshot, default=pydantic_encoder) + + def loads(self, data: str | Mapping[str, Any]) -> None: + """Restore runtime state from a serialized snapshot.""" + + payload: dict[str, Any] + if isinstance(data, str): + payload = json.loads(data) + else: + payload = dict(data) + + version = payload.get("version") + if version != "1.0": + raise ValueError(f"Unsupported GraphRuntimeState snapshot version: {version}") + + self._start_at = float(payload.get("start_at", 0.0)) + total_tokens = int(payload.get("total_tokens", 0)) + if total_tokens < 0: + raise ValueError("total_tokens must be non-negative") + self._total_tokens = total_tokens + + node_run_steps = int(payload.get("node_run_steps", 0)) + if node_run_steps < 0: + raise ValueError("node_run_steps must be non-negative") + self._node_run_steps = node_run_steps + + llm_usage_payload = payload.get("llm_usage", {}) + self._llm_usage = LLMUsage.model_validate(llm_usage_payload) + + self._outputs = deepcopy(payload.get("outputs", {})) + + variable_pool_payload = payload.get("variable_pool") + if variable_pool_payload is not None: + self._variable_pool = VariablePool.model_validate(variable_pool_payload) + + ready_queue_payload = payload.get("ready_queue") + if ready_queue_payload is not None: + self._ready_queue = self._build_ready_queue() + self._ready_queue.loads(ready_queue_payload) + else: + self._ready_queue = None + + graph_execution_payload = payload.get("graph_execution") + self._graph_execution = None + self._pending_graph_execution_workflow_id = None + if graph_execution_payload is not None: + try: + execution_payload = json.loads(graph_execution_payload) + self._pending_graph_execution_workflow_id = execution_payload.get("workflow_id") + except (json.JSONDecodeError, TypeError, AttributeError): + self._pending_graph_execution_workflow_id = None + self.graph_execution.loads(graph_execution_payload) + + response_payload = payload.get("response_coordinator") + if response_payload is not None: + if self._graph is not None: + self.response_coordinator.loads(response_payload) + else: + self._pending_response_coordinator_dump = response_payload + else: + self._pending_response_coordinator_dump = None + self._response_coordinator = None + + paused_nodes_payload = payload.get("paused_nodes", []) + self._paused_nodes = set(map(str, paused_nodes_payload)) + + def register_paused_node(self, node_id: str) -> None: + """Record a node that should resume when execution is continued.""" + + self._paused_nodes.add(node_id) + + def consume_paused_nodes(self) -> list[str]: + """Retrieve and clear the list of paused nodes awaiting resume.""" + + nodes = list(self._paused_nodes) + self._paused_nodes.clear() + return nodes + + # ------------------------------------------------------------------ + # Builders + # ------------------------------------------------------------------ + def _build_ready_queue(self) -> ReadyQueueProtocol: + # Import lazily to avoid breaching architecture boundaries enforced by import-linter. + module = importlib.import_module("core.workflow.graph_engine.ready_queue") + in_memory_cls = module.InMemoryReadyQueue + return in_memory_cls() + + def _build_graph_execution(self) -> GraphExecutionProtocol: + # Lazily import to keep the runtime domain decoupled from graph_engine modules. + module = importlib.import_module("core.workflow.graph_engine.domain.graph_execution") + graph_execution_cls = module.GraphExecution + workflow_id = self._pending_graph_execution_workflow_id or "" + self._pending_graph_execution_workflow_id = None + return graph_execution_cls(workflow_id=workflow_id) + + def _build_response_coordinator(self, graph: GraphProtocol) -> ResponseStreamCoordinatorProtocol: + # Lazily import to keep the runtime domain decoupled from graph_engine modules. + module = importlib.import_module("core.workflow.graph_engine.response_coordinator") + coordinator_cls = module.ResponseStreamCoordinator + return coordinator_cls(variable_pool=self.variable_pool, graph=graph) diff --git a/api/core/workflow/graph/graph_runtime_state_protocol.py b/api/core/workflow/runtime/graph_runtime_state_protocol.py similarity index 76% rename from api/core/workflow/graph/graph_runtime_state_protocol.py rename to api/core/workflow/runtime/graph_runtime_state_protocol.py index d7961405ca..40835a936f 100644 --- a/api/core/workflow/graph/graph_runtime_state_protocol.py +++ b/api/core/workflow/runtime/graph_runtime_state_protocol.py @@ -16,6 +16,10 @@ class ReadOnlyVariablePool(Protocol): """Get all variables for a node (read-only).""" ... + def get_by_prefix(self, prefix: str) -> Mapping[str, object]: + """Get all variables stored under a given node prefix (read-only).""" + ... + class ReadOnlyGraphRuntimeState(Protocol): """ @@ -56,6 +60,20 @@ class ReadOnlyGraphRuntimeState(Protocol): """Get the node run steps count (read-only).""" ... + @property + def ready_queue_size(self) -> int: + """Get the number of nodes currently in the ready queue.""" + ... + + @property + def exceptions_count(self) -> int: + """Get the number of node execution exceptions recorded.""" + ... + def get_output(self, key: str, default: Any = None) -> Any: """Get a single output value (returns a copy).""" ... + + def dumps(self) -> str: + """Serialize the runtime state into a JSON snapshot (read-only).""" + ... diff --git a/api/core/workflow/graph/read_only_state_wrapper.py b/api/core/workflow/runtime/read_only_wrappers.py similarity index 54% rename from api/core/workflow/graph/read_only_state_wrapper.py rename to api/core/workflow/runtime/read_only_wrappers.py index 255bb5adee..664c365295 100644 --- a/api/core/workflow/graph/read_only_state_wrapper.py +++ b/api/core/workflow/runtime/read_only_wrappers.py @@ -1,77 +1,82 @@ +from __future__ import annotations + from collections.abc import Mapping from copy import deepcopy from typing import Any from core.model_runtime.entities.llm_entities import LLMUsage from core.variables.segments import Segment -from core.workflow.entities.graph_runtime_state import GraphRuntimeState -from core.workflow.entities.variable_pool import VariablePool + +from .graph_runtime_state import GraphRuntimeState +from .variable_pool import VariablePool class ReadOnlyVariablePoolWrapper: - """Wrapper that provides read-only access to VariablePool.""" + """Provide defensive, read-only access to ``VariablePool``.""" - def __init__(self, variable_pool: VariablePool): + def __init__(self, variable_pool: VariablePool) -> None: self._variable_pool = variable_pool def get(self, node_id: str, variable_key: str) -> Segment | None: - """Get a variable value (returns a defensive copy).""" + """Return a copy of a variable value if present.""" value = self._variable_pool.get([node_id, variable_key]) return deepcopy(value) if value is not None else None def get_all_by_node(self, node_id: str) -> Mapping[str, object]: - """Get all variables for a node (returns defensive copies).""" + """Return a copy of all variables for the specified node.""" variables: dict[str, object] = {} if node_id in self._variable_pool.variable_dictionary: - for key, var in self._variable_pool.variable_dictionary[node_id].items(): - # Variables have a value property that contains the actual data - variables[key] = deepcopy(var.value) + for key, variable in self._variable_pool.variable_dictionary[node_id].items(): + variables[key] = deepcopy(variable.value) return variables + def get_by_prefix(self, prefix: str) -> Mapping[str, object]: + """Return a copy of all variables stored under the given prefix.""" + return self._variable_pool.get_by_prefix(prefix) + class ReadOnlyGraphRuntimeStateWrapper: - """ - Wrapper that provides read-only access to GraphRuntimeState. + """Expose a defensive, read-only view of ``GraphRuntimeState``.""" - This wrapper ensures that layers can observe the state without - modifying it. All returned values are defensive copies. - """ - - def __init__(self, state: GraphRuntimeState): + def __init__(self, state: GraphRuntimeState) -> None: self._state = state self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool) @property def variable_pool(self) -> ReadOnlyVariablePoolWrapper: - """Get read-only access to the variable pool.""" return self._variable_pool_wrapper @property def start_at(self) -> float: - """Get the start time (read-only).""" return self._state.start_at @property def total_tokens(self) -> int: - """Get the total tokens count (read-only).""" return self._state.total_tokens @property def llm_usage(self) -> LLMUsage: - """Get a copy of LLM usage info (read-only).""" - # Return a copy to prevent modification return self._state.llm_usage.model_copy() @property def outputs(self) -> dict[str, Any]: - """Get a defensive copy of outputs (read-only).""" return deepcopy(self._state.outputs) @property def node_run_steps(self) -> int: - """Get the node run steps count (read-only).""" return self._state.node_run_steps + @property + def ready_queue_size(self) -> int: + return self._state.ready_queue.qsize() + + @property + def exceptions_count(self) -> int: + return self._state.graph_execution.exceptions_count + def get_output(self, key: str, default: Any = None) -> Any: - """Get a single output value (returns a copy).""" return self._state.get_output(key, default) + + def dumps(self) -> str: + """Serialize the underlying runtime state for external persistence.""" + return self._state.dumps() diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/runtime/variable_pool.py similarity index 95% rename from api/core/workflow/entities/variable_pool.py rename to api/core/workflow/runtime/variable_pool.py index 2dc00fd70b..5fd6e894f1 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/runtime/variable_pool.py @@ -1,6 +1,7 @@ import re from collections import defaultdict from collections.abc import Mapping, Sequence +from copy import deepcopy from typing import Annotated, Any, Union, cast from pydantic import BaseModel, Field @@ -235,6 +236,20 @@ class VariablePool(BaseModel): return segment return None + def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]: + """Return a copy of all variables stored under the given node prefix.""" + + nodes = self.variable_dictionary.get(prefix) + if not nodes: + return {} + + result: dict[str, object] = {} + for key, variable in nodes.items(): + value = variable.value + result[key] = deepcopy(value) + + return result + def _add_system_variables(self, system_variable: SystemVariable): sys_var_mapping = system_variable.to_dict() for key, value in sys_var_mapping.items(): diff --git a/api/core/workflow/utils/condition/processor.py b/api/core/workflow/utils/condition/processor.py index f4bbe9c3c3..650a44c681 100644 --- a/api/core/workflow/utils/condition/processor.py +++ b/api/core/workflow/utils/condition/processor.py @@ -5,7 +5,7 @@ from typing import Literal, NamedTuple from core.file import FileAttribute, file_manager from core.variables import ArrayFileSegment from core.variables.segments import ArrayBooleanSegment, BooleanSegment -from core.workflow.entities import VariablePool +from core.workflow.runtime import VariablePool from .entities import Condition, SubCondition, SupportedComparisonOperator diff --git a/api/core/workflow/variable_loader.py b/api/core/workflow/variable_loader.py index 1b31022495..ea0bdc3537 100644 --- a/api/core/workflow/variable_loader.py +++ b/api/core/workflow/variable_loader.py @@ -4,7 +4,7 @@ from typing import Any, Protocol from core.variables import Variable from core.variables.consts import SELECTORS_LENGTH -from core.workflow.entities.variable_pool import VariablePool +from core.workflow.runtime import VariablePool class VariableLoader(Protocol): diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py deleted file mode 100644 index a88f350a9e..0000000000 --- a/api/core/workflow/workflow_cycle_manager.py +++ /dev/null @@ -1,459 +0,0 @@ -from collections.abc import Mapping -from dataclasses import dataclass -from datetime import datetime -from typing import Any, Union - -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity -from core.app.entities.queue_entities import ( - QueueNodeExceptionEvent, - QueueNodeFailedEvent, - QueueNodeRetryEvent, - QueueNodeStartedEvent, - QueueNodeSucceededEvent, -) -from core.app.task_pipeline.exc import WorkflowRunNotFoundError -from core.ops.entities.trace_entity import TraceTaskName -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask -from core.workflow.entities import ( - WorkflowExecution, - WorkflowNodeExecution, -) -from core.workflow.enums import ( - SystemVariableKey, - WorkflowExecutionStatus, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, - WorkflowType, -) -from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository -from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from core.workflow.system_variable import SystemVariable -from core.workflow.workflow_entry import WorkflowEntry -from libs.datetime_utils import naive_utc_now -from libs.uuid_utils import uuidv7 - - -@dataclass -class CycleManagerWorkflowInfo: - workflow_id: str - workflow_type: WorkflowType - version: str - graph_data: Mapping[str, Any] - - -class WorkflowCycleManager: - def __init__( - self, - *, - application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], - workflow_system_variables: SystemVariable, - workflow_info: CycleManagerWorkflowInfo, - workflow_execution_repository: WorkflowExecutionRepository, - workflow_node_execution_repository: WorkflowNodeExecutionRepository, - ): - self._application_generate_entity = application_generate_entity - self._workflow_system_variables = workflow_system_variables - self._workflow_info = workflow_info - self._workflow_execution_repository = workflow_execution_repository - self._workflow_node_execution_repository = workflow_node_execution_repository - - # Initialize caches for workflow execution cycle - # These caches avoid redundant repository calls during a single workflow execution - self._workflow_execution_cache: dict[str, WorkflowExecution] = {} - self._node_execution_cache: dict[str, WorkflowNodeExecution] = {} - - def handle_workflow_run_start(self) -> WorkflowExecution: - inputs = self._prepare_workflow_inputs() - execution_id = self._get_or_generate_execution_id() - - execution = WorkflowExecution.new( - id_=execution_id, - workflow_id=self._workflow_info.workflow_id, - workflow_type=self._workflow_info.workflow_type, - workflow_version=self._workflow_info.version, - graph=self._workflow_info.graph_data, - inputs=inputs, - started_at=naive_utc_now(), - ) - - return self._save_and_cache_workflow_execution(execution) - - def handle_workflow_run_success( - self, - *, - workflow_run_id: str, - total_tokens: int, - total_steps: int, - outputs: Mapping[str, Any] | None = None, - conversation_id: str | None = None, - trace_manager: TraceQueueManager | None = None, - external_trace_id: str | None = None, - ) -> WorkflowExecution: - workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id) - - self._update_workflow_execution_completion( - workflow_execution, - status=WorkflowExecutionStatus.SUCCEEDED, - outputs=outputs, - total_tokens=total_tokens, - total_steps=total_steps, - ) - - self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id, external_trace_id) - - self._workflow_execution_repository.save(workflow_execution) - return workflow_execution - - def handle_workflow_run_partial_success( - self, - *, - workflow_run_id: str, - total_tokens: int, - total_steps: int, - outputs: Mapping[str, Any] | None = None, - exceptions_count: int = 0, - conversation_id: str | None = None, - trace_manager: TraceQueueManager | None = None, - external_trace_id: str | None = None, - ) -> WorkflowExecution: - execution = self._get_workflow_execution_or_raise_error(workflow_run_id) - - self._update_workflow_execution_completion( - execution, - status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED, - outputs=outputs, - total_tokens=total_tokens, - total_steps=total_steps, - exceptions_count=exceptions_count, - ) - - self._add_trace_task_if_needed(trace_manager, execution, conversation_id, external_trace_id) - - self._workflow_execution_repository.save(execution) - return execution - - def handle_workflow_run_failed( - self, - *, - workflow_run_id: str, - total_tokens: int, - total_steps: int, - status: WorkflowExecutionStatus, - error_message: str, - conversation_id: str | None = None, - trace_manager: TraceQueueManager | None = None, - exceptions_count: int = 0, - external_trace_id: str | None = None, - ) -> WorkflowExecution: - workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id) - now = naive_utc_now() - - self._update_workflow_execution_completion( - workflow_execution, - status=status, - total_tokens=total_tokens, - total_steps=total_steps, - error_message=error_message, - exceptions_count=exceptions_count, - finished_at=now, - ) - - self._fail_running_node_executions(workflow_execution.id_, error_message, now) - self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id, external_trace_id) - - self._workflow_execution_repository.save(workflow_execution) - return workflow_execution - - def handle_node_execution_start( - self, - *, - workflow_execution_id: str, - event: QueueNodeStartedEvent, - ) -> WorkflowNodeExecution: - workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id) - - domain_execution = self._create_node_execution_from_event( - workflow_execution=workflow_execution, - event=event, - status=WorkflowNodeExecutionStatus.RUNNING, - ) - - return self._save_and_cache_node_execution(domain_execution) - - def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: - domain_execution = self._get_node_execution_from_cache(event.node_execution_id) - - self._update_node_execution_completion( - domain_execution, - event=event, - status=WorkflowNodeExecutionStatus.SUCCEEDED, - ) - - self._workflow_node_execution_repository.save(domain_execution) - self._workflow_node_execution_repository.save_execution_data(domain_execution) - return domain_execution - - def handle_workflow_node_execution_failed( - self, - *, - event: QueueNodeFailedEvent | QueueNodeExceptionEvent, - ) -> WorkflowNodeExecution: - """ - Workflow node execution failed - :param event: queue node failed event - :return: - """ - domain_execution = self._get_node_execution_from_cache(event.node_execution_id) - - status = ( - WorkflowNodeExecutionStatus.EXCEPTION - if isinstance(event, QueueNodeExceptionEvent) - else WorkflowNodeExecutionStatus.FAILED - ) - - self._update_node_execution_completion( - domain_execution, - event=event, - status=status, - error=event.error, - handle_special_values=True, - ) - - self._workflow_node_execution_repository.save(domain_execution) - self._workflow_node_execution_repository.save_execution_data(domain_execution) - return domain_execution - - def handle_workflow_node_execution_retried( - self, *, workflow_execution_id: str, event: QueueNodeRetryEvent - ) -> WorkflowNodeExecution: - workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id) - - domain_execution = self._create_node_execution_from_event( - workflow_execution=workflow_execution, - event=event, - status=WorkflowNodeExecutionStatus.RETRY, - error=event.error, - created_at=event.start_at, - ) - - # Handle inputs and outputs - inputs = WorkflowEntry.handle_special_values(event.inputs) - outputs = event.outputs - metadata = self._merge_event_metadata(event) - - domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=metadata) - - execution = self._save_and_cache_node_execution(domain_execution) - self._workflow_node_execution_repository.save_execution_data(execution) - return execution - - def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution: - # Check cache first - if id in self._workflow_execution_cache: - return self._workflow_execution_cache[id] - - raise WorkflowRunNotFoundError(id) - - def _prepare_workflow_inputs(self) -> dict[str, Any]: - """Prepare workflow inputs by merging application inputs with system variables.""" - inputs = {**self._application_generate_entity.inputs} - - if self._workflow_system_variables: - for field_name, value in self._workflow_system_variables.to_dict().items(): - if field_name != SystemVariableKey.CONVERSATION_ID: - inputs[f"sys.{field_name}"] = value - - return dict(WorkflowEntry.handle_special_values(inputs) or {}) - - def _get_or_generate_execution_id(self) -> str: - """Get execution ID from system variables or generate a new one.""" - if self._workflow_system_variables and self._workflow_system_variables.workflow_execution_id: - return str(self._workflow_system_variables.workflow_execution_id) - return str(uuidv7()) - - def _save_and_cache_workflow_execution(self, execution: WorkflowExecution) -> WorkflowExecution: - """Save workflow execution to repository and cache it.""" - self._workflow_execution_repository.save(execution) - self._workflow_execution_cache[execution.id_] = execution - return execution - - def _save_and_cache_node_execution(self, execution: WorkflowNodeExecution) -> WorkflowNodeExecution: - """Save node execution to repository and cache it if it has an ID. - - This does not persist the `inputs` / `process_data` / `outputs` fields of the execution model. - """ - self._workflow_node_execution_repository.save(execution) - if execution.node_execution_id: - self._node_execution_cache[execution.node_execution_id] = execution - return execution - - def _get_node_execution_from_cache(self, node_execution_id: str) -> WorkflowNodeExecution: - """Get node execution from cache or raise error if not found.""" - domain_execution = self._node_execution_cache.get(node_execution_id) - if not domain_execution: - raise ValueError(f"Domain node execution not found: {node_execution_id}") - return domain_execution - - def _update_workflow_execution_completion( - self, - execution: WorkflowExecution, - *, - status: WorkflowExecutionStatus, - total_tokens: int, - total_steps: int, - outputs: Mapping[str, Any] | None = None, - error_message: str | None = None, - exceptions_count: int = 0, - finished_at: datetime | None = None, - ): - """Update workflow execution with completion data.""" - execution.status = status - execution.outputs = outputs or {} - execution.total_tokens = total_tokens - execution.total_steps = total_steps - execution.finished_at = finished_at or naive_utc_now() - execution.exceptions_count = exceptions_count - if error_message: - execution.error_message = error_message - - def _add_trace_task_if_needed( - self, - trace_manager: TraceQueueManager | None, - workflow_execution: WorkflowExecution, - conversation_id: str | None, - external_trace_id: str | None, - ): - """Add trace task if trace manager is provided.""" - if trace_manager: - trace_manager.add_trace_task( - TraceTask( - TraceTaskName.WORKFLOW_TRACE, - workflow_execution=workflow_execution, - conversation_id=conversation_id, - user_id=trace_manager.user_id, - external_trace_id=external_trace_id, - ) - ) - - def _fail_running_node_executions( - self, - workflow_execution_id: str, - error_message: str, - now: datetime, - ): - """Fail all running node executions for a workflow.""" - running_node_executions = [ - node_exec - for node_exec in self._node_execution_cache.values() - if node_exec.workflow_execution_id == workflow_execution_id - and node_exec.status == WorkflowNodeExecutionStatus.RUNNING - ] - - for node_execution in running_node_executions: - if node_execution.node_execution_id: - node_execution.status = WorkflowNodeExecutionStatus.FAILED - node_execution.error = error_message - node_execution.finished_at = now - node_execution.elapsed_time = (now - node_execution.created_at).total_seconds() - self._workflow_node_execution_repository.save(node_execution) - - def _create_node_execution_from_event( - self, - *, - workflow_execution: WorkflowExecution, - event: QueueNodeStartedEvent, - status: WorkflowNodeExecutionStatus, - error: str | None = None, - created_at: datetime | None = None, - ) -> WorkflowNodeExecution: - """Create a node execution from an event.""" - now = naive_utc_now() - created_at = created_at or now - - metadata = { - WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, - WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id, - WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id, - } - - domain_execution = WorkflowNodeExecution( - id=event.node_execution_id, - workflow_id=workflow_execution.workflow_id, - workflow_execution_id=workflow_execution.id_, - predecessor_node_id=event.predecessor_node_id, - index=event.node_run_index, - node_execution_id=event.node_execution_id, - node_id=event.node_id, - node_type=event.node_type, - title=event.node_title, - status=status, - metadata=metadata, - created_at=created_at, - error=error, - ) - - if status == WorkflowNodeExecutionStatus.RETRY: - domain_execution.finished_at = now - domain_execution.elapsed_time = (now - created_at).total_seconds() - - return domain_execution - - def _update_node_execution_completion( - self, - domain_execution: WorkflowNodeExecution, - *, - event: Union[ - QueueNodeSucceededEvent, - QueueNodeFailedEvent, - QueueNodeExceptionEvent, - ], - status: WorkflowNodeExecutionStatus, - error: str | None = None, - handle_special_values: bool = False, - ): - """Update node execution with completion data.""" - finished_at = naive_utc_now() - elapsed_time = (finished_at - event.start_at).total_seconds() - - # Process data - if handle_special_values: - inputs = WorkflowEntry.handle_special_values(event.inputs) - process_data = WorkflowEntry.handle_special_values(event.process_data) - else: - inputs = event.inputs - process_data = event.process_data - - outputs = event.outputs - - # Convert metadata - execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, Any] = {} - if event.execution_metadata: - execution_metadata_dict.update(event.execution_metadata) - - # Update domain model - domain_execution.status = status - domain_execution.update_from_mapping( - inputs=inputs, - process_data=process_data, - outputs=outputs, - metadata=execution_metadata_dict, - ) - domain_execution.finished_at = finished_at - domain_execution.elapsed_time = elapsed_time - - if error: - domain_execution.error = error - - def _merge_event_metadata(self, event: QueueNodeRetryEvent) -> dict[WorkflowNodeExecutionMetadataKey, str | None]: - """Merge event metadata with origin metadata.""" - origin_metadata = { - WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id, - WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, - WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id, - } - - execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, str | None] = {} - if event.execution_metadata: - execution_metadata_dict.update(event.execution_metadata) - - return {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 4cd885cfa5..742c42ec2b 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -9,7 +9,7 @@ from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.file.models import File from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID -from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool +from core.workflow.entities import GraphInitParams from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph import Graph from core.workflow.graph_engine import GraphEngine @@ -20,6 +20,7 @@ from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, Gra from core.workflow.nodes import NodeType from core.workflow.nodes.base.node import Node from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from factories import file_factory diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index d2ba462a37..f6dddd75a3 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -37,7 +37,6 @@ from core.rag.entities.event import ( from core.repositories.factory import DifyCoreRepositoryFactory from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.variables.variables import Variable -from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, @@ -50,6 +49,7 @@ from core.workflow.node_events.base import NodeRunResult from core.workflow.nodes.base.node import Node from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.repositories.workflow_node_execution_repository import OrderConfig +from core.workflow.runtime import VariablePool from core.workflow.system_variable import SystemVariable from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index f765c229ab..2f69e46074 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -14,7 +14,7 @@ from core.file import File from core.repositories import DifyCoreRepositoryFactory from core.variables import Variable from core.variables.variables import VariableUnion -from core.workflow.entities import VariablePool, WorkflowNodeExecution +from core.workflow.entities import WorkflowNodeExecution from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent @@ -23,6 +23,7 @@ from core.workflow.nodes import NodeType from core.workflow.nodes.base.node import Node from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.start.entities import StartNodeData +from core.workflow.runtime import VariablePool from core.workflow.system_variable import SystemVariable from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index b62d8aa544..78878cdeef 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -5,12 +5,13 @@ import pytest from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool +from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.node_events import NodeRunResult from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.node_factory import DifyNodeFactory +from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from models.enums import UserFrom from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index ea99beacaa..2367990d3e 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -5,10 +5,11 @@ from urllib.parse import urlencode import pytest from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool +from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph from core.workflow.nodes.http_request.node import HttpRequestNode from core.workflow.nodes.node_factory import DifyNodeFactory +from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from models.enums import UserFrom from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock @@ -174,13 +175,13 @@ def test_custom_authorization_header(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_custom_auth_with_empty_api_key_does_not_set_header(setup_http_mock): """Test: In custom authentication mode, when the api_key is empty, no header should be set.""" - from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.http_request.entities import ( HttpRequestNodeAuthorization, HttpRequestNodeData, HttpRequestNodeTimeout, ) from core.workflow.nodes.http_request.executor import Executor + from core.workflow.runtime import VariablePool from core.workflow.system_variable import SystemVariable # Create variable pool diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 31281cd8ad..3b16c3920b 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -6,12 +6,13 @@ from unittest.mock import MagicMock, patch from core.app.entities.app_invoke_entities import InvokeFrom from core.llm_generator.output_parser.structured_output import _parse_structured_output -from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool +from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.node_events import StreamCompletedEvent from core.workflow.nodes.llm.node import LLMNode from core.workflow.nodes.node_factory import DifyNodeFactory +from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from extensions.ext_database import db from models.enums import UserFrom diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 76918f689f..9d9102cee2 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -5,11 +5,12 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom from core.model_runtime.entities import AssistantPromptMessage -from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool +from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from extensions.ext_database import db from models.enums import UserFrom diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 53252c7f2e..285387b817 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -4,11 +4,12 @@ import uuid import pytest from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool +from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode +from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from models.enums import UserFrom from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 16d44d1eaf..8dd8150b1c 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -4,12 +4,13 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.utils.configuration import ToolParameterConfigurationManager -from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool +from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.node_events import StreamCompletedEvent from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.nodes.tool.tool_node import ToolNode +from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from models.enums import UserFrom diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py index bb1d5e2f67..3a4fdc3cd8 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py @@ -99,6 +99,8 @@ class TestAdvancedChatAppRunnerConversationVariables: workflow=mock_workflow, system_user_id=str(uuid4()), app=MagicMock(), + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), ) # Mock database session @@ -237,6 +239,8 @@ class TestAdvancedChatAppRunnerConversationVariables: workflow=mock_workflow, system_user_id=str(uuid4()), app=MagicMock(), + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), ) # Mock database session @@ -390,6 +394,8 @@ class TestAdvancedChatAppRunnerConversationVariables: workflow=mock_workflow, system_user_id=str(uuid4()), app=MagicMock(), + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), ) # Mock database session diff --git a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py new file mode 100644 index 0000000000..cd5ea8986a --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py @@ -0,0 +1,63 @@ +from types import SimpleNamespace + +import pytest + +from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport +from core.workflow.runtime import GraphRuntimeState +from core.workflow.runtime.variable_pool import VariablePool +from core.workflow.system_variable import SystemVariable + + +def _make_state(workflow_run_id: str | None) -> GraphRuntimeState: + variable_pool = VariablePool(system_variables=SystemVariable(workflow_execution_id=workflow_run_id)) + return GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) + + +class _StubPipeline(GraphRuntimeStateSupport): + def __init__(self, *, cached_state: GraphRuntimeState | None, queue_state: GraphRuntimeState | None): + self._graph_runtime_state = cached_state + self._base_task_pipeline = SimpleNamespace(queue_manager=SimpleNamespace(graph_runtime_state=queue_state)) + + +def test_ensure_graph_runtime_initialized_caches_explicit_state(): + explicit_state = _make_state("run-explicit") + pipeline = _StubPipeline(cached_state=None, queue_state=None) + + resolved = pipeline._ensure_graph_runtime_initialized(explicit_state) + + assert resolved is explicit_state + assert pipeline._graph_runtime_state is explicit_state + + +def test_resolve_graph_runtime_state_reads_from_queue_when_cache_empty(): + queued_state = _make_state("run-queue") + pipeline = _StubPipeline(cached_state=None, queue_state=queued_state) + + resolved = pipeline._resolve_graph_runtime_state() + + assert resolved is queued_state + assert pipeline._graph_runtime_state is queued_state + + +def test_resolve_graph_runtime_state_raises_when_no_state_available(): + pipeline = _StubPipeline(cached_state=None, queue_state=None) + + with pytest.raises(ValueError): + pipeline._resolve_graph_runtime_state() + + +def test_extract_workflow_run_id_returns_value(): + state = _make_state("run-identifier") + pipeline = _StubPipeline(cached_state=state, queue_state=None) + + run_id = pipeline._extract_workflow_run_id(state) + + assert run_id == "run-identifier" + + +def test_extract_workflow_run_id_raises_when_missing(): + state = _make_state(None) + pipeline = _StubPipeline(cached_state=state, queue_state=None) + + with pytest.raises(ValueError): + pipeline._extract_workflow_run_id(state) diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_process_data.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_process_data.py index 3366666a47..abe09fb8a4 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_process_data.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_process_data.py @@ -3,8 +3,7 @@ Unit tests for WorkflowResponseConverter focusing on process_data truncation fun """ import uuid -from dataclasses import dataclass -from datetime import datetime +from collections.abc import Mapping from typing import Any from unittest.mock import Mock @@ -12,24 +11,17 @@ import pytest from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity -from core.app.entities.queue_entities import QueueNodeRetryEvent, QueueNodeSucceededEvent -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus +from core.app.entities.queue_entities import ( + QueueNodeRetryEvent, + QueueNodeStartedEvent, + QueueNodeSucceededEvent, +) from core.workflow.enums import NodeType +from core.workflow.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now from models import Account -@dataclass -class ProcessDataResponseScenario: - """Test scenario for process_data in responses.""" - - name: str - original_process_data: dict[str, Any] | None - truncated_process_data: dict[str, Any] | None - expected_response_data: dict[str, Any] | None - expected_truncated_flag: bool - - class TestWorkflowResponseConverterCenarios: """Test process_data truncation in WorkflowResponseConverter.""" @@ -39,6 +31,7 @@ class TestWorkflowResponseConverterCenarios: mock_app_config = Mock() mock_app_config.tenant_id = "test-tenant-id" mock_entity.app_config = mock_app_config + mock_entity.inputs = {} return mock_entity def create_workflow_response_converter(self) -> WorkflowResponseConverter: @@ -50,54 +43,59 @@ class TestWorkflowResponseConverterCenarios: mock_user.name = "Test User" mock_user.email = "test@example.com" - return WorkflowResponseConverter(application_generate_entity=mock_entity, user=mock_user) - - def create_workflow_node_execution( - self, - process_data: dict[str, Any] | None = None, - truncated_process_data: dict[str, Any] | None = None, - execution_id: str = "test-execution-id", - ) -> WorkflowNodeExecution: - """Create a WorkflowNodeExecution for testing.""" - execution = WorkflowNodeExecution( - id=execution_id, - workflow_id="test-workflow-id", - workflow_execution_id="test-run-id", - index=1, - node_id="test-node-id", - node_type=NodeType.LLM, - title="Test Node", - process_data=process_data, - status=WorkflowNodeExecutionStatus.SUCCEEDED, - created_at=datetime.now(), - finished_at=datetime.now(), + system_variables = SystemVariable(workflow_id="wf-id", workflow_execution_id="initial-run-id") + return WorkflowResponseConverter( + application_generate_entity=mock_entity, + user=mock_user, + system_variables=system_variables, ) - if truncated_process_data is not None: - execution.set_truncated_process_data(truncated_process_data) + def create_node_started_event(self, *, node_execution_id: str | None = None) -> QueueNodeStartedEvent: + """Create a QueueNodeStartedEvent for testing.""" + return QueueNodeStartedEvent( + node_execution_id=node_execution_id or str(uuid.uuid4()), + node_id="test-node-id", + node_title="Test Node", + node_type=NodeType.CODE, + start_at=naive_utc_now(), + predecessor_node_id=None, + in_iteration_id=None, + in_loop_id=None, + provider_type="built-in", + provider_id="code", + ) - return execution - - def create_node_succeeded_event(self) -> QueueNodeSucceededEvent: + def create_node_succeeded_event( + self, + *, + node_execution_id: str, + process_data: Mapping[str, Any] | None = None, + ) -> QueueNodeSucceededEvent: """Create a QueueNodeSucceededEvent for testing.""" return QueueNodeSucceededEvent( node_id="test-node-id", node_type=NodeType.CODE, - node_execution_id=str(uuid.uuid4()), + node_execution_id=node_execution_id, start_at=naive_utc_now(), - parallel_id=None, - parallel_start_node_id=None, - parent_parallel_id=None, - parent_parallel_start_node_id=None, in_iteration_id=None, in_loop_id=None, + inputs={}, + process_data=process_data or {}, + outputs={}, + execution_metadata={}, ) - def create_node_retry_event(self) -> QueueNodeRetryEvent: + def create_node_retry_event( + self, + *, + node_execution_id: str, + process_data: Mapping[str, Any] | None = None, + ) -> QueueNodeRetryEvent: """Create a QueueNodeRetryEvent for testing.""" return QueueNodeRetryEvent( inputs={"data": "inputs"}, outputs={"data": "outputs"}, + process_data=process_data or {}, error="oops", retry_index=1, node_id="test-node-id", @@ -105,12 +103,8 @@ class TestWorkflowResponseConverterCenarios: node_title="test code", provider_type="built-in", provider_id="code", - node_execution_id=str(uuid.uuid4()), + node_execution_id=node_execution_id, start_at=naive_utc_now(), - parallel_id=None, - parallel_start_node_id=None, - parent_parallel_id=None, - parent_parallel_start_node_id=None, in_iteration_id=None, in_loop_id=None, ) @@ -122,15 +116,28 @@ class TestWorkflowResponseConverterCenarios: original_data = {"large_field": "x" * 10000, "metadata": "info"} truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"} - execution = self.create_workflow_node_execution( - process_data=original_data, truncated_process_data=truncated_data + converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id") + start_event = self.create_node_started_event() + converter.workflow_node_start_to_stream_response( + event=start_event, + task_id="test-task-id", ) - event = self.create_node_succeeded_event() + + event = self.create_node_succeeded_event( + node_execution_id=start_event.node_execution_id, + process_data=original_data, + ) + + def fake_truncate(mapping): + if mapping == dict(original_data): + return truncated_data, True + return mapping, False + + converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment] response = converter.workflow_node_finish_to_stream_response( event=event, task_id="test-task-id", - workflow_node_execution=execution, ) # Response should use truncated data, not original @@ -145,13 +152,26 @@ class TestWorkflowResponseConverterCenarios: original_data = {"small": "data"} - execution = self.create_workflow_node_execution(process_data=original_data) - event = self.create_node_succeeded_event() + converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id") + start_event = self.create_node_started_event() + converter.workflow_node_start_to_stream_response( + event=start_event, + task_id="test-task-id", + ) + + event = self.create_node_succeeded_event( + node_execution_id=start_event.node_execution_id, + process_data=original_data, + ) + + def fake_truncate(mapping): + return mapping, False + + converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment] response = converter.workflow_node_finish_to_stream_response( event=event, task_id="test-task-id", - workflow_node_execution=execution, ) # Response should use original data @@ -163,18 +183,31 @@ class TestWorkflowResponseConverterCenarios: """Test node finish response when process_data is None.""" converter = self.create_workflow_response_converter() - execution = self.create_workflow_node_execution(process_data=None) - event = self.create_node_succeeded_event() + converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id") + start_event = self.create_node_started_event() + converter.workflow_node_start_to_stream_response( + event=start_event, + task_id="test-task-id", + ) + + event = self.create_node_succeeded_event( + node_execution_id=start_event.node_execution_id, + process_data=None, + ) + + def fake_truncate(mapping): + return mapping, False + + converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment] response = converter.workflow_node_finish_to_stream_response( event=event, task_id="test-task-id", - workflow_node_execution=execution, ) - # Response should have None process_data + # Response should normalize missing process_data to an empty mapping assert response is not None - assert response.data.process_data is None + assert response.data.process_data == {} assert response.data.process_data_truncated is False def test_workflow_node_retry_response_uses_truncated_process_data(self): @@ -184,15 +217,28 @@ class TestWorkflowResponseConverterCenarios: original_data = {"large_field": "x" * 10000, "metadata": "info"} truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"} - execution = self.create_workflow_node_execution( - process_data=original_data, truncated_process_data=truncated_data + converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id") + start_event = self.create_node_started_event() + converter.workflow_node_start_to_stream_response( + event=start_event, + task_id="test-task-id", ) - event = self.create_node_retry_event() + + event = self.create_node_retry_event( + node_execution_id=start_event.node_execution_id, + process_data=original_data, + ) + + def fake_truncate(mapping): + if mapping == dict(original_data): + return truncated_data, True + return mapping, False + + converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment] response = converter.workflow_node_retry_to_stream_response( event=event, task_id="test-task-id", - workflow_node_execution=execution, ) # Response should use truncated data, not original @@ -207,224 +253,72 @@ class TestWorkflowResponseConverterCenarios: original_data = {"small": "data"} - execution = self.create_workflow_node_execution(process_data=original_data) - event = self.create_node_retry_event() + converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id") + start_event = self.create_node_started_event() + converter.workflow_node_start_to_stream_response( + event=start_event, + task_id="test-task-id", + ) + + event = self.create_node_retry_event( + node_execution_id=start_event.node_execution_id, + process_data=original_data, + ) + + def fake_truncate(mapping): + return mapping, False + + converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment] response = converter.workflow_node_retry_to_stream_response( event=event, task_id="test-task-id", - workflow_node_execution=execution, ) - # Response should use original data assert response is not None assert response.data.process_data == original_data assert response.data.process_data_truncated is False def test_iteration_and_loop_nodes_return_none(self): - """Test that iteration and loop nodes return None (no change from existing behavior).""" + """Test that iteration and loop nodes return None (no streaming events).""" converter = self.create_workflow_response_converter() - # Test iteration node - iteration_execution = self.create_workflow_node_execution(process_data={"test": "data"}) - iteration_execution.node_type = NodeType.ITERATION - - event = self.create_node_succeeded_event() - - response = converter.workflow_node_finish_to_stream_response( - event=event, - task_id="test-task-id", - workflow_node_execution=iteration_execution, - ) - - # Should return None for iteration nodes - assert response is None - - # Test loop node - loop_execution = self.create_workflow_node_execution(process_data={"test": "data"}) - loop_execution.node_type = NodeType.LOOP - - response = converter.workflow_node_finish_to_stream_response( - event=event, - task_id="test-task-id", - workflow_node_execution=loop_execution, - ) - - # Should return None for loop nodes - assert response is None - - def test_execution_without_workflow_execution_id_returns_none(self): - """Test that executions without workflow_execution_id return None.""" - converter = self.create_workflow_response_converter() - - execution = self.create_workflow_node_execution(process_data={"test": "data"}) - execution.workflow_execution_id = None # Single-step debugging - - event = self.create_node_succeeded_event() - - response = converter.workflow_node_finish_to_stream_response( - event=event, - task_id="test-task-id", - workflow_node_execution=execution, - ) - - # Should return None for single-step debugging - assert response is None - - @staticmethod - def get_process_data_response_scenarios() -> list[ProcessDataResponseScenario]: - """Create test scenarios for process_data responses.""" - return [ - ProcessDataResponseScenario( - name="none_process_data", - original_process_data=None, - truncated_process_data=None, - expected_response_data=None, - expected_truncated_flag=False, - ), - ProcessDataResponseScenario( - name="small_process_data_no_truncation", - original_process_data={"small": "data"}, - truncated_process_data=None, - expected_response_data={"small": "data"}, - expected_truncated_flag=False, - ), - ProcessDataResponseScenario( - name="large_process_data_with_truncation", - original_process_data={"large": "x" * 10000, "metadata": "info"}, - truncated_process_data={"large": "[TRUNCATED]", "metadata": "info"}, - expected_response_data={"large": "[TRUNCATED]", "metadata": "info"}, - expected_truncated_flag=True, - ), - ProcessDataResponseScenario( - name="empty_process_data", - original_process_data={}, - truncated_process_data=None, - expected_response_data={}, - expected_truncated_flag=False, - ), - ProcessDataResponseScenario( - name="complex_data_with_truncation", - original_process_data={ - "logs": ["entry"] * 1000, # Large array - "config": {"setting": "value"}, - "status": "processing", - }, - truncated_process_data={ - "logs": "[TRUNCATED: 1000 items]", - "config": {"setting": "value"}, - "status": "processing", - }, - expected_response_data={ - "logs": "[TRUNCATED: 1000 items]", - "config": {"setting": "value"}, - "status": "processing", - }, - expected_truncated_flag=True, - ), - ] - - @pytest.mark.parametrize( - "scenario", - get_process_data_response_scenarios(), - ids=[scenario.name for scenario in get_process_data_response_scenarios()], - ) - def test_node_finish_response_scenarios(self, scenario: ProcessDataResponseScenario): - """Test various scenarios for node finish responses.""" - - mock_user = Mock(spec=Account) - mock_user.id = "test-user-id" - mock_user.name = "Test User" - mock_user.email = "test@example.com" - - converter = WorkflowResponseConverter( - application_generate_entity=Mock(spec=WorkflowAppGenerateEntity, app_config=Mock(tenant_id="test-tenant")), - user=mock_user, - ) - - execution = WorkflowNodeExecution( - id="test-execution-id", - workflow_id="test-workflow-id", - workflow_execution_id="test-run-id", - index=1, - node_id="test-node-id", - node_type=NodeType.LLM, - title="Test Node", - process_data=scenario.original_process_data, - status=WorkflowNodeExecutionStatus.SUCCEEDED, - created_at=datetime.now(), - finished_at=datetime.now(), - ) - - if scenario.truncated_process_data is not None: - execution.set_truncated_process_data(scenario.truncated_process_data) - - event = QueueNodeSucceededEvent( - node_id="test-node-id", - node_type=NodeType.CODE, + iteration_event = QueueNodeSucceededEvent( + node_id="iteration-node", + node_type=NodeType.ITERATION, node_execution_id=str(uuid.uuid4()), start_at=naive_utc_now(), - parallel_id=None, - parallel_start_node_id=None, - parent_parallel_id=None, - parent_parallel_start_node_id=None, in_iteration_id=None, in_loop_id=None, + inputs={}, + process_data={}, + outputs={}, + execution_metadata={}, ) response = converter.workflow_node_finish_to_stream_response( - event=event, + event=iteration_event, task_id="test-task-id", - workflow_node_execution=execution, ) + assert response is None - assert response is not None - assert response.data.process_data == scenario.expected_response_data - assert response.data.process_data_truncated == scenario.expected_truncated_flag - - @pytest.mark.parametrize( - "scenario", - get_process_data_response_scenarios(), - ids=[scenario.name for scenario in get_process_data_response_scenarios()], - ) - def test_node_retry_response_scenarios(self, scenario: ProcessDataResponseScenario): - """Test various scenarios for node retry responses.""" - - mock_user = Mock(spec=Account) - mock_user.id = "test-user-id" - mock_user.name = "Test User" - mock_user.email = "test@example.com" - - converter = WorkflowResponseConverter( - application_generate_entity=Mock(spec=WorkflowAppGenerateEntity, app_config=Mock(tenant_id="test-tenant")), - user=mock_user, - ) - - execution = WorkflowNodeExecution( - id="test-execution-id", - workflow_id="test-workflow-id", - workflow_execution_id="test-run-id", - index=1, - node_id="test-node-id", - node_type=NodeType.LLM, - title="Test Node", - process_data=scenario.original_process_data, - status=WorkflowNodeExecutionStatus.FAILED, # Retry scenario - created_at=datetime.now(), - finished_at=datetime.now(), - ) - - if scenario.truncated_process_data is not None: - execution.set_truncated_process_data(scenario.truncated_process_data) - - event = self.create_node_retry_event() - - response = converter.workflow_node_retry_to_stream_response( - event=event, + loop_event = iteration_event.model_copy(update={"node_type": NodeType.LOOP}) + response = converter.workflow_node_finish_to_stream_response( + event=loop_event, task_id="test-task-id", - workflow_node_execution=execution, + ) + assert response is None + + def test_finish_without_start_raises(self): + """Ensure finish responses require a prior workflow start.""" + converter = self.create_workflow_response_converter() + event = self.create_node_succeeded_event( + node_execution_id=str(uuid.uuid4()), + process_data={}, ) - assert response is not None - assert response.data.process_data == scenario.expected_response_data - assert response.data.process_data_truncated == scenario.expected_truncated_flag + with pytest.raises(ValueError): + converter.workflow_node_finish_to_stream_response( + event=event, + task_id="test-task-id", + ) diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index 5cd595088a..af4f96ba23 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -37,7 +37,7 @@ from core.variables.variables import ( Variable, VariableUnion, ) -from core.workflow.entities import VariablePool +from core.workflow.runtime import VariablePool from core.workflow.system_variable import SystemVariable diff --git a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py index 2614424dc7..5ecaeb60ac 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py +++ b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py @@ -1,9 +1,11 @@ +import json from time import time +from unittest.mock import MagicMock, patch import pytest -from core.workflow.entities.graph_runtime_state import GraphRuntimeState -from core.workflow.entities.variable_pool import VariablePool +from core.model_runtime.entities.llm_entities import LLMUsage +from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool class TestGraphRuntimeState: @@ -95,3 +97,141 @@ class TestGraphRuntimeState: # Test add_tokens validation with pytest.raises(ValueError): state.add_tokens(-1) + + def test_ready_queue_default_instantiation(self): + state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) + + queue = state.ready_queue + + from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue + + assert isinstance(queue, InMemoryReadyQueue) + assert state.ready_queue is queue + + def test_graph_execution_lazy_instantiation(self): + state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) + + execution = state.graph_execution + + from core.workflow.graph_engine.domain.graph_execution import GraphExecution + + assert isinstance(execution, GraphExecution) + assert execution.workflow_id == "" + assert state.graph_execution is execution + + def test_response_coordinator_configuration(self): + variable_pool = VariablePool() + state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) + + with pytest.raises(ValueError): + _ = state.response_coordinator + + mock_graph = MagicMock() + with patch("core.workflow.graph_engine.response_coordinator.ResponseStreamCoordinator") as coordinator_cls: + coordinator_instance = MagicMock() + coordinator_cls.return_value = coordinator_instance + + state.configure(graph=mock_graph) + + assert state.response_coordinator is coordinator_instance + coordinator_cls.assert_called_once_with(variable_pool=variable_pool, graph=mock_graph) + + # Configure again with same graph should be idempotent + state.configure(graph=mock_graph) + + other_graph = MagicMock() + with pytest.raises(ValueError): + state.attach_graph(other_graph) + + def test_read_only_wrapper_exposes_additional_state(self): + state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) + state.configure() + + wrapper = ReadOnlyGraphRuntimeStateWrapper(state) + + assert wrapper.ready_queue_size == 0 + assert wrapper.exceptions_count == 0 + + def test_read_only_wrapper_serializes_runtime_state(self): + state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) + state.total_tokens = 5 + state.set_output("result", {"success": True}) + state.ready_queue.put("node-1") + + wrapper = ReadOnlyGraphRuntimeStateWrapper(state) + + wrapper_snapshot = json.loads(wrapper.dumps()) + state_snapshot = json.loads(state.dumps()) + + assert wrapper_snapshot == state_snapshot + + def test_dumps_and_loads_roundtrip_with_response_coordinator(self): + variable_pool = VariablePool() + variable_pool.add(("node1", "value"), "payload") + + state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) + state.total_tokens = 10 + state.node_run_steps = 3 + state.set_output("final", {"result": True}) + usage = LLMUsage.from_metadata( + { + "prompt_tokens": 2, + "completion_tokens": 3, + "total_tokens": 5, + "total_price": "1.23", + "currency": "USD", + "latency": 0.5, + } + ) + state.llm_usage = usage + state.ready_queue.put("node-A") + + graph_execution = state.graph_execution + graph_execution.workflow_id = "wf-123" + graph_execution.exceptions_count = 4 + graph_execution.started = True + + class StubCoordinator: + def __init__(self) -> None: + self.state = "initial" + + def dumps(self) -> str: + return json.dumps({"state": self.state}) + + def loads(self, data: str) -> None: + payload = json.loads(data) + self.state = payload["state"] + + mock_graph = MagicMock() + stub = StubCoordinator() + with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=stub): + state.attach_graph(mock_graph) + + stub.state = "configured" + + snapshot = state.dumps() + + restored = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0) + restored.loads(snapshot) + + assert restored.total_tokens == 10 + assert restored.node_run_steps == 3 + assert restored.get_output("final") == {"result": True} + assert restored.llm_usage.total_tokens == usage.total_tokens + assert restored.ready_queue.qsize() == 1 + assert restored.ready_queue.get(timeout=0.01) == "node-A" + + restored_segment = restored.variable_pool.get(("node1", "value")) + assert restored_segment is not None + assert restored_segment.value == "payload" + + restored_execution = restored.graph_execution + assert restored_execution.workflow_id == "wf-123" + assert restored_execution.exceptions_count == 4 + assert restored_execution.started is True + + new_stub = StubCoordinator() + with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub): + restored.attach_graph(mock_graph) + + assert new_stub.state == "configured" diff --git a/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py b/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py index 68fe82d05e..f9de456b19 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py @@ -4,7 +4,7 @@ from core.variables.segments import ( NoneSegment, StringSegment, ) -from core.workflow.entities.variable_pool import VariablePool +from core.workflow.runtime import VariablePool class TestVariablePoolGetAndNestedAttribute: diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py b/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py new file mode 100644 index 0000000000..15d1dcb48d --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py @@ -0,0 +1,59 @@ +from unittest.mock import MagicMock + +import pytest + +from core.workflow.enums import NodeType +from core.workflow.graph import Graph +from core.workflow.nodes.base.node import Node + + +def _make_node(node_id: str, node_type: NodeType = NodeType.START) -> Node: + node = MagicMock(spec=Node) + node.id = node_id + node.node_type = node_type + node.execution_type = None # attribute not used in builder path + return node + + +def test_graph_builder_creates_linear_graph(): + builder = Graph.new() + root = _make_node("root", NodeType.START) + mid = _make_node("mid", NodeType.LLM) + end = _make_node("end", NodeType.END) + + graph = builder.add_root(root).add_node(mid).add_node(end).build() + + assert graph.root_node is root + assert graph.nodes == {"root": root, "mid": mid, "end": end} + assert len(graph.edges) == 2 + first_edge = next(iter(graph.edges.values())) + assert first_edge.tail == "root" + assert first_edge.head == "mid" + assert graph.out_edges["mid"] == [edge_id for edge_id, edge in graph.edges.items() if edge.tail == "mid"] + + +def test_graph_builder_supports_custom_predecessor(): + builder = Graph.new() + root = _make_node("root") + branch = _make_node("branch") + other = _make_node("other") + + graph = builder.add_root(root).add_node(branch).add_node(other, from_node_id="root").build() + + outgoing_root = graph.out_edges["root"] + assert len(outgoing_root) == 2 + edge_targets = {graph.edges[eid].head for eid in outgoing_root} + assert edge_targets == {"branch", "other"} + + +def test_graph_builder_validates_usage(): + builder = Graph.new() + node = _make_node("node") + + with pytest.raises(ValueError, match="Root node"): + builder.add_node(node) + + builder.add_root(node) + duplicate = _make_node("node") + with pytest.raises(ValueError, match="Duplicate"): + builder.add_node(duplicate) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/README.md b/api/tests/unit_tests/core/workflow/graph_engine/README.md index bff82b3ac4..3fff4cf6a9 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/README.md +++ b/api/tests/unit_tests/core/workflow/graph_engine/README.md @@ -20,9 +20,6 @@ The TableTestRunner (`test_table_runner.py`) provides a robust table-driven test - **Mock configuration** - Seamless integration with the auto-mock system - **Performance metrics** - Track execution times and bottlenecks - **Detailed error reporting** - Comprehensive failure diagnostics -- **Test tagging** - Organize and filter tests by tags -- **Retry mechanism** - Handle flaky tests gracefully -- **Custom validators** - Define custom validation logic ### Basic Usage @@ -68,49 +65,6 @@ suite_result = runner.run_table_tests( print(f"Success rate: {suite_result.success_rate:.1f}%") ``` -#### Test Tagging and Filtering - -```python -test_case = WorkflowTestCase( - fixture_path="workflow", - inputs={}, - expected_outputs={}, - tags=["smoke", "critical"], -) - -# Run only tests with specific tags -suite_result = runner.run_table_tests( - test_cases, - tags_filter=["smoke"] -) -``` - -#### Retry Mechanism - -```python -test_case = WorkflowTestCase( - fixture_path="flaky_workflow", - inputs={}, - expected_outputs={}, - retry_count=2, # Retry up to 2 times on failure -) -``` - -#### Custom Validators - -```python -def custom_validator(outputs: dict) -> bool: - # Custom validation logic - return "error" not in outputs.get("status", "") - -test_case = WorkflowTestCase( - fixture_path="workflow", - inputs={}, - expected_outputs={"status": "success"}, - custom_validator=custom_validator, -) -``` - #### Event Sequence Validation ```python diff --git a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py index d556bb138e..2b8f04979d 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py @@ -4,7 +4,6 @@ from __future__ import annotations from datetime import datetime -from core.workflow.entities import GraphRuntimeState, VariablePool from core.workflow.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.graph_engine.domain.graph_execution import GraphExecution @@ -16,6 +15,7 @@ from core.workflow.graph_engine.response_coordinator.coordinator import Response from core.workflow.graph_events import NodeRunRetryEvent, NodeRunStartedEvent from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.entities import RetryConfig +from core.workflow.runtime import GraphRuntimeState, VariablePool class _StubEdgeProcessor: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py index 9fec855a93..d451e7e608 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py @@ -3,12 +3,12 @@ import time from unittest.mock import MagicMock -from core.workflow.entities import GraphRuntimeState, VariablePool from core.workflow.graph import Graph from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine.command_channels import InMemoryChannel -from core.workflow.graph_engine.entities.commands import AbortCommand -from core.workflow.graph_events import GraphRunAbortedEvent, GraphRunStartedEvent +from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand +from core.workflow.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, GraphRunStartedEvent +from core.workflow.runtime import GraphRuntimeState, VariablePool def test_abort_command(): @@ -100,8 +100,57 @@ def test_redis_channel_serialization(): assert command_data["command_type"] == "abort" assert command_data["reason"] == "Test abort" + # Test pause command serialization + pause_command = PauseCommand(reason="User requested pause") + channel.send_command(pause_command) -if __name__ == "__main__": - test_abort_command() - test_redis_channel_serialization() - print("All tests passed!") + assert len(mock_pipeline.rpush.call_args_list) == 2 + second_call_args = mock_pipeline.rpush.call_args_list[1] + pause_command_json = second_call_args[0][1] + pause_command_data = json.loads(pause_command_json) + assert pause_command_data["command_type"] == CommandType.PAUSE.value + assert pause_command_data["reason"] == "User requested pause" + + +def test_pause_command(): + """Test that GraphEngine properly handles pause commands.""" + + shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + mock_graph.root_node.id = "start" + + mock_start_node = MagicMock() + mock_start_node.state = None + mock_start_node.id = "start" + mock_start_node.graph_runtime_state = shared_runtime_state + mock_graph.nodes["start"] = mock_start_node + + mock_graph.get_outgoing_edges = MagicMock(return_value=[]) + mock_graph.get_incoming_edges = MagicMock(return_value=[]) + + command_channel = InMemoryChannel() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=shared_runtime_state, + command_channel=command_channel, + ) + + pause_command = PauseCommand(reason="User requested pause") + command_channel.send_command(pause_command) + + events = list(engine.run()) + + assert any(isinstance(e, GraphRunStartedEvent) for e in events) + pause_events = [e for e in events if isinstance(e, GraphRunPausedEvent)] + assert len(pause_events) == 1 + assert pause_events[0].reason == "User requested pause" + + graph_execution = engine.graph_runtime_state.graph_execution + assert graph_execution.is_paused + assert graph_execution.pause_reason == "User requested pause" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher.py b/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher.py index 0d612e054f..3fe4ce3400 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher.py @@ -21,6 +21,7 @@ class _StubExecutionCoordinator: self._execution_complete = False self.mark_complete_called = False self.failed = False + self._paused = False def check_commands(self) -> None: self.command_checks += 1 @@ -28,6 +29,10 @@ class _StubExecutionCoordinator: def check_scaling(self) -> None: self.scaling_checks += 1 + @property + def is_paused(self) -> bool: + return self._paused + def is_execution_complete(self) -> bool: return self._execution_complete @@ -96,7 +101,7 @@ def _make_succeeded_event() -> NodeRunSucceededEvent: def test_dispatcher_checks_commands_during_idle_and_on_completion() -> None: - """Dispatcher polls commands when idle and re-checks after completion events.""" + """Dispatcher polls commands when idle and after completion events.""" started_checks = _run_dispatcher_for_event(_make_started_event()) succeeded_checks = _run_dispatcher_for_event(_make_succeeded_event()) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py b/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py new file mode 100644 index 0000000000..025393e435 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py @@ -0,0 +1,62 @@ +"""Unit tests for the execution coordinator orchestration logic.""" + +from unittest.mock import MagicMock + +from core.workflow.graph_engine.command_processing.command_processor import CommandProcessor +from core.workflow.graph_engine.domain.graph_execution import GraphExecution +from core.workflow.graph_engine.graph_state_manager import GraphStateManager +from core.workflow.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator +from core.workflow.graph_engine.worker_management.worker_pool import WorkerPool + + +def _build_coordinator(graph_execution: GraphExecution) -> tuple[ExecutionCoordinator, MagicMock, MagicMock]: + command_processor = MagicMock(spec=CommandProcessor) + state_manager = MagicMock(spec=GraphStateManager) + worker_pool = MagicMock(spec=WorkerPool) + + coordinator = ExecutionCoordinator( + graph_execution=graph_execution, + state_manager=state_manager, + command_processor=command_processor, + worker_pool=worker_pool, + ) + return coordinator, state_manager, worker_pool + + +def test_handle_pause_stops_workers_and_clears_state() -> None: + """Paused execution should stop workers and clear executing state.""" + graph_execution = GraphExecution(workflow_id="workflow") + graph_execution.start() + graph_execution.pause("Awaiting human input") + + coordinator, state_manager, worker_pool = _build_coordinator(graph_execution) + + coordinator.handle_pause_if_needed() + + worker_pool.stop.assert_called_once_with() + state_manager.clear_executing.assert_called_once_with() + + +def test_handle_pause_noop_when_execution_running() -> None: + """Running execution should not trigger pause handling.""" + graph_execution = GraphExecution(workflow_id="workflow") + graph_execution.start() + + coordinator, state_manager, worker_pool = _build_coordinator(graph_execution) + + coordinator.handle_pause_if_needed() + + worker_pool.stop.assert_not_called() + state_manager.clear_executing.assert_not_called() + + +def test_is_execution_complete_when_paused() -> None: + """Paused execution should be treated as complete.""" + graph_execution = GraphExecution(workflow_id="workflow") + graph_execution.start() + graph_execution.pause("Awaiting input") + + coordinator, state_manager, _worker_pool = _build_coordinator(graph_execution) + state_manager.is_execution_complete.return_value = False + + assert coordinator.is_execution_complete() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py new file mode 100644 index 0000000000..c9e7e31e52 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py @@ -0,0 +1,341 @@ +import time +from collections.abc import Iterable + +from core.model_runtime.entities.llm_entities import LLMMode +from core.model_runtime.entities.message_entities import PromptMessageRole +from core.workflow.entities import GraphInitParams +from core.workflow.graph import Graph +from core.workflow.graph_events import ( + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunPauseRequestedEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from core.workflow.nodes.base.entities import VariableSelector +from core.workflow.nodes.end.end_node import EndNode +from core.workflow.nodes.end.entities import EndNodeData +from core.workflow.nodes.human_input import HumanInputNode +from core.workflow.nodes.human_input.entities import HumanInputNodeData +from core.workflow.nodes.llm.entities import ( + ContextConfig, + LLMNodeChatModelMessage, + LLMNodeData, + ModelConfig, + VisionConfig, +) +from core.workflow.nodes.start.entities import StartNodeData +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable + +from .test_mock_config import MockConfig +from .test_mock_nodes import MockLLMNode +from .test_table_runner import TableTestRunner, WorkflowTestCase + + +def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]: + graph_config: dict[str, object] = {"nodes": [], "edges": []} + graph_init_params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config=graph_config, + user_id="user", + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + + variable_pool = VariablePool( + system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), + user_inputs={}, + conversation_variables=[], + ) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} + start_node = StartNode( + id=start_config["id"], + config=start_config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + start_node.init_node_data(start_config["data"]) + + def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode: + llm_data = LLMNodeData( + title=title, + model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), + prompt_template=[ + LLMNodeChatModelMessage( + text=prompt_text, + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + context=ContextConfig(enabled=False, variable_selector=None), + vision=VisionConfig(enabled=False), + reasoning_format="tagged", + ) + llm_config = {"id": node_id, "data": llm_data.model_dump()} + llm_node = MockLLMNode( + id=node_id, + config=llm_config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + mock_config=mock_config, + ) + llm_node.init_node_data(llm_config["data"]) + return llm_node + + llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream") + + human_data = HumanInputNodeData( + title="Human Input", + required_variables=["human.input_ready"], + pause_reason="Awaiting human input", + ) + human_config = {"id": "human", "data": human_data.model_dump()} + human_node = HumanInputNode( + id=human_config["id"], + config=human_config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + human_node.init_node_data(human_config["data"]) + + llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output") + llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary") + + end_primary_data = EndNodeData( + title="End Primary", + outputs=[ + VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]), + VariableSelector(variable="primary_text", value_selector=["llm_primary", "text"]), + ], + desc=None, + ) + end_primary_config = {"id": "end_primary", "data": end_primary_data.model_dump()} + end_primary = EndNode( + id=end_primary_config["id"], + config=end_primary_config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + end_primary.init_node_data(end_primary_config["data"]) + + end_secondary_data = EndNodeData( + title="End Secondary", + outputs=[ + VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]), + VariableSelector(variable="secondary_text", value_selector=["llm_secondary", "text"]), + ], + desc=None, + ) + end_secondary_config = {"id": "end_secondary", "data": end_secondary_data.model_dump()} + end_secondary = EndNode( + id=end_secondary_config["id"], + config=end_secondary_config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + end_secondary.init_node_data(end_secondary_config["data"]) + + graph = ( + Graph.new() + .add_root(start_node) + .add_node(llm_initial) + .add_node(human_node) + .add_node(llm_primary, from_node_id="human", source_handle="primary") + .add_node(end_primary, from_node_id="llm_primary") + .add_node(llm_secondary, from_node_id="human", source_handle="secondary") + .add_node(end_secondary, from_node_id="llm_secondary") + .build() + ) + return graph, graph_runtime_state + + +def _expected_mock_llm_chunks(text: str) -> list[str]: + chunks: list[str] = [] + for index, word in enumerate(text.split(" ")): + chunk = word if index == 0 else f" {word}" + chunks.append(chunk) + chunks.append("") + return chunks + + +def _assert_stream_chunk_sequence( + chunk_events: Iterable[NodeRunStreamChunkEvent], + expected_nodes: list[str], + expected_chunks: list[str], +) -> None: + actual_nodes = [event.node_id for event in chunk_events] + actual_chunks = [event.chunk for event in chunk_events] + assert actual_nodes == expected_nodes + assert actual_chunks == expected_chunks + + +def test_human_input_llm_streaming_across_multiple_branches() -> None: + mock_config = MockConfig() + mock_config.set_node_outputs("llm_initial", {"text": "Initial stream"}) + mock_config.set_node_outputs("llm_primary", {"text": "Primary stream output"}) + mock_config.set_node_outputs("llm_secondary", {"text": "Secondary"}) + + branch_scenarios = [ + { + "handle": "primary", + "resume_llm": "llm_primary", + "end_node": "end_primary", + "expected_pre_chunks": [ + ("llm_initial", _expected_mock_llm_chunks("Initial stream")), # cached output before branch completes + ("end_primary", ["\n"]), # literal segment emitted when end_primary session activates + ], + "expected_post_chunks": [ + ("llm_primary", _expected_mock_llm_chunks("Primary stream output")), # live stream from chosen branch + ], + }, + { + "handle": "secondary", + "resume_llm": "llm_secondary", + "end_node": "end_secondary", + "expected_pre_chunks": [ + ("llm_initial", _expected_mock_llm_chunks("Initial stream")), # cached output before branch completes + ("end_secondary", ["\n"]), # literal segment emitted when end_secondary session activates + ], + "expected_post_chunks": [ + ("llm_secondary", _expected_mock_llm_chunks("Secondary")), # live stream from chosen branch + ], + }, + ] + + for scenario in branch_scenarios: + runner = TableTestRunner() + + def initial_graph_factory() -> tuple[Graph, GraphRuntimeState]: + return _build_branching_graph(mock_config) + + initial_case = WorkflowTestCase( + description="HumanInput pause before branching decision", + graph_factory=initial_graph_factory, + expected_event_sequence=[ + GraphRunStartedEvent, # initial run: graph execution starts + NodeRunStartedEvent, # start node begins execution + NodeRunSucceededEvent, # start node completes + NodeRunStartedEvent, # llm_initial starts streaming + NodeRunSucceededEvent, # llm_initial completes streaming + NodeRunStartedEvent, # human node begins and issues pause + NodeRunPauseRequestedEvent, # human node requests pause awaiting input + GraphRunPausedEvent, # graph run pauses awaiting resume + ], + ) + + initial_result = runner.run_test_case(initial_case) + + assert initial_result.success, initial_result.event_mismatch_details + assert not any(isinstance(event, NodeRunStreamChunkEvent) for event in initial_result.events) + + graph_runtime_state = initial_result.graph_runtime_state + graph = initial_result.graph + assert graph_runtime_state is not None + assert graph is not None + + graph_runtime_state.variable_pool.add(("human", "input_ready"), True) + graph_runtime_state.variable_pool.add(("human", "edge_source_handle"), scenario["handle"]) + graph_runtime_state.graph_execution.pause_reason = None + + pre_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_pre_chunks"]) + post_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_post_chunks"]) + + expected_resume_sequence: list[type] = ( + [ + GraphRunStartedEvent, + NodeRunStartedEvent, + ] + + [NodeRunStreamChunkEvent] * pre_chunk_count + + [ + NodeRunSucceededEvent, + NodeRunStartedEvent, + ] + + [NodeRunStreamChunkEvent] * post_chunk_count + + [ + NodeRunSucceededEvent, + NodeRunStartedEvent, + NodeRunSucceededEvent, + GraphRunSucceededEvent, + ] + ) + + def resume_graph_factory( + graph_snapshot: Graph = graph, + state_snapshot: GraphRuntimeState = graph_runtime_state, + ) -> tuple[Graph, GraphRuntimeState]: + return graph_snapshot, state_snapshot + + resume_case = WorkflowTestCase( + description=f"HumanInput resumes via {scenario['handle']} branch", + graph_factory=resume_graph_factory, + expected_event_sequence=expected_resume_sequence, + ) + + resume_result = runner.run_test_case(resume_case) + + assert resume_result.success, resume_result.event_mismatch_details + + resume_events = resume_result.events + + chunk_events = [event for event in resume_events if isinstance(event, NodeRunStreamChunkEvent)] + assert len(chunk_events) == pre_chunk_count + post_chunk_count + + pre_chunk_events = chunk_events[:pre_chunk_count] + post_chunk_events = chunk_events[pre_chunk_count:] + + expected_pre_nodes: list[str] = [] + expected_pre_chunks: list[str] = [] + for node_id, chunks in scenario["expected_pre_chunks"]: + expected_pre_nodes.extend([node_id] * len(chunks)) + expected_pre_chunks.extend(chunks) + _assert_stream_chunk_sequence(pre_chunk_events, expected_pre_nodes, expected_pre_chunks) + + expected_post_nodes: list[str] = [] + expected_post_chunks: list[str] = [] + for node_id, chunks in scenario["expected_post_chunks"]: + expected_post_nodes.extend([node_id] * len(chunks)) + expected_post_chunks.extend(chunks) + _assert_stream_chunk_sequence(post_chunk_events, expected_post_nodes, expected_post_chunks) + + human_success_index = next( + index + for index, event in enumerate(resume_events) + if isinstance(event, NodeRunSucceededEvent) and event.node_id == "human" + ) + pre_indices = [ + index + for index, event in enumerate(resume_events) + if isinstance(event, NodeRunStreamChunkEvent) and index < human_success_index + ] + assert pre_indices == list(range(2, 2 + pre_chunk_count)) + + resume_chunk_indices = [ + index + for index, event in enumerate(resume_events) + if isinstance(event, NodeRunStreamChunkEvent) and event.node_id == scenario["resume_llm"] + ] + assert resume_chunk_indices, "Expected streaming output from the selected branch" + resume_start_index = next( + index + for index, event in enumerate(resume_events) + if isinstance(event, NodeRunStartedEvent) and event.node_id == scenario["resume_llm"] + ) + resume_success_index = next( + index + for index, event in enumerate(resume_events) + if isinstance(event, NodeRunSucceededEvent) and event.node_id == scenario["resume_llm"] + ) + assert resume_start_index < min(resume_chunk_indices) + assert max(resume_chunk_indices) < resume_success_index + + started_nodes = [event.node_id for event in resume_events if isinstance(event, NodeRunStartedEvent)] + assert started_nodes == ["human", scenario["resume_llm"], scenario["end_node"]] diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py new file mode 100644 index 0000000000..27d264365d --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py @@ -0,0 +1,297 @@ +import time + +from core.model_runtime.entities.llm_entities import LLMMode +from core.model_runtime.entities.message_entities import PromptMessageRole +from core.workflow.entities import GraphInitParams +from core.workflow.graph import Graph +from core.workflow.graph_events import ( + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunPauseRequestedEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from core.workflow.nodes.base.entities import VariableSelector +from core.workflow.nodes.end.end_node import EndNode +from core.workflow.nodes.end.entities import EndNodeData +from core.workflow.nodes.human_input import HumanInputNode +from core.workflow.nodes.human_input.entities import HumanInputNodeData +from core.workflow.nodes.llm.entities import ( + ContextConfig, + LLMNodeChatModelMessage, + LLMNodeData, + ModelConfig, + VisionConfig, +) +from core.workflow.nodes.start.entities import StartNodeData +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable + +from .test_mock_config import MockConfig +from .test_mock_nodes import MockLLMNode +from .test_table_runner import TableTestRunner, WorkflowTestCase + + +def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]: + graph_config: dict[str, object] = {"nodes": [], "edges": []} + graph_init_params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config=graph_config, + user_id="user", + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + + variable_pool = VariablePool( + system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), + user_inputs={}, + conversation_variables=[], + ) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} + start_node = StartNode( + id=start_config["id"], + config=start_config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + start_node.init_node_data(start_config["data"]) + + def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode: + llm_data = LLMNodeData( + title=title, + model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), + prompt_template=[ + LLMNodeChatModelMessage( + text=prompt_text, + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + context=ContextConfig(enabled=False, variable_selector=None), + vision=VisionConfig(enabled=False), + reasoning_format="tagged", + ) + llm_config = {"id": node_id, "data": llm_data.model_dump()} + llm_node = MockLLMNode( + id=node_id, + config=llm_config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + mock_config=mock_config, + ) + llm_node.init_node_data(llm_config["data"]) + return llm_node + + llm_first = _create_llm_node("llm_initial", "Initial LLM", "Initial prompt") + + human_data = HumanInputNodeData( + title="Human Input", + required_variables=["human.input_ready"], + pause_reason="Awaiting human input", + ) + human_config = {"id": "human", "data": human_data.model_dump()} + human_node = HumanInputNode( + id=human_config["id"], + config=human_config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + human_node.init_node_data(human_config["data"]) + + llm_second = _create_llm_node("llm_resume", "Follow-up LLM", "Follow-up prompt") + + end_data = EndNodeData( + title="End", + outputs=[ + VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]), + VariableSelector(variable="resume_text", value_selector=["llm_resume", "text"]), + ], + desc=None, + ) + end_config = {"id": "end", "data": end_data.model_dump()} + end_node = EndNode( + id=end_config["id"], + config=end_config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + end_node.init_node_data(end_config["data"]) + + graph = ( + Graph.new() + .add_root(start_node) + .add_node(llm_first) + .add_node(human_node) + .add_node(llm_second) + .add_node(end_node) + .build() + ) + return graph, graph_runtime_state + + +def _expected_mock_llm_chunks(text: str) -> list[str]: + chunks: list[str] = [] + for index, word in enumerate(text.split(" ")): + chunk = word if index == 0 else f" {word}" + chunks.append(chunk) + chunks.append("") + return chunks + + +def test_human_input_llm_streaming_order_across_pause() -> None: + runner = TableTestRunner() + + initial_text = "Hello, pause" + resume_text = "Welcome back!" + + mock_config = MockConfig() + mock_config.set_node_outputs("llm_initial", {"text": initial_text}) + mock_config.set_node_outputs("llm_resume", {"text": resume_text}) + + expected_initial_sequence: list[type] = [ + GraphRunStartedEvent, # graph run begins + NodeRunStartedEvent, # start node begins + NodeRunSucceededEvent, # start node completes + NodeRunStartedEvent, # llm_initial begins streaming + NodeRunSucceededEvent, # llm_initial completes streaming + NodeRunStartedEvent, # human node begins and requests pause + NodeRunPauseRequestedEvent, # human node pause requested + GraphRunPausedEvent, # graph run pauses awaiting resume + ] + + def graph_factory() -> tuple[Graph, GraphRuntimeState]: + return _build_llm_human_llm_graph(mock_config) + + initial_case = WorkflowTestCase( + description="HumanInput pause preserves LLM streaming order", + graph_factory=graph_factory, + expected_event_sequence=expected_initial_sequence, + ) + + initial_result = runner.run_test_case(initial_case) + + assert initial_result.success, initial_result.event_mismatch_details + + initial_events = initial_result.events + initial_chunks = _expected_mock_llm_chunks(initial_text) + + initial_stream_chunk_events = [event for event in initial_events if isinstance(event, NodeRunStreamChunkEvent)] + assert initial_stream_chunk_events == [] + + pause_index = next(i for i, event in enumerate(initial_events) if isinstance(event, GraphRunPausedEvent)) + llm_succeeded_index = next( + i + for i, event in enumerate(initial_events) + if isinstance(event, NodeRunSucceededEvent) and event.node_id == "llm_initial" + ) + assert llm_succeeded_index < pause_index + + graph_runtime_state = initial_result.graph_runtime_state + graph = initial_result.graph + assert graph_runtime_state is not None + assert graph is not None + + coordinator = graph_runtime_state.response_coordinator + stream_buffers = coordinator._stream_buffers # Tests may access internals for assertions + assert ("llm_initial", "text") in stream_buffers + initial_stream_chunks = [event.chunk for event in stream_buffers[("llm_initial", "text")]] + assert initial_stream_chunks == initial_chunks + assert ("llm_resume", "text") not in stream_buffers + + resume_chunks = _expected_mock_llm_chunks(resume_text) + expected_resume_sequence: list[type] = [ + GraphRunStartedEvent, # resumed graph run begins + NodeRunStartedEvent, # human node restarts + NodeRunStreamChunkEvent, # cached llm_initial chunk 1 + NodeRunStreamChunkEvent, # cached llm_initial chunk 2 + NodeRunStreamChunkEvent, # cached llm_initial final chunk + NodeRunStreamChunkEvent, # end node emits combined template separator + NodeRunSucceededEvent, # human node finishes instantly after input + NodeRunStartedEvent, # llm_resume begins streaming + NodeRunStreamChunkEvent, # llm_resume chunk 1 + NodeRunStreamChunkEvent, # llm_resume chunk 2 + NodeRunStreamChunkEvent, # llm_resume final chunk + NodeRunSucceededEvent, # llm_resume completes streaming + NodeRunStartedEvent, # end node starts + NodeRunSucceededEvent, # end node finishes + GraphRunSucceededEvent, # graph run succeeds after resume + ] + + def resume_graph_factory() -> tuple[Graph, GraphRuntimeState]: + assert graph_runtime_state is not None + assert graph is not None + graph_runtime_state.variable_pool.add(("human", "input_ready"), True) + graph_runtime_state.graph_execution.pause_reason = None + return graph, graph_runtime_state + + resume_case = WorkflowTestCase( + description="HumanInput resume continues LLM streaming order", + graph_factory=resume_graph_factory, + expected_event_sequence=expected_resume_sequence, + ) + + resume_result = runner.run_test_case(resume_case) + + assert resume_result.success, resume_result.event_mismatch_details + + resume_events = resume_result.events + + success_index = next(i for i, event in enumerate(resume_events) if isinstance(event, GraphRunSucceededEvent)) + llm_resume_succeeded_index = next( + i + for i, event in enumerate(resume_events) + if isinstance(event, NodeRunSucceededEvent) and event.node_id == "llm_resume" + ) + assert llm_resume_succeeded_index < success_index + + resume_chunk_events = [event for event in resume_events if isinstance(event, NodeRunStreamChunkEvent)] + assert [event.node_id for event in resume_chunk_events[:3]] == ["llm_initial"] * 3 + assert [event.chunk for event in resume_chunk_events[:3]] == initial_chunks + assert resume_chunk_events[3].node_id == "end" + assert resume_chunk_events[3].chunk == "\n" + assert [event.node_id for event in resume_chunk_events[4:]] == ["llm_resume"] * 3 + assert [event.chunk for event in resume_chunk_events[4:]] == resume_chunks + + human_success_index = next( + i + for i, event in enumerate(resume_events) + if isinstance(event, NodeRunSucceededEvent) and event.node_id == "human" + ) + cached_chunk_indices = [ + i + for i, event in enumerate(resume_events) + if isinstance(event, NodeRunStreamChunkEvent) and event.node_id in {"llm_initial", "end"} + ] + assert all(index < human_success_index for index in cached_chunk_indices) + + llm_resume_start_index = next( + i + for i, event in enumerate(resume_events) + if isinstance(event, NodeRunStartedEvent) and event.node_id == "llm_resume" + ) + llm_resume_success_index = next( + i + for i, event in enumerate(resume_events) + if isinstance(event, NodeRunSucceededEvent) and event.node_id == "llm_resume" + ) + llm_resume_chunk_indices = [ + i + for i, event in enumerate(resume_events) + if isinstance(event, NodeRunStreamChunkEvent) and event.node_id == "llm_resume" + ] + assert llm_resume_chunk_indices + first_resume_chunk_index = min(llm_resume_chunk_indices) + last_resume_chunk_index = max(llm_resume_chunk_indices) + assert llm_resume_start_index < first_resume_chunk_index + assert last_resume_chunk_index < llm_resume_success_index + + started_nodes = [event.node_id for event in resume_events if isinstance(event, NodeRunStartedEvent)] + assert started_nodes == ["human", "llm_resume", "end"] diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py b/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py new file mode 100644 index 0000000000..dfd33f135f --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py @@ -0,0 +1,321 @@ +import time + +from core.model_runtime.entities.llm_entities import LLMMode +from core.model_runtime.entities.message_entities import PromptMessageRole +from core.workflow.entities import GraphInitParams +from core.workflow.graph import Graph +from core.workflow.graph_events import ( + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from core.workflow.nodes.base.entities import VariableSelector +from core.workflow.nodes.end.end_node import EndNode +from core.workflow.nodes.end.entities import EndNodeData +from core.workflow.nodes.if_else.entities import IfElseNodeData +from core.workflow.nodes.if_else.if_else_node import IfElseNode +from core.workflow.nodes.llm.entities import ( + ContextConfig, + LLMNodeChatModelMessage, + LLMNodeData, + ModelConfig, + VisionConfig, +) +from core.workflow.nodes.start.entities import StartNodeData +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable +from core.workflow.utils.condition.entities import Condition + +from .test_mock_config import MockConfig +from .test_mock_nodes import MockLLMNode +from .test_table_runner import TableTestRunner, WorkflowTestCase + + +def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]: + graph_config: dict[str, object] = {"nodes": [], "edges": []} + graph_init_params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config=graph_config, + user_id="user", + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + + variable_pool = VariablePool( + system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), + user_inputs={}, + conversation_variables=[], + ) + variable_pool.add(("branch", "value"), branch_value) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} + start_node = StartNode( + id=start_config["id"], + config=start_config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + start_node.init_node_data(start_config["data"]) + + def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode: + llm_data = LLMNodeData( + title=title, + model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), + prompt_template=[ + LLMNodeChatModelMessage( + text=prompt_text, + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + context=ContextConfig(enabled=False, variable_selector=None), + vision=VisionConfig(enabled=False), + reasoning_format="tagged", + ) + llm_config = {"id": node_id, "data": llm_data.model_dump()} + llm_node = MockLLMNode( + id=node_id, + config=llm_config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + mock_config=mock_config, + ) + llm_node.init_node_data(llm_config["data"]) + return llm_node + + llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream") + + if_else_data = IfElseNodeData( + title="IfElse", + cases=[ + IfElseNodeData.Case( + case_id="primary", + logical_operator="and", + conditions=[ + Condition(variable_selector=["branch", "value"], comparison_operator="is", value="primary") + ], + ), + IfElseNodeData.Case( + case_id="secondary", + logical_operator="and", + conditions=[ + Condition(variable_selector=["branch", "value"], comparison_operator="is", value="secondary") + ], + ), + ], + ) + if_else_config = {"id": "if_else", "data": if_else_data.model_dump()} + if_else_node = IfElseNode( + id=if_else_config["id"], + config=if_else_config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + if_else_node.init_node_data(if_else_config["data"]) + + llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output") + llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary") + + end_primary_data = EndNodeData( + title="End Primary", + outputs=[ + VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]), + VariableSelector(variable="primary_text", value_selector=["llm_primary", "text"]), + ], + desc=None, + ) + end_primary_config = {"id": "end_primary", "data": end_primary_data.model_dump()} + end_primary = EndNode( + id=end_primary_config["id"], + config=end_primary_config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + end_primary.init_node_data(end_primary_config["data"]) + + end_secondary_data = EndNodeData( + title="End Secondary", + outputs=[ + VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]), + VariableSelector(variable="secondary_text", value_selector=["llm_secondary", "text"]), + ], + desc=None, + ) + end_secondary_config = {"id": "end_secondary", "data": end_secondary_data.model_dump()} + end_secondary = EndNode( + id=end_secondary_config["id"], + config=end_secondary_config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + end_secondary.init_node_data(end_secondary_config["data"]) + + graph = ( + Graph.new() + .add_root(start_node) + .add_node(llm_initial) + .add_node(if_else_node) + .add_node(llm_primary, from_node_id="if_else", source_handle="primary") + .add_node(end_primary, from_node_id="llm_primary") + .add_node(llm_secondary, from_node_id="if_else", source_handle="secondary") + .add_node(end_secondary, from_node_id="llm_secondary") + .build() + ) + return graph, graph_runtime_state + + +def _expected_mock_llm_chunks(text: str) -> list[str]: + chunks: list[str] = [] + for index, word in enumerate(text.split(" ")): + chunk = word if index == 0 else f" {word}" + chunks.append(chunk) + chunks.append("") + return chunks + + +def test_if_else_llm_streaming_order() -> None: + mock_config = MockConfig() + mock_config.set_node_outputs("llm_initial", {"text": "Initial stream"}) + mock_config.set_node_outputs("llm_primary", {"text": "Primary stream output"}) + mock_config.set_node_outputs("llm_secondary", {"text": "Secondary"}) + + scenarios = [ + { + "branch": "primary", + "resume_llm": "llm_primary", + "end_node": "end_primary", + "expected_sequence": [ + GraphRunStartedEvent, # graph run begins + NodeRunStartedEvent, # start node begins execution + NodeRunSucceededEvent, # start node completes + NodeRunStartedEvent, # llm_initial starts and streams + NodeRunSucceededEvent, # llm_initial completes streaming + NodeRunStartedEvent, # if_else evaluates conditions + NodeRunStreamChunkEvent, # cached llm_initial chunk 1 flushed + NodeRunStreamChunkEvent, # cached llm_initial chunk 2 flushed + NodeRunStreamChunkEvent, # cached llm_initial final chunk flushed + NodeRunStreamChunkEvent, # template literal newline emitted + NodeRunSucceededEvent, # if_else completes branch selection + NodeRunStartedEvent, # llm_primary begins streaming + NodeRunStreamChunkEvent, # llm_primary chunk 1 + NodeRunStreamChunkEvent, # llm_primary chunk 2 + NodeRunStreamChunkEvent, # llm_primary chunk 3 + NodeRunStreamChunkEvent, # llm_primary final chunk + NodeRunSucceededEvent, # llm_primary completes streaming + NodeRunStartedEvent, # end_primary node starts + NodeRunSucceededEvent, # end_primary finishes aggregation + GraphRunSucceededEvent, # graph run succeeds + ], + "expected_chunks": [ + ("llm_initial", _expected_mock_llm_chunks("Initial stream")), + ("end_primary", ["\n"]), + ("llm_primary", _expected_mock_llm_chunks("Primary stream output")), + ], + }, + { + "branch": "secondary", + "resume_llm": "llm_secondary", + "end_node": "end_secondary", + "expected_sequence": [ + GraphRunStartedEvent, # graph run begins + NodeRunStartedEvent, # start node begins execution + NodeRunSucceededEvent, # start node completes + NodeRunStartedEvent, # llm_initial starts and streams + NodeRunSucceededEvent, # llm_initial completes streaming + NodeRunStartedEvent, # if_else evaluates conditions + NodeRunStreamChunkEvent, # cached llm_initial chunk 1 flushed + NodeRunStreamChunkEvent, # cached llm_initial chunk 2 flushed + NodeRunStreamChunkEvent, # cached llm_initial final chunk flushed + NodeRunStreamChunkEvent, # template literal newline emitted + NodeRunSucceededEvent, # if_else completes branch selection + NodeRunStartedEvent, # llm_secondary begins streaming + NodeRunStreamChunkEvent, # llm_secondary chunk 1 + NodeRunStreamChunkEvent, # llm_secondary final chunk + NodeRunSucceededEvent, # llm_secondary completes + NodeRunStartedEvent, # end_secondary node starts + NodeRunSucceededEvent, # end_secondary finishes aggregation + GraphRunSucceededEvent, # graph run succeeds + ], + "expected_chunks": [ + ("llm_initial", _expected_mock_llm_chunks("Initial stream")), + ("end_secondary", ["\n"]), + ("llm_secondary", _expected_mock_llm_chunks("Secondary")), + ], + }, + ] + + for scenario in scenarios: + runner = TableTestRunner() + + def graph_factory( + branch_value: str = scenario["branch"], + cfg: MockConfig = mock_config, + ) -> tuple[Graph, GraphRuntimeState]: + return _build_if_else_graph(branch_value, cfg) + + test_case = WorkflowTestCase( + description=f"IfElse streaming via {scenario['branch']} branch", + graph_factory=graph_factory, + expected_event_sequence=scenario["expected_sequence"], + ) + + result = runner.run_test_case(test_case) + + assert result.success, result.event_mismatch_details + + chunk_events = [event for event in result.events if isinstance(event, NodeRunStreamChunkEvent)] + expected_nodes: list[str] = [] + expected_chunks: list[str] = [] + for node_id, chunks in scenario["expected_chunks"]: + expected_nodes.extend([node_id] * len(chunks)) + expected_chunks.extend(chunks) + assert [event.node_id for event in chunk_events] == expected_nodes + assert [event.chunk for event in chunk_events] == expected_chunks + + branch_node_index = next( + index + for index, event in enumerate(result.events) + if isinstance(event, NodeRunStartedEvent) and event.node_id == "if_else" + ) + branch_success_index = next( + index + for index, event in enumerate(result.events) + if isinstance(event, NodeRunSucceededEvent) and event.node_id == "if_else" + ) + pre_branch_chunk_indices = [ + index + for index, event in enumerate(result.events) + if isinstance(event, NodeRunStreamChunkEvent) and index < branch_success_index + ] + assert len(pre_branch_chunk_indices) == len(_expected_mock_llm_chunks("Initial stream")) + 1 + assert min(pre_branch_chunk_indices) == branch_node_index + 1 + assert max(pre_branch_chunk_indices) < branch_success_index + + resume_chunk_indices = [ + index + for index, event in enumerate(result.events) + if isinstance(event, NodeRunStreamChunkEvent) and event.node_id == scenario["resume_llm"] + ] + assert resume_chunk_indices + resume_start_index = next( + index + for index, event in enumerate(result.events) + if isinstance(event, NodeRunStartedEvent) and event.node_id == scenario["resume_llm"] + ) + resume_success_index = next( + index + for index, event in enumerate(result.events) + if isinstance(event, NodeRunSucceededEvent) and event.node_id == scenario["resume_llm"] + ) + assert resume_start_index < min(resume_chunk_indices) + assert max(resume_chunk_indices) < resume_success_index + + started_nodes = [event.node_id for event in result.events if isinstance(event, NodeRunStartedEvent)] + assert started_nodes == ["start", "llm_initial", "if_else", scenario["resume_llm"], scenario["end_node"]] diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index 7f802effa6..03de984bd1 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -27,7 +27,8 @@ from .test_mock_nodes import ( ) if TYPE_CHECKING: - from core.workflow.entities import GraphInitParams, GraphRuntimeState + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState from .test_mock_config import MockConfig diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py index c39c12925f..48fa00f105 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py @@ -42,7 +42,8 @@ def test_mock_iteration_node_preserves_config(): """Test that MockIterationNode preserves mock configuration.""" from core.app.entities.app_invoke_entities import InvokeFrom - from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState, VariablePool from models.enums import UserFrom from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockIterationNode @@ -103,7 +104,8 @@ def test_mock_loop_node_preserves_config(): """Test that MockLoopNode preserves mock configuration.""" from core.app.entities.app_invoke_entities import InvokeFrom - from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState, VariablePool from models.enums import UserFrom from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockLoopNode diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index e5ae32bbff..68f57ee9fb 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -24,7 +24,8 @@ from core.workflow.nodes.template_transform import TemplateTransformNode from core.workflow.nodes.tool import ToolNode if TYPE_CHECKING: - from core.workflow.entities import GraphInitParams, GraphRuntimeState + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState from .test_mock_config import MockConfig @@ -561,10 +562,11 @@ class MockIterationNode(MockNodeMixin, IterationNode): def _create_graph_engine(self, index: int, item: Any): """Create a graph engine with MockNodeFactory instead of DifyNodeFactory.""" # Import dependencies - from core.workflow.entities import GraphInitParams, GraphRuntimeState + from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine.command_channels import InMemoryChannel + from core.workflow.runtime import GraphRuntimeState # Import our MockNodeFactory instead of DifyNodeFactory from .test_mock_factory import MockNodeFactory @@ -635,10 +637,11 @@ class MockLoopNode(MockNodeMixin, LoopNode): def _create_graph_engine(self, start_at, root_node_id: str): """Create a graph engine with MockNodeFactory instead of DifyNodeFactory.""" # Import dependencies - from core.workflow.entities import GraphInitParams, GraphRuntimeState + from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine.command_channels import InMemoryChannel + from core.workflow.runtime import GraphRuntimeState # Import our MockNodeFactory instead of DifyNodeFactory from .test_mock_factory import MockNodeFactory diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py index 394addd5c2..23274f5981 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py @@ -16,8 +16,8 @@ class TestMockTemplateTransformNode: def test_mock_template_transform_node_default_output(self): """Test that MockTemplateTransformNode processes templates with Jinja2.""" - from core.workflow.entities import GraphInitParams, GraphRuntimeState - from core.workflow.entities.variable_pool import VariablePool + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( @@ -76,8 +76,8 @@ class TestMockTemplateTransformNode: def test_mock_template_transform_node_custom_output(self): """Test that MockTemplateTransformNode returns custom configured output.""" - from core.workflow.entities import GraphInitParams, GraphRuntimeState - from core.workflow.entities.variable_pool import VariablePool + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( @@ -137,8 +137,8 @@ class TestMockTemplateTransformNode: def test_mock_template_transform_node_error_simulation(self): """Test that MockTemplateTransformNode can simulate errors.""" - from core.workflow.entities import GraphInitParams, GraphRuntimeState - from core.workflow.entities.variable_pool import VariablePool + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( @@ -196,8 +196,8 @@ class TestMockTemplateTransformNode: def test_mock_template_transform_node_with_variables(self): """Test that MockTemplateTransformNode processes templates with variables.""" from core.variables import StringVariable - from core.workflow.entities import GraphInitParams, GraphRuntimeState - from core.workflow.entities.variable_pool import VariablePool + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( @@ -262,8 +262,8 @@ class TestMockCodeNode: def test_mock_code_node_default_output(self): """Test that MockCodeNode returns default output.""" - from core.workflow.entities import GraphInitParams, GraphRuntimeState - from core.workflow.entities.variable_pool import VariablePool + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( @@ -323,8 +323,8 @@ class TestMockCodeNode: def test_mock_code_node_with_output_schema(self): """Test that MockCodeNode generates outputs based on schema.""" - from core.workflow.entities import GraphInitParams, GraphRuntimeState - from core.workflow.entities.variable_pool import VariablePool + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( @@ -392,8 +392,8 @@ class TestMockCodeNode: def test_mock_code_node_custom_output(self): """Test that MockCodeNode returns custom configured output.""" - from core.workflow.entities import GraphInitParams, GraphRuntimeState - from core.workflow.entities.variable_pool import VariablePool + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( @@ -463,8 +463,8 @@ class TestMockNodeFactory: def test_code_and_template_nodes_mocked_by_default(self): """Test that CODE and TEMPLATE_TRANSFORM nodes are mocked by default (they require SSRF proxy).""" - from core.workflow.entities import GraphInitParams, GraphRuntimeState - from core.workflow.entities.variable_pool import VariablePool + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( @@ -504,8 +504,8 @@ class TestMockNodeFactory: def test_factory_creates_mock_template_transform_node(self): """Test that MockNodeFactory creates MockTemplateTransformNode for template-transform type.""" - from core.workflow.entities import GraphInitParams, GraphRuntimeState - from core.workflow.entities.variable_pool import VariablePool + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( @@ -555,8 +555,8 @@ class TestMockNodeFactory: def test_factory_creates_mock_code_node(self): """Test that MockNodeFactory creates MockCodeNode for code type.""" - from core.workflow.entities import GraphInitParams, GraphRuntimeState - from core.workflow.entities.variable_pool import VariablePool + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py index d1f1f53b78..b76fe42fce 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py @@ -13,7 +13,7 @@ from unittest.mock import patch from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool +from core.workflow.entities import GraphInitParams from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.graph_engine import GraphEngine @@ -27,6 +27,7 @@ from core.workflow.graph_events import ( from core.workflow.node_events import NodeRunResult, StreamCompletedEvent from core.workflow.nodes.llm.node import LLMNode from core.workflow.nodes.node_factory import DifyNodeFactory +from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from models.enums import UserFrom diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py b/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py index e191246bed..f1a495d20a 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py @@ -13,7 +13,7 @@ import redis from core.app.apps.base_app_queue_manager import AppQueueManager from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel -from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType +from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand from core.workflow.graph_engine.manager import GraphEngineManager @@ -52,6 +52,29 @@ class TestRedisStopIntegration: assert command_data["command_type"] == CommandType.ABORT assert command_data["reason"] == "Test stop" + def test_graph_engine_manager_sends_pause_command(self): + """Test that GraphEngineManager correctly sends pause command through Redis.""" + task_id = "test-task-pause-123" + expected_channel_key = f"workflow:{task_id}:commands" + + mock_redis = MagicMock() + mock_pipeline = MagicMock() + mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline) + mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None) + + with patch("core.workflow.graph_engine.manager.redis_client", mock_redis): + GraphEngineManager.send_pause_command(task_id, reason="Awaiting resources") + + mock_redis.pipeline.assert_called_once() + calls = mock_pipeline.rpush.call_args_list + assert len(calls) == 1 + assert calls[0][0][0] == expected_channel_key + + command_json = calls[0][0][1] + command_data = json.loads(command_json) + assert command_data["command_type"] == CommandType.PAUSE.value + assert command_data["reason"] == "Awaiting resources" + def test_graph_engine_manager_handles_redis_failure_gracefully(self): """Test that GraphEngineManager handles Redis failures without raising exceptions.""" task_id = "test-task-456" @@ -105,28 +128,37 @@ class TestRedisStopIntegration: channel_key = "workflow:test:commands" channel = RedisChannel(mock_redis, channel_key) - # Create abort command + # Create commands abort_command = AbortCommand(reason="User requested stop") + pause_command = PauseCommand(reason="User requested pause") # Execute channel.send_command(abort_command) + channel.send_command(pause_command) # Verify - mock_redis.pipeline.assert_called_once() + mock_redis.pipeline.assert_called() # Check rpush was called calls = mock_pipeline.rpush.call_args_list - assert len(calls) == 1 + assert len(calls) == 2 assert calls[0][0][0] == channel_key + assert calls[1][0][0] == channel_key - # Verify serialized command - command_json = calls[0][0][1] - command_data = json.loads(command_json) - assert command_data["command_type"] == CommandType.ABORT - assert command_data["reason"] == "User requested stop" + # Verify serialized commands + abort_command_json = calls[0][0][1] + abort_command_data = json.loads(abort_command_json) + assert abort_command_data["command_type"] == CommandType.ABORT.value + assert abort_command_data["reason"] == "User requested stop" - # Check expire was set - mock_pipeline.expire.assert_called_once_with(channel_key, 3600) + pause_command_json = calls[1][0][1] + pause_command_data = json.loads(pause_command_json) + assert pause_command_data["command_type"] == CommandType.PAUSE.value + assert pause_command_data["reason"] == "User requested pause" + + # Check expire was set for each + assert mock_pipeline.expire.call_count == 2 + mock_pipeline.expire.assert_any_call(channel_key, 3600) def test_redis_channel_fetch_commands(self): """Test RedisChannel correctly fetches and deserializes commands.""" @@ -143,12 +175,17 @@ class TestRedisStopIntegration: mock_redis.pipeline.side_effect = [pending_context, fetch_context] # Mock command data - abort_command_json = json.dumps({"command_type": CommandType.ABORT, "reason": "Test abort", "payload": None}) + abort_command_json = json.dumps( + {"command_type": CommandType.ABORT.value, "reason": "Test abort", "payload": None} + ) + pause_command_json = json.dumps( + {"command_type": CommandType.PAUSE.value, "reason": "Pause requested", "payload": None} + ) # Mock pipeline execute to return commands pending_pipe.execute.return_value = [b"1", 1] fetch_pipe.execute.return_value = [ - [abort_command_json.encode()], # lrange result + [abort_command_json.encode(), pause_command_json.encode()], # lrange result True, # delete result ] @@ -159,10 +196,13 @@ class TestRedisStopIntegration: commands = channel.fetch_commands() # Verify - assert len(commands) == 1 + assert len(commands) == 2 assert isinstance(commands[0], AbortCommand) assert commands[0].command_type == CommandType.ABORT assert commands[0].reason == "Test abort" + assert isinstance(commands[1], PauseCommand) + assert commands[1].command_type == CommandType.PAUSE + assert commands[1].reason == "Pause requested" # Verify Redis operations pending_pipe.get.assert_called_once_with(f"{channel_key}:pending") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index 0f3a142b1a..08f7b00a33 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -29,7 +29,6 @@ from core.variables import ( ObjectVariable, StringVariable, ) -from core.workflow.entities import GraphRuntimeState, VariablePool from core.workflow.entities.graph_init_params import GraphInitParams from core.workflow.graph import Graph from core.workflow.graph_engine import GraphEngine @@ -40,6 +39,7 @@ from core.workflow.graph_events import ( GraphRunSucceededEvent, ) from core.workflow.nodes.node_factory import DifyNodeFactory +from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from .test_mock_config import MockConfig @@ -52,8 +52,8 @@ logger = logging.getLogger(__name__) class WorkflowTestCase: """Represents a single test case for table-driven testing.""" - fixture_path: str - expected_outputs: dict[str, Any] + fixture_path: str = "" + expected_outputs: dict[str, Any] = field(default_factory=dict) inputs: dict[str, Any] = field(default_factory=dict) query: str = "" description: str = "" @@ -61,11 +61,7 @@ class WorkflowTestCase: mock_config: MockConfig | None = None use_auto_mock: bool = False expected_event_sequence: Sequence[type[GraphEngineEvent]] | None = None - tags: list[str] = field(default_factory=list) - skip: bool = False - skip_reason: str = "" - retry_count: int = 0 - custom_validator: Callable[[dict[str, Any]], bool] | None = None + graph_factory: Callable[[], tuple[Graph, GraphRuntimeState]] | None = None @dataclass @@ -80,7 +76,8 @@ class WorkflowTestResult: event_sequence_match: bool | None = None event_mismatch_details: str | None = None events: list[GraphEngineEvent] = field(default_factory=list) - retry_attempts: int = 0 + graph: Graph | None = None + graph_runtime_state: GraphRuntimeState | None = None validation_details: str | None = None @@ -91,7 +88,6 @@ class TestSuiteResult: total_tests: int passed_tests: int failed_tests: int - skipped_tests: int total_execution_time: float results: list[WorkflowTestResult] @@ -106,10 +102,6 @@ class TestSuiteResult: """Get all failed test results.""" return [r for r in self.results if not r.success] - def get_results_by_tag(self, tag: str) -> list[WorkflowTestResult]: - """Get test results filtered by tag.""" - return [r for r in self.results if tag in r.test_case.tags] - class WorkflowRunner: """Core workflow execution engine for tests.""" @@ -286,90 +278,30 @@ class TableTestRunner: Returns: WorkflowTestResult with execution details """ - if test_case.skip: - self.logger.info("Skipping test: %s - %s", test_case.description, test_case.skip_reason) - return WorkflowTestResult( - test_case=test_case, - success=True, - execution_time=0.0, - validation_details=f"Skipped: {test_case.skip_reason}", - ) - - retry_attempts = 0 - last_result = None - last_error = None start_time = time.perf_counter() - for attempt in range(test_case.retry_count + 1): - start_time = time.perf_counter() - - try: - result = self._execute_test_case(test_case) - last_result = result # Save the last result - - if result.success: - result.retry_attempts = retry_attempts - self.logger.info("Test passed: %s", test_case.description) - return result - - last_error = result.error - retry_attempts += 1 - - if attempt < test_case.retry_count: - self.logger.warning( - "Test failed (attempt %d/%d): %s", - attempt + 1, - test_case.retry_count + 1, - test_case.description, - ) - time.sleep(0.5 * (attempt + 1)) # Exponential backoff - - except Exception as e: - last_error = e - retry_attempts += 1 - - if attempt < test_case.retry_count: - self.logger.warning( - "Test error (attempt %d/%d): %s - %s", - attempt + 1, - test_case.retry_count + 1, - test_case.description, - str(e), - ) - time.sleep(0.5 * (attempt + 1)) - - # All retries failed - return the last result if available - if last_result: - last_result.retry_attempts = retry_attempts - self.logger.error("Test failed after %d attempts: %s", retry_attempts, test_case.description) - return last_result - - # If no result available (all attempts threw exceptions), create a failure result - self.logger.error("Test failed after %d attempts: %s", retry_attempts, test_case.description) - return WorkflowTestResult( - test_case=test_case, - success=False, - error=last_error, - execution_time=time.perf_counter() - start_time, - retry_attempts=retry_attempts, - ) + try: + result = self._execute_test_case(test_case) + if result.success: + self.logger.info("Test passed: %s", test_case.description) + else: + self.logger.error("Test failed: %s", test_case.description) + return result + except Exception as exc: + self.logger.exception("Error executing test case: %s", test_case.description) + return WorkflowTestResult( + test_case=test_case, + success=False, + error=exc, + execution_time=time.perf_counter() - start_time, + ) def _execute_test_case(self, test_case: WorkflowTestCase) -> WorkflowTestResult: """Internal method to execute a single test case.""" start_time = time.perf_counter() try: - # Load fixture data - fixture_data = self.workflow_runner.load_fixture(test_case.fixture_path) - - # Create graph from fixture - graph, graph_runtime_state = self.workflow_runner.create_graph_from_fixture( - fixture_data=fixture_data, - inputs=test_case.inputs, - query=test_case.query, - use_mock_factory=test_case.use_auto_mock, - mock_config=test_case.mock_config, - ) + graph, graph_runtime_state = self._create_graph_runtime_state(test_case) # Create and run the engine with configured worker settings engine = GraphEngine( @@ -384,7 +316,7 @@ class TableTestRunner: ) # Execute and collect events - events = [] + events: list[GraphEngineEvent] = [] for event in engine.run(): events.append(event) @@ -416,6 +348,8 @@ class TableTestRunner: events=events, event_sequence_match=event_sequence_match, event_mismatch_details=event_mismatch_details, + graph=graph, + graph_runtime_state=graph_runtime_state, ) # Get actual outputs @@ -423,9 +357,7 @@ class TableTestRunner: actual_outputs = success_event.outputs or {} # Validate outputs - output_success, validation_details = self._validate_outputs( - test_case.expected_outputs, actual_outputs, test_case.custom_validator - ) + output_success, validation_details = self._validate_outputs(test_case.expected_outputs, actual_outputs) # Overall success requires both output and event sequence validation success = output_success and (event_sequence_match if event_sequence_match is not None else True) @@ -440,6 +372,8 @@ class TableTestRunner: events=events, validation_details=validation_details, error=None if success else Exception(validation_details or event_mismatch_details or "Test failed"), + graph=graph, + graph_runtime_state=graph_runtime_state, ) except Exception as e: @@ -449,13 +383,33 @@ class TableTestRunner: success=False, error=e, execution_time=time.perf_counter() - start_time, + graph=graph if "graph" in locals() else None, + graph_runtime_state=graph_runtime_state if "graph_runtime_state" in locals() else None, ) + def _create_graph_runtime_state(self, test_case: WorkflowTestCase) -> tuple[Graph, GraphRuntimeState]: + """Create or retrieve graph/runtime state according to test configuration.""" + + if test_case.graph_factory is not None: + return test_case.graph_factory() + + if not test_case.fixture_path: + raise ValueError("fixture_path must be provided when graph_factory is not specified") + + fixture_data = self.workflow_runner.load_fixture(test_case.fixture_path) + + return self.workflow_runner.create_graph_from_fixture( + fixture_data=fixture_data, + inputs=test_case.inputs, + query=test_case.query, + use_mock_factory=test_case.use_auto_mock, + mock_config=test_case.mock_config, + ) + def _validate_outputs( self, expected_outputs: dict[str, Any], actual_outputs: dict[str, Any], - custom_validator: Callable[[dict[str, Any]], bool] | None = None, ) -> tuple[bool, str | None]: """ Validate actual outputs against expected outputs. @@ -490,14 +444,6 @@ class TableTestRunner: f"Value mismatch for key '{key}':\n Expected: {expected_value}\n Actual: {actual_value}" ) - # Apply custom validator if provided - if custom_validator: - try: - if not custom_validator(actual_outputs): - validation_errors.append("Custom validator failed") - except Exception as e: - validation_errors.append(f"Custom validator error: {str(e)}") - if validation_errors: return False, "\n".join(validation_errors) @@ -537,7 +483,6 @@ class TableTestRunner: self, test_cases: list[WorkflowTestCase], parallel: bool = False, - tags_filter: list[str] | None = None, fail_fast: bool = False, ) -> TestSuiteResult: """ @@ -546,22 +491,16 @@ class TableTestRunner: Args: test_cases: List of test cases to execute parallel: Run tests in parallel - tags_filter: Only run tests with specified tags - fail_fast: Stop execution on first failure + fail_fast: Stop execution on first failure Returns: TestSuiteResult with aggregated results """ - # Filter by tags if specified - if tags_filter: - test_cases = [tc for tc in test_cases if any(tag in tc.tags for tag in tags_filter)] - if not test_cases: return TestSuiteResult( total_tests=0, passed_tests=0, failed_tests=0, - skipped_tests=0, total_execution_time=0.0, results=[], ) @@ -576,16 +515,14 @@ class TableTestRunner: # Calculate statistics total_tests = len(results) - passed_tests = sum(1 for r in results if r.success and not r.test_case.skip) - failed_tests = sum(1 for r in results if not r.success and not r.test_case.skip) - skipped_tests = sum(1 for r in results if r.test_case.skip) + passed_tests = sum(1 for r in results if r.success) + failed_tests = total_tests - passed_tests total_execution_time = time.perf_counter() - start_time return TestSuiteResult( total_tests=total_tests, passed_tests=passed_tests, failed_tests=failed_tests, - skipped_tests=skipped_tests, total_execution_time=total_execution_time, results=results, ) @@ -598,7 +535,7 @@ class TableTestRunner: result = self.run_test_case(test_case) results.append(result) - if fail_fast and not result.success and not result.test_case.skip: + if fail_fast and not result.success: self.logger.info("Fail-fast enabled: stopping execution") break @@ -618,11 +555,11 @@ class TableTestRunner: result = future.result() results.append(result) - if fail_fast and not result.success and not result.test_case.skip: + if fail_fast and not result.success: self.logger.info("Fail-fast enabled: cancelling remaining tests") - # Cancel remaining futures - for f in future_to_test: - f.cancel() + for remaining_future in future_to_test: + if not remaining_future.done(): + remaining_future.cancel() break except Exception as e: @@ -636,8 +573,9 @@ class TableTestRunner: ) if fail_fast: - for f in future_to_test: - f.cancel() + for remaining_future in future_to_test: + if not remaining_future.done(): + remaining_future.cancel() break return results @@ -663,7 +601,6 @@ class TableTestRunner: report.append(f" Total Tests: {suite_result.total_tests}") report.append(f" Passed: {suite_result.passed_tests}") report.append(f" Failed: {suite_result.failed_tests}") - report.append(f" Skipped: {suite_result.skipped_tests}") report.append(f" Success Rate: {suite_result.success_rate:.1f}%") report.append(f" Total Time: {suite_result.total_execution_time:.2f}s") report.append("") diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index 79f3f45ce2..d151bbe015 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -3,11 +3,12 @@ import uuid from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool +from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.node_factory import DifyNodeFactory +from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from extensions.ext_database import db from models.enums import UserFrom diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index b34f73be5f..f040a92b6f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -1,4 +1,3 @@ -from core.workflow.entities import VariablePool from core.workflow.nodes.http_request import ( BodyData, HttpRequestNodeAuthorization, @@ -7,6 +6,7 @@ from core.workflow.nodes.http_request import ( ) from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout from core.workflow.nodes.http_request.executor import Executor +from core.workflow.runtime import VariablePool from core.workflow.system_variable import SystemVariable diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 94c638bb0f..3ffb5c0fdf 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -20,7 +20,7 @@ from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment -from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool +from core.workflow.entities import GraphInitParams from core.workflow.nodes.llm import llm_utils from core.workflow.nodes.llm.entities import ( ContextConfig, @@ -32,6 +32,7 @@ from core.workflow.nodes.llm.entities import ( ) from core.workflow.nodes.llm.file_saver import LLMFileSaver from core.workflow.nodes.llm.node import LLMNode +from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from models.enums import UserFrom from models.provider import ProviderType diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 69e0052543..962e43a897 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -7,12 +7,13 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom from core.file import File, FileTransferMethod, FileType from core.variables import ArrayFileSegment -from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool +from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.nodes.if_else.entities import IfElseNodeData from core.workflow.nodes.if_else.if_else_node import IfElseNode from core.workflow.nodes.node_factory import DifyNodeFactory +from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from core.workflow.utils.condition.entities import Condition, SubCondition, SubVariableCondition from extensions.ext_database import db diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py index 6189febdf5..6af4777e0e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -6,11 +6,12 @@ from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import ArrayStringVariable, StringVariable from core.workflow.conversation_variable_updater import ConversationVariableUpdater -from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool +from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode +from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from models.enums import UserFrom diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py index b842dfdb58..80071c8616 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py @@ -4,11 +4,12 @@ from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import ArrayStringVariable -from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool +from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode from core.workflow.nodes.variable_assigner.v2.enums import InputType, Operation +from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from models.enums import UserFrom diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py index 66d9d3fc14..9733bf60eb 100644 --- a/api/tests/unit_tests/core/workflow/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -27,7 +27,7 @@ from core.variables.variables import ( VariableUnion, ) from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from core.workflow.entities import VariablePool +from core.workflow.runtime import VariablePool from core.workflow.system_variable import SystemVariable from factories.variable_factory import build_segment, segment_to_variable diff --git a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py deleted file mode 100644 index 9f8f52015b..0000000000 --- a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py +++ /dev/null @@ -1,476 +0,0 @@ -import json -from unittest.mock import MagicMock - -import pytest -from sqlalchemy.orm import Session - -from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom -from core.app.entities.queue_entities import ( - QueueNodeFailedEvent, - QueueNodeStartedEvent, - QueueNodeSucceededEvent, -) -from core.workflow.entities import ( - WorkflowExecution, - WorkflowNodeExecution, -) -from core.workflow.enums import ( - WorkflowExecutionStatus, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, - WorkflowType, -) -from core.workflow.nodes import NodeType -from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository -from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from core.workflow.system_variable import SystemVariable -from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager -from libs.datetime_utils import naive_utc_now -from models.enums import CreatorUserRole -from models.model import AppMode -from models.workflow import Workflow, WorkflowRun - - -@pytest.fixture -def real_app_generate_entity(): - additional_features = AppAdditionalFeatures( - file_upload=None, - opening_statement=None, - suggested_questions=[], - suggested_questions_after_answer=False, - show_retrieve_source=False, - more_like_this=False, - speech_to_text=False, - text_to_speech=None, - trace_config=None, - ) - - app_config = WorkflowUIBasedAppConfig( - tenant_id="test-tenant-id", - app_id="test-app-id", - app_mode=AppMode.WORKFLOW, - additional_features=additional_features, - workflow_id="test-workflow-id", - ) - - entity = AdvancedChatAppGenerateEntity( - task_id="test-task-id", - app_config=app_config, - inputs={"query": "test query"}, - files=[], - user_id="test-user-id", - stream=False, - invoke_from=InvokeFrom.WEB_APP, - query="test query", - conversation_id="test-conversation-id", - ) - - return entity - - -@pytest.fixture -def real_workflow_system_variables(): - return SystemVariable( - query="test query", - conversation_id="test-conversation-id", - user_id="test-user-id", - app_id="test-app-id", - workflow_id="test-workflow-id", - workflow_execution_id="test-workflow-run-id", - ) - - -@pytest.fixture -def mock_node_execution_repository(): - repo = MagicMock(spec=WorkflowNodeExecutionRepository) - return repo - - -@pytest.fixture -def mock_workflow_execution_repository(): - repo = MagicMock(spec=WorkflowExecutionRepository) - return repo - - -@pytest.fixture -def real_workflow_entity(): - return CycleManagerWorkflowInfo( - workflow_id="test-workflow-id", # Matches ID used in other fixtures - workflow_type=WorkflowType.WORKFLOW, - version="1.0.0", - graph_data={ - "nodes": [ - { - "id": "node1", - "type": "chat", # NodeType is a string enum - "name": "Chat Node", - "data": {"model": "gpt-3.5-turbo", "prompt": "test prompt"}, - } - ], - "edges": [], - }, - ) - - -@pytest.fixture -def workflow_cycle_manager( - real_app_generate_entity, - real_workflow_system_variables, - mock_workflow_execution_repository, - mock_node_execution_repository, - real_workflow_entity, -): - return WorkflowCycleManager( - application_generate_entity=real_app_generate_entity, - workflow_system_variables=real_workflow_system_variables, - workflow_info=real_workflow_entity, - workflow_execution_repository=mock_workflow_execution_repository, - workflow_node_execution_repository=mock_node_execution_repository, - ) - - -@pytest.fixture -def mock_session(): - session = MagicMock(spec=Session) - return session - - -@pytest.fixture -def real_workflow(): - workflow = Workflow() - workflow.id = "test-workflow-id" - workflow.tenant_id = "test-tenant-id" - workflow.app_id = "test-app-id" - workflow.type = "chat" - workflow.version = "1.0" - - graph_data = {"nodes": [], "edges": []} - workflow.graph = json.dumps(graph_data) - workflow.features = json.dumps({"file_upload": {"enabled": False}}) - workflow.created_by = "test-user-id" - workflow.created_at = naive_utc_now() - workflow.updated_at = naive_utc_now() - workflow._environment_variables = "{}" - workflow._conversation_variables = "{}" - - return workflow - - -@pytest.fixture -def real_workflow_run(): - workflow_run = WorkflowRun() - workflow_run.id = "test-workflow-run-id" - workflow_run.tenant_id = "test-tenant-id" - workflow_run.app_id = "test-app-id" - workflow_run.workflow_id = "test-workflow-id" - workflow_run.type = "chat" - workflow_run.triggered_from = "app-run" - workflow_run.version = "1.0" - workflow_run.graph = json.dumps({"nodes": [], "edges": []}) - workflow_run.inputs = json.dumps({"query": "test query"}) - workflow_run.status = WorkflowExecutionStatus.RUNNING - workflow_run.outputs = json.dumps({"answer": "test answer"}) - workflow_run.created_by_role = CreatorUserRole.ACCOUNT - workflow_run.created_by = "test-user-id" - workflow_run.created_at = naive_utc_now() - - return workflow_run - - -def test_init( - workflow_cycle_manager, - real_app_generate_entity, - real_workflow_system_variables, - mock_workflow_execution_repository, - mock_node_execution_repository, -): - """Test initialization of WorkflowCycleManager""" - assert workflow_cycle_manager._application_generate_entity == real_app_generate_entity - assert workflow_cycle_manager._workflow_system_variables == real_workflow_system_variables - assert workflow_cycle_manager._workflow_execution_repository == mock_workflow_execution_repository - assert workflow_cycle_manager._workflow_node_execution_repository == mock_node_execution_repository - - -def test_handle_workflow_run_start(workflow_cycle_manager): - """Test handle_workflow_run_start method""" - # Call the method - workflow_execution = workflow_cycle_manager.handle_workflow_run_start() - - # Verify the result - assert workflow_execution.workflow_id == "test-workflow-id" - - # Verify the workflow_execution_repository.save was called - workflow_cycle_manager._workflow_execution_repository.save.assert_called_once_with(workflow_execution) - - -def test_handle_workflow_run_success(workflow_cycle_manager, mock_workflow_execution_repository): - """Test handle_workflow_run_success method""" - # Create a real WorkflowExecution - - workflow_execution = WorkflowExecution( - id_="test-workflow-run-id", - workflow_id="test-workflow-id", - workflow_type=WorkflowType.WORKFLOW, - workflow_version="1.0", - graph={"nodes": [], "edges": []}, - inputs={"query": "test query"}, - started_at=naive_utc_now(), - ) - - # Pre-populate the cache with the workflow execution - workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution - - # Call the method - result = workflow_cycle_manager.handle_workflow_run_success( - workflow_run_id="test-workflow-run-id", - total_tokens=100, - total_steps=5, - outputs={"answer": "test answer"}, - ) - - # Verify the result - assert result == workflow_execution - assert result.status == WorkflowExecutionStatus.SUCCEEDED - assert result.outputs == {"answer": "test answer"} - assert result.total_tokens == 100 - assert result.total_steps == 5 - assert result.finished_at is not None - - -def test_handle_workflow_run_failed(workflow_cycle_manager, mock_workflow_execution_repository): - """Test handle_workflow_run_failed method""" - # Create a real WorkflowExecution - - workflow_execution = WorkflowExecution( - id_="test-workflow-run-id", - workflow_id="test-workflow-id", - workflow_type=WorkflowType.WORKFLOW, - workflow_version="1.0", - graph={"nodes": [], "edges": []}, - inputs={"query": "test query"}, - started_at=naive_utc_now(), - ) - - # Pre-populate the cache with the workflow execution - workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution - - # No running node executions in cache (empty cache) - - # Call the method - result = workflow_cycle_manager.handle_workflow_run_failed( - workflow_run_id="test-workflow-run-id", - total_tokens=50, - total_steps=3, - status=WorkflowExecutionStatus.FAILED, - error_message="Test error message", - ) - - # Verify the result - assert result == workflow_execution - assert result.status == WorkflowExecutionStatus.FAILED - assert result.error_message == "Test error message" - assert result.total_tokens == 50 - assert result.total_steps == 3 - assert result.finished_at is not None - - -def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execution_repository): - """Test handle_node_execution_start method""" - # Create a real WorkflowExecution - - workflow_execution = WorkflowExecution( - id_="test-workflow-execution-id", - workflow_id="test-workflow-id", - workflow_type=WorkflowType.WORKFLOW, - workflow_version="1.0", - graph={"nodes": [], "edges": []}, - inputs={"query": "test query"}, - started_at=naive_utc_now(), - ) - - # Pre-populate the cache with the workflow execution - workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution - - # Create a mock event - event = MagicMock(spec=QueueNodeStartedEvent) - event.node_execution_id = "test-node-execution-id" - event.node_id = "test-node-id" - event.node_type = NodeType.LLM - event.node_title = "Test Node" - event.predecessor_node_id = "test-predecessor-node-id" - event.node_run_index = 1 - event.parallel_mode_run_id = "test-parallel-mode-run-id" - event.in_iteration_id = "test-iteration-id" - event.in_loop_id = "test-loop-id" - - # Call the method - result = workflow_cycle_manager.handle_node_execution_start( - workflow_execution_id=workflow_execution.id_, - event=event, - ) - - # Verify the result - assert result.workflow_id == workflow_execution.workflow_id - assert result.workflow_execution_id == workflow_execution.id_ - assert result.node_execution_id == event.node_execution_id - assert result.node_id == event.node_id - assert result.node_type == event.node_type - assert result.title == event.node_title - assert result.status == WorkflowNodeExecutionStatus.RUNNING - - # Verify save was called - workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(result) - - -def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_workflow_execution_repository): - """Test _get_workflow_execution_or_raise_error method""" - # Create a real WorkflowExecution - - workflow_execution = WorkflowExecution( - id_="test-workflow-run-id", - workflow_id="test-workflow-id", - workflow_type=WorkflowType.WORKFLOW, - workflow_version="1.0", - graph={"nodes": [], "edges": []}, - inputs={"query": "test query"}, - started_at=naive_utc_now(), - ) - - # Pre-populate the cache with the workflow execution - workflow_cycle_manager._workflow_execution_cache["test-workflow-run-id"] = workflow_execution - - # Call the method - result = workflow_cycle_manager._get_workflow_execution_or_raise_error("test-workflow-run-id") - - # Verify the result - assert result == workflow_execution - - # Test error case - clear cache - workflow_cycle_manager._workflow_execution_cache.clear() - - # Expect an error when execution is not found - from core.app.task_pipeline.exc import WorkflowRunNotFoundError - - with pytest.raises(WorkflowRunNotFoundError): - workflow_cycle_manager._get_workflow_execution_or_raise_error("non-existent-id") - - -def test_handle_workflow_node_execution_success(workflow_cycle_manager): - """Test handle_workflow_node_execution_success method""" - # Create a mock event - event = MagicMock(spec=QueueNodeSucceededEvent) - event.node_execution_id = "test-node-execution-id" - event.inputs = {"input": "test input"} - event.process_data = {"process": "test process"} - event.outputs = {"output": "test output"} - event.execution_metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100} - event.start_at = naive_utc_now() - - # Create a real node execution - - node_execution = WorkflowNodeExecution( - id="test-node-execution-record-id", - node_execution_id="test-node-execution-id", - workflow_id="test-workflow-id", - workflow_execution_id="test-workflow-run-id", - index=1, - node_id="test-node-id", - node_type=NodeType.LLM, - title="Test Node", - created_at=naive_utc_now(), - ) - - # Pre-populate the cache with the node execution - workflow_cycle_manager._node_execution_cache["test-node-execution-id"] = node_execution - - # Call the method - result = workflow_cycle_manager.handle_workflow_node_execution_success( - event=event, - ) - - # Verify the result - assert result == node_execution - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - - # Verify save was called - workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(node_execution) - - -def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_workflow_execution_repository): - """Test handle_workflow_run_partial_success method""" - # Create a real WorkflowExecution - - workflow_execution = WorkflowExecution( - id_="test-workflow-run-id", - workflow_id="test-workflow-id", - workflow_type=WorkflowType.WORKFLOW, - workflow_version="1.0", - graph={"nodes": [], "edges": []}, - inputs={"query": "test query"}, - started_at=naive_utc_now(), - ) - - # Pre-populate the cache with the workflow execution - workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution - - # Call the method - result = workflow_cycle_manager.handle_workflow_run_partial_success( - workflow_run_id="test-workflow-run-id", - total_tokens=75, - total_steps=4, - outputs={"partial_answer": "test partial answer"}, - exceptions_count=2, - ) - - # Verify the result - assert result == workflow_execution - assert result.status == WorkflowExecutionStatus.PARTIAL_SUCCEEDED - assert result.outputs == {"partial_answer": "test partial answer"} - assert result.total_tokens == 75 - assert result.total_steps == 4 - assert result.exceptions_count == 2 - assert result.finished_at is not None - - -def test_handle_workflow_node_execution_failed(workflow_cycle_manager): - """Test handle_workflow_node_execution_failed method""" - # Create a mock event - event = MagicMock(spec=QueueNodeFailedEvent) - event.node_execution_id = "test-node-execution-id" - event.inputs = {"input": "test input"} - event.process_data = {"process": "test process"} - event.outputs = {"output": "test output"} - event.execution_metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100} - event.start_at = naive_utc_now() - event.error = "Test error message" - - # Create a real node execution - - node_execution = WorkflowNodeExecution( - id="test-node-execution-record-id", - node_execution_id="test-node-execution-id", - workflow_id="test-workflow-id", - workflow_execution_id="test-workflow-run-id", - index=1, - node_id="test-node-id", - node_type=NodeType.LLM, - title="Test Node", - created_at=naive_utc_now(), - ) - - # Pre-populate the cache with the node execution - workflow_cycle_manager._node_execution_cache["test-node-execution-id"] = node_execution - - # Call the method - result = workflow_cycle_manager.handle_workflow_node_execution_failed( - event=event, - ) - - # Verify the result - assert result == node_execution - assert result.status == WorkflowNodeExecutionStatus.FAILED - assert result.error == "Test error message" - - # Verify save was called - workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(node_execution) diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry.py b/api/tests/unit_tests/core/workflow/test_workflow_entry.py index 324f58abf6..75de5c455f 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry.py @@ -7,7 +7,7 @@ from core.workflow.constants import ( CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, ) -from core.workflow.entities.variable_pool import VariablePool +from core.workflow.runtime import VariablePool from core.workflow.system_variable import SystemVariable from core.workflow.workflow_entry import WorkflowEntry diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py b/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py index c3d59aaf3f..bc55d3fccf 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py @@ -3,8 +3,8 @@ from unittest.mock import MagicMock, patch from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities import GraphRuntimeState, VariablePool from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel +from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.workflow_entry import WorkflowEntry from models.enums import UserFrom diff --git a/dev/start-worker b/dev/start-worker index a2af04c01c..a7f16b853f 100755 --- a/dev/start-worker +++ b/dev/start-worker @@ -5,7 +5,6 @@ set -x SCRIPT_DIR="$(dirname "$(realpath "$0")")" cd "$SCRIPT_DIR/.." - uv --directory api run \ - celery -A app.celery worker \ - -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation + celery -A app.celery worker \ + -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline From 3f9f02b9e7e9e6174ba71dae7150ef8696fe46e5 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 20 Oct 2025 09:36:41 +0800 Subject: [PATCH 39/46] docs: mention backend lint gate in AGENTS (#27102) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- AGENTS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/AGENTS.md b/AGENTS.md index 5859cd1bd9..2ef7931efc 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -14,7 +14,7 @@ The codebase is split into: - Run backend CLI commands through `uv run --project api `. -- Backend QA gate requires passing `make lint`, `make type-check`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh` before review. +- Before submission, all backend modifications must pass local checks: `make lint`, `make type-check`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`. - Use Makefile targets for linting and formatting; `make lint` and `make type-check` cover the required checks. From f87db2652b3c935bc4dab162be671aa4ea80add6 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 20 Oct 2025 09:37:02 +0800 Subject: [PATCH 40/46] chore(deps): bump @lexical/selection from 0.36.2 to 0.37.0 in /web (#27108) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- web/package.json | 2 +- web/pnpm-lock.yaml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/web/package.json b/web/package.json index 1721e54d73..4c42f03ce8 100644 --- a/web/package.json +++ b/web/package.json @@ -54,7 +54,7 @@ "@lexical/link": "^0.36.2", "@lexical/list": "^0.36.2", "@lexical/react": "^0.36.2", - "@lexical/selection": "^0.36.2", + "@lexical/selection": "^0.37.0", "@lexical/text": "^0.36.2", "@lexical/utils": "^0.37.0", "@monaco-editor/react": "^4.6.0", diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index 4f75b6e93e..f6efc17a26 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -80,8 +80,8 @@ importers: specifier: ^0.36.2 version: 0.36.2(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(yjs@13.6.27) '@lexical/selection': - specifier: ^0.36.2 - version: 0.36.2 + specifier: ^0.37.0 + version: 0.37.0 '@lexical/text': specifier: ^0.36.2 version: 0.36.2 From fe2ac66a52c4f2ed916e6fe22184cb08bac4f408 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 20 Oct 2025 09:37:10 +0800 Subject: [PATCH 41/46] chore(deps): bump html-to-image from 1.11.11 to 1.11.13 in /web (#27109) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- web/package.json | 2 +- web/pnpm-lock.yaml | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/web/package.json b/web/package.json index 4c42f03ce8..09b6630fa6 100644 --- a/web/package.json +++ b/web/package.json @@ -81,7 +81,7 @@ "elkjs": "^0.9.3", "emoji-mart": "^5.5.2", "fast-deep-equal": "^3.1.3", - "html-to-image": "1.11.11", + "html-to-image": "1.11.13", "i18next": "^23.16.4", "i18next-resources-to-backend": "^1.2.1", "immer": "^9.0.19", diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index f6efc17a26..ca5c8a744e 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -161,8 +161,8 @@ importers: specifier: ^3.1.3 version: 3.1.3 html-to-image: - specifier: 1.11.11 - version: 1.11.11 + specifier: 1.11.13 + version: 1.11.13 i18next: specifier: ^23.16.4 version: 23.16.8 @@ -5671,8 +5671,8 @@ packages: html-parse-stringify@3.0.1: resolution: {integrity: sha512-KknJ50kTInJ7qIScF3jeaFRpMpE8/lfiTdzf/twXyPBLAGrLRTmkz3AdTnKeh40X8k9L2fdYwEp/42WGXIRGcg==} - html-to-image@1.11.11: - resolution: {integrity: sha512-9gux8QhvjRO/erSnDPv28noDZcPZmYE7e1vFsBLKLlRlKDSqNJYebj6Qz1TGd5lsRV+X+xYyjCKjuZdABinWjA==} + html-to-image@1.11.13: + resolution: {integrity: sha512-cuOPoI7WApyhBElTTb9oqsawRvZ0rHhaHwghRLlTuffoD1B2aDemlCruLeZrUIIdvG7gs9xeELEPm6PhuASqrg==} html-url-attributes@3.0.1: resolution: {integrity: sha512-ol6UPyBWqsrO6EJySPz2O7ZSr856WDrEzM5zMqp+FJJLGMW35cLYmmZnl0vztAZxRUoNZJFTCohfjuIJ8I4QBQ==} @@ -14983,7 +14983,7 @@ snapshots: dependencies: void-elements: 3.1.0 - html-to-image@1.11.11: {} + html-to-image@1.11.13: {} html-url-attributes@3.0.1: {} From ab1059134d5d32d551524c013d86bf5f6538778c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 20 Oct 2025 10:12:16 +0800 Subject: [PATCH 42/46] chore(deps): bump pydantic-settings from 2.9.1 to 2.11.0 in /api (#27114) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- api/pyproject.toml | 2 +- api/uv.lock | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/api/pyproject.toml b/api/pyproject.toml index 74e6782d83..040d9658b3 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -63,7 +63,7 @@ dependencies = [ "pycryptodome==3.19.1", "pydantic~=2.11.4", "pydantic-extra-types~=2.10.3", - "pydantic-settings~=2.9.1", + "pydantic-settings~=2.11.0", "pyjwt~=2.10.1", "pypdfium2==4.30.0", "python-docx~=1.1.0", diff --git a/api/uv.lock b/api/uv.lock index 8f28fa36a8..e7e51acedf 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1545,7 +1545,7 @@ requires-dist = [ { name = "pycryptodome", specifier = "==3.19.1" }, { name = "pydantic", specifier = "~=2.11.4" }, { name = "pydantic-extra-types", specifier = "~=2.10.3" }, - { name = "pydantic-settings", specifier = "~=2.9.1" }, + { name = "pydantic-settings", specifier = "~=2.11.0" }, { name = "pyjwt", specifier = "~=2.10.1" }, { name = "pypdfium2", specifier = "==4.30.0" }, { name = "python-docx", specifier = "~=1.1.0" }, @@ -4754,16 +4754,16 @@ wheels = [ [[package]] name = "pydantic-settings" -version = "2.9.1" +version = "2.11.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pydantic" }, { name = "python-dotenv" }, { name = "typing-inspection" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/67/1d/42628a2c33e93f8e9acbde0d5d735fa0850f3e6a2f8cb1eb6c40b9a732ac/pydantic_settings-2.9.1.tar.gz", hash = "sha256:c509bf79d27563add44e8446233359004ed85066cd096d8b510f715e6ef5d268", size = 163234, upload-time = "2025-04-18T16:44:48.265Z" } +sdist = { url = "https://files.pythonhosted.org/packages/20/c5/dbbc27b814c71676593d1c3f718e6cd7d4f00652cefa24b75f7aa3efb25e/pydantic_settings-2.11.0.tar.gz", hash = "sha256:d0e87a1c7d33593beb7194adb8470fc426e95ba02af83a0f23474a04c9a08180", size = 188394, upload-time = "2025-09-24T14:19:11.764Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b6/5f/d6d641b490fd3ec2c4c13b4244d68deea3a1b970a97be64f34fb5504ff72/pydantic_settings-2.9.1-py3-none-any.whl", hash = "sha256:59b4f431b1defb26fe620c71a7d3968a710d719f5f4cdbbdb7926edeb770f6ef", size = 44356, upload-time = "2025-04-18T16:44:46.617Z" }, + { url = "https://files.pythonhosted.org/packages/83/d6/887a1ff844e64aa823fb4905978d882a633cfe295c32eacad582b78a7d8b/pydantic_settings-2.11.0-py3-none-any.whl", hash = "sha256:fe2cea3413b9530d10f3a5875adffb17ada5c1e1bab0b2885546d7310415207c", size = 48608, upload-time = "2025-09-24T14:19:10.015Z" }, ] [[package]] From 5579521ffc3c1da88776405a9f15db6e0abe5084 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 20 Oct 2025 10:12:30 +0800 Subject: [PATCH 43/46] chore(deps-dev): bump cross-env from 7.0.3 to 10.1.0 in /web (#27112) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- web/package.json | 2 +- web/pnpm-lock.yaml | 18 ++++++++++++------ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/web/package.json b/web/package.json index 09b6630fa6..cec4124d29 100644 --- a/web/package.json +++ b/web/package.json @@ -180,7 +180,7 @@ "babel-loader": "^10.0.0", "bing-translate-api": "^4.0.2", "code-inspector-plugin": "1.2.9", - "cross-env": "^7.0.3", + "cross-env": "^10.1.0", "eslint": "^9.35.0", "eslint-plugin-oxlint": "^1.6.0", "eslint-plugin-react-hooks": "^5.1.0", diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index ca5c8a744e..2105d34b89 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -453,8 +453,8 @@ importers: specifier: 1.2.9 version: 1.2.9 cross-env: - specifier: ^7.0.3 - version: 7.0.3 + specifier: ^10.1.0 + version: 10.1.0 eslint: specifier: ^9.35.0 version: 9.35.0(jiti@2.6.1) @@ -1339,6 +1339,9 @@ packages: '@emoji-mart/data@1.2.1': resolution: {integrity: sha512-no2pQMWiBy6gpBEiqGeU77/bFejDqUTRY7KX+0+iur13op3bqUsXdnwoZs6Xb1zbv0gAj5VvS1PWoUUckSr5Dw==} + '@epic-web/invariant@1.0.0': + resolution: {integrity: sha512-lrTPqgvfFQtR/eY/qkIzp98OGdNJu0m5ji3q/nJI8v3SXkRKEnWiOxMmbvcSoAIzv/cGiuvRy57k4suKQSAdwA==} + '@es-joy/jsdoccomment@0.50.2': resolution: {integrity: sha512-YAdE/IJSpwbOTiaURNCKECdAwqrJuFiZhylmesBcIRawtYKnBR2wxPhoIewMg+Yu+QuYvHfJNReWpoxGBKOChA==} engines: {node: '>=18'} @@ -4468,9 +4471,9 @@ packages: create-require@1.1.1: resolution: {integrity: sha512-dcKFX3jn0MpIaXjisoRvexIJVEKzaq7z2rZKxf+MSr9TkdmHmsU4m2lcLojrj/FHl8mk5VxMmYA+ftRkP/3oKQ==} - cross-env@7.0.3: - resolution: {integrity: sha512-+/HKd6EgcQCJGh2PSjZuUitQBQynKor4wrFbRg4DtAgS1aWO+gU52xpH7M9ScGgXSYmAVS9bIJ8EzuaGw0oNAw==} - engines: {node: '>=10.14', npm: '>=6', yarn: '>=1'} + cross-env@10.1.0: + resolution: {integrity: sha512-GsYosgnACZTADcmEyJctkJIoqAhHjttw7RsFrVoJNXbsWWqaq6Ym+7kZjq6mS45O0jij6vtiReppKQEtqWy6Dw==} + engines: {node: '>=20'} hasBin: true cross-spawn@7.0.6: @@ -9816,6 +9819,8 @@ snapshots: '@emoji-mart/data@1.2.1': {} + '@epic-web/invariant@1.0.0': {} + '@es-joy/jsdoccomment@0.50.2': dependencies: '@types/estree': 1.0.8 @@ -13454,8 +13459,9 @@ snapshots: create-require@1.1.1: optional: true - cross-env@7.0.3: + cross-env@10.1.0: dependencies: + '@epic-web/invariant': 1.0.0 cross-spawn: 7.0.6 cross-spawn@7.0.6: From 7e9be4d3d96c1b5a75519511661d038fc115cb60 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 20 Oct 2025 10:16:35 +0800 Subject: [PATCH 44/46] chore(deps): bump immer from 9.0.21 to 10.1.3 in /web (#27113) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- web/package.json | 2 +- web/pnpm-lock.yaml | 72 +++++++++++++++++++++++----------------------- 2 files changed, 37 insertions(+), 37 deletions(-) diff --git a/web/package.json b/web/package.json index cec4124d29..f2c963a3eb 100644 --- a/web/package.json +++ b/web/package.json @@ -84,7 +84,7 @@ "html-to-image": "1.11.13", "i18next": "^23.16.4", "i18next-resources-to-backend": "^1.2.1", - "immer": "^9.0.19", + "immer": "^10.1.3", "js-audio-recorder": "^1.0.7", "js-cookie": "^3.0.5", "jsonschema": "^1.5.0", diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index 2105d34b89..0f1495e474 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -170,8 +170,8 @@ importers: specifier: ^1.2.1 version: 1.2.1 immer: - specifier: ^9.0.19 - version: 9.0.21 + specifier: ^10.1.3 + version: 10.1.3 js-audio-recorder: specifier: ^1.0.7 version: 1.0.7 @@ -279,7 +279,7 @@ importers: version: 1.8.11(react-dom@19.1.1(react@19.1.1))(react@19.1.1) reactflow: specifier: ^11.11.3 - version: 11.11.4(@types/react@19.1.11)(immer@9.0.21)(react-dom@19.1.1(react@19.1.1))(react@19.1.1) + version: 11.11.4(@types/react@19.1.11)(immer@10.1.3)(react-dom@19.1.1(react@19.1.1))(react@19.1.1) rehype-katex: specifier: ^7.0.1 version: 7.0.1 @@ -327,10 +327,10 @@ importers: version: 3.25.76 zundo: specifier: ^2.1.0 - version: 2.3.0(zustand@4.5.7(@types/react@19.1.11)(immer@9.0.21)(react@19.1.1)) + version: 2.3.0(zustand@4.5.7(@types/react@19.1.11)(immer@10.1.3)(react@19.1.1)) zustand: specifier: ^4.5.2 - version: 4.5.7(@types/react@19.1.11)(immer@9.0.21)(react@19.1.1) + version: 4.5.7(@types/react@19.1.11)(immer@10.1.3)(react@19.1.1) devDependencies: '@antfu/eslint-config': specifier: ^5.0.0 @@ -5760,8 +5760,8 @@ packages: engines: {node: '>=16.x'} hasBin: true - immer@9.0.21: - resolution: {integrity: sha512-bc4NBHqOqSfRW7POMkHd51LvClaeMXpm8dx0e8oE2GORbq5aRK7Bxl4FyzVLdGtLmvLKL7BTDBG5ACQm4HWjTA==} + immer@10.1.3: + resolution: {integrity: sha512-tmjF/k8QDKydUlm3mZU+tjM6zeq9/fFpPqH9SzWmBnVVKsPBg/V66qsMwb3/Bo90cgUN+ghdVBess+hPsxUyRw==} immutable@5.1.3: resolution: {integrity: sha512-+chQdDfvscSF1SJqv2gn4SRO2ZyS3xL3r7IW/wWEEzrzLisnOlKiQu5ytC/BVNcS15C39WT2Hg/bjKjDMcu+zg==} @@ -11366,29 +11366,29 @@ snapshots: dependencies: react: 19.1.1 - '@reactflow/background@11.3.14(@types/react@19.1.11)(immer@9.0.21)(react-dom@19.1.1(react@19.1.1))(react@19.1.1)': + '@reactflow/background@11.3.14(@types/react@19.1.11)(immer@10.1.3)(react-dom@19.1.1(react@19.1.1))(react@19.1.1)': dependencies: - '@reactflow/core': 11.11.4(@types/react@19.1.11)(immer@9.0.21)(react-dom@19.1.1(react@19.1.1))(react@19.1.1) + '@reactflow/core': 11.11.4(@types/react@19.1.11)(immer@10.1.3)(react-dom@19.1.1(react@19.1.1))(react@19.1.1) classcat: 5.0.5 react: 19.1.1 react-dom: 19.1.1(react@19.1.1) - zustand: 4.5.7(@types/react@19.1.11)(immer@9.0.21)(react@19.1.1) + zustand: 4.5.7(@types/react@19.1.11)(immer@10.1.3)(react@19.1.1) transitivePeerDependencies: - '@types/react' - immer - '@reactflow/controls@11.2.14(@types/react@19.1.11)(immer@9.0.21)(react-dom@19.1.1(react@19.1.1))(react@19.1.1)': + '@reactflow/controls@11.2.14(@types/react@19.1.11)(immer@10.1.3)(react-dom@19.1.1(react@19.1.1))(react@19.1.1)': dependencies: - '@reactflow/core': 11.11.4(@types/react@19.1.11)(immer@9.0.21)(react-dom@19.1.1(react@19.1.1))(react@19.1.1) + '@reactflow/core': 11.11.4(@types/react@19.1.11)(immer@10.1.3)(react-dom@19.1.1(react@19.1.1))(react@19.1.1) classcat: 5.0.5 react: 19.1.1 react-dom: 19.1.1(react@19.1.1) - zustand: 4.5.7(@types/react@19.1.11)(immer@9.0.21)(react@19.1.1) + zustand: 4.5.7(@types/react@19.1.11)(immer@10.1.3)(react@19.1.1) transitivePeerDependencies: - '@types/react' - immer - '@reactflow/core@11.11.4(@types/react@19.1.11)(immer@9.0.21)(react-dom@19.1.1(react@19.1.1))(react@19.1.1)': + '@reactflow/core@11.11.4(@types/react@19.1.11)(immer@10.1.3)(react-dom@19.1.1(react@19.1.1))(react@19.1.1)': dependencies: '@types/d3': 7.4.3 '@types/d3-drag': 3.0.7 @@ -11400,14 +11400,14 @@ snapshots: d3-zoom: 3.0.0 react: 19.1.1 react-dom: 19.1.1(react@19.1.1) - zustand: 4.5.7(@types/react@19.1.11)(immer@9.0.21)(react@19.1.1) + zustand: 4.5.7(@types/react@19.1.11)(immer@10.1.3)(react@19.1.1) transitivePeerDependencies: - '@types/react' - immer - '@reactflow/minimap@11.7.14(@types/react@19.1.11)(immer@9.0.21)(react-dom@19.1.1(react@19.1.1))(react@19.1.1)': + '@reactflow/minimap@11.7.14(@types/react@19.1.11)(immer@10.1.3)(react-dom@19.1.1(react@19.1.1))(react@19.1.1)': dependencies: - '@reactflow/core': 11.11.4(@types/react@19.1.11)(immer@9.0.21)(react-dom@19.1.1(react@19.1.1))(react@19.1.1) + '@reactflow/core': 11.11.4(@types/react@19.1.11)(immer@10.1.3)(react-dom@19.1.1(react@19.1.1))(react@19.1.1) '@types/d3-selection': 3.0.11 '@types/d3-zoom': 3.0.8 classcat: 5.0.5 @@ -11415,31 +11415,31 @@ snapshots: d3-zoom: 3.0.0 react: 19.1.1 react-dom: 19.1.1(react@19.1.1) - zustand: 4.5.7(@types/react@19.1.11)(immer@9.0.21)(react@19.1.1) + zustand: 4.5.7(@types/react@19.1.11)(immer@10.1.3)(react@19.1.1) transitivePeerDependencies: - '@types/react' - immer - '@reactflow/node-resizer@2.2.14(@types/react@19.1.11)(immer@9.0.21)(react-dom@19.1.1(react@19.1.1))(react@19.1.1)': + '@reactflow/node-resizer@2.2.14(@types/react@19.1.11)(immer@10.1.3)(react-dom@19.1.1(react@19.1.1))(react@19.1.1)': dependencies: - '@reactflow/core': 11.11.4(@types/react@19.1.11)(immer@9.0.21)(react-dom@19.1.1(react@19.1.1))(react@19.1.1) + '@reactflow/core': 11.11.4(@types/react@19.1.11)(immer@10.1.3)(react-dom@19.1.1(react@19.1.1))(react@19.1.1) classcat: 5.0.5 d3-drag: 3.0.0 d3-selection: 3.0.0 react: 19.1.1 react-dom: 19.1.1(react@19.1.1) - zustand: 4.5.7(@types/react@19.1.11)(immer@9.0.21)(react@19.1.1) + zustand: 4.5.7(@types/react@19.1.11)(immer@10.1.3)(react@19.1.1) transitivePeerDependencies: - '@types/react' - immer - '@reactflow/node-toolbar@1.3.14(@types/react@19.1.11)(immer@9.0.21)(react-dom@19.1.1(react@19.1.1))(react@19.1.1)': + '@reactflow/node-toolbar@1.3.14(@types/react@19.1.11)(immer@10.1.3)(react-dom@19.1.1(react@19.1.1))(react@19.1.1)': dependencies: - '@reactflow/core': 11.11.4(@types/react@19.1.11)(immer@9.0.21)(react-dom@19.1.1(react@19.1.1))(react@19.1.1) + '@reactflow/core': 11.11.4(@types/react@19.1.11)(immer@10.1.3)(react-dom@19.1.1(react@19.1.1))(react@19.1.1) classcat: 5.0.5 react: 19.1.1 react-dom: 19.1.1(react@19.1.1) - zustand: 4.5.7(@types/react@19.1.11)(immer@9.0.21)(react@19.1.1) + zustand: 4.5.7(@types/react@19.1.11)(immer@10.1.3)(react@19.1.1) transitivePeerDependencies: - '@types/react' - immer @@ -15063,7 +15063,7 @@ snapshots: dependencies: queue: 6.0.2 - immer@9.0.21: {} + immer@10.1.3: {} immutable@5.1.3: {} @@ -17303,14 +17303,14 @@ snapshots: react@19.1.1: {} - reactflow@11.11.4(@types/react@19.1.11)(immer@9.0.21)(react-dom@19.1.1(react@19.1.1))(react@19.1.1): + reactflow@11.11.4(@types/react@19.1.11)(immer@10.1.3)(react-dom@19.1.1(react@19.1.1))(react@19.1.1): dependencies: - '@reactflow/background': 11.3.14(@types/react@19.1.11)(immer@9.0.21)(react-dom@19.1.1(react@19.1.1))(react@19.1.1) - '@reactflow/controls': 11.2.14(@types/react@19.1.11)(immer@9.0.21)(react-dom@19.1.1(react@19.1.1))(react@19.1.1) - '@reactflow/core': 11.11.4(@types/react@19.1.11)(immer@9.0.21)(react-dom@19.1.1(react@19.1.1))(react@19.1.1) - '@reactflow/minimap': 11.7.14(@types/react@19.1.11)(immer@9.0.21)(react-dom@19.1.1(react@19.1.1))(react@19.1.1) - '@reactflow/node-resizer': 2.2.14(@types/react@19.1.11)(immer@9.0.21)(react-dom@19.1.1(react@19.1.1))(react@19.1.1) - '@reactflow/node-toolbar': 1.3.14(@types/react@19.1.11)(immer@9.0.21)(react-dom@19.1.1(react@19.1.1))(react@19.1.1) + '@reactflow/background': 11.3.14(@types/react@19.1.11)(immer@10.1.3)(react-dom@19.1.1(react@19.1.1))(react@19.1.1) + '@reactflow/controls': 11.2.14(@types/react@19.1.11)(immer@10.1.3)(react-dom@19.1.1(react@19.1.1))(react@19.1.1) + '@reactflow/core': 11.11.4(@types/react@19.1.11)(immer@10.1.3)(react-dom@19.1.1(react@19.1.1))(react@19.1.1) + '@reactflow/minimap': 11.7.14(@types/react@19.1.11)(immer@10.1.3)(react-dom@19.1.1(react@19.1.1))(react@19.1.1) + '@reactflow/node-resizer': 2.2.14(@types/react@19.1.11)(immer@10.1.3)(react-dom@19.1.1(react@19.1.1))(react@19.1.1) + '@reactflow/node-toolbar': 1.3.14(@types/react@19.1.11)(immer@10.1.3)(react-dom@19.1.1(react@19.1.1))(react@19.1.1) react: 19.1.1 react-dom: 19.1.1(react@19.1.1) transitivePeerDependencies: @@ -18742,16 +18742,16 @@ snapshots: dependencies: tslib: 2.3.0 - zundo@2.3.0(zustand@4.5.7(@types/react@19.1.11)(immer@9.0.21)(react@19.1.1)): + zundo@2.3.0(zustand@4.5.7(@types/react@19.1.11)(immer@10.1.3)(react@19.1.1)): dependencies: - zustand: 4.5.7(@types/react@19.1.11)(immer@9.0.21)(react@19.1.1) + zustand: 4.5.7(@types/react@19.1.11)(immer@10.1.3)(react@19.1.1) - zustand@4.5.7(@types/react@19.1.11)(immer@9.0.21)(react@19.1.1): + zustand@4.5.7(@types/react@19.1.11)(immer@10.1.3)(react@19.1.1): dependencies: use-sync-external-store: 1.5.0(react@19.1.1) optionalDependencies: '@types/react': 19.1.11 - immer: 9.0.21 + immer: 10.1.3 react: 19.1.1 zwitch@2.0.4: {} From dc1a38088850821ba4d063410d3caca5fd00ceb0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Mon, 20 Oct 2025 10:17:17 +0800 Subject: [PATCH 45/46] chore: improve storybook (#27111) --- web/.storybook/main.ts | 18 +- web/.storybook/preview.tsx | 29 +- .../components/base/button/index.stories.tsx | 9 +- .../base/chat/chat/answer/index.stories.tsx | 2 +- .../base/chat/chat/question.stories.tsx | 2 +- .../components/base/confirm/index.stories.tsx | 199 +++ web/package.json | 23 +- web/pnpm-lock.yaml | 1454 ++++++++--------- 8 files changed, 974 insertions(+), 762 deletions(-) create mode 100644 web/app/components/base/confirm/index.stories.tsx diff --git a/web/.storybook/main.ts b/web/.storybook/main.ts index fecf774e98..0605c71346 100644 --- a/web/.storybook/main.ts +++ b/web/.storybook/main.ts @@ -1,19 +1,29 @@ import type { StorybookConfig } from '@storybook/nextjs' const config: StorybookConfig = { - // stories: ['../stories/**/*.mdx', '../stories/**/*.stories.@(js|jsx|mjs|ts|tsx)'], stories: ['../app/components/**/*.stories.@(js|jsx|mjs|ts|tsx)'], addons: [ '@storybook/addon-onboarding', '@storybook/addon-links', - '@storybook/addon-essentials', + '@storybook/addon-docs', '@chromatic-com/storybook', - '@storybook/addon-interactions', ], framework: { name: '@storybook/nextjs', - options: {}, + options: { + builder: { + useSWC: true, + lazyCompilation: false, + }, + nextConfigPath: undefined, + }, }, staticDirs: ['../public'], + core: { + disableWhatsNewNotifications: true, + }, + docs: { + defaultName: 'Documentation', + }, } export default config diff --git a/web/.storybook/preview.tsx b/web/.storybook/preview.tsx index 55328602f9..1f5726de34 100644 --- a/web/.storybook/preview.tsx +++ b/web/.storybook/preview.tsx @@ -1,12 +1,21 @@ -import React from 'react' import type { Preview } from '@storybook/react' import { withThemeByDataAttribute } from '@storybook/addon-themes' -import I18nServer from '../app/components/i18n-server' +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' +import I18N from '../app/components/i18n' +import { ToastProvider } from '../app/components/base/toast' import '../app/styles/globals.css' import '../app/styles/markdown.scss' import './storybook.css' +const queryClient = new QueryClient({ + defaultOptions: { + queries: { + refetchOnWindowFocus: false, + }, + }, +}) + export const decorators = [ withThemeByDataAttribute({ themes: { @@ -17,9 +26,15 @@ export const decorators = [ attributeName: 'data-theme', }), (Story) => { - return - - + return ( + + + + + + + + ) }, ] @@ -31,7 +46,11 @@ const preview: Preview = { date: /Date$/i, }, }, + docs: { + toc: true, + }, }, + tags: ['autodocs'], } export default preview diff --git a/web/app/components/base/button/index.stories.tsx b/web/app/components/base/button/index.stories.tsx index c1b18f1e50..e51b928e5e 100644 --- a/web/app/components/base/button/index.stories.tsx +++ b/web/app/components/base/button/index.stories.tsx @@ -1,5 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/react' -import { fn } from '@storybook/test' +import type { Meta, StoryObj } from '@storybook/nextjs' import { RocketLaunchIcon } from '@heroicons/react/20/solid' import { Button } from '.' @@ -20,8 +19,7 @@ const meta = { }, args: { variant: 'ghost', - onClick: fn(), - children: 'adsf', + children: 'Button', }, } satisfies Meta @@ -33,6 +31,9 @@ export const Default: Story = { variant: 'primary', loading: false, children: 'Primary Button', + styleCss: {}, + spinnerClassName: '', + destructive: false, }, } diff --git a/web/app/components/base/chat/chat/answer/index.stories.tsx b/web/app/components/base/chat/chat/answer/index.stories.tsx index 18bc129994..1f45844ec4 100644 --- a/web/app/components/base/chat/chat/answer/index.stories.tsx +++ b/web/app/components/base/chat/chat/answer/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/react' +import type { Meta, StoryObj } from '@storybook/nextjs' import type { ChatItem } from '../../types' import { mockedWorkflowProcess } from './__mocks__/workflowProcess' diff --git a/web/app/components/base/chat/chat/question.stories.tsx b/web/app/components/base/chat/chat/question.stories.tsx index 9c0eb8cad8..6474add9df 100644 --- a/web/app/components/base/chat/chat/question.stories.tsx +++ b/web/app/components/base/chat/chat/question.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/react' +import type { Meta, StoryObj } from '@storybook/nextjs' import type { ChatItem } from '../types' import Question from './question' diff --git a/web/app/components/base/confirm/index.stories.tsx b/web/app/components/base/confirm/index.stories.tsx new file mode 100644 index 0000000000..dfbe00f293 --- /dev/null +++ b/web/app/components/base/confirm/index.stories.tsx @@ -0,0 +1,199 @@ +import type { Meta, StoryObj } from '@storybook/nextjs' +import { useState } from 'react' +import Confirm from '.' +import Button from '../button' + +const meta = { + title: 'Base/Confirm', + component: Confirm, + parameters: { + layout: 'centered', + docs: { + description: { + component: 'Confirmation dialog component that supports warning and info types, with customizable button text and behavior.', + }, + }, + }, + tags: ['autodocs'], + argTypes: { + type: { + control: 'select', + options: ['info', 'warning'], + description: 'Dialog type', + }, + isShow: { + control: 'boolean', + description: 'Whether to show the dialog', + }, + title: { + control: 'text', + description: 'Dialog title', + }, + content: { + control: 'text', + description: 'Dialog content', + }, + confirmText: { + control: 'text', + description: 'Confirm button text', + }, + cancelText: { + control: 'text', + description: 'Cancel button text', + }, + isLoading: { + control: 'boolean', + description: 'Confirm button loading state', + }, + isDisabled: { + control: 'boolean', + description: 'Confirm button disabled state', + }, + showConfirm: { + control: 'boolean', + description: 'Whether to show confirm button', + }, + showCancel: { + control: 'boolean', + description: 'Whether to show cancel button', + }, + maskClosable: { + control: 'boolean', + description: 'Whether clicking mask closes dialog', + }, + }, +} satisfies Meta + +export default meta +type Story = StoryObj + +// Interactive demo wrapper +const ConfirmDemo = (args: any) => { + const [isShow, setIsShow] = useState(false) + + return ( +
+ + { + console.log('✅ User clicked confirm') + setIsShow(false) + }} + onCancel={() => { + console.log('❌ User clicked cancel') + setIsShow(false) + }} + /> +
+ ) +} + +// Basic warning dialog - Delete action +export const WarningDialog: Story = { + render: args => , + args: { + type: 'warning', + title: 'Delete Confirmation', + content: 'Are you sure you want to delete this project? This action cannot be undone.', + }, +} + +// Info dialog +export const InfoDialog: Story = { + render: args => , + args: { + type: 'info', + title: 'Notice', + content: 'Your changes have been saved. Do you want to proceed to the next step?', + }, +} + +// Custom button text +export const CustomButtonText: Story = { + render: args => , + args: { + type: 'warning', + title: 'Exit Editor', + content: 'You have unsaved changes. Are you sure you want to exit?', + confirmText: 'Discard Changes', + cancelText: 'Continue Editing', + }, +} + +// Loading state +export const LoadingState: Story = { + render: args => , + args: { + type: 'warning', + title: 'Deleting...', + content: 'Please wait while we delete the file...', + isLoading: true, + }, +} + +// Disabled state +export const DisabledState: Story = { + render: args => , + args: { + type: 'info', + title: 'Verification Required', + content: 'Please complete email verification before proceeding.', + isDisabled: true, + }, +} + +// Alert style - Confirm button only +export const AlertStyle: Story = { + render: args => , + args: { + type: 'info', + title: 'Success', + content: 'Your settings have been updated!', + showCancel: false, + confirmText: 'Got it', + }, +} + +// Dangerous action - Long content +export const DangerousAction: Story = { + render: args => , + args: { + type: 'warning', + title: 'Permanently Delete Account', + content: 'This action will permanently delete your account and all associated data, including: all projects and files, collaboration history, and personal settings. This action cannot be reversed!', + confirmText: 'Delete My Account', + cancelText: 'Keep My Account', + }, +} + +// Non-closable mask +export const NotMaskClosable: Story = { + render: args => , + args: { + type: 'warning', + title: 'Important Action', + content: 'This action requires your explicit choice. Clicking outside will not close this dialog.', + maskClosable: false, + }, +} + +// Full feature demo - Playground +export const Playground: Story = { + render: args => , + args: { + type: 'warning', + title: 'This is a title', + content: 'This is the dialog content text...', + confirmText: undefined, + cancelText: undefined, + isLoading: false, + isDisabled: false, + showConfirm: true, + showCancel: true, + maskClosable: true, + }, +} diff --git a/web/package.json b/web/package.json index f2c963a3eb..25fc27807a 100644 --- a/web/package.json +++ b/web/package.json @@ -142,7 +142,7 @@ "devDependencies": { "@antfu/eslint-config": "^5.0.0", "@babel/core": "^7.28.3", - "@chromatic-com/storybook": "^3.1.0", + "@chromatic-com/storybook": "^4.1.1", "@eslint-react/eslint-plugin": "^1.15.0", "@happy-dom/jest-environment": "^20.0.2", "@mdx-js/loader": "^3.1.0", @@ -151,14 +151,12 @@ "@next/eslint-plugin-next": "15.5.4", "@next/mdx": "15.5.4", "@rgrove/parse-xml": "^4.1.0", - "@storybook/addon-essentials": "8.5.0", - "@storybook/addon-interactions": "8.5.0", - "@storybook/addon-links": "8.5.0", - "@storybook/addon-onboarding": "8.5.0", - "@storybook/addon-themes": "8.5.0", - "@storybook/nextjs": "8.5.0", - "@storybook/react": "8.5.0", - "@storybook/test": "8.5.0", + "@storybook/addon-docs": "9.1.13", + "@storybook/addon-links": "9.1.13", + "@storybook/addon-onboarding": "9.1.13", + "@storybook/addon-themes": "9.1.13", + "@storybook/nextjs": "9.1.13", + "@storybook/react": "9.1.13", "@testing-library/dom": "^10.4.0", "@testing-library/jest-dom": "^6.8.0", "@testing-library/react": "^16.0.1", @@ -186,7 +184,7 @@ "eslint-plugin-react-hooks": "^5.1.0", "eslint-plugin-react-refresh": "^0.4.19", "eslint-plugin-sonarjs": "^3.0.2", - "eslint-plugin-storybook": "^9.0.7", + "eslint-plugin-storybook": "^9.1.13", "eslint-plugin-tailwindcss": "^3.18.0", "globals": "^15.11.0", "husky": "^9.1.6", @@ -197,7 +195,7 @@ "magicast": "^0.3.4", "postcss": "^8.4.47", "sass": "^1.92.1", - "storybook": "8.5.0", + "storybook": "9.1.13", "tailwindcss": "^3.4.14", "typescript": "^5.8.3", "uglify-js": "^3.19.3" @@ -243,7 +241,8 @@ "object.fromentries": "npm:@nolyfill/object.fromentries@^1", "object.groupby": "npm:@nolyfill/object.groupby@^1", "object.values": "npm:@nolyfill/object.values@^1", - "safe-buffer": "npm:@nolyfill/safe-buffer@^1", + "safe-buffer": "^5.2.1", + "@nolyfill/safe-buffer": "npm:safe-buffer@^5.2.1", "safe-regex-test": "npm:@nolyfill/safe-regex-test@^1", "safer-buffer": "npm:@nolyfill/safer-buffer@^1", "side-channel": "npm:@nolyfill/side-channel@^1", diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index 0f1495e474..4ee8e8af1d 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -9,11 +9,12 @@ overrides: '@types/react-dom': 19.1.7 string-width: 4.2.3 '@eslint/plugin-kit@<0.3.4': 0.3.4 + brace-expansion@<2.0.2: 2.0.2 + devalue@<5.3.2: 5.3.2 esbuild@<0.25.0: 0.25.0 pbkdf2@<3.1.3: 3.1.3 - vite@<6.2.7: 6.2.7 prismjs@<1.30.0: 1.30.0 - brace-expansion@<2.0.2: 2.0.2 + vite@<6.2.7: 6.2.7 array-includes: npm:@nolyfill/array-includes@^1 array.prototype.findlast: npm:@nolyfill/array.prototype.findlast@^1 array.prototype.findlastindex: npm:@nolyfill/array.prototype.findlastindex@^1 @@ -33,7 +34,8 @@ overrides: object.fromentries: npm:@nolyfill/object.fromentries@^1 object.groupby: npm:@nolyfill/object.groupby@^1 object.values: npm:@nolyfill/object.values@^1 - safe-buffer: npm:@nolyfill/safe-buffer@^1 + safe-buffer: ^5.2.1 + '@nolyfill/safe-buffer': npm:safe-buffer@^5.2.1 safe-regex-test: npm:@nolyfill/safe-regex-test@^1 safer-buffer: npm:@nolyfill/safer-buffer@^1 side-channel: npm:@nolyfill/side-channel@^1 @@ -43,7 +45,6 @@ overrides: string.prototype.trimend: npm:@nolyfill/string.prototype.trimend@^1 typed-array-buffer: npm:@nolyfill/typed-array-buffer@^1 which-typed-array: npm:@nolyfill/which-typed-array@^1 - devalue@<5.3.2: 5.3.2 importers: @@ -216,7 +217,7 @@ importers: version: 15.5.4(@babel/core@7.28.3)(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(sass@1.92.1) next-pwa: specifier: ^5.6.0 - version: 5.6.0(@babel/core@7.28.3)(@types/babel__core@7.20.5)(esbuild@0.25.0)(next@15.5.4(@babel/core@7.28.3)(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(sass@1.92.1))(uglify-js@3.19.3)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)) + version: 5.6.0(@babel/core@7.28.3)(@types/babel__core@7.20.5)(esbuild@0.25.0)(next@15.5.4(@babel/core@7.28.3)(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(sass@1.92.1))(uglify-js@3.19.3)(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)) next-themes: specifier: ^0.4.3 version: 0.4.6(react-dom@19.1.1(react@19.1.1))(react@19.1.1) @@ -339,8 +340,8 @@ importers: specifier: ^7.28.3 version: 7.28.3 '@chromatic-com/storybook': - specifier: ^3.1.0 - version: 3.2.7(react@19.1.1)(storybook@8.5.0) + specifier: ^4.1.1 + version: 4.1.1(storybook@9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0))) '@eslint-react/eslint-plugin': specifier: ^1.15.0 version: 1.52.3(eslint@9.35.0(jiti@2.6.1))(ts-api-utils@2.1.0(typescript@5.8.3))(typescript@5.8.3) @@ -349,7 +350,7 @@ importers: version: 20.0.4(@jest/environment@29.7.0)(@jest/fake-timers@29.7.0)(@jest/types@29.6.3)(jest-mock@29.7.0)(jest-util@29.7.0) '@mdx-js/loader': specifier: ^3.1.0 - version: 3.1.0(acorn@8.15.0)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)) + version: 3.1.0(acorn@8.15.0)(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)) '@mdx-js/react': specifier: ^3.1.0 version: 3.1.0(@types/react@19.1.11)(react@19.1.1) @@ -361,34 +362,28 @@ importers: version: 15.5.4 '@next/mdx': specifier: 15.5.4 - version: 15.5.4(@mdx-js/loader@3.1.0(acorn@8.15.0)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)))(@mdx-js/react@3.1.0(@types/react@19.1.11)(react@19.1.1)) + version: 15.5.4(@mdx-js/loader@3.1.0(acorn@8.15.0)(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)))(@mdx-js/react@3.1.0(@types/react@19.1.11)(react@19.1.1)) '@rgrove/parse-xml': specifier: ^4.1.0 version: 4.2.0 - '@storybook/addon-essentials': - specifier: 8.5.0 - version: 8.5.0(@types/react@19.1.11)(storybook@8.5.0) - '@storybook/addon-interactions': - specifier: 8.5.0 - version: 8.5.0(storybook@8.5.0) + '@storybook/addon-docs': + specifier: 9.1.13 + version: 9.1.13(@types/react@19.1.11)(storybook@9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0))) '@storybook/addon-links': - specifier: 8.5.0 - version: 8.5.0(react@19.1.1)(storybook@8.5.0) + specifier: 9.1.13 + version: 9.1.13(react@19.1.1)(storybook@9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0))) '@storybook/addon-onboarding': - specifier: 8.5.0 - version: 8.5.0(storybook@8.5.0) + specifier: 9.1.13 + version: 9.1.13(storybook@9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0))) '@storybook/addon-themes': - specifier: 8.5.0 - version: 8.5.0(storybook@8.5.0) + specifier: 9.1.13 + version: 9.1.13(storybook@9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0))) '@storybook/nextjs': - specifier: 8.5.0 - version: 8.5.0(esbuild@0.25.0)(next@15.5.4(@babel/core@7.28.3)(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(sass@1.92.1))(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(sass@1.92.1)(storybook@8.5.0)(type-fest@2.19.0)(typescript@5.8.3)(uglify-js@3.19.3)(webpack-hot-middleware@2.26.1)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)) + specifier: 9.1.13 + version: 9.1.13(esbuild@0.25.0)(next@15.5.4(@babel/core@7.28.3)(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(sass@1.92.1))(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(sass@1.92.1)(storybook@9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)))(type-fest@2.19.0)(typescript@5.8.3)(uglify-js@3.19.3)(webpack-hot-middleware@2.26.1)(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)) '@storybook/react': - specifier: 8.5.0 - version: 8.5.0(@storybook/test@8.5.0(storybook@8.5.0))(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(storybook@8.5.0)(typescript@5.8.3) - '@storybook/test': - specifier: 8.5.0 - version: 8.5.0(storybook@8.5.0) + specifier: 9.1.13 + version: 9.1.13(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(storybook@9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)))(typescript@5.8.3) '@testing-library/dom': specifier: ^10.4.0 version: 10.4.0 @@ -445,7 +440,7 @@ importers: version: 10.4.21(postcss@8.5.6) babel-loader: specifier: ^10.0.0 - version: 10.0.0(@babel/core@7.28.3)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)) + version: 10.0.0(@babel/core@7.28.3)(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)) bing-translate-api: specifier: ^4.0.2 version: 4.1.0 @@ -471,8 +466,8 @@ importers: specifier: ^3.0.2 version: 3.0.4(eslint@9.35.0(jiti@2.6.1)) eslint-plugin-storybook: - specifier: ^9.0.7 - version: 9.0.7(eslint@9.35.0(jiti@2.6.1))(typescript@5.8.3) + specifier: ^9.1.13 + version: 9.1.13(eslint@9.35.0(jiti@2.6.1))(storybook@9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)))(typescript@5.8.3) eslint-plugin-tailwindcss: specifier: ^3.18.0 version: 3.18.2(tailwindcss@3.4.17(ts-node@10.9.2(@types/node@18.15.0)(typescript@5.8.3))) @@ -504,8 +499,8 @@ importers: specifier: ^1.92.1 version: 1.92.1 storybook: - specifier: 8.5.0 - version: 8.5.0 + specifier: 9.1.13 + version: 9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)) tailwindcss: specifier: ^3.4.14 version: 3.4.17(ts-node@10.9.2(@types/node@18.15.0)(typescript@5.8.3)) @@ -1289,11 +1284,11 @@ packages: '@chevrotain/utils@11.0.3': resolution: {integrity: sha512-YslZMgtJUyuMbZ+aKvfF3x1f5liK4mWNxghFRv7jqRR9C3R3fAOGTTKvxXDa2Y1s9zSbcpuO0cAxDYsc9SrXoQ==} - '@chromatic-com/storybook@3.2.7': - resolution: {integrity: sha512-fCGhk4cd3VA8RNg55MZL5CScdHqljsQcL9g6Ss7YuobHpSo9yytEWNdgMd5QxAHSPBlLGFHjnSmliM3G/BeBqw==} - engines: {node: '>=16.0.0', yarn: '>=1.22.18'} + '@chromatic-com/storybook@4.1.1': + resolution: {integrity: sha512-+Ib4cHtEjKl/Do+4LyU0U1FhLPbIU2Q/zgbOKHBCV+dTC4T3/vGzPqiGsgkdnZyTsK/zXg96LMPSPC4jjOiapg==} + engines: {node: '>=20.0.0', yarn: '>=1.22.18'} peerDependencies: - storybook: ^8.2.0 || ^8.3.0-0 || ^8.4.0-0 || ^8.5.0-0 || ^8.6.0-0 + storybook: ^0.0.0-0 || ^9.0.0 || ^9.1.0-0 || ^9.2.0-0 || ^10.0.0-0 '@clack/core@0.5.0': resolution: {integrity: sha512-p3y0FIOwaYRUPRcMO7+dlmLh8PSRcrjuTndsiA0WAFbWES0mLZlrjVoBRZ9DzkPFJZG6KGkJmoEAY0ZcVWTkow==} @@ -1752,144 +1747,170 @@ packages: resolution: {integrity: sha512-9B+taZ8DlyyqzZQnoeIvDVR/2F4EbMepXMc/NdVbkzsJbzkUjhXv/70GQJ7tdLA4YJgNP25zukcxpX2/SueNrA==} cpu: [arm64] os: [linux] + libc: [glibc] '@img/sharp-libvips-linux-arm64@1.2.3': resolution: {integrity: sha512-I4RxkXU90cpufazhGPyVujYwfIm9Nk1QDEmiIsaPwdnm013F7RIceaCc87kAH+oUB1ezqEvC6ga4m7MSlqsJvQ==} cpu: [arm64] os: [linux] + libc: [glibc] '@img/sharp-libvips-linux-arm@1.0.5': resolution: {integrity: sha512-gvcC4ACAOPRNATg/ov8/MnbxFDJqf/pDePbBnuBDcjsI8PssmjoKMAz4LtLaVi+OnSb5FK/yIOamqDwGmXW32g==} cpu: [arm] os: [linux] + libc: [glibc] '@img/sharp-libvips-linux-arm@1.2.3': resolution: {integrity: sha512-x1uE93lyP6wEwGvgAIV0gP6zmaL/a0tGzJs/BIDDG0zeBhMnuUPm7ptxGhUbcGs4okDJrk4nxgrmxpib9g6HpA==} cpu: [arm] os: [linux] + libc: [glibc] '@img/sharp-libvips-linux-ppc64@1.2.3': resolution: {integrity: sha512-Y2T7IsQvJLMCBM+pmPbM3bKT/yYJvVtLJGfCs4Sp95SjvnFIjynbjzsa7dY1fRJX45FTSfDksbTp6AGWudiyCg==} cpu: [ppc64] os: [linux] + libc: [glibc] '@img/sharp-libvips-linux-s390x@1.0.4': resolution: {integrity: sha512-u7Wz6ntiSSgGSGcjZ55im6uvTrOxSIS8/dgoVMoiGE9I6JAfU50yH5BoDlYA1tcuGS7g/QNtetJnxA6QEsCVTA==} cpu: [s390x] os: [linux] + libc: [glibc] '@img/sharp-libvips-linux-s390x@1.2.3': resolution: {integrity: sha512-RgWrs/gVU7f+K7P+KeHFaBAJlNkD1nIZuVXdQv6S+fNA6syCcoboNjsV2Pou7zNlVdNQoQUpQTk8SWDHUA3y/w==} cpu: [s390x] os: [linux] + libc: [glibc] '@img/sharp-libvips-linux-x64@1.0.4': resolution: {integrity: sha512-MmWmQ3iPFZr0Iev+BAgVMb3ZyC4KeFc3jFxnNbEPas60e1cIfevbtuyf9nDGIzOaW9PdnDciJm+wFFaTlj5xYw==} cpu: [x64] os: [linux] + libc: [glibc] '@img/sharp-libvips-linux-x64@1.2.3': resolution: {integrity: sha512-3JU7LmR85K6bBiRzSUc/Ff9JBVIFVvq6bomKE0e63UXGeRw2HPVEjoJke1Yx+iU4rL7/7kUjES4dZ/81Qjhyxg==} cpu: [x64] os: [linux] + libc: [glibc] '@img/sharp-libvips-linuxmusl-arm64@1.0.4': resolution: {integrity: sha512-9Ti+BbTYDcsbp4wfYib8Ctm1ilkugkA/uscUn6UXK1ldpC1JjiXbLfFZtRlBhjPZ5o1NCLiDbg8fhUPKStHoTA==} cpu: [arm64] os: [linux] + libc: [musl] '@img/sharp-libvips-linuxmusl-arm64@1.2.3': resolution: {integrity: sha512-F9q83RZ8yaCwENw1GieztSfj5msz7GGykG/BA+MOUefvER69K/ubgFHNeSyUu64amHIYKGDs4sRCMzXVj8sEyw==} cpu: [arm64] os: [linux] + libc: [musl] '@img/sharp-libvips-linuxmusl-x64@1.0.4': resolution: {integrity: sha512-viYN1KX9m+/hGkJtvYYp+CCLgnJXwiQB39damAO7WMdKWlIhmYTfHjwSbQeUK/20vY154mwezd9HflVFM1wVSw==} cpu: [x64] os: [linux] + libc: [musl] '@img/sharp-libvips-linuxmusl-x64@1.2.3': resolution: {integrity: sha512-U5PUY5jbc45ANM6tSJpsgqmBF/VsL6LnxJmIf11kB7J5DctHgqm0SkuXzVWtIY90GnJxKnC/JT251TDnk1fu/g==} cpu: [x64] os: [linux] + libc: [musl] '@img/sharp-linux-arm64@0.33.5': resolution: {integrity: sha512-JMVv+AMRyGOHtO1RFBiJy/MBsgz0x4AWrT6QoEVVTyh1E39TrCUpTRI7mx9VksGX4awWASxqCYLCV4wBZHAYxA==} engines: {node: ^18.17.0 || ^20.3.0 || >=21.0.0} cpu: [arm64] os: [linux] + libc: [glibc] '@img/sharp-linux-arm64@0.34.4': resolution: {integrity: sha512-YXU1F/mN/Wu786tl72CyJjP/Ngl8mGHN1hST4BGl+hiW5jhCnV2uRVTNOcaYPs73NeT/H8Upm3y9582JVuZHrQ==} engines: {node: ^18.17.0 || ^20.3.0 || >=21.0.0} cpu: [arm64] os: [linux] + libc: [glibc] '@img/sharp-linux-arm@0.33.5': resolution: {integrity: sha512-JTS1eldqZbJxjvKaAkxhZmBqPRGmxgu+qFKSInv8moZ2AmT5Yib3EQ1c6gp493HvrvV8QgdOXdyaIBrhvFhBMQ==} engines: {node: ^18.17.0 || ^20.3.0 || >=21.0.0} cpu: [arm] os: [linux] + libc: [glibc] '@img/sharp-linux-arm@0.34.4': resolution: {integrity: sha512-Xyam4mlqM0KkTHYVSuc6wXRmM7LGN0P12li03jAnZ3EJWZqj83+hi8Y9UxZUbxsgsK1qOEwg7O0Bc0LjqQVtxA==} engines: {node: ^18.17.0 || ^20.3.0 || >=21.0.0} cpu: [arm] os: [linux] + libc: [glibc] '@img/sharp-linux-ppc64@0.34.4': resolution: {integrity: sha512-F4PDtF4Cy8L8hXA2p3TO6s4aDt93v+LKmpcYFLAVdkkD3hSxZzee0rh6/+94FpAynsuMpLX5h+LRsSG3rIciUQ==} engines: {node: ^18.17.0 || ^20.3.0 || >=21.0.0} cpu: [ppc64] os: [linux] + libc: [glibc] '@img/sharp-linux-s390x@0.33.5': resolution: {integrity: sha512-y/5PCd+mP4CA/sPDKl2961b+C9d+vPAveS33s6Z3zfASk2j5upL6fXVPZi7ztePZ5CuH+1kW8JtvxgbuXHRa4Q==} engines: {node: ^18.17.0 || ^20.3.0 || >=21.0.0} cpu: [s390x] os: [linux] + libc: [glibc] '@img/sharp-linux-s390x@0.34.4': resolution: {integrity: sha512-qVrZKE9Bsnzy+myf7lFKvng6bQzhNUAYcVORq2P7bDlvmF6u2sCmK2KyEQEBdYk+u3T01pVsPrkj943T1aJAsw==} engines: {node: ^18.17.0 || ^20.3.0 || >=21.0.0} cpu: [s390x] os: [linux] + libc: [glibc] '@img/sharp-linux-x64@0.33.5': resolution: {integrity: sha512-opC+Ok5pRNAzuvq1AG0ar+1owsu842/Ab+4qvU879ippJBHvyY5n2mxF1izXqkPYlGuP/M556uh53jRLJmzTWA==} engines: {node: ^18.17.0 || ^20.3.0 || >=21.0.0} cpu: [x64] os: [linux] + libc: [glibc] '@img/sharp-linux-x64@0.34.4': resolution: {integrity: sha512-ZfGtcp2xS51iG79c6Vhw9CWqQC8l2Ot8dygxoDoIQPTat/Ov3qAa8qpxSrtAEAJW+UjTXc4yxCjNfxm4h6Xm2A==} engines: {node: ^18.17.0 || ^20.3.0 || >=21.0.0} cpu: [x64] os: [linux] + libc: [glibc] '@img/sharp-linuxmusl-arm64@0.33.5': resolution: {integrity: sha512-XrHMZwGQGvJg2V/oRSUfSAfjfPxO+4DkiRh6p2AFjLQztWUuY/o8Mq0eMQVIY7HJ1CDQUJlxGGZRw1a5bqmd1g==} engines: {node: ^18.17.0 || ^20.3.0 || >=21.0.0} cpu: [arm64] os: [linux] + libc: [musl] '@img/sharp-linuxmusl-arm64@0.34.4': resolution: {integrity: sha512-8hDVvW9eu4yHWnjaOOR8kHVrew1iIX+MUgwxSuH2XyYeNRtLUe4VNioSqbNkB7ZYQJj9rUTT4PyRscyk2PXFKA==} engines: {node: ^18.17.0 || ^20.3.0 || >=21.0.0} cpu: [arm64] os: [linux] + libc: [musl] '@img/sharp-linuxmusl-x64@0.33.5': resolution: {integrity: sha512-WT+d/cgqKkkKySYmqoZ8y3pxx7lx9vVejxW/W4DOFMYVSkErR+w7mf2u8m/y4+xHe7yY9DAXQMWQhpnMuFfScw==} engines: {node: ^18.17.0 || ^20.3.0 || >=21.0.0} cpu: [x64] os: [linux] + libc: [musl] '@img/sharp-linuxmusl-x64@0.34.4': resolution: {integrity: sha512-lU0aA5L8QTlfKjpDCEFOZsTYGn3AEiO6db8W5aQDxj0nQkVrZWmN3ZP9sYKWJdtq3PWPhUNlqehWyXpYDcI9Sg==} engines: {node: ^18.17.0 || ^20.3.0 || >=21.0.0} cpu: [x64] os: [linux] + libc: [musl] '@img/sharp-wasm32@0.33.5': resolution: {integrity: sha512-ykUW4LVGaMcU9lu9thv85CbRMAwfeadCJHRsg2GmeRa/cJxsVY9Rbd57JcMxBkKHag5U/x7TSBpScF4U8ElVzg==} @@ -2171,6 +2192,9 @@ packages: '@napi-rs/wasm-runtime@1.0.7': resolution: {integrity: sha512-SeDnOO0Tk7Okiq6DbXmmBODgOAb9dp9gjlphokTUxmt8U3liIP1ZsozBahH69j/RJv+Rfs6IwUKHTgQYJ/HBAw==} + '@neoconfetti/react@1.0.0': + resolution: {integrity: sha512-klcSooChXXOzIm+SE5IISIAn3bYzYfPjbX7D7HoqZL84oAfgREeSg5vSIaSFH+DaGzzvImTyWe1OyrJ67vik4A==} + '@next/bundle-analyzer@15.5.4': resolution: {integrity: sha512-wMtpIjEHi+B/wC34ZbEcacGIPgQTwTFjjp0+F742s9TxC6QwT0MwB/O0QEgalMe8s3SH/K09DO0gmTvUSJrLRA==} @@ -2208,24 +2232,28 @@ packages: engines: {node: '>= 10'} cpu: [arm64] os: [linux] + libc: [glibc] '@next/swc-linux-arm64-musl@15.5.4': resolution: {integrity: sha512-TOK7iTxmXFc45UrtKqWdZ1shfxuL4tnVAOuuJK4S88rX3oyVV4ZkLjtMT85wQkfBrOOvU55aLty+MV8xmcJR8A==} engines: {node: '>= 10'} cpu: [arm64] os: [linux] + libc: [musl] '@next/swc-linux-x64-gnu@15.5.4': resolution: {integrity: sha512-7HKolaj+481FSW/5lL0BcTkA4Ueam9SPYWyN/ib/WGAFZf0DGAN8frNpNZYFHtM4ZstrHZS3LY3vrwlIQfsiMA==} engines: {node: '>= 10'} cpu: [x64] os: [linux] + libc: [glibc] '@next/swc-linux-x64-musl@15.5.4': resolution: {integrity: sha512-nlQQ6nfgN0nCO/KuyEUwwOdwQIGjOs4WNMjEUtpIQJPR2NUfmGpW2wkJln1d4nJ7oUzd1g4GivH5GoEPBgfsdw==} engines: {node: '>= 10'} cpu: [x64] os: [linux] + libc: [musl] '@next/swc-win32-arm64-msvc@15.5.4': resolution: {integrity: sha512-PcR2bN7FlM32XM6eumklmyWLLbu2vs+D7nJX8OAIoWy69Kef8mfiN4e8TUv2KohprwifdpFKPzIP1njuCjD0YA==} @@ -2303,10 +2331,6 @@ packages: resolution: {integrity: sha512-bwIpVzFMudUC0ofnvdSDB/OyGUizcU+r32ZZ0QTMbN03gUttMtdCFDekuSYT0XGFgufTQyZ4ONBnAeb3DFCPGQ==} engines: {node: '>=12.4.0'} - '@nolyfill/safe-buffer@1.0.44': - resolution: {integrity: sha512-SqlKXtlhNTDMeZKey9jnnuPhi8YTl1lJuEcY9zbm5i4Pqe79UJJ8IJ9oiD6DhgI8KjYc+HtLzpQJNRdNYqb/hw==} - engines: {node: '>=12.4.0'} - '@nolyfill/safer-buffer@1.0.44': resolution: {integrity: sha512-Ouw1fMwjAy1V4MpnDASfu1DCPgkP0nNFteiiWbFoEGSqa7Vnmkb6if2c522N2WcMk+RuaaabQbC1F1D4/kTXcg==} engines: {node: '>=12.4.0'} @@ -2402,41 +2426,49 @@ packages: resolution: {integrity: sha512-TWq+y2psMzbMtZB9USAq2bSA7NV1TMmh9lhAFbMGQ8Yp2YV4BRC/HilD6qF++efQl6shueGBFOv0LVe9BUXaIA==} cpu: [arm64] os: [linux] + libc: [glibc] '@oxc-resolver/binding-linux-arm64-musl@11.9.0': resolution: {integrity: sha512-8WwGLfXk7yttc6rD6g53+RnYfX5B8xOot1ffthLn8oCXzVRO4cdChlmeHStxwLD/MWx8z8BGeyfyINNrsh9N2w==} cpu: [arm64] os: [linux] + libc: [musl] '@oxc-resolver/binding-linux-ppc64-gnu@11.9.0': resolution: {integrity: sha512-ZWiAXfan6actlSzayaFS/kYO2zD6k1k0fmLb1opbujXYMKepEnjjVOvKdzCIYR/zKzudqI39dGc+ywqVdsPIpQ==} cpu: [ppc64] os: [linux] + libc: [glibc] '@oxc-resolver/binding-linux-riscv64-gnu@11.9.0': resolution: {integrity: sha512-p9mCSb+Bym+eycNo9k+81wQ5SAE31E+/rtfbDmF4/7krPotkEjPsEBSc3rqunRwO+FtsUn7H68JLY7hlai49eQ==} cpu: [riscv64] os: [linux] + libc: [glibc] '@oxc-resolver/binding-linux-riscv64-musl@11.9.0': resolution: {integrity: sha512-/SePuVxgFhLPciRwsJ8kLVltr+rxh0b6riGFuoPnFXBbHFclKnjNIt3TfqzUj0/vOnslXw3cVGPpmtkm2TgCgg==} cpu: [riscv64] os: [linux] + libc: [musl] '@oxc-resolver/binding-linux-s390x-gnu@11.9.0': resolution: {integrity: sha512-zLuEjlYIzfnr1Ei2UZYQBbCTa/9deh+BEjO9rh1ai8BfEq4uj6RupTtNpgHfgAsEYdqOBVExw9EU1S6SW3RCAw==} cpu: [s390x] os: [linux] + libc: [glibc] '@oxc-resolver/binding-linux-x64-gnu@11.9.0': resolution: {integrity: sha512-cxdg73WG+aVlPu/k4lEQPRVOhWunYOUglW6OSzclZLJJAXZU0tSZ5ymKaqPRkfTsyNSAafj1cA1XYd+P9UxBgw==} cpu: [x64] os: [linux] + libc: [glibc] '@oxc-resolver/binding-linux-x64-musl@11.9.0': resolution: {integrity: sha512-sy5nkVdMvNgqcx9sIY7G6U9TYZUZC4cmMGw/wKhJNuuD2/HFGtbje62ttXSwBAbVbmJ2GgZ4ZUo/S1OMyU+/OA==} cpu: [x64] os: [linux] + libc: [musl] '@oxc-resolver/binding-wasm32-wasi@11.9.0': resolution: {integrity: sha512-dfi/a0Xh6o6nOLbJdaYuy7txncEcwkRHp9DGGZaAP7zxDiepkBZ6ewSJODQrWwhjVmMteXo+XFzEOMjsC7WUtQ==} @@ -2487,36 +2519,42 @@ packages: engines: {node: '>= 10.0.0'} cpu: [arm] os: [linux] + libc: [glibc] '@parcel/watcher-linux-arm-musl@2.5.1': resolution: {integrity: sha512-6E+m/Mm1t1yhB8X412stiKFG3XykmgdIOqhjWj+VL8oHkKABfu/gjFj8DvLrYVHSBNC+/u5PeNrujiSQ1zwd1Q==} engines: {node: '>= 10.0.0'} cpu: [arm] os: [linux] + libc: [musl] '@parcel/watcher-linux-arm64-glibc@2.5.1': resolution: {integrity: sha512-LrGp+f02yU3BN9A+DGuY3v3bmnFUggAITBGriZHUREfNEzZh/GO06FF5u2kx8x+GBEUYfyTGamol4j3m9ANe8w==} engines: {node: '>= 10.0.0'} cpu: [arm64] os: [linux] + libc: [glibc] '@parcel/watcher-linux-arm64-musl@2.5.1': resolution: {integrity: sha512-cFOjABi92pMYRXS7AcQv9/M1YuKRw8SZniCDw0ssQb/noPkRzA+HBDkwmyOJYp5wXcsTrhxO0zq1U11cK9jsFg==} engines: {node: '>= 10.0.0'} cpu: [arm64] os: [linux] + libc: [musl] '@parcel/watcher-linux-x64-glibc@2.5.1': resolution: {integrity: sha512-GcESn8NZySmfwlTsIur+49yDqSny2IhPeZfXunQi48DMugKeZ7uy1FX83pO0X22sHntJ4Ub+9k34XQCX+oHt2A==} engines: {node: '>= 10.0.0'} cpu: [x64] os: [linux] + libc: [glibc] '@parcel/watcher-linux-x64-musl@2.5.1': resolution: {integrity: sha512-n0E2EQbatQ3bXhcH2D1XIAANAcTZkQICBPVaxMeaCVBtOpBZpWJuf7LwyWPSBDITb7In8mqQgJ7gH8CILCURXg==} engines: {node: '>= 10.0.0'} cpu: [x64] os: [linux] + libc: [musl] '@parcel/watcher-win32-arm64@2.5.1': resolution: {integrity: sha512-RFzklRvmc3PkjKjry3hLF9wD7ppR4AKcWNzH7kXR7GUe0Igb3Nz8fyPwtZCSquGrhU5HhUNDr/mKBqj7tqA2Vw==} @@ -2861,6 +2899,127 @@ packages: peerDependencies: rollup: ^1.20.0||^2.0.0 + '@rollup/rollup-android-arm-eabi@4.52.5': + resolution: {integrity: sha512-8c1vW4ocv3UOMp9K+gToY5zL2XiiVw3k7f1ksf4yO1FlDFQ1C2u72iACFnSOceJFsWskc2WZNqeRhFRPzv+wtQ==} + cpu: [arm] + os: [android] + + '@rollup/rollup-android-arm64@4.52.5': + resolution: {integrity: sha512-mQGfsIEFcu21mvqkEKKu2dYmtuSZOBMmAl5CFlPGLY94Vlcm+zWApK7F/eocsNzp8tKmbeBP8yXyAbx0XHsFNA==} + cpu: [arm64] + os: [android] + + '@rollup/rollup-darwin-arm64@4.52.5': + resolution: {integrity: sha512-takF3CR71mCAGA+v794QUZ0b6ZSrgJkArC+gUiG6LB6TQty9T0Mqh3m2ImRBOxS2IeYBo4lKWIieSvnEk2OQWA==} + cpu: [arm64] + os: [darwin] + + '@rollup/rollup-darwin-x64@4.52.5': + resolution: {integrity: sha512-W901Pla8Ya95WpxDn//VF9K9u2JbocwV/v75TE0YIHNTbhqUTv9w4VuQ9MaWlNOkkEfFwkdNhXgcLqPSmHy0fA==} + cpu: [x64] + os: [darwin] + + '@rollup/rollup-freebsd-arm64@4.52.5': + resolution: {integrity: sha512-QofO7i7JycsYOWxe0GFqhLmF6l1TqBswJMvICnRUjqCx8b47MTo46W8AoeQwiokAx3zVryVnxtBMcGcnX12LvA==} + cpu: [arm64] + os: [freebsd] + + '@rollup/rollup-freebsd-x64@4.52.5': + resolution: {integrity: sha512-jr21b/99ew8ujZubPo9skbrItHEIE50WdV86cdSoRkKtmWa+DDr6fu2c/xyRT0F/WazZpam6kk7IHBerSL7LDQ==} + cpu: [x64] + os: [freebsd] + + '@rollup/rollup-linux-arm-gnueabihf@4.52.5': + resolution: {integrity: sha512-PsNAbcyv9CcecAUagQefwX8fQn9LQ4nZkpDboBOttmyffnInRy8R8dSg6hxxl2Re5QhHBf6FYIDhIj5v982ATQ==} + cpu: [arm] + os: [linux] + libc: [glibc] + + '@rollup/rollup-linux-arm-musleabihf@4.52.5': + resolution: {integrity: sha512-Fw4tysRutyQc/wwkmcyoqFtJhh0u31K+Q6jYjeicsGJJ7bbEq8LwPWV/w0cnzOqR2m694/Af6hpFayLJZkG2VQ==} + cpu: [arm] + os: [linux] + libc: [musl] + + '@rollup/rollup-linux-arm64-gnu@4.52.5': + resolution: {integrity: sha512-a+3wVnAYdQClOTlyapKmyI6BLPAFYs0JM8HRpgYZQO02rMR09ZcV9LbQB+NL6sljzG38869YqThrRnfPMCDtZg==} + cpu: [arm64] + os: [linux] + libc: [glibc] + + '@rollup/rollup-linux-arm64-musl@4.52.5': + resolution: {integrity: sha512-AvttBOMwO9Pcuuf7m9PkC1PUIKsfaAJ4AYhy944qeTJgQOqJYJ9oVl2nYgY7Rk0mkbsuOpCAYSs6wLYB2Xiw0Q==} + cpu: [arm64] + os: [linux] + libc: [musl] + + '@rollup/rollup-linux-loong64-gnu@4.52.5': + resolution: {integrity: sha512-DkDk8pmXQV2wVrF6oq5tONK6UHLz/XcEVow4JTTerdeV1uqPeHxwcg7aFsfnSm9L+OO8WJsWotKM2JJPMWrQtA==} + cpu: [loong64] + os: [linux] + libc: [glibc] + + '@rollup/rollup-linux-ppc64-gnu@4.52.5': + resolution: {integrity: sha512-W/b9ZN/U9+hPQVvlGwjzi+Wy4xdoH2I8EjaCkMvzpI7wJUs8sWJ03Rq96jRnHkSrcHTpQe8h5Tg3ZzUPGauvAw==} + cpu: [ppc64] + os: [linux] + libc: [glibc] + + '@rollup/rollup-linux-riscv64-gnu@4.52.5': + resolution: {integrity: sha512-sjQLr9BW7R/ZiXnQiWPkErNfLMkkWIoCz7YMn27HldKsADEKa5WYdobaa1hmN6slu9oWQbB6/jFpJ+P2IkVrmw==} + cpu: [riscv64] + os: [linux] + libc: [glibc] + + '@rollup/rollup-linux-riscv64-musl@4.52.5': + resolution: {integrity: sha512-hq3jU/kGyjXWTvAh2awn8oHroCbrPm8JqM7RUpKjalIRWWXE01CQOf/tUNWNHjmbMHg/hmNCwc/Pz3k1T/j/Lg==} + cpu: [riscv64] + os: [linux] + libc: [musl] + + '@rollup/rollup-linux-s390x-gnu@4.52.5': + resolution: {integrity: sha512-gn8kHOrku8D4NGHMK1Y7NA7INQTRdVOntt1OCYypZPRt6skGbddska44K8iocdpxHTMMNui5oH4elPH4QOLrFQ==} + cpu: [s390x] + os: [linux] + libc: [glibc] + + '@rollup/rollup-linux-x64-gnu@4.52.5': + resolution: {integrity: sha512-hXGLYpdhiNElzN770+H2nlx+jRog8TyynpTVzdlc6bndktjKWyZyiCsuDAlpd+j+W+WNqfcyAWz9HxxIGfZm1Q==} + cpu: [x64] + os: [linux] + libc: [glibc] + + '@rollup/rollup-linux-x64-musl@4.52.5': + resolution: {integrity: sha512-arCGIcuNKjBoKAXD+y7XomR9gY6Mw7HnFBv5Rw7wQRvwYLR7gBAgV7Mb2QTyjXfTveBNFAtPt46/36vV9STLNg==} + cpu: [x64] + os: [linux] + libc: [musl] + + '@rollup/rollup-openharmony-arm64@4.52.5': + resolution: {integrity: sha512-QoFqB6+/9Rly/RiPjaomPLmR/13cgkIGfA40LHly9zcH1S0bN2HVFYk3a1eAyHQyjs3ZJYlXvIGtcCs5tko9Cw==} + cpu: [arm64] + os: [openharmony] + + '@rollup/rollup-win32-arm64-msvc@4.52.5': + resolution: {integrity: sha512-w0cDWVR6MlTstla1cIfOGyl8+qb93FlAVutcor14Gf5Md5ap5ySfQ7R9S/NjNaMLSFdUnKGEasmVnu3lCMqB7w==} + cpu: [arm64] + os: [win32] + + '@rollup/rollup-win32-ia32-msvc@4.52.5': + resolution: {integrity: sha512-Aufdpzp7DpOTULJCuvzqcItSGDH73pF3ko/f+ckJhxQyHtp67rHw3HMNxoIdDMUITJESNE6a8uh4Lo4SLouOUg==} + cpu: [ia32] + os: [win32] + + '@rollup/rollup-win32-x64-gnu@4.52.5': + resolution: {integrity: sha512-UGBUGPFp1vkj6p8wCRraqNhqwX/4kNQPS57BCFc8wYh0g94iVIW33wJtQAx3G7vrjjNtRaxiMUylM0ktp/TRSQ==} + cpu: [x64] + os: [win32] + + '@rollup/rollup-win32-x64-msvc@4.52.5': + resolution: {integrity: sha512-TAcgQh2sSkykPRWLrdyy2AiceMckNf5loITqXxFI5VuQjS5tSuw3WlwdN8qv8vzjLAUTvYaH/mVjSFpbkFbpTg==} + cpu: [x64] + os: [win32] + '@sentry-internal/browser-utils@8.55.0': resolution: {integrity: sha512-ROgqtQfpH/82AQIpESPqPQe0UyWywKJsmVIqi3c5Fh+zkds5LUxnssTj3yNd1x+kxaPDVB023jAP+3ibNgeNDw==} engines: {node: '>=14.18'} @@ -2904,158 +3063,67 @@ packages: '@sinonjs/fake-timers@10.3.0': resolution: {integrity: sha512-V4BG07kuYSUkTCSBHG8G8TNhM+F19jXFWnQtzj+we8DrkpSBCee9Z3Ms8yiGer/dlmhe35/Xdgyo3/0rQKg7YA==} - '@storybook/addon-actions@8.5.0': - resolution: {integrity: sha512-6CW9+17rk5eNx6I8EKqCxRKtsJFTR/lHL+xiJ6/iBWApIm8sg63vhXvUTJ58UixmIkT5oLh0+ESNPh+x10D8fw==} + '@storybook/addon-docs@9.1.13': + resolution: {integrity: sha512-V1nCo7bfC3kQ5VNVq0VDcHsIhQf507m+BxMA5SIYiwdJHljH2BXpW2fL3FFn9gv9Wp57AEEzhm+wh4zANaJgkg==} peerDependencies: - storybook: ^8.5.0 + storybook: ^9.1.13 - '@storybook/addon-backgrounds@8.5.0': - resolution: {integrity: sha512-lzyFLs7niNsqlhH5kdUrp7htLiMIcjY50VLWe0PaeJ6T6GZ7X9qhQzROAUV6cGqzyd8A6y/LzIUntDPMVEm/6g==} - peerDependencies: - storybook: ^8.5.0 - - '@storybook/addon-controls@8.5.0': - resolution: {integrity: sha512-1fivx77A/ahObrPl0L66o9i9MUNfqXxsrpekne5gjMNXw9XJFIRNUe/ddL4CMmwu7SgVbj2QV+q5E5mlnZNTJw==} - peerDependencies: - storybook: ^8.5.0 - - '@storybook/addon-docs@8.5.0': - resolution: {integrity: sha512-REwLSr1VgOVNJZwP3y3mldhOjBHlM5fqTvq/tC8NaYpAzx9O4rZdoUSZxW3tYtoNoYrHpB8kzRTeZl8WSdKllw==} - peerDependencies: - storybook: ^8.5.0 - - '@storybook/addon-essentials@8.5.0': - resolution: {integrity: sha512-RrHRdaw2j3ugZiYQ6OHt3Ff08ID4hwAvipqULEsbEnEw3VlXOaW/MT5e2M7kW3MHskQ3iJ6XAD1Y1rNm432Pzw==} - peerDependencies: - storybook: ^8.5.0 - - '@storybook/addon-highlight@8.5.0': - resolution: {integrity: sha512-/JxYzMK5aJSYs0K/0eAEFyER2dMoxqwM891MdnkNwLFdyrM58lzHee00F9oEX6zeQoRUNQPRepq0ui2PvbTMGw==} - peerDependencies: - storybook: ^8.5.0 - - '@storybook/addon-interactions@8.5.0': - resolution: {integrity: sha512-vX1a8qS7o/W3kEzfL/CqOj/Rr6UlGLT/n0KXMpfIhx63tzxe1a1qGpFLL0h0zqAVPHZIOu9humWMKri5Iny6oA==} - peerDependencies: - storybook: ^8.5.0 - - '@storybook/addon-links@8.5.0': - resolution: {integrity: sha512-Y11GIByAYqn0TibI/xsy0vCe+ZxJS9PVAAoHngLxkf9J4WodAXcJABr8ZPlWDNdaEhSS/FF7UQUmNag0UC2/pw==} + '@storybook/addon-links@9.1.13': + resolution: {integrity: sha512-wx33RA5PPRSepVAjR0hMFp2IXoPgjwNAHIP92aoi2QQFS3+NHlf1I4vXEPpHU6lc0WBwM43qvLSI0qTAyZd8Nw==} peerDependencies: react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0-beta - storybook: ^8.5.0 + storybook: ^9.1.13 peerDependenciesMeta: react: optional: true - '@storybook/addon-measure@8.5.0': - resolution: {integrity: sha512-e8pJy2sICyj0Ff0W1PFc6HPE6PqcjnnHtfuDaO3M9uSKJLYkpTWJ8i1VSP178f8seq44r5/PdQCHqs5q5l3zgw==} + '@storybook/addon-onboarding@9.1.13': + resolution: {integrity: sha512-WqyzBA2VIPkWw6yFbyZ6PLVJWf+H+R99gvKHchUj7oJWVEs8FHYoP2Lum+5LonUerBqgwGQZlS3UPrRKJ0avZw==} peerDependencies: - storybook: ^8.5.0 + storybook: ^9.1.13 - '@storybook/addon-onboarding@8.5.0': - resolution: {integrity: sha512-77ebcHkKR744ciPbT4ZgqW4W7KrLv1uAdSb3mX3gWukSl4oxP9D/HjmNiX5fBDYWUC4wsf6q5barOs4Hqn8ivw==} + '@storybook/addon-themes@9.1.13': + resolution: {integrity: sha512-0ewLnwpoeOzOxDYg4VBlcnWiJz2jXvbZgEsQnqDXcK6y+WwK5MdupRFzSSJb+h470h3MnINAQrskPgGMKmI44A==} peerDependencies: - storybook: ^8.5.0 + storybook: ^9.1.13 - '@storybook/addon-outline@8.5.0': - resolution: {integrity: sha512-r12sk1b38Ph6NroWAOTfjbJ/V+gDobm7tKQQlbSDf6fgX7cqyPHmKjfNDCOCQpXouZm/Jm+41zd758PW+Yt4ng==} + '@storybook/builder-webpack5@9.1.13': + resolution: {integrity: sha512-BoFXrTlc22ryLl6U5QwgV/gHVbHBcXeVSjYOyu6XZ9SPV5GGbw5T/G7NJYJAZcsz1ZxuMEYYSMFryfZ5qcjRsA==} peerDependencies: - storybook: ^8.5.0 - - '@storybook/addon-themes@8.5.0': - resolution: {integrity: sha512-pBNut4sLfcOeLBvWdNAJ3cxv/BfMSTmJcUtSzE4G+1pVNsBbGF+T2f/GM0IjaM0K8Ft03VDzeEAB64nluDS4RA==} - peerDependencies: - storybook: ^8.5.0 - - '@storybook/addon-toolbars@8.5.0': - resolution: {integrity: sha512-q3yYYO2WX8K2DYNM++FzixGDjzYaeREincgsl2WXYXrcuGb5hkOoOgRiAQL8Nz9NQ1Eo+B/yZxrhG/5VoVhUUQ==} - peerDependencies: - storybook: ^8.5.0 - - '@storybook/addon-viewport@8.5.0': - resolution: {integrity: sha512-MlhVELImk9YzjEgGR2ciLC8d5tUSGcO7my4kWIClN0VyTRcvG4ZfwrsEC+jN3/l52nrgjLmKrDX5UAGZm6w5mQ==} - peerDependencies: - storybook: ^8.5.0 - - '@storybook/blocks@8.5.0': - resolution: {integrity: sha512-2sTOgjH/JFOgWnpqkKjpKVvKAgUaC9ZBjH1gnCoA5dne/SDafYaCAYfv6yZn7g2Xm1sTxWCAmMIUkYSALeWr+w==} - peerDependencies: - react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0-beta - react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0-beta - storybook: ^8.5.0 - peerDependenciesMeta: - react: - optional: true - react-dom: - optional: true - - '@storybook/builder-webpack5@8.5.0': - resolution: {integrity: sha512-MyCj11cktyN2HeK8NsLv+L0Km36qAz2UGqu6j1VKJUgPelgpCCi4StCW/KaSBeOFAwGD52xjAdNu+c1h/vfiMg==} - peerDependencies: - storybook: ^8.5.0 + storybook: ^9.1.13 typescript: '*' peerDependenciesMeta: typescript: optional: true - '@storybook/components@8.5.0': - resolution: {integrity: sha512-DhaHtwfEcfWYj3ih/5RBSDHe3Idxyf+oHw2/DmaLKJX6MluhdK3ZqigjRcTmA9Gj/SbR4CkHEEtDzAvBlW0BYw==} + '@storybook/core-webpack@9.1.13': + resolution: {integrity: sha512-HtBZ+ZVgeqlhyMiT/Tdb/vpKrCSiZEi6p4s7y/qk04SaX8XIPSufEeqLI/ELSz2hOcuCy6smU/tE1JkqVz/4uA==} peerDependencies: - storybook: ^8.2.0 || ^8.3.0-0 || ^8.4.0-0 || ^8.5.0-0 || ^8.6.0-0 + storybook: ^9.1.13 - '@storybook/core-webpack@8.5.0': - resolution: {integrity: sha512-bJAcF9TwNO2qNa7Jef4h5U9ka4399HDiHiQec1AxdqUIy/2zfbetgV6+2Fr5mtejPqJgbs7kXNGErI+fFByLGg==} + '@storybook/csf-plugin@9.1.13': + resolution: {integrity: sha512-EMpzYuyt9FDcxxfBChWzfId50y8QMpdenviEQ8m+pa6c+ANx3pC5J6t7y0khD8TQu815sTy+nc6cc8PC45dPUA==} peerDependencies: - storybook: ^8.5.0 - - '@storybook/core@8.5.0': - resolution: {integrity: sha512-apborO6ynns7SeydBSqE9o0zT6JSU+VY4gLFPJROGcconvSW4bS5xtJCsgjlulceyWVxepFHGXl4jEZw+SktXA==} - peerDependencies: - prettier: ^2 || ^3 - peerDependenciesMeta: - prettier: - optional: true - - '@storybook/csf-plugin@8.5.0': - resolution: {integrity: sha512-cs6ogviNyLG1h9J8Sb47U3DqIrQmn2EHm4ta3fpCeV3ABbrMgbzYyxtmybz4g/AwlDgjAZAt6PPcXkfCJ6p2CQ==} - peerDependencies: - storybook: ^8.5.0 - - '@storybook/csf@0.1.12': - resolution: {integrity: sha512-9/exVhabisyIVL0VxTCxo01Tdm8wefIXKXfltAPTSr8cbLn5JAxGQ6QV3mjdecLGEOucfoVhAKtJfVHxEK1iqw==} - - '@storybook/csf@0.1.13': - resolution: {integrity: sha512-7xOOwCLGB3ebM87eemep89MYRFTko+D8qE7EdAAq74lgdqRR5cOUtYWJLjO2dLtP94nqoOdHJo6MdLLKzg412Q==} + storybook: ^9.1.13 '@storybook/global@5.0.0': resolution: {integrity: sha512-FcOqPAXACP0I3oJ/ws6/rrPT9WGhu915Cg8D02a9YxLo0DE9zI+a9A5gRGvmQ09fiWPukqI8ZAEoQEdWUKMQdQ==} - '@storybook/icons@1.4.0': - resolution: {integrity: sha512-Td73IeJxOyalzvjQL+JXx72jlIYHgs+REaHiREOqfpo3A2AYYG71AUbcv+lg7mEDIweKVCxsMQ0UKo634c8XeA==} + '@storybook/icons@1.6.0': + resolution: {integrity: sha512-hcFZIjW8yQz8O8//2WTIXylm5Xsgc+lW9ISLgUk1xGmptIJQRdlhVIXCpSyLrQaaRiyhQRaVg7l3BD9S216BHw==} engines: {node: '>=14.0.0'} peerDependencies: react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0-beta react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0-beta - '@storybook/instrumenter@8.5.0': - resolution: {integrity: sha512-eZ/UY6w4U2vay+wX7QVwKiRoyMzZscuv6v4k4r8BlmHPFWbhiZDO9S2GsG16UkyKnrQrYk432he70n7hn1Xvmg==} + '@storybook/nextjs@9.1.13': + resolution: {integrity: sha512-Vio6+sLkuAGB9C7wai/4wTutYbMylsMjWaDZzGSAra4/Fx3Qk40CK3YiyPzQ5fhkpcONA9amPZ8iM0vLUs1UcQ==} + engines: {node: '>=20.0.0'} peerDependencies: - storybook: ^8.5.0 - - '@storybook/manager-api@8.5.0': - resolution: {integrity: sha512-Ildriueo3eif4M+gMlMxu/mrBIbAnz8+oesmQJKdzZfe/U9eQTI9OUqJsxx/IVBmdzQ3ySsgNmzj5VweRkse4A==} - peerDependencies: - storybook: ^8.2.0 || ^8.3.0-0 || ^8.4.0-0 || ^8.5.0-0 || ^8.6.0-0 - - '@storybook/nextjs@8.5.0': - resolution: {integrity: sha512-zUU0wQd4F2p006gZX0XC3+Zsj0tB4DOz+7FjSlnyGbzf5cDE6cD74l0Azj6aZluR4Q2say7gWDIpHu05YvIJsg==} - engines: {node: '>=18.0.0'} - peerDependencies: - next: ^13.5.0 || ^14.0.0 || ^15.0.0 + next: ^14.1.0 || ^15.0.0 react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0-beta react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0-beta - storybook: ^8.5.0 + storybook: ^9.1.13 typescript: '*' webpack: ^5.0.0 peerDependenciesMeta: @@ -3064,61 +3132,43 @@ packages: webpack: optional: true - '@storybook/preset-react-webpack@8.5.0': - resolution: {integrity: sha512-KJwVcQVYQWuMT5QUF06be60UuBfazBIO+90erfoYoIx0UwOxKMVnQz0HfG2JMc4EIoNLIl0/cm5mb2k4BWyhbA==} - engines: {node: '>=18.0.0'} + '@storybook/preset-react-webpack@9.1.13': + resolution: {integrity: sha512-2bWRdGSYvXWaE1QnrKFeE7EbTj+/Y0D8DHZ/OlKCB3xtNM7koMDrTnnI27hVlMjXqcX8RvOwb/N31FGBRgkiNg==} + engines: {node: '>=20.0.0'} peerDependencies: react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0-beta react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0-beta - storybook: ^8.5.0 + storybook: ^9.1.13 typescript: '*' peerDependenciesMeta: typescript: optional: true - '@storybook/preview-api@8.5.0': - resolution: {integrity: sha512-g0XbD54zMUkl6bpuA7qEBCE9rW1QV6KKmwkO4bkxMOJcMke3x9l00JTaYn7Un8wItjXiS3BIG15B6mnfBG7fng==} - peerDependencies: - storybook: ^8.2.0 || ^8.3.0-0 || ^8.4.0-0 || ^8.5.0-0 || ^8.6.0-0 - '@storybook/react-docgen-typescript-plugin@1.0.6--canary.9.0c3f3b7.0': resolution: {integrity: sha512-KUqXC3oa9JuQ0kZJLBhVdS4lOneKTOopnNBK4tUAgoxWQ3u/IjzdueZjFr7gyBrXMoU6duutk3RQR9u8ZpYJ4Q==} peerDependencies: typescript: '>= 4.x' webpack: '>= 4' - '@storybook/react-dom-shim@8.5.0': - resolution: {integrity: sha512-7P8xg4FiuFpM6kQOzZynno+0zyLVs8NgsmRK58t3JRZXbda1tzlxTXzvqx4hUevvbPJGjmrB0F3xTFH+8Otnvw==} + '@storybook/react-dom-shim@9.1.13': + resolution: {integrity: sha512-/tMr9TmV3+98GEQO0S03k4gtKHGCpv9+k9Dmnv+TJK3TBz7QsaFEzMwe3gCgoTaebLACyVveDiZkWnCYAWB6NA==} peerDependencies: react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0-beta react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0-beta - storybook: ^8.5.0 + storybook: ^9.1.13 - '@storybook/react@8.5.0': - resolution: {integrity: sha512-/jbkmGGc95N7KduIennL/k8grNTP5ye/YBnkcS4TbF7uDWBtKy3/Wqvx5BIlFXq3qeUnZJ8YtZc0lPIYeCY8XQ==} - engines: {node: '>=18.0.0'} + '@storybook/react@9.1.13': + resolution: {integrity: sha512-B0UpYikKf29t8QGcdmumWojSQQ0phSDy/Ne2HYdrpNIxnUvHHUVOlGpq4lFcIDt52Ip5YG5GuAwJg3+eR4LCRg==} + engines: {node: '>=20.0.0'} peerDependencies: - '@storybook/test': 8.5.0 react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0-beta react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0-beta - storybook: ^8.5.0 - typescript: '>= 4.2.x' + storybook: ^9.1.13 + typescript: '>= 4.9.x' peerDependenciesMeta: - '@storybook/test': - optional: true typescript: optional: true - '@storybook/test@8.5.0': - resolution: {integrity: sha512-M/DdPlI6gwL7NGkK5o7GYjdEBp95AsFEUtW29zQfnVIAngYugzi3nIuM/XkQHunidVdAZCYjw2s2Yhhsx/m9sw==} - peerDependencies: - storybook: ^8.5.0 - - '@storybook/theming@8.5.0': - resolution: {integrity: sha512-591LbOj/HMmHYUfLgrMerxhF1A9mY61HWKxcRpB6xxalc1Xw1kRtQ49DcwuTXnUu9ktBB3nuOzPNPQPFSh/7PQ==} - peerDependencies: - storybook: ^8.2.0 || ^8.3.0-0 || ^8.4.0-0 || ^8.5.0-0 || ^8.6.0-0 - '@stylistic/eslint-plugin@5.2.2': resolution: {integrity: sha512-bE2DUjruqXlHYP3Q2Gpqiuj2bHq7/88FnuaS0FjeGGLCy+X6a07bGVuwtiOYnPSLHR6jmx5Bwdv+j7l8H+G97A==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} @@ -3200,10 +3250,6 @@ packages: resolution: {integrity: sha512-pemlzrSESWbdAloYml3bAJMEfNh1Z7EduzqPKprCH5S341frlpYnUEW0H72dLxa6IsYr+mPno20GiSm+h9dEdQ==} engines: {node: '>=18'} - '@testing-library/jest-dom@6.5.0': - resolution: {integrity: sha512-xGGHpBXYSHUUr6XsKBfs85TWlYKpTc37cSBBVrXcib2MkHLboWlkClhWF37JKlDb9KEq3dHs+f2xR7XJEWGBxA==} - engines: {node: '>=14', npm: '>=6', yarn: '>=1'} - '@testing-library/jest-dom@6.8.0': resolution: {integrity: sha512-WgXcWzVM6idy5JaftTVC8Vs83NKRmGJz4Hqs4oyOuO2J4r/y79vvKZsb+CaGyCSEbUPI6OsewfPd0G1A0/TUZQ==} engines: {node: '>=14', npm: '>=6', yarn: '>=1'} @@ -3223,8 +3269,8 @@ packages: '@types/react-dom': optional: true - '@testing-library/user-event@14.5.2': - resolution: {integrity: sha512-YAh82Wh4TIrxYLmfGcixwD18oIjyC1pFQC2Y01F2lzV2HTMiYrI0nze0FD0ocB//CKS/7jIUgae+adPqxK5yCQ==} + '@testing-library/user-event@14.6.1': + resolution: {integrity: sha512-vq7fv0rnt+QTXgPxr5Hjc210p6YKq2kmdziLgnsZGgLJ9e6VAShx1pACLuRjd/AS/sr7phAR58OIIpf0LlmQNw==} engines: {node: '>=12', npm: '>=6'} peerDependencies: '@testing-library/dom': '>=7.21.4' @@ -3262,6 +3308,9 @@ packages: '@types/cacheable-request@6.0.3': resolution: {integrity: sha512-IQ3EbTzGxIigb1I3qPZc1rWJnH0BmSKv5QYTalEwweFvyBDLSAe24zP0le/hyi7ecGfZVlIVAg4BZqb8WBwKqw==} + '@types/chai@5.2.2': + resolution: {integrity: sha512-8kB30R7Hwqf40JPiKhVzodJs2Qc1ZJ5zuT3uzw5Hq/dhNCl3G3l83jfpdI1e20BP348+fV7VIL/+FxaXkqBmWg==} + '@types/d3-array@3.2.1': resolution: {integrity: sha512-Y2Jn2idRrLzUfAKV2LyRImR+y4oa2AntrgID95SHJxuMUrkNXmanDSed71sRNZysveJVt1hLLemQZIady0FpEg==} @@ -3358,6 +3407,9 @@ packages: '@types/debug@4.1.12': resolution: {integrity: sha512-vIChWdVG3LG1SMxEvI/AK+FWJthlrqlTu7fbrlywTkkaONwk/UAGaULXRlf8vkzFBLVm0zkMdCquhL5aOjhXPQ==} + '@types/deep-eql@4.0.2': + resolution: {integrity: sha512-c9h9dVVMigMPc4bwTvC5dxqtqJZwQPePsWjPlpSOnojbor6pGqdk541lfA7AqFQr5pB1BRdq0juY9db81BwyFw==} + '@types/doctrine@0.0.9': resolution: {integrity: sha512-eOIHzCUSH7SMfonMG1LsC2f8vxBFtho6NGBznK41R84YzPuvSBzrhEps33IsQiOW9+VL6NQ9DbjQJznk/S4uRA==} @@ -3505,9 +3557,6 @@ packages: '@types/uuid@10.0.0': resolution: {integrity: sha512-7gqG38EyHgyP1S+7+xomFtL+ZNHcKv6DwNaCZmJmo1vgMugyF3TCnXVg4t1uk89mLNwnLtnY3TpOpCOyp1/xHQ==} - '@types/uuid@9.0.8': - resolution: {integrity: sha512-jg+97EGIcY9AGHJJRaaPVgetKDsrTgbRjQ5Msgjh/DQKEFl0DtyRr/VCOyD1T2R1MNeWPK/u7JoGhlDZnKBAfA==} - '@types/whatwg-mimetype@3.0.2': resolution: {integrity: sha512-c2AKvDT8ToxLIOUlN51gTiHXflsfIFisS4pO7pDPoKouJCESkhZnEy623gwP9laCy5lnLDAw1vAzu2vM2YLOrA==} @@ -3672,23 +3721,28 @@ packages: vitest: optional: true - '@vitest/expect@2.0.5': - resolution: {integrity: sha512-yHZtwuP7JZivj65Gxoi8upUN2OzHTi3zVfjwdpu2WrvCZPLwsJ2Ey5ILIPccoW23dd/zQBlJ4/dhi7DWNyXCpA==} + '@vitest/expect@3.2.4': + resolution: {integrity: sha512-Io0yyORnB6sikFlt8QW5K7slY4OjqNX9jmJQ02QDda8lyM6B5oNgVWoSoKPac8/kgnCUzuHQKrSLtu/uOqqrig==} - '@vitest/pretty-format@2.0.5': - resolution: {integrity: sha512-h8k+1oWHfwTkyTkb9egzwNMfJAEx4veaPSnMeKbVSjp4euqGSbQlm5+6VHwTr7u4FJslVVsUG5nopCaAYdOmSQ==} + '@vitest/mocker@3.2.4': + resolution: {integrity: sha512-46ryTE9RZO/rfDd7pEqFl7etuyzekzEhUbTW3BvmeO/BcCMEgq59BKhek3dXDWgAj4oMK6OZi+vRr1wPW6qjEQ==} + peerDependencies: + msw: ^2.4.9 + vite: 6.2.7 + peerDependenciesMeta: + msw: + optional: true + vite: + optional: true - '@vitest/pretty-format@2.1.9': - resolution: {integrity: sha512-KhRIdGV2U9HOUzxfiHmY8IFHTdqtOhIzCpd8WRdJiE7D/HUcZVD0EgQCVjm+Q9gkUXWgBvMmTtZgIG48wq7sOQ==} + '@vitest/pretty-format@3.2.4': + resolution: {integrity: sha512-IVNZik8IVRJRTr9fxlitMKeJeXFFFN0JaB9PHPGQ8NKQbGpfjlTx9zO4RefN8gp7eqjNy8nyK3NZmBzOPeIxtA==} - '@vitest/spy@2.0.5': - resolution: {integrity: sha512-c/jdthAhvJdpfVuaexSrnawxZz6pywlTPe84LUB2m/4t3rl2fTo9NFGBG4oWgaD+FTgDDV8hJ/nibT7IfH3JfA==} + '@vitest/spy@3.2.4': + resolution: {integrity: sha512-vAfasCOe6AIK70iP5UD11Ac4siNUNJ9i/9PZ3NKx07sG6sUxeag1LWdNrMWeKKYBLlzuK+Gn65Yd5nyL6ds+nw==} - '@vitest/utils@2.0.5': - resolution: {integrity: sha512-d8HKbqIcya+GR67mkZbrzhS5kKhtp8dQLcmRZLGTscGVg7yImT82cIrhtn2L8+VujWcy6KZweApgNmPsTAO/UQ==} - - '@vitest/utils@2.1.9': - resolution: {integrity: sha512-v0psaMSkNJ3A2NMrUEHFRzJtDPFn+/VWZ5WxImB21T9fjucJRmS7xCS3ppEnARb9y11OAzaD+P2Ps+b+BGX5iQ==} + '@vitest/utils@3.2.4': + resolution: {integrity: sha512-fB2V0JFrQSMsCo9HiSq3Ezpdv4iYaXRG1Sx8edX3MwxfyNn83mKiGzOcH+Fkxt4MHxr3y42fQi1oeAInqgX2QA==} '@vue/compiler-core@3.5.17': resolution: {integrity: sha512-Xe+AittLbAyV0pabcN7cP7/BenRBNcteM4aSDCtRvGw0d9OL+HG1u/XHLY/kt1q4fyMeZYXyIYrsHuPSiDPosA==} @@ -4020,6 +4074,10 @@ packages: base64-js@1.5.1: resolution: {integrity: sha512-AKpaYlHn8t4SVbOHCy+b5+KKgvR4vrsD8vbvrbiQJps7fKDTkjkDry6ji0rUJjC0kzbNePLwzxq8iypo41qeWA==} + baseline-browser-mapping@2.8.18: + resolution: {integrity: sha512-UYmTpOBwgPScZpS4A+YbapwWuBwasxvO/2IOHArSsAhL/+ZdmATBXTex3t+l2hXwLVYK382ibr/nKoY9GKe86w==} + hasBin: true + before-after-hook@3.0.2: resolution: {integrity: sha512-Nik3Sc0ncrMK4UUdXQmAnRtzmNQTAAXmXIopizwZ1W1t8QmfJj+zL4OA2I7XPTPW5z5TDqv4hRo/JzouDJnX3A==} @@ -4059,9 +4117,6 @@ packages: brorand@1.1.0: resolution: {integrity: sha512-cKV8tMCEpQs4hK/ik71d6LrPOnpkpGBR0wzxqr68g2m/LB2GxVYQroAjMJZRVM1Y4BCjCKc3vAamxSzOY2RP+w==} - browser-assert@1.2.1: - resolution: {integrity: sha512-nfulgvOR6S4gt9UKCeGJOuSGBPGiFT6oQ/2UBnvTY/5aQ1PnksW72fhZkM30DzoRRv2WpwZf1vHHEr3mtuXIWQ==} - browserify-aes@1.2.0: resolution: {integrity: sha512-+7CHXqGuspUn/Sl5aO7Ea0xWGAtETPXNSAjHo48JfLdPWcMng33Xe4znFvQweqc/uzk5zSOI3H52CYnjCfb5hA==} @@ -4087,6 +4142,11 @@ packages: engines: {node: ^6 || ^7 || ^8 || ^9 || ^10 || ^11 || ^12 || >=13.7} hasBin: true + browserslist@4.26.3: + resolution: {integrity: sha512-lAUU+02RFBuCKQPj/P6NgjlbCnLBMp4UtgTx7vNHd3XSIJF87s9a5rA3aH2yw3GS9DqZAUbOtZdCCiZeVRqt0w==} + engines: {node: ^6 || ^7 || ^8 || ^9 || ^10 || ^11 || ^12 || >=13.7} + hasBin: true + bser@2.1.1: resolution: {integrity: sha512-gQxTNE/GAfIIrmHLUE3oJyp5FO6HRBfhjnw4/wMmA63ZGDJnWBmgY/lyQBpnDUkGmAhbSe39tx2d/iTOAfglwQ==} @@ -4151,6 +4211,9 @@ packages: caniuse-lite@1.0.30001746: resolution: {integrity: sha512-eA7Ys/DGw+pnkWWSE/id29f2IcPHVoE8wxtvE5JdvD2V28VTDPy1yEeo11Guz0sJ4ZeGRcm3uaTcAqK1LXaphA==} + caniuse-lite@1.0.30001751: + resolution: {integrity: sha512-A0QJhug0Ly64Ii3eIqHu5X51ebln3k4yTUkY1j8drqpWHVreg/VLijN48cZ1bYPiqOQuqpkIKnzr/Ul8V+p6Cw==} + canvas@2.11.2: resolution: {integrity: sha512-ItanGBMrmRV7Py2Z+Xhs7cT+FNt5K0vPL4p9EZ/UX/Mu7hFbkxSjKF2KVtPwX7UYWp7dRKnrTvReflgrItJbdw==} engines: {node: '>=6'} @@ -4166,10 +4229,6 @@ packages: resolution: {integrity: sha512-5nFxhUrX0PqtyogoYOA8IPswy5sZFTOsBFl/9bNsmDLgsxYTzSZQJDPppDnZPTQbzSEm0hqGjWPzRemQCYbD6A==} engines: {node: '>=18'} - chalk@3.0.0: - resolution: {integrity: sha512-4D3B6Wf41KOYRFdszmDqMCGq5VV/uMAB273JILmO+3jAlh8X4qDtdtgCR3fxtbLEMzSx22QdhnDcJvu2u1fVwg==} - engines: {node: '>=8'} - chalk@4.1.1: resolution: {integrity: sha512-diHzdDKxcU+bAsUboHLPEDQiw0qEe0qd7SYUn3HgcFlWgbDcfLGswOHYeGrHKzG9z6UYf01d9VFMfZxPM1xZSg==} engines: {node: '>=10'} @@ -4234,8 +4293,8 @@ packages: resolution: {integrity: sha512-bIomtDF5KGpdogkLd9VspvFzk9KfpyyGlS8YFVZl7TGPBHL5snIOnxeshwVgPteQ9b4Eydl+pVbIyE1DcvCWgQ==} engines: {node: '>=10'} - chromatic@11.29.0: - resolution: {integrity: sha512-yisBlntp9hHVj19lIQdpTlcYIXuU9H/DbFuu6tyWHmj6hWT2EtukCCcxYXL78XdQt1vm2GfIrtgtKpj/Rzmo4A==} + chromatic@12.2.0: + resolution: {integrity: sha512-GswmBW9ZptAoTns1BMyjbm55Z7EsIJnUvYKdQqXIBZIKbGErmpA+p4c0BYA+nzw5B0M+rb3Iqp1IaH8TFwIQew==} hasBin: true peerDependencies: '@chromatic-com/cypress': ^0.*.* || ^1.0.0 @@ -4771,8 +4830,8 @@ packages: resolution: {integrity: sha512-vEtk+OcP7VBRtQZ1EJ3bdgzSfBjgnEalLTp5zjJrS+2Z1w2KZly4SBdac/WDU3hhsNAZ9E8SC96ME4Ey8MZ7cg==} engines: {node: '>=8'} - detect-libc@2.1.1: - resolution: {integrity: sha512-ecqj/sy1jcK1uWrwpR67UhYrIFQ+5WlGxth34WquCbamhFA6hkkwiu37o6J5xCHdo1oixJRfVRw+ywV+Hq/0Aw==} + detect-libc@2.1.2: + resolution: {integrity: sha512-Btj2BOOO83o3WyH59e8MgXsxEQVcarkUOpEYrubB0urwnN10yQ364rsiByU11nZlqWYZm05i/of7io4mzihBtQ==} engines: {node: '>=8'} detect-newline@3.1.0: @@ -4869,6 +4928,9 @@ packages: electron-to-chromium@1.5.186: resolution: {integrity: sha512-lur7L4BFklgepaJxj4DqPk7vKbTEl0pajNlg2QjE5shefmlmBLm2HvQ7PMf1R/GvlevT/581cop33/quQcfX3A==} + electron-to-chromium@1.5.237: + resolution: {integrity: sha512-icUt1NvfhGLar5lSWH3tHNzablaA5js3HVHacQimfP8ViEBOQv+L7DKEuHdbTZ0SKCO1ogTJTIL1Gwk9S6Qvcg==} + elkjs@0.9.3: resolution: {integrity: sha512-f/ZeWvW/BCXbhGEf1Ujp29EASo/lk1FDnETgNKwJrsVvGZhUWCZyg3xLJjAsxfOmt8KjswHmI5EwCQcPMpOYhQ==} @@ -5146,11 +5208,12 @@ packages: peerDependencies: eslint: ^8.0.0 || ^9.0.0 - eslint-plugin-storybook@9.0.7: - resolution: {integrity: sha512-da9oIFo2ww+/PWAsTrpeEPUmhel6Ej1++SwBvdf+SV0H6+rOPbzJGOh367hdOvkwKCbGdKRmw+JmXFCQfHCpqw==} - engines: {node: '>= 18'} + eslint-plugin-storybook@9.1.13: + resolution: {integrity: sha512-kPuhbtGDiJLB5OLZuwFZAxgzWakNDw64sJtXUPN8g0+VAeXfHyZEmsE28qIIETHxtal71lPKVm8QNnERaJHPJQ==} + engines: {node: '>=20.0.0'} peerDependencies: eslint: '>=8' + storybook: ^9.1.13 eslint-plugin-tailwindcss@3.18.2: resolution: {integrity: sha512-QbkMLDC/OkkjFQ1iz/5jkMdHfiMu/uwujUHLAJK5iwNHD8RTxVTlsUezE0toTZ6VhybNBsk+gYGPDq2agfeRNA==} @@ -5416,6 +5479,10 @@ packages: resolution: {integrity: sha512-v2ZsoEuVHYy8ZIlYqwPe/39Cy+cFDzp4dXPaxNvkEuouymu+2Jbz0PxpKarJHYJTmv2HWT3O382qY8l4jMWthw==} engines: {node: ^12.20.0 || ^14.13.1 || >=16.0.0} + find-up@7.0.0: + resolution: {integrity: sha512-YyZM99iHrqLKjmt4LJDj58KI+fYyufRLBSYcqycxf//KpBk9FoewoGX0450m9nB44qrZnovzC2oeP5hUibxc/g==} + engines: {node: '>=18'} + flat-cache@3.2.0: resolution: {integrity: sha512-CYcENa+FtcUKLmhhqyctpclsq7QF38pKjZHsGNiSQF5r4FtoKDWabFDl3hzaEQMvT1LHEysw5twgLvpYYb4vbw==} engines: {node: ^10.12.0 || >=12.0.0} @@ -5683,17 +5750,11 @@ packages: html-void-elements@3.0.0: resolution: {integrity: sha512-bEqo66MRXsUGxWHV5IP0PUiAWwoEjba4VCzg0LjFJBpchPaTfyfCKTG6bc5F8ucKec3q5y6qOdGyYTSBEvhCrg==} - html-webpack-plugin@5.6.3: - resolution: {integrity: sha512-QSf1yjtSAsmf7rYBV7XX86uua4W/vkhIt0xNXKbsi2foEeW7vjJQz4bhnpL3xH+l1ryl1680uNv968Z+X6jSYg==} + html-webpack-plugin@5.5.4: + resolution: {integrity: sha512-3wNSaVVxdxcu0jd4FpQFoICdqgxs4zIQQvj+2yQKFfBOnLETQ6X5CDWdeasuGlSsooFlMkEioWDTqBv1wvw5Iw==} engines: {node: '>=10.13.0'} peerDependencies: - '@rspack/core': 0.x || 1.x webpack: ^5.20.0 - peerDependenciesMeta: - '@rspack/core': - optional: true - webpack: - optional: true htmlparser2@6.1.0: resolution: {integrity: sha512-gyyPk6rgonLFEDGoeRgQNaEUvdJ4ktTmmUh/h2t7s+M8oPpIPxgNACWa+6ESR57kXstwqPiCut0V8NRpcwgU7A==} @@ -5755,8 +5816,8 @@ packages: resolution: {integrity: sha512-Hs59xBNfUIunMFgWAbGX5cq6893IbWg4KnrjbYwX3tx0ztorVgTDA6B2sxf8ejHJ4wz8BqGUMYlnzNBer5NvGg==} engines: {node: '>= 4'} - image-size@1.2.1: - resolution: {integrity: sha512-rH+46sQJ2dlwfjfhCyNx5thzrv+dtmBIhPHk0zgRUukHzZ/kRueTJXoYYsclBaKcSMBWuGbOFXtioLpzTb5euw==} + image-size@2.0.2: + resolution: {integrity: sha512-IRqXKlaXwgSMAMtpNzZa1ZAe8m+Sa1770Dhk8VkSsP9LS+iHD62Zd8FQKs8fbPiagBE7BzoFX23cxFnwshpV6w==} engines: {node: '>=16.x'} hasBin: true @@ -6359,9 +6420,6 @@ packages: magic-string@0.25.9: resolution: {integrity: sha512-RmF0AsMzgt25qzqqLc1+MbHmhdx0ojF2Fvs4XnOqz2ZOBXzzkEwc/dJQZCYHAn7v1jbVOjAZfK8msRn4BxO4VQ==} - magic-string@0.30.17: - resolution: {integrity: sha512-sNPKHvyjVf7gyjwS4xGTaW/mCnF8wnjtifKBEhxfZ7E/S8tQ0rssrwGNn6q8JH/ohItJfSQp9mBtQYuTlH5QnA==} - magic-string@0.30.19: resolution: {integrity: sha512-2N21sPY9Ws53PZvsEpVtNuSW+ScYbQdp4b9qUaL+9QkHUrGFKo56Lg9Emg5s9V/qrtNBmiR01sYhUOwu3H+VOw==} @@ -6382,9 +6440,6 @@ packages: makeerror@1.0.12: resolution: {integrity: sha512-JmqCvUhmt43madlpFzG4BQzG2Z3m6tvQDNKdClZnO3VbIudJYmxsT0FNJMeiB2+JTSlTQTSbU8QdesVmwJcmLg==} - map-or-similar@1.5.0: - resolution: {integrity: sha512-0aF7ZmVon1igznGI4VS30yugpduQW3y3GkcgGJOp7d8x8QrizhigUxjI/m2UojsXXto+jLAH3KSz+xOJTiORjg==} - markdown-extensions@2.0.0: resolution: {integrity: sha512-o5vL7aDWatOTX8LzaS1WMoaoxIiLRQJuIKKe2wAw6IeULDHaqbiqiggmx+pKvZDb1Sj+pE46Sn1T7lCqfFtg1Q==} engines: {node: '>=16'} @@ -6464,9 +6519,6 @@ packages: memoize-one@5.2.1: resolution: {integrity: sha512-zYiwtZUcYyXKo/np96AGZAckk+FWWsUdJ3cHGGmld7+AhvcWmQyGCYUh1hc4Q/pkOhb65dQR/pqCyK0cOaHz4Q==} - memoizerific@1.11.3: - resolution: {integrity: sha512-/EuHYwAPdLtXwAwSZkh/Gutery6pD2KYd44oQLhAvQp/50mpyduZh8Q7PYHXTCJ+wuXxt7oij2LXyIJOOYFPog==} - merge-stream@2.0.0: resolution: {integrity: sha512-abv/qOcuPfk3URPfDzmZU1LKmuw8kT+0nIHvKrKgFrwifol/doWcdA4ZqsWQ8ENrFKkd67Mfpo/LovbIUsbt3w==} @@ -6788,6 +6840,9 @@ packages: node-releases@2.0.19: resolution: {integrity: sha512-xxOWJsBKtzAq7DY0J+DTzuz58K8e7sJbdgwkbMWQe8UYB6ekmsQ45q0M/tJDsGaZmbC+l7n57UV8Hl5tHxO9uw==} + node-releases@2.0.25: + resolution: {integrity: sha512-4auku8B/vw5psvTiiN9j1dAOsXvMoGqJuKJcR+dTdqiXEK20mMTk1UEo3HS16LeGQsVG6+qKTPM9u/qQ2LqATA==} + nopt@5.0.0: resolution: {integrity: sha512-Tbj67rffqceeLpcRXrT7vKAN8CwfPeIBgM7E6iBkmKLV7bEMwpGgYLGv0jACUsECaa/vuxP0IjEont6umdMgtQ==} engines: {node: '>=6'} @@ -7071,10 +7126,6 @@ packages: resolution: {integrity: sha512-Nc3IT5yHzflTfbjgqWcCPpo7DaKy4FnpB0l/zCAW0Tc7jxAiuqSxHasntB3D7887LSrA93kDJ9IXovxJYxyLCA==} engines: {node: '>=4'} - pnp-webpack-plugin@1.7.0: - resolution: {integrity: sha512-2Rb3vm+EXble/sMXNSu6eoBx8e79gKqhNq9F5ZWW6ERNCTE/Q0wQNne5541tE5vKjfM8hpNCYL+LGc1YTfI0dg==} - engines: {node: '>=6'} - pnpm-workspace-yaml@1.1.0: resolution: {integrity: sha512-OWUzBxtitpyUV0fBYYwLAfWxn3mSzVbVB7cwgNaHvTTU9P0V2QHjyaY5i7f1hEiT9VeKsNH1Skfhe2E3lx/zhA==} @@ -7084,10 +7135,6 @@ packages: points-on-path@0.2.1: resolution: {integrity: sha512-25ClnWWuw7JbWZcgqY/gJ4FQWadKxGWk+3kR/7kD0tCaDtPPMj7oHu2ToLaVhfpnHrZzYby2w6tUA0eOIuUg8g==} - polished@4.3.1: - resolution: {integrity: sha512-OBatVyC/N7SCW/FaDHrSd+vn0o5cS855TOmYi4OkdWUMSJCET/xip//ch8xGUvtr3i44X9LVyWwQlRMTN3pwSA==} - engines: {node: '>=10'} - portfinder@1.0.37: resolution: {integrity: sha512-yuGIEjDAYnnOex9ddMnKZEMFE0CcGo6zbfzDklkmT1m5z734ss6JMzN9rNB3+RR7iS+F10D4/BVIaXOyh8PQKw==} engines: {node: '>= 10.12'} @@ -7266,9 +7313,6 @@ packages: queue-microtask@1.2.3: resolution: {integrity: sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A==} - queue@6.0.2: - resolution: {integrity: sha512-iHZWu+q3IdFZFX36ro/lKBkSvfkztY5Y7HMiPlOUjhupPcG2JMfst2KKEpu5XndviX/3UhFbRngUPNKtgvtZiA==} - quick-lru@5.1.1: resolution: {integrity: sha512-WuyALRjWPDGtt/wzJiadO5AXY+8hZ80hVpe6MyivgraREW751X3SbhRvG3eLKOYN+8VEvqLcf3wdnt44Z4S4SA==} engines: {node: '>=10'} @@ -7294,12 +7338,6 @@ packages: peerDependencies: react: ^16.3.0 || ^17.0.0 || ^18.0.0 - react-confetti@6.4.0: - resolution: {integrity: sha512-5MdGUcqxrTU26I2EU7ltkWPwxvucQTuqMm8dUz72z2YMqTD6s9vMcDUysk7n9jnC+lXuCPeJJ7Knf98VEYE9Rg==} - engines: {node: '>=16'} - peerDependencies: - react: ^16.3.0 || ^17.0.1 || ^18.0.0 || ^19.0.0 - react-docgen-typescript@2.4.0: resolution: {integrity: sha512-ZtAp5XTO5HRzQctjPU0ybY0RRCQO19X/8fxn3w7y2VVTUbGHDKULPTL4ky3vB05euSgG5NpALhEhDPvQ56wvXg==} peerDependencies: @@ -7309,11 +7347,6 @@ packages: resolution: {integrity: sha512-hlSJDQ2synMPKFZOsKo9Hi8WWZTC7POR8EmWvTSjow+VDgKzkmjQvFm2fk0tmRw+f0vTOIYKlarR0iL4996pdg==} engines: {node: '>=16.14.0'} - react-dom@18.3.1: - resolution: {integrity: sha512-5m4nQKp+rZRb09LNH59GM4BxTh9251/ylbKIbpe7TpGxfJ+9kv6BLkLBXIjjspbgbnIBNqlI23tRnTWT0snUIw==} - peerDependencies: - react: ^18.3.1 - react-dom@19.1.1: resolution: {integrity: sha512-Dlq/5LAZgF0Gaz6yiqZCf6VCcZs1ghAJyrsu84Q/GT0gV+mCxbfmKNoGRKBYMJ8IEdGPqu49YWXD02GCknEDkw==} peerDependencies: @@ -7469,10 +7502,6 @@ packages: react: ^15.0.0 || ^16.0.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 react-dom: ^15.0.0 || ^16.0.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 - react@18.3.1: - resolution: {integrity: sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ==} - engines: {node: '>=0.10.0'} - react@19.1.1: resolution: {integrity: sha512-w8nqGImo45dmMIfljjMwOGtbmC/mk4CMYhWIicdSflH91J9TyCyczcPFXJzrZ/ZXcgGRFeP6BU0BEJTw6tZdfQ==} engines: {node: '>=0.10.0'} @@ -7684,6 +7713,11 @@ packages: engines: {node: '>=10.0.0'} hasBin: true + rollup@4.52.5: + resolution: {integrity: sha512-3GuObel8h7Kqdjt0gxkEzaifHTqLVW56Y/bjN7PSQtkKr0w3V/QYSdt6QWYtd7A1xUtYQigtdUfgj1RvWVtorw==} + engines: {node: '>=18.0.0', npm: '>=8.0.0'} + hasBin: true + roughjs@4.6.6: resolution: {integrity: sha512-ZUz/69+SYpFN/g/lUlo2FXcIjRkSu3nDarreVdGGndHEBJ6cXPdKguS8JGxwj5HA5xIbVKSmLgr5b3AWxtRfvQ==} @@ -7693,8 +7727,11 @@ packages: rw@1.3.3: resolution: {integrity: sha512-PdhdWy89SiZogBLaw42zdeqtRJ//zFd2PgQavcICDUgJT5oW10QCRKbJ6bg4r0/UY2M6BWd5tkxuGFRvCkgfHQ==} - sass-loader@14.2.1: - resolution: {integrity: sha512-G0VcnMYU18a4N7VoNDegg2OuMjYtxnqzQWARVWCIVSZwJeiL9kg8QMsuIZOplsJgTzZLF6jGxI3AClj8I9nRdQ==} + safe-buffer@5.2.1: + resolution: {integrity: sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==} + + sass-loader@16.0.5: + resolution: {integrity: sha512-oL+CMBXrj6BZ/zOq4os+UECPL+bWqt6OAC6DWS8Ln8GZRcMDjlJ4JC3FBDuHJdYaFWIdKNIBYmtZtK2MaMkNIw==} engines: {node: '>= 18.12.0'} peerDependencies: '@rspack/core': 0.x || 1.x @@ -7719,9 +7756,6 @@ packages: engines: {node: '>=14.0.0'} hasBin: true - scheduler@0.23.2: - resolution: {integrity: sha512-UOShsPwz7NrMUqhR6t0hWjFduvOzbtv7toDH1/hIrfRNIDBnnBWd0CwJTGvTpngVlmwGCdP9/Zl/tVrDqcuYzQ==} - scheduler@0.26.0: resolution: {integrity: sha512-NlHwttCI/l5gCPR3D1nNXtWABUmBwvZpEQiD4IXSbIDq8BzLIK/7Ir5gTFSGZDUu37K5cMNp0hFtzO38sC7gWA==} @@ -7737,6 +7771,10 @@ packages: resolution: {integrity: sha512-Gn/JaSk/Mt9gYubxTtSn/QCV4em9mpAPiR1rqy/Ocu19u/G9J5WWdNoUT4SiV6mFC3y6cxyFcFwdzPM3FgxGAQ==} engines: {node: '>= 10.13.0'} + schema-utils@4.3.3: + resolution: {integrity: sha512-eflK8wEtyOE6+hsaRVPxvUKYCpRgzLqDTb8krvAsRIwOGlHoSgYLgBXoubGgLd2fT41/OUYdb48v4k4WWHQurA==} + engines: {node: '>= 10.13.0'} + screenfull@5.2.0: resolution: {integrity: sha512-9BakfsO2aUQN2K9Fdbj87RJIEZ82Q9IGim7FqM5OsebfoFC6ZHXgDq/KvniuLTPdeM8wY2o6Dj3WQ7KeQCj3cA==} engines: {node: '>=0.10.0'} @@ -7894,8 +7932,8 @@ packages: state-local@1.0.7: resolution: {integrity: sha512-HTEHMNieakEnoe33shBYcZ7NX83ACUjCu8c40iOGEZsngj9zRnkqS9j1pqQPXwobB0ZcVTk27REb7COQ0UR59w==} - storybook@8.5.0: - resolution: {integrity: sha512-cEx42OlCetManF+cONVJVYP7SYsnI2K922DfWKmZhebP0it0n6TUof4y5/XzJ8YUruwPgyclGLdX8TvdRuNSfw==} + storybook@9.1.13: + resolution: {integrity: sha512-G3KZ36EVzXyHds72B/qtWiJnhUpM0xOUeYlDcO9DSHL1bDTv15cW4+upBl+mcBZrDvU838cn7Bv4GpF+O5MCfw==} hasBin: true peerDependencies: prettier: ^2 || ^3 @@ -8067,6 +8105,10 @@ packages: resolution: {integrity: sha512-Re10+NauLTMCudc7T5WLFLAwDhQ0JWdrMK+9B2M8zR5hRExKmsRDCBA7/aV/pNJFltmBFO5BAMlQFi/vq3nKOg==} engines: {node: '>=6'} + tapable@2.3.0: + resolution: {integrity: sha512-g9ljZiwki/LfxmQADO3dEY1CbpmXT5Hm2fJ+QaGKwSXUylMybePR7/67YW7jOrrvjEgL1Fmz5kzyAjWVWLlucg==} + engines: {node: '>=6'} + tar@6.2.1: resolution: {integrity: sha512-DZ4yORTwrbTj/7MZYq2w+/ZFdI6OZ/f9SFHR+71gIVUZhOQPHzVCLpvRnPgyaMpfWxxk/4ONva3GQSyNIKRv6A==} engines: {node: '>=10'} @@ -8128,12 +8170,12 @@ packages: resolution: {integrity: sha512-tX5e7OM1HnYr2+a2C/4V0htOcSQcoSTH9KgJnVvNm5zm/cyEWKJ7j7YutsH9CxMdtOkkLFy2AHrMci9IM8IPZQ==} engines: {node: '>=12.0.0'} - tinyrainbow@1.2.0: - resolution: {integrity: sha512-weEDEq7Z5eTHPDh4xjX789+fHfF+P8boiFB+0vbWzpbnbsEr/GRaohi/uMKxg8RZMXnl1ItAi/IUHWMsjDV7kQ==} + tinyrainbow@2.0.0: + resolution: {integrity: sha512-op4nsTR47R6p0vMUUoYl/a+ljLFVtlfaXkLQmqfLR1qHma1h/ysYk4hEXZ880bf2CYgTskvTa/e196Vd5dDQXw==} engines: {node: '>=14.0.0'} - tinyspy@3.0.2: - resolution: {integrity: sha512-n1cw8k1k0x4pgA2+9XrOkFydTerNcJ1zWCO5Nn9scWHTD+5tp8dghT2x1uduQePZTZgd3Tupf+x9BxJjeJi77Q==} + tinyspy@4.0.4: + resolution: {integrity: sha512-azl+t0z7pw/z958Gy9svOTuzqIk6xq+NSheJzn5MMWtWTFywIacg2wUlzKFGtt3cthx0r2SxMK0yzJOR0IES7Q==} engines: {node: '>=14.0.0'} tldts-core@7.0.10: @@ -8215,15 +8257,6 @@ packages: ts-pattern@5.7.1: resolution: {integrity: sha512-EGs8PguQqAAUIcQfK4E9xdXxB6s2GK4sJfT/vcc9V1ELIvC4LH/zXu2t/5fajtv6oiRCxdv7BgtVK3vWgROxag==} - ts-pnp@1.2.0: - resolution: {integrity: sha512-csd+vJOb/gkzvcCHgTGSChYpy5f1/XKNsmvBGO4JXS+z1v2HobugDz4s1IeFXM3wZB44uczs+eazB5Q/ccdhQw==} - engines: {node: '>=6'} - peerDependencies: - typescript: '*' - peerDependenciesMeta: - typescript: - optional: true - tsconfig-paths-webpack-plugin@4.2.0: resolution: {integrity: sha512-zbem3rfRS8BgeNK50Zz5SIQgXzLafiHjOwUAvk/38/o1jHn/V5QAgVUcz884or7WYcPaH3N2CIfUc2u0ul7UcA==} engines: {node: '>=10.13.0'} @@ -8244,9 +8277,6 @@ packages: tty-browserify@0.0.1: resolution: {integrity: sha512-C3TaO7K81YvjCgQH9Q1S3R3P3BtN3RIM8n+OvX4il1K1zgE8ZhI0op7kClgkxtutIE8hQrcrHBXvIheqKUUCxw==} - tween-functions@1.2.0: - resolution: {integrity: sha512-PZBtLYcCLtEcjL14Fzb1gSxPBeL7nWvGhO5ZFPGqziCcr8uvHp0NDmdjBchp6KHL+tExcg0m3NISmKxhU394dA==} - type-check@0.4.0: resolution: {integrity: sha512-XleUoc9uwGXqjWwXaUTZAmzMcFZ5858QA2vvx1Ur5xIcixXIP+8LnFDgRplU30us6teqdlskFfu+ae4K79Ooew==} engines: {node: '>= 0.8.0'} @@ -8299,6 +8329,10 @@ packages: resolution: {integrity: sha512-6t3foTQI9qne+OZoVQB/8x8rk2k1eVy1gRXhV3oFQ5T6R1dqQ1xtin3XqSlx3+ATBkliTaR/hHyJBm+LVPNM8w==} engines: {node: '>=4'} + unicorn-magic@0.1.0: + resolution: {integrity: sha512-lRfVq8fE8gz6QMBuDM6a+LO3IAzTi05H6gCVaUpir2E1Rwpo4ZUog45KpNXKC/Mn3Yb9UDuHumeFTo9iV/D9FQ==} + engines: {node: '>=18'} + unified@11.0.5: resolution: {integrity: sha512-xKvGhPWw3k84Qjh8bI3ZeJjqnyadK+GEFtazSfZv/rKeTkTjOJho6mFqh2SM96iIcZokxiOpg78GazTSg8+KHA==} @@ -8436,10 +8470,6 @@ packages: resolution: {integrity: sha512-0/A9rDy9P7cJ+8w1c9WD9V//9Wj15Ce2MPz8Ri6032usz+NfePxx5AcN3bN+r6ZL6jEo066/yNYB3tn4pQEx+A==} hasBin: true - uuid@9.0.1: - resolution: {integrity: sha512-b+1eJOlsR9K8HJpow9Ok3fiWOWSIcIzXodvv0rQjVoOVNpWMpxf1wZNpt4y9h10odCNrqnYp1OBzRktckBe3sA==} - hasBin: true - v8-compile-cache-lib@3.0.1: resolution: {integrity: sha512-wa7YjyUGfNZngI/vtK0UHAN+lgDCxBPCylVXGp0zu59Fz5aiGtNXaq3DhIov063MorB+VfufLh3JlF2KdTK3xg==} @@ -8456,6 +8486,46 @@ packages: vfile@6.0.3: resolution: {integrity: sha512-KzIbH/9tXat2u30jf+smMwFCsno4wHVdNmzFyL+T/L3UGqqk6JKfVqOFOZEpZSHADH1k40ab6NUIXZq422ov3Q==} + vite@6.2.7: + resolution: {integrity: sha512-qg3LkeuinTrZoJHHF94coSaTfIPyBYoywp+ys4qu20oSJFbKMYoIJo0FWJT9q6Vp49l6z9IsJRbHdcGtiKbGoQ==} + engines: {node: ^18.0.0 || ^20.0.0 || >=22.0.0} + hasBin: true + peerDependencies: + '@types/node': ^18.0.0 || ^20.0.0 || >=22.0.0 + jiti: '>=1.21.0' + less: '*' + lightningcss: ^1.21.0 + sass: '*' + sass-embedded: '*' + stylus: '*' + sugarss: '*' + terser: ^5.16.0 + tsx: ^4.8.1 + yaml: ^2.4.2 + peerDependenciesMeta: + '@types/node': + optional: true + jiti: + optional: true + less: + optional: true + lightningcss: + optional: true + sass: + optional: true + sass-embedded: + optional: true + stylus: + optional: true + sugarss: + optional: true + terser: + optional: true + tsx: + optional: true + yaml: + optional: true + vm-browserify@1.1.2: resolution: {integrity: sha512-2ham8XPWTONajOR0ohOKOHXkm3+gaBmGut3SRuu75xLd/RRaY6vqgh8NBYYk7+RW3u5AtzPQZG8F10LHkl0lAQ==} @@ -8536,8 +8606,8 @@ packages: webpack-virtual-modules@0.6.2: resolution: {integrity: sha512-66/V2i5hQanC51vBQKPH4aI8NMAcBW59FVBs+rC7eGHupMyfn34q7rZIE+ETlJ+XTevqfUhVVBgSUNSW2flEUQ==} - webpack@5.100.2: - resolution: {integrity: sha512-QaNKAvGCDRh3wW1dsDjeMdDXwZm2vqq3zn6Pvq4rHOEOGSaUMgOOjG2Y9ZbIGzpfkJk9ZYTHpDqgDfeBDcnLaw==} + webpack@5.102.1: + resolution: {integrity: sha512-7h/weGm9d/ywQ6qzJ+Xy+r9n/3qgp/thalBbpOi5i223dPXKi04IBtqPN9nTd+jBc7QKfvDbaBnFipYp4sJAUQ==} engines: {node: '>=10.13.0'} hasBin: true peerDependencies: @@ -9203,7 +9273,7 @@ snapshots: '@babel/plugin-transform-class-properties@7.27.1(@babel/core@7.28.3)': dependencies: '@babel/core': 7.28.3 - '@babel/helper-create-class-features-plugin': 7.27.1(@babel/core@7.28.3) + '@babel/helper-create-class-features-plugin': 7.28.3(@babel/core@7.28.3) '@babel/helper-plugin-utils': 7.27.1 transitivePeerDependencies: - supports-color @@ -9381,7 +9451,7 @@ snapshots: '@babel/helper-plugin-utils': 7.27.1 '@babel/plugin-transform-destructuring': 7.28.0(@babel/core@7.28.3) '@babel/plugin-transform-parameters': 7.27.7(@babel/core@7.28.3) - '@babel/traverse': 7.28.0 + '@babel/traverse': 7.28.3 transitivePeerDependencies: - supports-color @@ -9452,7 +9522,7 @@ snapshots: '@babel/helper-module-imports': 7.27.1 '@babel/helper-plugin-utils': 7.27.1 '@babel/plugin-syntax-jsx': 7.27.1(@babel/core@7.28.3) - '@babel/types': 7.28.1 + '@babel/types': 7.28.4 transitivePeerDependencies: - supports-color @@ -9522,7 +9592,7 @@ snapshots: dependencies: '@babel/core': 7.28.3 '@babel/helper-annotate-as-pure': 7.27.3 - '@babel/helper-create-class-features-plugin': 7.27.1(@babel/core@7.28.3) + '@babel/helper-create-class-features-plugin': 7.28.3(@babel/core@7.28.3) '@babel/helper-plugin-utils': 7.27.1 '@babel/helper-skip-transparent-expression-wrappers': 7.27.1 '@babel/plugin-syntax-typescript': 7.27.1(@babel/core@7.28.3) @@ -9728,18 +9798,17 @@ snapshots: '@chevrotain/utils@11.0.3': {} - '@chromatic-com/storybook@3.2.7(react@19.1.1)(storybook@8.5.0)': + '@chromatic-com/storybook@4.1.1(storybook@9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)))': dependencies: - chromatic: 11.29.0 + '@neoconfetti/react': 1.0.0 + chromatic: 12.2.0 filesize: 10.1.6 jsonfile: 6.1.0 - react-confetti: 6.4.0(react@19.1.1) - storybook: 8.5.0 + storybook: 9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)) strip-ansi: 7.1.0 transitivePeerDependencies: - '@chromatic-com/cypress' - '@chromatic-com/playwright' - - react '@clack/core@0.5.0': dependencies: @@ -10791,12 +10860,12 @@ snapshots: - supports-color optional: true - '@mdx-js/loader@3.1.0(acorn@8.15.0)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3))': + '@mdx-js/loader@3.1.0(acorn@8.15.0)(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3))': dependencies: '@mdx-js/mdx': 3.1.0(acorn@8.15.0) source-map: 0.7.4 optionalDependencies: - webpack: 5.100.2(esbuild@0.25.0)(uglify-js@3.19.3) + webpack: 5.102.1(esbuild@0.25.0)(uglify-js@3.19.3) transitivePeerDependencies: - acorn - supports-color @@ -10831,12 +10900,6 @@ snapshots: - acorn - supports-color - '@mdx-js/react@3.1.0(@types/react@19.1.11)(react@18.3.1)': - dependencies: - '@types/mdx': 2.0.13 - '@types/react': 19.1.11 - react: 18.3.1 - '@mdx-js/react@3.1.0(@types/react@19.1.11)(react@19.1.1)': dependencies: '@types/mdx': 2.0.13 @@ -10865,6 +10928,8 @@ snapshots: '@tybys/wasm-util': 0.10.1 optional: true + '@neoconfetti/react@1.0.0': {} + '@next/bundle-analyzer@15.5.4': dependencies: webpack-bundle-analyzer: 4.10.1 @@ -10878,11 +10943,11 @@ snapshots: dependencies: fast-glob: 3.3.1 - '@next/mdx@15.5.4(@mdx-js/loader@3.1.0(acorn@8.15.0)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)))(@mdx-js/react@3.1.0(@types/react@19.1.11)(react@19.1.1))': + '@next/mdx@15.5.4(@mdx-js/loader@3.1.0(acorn@8.15.0)(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)))(@mdx-js/react@3.1.0(@types/react@19.1.11)(react@19.1.1))': dependencies: source-map: 0.7.6 optionalDependencies: - '@mdx-js/loader': 3.1.0(acorn@8.15.0)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)) + '@mdx-js/loader': 3.1.0(acorn@8.15.0)(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)) '@mdx-js/react': 3.1.0(@types/react@19.1.11)(react@19.1.1) '@next/swc-darwin-arm64@15.5.4': @@ -10967,8 +11032,6 @@ snapshots: dependencies: '@nolyfill/shared': 1.0.44 - '@nolyfill/safe-buffer@1.0.44': {} - '@nolyfill/safer-buffer@1.0.44': {} '@nolyfill/shared@1.0.24': {} @@ -11155,7 +11218,7 @@ snapshots: '@pkgr/core@0.2.7': {} - '@pmmmwh/react-refresh-webpack-plugin@0.5.17(react-refresh@0.14.2)(type-fest@2.19.0)(webpack-hot-middleware@2.26.1)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3))': + '@pmmmwh/react-refresh-webpack-plugin@0.5.17(react-refresh@0.14.2)(type-fest@2.19.0)(webpack-hot-middleware@2.26.1)(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3))': dependencies: ansi-html: 0.0.9 core-js-pure: 3.44.0 @@ -11163,9 +11226,9 @@ snapshots: html-entities: 2.6.0 loader-utils: 2.0.4 react-refresh: 0.14.2 - schema-utils: 4.3.2 - source-map: 0.7.4 - webpack: 5.100.2(esbuild@0.25.0)(uglify-js@3.19.3) + schema-utils: 4.3.3 + source-map: 0.7.6 + webpack: 5.102.1(esbuild@0.25.0)(uglify-js@3.19.3) optionalDependencies: type-fest: 2.19.0 webpack-hot-middleware: 2.26.1 @@ -11484,6 +11547,72 @@ snapshots: picomatch: 2.3.1 rollup: 2.79.2 + '@rollup/rollup-android-arm-eabi@4.52.5': + optional: true + + '@rollup/rollup-android-arm64@4.52.5': + optional: true + + '@rollup/rollup-darwin-arm64@4.52.5': + optional: true + + '@rollup/rollup-darwin-x64@4.52.5': + optional: true + + '@rollup/rollup-freebsd-arm64@4.52.5': + optional: true + + '@rollup/rollup-freebsd-x64@4.52.5': + optional: true + + '@rollup/rollup-linux-arm-gnueabihf@4.52.5': + optional: true + + '@rollup/rollup-linux-arm-musleabihf@4.52.5': + optional: true + + '@rollup/rollup-linux-arm64-gnu@4.52.5': + optional: true + + '@rollup/rollup-linux-arm64-musl@4.52.5': + optional: true + + '@rollup/rollup-linux-loong64-gnu@4.52.5': + optional: true + + '@rollup/rollup-linux-ppc64-gnu@4.52.5': + optional: true + + '@rollup/rollup-linux-riscv64-gnu@4.52.5': + optional: true + + '@rollup/rollup-linux-riscv64-musl@4.52.5': + optional: true + + '@rollup/rollup-linux-s390x-gnu@4.52.5': + optional: true + + '@rollup/rollup-linux-x64-gnu@4.52.5': + optional: true + + '@rollup/rollup-linux-x64-musl@4.52.5': + optional: true + + '@rollup/rollup-openharmony-arm64@4.52.5': + optional: true + + '@rollup/rollup-win32-arm64-msvc@4.52.5': + optional: true + + '@rollup/rollup-win32-ia32-msvc@4.52.5': + optional: true + + '@rollup/rollup-win32-x64-gnu@4.52.5': + optional: true + + '@rollup/rollup-win32-x64-msvc@4.52.5': + optional: true + '@sentry-internal/browser-utils@8.55.0': dependencies: '@sentry/core': 8.55.0 @@ -11531,146 +11660,51 @@ snapshots: dependencies: '@sinonjs/commons': 3.0.1 - '@storybook/addon-actions@8.5.0(storybook@8.5.0)': + '@storybook/addon-docs@9.1.13(@types/react@19.1.11)(storybook@9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)))': dependencies: - '@storybook/global': 5.0.0 - '@types/uuid': 9.0.8 - dequal: 2.0.3 - polished: 4.3.1 - storybook: 8.5.0 - uuid: 9.0.1 - - '@storybook/addon-backgrounds@8.5.0(storybook@8.5.0)': - dependencies: - '@storybook/global': 5.0.0 - memoizerific: 1.11.3 - storybook: 8.5.0 - ts-dedent: 2.2.0 - - '@storybook/addon-controls@8.5.0(storybook@8.5.0)': - dependencies: - '@storybook/global': 5.0.0 - dequal: 2.0.3 - storybook: 8.5.0 - ts-dedent: 2.2.0 - - '@storybook/addon-docs@8.5.0(@types/react@19.1.11)(storybook@8.5.0)': - dependencies: - '@mdx-js/react': 3.1.0(@types/react@19.1.11)(react@18.3.1) - '@storybook/blocks': 8.5.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@8.5.0) - '@storybook/csf-plugin': 8.5.0(storybook@8.5.0) - '@storybook/react-dom-shim': 8.5.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@8.5.0) - react: 18.3.1 - react-dom: 18.3.1(react@18.3.1) - storybook: 8.5.0 + '@mdx-js/react': 3.1.0(@types/react@19.1.11)(react@19.1.1) + '@storybook/csf-plugin': 9.1.13(storybook@9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0))) + '@storybook/icons': 1.6.0(react-dom@19.1.1(react@19.1.1))(react@19.1.1) + '@storybook/react-dom-shim': 9.1.13(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(storybook@9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0))) + react: 19.1.1 + react-dom: 19.1.1(react@19.1.1) + storybook: 9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)) ts-dedent: 2.2.0 transitivePeerDependencies: - '@types/react' - '@storybook/addon-essentials@8.5.0(@types/react@19.1.11)(storybook@8.5.0)': - dependencies: - '@storybook/addon-actions': 8.5.0(storybook@8.5.0) - '@storybook/addon-backgrounds': 8.5.0(storybook@8.5.0) - '@storybook/addon-controls': 8.5.0(storybook@8.5.0) - '@storybook/addon-docs': 8.5.0(@types/react@19.1.11)(storybook@8.5.0) - '@storybook/addon-highlight': 8.5.0(storybook@8.5.0) - '@storybook/addon-measure': 8.5.0(storybook@8.5.0) - '@storybook/addon-outline': 8.5.0(storybook@8.5.0) - '@storybook/addon-toolbars': 8.5.0(storybook@8.5.0) - '@storybook/addon-viewport': 8.5.0(storybook@8.5.0) - storybook: 8.5.0 - ts-dedent: 2.2.0 - transitivePeerDependencies: - - '@types/react' - - '@storybook/addon-highlight@8.5.0(storybook@8.5.0)': + '@storybook/addon-links@9.1.13(react@19.1.1)(storybook@9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)))': dependencies: '@storybook/global': 5.0.0 - storybook: 8.5.0 - - '@storybook/addon-interactions@8.5.0(storybook@8.5.0)': - dependencies: - '@storybook/global': 5.0.0 - '@storybook/instrumenter': 8.5.0(storybook@8.5.0) - '@storybook/test': 8.5.0(storybook@8.5.0) - polished: 4.3.1 - storybook: 8.5.0 - ts-dedent: 2.2.0 - - '@storybook/addon-links@8.5.0(react@19.1.1)(storybook@8.5.0)': - dependencies: - '@storybook/csf': 0.1.12 - '@storybook/global': 5.0.0 - storybook: 8.5.0 - ts-dedent: 2.2.0 + storybook: 9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)) optionalDependencies: react: 19.1.1 - '@storybook/addon-measure@8.5.0(storybook@8.5.0)': + '@storybook/addon-onboarding@9.1.13(storybook@9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)))': dependencies: - '@storybook/global': 5.0.0 - storybook: 8.5.0 - tiny-invariant: 1.3.3 + storybook: 9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)) - '@storybook/addon-onboarding@8.5.0(storybook@8.5.0)': + '@storybook/addon-themes@9.1.13(storybook@9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)))': dependencies: - storybook: 8.5.0 - - '@storybook/addon-outline@8.5.0(storybook@8.5.0)': - dependencies: - '@storybook/global': 5.0.0 - storybook: 8.5.0 + storybook: 9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)) ts-dedent: 2.2.0 - '@storybook/addon-themes@8.5.0(storybook@8.5.0)': + '@storybook/builder-webpack5@9.1.13(esbuild@0.25.0)(storybook@9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)))(typescript@5.8.3)(uglify-js@3.19.3)': dependencies: - storybook: 8.5.0 - ts-dedent: 2.2.0 - - '@storybook/addon-toolbars@8.5.0(storybook@8.5.0)': - dependencies: - storybook: 8.5.0 - - '@storybook/addon-viewport@8.5.0(storybook@8.5.0)': - dependencies: - memoizerific: 1.11.3 - storybook: 8.5.0 - - '@storybook/blocks@8.5.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@8.5.0)': - dependencies: - '@storybook/csf': 0.1.12 - '@storybook/icons': 1.4.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1) - storybook: 8.5.0 - ts-dedent: 2.2.0 - optionalDependencies: - react: 18.3.1 - react-dom: 18.3.1(react@18.3.1) - - '@storybook/builder-webpack5@8.5.0(esbuild@0.25.0)(storybook@8.5.0)(typescript@5.8.3)(uglify-js@3.19.3)': - dependencies: - '@storybook/core-webpack': 8.5.0(storybook@8.5.0) - '@types/semver': 7.7.0 - browser-assert: 1.2.1 + '@storybook/core-webpack': 9.1.13(storybook@9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0))) case-sensitive-paths-webpack-plugin: 2.4.0 cjs-module-lexer: 1.4.3 - constants-browserify: 1.0.0 - css-loader: 6.11.0(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)) + css-loader: 6.11.0(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)) es-module-lexer: 1.7.0 - fork-ts-checker-webpack-plugin: 8.0.0(typescript@5.8.3)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)) - html-webpack-plugin: 5.6.3(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)) - magic-string: 0.30.17 - path-browserify: 1.0.1 - process: 0.11.10 - semver: 7.7.2 - storybook: 8.5.0 - style-loader: 3.3.4(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)) - terser-webpack-plugin: 5.3.14(esbuild@0.25.0)(uglify-js@3.19.3)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)) + fork-ts-checker-webpack-plugin: 8.0.0(typescript@5.8.3)(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)) + html-webpack-plugin: 5.5.4(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)) + magic-string: 0.30.19 + storybook: 9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)) + style-loader: 3.3.4(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)) + terser-webpack-plugin: 5.3.14(esbuild@0.25.0)(uglify-js@3.19.3)(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)) ts-dedent: 2.2.0 - url: 0.11.4 - util: 0.12.5 - util-deprecate: 1.0.2 - webpack: 5.100.2(esbuild@0.25.0)(uglify-js@3.19.3) - webpack-dev-middleware: 6.1.3(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)) + webpack: 5.102.1(esbuild@0.25.0)(uglify-js@3.19.3) + webpack-dev-middleware: 6.1.3(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)) webpack-hot-middleware: 2.26.1 webpack-virtual-modules: 0.6.2 optionalDependencies: @@ -11682,64 +11716,24 @@ snapshots: - uglify-js - webpack-cli - '@storybook/components@8.5.0(storybook@8.5.0)': + '@storybook/core-webpack@9.1.13(storybook@9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)))': dependencies: - storybook: 8.5.0 - - '@storybook/core-webpack@8.5.0(storybook@8.5.0)': - dependencies: - storybook: 8.5.0 + storybook: 9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)) ts-dedent: 2.2.0 - '@storybook/core@8.5.0': + '@storybook/csf-plugin@9.1.13(storybook@9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)))': dependencies: - '@storybook/csf': 0.1.12 - better-opn: 3.0.2 - browser-assert: 1.2.1 - esbuild: 0.25.0 - esbuild-register: 3.6.0(esbuild@0.25.0) - jsdoc-type-pratt-parser: 4.1.0 - process: 0.11.10 - recast: 0.23.11 - semver: 7.7.2 - util: 0.12.5 - ws: 8.18.3 - transitivePeerDependencies: - - bufferutil - - supports-color - - utf-8-validate - - '@storybook/csf-plugin@8.5.0(storybook@8.5.0)': - dependencies: - storybook: 8.5.0 + storybook: 9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)) unplugin: 1.16.1 - '@storybook/csf@0.1.12': - dependencies: - type-fest: 2.19.0 - - '@storybook/csf@0.1.13': - dependencies: - type-fest: 2.19.0 - '@storybook/global@5.0.0': {} - '@storybook/icons@1.4.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + '@storybook/icons@1.6.0(react-dom@19.1.1(react@19.1.1))(react@19.1.1)': dependencies: - react: 18.3.1 - react-dom: 18.3.1(react@18.3.1) + react: 19.1.1 + react-dom: 19.1.1(react@19.1.1) - '@storybook/instrumenter@8.5.0(storybook@8.5.0)': - dependencies: - '@storybook/global': 5.0.0 - '@vitest/utils': 2.1.9 - storybook: 8.5.0 - - '@storybook/manager-api@8.5.0(storybook@8.5.0)': - dependencies: - storybook: 8.5.0 - - '@storybook/nextjs@8.5.0(esbuild@0.25.0)(next@15.5.4(@babel/core@7.28.3)(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(sass@1.92.1))(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(sass@1.92.1)(storybook@8.5.0)(type-fest@2.19.0)(typescript@5.8.3)(uglify-js@3.19.3)(webpack-hot-middleware@2.26.1)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3))': + '@storybook/nextjs@9.1.13(esbuild@0.25.0)(next@15.5.4(@babel/core@7.28.3)(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(sass@1.92.1))(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(sass@1.92.1)(storybook@9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)))(type-fest@2.19.0)(typescript@5.8.3)(uglify-js@3.19.3)(webpack-hot-middleware@2.26.1)(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3))': dependencies: '@babel/core': 7.28.3 '@babel/plugin-syntax-bigint': 7.8.3(@babel/core@7.28.3) @@ -11753,39 +11747,34 @@ snapshots: '@babel/preset-env': 7.28.3(@babel/core@7.28.3) '@babel/preset-react': 7.27.1(@babel/core@7.28.3) '@babel/preset-typescript': 7.27.1(@babel/core@7.28.3) - '@babel/runtime': 7.27.6 - '@pmmmwh/react-refresh-webpack-plugin': 0.5.17(react-refresh@0.14.2)(type-fest@2.19.0)(webpack-hot-middleware@2.26.1)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)) - '@storybook/builder-webpack5': 8.5.0(esbuild@0.25.0)(storybook@8.5.0)(typescript@5.8.3)(uglify-js@3.19.3) - '@storybook/preset-react-webpack': 8.5.0(@storybook/test@8.5.0(storybook@8.5.0))(esbuild@0.25.0)(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(storybook@8.5.0)(typescript@5.8.3)(uglify-js@3.19.3) - '@storybook/react': 8.5.0(@storybook/test@8.5.0(storybook@8.5.0))(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(storybook@8.5.0)(typescript@5.8.3) - '@storybook/test': 8.5.0(storybook@8.5.0) + '@babel/runtime': 7.28.4 + '@pmmmwh/react-refresh-webpack-plugin': 0.5.17(react-refresh@0.14.2)(type-fest@2.19.0)(webpack-hot-middleware@2.26.1)(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)) + '@storybook/builder-webpack5': 9.1.13(esbuild@0.25.0)(storybook@9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)))(typescript@5.8.3)(uglify-js@3.19.3) + '@storybook/preset-react-webpack': 9.1.13(esbuild@0.25.0)(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(storybook@9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)))(typescript@5.8.3)(uglify-js@3.19.3) + '@storybook/react': 9.1.13(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(storybook@9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)))(typescript@5.8.3) '@types/semver': 7.7.0 - babel-loader: 9.2.1(@babel/core@7.28.3)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)) - css-loader: 6.11.0(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)) - find-up: 5.0.0 - image-size: 1.2.1 + babel-loader: 9.2.1(@babel/core@7.28.3)(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)) + css-loader: 6.11.0(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)) + image-size: 2.0.2 loader-utils: 3.3.1 next: 15.5.4(@babel/core@7.28.3)(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(sass@1.92.1) - node-polyfill-webpack-plugin: 2.0.1(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)) - pnp-webpack-plugin: 1.7.0(typescript@5.8.3) + node-polyfill-webpack-plugin: 2.0.1(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)) postcss: 8.5.6 - postcss-loader: 8.1.1(postcss@8.5.6)(typescript@5.8.3)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)) + postcss-loader: 8.1.1(postcss@8.5.6)(typescript@5.8.3)(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)) react: 19.1.1 react-dom: 19.1.1(react@19.1.1) react-refresh: 0.14.2 resolve-url-loader: 5.0.0 - sass-loader: 14.2.1(sass@1.92.1)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)) + sass-loader: 16.0.5(sass@1.92.1)(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)) semver: 7.7.2 - storybook: 8.5.0 - style-loader: 3.3.4(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)) + storybook: 9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)) + style-loader: 3.3.4(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)) styled-jsx: 5.1.7(@babel/core@7.28.3)(react@19.1.1) - ts-dedent: 2.2.0 tsconfig-paths: 4.2.0 tsconfig-paths-webpack-plugin: 4.2.0 optionalDependencies: - sharp: 0.33.5 typescript: 5.8.3 - webpack: 5.100.2(esbuild@0.25.0)(uglify-js@3.19.3) + webpack: 5.102.1(esbuild@0.25.0)(uglify-js@3.19.3) transitivePeerDependencies: - '@rspack/core' - '@swc/core' @@ -11804,39 +11793,33 @@ snapshots: - webpack-hot-middleware - webpack-plugin-serve - '@storybook/preset-react-webpack@8.5.0(@storybook/test@8.5.0(storybook@8.5.0))(esbuild@0.25.0)(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(storybook@8.5.0)(typescript@5.8.3)(uglify-js@3.19.3)': + '@storybook/preset-react-webpack@9.1.13(esbuild@0.25.0)(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(storybook@9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)))(typescript@5.8.3)(uglify-js@3.19.3)': dependencies: - '@storybook/core-webpack': 8.5.0(storybook@8.5.0) - '@storybook/react': 8.5.0(@storybook/test@8.5.0(storybook@8.5.0))(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(storybook@8.5.0)(typescript@5.8.3) - '@storybook/react-docgen-typescript-plugin': 1.0.6--canary.9.0c3f3b7.0(typescript@5.8.3)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)) + '@storybook/core-webpack': 9.1.13(storybook@9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0))) + '@storybook/react-docgen-typescript-plugin': 1.0.6--canary.9.0c3f3b7.0(typescript@5.8.3)(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)) '@types/semver': 7.7.0 - find-up: 5.0.0 - magic-string: 0.30.17 + find-up: 7.0.0 + magic-string: 0.30.19 react: 19.1.1 react-docgen: 7.1.1 react-dom: 19.1.1(react@19.1.1) resolve: 1.22.10 semver: 7.7.2 - storybook: 8.5.0 + storybook: 9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)) tsconfig-paths: 4.2.0 - webpack: 5.100.2(esbuild@0.25.0)(uglify-js@3.19.3) + webpack: 5.102.1(esbuild@0.25.0)(uglify-js@3.19.3) optionalDependencies: typescript: 5.8.3 transitivePeerDependencies: - - '@storybook/test' - '@swc/core' - esbuild - supports-color - uglify-js - webpack-cli - '@storybook/preview-api@8.5.0(storybook@8.5.0)': + '@storybook/react-docgen-typescript-plugin@1.0.6--canary.9.0c3f3b7.0(typescript@5.8.3)(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3))': dependencies: - storybook: 8.5.0 - - '@storybook/react-docgen-typescript-plugin@1.0.6--canary.9.0c3f3b7.0(typescript@5.8.3)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3))': - dependencies: - debug: 4.4.1 + debug: 4.4.3 endent: 2.1.0 find-cache-dir: 3.3.2 flat-cache: 3.2.0 @@ -11844,53 +11827,26 @@ snapshots: react-docgen-typescript: 2.4.0(typescript@5.8.3) tslib: 2.8.1 typescript: 5.8.3 - webpack: 5.100.2(esbuild@0.25.0)(uglify-js@3.19.3) + webpack: 5.102.1(esbuild@0.25.0)(uglify-js@3.19.3) transitivePeerDependencies: - supports-color - '@storybook/react-dom-shim@8.5.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@8.5.0)': - dependencies: - react: 18.3.1 - react-dom: 18.3.1(react@18.3.1) - storybook: 8.5.0 - - '@storybook/react-dom-shim@8.5.0(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(storybook@8.5.0)': + '@storybook/react-dom-shim@9.1.13(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(storybook@9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)))': dependencies: react: 19.1.1 react-dom: 19.1.1(react@19.1.1) - storybook: 8.5.0 + storybook: 9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)) - '@storybook/react@8.5.0(@storybook/test@8.5.0(storybook@8.5.0))(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(storybook@8.5.0)(typescript@5.8.3)': + '@storybook/react@9.1.13(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(storybook@9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)))(typescript@5.8.3)': dependencies: - '@storybook/components': 8.5.0(storybook@8.5.0) '@storybook/global': 5.0.0 - '@storybook/manager-api': 8.5.0(storybook@8.5.0) - '@storybook/preview-api': 8.5.0(storybook@8.5.0) - '@storybook/react-dom-shim': 8.5.0(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(storybook@8.5.0) - '@storybook/theming': 8.5.0(storybook@8.5.0) + '@storybook/react-dom-shim': 9.1.13(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(storybook@9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0))) react: 19.1.1 react-dom: 19.1.1(react@19.1.1) - storybook: 8.5.0 + storybook: 9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)) optionalDependencies: - '@storybook/test': 8.5.0(storybook@8.5.0) typescript: 5.8.3 - '@storybook/test@8.5.0(storybook@8.5.0)': - dependencies: - '@storybook/csf': 0.1.12 - '@storybook/global': 5.0.0 - '@storybook/instrumenter': 8.5.0(storybook@8.5.0) - '@testing-library/dom': 10.4.0 - '@testing-library/jest-dom': 6.5.0 - '@testing-library/user-event': 14.5.2(@testing-library/dom@10.4.0) - '@vitest/expect': 2.0.5 - '@vitest/spy': 2.0.5 - storybook: 8.5.0 - - '@storybook/theming@8.5.0(storybook@8.5.0)': - dependencies: - storybook: 8.5.0 - '@stylistic/eslint-plugin@5.2.2(eslint@9.35.0(jiti@2.6.1))': dependencies: '@eslint-community/eslint-utils': 4.7.0(eslint@9.35.0(jiti@2.6.1)) @@ -11987,16 +11943,6 @@ snapshots: lz-string: 1.5.0 pretty-format: 27.5.1 - '@testing-library/jest-dom@6.5.0': - dependencies: - '@adobe/css-tools': 4.4.4 - aria-query: 5.3.2 - chalk: 3.0.0 - css.escape: 1.5.1 - dom-accessibility-api: 0.6.3 - lodash: 4.17.21 - redent: 3.0.0 - '@testing-library/jest-dom@6.8.0': dependencies: '@adobe/css-tools': 4.4.4 @@ -12016,7 +11962,7 @@ snapshots: '@types/react': 19.1.11 '@types/react-dom': 19.1.7(@types/react@19.1.11) - '@testing-library/user-event@14.5.2(@testing-library/dom@10.4.0)': + '@testing-library/user-event@14.6.1(@testing-library/dom@10.4.0)': dependencies: '@testing-library/dom': 10.4.0 @@ -12067,6 +12013,10 @@ snapshots: '@types/node': 18.15.0 '@types/responselike': 1.0.3 + '@types/chai@5.2.2': + dependencies: + '@types/deep-eql': 4.0.2 + '@types/d3-array@3.2.1': {} '@types/d3-axis@3.0.6': @@ -12188,6 +12138,8 @@ snapshots: dependencies: '@types/ms': 2.1.0 + '@types/deep-eql@4.0.2': {} + '@types/doctrine@0.0.9': {} '@types/eslint-scope@3.7.7': @@ -12334,8 +12286,6 @@ snapshots: '@types/uuid@10.0.0': {} - '@types/uuid@9.0.8': {} - '@types/whatwg-mimetype@3.0.2': {} '@types/yargs-parser@21.0.3': {} @@ -12564,37 +12514,35 @@ snapshots: transitivePeerDependencies: - supports-color - '@vitest/expect@2.0.5': + '@vitest/expect@3.2.4': dependencies: - '@vitest/spy': 2.0.5 - '@vitest/utils': 2.0.5 + '@types/chai': 5.2.2 + '@vitest/spy': 3.2.4 + '@vitest/utils': 3.2.4 chai: 5.2.1 - tinyrainbow: 1.2.0 + tinyrainbow: 2.0.0 - '@vitest/pretty-format@2.0.5': + '@vitest/mocker@3.2.4(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0))': dependencies: - tinyrainbow: 1.2.0 - - '@vitest/pretty-format@2.1.9': - dependencies: - tinyrainbow: 1.2.0 - - '@vitest/spy@2.0.5': - dependencies: - tinyspy: 3.0.2 - - '@vitest/utils@2.0.5': - dependencies: - '@vitest/pretty-format': 2.0.5 + '@vitest/spy': 3.2.4 estree-walker: 3.0.3 - loupe: 3.1.4 - tinyrainbow: 1.2.0 + magic-string: 0.30.19 + optionalDependencies: + vite: 6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0) - '@vitest/utils@2.1.9': + '@vitest/pretty-format@3.2.4': dependencies: - '@vitest/pretty-format': 2.1.9 + tinyrainbow: 2.0.0 + + '@vitest/spy@3.2.4': + dependencies: + tinyspy: 4.0.4 + + '@vitest/utils@3.2.4': + dependencies: + '@vitest/pretty-format': 3.2.4 loupe: 3.1.4 - tinyrainbow: 1.2.0 + tinyrainbow: 2.0.0 '@vue/compiler-core@3.5.17': dependencies: @@ -12898,27 +12846,27 @@ snapshots: transitivePeerDependencies: - supports-color - babel-loader@10.0.0(@babel/core@7.28.3)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)): + babel-loader@10.0.0(@babel/core@7.28.3)(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)): dependencies: '@babel/core': 7.28.3 find-up: 5.0.0 - webpack: 5.100.2(esbuild@0.25.0)(uglify-js@3.19.3) + webpack: 5.102.1(esbuild@0.25.0)(uglify-js@3.19.3) - babel-loader@8.4.1(@babel/core@7.28.3)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)): + babel-loader@8.4.1(@babel/core@7.28.3)(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)): dependencies: '@babel/core': 7.28.3 find-cache-dir: 3.3.2 loader-utils: 2.0.4 make-dir: 3.1.0 schema-utils: 2.7.1 - webpack: 5.100.2(esbuild@0.25.0)(uglify-js@3.19.3) + webpack: 5.102.1(esbuild@0.25.0)(uglify-js@3.19.3) - babel-loader@9.2.1(@babel/core@7.28.3)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)): + babel-loader@9.2.1(@babel/core@7.28.3)(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)): dependencies: '@babel/core': 7.28.3 find-cache-dir: 4.0.0 - schema-utils: 4.3.2 - webpack: 5.100.2(esbuild@0.25.0)(uglify-js@3.19.3) + schema-utils: 4.3.3 + webpack: 5.102.1(esbuild@0.25.0)(uglify-js@3.19.3) babel-plugin-istanbul@6.1.1: dependencies: @@ -12992,6 +12940,8 @@ snapshots: base64-js@1.5.1: {} + baseline-browser-mapping@2.8.18: {} + before-after-hook@3.0.2: {} better-opn@3.0.2: @@ -13024,8 +12974,6 @@ snapshots: brorand@1.1.0: {} - browser-assert@1.2.1: {} - browserify-aes@1.2.0: dependencies: buffer-xor: 1.0.3 @@ -13033,7 +12981,7 @@ snapshots: create-hash: 1.2.0 evp_bytestokey: 1.0.3 inherits: 2.0.4 - safe-buffer: '@nolyfill/safe-buffer@1.0.44' + safe-buffer: 5.2.1 browserify-cipher@1.0.1: dependencies: @@ -13046,13 +12994,13 @@ snapshots: cipher-base: 1.0.6 des.js: 1.1.0 inherits: 2.0.4 - safe-buffer: '@nolyfill/safe-buffer@1.0.44' + safe-buffer: 5.2.1 browserify-rsa@4.1.1: dependencies: bn.js: 5.2.2 randombytes: 2.1.0 - safe-buffer: '@nolyfill/safe-buffer@1.0.44' + safe-buffer: 5.2.1 browserify-sign@4.2.3: dependencies: @@ -13065,7 +13013,7 @@ snapshots: inherits: 2.0.4 parse-asn1: 5.1.7 readable-stream: 2.3.8 - safe-buffer: '@nolyfill/safe-buffer@1.0.44' + safe-buffer: 5.2.1 browserify-zlib@0.2.0: dependencies: @@ -13078,6 +13026,14 @@ snapshots: node-releases: 2.0.19 update-browserslist-db: 1.1.3(browserslist@4.25.1) + browserslist@4.26.3: + dependencies: + baseline-browser-mapping: 2.8.18 + caniuse-lite: 1.0.30001746 + electron-to-chromium: 1.5.237 + node-releases: 2.0.25 + update-browserslist-db: 1.1.3(browserslist@4.26.3) + bser@2.1.1: dependencies: node-int64: 0.4.0 @@ -13130,6 +13086,8 @@ snapshots: caniuse-lite@1.0.30001746: {} + caniuse-lite@1.0.30001751: {} + canvas@2.11.2: dependencies: '@mapbox/node-pre-gyp': 1.0.11 @@ -13152,11 +13110,6 @@ snapshots: loupe: 3.1.4 pathval: 2.0.1 - chalk@3.0.0: - dependencies: - ansi-styles: 4.3.0 - supports-color: 7.2.0 - chalk@4.1.1: dependencies: ansi-styles: 4.3.0 @@ -13222,7 +13175,7 @@ snapshots: chownr@2.0.0: optional: true - chromatic@11.29.0: {} + chromatic@12.2.0: {} chrome-trace-event@1.0.4: {} @@ -13233,7 +13186,7 @@ snapshots: cipher-base@1.0.6: dependencies: inherits: 2.0.4 - safe-buffer: '@nolyfill/safe-buffer@1.0.44' + safe-buffer: 5.2.1 cjs-module-lexer@1.4.3: {} @@ -13255,10 +13208,10 @@ snapshots: dependencies: escape-string-regexp: 1.0.5 - clean-webpack-plugin@4.0.0(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)): + clean-webpack-plugin@4.0.0(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)): dependencies: del: 4.1.1 - webpack: 5.100.2(esbuild@0.25.0)(uglify-js@3.19.3) + webpack: 5.102.1(esbuild@0.25.0)(uglify-js@3.19.3) cli-cursor@5.0.0: dependencies: @@ -13421,7 +13374,7 @@ snapshots: dependencies: cipher-base: 1.0.6 inherits: 2.0.4 - ripemd160: 2.0.1 + ripemd160: 2.0.2 sha.js: 2.4.12 create-hash@1.2.0: @@ -13438,7 +13391,7 @@ snapshots: create-hash: 1.2.0 inherits: 2.0.4 ripemd160: 2.0.2 - safe-buffer: '@nolyfill/safe-buffer@1.0.44' + safe-buffer: 5.2.1 sha.js: 2.4.12 create-jest@29.7.0(@types/node@18.15.0)(ts-node@10.9.2(@types/node@18.15.0)(typescript@5.8.3)): @@ -13487,7 +13440,7 @@ snapshots: crypto-random-string@2.0.0: {} - css-loader@6.11.0(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)): + css-loader@6.11.0(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)): dependencies: icss-utils: 5.1.0(postcss@8.5.6) postcss: 8.5.6 @@ -13498,7 +13451,7 @@ snapshots: postcss-value-parser: 4.2.0 semver: 7.7.2 optionalDependencies: - webpack: 5.100.2(esbuild@0.25.0)(uglify-js@3.19.3) + webpack: 5.102.1(esbuild@0.25.0)(uglify-js@3.19.3) css-select@4.3.0: dependencies: @@ -13772,7 +13725,7 @@ snapshots: detect-libc@2.1.0: {} - detect-libc@2.1.1: + detect-libc@2.1.2: optional: true detect-newline@3.1.0: {} @@ -13867,6 +13820,8 @@ snapshots: electron-to-chromium@1.5.186: {} + electron-to-chromium@1.5.237: {} + elkjs@0.9.3: {} elliptic@6.6.1: @@ -13938,7 +13893,7 @@ snapshots: esbuild-register@3.6.0(esbuild@0.25.0): dependencies: - debug: 4.4.1 + debug: 4.4.3 esbuild: 0.25.0 transitivePeerDependencies: - supports-color @@ -14260,12 +14215,11 @@ snapshots: semver: 7.7.2 typescript: 5.8.3 - eslint-plugin-storybook@9.0.7(eslint@9.35.0(jiti@2.6.1))(typescript@5.8.3): + eslint-plugin-storybook@9.1.13(eslint@9.35.0(jiti@2.6.1))(storybook@9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)))(typescript@5.8.3): dependencies: - '@storybook/csf': 0.1.13 '@typescript-eslint/utils': 8.44.0(eslint@9.35.0(jiti@2.6.1))(typescript@5.8.3) eslint: 9.35.0(jiti@2.6.1) - ts-dedent: 2.2.0 + storybook: 9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)) transitivePeerDependencies: - supports-color - typescript @@ -14473,7 +14427,7 @@ snapshots: evp_bytestokey@1.0.3: dependencies: md5.js: 1.3.5 - safe-buffer: '@nolyfill/safe-buffer@1.0.44' + safe-buffer: 5.2.1 execa@5.1.1: dependencies: @@ -14609,6 +14563,12 @@ snapshots: locate-path: 7.2.0 path-exists: 5.0.0 + find-up@7.0.0: + dependencies: + locate-path: 7.2.0 + path-exists: 5.0.0 + unicorn-magic: 0.1.0 + flat-cache@3.2.0: dependencies: flatted: 3.3.3 @@ -14627,7 +14587,7 @@ snapshots: cross-spawn: 7.0.6 signal-exit: 4.1.0 - fork-ts-checker-webpack-plugin@8.0.0(typescript@5.8.3)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)): + fork-ts-checker-webpack-plugin@8.0.0(typescript@5.8.3)(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)): dependencies: '@babel/code-frame': 7.27.1 chalk: 4.1.2 @@ -14640,9 +14600,9 @@ snapshots: node-abort-controller: 3.1.1 schema-utils: 3.3.0 semver: 7.7.2 - tapable: 2.2.2 + tapable: 2.3.0 typescript: 5.8.3 - webpack: 5.100.2(esbuild@0.25.0)(uglify-js@3.19.3) + webpack: 5.102.1(esbuild@0.25.0)(uglify-js@3.19.3) format@0.2.2: {} @@ -14811,7 +14771,7 @@ snapshots: hash-base@3.0.5: dependencies: inherits: 2.0.4 - safe-buffer: '@nolyfill/safe-buffer@1.0.44' + safe-buffer: 5.2.1 hash.js@1.1.7: dependencies: @@ -14995,15 +14955,14 @@ snapshots: html-void-elements@3.0.0: {} - html-webpack-plugin@5.6.3(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)): + html-webpack-plugin@5.5.4(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)): dependencies: '@types/html-minifier-terser': 6.1.0 html-minifier-terser: 6.1.0 lodash: 4.17.21 pretty-error: 4.0.0 - tapable: 2.2.2 - optionalDependencies: - webpack: 5.100.2(esbuild@0.25.0)(uglify-js@3.19.3) + tapable: 2.3.0 + webpack: 5.102.1(esbuild@0.25.0)(uglify-js@3.19.3) htmlparser2@6.1.0: dependencies: @@ -15059,9 +15018,7 @@ snapshots: ignore@7.0.5: {} - image-size@1.2.1: - dependencies: - queue: 6.0.2 + image-size@2.0.2: {} immer@10.1.3: {} @@ -15819,10 +15776,6 @@ snapshots: dependencies: sourcemap-codec: 1.4.8 - magic-string@0.30.17: - dependencies: - '@jridgewell/sourcemap-codec': 1.5.4 - magic-string@0.30.19: dependencies: '@jridgewell/sourcemap-codec': 1.5.5 @@ -15848,8 +15801,6 @@ snapshots: dependencies: tmpl: 1.0.5 - map-or-similar@1.5.0: {} - markdown-extensions@2.0.0: {} markdown-table@3.0.4: {} @@ -15860,7 +15811,7 @@ snapshots: dependencies: hash-base: 3.0.5 inherits: 2.0.4 - safe-buffer: '@nolyfill/safe-buffer@1.0.44' + safe-buffer: 5.2.1 mdast-util-find-and-replace@3.0.2: dependencies: @@ -16059,10 +16010,6 @@ snapshots: memoize-one@5.2.1: {} - memoizerific@1.11.3: - dependencies: - map-or-similar: 1.5.0 - merge-stream@2.0.0: {} merge2@1.4.1: {} @@ -16488,14 +16435,14 @@ snapshots: neo-async@2.6.2: {} - next-pwa@5.6.0(@babel/core@7.28.3)(@types/babel__core@7.20.5)(esbuild@0.25.0)(next@15.5.4(@babel/core@7.28.3)(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(sass@1.92.1))(uglify-js@3.19.3)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)): + next-pwa@5.6.0(@babel/core@7.28.3)(@types/babel__core@7.20.5)(esbuild@0.25.0)(next@15.5.4(@babel/core@7.28.3)(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(sass@1.92.1))(uglify-js@3.19.3)(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)): dependencies: - babel-loader: 8.4.1(@babel/core@7.28.3)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)) - clean-webpack-plugin: 4.0.0(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)) + babel-loader: 8.4.1(@babel/core@7.28.3)(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)) + clean-webpack-plugin: 4.0.0(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)) globby: 11.1.0 next: 15.5.4(@babel/core@7.28.3)(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(sass@1.92.1) - terser-webpack-plugin: 5.3.14(esbuild@0.25.0)(uglify-js@3.19.3)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)) - workbox-webpack-plugin: 6.6.0(@types/babel__core@7.20.5)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)) + terser-webpack-plugin: 5.3.14(esbuild@0.25.0)(uglify-js@3.19.3)(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)) + workbox-webpack-plugin: 6.6.0(@types/babel__core@7.20.5)(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)) workbox-window: 6.6.0 transitivePeerDependencies: - '@babel/core' @@ -16515,7 +16462,7 @@ snapshots: dependencies: '@next/env': 15.5.4 '@swc/helpers': 0.5.15 - caniuse-lite: 1.0.30001746 + caniuse-lite: 1.0.30001751 postcss: 8.4.31 react: 19.1.1 react-dom: 19.1.1(react@19.1.1) @@ -16552,7 +16499,7 @@ snapshots: node-int64@0.4.0: {} - node-polyfill-webpack-plugin@2.0.1(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)): + node-polyfill-webpack-plugin@2.0.1(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)): dependencies: assert: '@nolyfill/assert@1.0.26' browserify-zlib: 0.2.0 @@ -16579,10 +16526,12 @@ snapshots: url: 0.11.4 util: 0.12.5 vm-browserify: 1.1.2 - webpack: 5.100.2(esbuild@0.25.0)(uglify-js@3.19.3) + webpack: 5.102.1(esbuild@0.25.0)(uglify-js@3.19.3) node-releases@2.0.19: {} + node-releases@2.0.25: {} + nopt@5.0.0: dependencies: abbrev: 1.1.1 @@ -16733,7 +16682,7 @@ snapshots: evp_bytestokey: 1.0.3 hash-base: 3.0.5 pbkdf2: 3.1.3 - safe-buffer: '@nolyfill/safe-buffer@1.0.44' + safe-buffer: 5.2.1 parse-entities@2.0.0: dependencies: @@ -16815,7 +16764,7 @@ snapshots: create-hash: 1.1.3 create-hmac: 1.1.7 ripemd160: 2.0.1 - safe-buffer: '@nolyfill/safe-buffer@1.0.44' + safe-buffer: 5.2.1 sha.js: 2.4.12 to-buffer: 1.2.1 @@ -16877,12 +16826,6 @@ snapshots: pluralize@8.0.0: {} - pnp-webpack-plugin@1.7.0(typescript@5.8.3): - dependencies: - ts-pnp: 1.2.0(typescript@5.8.3) - transitivePeerDependencies: - - typescript - pnpm-workspace-yaml@1.1.0: dependencies: yaml: 2.8.0 @@ -16894,10 +16837,6 @@ snapshots: path-data-parser: 0.1.0 points-on-curve: 0.2.0 - polished@4.3.1: - dependencies: - '@babel/runtime': 7.27.6 - portfinder@1.0.37: dependencies: async: 3.2.6 @@ -16925,14 +16864,14 @@ snapshots: postcss: 8.5.6 ts-node: 10.9.2(@types/node@18.15.0)(typescript@5.8.3) - postcss-loader@8.1.1(postcss@8.5.6)(typescript@5.8.3)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)): + postcss-loader@8.1.1(postcss@8.5.6)(typescript@5.8.3)(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)): dependencies: cosmiconfig: 9.0.0(typescript@5.8.3) jiti: 1.21.7 postcss: 8.5.6 semver: 7.7.2 optionalDependencies: - webpack: 5.100.2(esbuild@0.25.0)(uglify-js@3.19.3) + webpack: 5.102.1(esbuild@0.25.0)(uglify-js@3.19.3) transitivePeerDependencies: - typescript @@ -17044,7 +16983,7 @@ snapshots: create-hash: 1.2.0 parse-asn1: 5.1.7 randombytes: 2.1.0 - safe-buffer: '@nolyfill/safe-buffer@1.0.44' + safe-buffer: 5.2.1 pump@3.0.3: dependencies: @@ -17073,20 +17012,16 @@ snapshots: queue-microtask@1.2.3: {} - queue@6.0.2: - dependencies: - inherits: 2.0.4 - quick-lru@5.1.1: {} randombytes@2.1.0: dependencies: - safe-buffer: '@nolyfill/safe-buffer@1.0.44' + safe-buffer: 5.2.1 randomfill@1.0.4: dependencies: randombytes: 2.1.0 - safe-buffer: '@nolyfill/safe-buffer@1.0.44' + safe-buffer: 5.2.1 range-parser@1.2.1: {} @@ -17100,11 +17035,6 @@ snapshots: prop-types: 15.8.1 react: 19.1.1 - react-confetti@6.4.0(react@19.1.1): - dependencies: - react: 19.1.1 - tween-functions: 1.2.0 - react-docgen-typescript@2.4.0(typescript@5.8.3): dependencies: typescript: 5.8.3 @@ -17112,8 +17042,8 @@ snapshots: react-docgen@7.1.1: dependencies: '@babel/core': 7.28.3 - '@babel/traverse': 7.28.0 - '@babel/types': 7.28.1 + '@babel/traverse': 7.28.3 + '@babel/types': 7.28.4 '@types/babel__core': 7.20.5 '@types/babel__traverse': 7.20.7 '@types/doctrine': 0.0.9 @@ -17124,12 +17054,6 @@ snapshots: transitivePeerDependencies: - supports-color - react-dom@18.3.1(react@18.3.1): - dependencies: - loose-envify: 1.4.0 - react: 18.3.1 - scheduler: 0.23.2 - react-dom@19.1.1(react@19.1.1): dependencies: react: 19.1.1 @@ -17297,10 +17221,6 @@ snapshots: react: 19.1.1 react-dom: 19.1.1(react@19.1.1) - react@18.3.1: - dependencies: - loose-envify: 1.4.0 - react@19.1.1: {} reactflow@11.11.4(@types/react@19.1.11)(immer@10.1.3)(react-dom@19.1.1(react@19.1.1))(react@19.1.1): @@ -17327,7 +17247,7 @@ snapshots: inherits: 2.0.4 isarray: '@nolyfill/isarray@1.0.44' process-nextick-args: 2.0.1 - safe-buffer: '@nolyfill/safe-buffer@1.0.44' + safe-buffer: 5.2.1 string_decoder: 1.1.1 util-deprecate: 1.0.2 @@ -17603,6 +17523,35 @@ snapshots: optionalDependencies: fsevents: 2.3.3 + rollup@4.52.5: + dependencies: + '@types/estree': 1.0.8 + optionalDependencies: + '@rollup/rollup-android-arm-eabi': 4.52.5 + '@rollup/rollup-android-arm64': 4.52.5 + '@rollup/rollup-darwin-arm64': 4.52.5 + '@rollup/rollup-darwin-x64': 4.52.5 + '@rollup/rollup-freebsd-arm64': 4.52.5 + '@rollup/rollup-freebsd-x64': 4.52.5 + '@rollup/rollup-linux-arm-gnueabihf': 4.52.5 + '@rollup/rollup-linux-arm-musleabihf': 4.52.5 + '@rollup/rollup-linux-arm64-gnu': 4.52.5 + '@rollup/rollup-linux-arm64-musl': 4.52.5 + '@rollup/rollup-linux-loong64-gnu': 4.52.5 + '@rollup/rollup-linux-ppc64-gnu': 4.52.5 + '@rollup/rollup-linux-riscv64-gnu': 4.52.5 + '@rollup/rollup-linux-riscv64-musl': 4.52.5 + '@rollup/rollup-linux-s390x-gnu': 4.52.5 + '@rollup/rollup-linux-x64-gnu': 4.52.5 + '@rollup/rollup-linux-x64-musl': 4.52.5 + '@rollup/rollup-openharmony-arm64': 4.52.5 + '@rollup/rollup-win32-arm64-msvc': 4.52.5 + '@rollup/rollup-win32-ia32-msvc': 4.52.5 + '@rollup/rollup-win32-x64-gnu': 4.52.5 + '@rollup/rollup-win32-x64-msvc': 4.52.5 + fsevents: 2.3.3 + optional: true + roughjs@4.6.6: dependencies: hachure-fill: 0.5.2 @@ -17616,12 +17565,14 @@ snapshots: rw@1.3.3: {} - sass-loader@14.2.1(sass@1.92.1)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)): + safe-buffer@5.2.1: {} + + sass-loader@16.0.5(sass@1.92.1)(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)): dependencies: neo-async: 2.6.2 optionalDependencies: sass: 1.92.1 - webpack: 5.100.2(esbuild@0.25.0)(uglify-js@3.19.3) + webpack: 5.102.1(esbuild@0.25.0)(uglify-js@3.19.3) sass@1.92.1: dependencies: @@ -17631,10 +17582,6 @@ snapshots: optionalDependencies: '@parcel/watcher': 2.5.1 - scheduler@0.23.2: - dependencies: - loose-envify: 1.4.0 - scheduler@0.26.0: {} schema-utils@2.7.1: @@ -17656,6 +17603,13 @@ snapshots: ajv-formats: 2.1.1(ajv@8.17.1) ajv-keywords: 5.1.0(ajv@8.17.1) + schema-utils@4.3.3: + dependencies: + '@types/json-schema': 7.0.15 + ajv: 8.17.1 + ajv-formats: 2.1.1(ajv@8.17.1) + ajv-keywords: 5.1.0(ajv@8.17.1) + screenfull@5.2.0: {} scslre@0.3.0: @@ -17684,7 +17638,7 @@ snapshots: sha.js@2.4.12: dependencies: inherits: 2.0.4 - safe-buffer: '@nolyfill/safe-buffer@1.0.44' + safe-buffer: 5.2.1 to-buffer: 1.2.1 sharp@0.33.5: @@ -17716,7 +17670,7 @@ snapshots: sharp@0.34.4: dependencies: '@img/colour': 1.0.0 - detect-libc: 2.1.1 + detect-libc: 2.1.2 semver: 7.7.2 optionalDependencies: '@img/sharp-darwin-arm64': 0.34.4 @@ -17842,13 +17796,27 @@ snapshots: state-local@1.0.7: {} - storybook@8.5.0: + storybook@9.1.13(@testing-library/dom@10.4.0)(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)): dependencies: - '@storybook/core': 8.5.0 + '@storybook/global': 5.0.0 + '@testing-library/jest-dom': 6.8.0 + '@testing-library/user-event': 14.6.1(@testing-library/dom@10.4.0) + '@vitest/expect': 3.2.4 + '@vitest/mocker': 3.2.4(vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0)) + '@vitest/spy': 3.2.4 + better-opn: 3.0.2 + esbuild: 0.25.0 + esbuild-register: 3.6.0(esbuild@0.25.0) + recast: 0.23.11 + semver: 7.7.2 + ws: 8.18.3 transitivePeerDependencies: + - '@testing-library/dom' - bufferutil + - msw - supports-color - utf-8-validate + - vite stream-browserify@3.0.0: dependencies: @@ -17879,11 +17847,11 @@ snapshots: string_decoder@1.1.1: dependencies: - safe-buffer: '@nolyfill/safe-buffer@1.0.44' + safe-buffer: 5.2.1 string_decoder@1.3.0: dependencies: - safe-buffer: '@nolyfill/safe-buffer@1.0.44' + safe-buffer: 5.2.1 stringify-entities@4.0.4: dependencies: @@ -17928,9 +17896,9 @@ snapshots: strip-json-comments@5.0.2: {} - style-loader@3.3.4(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)): + style-loader@3.3.4(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)): dependencies: - webpack: 5.100.2(esbuild@0.25.0)(uglify-js@3.19.3) + webpack: 5.102.1(esbuild@0.25.0)(uglify-js@3.19.3) style-to-js@1.1.17: dependencies: @@ -18019,6 +17987,8 @@ snapshots: tapable@2.2.2: {} + tapable@2.3.0: {} + tar@6.2.1: dependencies: chownr: 2.0.0 @@ -18038,14 +18008,14 @@ snapshots: type-fest: 0.16.0 unique-string: 2.0.0 - terser-webpack-plugin@5.3.14(esbuild@0.25.0)(uglify-js@3.19.3)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)): + terser-webpack-plugin@5.3.14(esbuild@0.25.0)(uglify-js@3.19.3)(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)): dependencies: '@jridgewell/trace-mapping': 0.3.29 jest-worker: 27.5.1 schema-utils: 4.3.2 serialize-javascript: 6.0.2 terser: 5.43.1 - webpack: 5.100.2(esbuild@0.25.0)(uglify-js@3.19.3) + webpack: 5.102.1(esbuild@0.25.0)(uglify-js@3.19.3) optionalDependencies: esbuild: 0.25.0 uglify-js: 3.19.3 @@ -18086,9 +18056,9 @@ snapshots: fdir: 6.4.6(picomatch@4.0.3) picomatch: 4.0.3 - tinyrainbow@1.2.0: {} + tinyrainbow@2.0.0: {} - tinyspy@3.0.2: {} + tinyspy@4.0.4: {} tldts-core@7.0.10: {} @@ -18101,7 +18071,7 @@ snapshots: to-buffer@1.2.1: dependencies: isarray: '@nolyfill/isarray@1.0.44' - safe-buffer: '@nolyfill/safe-buffer@1.0.44' + safe-buffer: 5.2.1 typed-array-buffer: '@nolyfill/typed-array-buffer@1.0.44' to-regex-range@5.0.1: @@ -18163,15 +18133,11 @@ snapshots: ts-pattern@5.7.1: {} - ts-pnp@1.2.0(typescript@5.8.3): - optionalDependencies: - typescript: 5.8.3 - tsconfig-paths-webpack-plugin@4.2.0: dependencies: chalk: 4.1.2 enhanced-resolve: 5.18.2 - tapable: 2.2.2 + tapable: 2.3.0 tsconfig-paths: 4.2.0 tsconfig-paths@4.2.0: @@ -18188,8 +18154,6 @@ snapshots: tty-browserify@0.0.1: {} - tween-functions@1.2.0: {} - type-check@0.4.0: dependencies: prelude-ls: 1.2.1 @@ -18221,6 +18185,8 @@ snapshots: unicode-property-aliases-ecmascript@2.1.0: {} + unicorn-magic@0.1.0: {} + unified@11.0.5: dependencies: '@types/unist': 3.0.3 @@ -18289,6 +18255,12 @@ snapshots: escalade: 3.2.0 picocolors: 1.1.1 + update-browserslist-db@1.1.3(browserslist@4.26.3): + dependencies: + browserslist: 4.26.3 + escalade: 3.2.0 + picocolors: 1.1.1 + uri-js@4.4.1: dependencies: punycode: 2.3.1 @@ -18359,8 +18331,6 @@ snapshots: uuid@11.1.0: {} - uuid@9.0.1: {} - v8-compile-cache-lib@3.0.1: optional: true @@ -18385,6 +18355,20 @@ snapshots: '@types/unist': 3.0.3 vfile-message: 4.0.2 + vite@6.2.7(@types/node@18.15.0)(jiti@2.6.1)(sass@1.92.1)(terser@5.43.1)(yaml@2.8.0): + dependencies: + esbuild: 0.25.0 + postcss: 8.5.6 + rollup: 4.52.5 + optionalDependencies: + '@types/node': 18.15.0 + fsevents: 2.3.3 + jiti: 2.6.1 + sass: 1.92.1 + terser: 5.43.1 + yaml: 2.8.0 + optional: true + vm-browserify@1.1.2: {} void-elements@3.1.0: {} @@ -18455,15 +18439,15 @@ snapshots: - bufferutil - utf-8-validate - webpack-dev-middleware@6.1.3(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)): + webpack-dev-middleware@6.1.3(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)): dependencies: colorette: 2.0.20 memfs: 3.5.3 mime-types: 2.1.35 range-parser: 1.2.1 - schema-utils: 4.3.2 + schema-utils: 4.3.3 optionalDependencies: - webpack: 5.100.2(esbuild@0.25.0)(uglify-js@3.19.3) + webpack: 5.102.1(esbuild@0.25.0)(uglify-js@3.19.3) webpack-hot-middleware@2.26.1: dependencies: @@ -18480,7 +18464,7 @@ snapshots: webpack-virtual-modules@0.6.2: {} - webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3): + webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3): dependencies: '@types/eslint-scope': 3.7.7 '@types/estree': 1.0.8 @@ -18490,7 +18474,7 @@ snapshots: '@webassemblyjs/wasm-parser': 1.14.1 acorn: 8.15.0 acorn-import-phases: 1.0.4(acorn@8.15.0) - browserslist: 4.25.1 + browserslist: 4.26.3 chrome-trace-event: 1.0.4 enhanced-resolve: 5.18.2 es-module-lexer: 1.7.0 @@ -18502,9 +18486,9 @@ snapshots: loader-runner: 4.3.0 mime-types: 2.1.35 neo-async: 2.6.2 - schema-utils: 4.3.2 - tapable: 2.2.2 - terser-webpack-plugin: 5.3.14(esbuild@0.25.0)(uglify-js@3.19.3)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)) + schema-utils: 4.3.3 + tapable: 2.3.0 + terser-webpack-plugin: 5.3.14(esbuild@0.25.0)(uglify-js@3.19.3)(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)) watchpack: 2.4.4 webpack-sources: 3.3.3 transitivePeerDependencies: @@ -18645,12 +18629,12 @@ snapshots: workbox-sw@6.6.0: {} - workbox-webpack-plugin@6.6.0(@types/babel__core@7.20.5)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)): + workbox-webpack-plugin@6.6.0(@types/babel__core@7.20.5)(webpack@5.102.1(esbuild@0.25.0)(uglify-js@3.19.3)): dependencies: fast-json-stable-stringify: 2.1.0 pretty-bytes: 5.6.0 upath: 1.2.0 - webpack: 5.100.2(esbuild@0.25.0)(uglify-js@3.19.3) + webpack: 5.102.1(esbuild@0.25.0)(uglify-js@3.19.3) webpack-sources: 1.4.3 workbox-build: 6.6.0(@types/babel__core@7.20.5) transitivePeerDependencies: From 8cf4a0d3ad7ab763ad46a3e6e5185ceb11aeb9ba Mon Sep 17 00:00:00 2001 From: Novice Date: Mon, 20 Oct 2025 10:54:02 +0800 Subject: [PATCH 46/46] chore: handle merge conflict --- api/core/tools/tool_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 5ab32ecad4..0641fa01fe 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -24,7 +24,7 @@ from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.tool import PluginTool from core.tools.utils.uuid_utils import is_valid_uuid from core.tools.workflow_as_tool.provider import WorkflowToolProviderController -from core.workflow.entities.variable_pool import VariablePool +from core.workflow.runtime.variable_pool import VariablePool from extensions.ext_database import db from models.provider_ids import ToolProviderID from services.enterprise.plugin_manager_service import PluginCredentialType