From 631f999f6597f16e82a4d459ce385fd41d92a81d Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Mon, 5 Jan 2026 15:48:31 +0800 Subject: [PATCH 01/17] refactor: use contains_any instead of Chaining where = where | f (#30559) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../rag/datasource/vdb/weaviate/weaviate_vector.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 84d1e26b34..b48dd93f04 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -66,6 +66,8 @@ class WeaviateVector(BaseVector): in a Weaviate collection. """ + _DOCUMENT_ID_PROPERTY = "document_id" + def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list): """ Initializes the Weaviate vector store. @@ -353,15 +355,12 @@ class WeaviateVector(BaseVector): return [] col = self._client.collections.use(self._collection_name) - props = list({*self._attributes, "document_id", Field.TEXT_KEY.value}) + props = list({*self._attributes, self._DOCUMENT_ID_PROPERTY, Field.TEXT_KEY.value}) where = None doc_ids = kwargs.get("document_ids_filter") or [] if doc_ids: - ors = [Filter.by_property("document_id").equal(x) for x in doc_ids] - where = ors[0] - for f in ors[1:]: - where = where | f + where = Filter.by_property(self._DOCUMENT_ID_PROPERTY).contains_any(doc_ids) top_k = int(kwargs.get("top_k", 4)) score_threshold = float(kwargs.get("score_threshold") or 0.0) @@ -408,10 +407,7 @@ class WeaviateVector(BaseVector): where = None doc_ids = kwargs.get("document_ids_filter") or [] if doc_ids: - ors = [Filter.by_property("document_id").equal(x) for x in doc_ids] - where = ors[0] - for f in ors[1:]: - where = where | f + where = Filter.by_property(self._DOCUMENT_ID_PROPERTY).contains_any(doc_ids) top_k = int(kwargs.get("top_k", 4)) From 52149c0d9b6b1d6bd21ea3aac0c7a0f9ae26e7e3 Mon Sep 17 00:00:00 2001 From: Stephen Zhou <38493346+hyoban@users.noreply.github.com> Date: Mon, 5 Jan 2026 15:49:31 +0800 Subject: [PATCH 02/17] chore(web): add ESLint rules for i18n JSON validation (#30491) --- web/eslint-rules/index.js | 4 ++ web/eslint-rules/rules/no-extra-keys.js | 70 +++++++++++++++++++++++ web/eslint-rules/rules/valid-i18n-keys.js | 61 ++++++++++++++++++++ web/eslint-rules/utils.js | 10 ++++ web/eslint.config.mjs | 24 +++++--- 5 files changed, 160 insertions(+), 9 deletions(-) create mode 100644 web/eslint-rules/rules/no-extra-keys.js create mode 100644 web/eslint-rules/rules/valid-i18n-keys.js create mode 100644 web/eslint-rules/utils.js diff --git a/web/eslint-rules/index.js b/web/eslint-rules/index.js index edb6b96ba4..66c2034625 100644 --- a/web/eslint-rules/index.js +++ b/web/eslint-rules/index.js @@ -1,6 +1,8 @@ import noAsAnyInT from './rules/no-as-any-in-t.js' +import noExtraKeys from './rules/no-extra-keys.js' import noLegacyNamespacePrefix from './rules/no-legacy-namespace-prefix.js' import requireNsOption from './rules/require-ns-option.js' +import validI18nKeys from './rules/valid-i18n-keys.js' /** @type {import('eslint').ESLint.Plugin} */ const plugin = { @@ -10,8 +12,10 @@ const plugin = { }, rules: { 'no-as-any-in-t': noAsAnyInT, + 'no-extra-keys': noExtraKeys, 'no-legacy-namespace-prefix': noLegacyNamespacePrefix, 'require-ns-option': requireNsOption, + 'valid-i18n-keys': validI18nKeys, }, } diff --git a/web/eslint-rules/rules/no-extra-keys.js b/web/eslint-rules/rules/no-extra-keys.js new file mode 100644 index 0000000000..eb47f60934 --- /dev/null +++ b/web/eslint-rules/rules/no-extra-keys.js @@ -0,0 +1,70 @@ +import fs from 'node:fs' +import path, { normalize, sep } from 'node:path' + +/** @type {import('eslint').Rule.RuleModule} */ +export default { + meta: { + type: 'problem', + docs: { + description: 'Ensure non-English JSON files don\'t have extra keys not present in en-US', + }, + fixable: 'code', + }, + create(context) { + return { + Program(node) { + const { filename, sourceCode } = context + + if (!filename.endsWith('.json')) + return + + const parts = normalize(filename).split(sep) + // e.g., i18n/ar-TN/common.json -> jsonFile = common.json, lang = ar-TN + const jsonFile = parts.at(-1) + const lang = parts.at(-2) + + // Skip English files + if (lang === 'en-US') + return + + let currentJson = {} + let englishJson = {} + + try { + currentJson = JSON.parse(sourceCode.text) + // Look for the same filename in en-US folder + // e.g., i18n/ar-TN/common.json -> i18n/en-US/common.json + const englishFilePath = path.join(path.dirname(filename), '..', 'en-US', jsonFile ?? '') + englishJson = JSON.parse(fs.readFileSync(englishFilePath, 'utf8')) + } + catch (error) { + context.report({ + node, + message: `Error parsing JSON: ${error instanceof Error ? error.message : String(error)}`, + }) + return + } + + const extraKeys = Object.keys(currentJson).filter( + key => !Object.prototype.hasOwnProperty.call(englishJson, key), + ) + + for (const key of extraKeys) { + context.report({ + node, + message: `Key "${key}" is present in ${lang}/${jsonFile} but not in en-US/${jsonFile}`, + fix(fixer) { + const newJson = Object.fromEntries( + Object.entries(currentJson).filter(([k]) => !extraKeys.includes(k)), + ) + + const newText = `${JSON.stringify(newJson, null, 2)}\n` + + return fixer.replaceText(node, newText) + }, + }) + } + }, + } + }, +} diff --git a/web/eslint-rules/rules/valid-i18n-keys.js b/web/eslint-rules/rules/valid-i18n-keys.js new file mode 100644 index 0000000000..08d863a19a --- /dev/null +++ b/web/eslint-rules/rules/valid-i18n-keys.js @@ -0,0 +1,61 @@ +import { cleanJsonText } from '../utils.js' + +/** @type {import('eslint').Rule.RuleModule} */ +export default { + meta: { + type: 'problem', + docs: { + description: 'Ensure i18n JSON keys are flat and valid as object paths', + }, + }, + create(context) { + return { + Program(node) { + const { filename, sourceCode } = context + + if (!filename.endsWith('.json')) + return + + let json + try { + json = JSON.parse(cleanJsonText(sourceCode.text)) + } + catch { + context.report({ + node, + message: 'Invalid JSON format', + }) + return + } + + const keys = Object.keys(json) + const keyPrefixes = new Set() + + for (const key of keys) { + if (key.includes('.')) { + const parts = key.split('.') + for (let i = 1; i < parts.length; i++) { + const prefix = parts.slice(0, i).join('.') + if (keys.includes(prefix)) { + context.report({ + node, + message: `Invalid key structure: '${key}' conflicts with '${prefix}'`, + }) + } + keyPrefixes.add(prefix) + } + } + } + + for (const key of keys) { + if (keyPrefixes.has(key)) { + context.report({ + node, + message: `Invalid key structure: '${key}' is a prefix of another key`, + }) + } + } + }, + } + }, +} diff --git a/web/eslint-rules/utils.js b/web/eslint-rules/utils.js new file mode 100644 index 0000000000..2030c96b5a --- /dev/null +++ b/web/eslint-rules/utils.js @@ -0,0 +1,10 @@ +export const cleanJsonText = (text) => { + const cleaned = text.replaceAll(/,\s*\}/g, '}') + try { + JSON.parse(cleaned) + return cleaned + } + catch { + return text + } +} diff --git a/web/eslint.config.mjs b/web/eslint.config.mjs index 3cdd3efedb..2cfe2e5e13 100644 --- a/web/eslint.config.mjs +++ b/web/eslint.config.mjs @@ -130,15 +130,6 @@ export default antfu( sonarjs: sonar, }, }, - // allow generated i18n files (like i18n/*/workflow.ts) to exceed max-lines - { - files: ['i18n/**'], - rules: { - 'sonarjs/max-lines': 'off', - 'max-lines': 'off', - 'jsonc/sort-keys': 'error', - }, - }, tailwind.configs['flat/recommended'], { settings: { @@ -191,4 +182,19 @@ export default antfu( 'dify-i18n/require-ns-option': 'error', }, }, + // i18n JSON validation rules + { + files: ['i18n/**/*.json'], + plugins: { + 'dify-i18n': difyI18n, + }, + rules: { + 'sonarjs/max-lines': 'off', + 'max-lines': 'off', + 'jsonc/sort-keys': 'error', + + 'dify-i18n/valid-i18n-keys': 'error', + 'dify-i18n/no-extra-keys': 'error', + }, + }, ) From a99ac3fe0df489cdcbeff9796a6153bebab1f192 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 5 Jan 2026 15:50:03 +0800 Subject: [PATCH 03/17] refactor(models): Add mapped type hints to MessageAnnotation (#27751) --- api/commands.py | 2 +- .../features/annotation_reply/annotation_reply.py | 2 +- api/models/model.py | 15 ++++++++++----- api/services/annotation_service.py | 4 ++-- .../annotation/enable_annotation_reply_task.py | 2 +- 5 files changed, 15 insertions(+), 10 deletions(-) diff --git a/api/commands.py b/api/commands.py index 44f7b42825..7ebf5b4874 100644 --- a/api/commands.py +++ b/api/commands.py @@ -235,7 +235,7 @@ def migrate_annotation_vector_database(): if annotations: for annotation in annotations: document = Document( - page_content=annotation.question, + page_content=annotation.question_text, metadata={"annotation_id": annotation.id, "app_id": app.id, "doc_id": annotation.id}, ) documents.append(document) diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 79fbafe39e..3f9f3da9b2 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -75,7 +75,7 @@ class AnnotationReplyFeature: AppAnnotationService.add_annotation_history( annotation.id, app_record.id, - annotation.question, + annotation.question_text, annotation.content, query, user_id, diff --git a/api/models/model.py b/api/models/model.py index 88cb945b3f..6cfcc2859d 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1419,15 +1419,20 @@ class MessageAnnotation(Base): app_id: Mapped[str] = mapped_column(StringUUID) conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id")) message_id: Mapped[str | None] = mapped_column(StringUUID) - question = mapped_column(LongText, nullable=True) - content = mapped_column(LongText, nullable=False) + question: Mapped[str | None] = mapped_column(LongText, nullable=True) + content: Mapped[str] = mapped_column(LongText, nullable=False) hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) - account_id = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column( + account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column( sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) + @property + def question_text(self) -> str: + """Return a non-null question string, falling back to the answer content.""" + return self.question or self.content + @property def account(self): account = db.session.query(Account).where(Account.id == self.account_id).first() diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index d03cbddceb..7f44fe05a6 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -77,7 +77,7 @@ class AppAnnotationService: if annotation_setting: add_annotation_to_index_task.delay( annotation.id, - annotation.question, + question, current_tenant_id, app_id, annotation_setting.collection_binding_id, @@ -253,7 +253,7 @@ class AppAnnotationService: if app_annotation_setting: update_annotation_to_index_task.delay( annotation.id, - annotation.question, + annotation.question_text, current_tenant_id, app_id, app_annotation_setting.collection_binding_id, diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index cdc07c77a8..be1de3cdd2 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -98,7 +98,7 @@ def enable_annotation_reply_task( if annotations: for annotation in annotations: document = Document( - page_content=annotation.question, + page_content=annotation.question_text, metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id}, ) documents.append(document) From 34f3b288a720687020b3d1da59f7897a936fb407 Mon Sep 17 00:00:00 2001 From: Lework Date: Mon, 5 Jan 2026 15:50:33 +0800 Subject: [PATCH 04/17] chore(docker): update nltk data download process to include unstructured download_nltk_packages (#28876) --- api/Dockerfile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/api/Dockerfile b/api/Dockerfile index 02df91bfc1..e800e60322 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -79,7 +79,8 @@ COPY --from=packages --chown=dify:dify ${VIRTUAL_ENV} ${VIRTUAL_ENV} ENV PATH="${VIRTUAL_ENV}/bin:${PATH}" # Download nltk data -RUN mkdir -p /usr/local/share/nltk_data && NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords')" \ +RUN mkdir -p /usr/local/share/nltk_data \ + && NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; from unstructured.nlp.tokenize import download_nltk_packages; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords'); download_nltk_packages()" \ && chmod -R 755 /usr/local/share/nltk_data ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache From a72044aa86185fdce3e89b4dff605559c34a417a Mon Sep 17 00:00:00 2001 From: Stephen Zhou <38493346+hyoban@users.noreply.github.com> Date: Mon, 5 Jan 2026 16:12:12 +0800 Subject: [PATCH 05/17] chore: fix lint in i18n (#30571) --- web/i18n/ar-TN/common.json | 4 ---- web/i18n/de-DE/common.json | 4 ---- web/i18n/es-ES/common.json | 4 ---- web/i18n/fa-IR/common.json | 4 ---- web/i18n/fr-FR/common.json | 4 ---- web/i18n/hi-IN/common.json | 4 ---- web/i18n/id-ID/common.json | 4 ---- web/i18n/it-IT/common.json | 4 ---- web/i18n/ko-KR/common.json | 4 ---- web/i18n/pl-PL/common.json | 4 ---- web/i18n/pt-BR/common.json | 4 ---- web/i18n/ro-RO/common.json | 4 ---- web/i18n/ru-RU/common.json | 4 ---- web/i18n/sl-SI/common.json | 4 ---- web/i18n/th-TH/common.json | 4 ---- web/i18n/tr-TR/common.json | 4 ---- web/i18n/uk-UA/common.json | 4 ---- web/i18n/vi-VN/common.json | 4 ---- web/i18n/zh-Hant/common.json | 4 ---- 19 files changed, 76 deletions(-) diff --git a/web/i18n/ar-TN/common.json b/web/i18n/ar-TN/common.json index beda6bb4c7..d015f1ae0b 100644 --- a/web/i18n/ar-TN/common.json +++ b/web/i18n/ar-TN/common.json @@ -339,9 +339,6 @@ "modelProvider.callTimes": "أوقات الاتصال", "modelProvider.card.buyQuota": "شراء حصة", "modelProvider.card.callTimes": "أوقات الاتصال", - "modelProvider.card.modelAPI": "النماذج {{modelName}} تستخدم مفتاح واجهة برمجة التطبيقات.", - "modelProvider.card.modelNotSupported": "النماذج {{modelName}} غير مثبتة.", - "modelProvider.card.modelSupported": "النماذج {{modelName}} تستخدم هذا الحصة.", "modelProvider.card.onTrial": "في التجربة", "modelProvider.card.paid": "مدفوع", "modelProvider.card.priorityUse": "أولوية الاستخدام", @@ -397,7 +394,6 @@ "modelProvider.quotaTip": "الرموز المجانية المتاحة المتبقية", "modelProvider.rerankModel.key": "نموذج إعادة الترتيب", "modelProvider.rerankModel.tip": "سيعيد نموذج إعادة الترتيب ترتيب قائمة المستندات المرشحة بناءً على المطابقة الدلالية مع استعلام المستخدم، مما يحسن نتائج الترتيب الدلالي", - "modelProvider.resetDate": "إعادة الضبط على {{date}}", "modelProvider.searchModel": "نموذج البحث", "modelProvider.selectModel": "اختر نموذجك", "modelProvider.selector.emptySetting": "يرجى الانتقال إلى الإعدادات للتكوين", diff --git a/web/i18n/de-DE/common.json b/web/i18n/de-DE/common.json index 1792c9b7ca..f54f6a939f 100644 --- a/web/i18n/de-DE/common.json +++ b/web/i18n/de-DE/common.json @@ -339,9 +339,6 @@ "modelProvider.callTimes": "Anrufzeiten", "modelProvider.card.buyQuota": "Kontingent kaufen", "modelProvider.card.callTimes": "Anrufzeiten", - "modelProvider.card.modelAPI": "{{modelName}}-Modelle verwenden den API-Schlüssel.", - "modelProvider.card.modelNotSupported": "{{modelName}}-Modelle sind nicht installiert.", - "modelProvider.card.modelSupported": "{{modelName}}-Modelle verwenden dieses Kontingent.", "modelProvider.card.onTrial": "In Probe", "modelProvider.card.paid": "Bezahlt", "modelProvider.card.priorityUse": "Priorisierte Nutzung", @@ -397,7 +394,6 @@ "modelProvider.quotaTip": "Verbleibende verfügbare kostenlose Token", "modelProvider.rerankModel.key": "Rerank-Modell", "modelProvider.rerankModel.tip": "Rerank-Modell wird die Kandidatendokumentenliste basierend auf der semantischen Übereinstimmung mit der Benutzeranfrage neu ordnen und die Ergebnisse der semantischen Rangordnung verbessern", - "modelProvider.resetDate": "Zurücksetzen bei {{date}}", "modelProvider.searchModel": "Suchmodell", "modelProvider.selectModel": "Wählen Sie Ihr Modell", "modelProvider.selector.emptySetting": "Bitte gehen Sie zu den Einstellungen, um zu konfigurieren", diff --git a/web/i18n/es-ES/common.json b/web/i18n/es-ES/common.json index d99c36d9dd..ec08f11ed7 100644 --- a/web/i18n/es-ES/common.json +++ b/web/i18n/es-ES/common.json @@ -339,9 +339,6 @@ "modelProvider.callTimes": "Tiempos de llamada", "modelProvider.card.buyQuota": "Comprar Cuota", "modelProvider.card.callTimes": "Tiempos de llamada", - "modelProvider.card.modelAPI": "Los modelos {{modelName}} están usando la clave de API.", - "modelProvider.card.modelNotSupported": "Los modelos {{modelName}} no están instalados.", - "modelProvider.card.modelSupported": "Los modelos {{modelName}} están utilizando esta cuota.", "modelProvider.card.onTrial": "En prueba", "modelProvider.card.paid": "Pagado", "modelProvider.card.priorityUse": "Uso prioritario", @@ -397,7 +394,6 @@ "modelProvider.quotaTip": "Tokens gratuitos restantes disponibles", "modelProvider.rerankModel.key": "Modelo de Reordenar", "modelProvider.rerankModel.tip": "El modelo de reordenar reordenará la lista de documentos candidatos basada en la coincidencia semántica con la consulta del usuario, mejorando los resultados de clasificación semántica", - "modelProvider.resetDate": "Reiniciar en {{date}}", "modelProvider.searchModel": "Modelo de búsqueda", "modelProvider.selectModel": "Selecciona tu modelo", "modelProvider.selector.emptySetting": "Por favor ve a configuraciones para configurar", diff --git a/web/i18n/fa-IR/common.json b/web/i18n/fa-IR/common.json index 588b37ee43..78f9b9e388 100644 --- a/web/i18n/fa-IR/common.json +++ b/web/i18n/fa-IR/common.json @@ -339,9 +339,6 @@ "modelProvider.callTimes": "تعداد فراخوانی", "modelProvider.card.buyQuota": "خرید سهمیه", "modelProvider.card.callTimes": "تعداد فراخوانی", - "modelProvider.card.modelAPI": "مدل‌های {{modelName}} در حال استفاده از کلید API هستند.", - "modelProvider.card.modelNotSupported": "مدل‌های {{modelName}} نصب نشده‌اند.", - "modelProvider.card.modelSupported": "مدل‌های {{modelName}} از این سهمیه استفاده می‌کنند.", "modelProvider.card.onTrial": "در حال آزمایش", "modelProvider.card.paid": "پرداخت شده", "modelProvider.card.priorityUse": "استفاده با اولویت", @@ -397,7 +394,6 @@ "modelProvider.quotaTip": "توکن‌های رایگان باقی‌مانده در دسترس", "modelProvider.rerankModel.key": "مدل رتبه‌بندی مجدد", "modelProvider.rerankModel.tip": "مدل رتبه‌بندی مجدد، لیست اسناد کاندید را بر اساس تطابق معنایی با پرسش کاربر مرتب می‌کند و نتایج رتبه‌بندی معنایی را بهبود می‌بخشد", - "modelProvider.resetDate": "بازنشانی در {{date}}", "modelProvider.searchModel": "جستجوی مدل", "modelProvider.selectModel": "مدل خود را انتخاب کنید", "modelProvider.selector.emptySetting": "لطفاً به تنظیمات بروید تا پیکربندی کنید", diff --git a/web/i18n/fr-FR/common.json b/web/i18n/fr-FR/common.json index 7df7d86272..7cc1af2d80 100644 --- a/web/i18n/fr-FR/common.json +++ b/web/i18n/fr-FR/common.json @@ -339,9 +339,6 @@ "modelProvider.callTimes": "Temps d'appel", "modelProvider.card.buyQuota": "Acheter Quota", "modelProvider.card.callTimes": "Temps d'appel", - "modelProvider.card.modelAPI": "Les modèles {{modelName}} utilisent la clé API.", - "modelProvider.card.modelNotSupported": "Les modèles {{modelName}} ne sont pas installés.", - "modelProvider.card.modelSupported": "Les modèles {{modelName}} utilisent ce quota.", "modelProvider.card.onTrial": "En Essai", "modelProvider.card.paid": "Payé", "modelProvider.card.priorityUse": "Utilisation prioritaire", @@ -397,7 +394,6 @@ "modelProvider.quotaTip": "Tokens gratuits restants disponibles", "modelProvider.rerankModel.key": "Modèle de Réorganisation", "modelProvider.rerankModel.tip": "Le modèle de réorganisation réorganisera la liste des documents candidats en fonction de la correspondance sémantique avec la requête de l'utilisateur, améliorant ainsi les résultats du classement sémantique.", - "modelProvider.resetDate": "Réinitialiser sur {{date}}", "modelProvider.searchModel": "Modèle de recherche", "modelProvider.selectModel": "Sélectionnez votre modèle", "modelProvider.selector.emptySetting": "Veuillez aller dans les paramètres pour configurer", diff --git a/web/i18n/hi-IN/common.json b/web/i18n/hi-IN/common.json index 996880d51c..4670d5a545 100644 --- a/web/i18n/hi-IN/common.json +++ b/web/i18n/hi-IN/common.json @@ -339,9 +339,6 @@ "modelProvider.callTimes": "कॉल समय", "modelProvider.card.buyQuota": "कोटा खरीदें", "modelProvider.card.callTimes": "कॉल समय", - "modelProvider.card.modelAPI": "{{modelName}} मॉडल एपीआई कुंजी का उपयोग कर रहे हैं।", - "modelProvider.card.modelNotSupported": "{{modelName}} मॉडल इंस्टॉल नहीं हैं।", - "modelProvider.card.modelSupported": "{{modelName}} मॉडल इस कोटा का उपयोग कर रहे हैं।", "modelProvider.card.onTrial": "परीक्षण पर", "modelProvider.card.paid": "भुगतान किया हुआ", "modelProvider.card.priorityUse": "प्राथमिकता उपयोग", @@ -397,7 +394,6 @@ "modelProvider.quotaTip": "बचे हुए उपलब्ध मुफ्त टोकन", "modelProvider.rerankModel.key": "रीरैंक मॉडल", "modelProvider.rerankModel.tip": "रीरैंक मॉडल उपयोगकर्ता प्रश्न के साथ सांविधिक मेल के आधार पर उम्मीदवार दस्तावेज़ सूची को पुनः क्रमित करेगा, सांविधिक रैंकिंग के परिणामों में सुधार करेगा।", - "modelProvider.resetDate": "{{date}} पर रीसेट करें", "modelProvider.searchModel": "खोज मॉडल", "modelProvider.selectModel": "अपने मॉडल का चयन करें", "modelProvider.selector.emptySetting": "कॉन्फ़िगर करने के लिए कृपया सेटिंग्स पर जाएं", diff --git a/web/i18n/id-ID/common.json b/web/i18n/id-ID/common.json index 710136c7e2..ede4d3ae44 100644 --- a/web/i18n/id-ID/common.json +++ b/web/i18n/id-ID/common.json @@ -339,9 +339,6 @@ "modelProvider.callTimes": "Waktu panggilan", "modelProvider.card.buyQuota": "Beli Kuota", "modelProvider.card.callTimes": "Waktu panggilan", - "modelProvider.card.modelAPI": "Model {{modelName}} sedang menggunakan API Key.", - "modelProvider.card.modelNotSupported": "Model {{modelName}} tidak terpasang.", - "modelProvider.card.modelSupported": "Model {{modelName}} sedang menggunakan kuota ini.", "modelProvider.card.onTrial": "Sedang Diadili", "modelProvider.card.paid": "Dibayar", "modelProvider.card.priorityUse": "Penggunaan prioritas", @@ -397,7 +394,6 @@ "modelProvider.quotaTip": "Token gratis yang masih tersedia", "modelProvider.rerankModel.key": "Peringkat ulang Model", "modelProvider.rerankModel.tip": "Model rerank akan menyusun ulang daftar dokumen kandidat berdasarkan kecocokan semantik dengan kueri pengguna, meningkatkan hasil peringkat semantik", - "modelProvider.resetDate": "Atur ulang pada {{date}}", "modelProvider.searchModel": "Model pencarian", "modelProvider.selectModel": "Pilih model Anda", "modelProvider.selector.emptySetting": "Silakan buka pengaturan untuk mengonfigurasi", diff --git a/web/i18n/it-IT/common.json b/web/i18n/it-IT/common.json index 64bf2e3d1d..737ef923b1 100644 --- a/web/i18n/it-IT/common.json +++ b/web/i18n/it-IT/common.json @@ -339,9 +339,6 @@ "modelProvider.callTimes": "Numero di chiamate", "modelProvider.card.buyQuota": "Acquista Quota", "modelProvider.card.callTimes": "Numero di chiamate", - "modelProvider.card.modelAPI": "I modelli {{modelName}} stanno utilizzando la chiave API.", - "modelProvider.card.modelNotSupported": "I modelli {{modelName}} non sono installati.", - "modelProvider.card.modelSupported": "I modelli {{modelName}} stanno utilizzando questa quota.", "modelProvider.card.onTrial": "In Prova", "modelProvider.card.paid": "Pagato", "modelProvider.card.priorityUse": "Uso prioritario", @@ -397,7 +394,6 @@ "modelProvider.quotaTip": "Token gratuiti rimanenti disponibili", "modelProvider.rerankModel.key": "Modello di Rerank", "modelProvider.rerankModel.tip": "Il modello di rerank riordinerà la lista dei documenti candidati basandosi sulla corrispondenza semantica con la query dell'utente, migliorando i risultati del ranking semantico", - "modelProvider.resetDate": "Reimposta su {{date}}", "modelProvider.searchModel": "Modello di ricerca", "modelProvider.selectModel": "Seleziona il tuo modello", "modelProvider.selector.emptySetting": "Per favore vai alle impostazioni per configurare", diff --git a/web/i18n/ko-KR/common.json b/web/i18n/ko-KR/common.json index fa3736e561..5640cb353d 100644 --- a/web/i18n/ko-KR/common.json +++ b/web/i18n/ko-KR/common.json @@ -339,9 +339,6 @@ "modelProvider.callTimes": "호출 횟수", "modelProvider.card.buyQuota": "Buy Quota", "modelProvider.card.callTimes": "호출 횟수", - "modelProvider.card.modelAPI": "{{modelName}} 모델이 API 키를 사용하고 있습니다.", - "modelProvider.card.modelNotSupported": "{{modelName}} 모델이 설치되지 않았습니다.", - "modelProvider.card.modelSupported": "{{modelName}} 모델이 이 할당량을 사용하고 있습니다.", "modelProvider.card.onTrial": "트라이얼 중", "modelProvider.card.paid": "유료", "modelProvider.card.priorityUse": "우선 사용", @@ -397,7 +394,6 @@ "modelProvider.quotaTip": "남은 무료 토큰 사용 가능", "modelProvider.rerankModel.key": "재랭크 모델", "modelProvider.rerankModel.tip": "재랭크 모델은 사용자 쿼리와의 의미적 일치를 기반으로 후보 문서 목록을 재배열하여 의미적 순위를 향상시킵니다.", - "modelProvider.resetDate": "{{date}}에서 재설정", "modelProvider.searchModel": "검색 모델", "modelProvider.selectModel": "모델 선택", "modelProvider.selector.emptySetting": "설정으로 이동하여 구성하세요", diff --git a/web/i18n/pl-PL/common.json b/web/i18n/pl-PL/common.json index 1a83dc517e..ae654e04ac 100644 --- a/web/i18n/pl-PL/common.json +++ b/web/i18n/pl-PL/common.json @@ -339,9 +339,6 @@ "modelProvider.callTimes": "Czasy wywołań", "modelProvider.card.buyQuota": "Kup limit", "modelProvider.card.callTimes": "Czasy wywołań", - "modelProvider.card.modelAPI": "Modele {{modelName}} używają klucza API.", - "modelProvider.card.modelNotSupported": "Modele {{modelName}} nie są zainstalowane.", - "modelProvider.card.modelSupported": "{{modelName}} modeli korzysta z tej kwoty.", "modelProvider.card.onTrial": "Na próbę", "modelProvider.card.paid": "Płatny", "modelProvider.card.priorityUse": "Używanie z priorytetem", @@ -397,7 +394,6 @@ "modelProvider.quotaTip": "Pozostałe dostępne darmowe tokeny", "modelProvider.rerankModel.key": "Model ponownego rankingu", "modelProvider.rerankModel.tip": "Model ponownego rankingu zmieni kolejność listy dokumentów kandydatów na podstawie semantycznego dopasowania z zapytaniem użytkownika, poprawiając wyniki rankingu semantycznego", - "modelProvider.resetDate": "Reset na {{date}}", "modelProvider.searchModel": "Model wyszukiwania", "modelProvider.selectModel": "Wybierz swój model", "modelProvider.selector.emptySetting": "Przejdź do ustawień, aby skonfigurować", diff --git a/web/i18n/pt-BR/common.json b/web/i18n/pt-BR/common.json index e97e4364ad..2e7f49de7e 100644 --- a/web/i18n/pt-BR/common.json +++ b/web/i18n/pt-BR/common.json @@ -339,9 +339,6 @@ "modelProvider.callTimes": "Chamadas", "modelProvider.card.buyQuota": "Comprar Quota", "modelProvider.card.callTimes": "Chamadas", - "modelProvider.card.modelAPI": "Os modelos {{modelName}} estão usando a Chave de API.", - "modelProvider.card.modelNotSupported": "Modelos {{modelName}} não estão instalados.", - "modelProvider.card.modelSupported": "Modelos {{modelName}} estão usando esta cota.", "modelProvider.card.onTrial": "Em Teste", "modelProvider.card.paid": "Pago", "modelProvider.card.priorityUse": "Uso prioritário", @@ -397,7 +394,6 @@ "modelProvider.quotaTip": "Tokens gratuitos disponíveis restantes", "modelProvider.rerankModel.key": "Modelo de Reordenação", "modelProvider.rerankModel.tip": "O modelo de reordenaenação reorganizará a lista de documentos candidatos com base na correspondência semântica com a consulta do usuário, melhorando os resultados da classificação semântica", - "modelProvider.resetDate": "Redefinir em {{date}}", "modelProvider.searchModel": "Modelo de pesquisa", "modelProvider.selectModel": "Selecione seu modelo", "modelProvider.selector.emptySetting": "Por favor, vá para configurações para configurar", diff --git a/web/i18n/ro-RO/common.json b/web/i18n/ro-RO/common.json index 785050d8ec..c21e755b3c 100644 --- a/web/i18n/ro-RO/common.json +++ b/web/i18n/ro-RO/common.json @@ -339,9 +339,6 @@ "modelProvider.callTimes": "Apeluri", "modelProvider.card.buyQuota": "Cumpără cotă", "modelProvider.card.callTimes": "Apeluri", - "modelProvider.card.modelAPI": "Modelele {{modelName}} folosesc cheia API.", - "modelProvider.card.modelNotSupported": "Modelele {{modelName}} nu sunt instalate.", - "modelProvider.card.modelSupported": "{{modelName}} modele utilizează această cotă.", "modelProvider.card.onTrial": "În probă", "modelProvider.card.paid": "Plătit", "modelProvider.card.priorityUse": "Utilizare prioritară", @@ -397,7 +394,6 @@ "modelProvider.quotaTip": "Jetoane gratuite disponibile rămase", "modelProvider.rerankModel.key": "Model de reordonare", "modelProvider.rerankModel.tip": "Modelul de reordonare va reordona lista de documente candidate pe baza potrivirii semantice cu interogarea utilizatorului, îmbunătățind rezultatele clasificării semantice", - "modelProvider.resetDate": "Resetați la {{date}}", "modelProvider.searchModel": "Model de căutare", "modelProvider.selectModel": "Selectați modelul dvs.", "modelProvider.selector.emptySetting": "Vă rugăm să mergeți la setări pentru a configura", diff --git a/web/i18n/ru-RU/common.json b/web/i18n/ru-RU/common.json index 63f3758185..e763a7ec2a 100644 --- a/web/i18n/ru-RU/common.json +++ b/web/i18n/ru-RU/common.json @@ -339,9 +339,6 @@ "modelProvider.callTimes": "Количество вызовов", "modelProvider.card.buyQuota": "Купить квоту", "modelProvider.card.callTimes": "Количество вызовов", - "modelProvider.card.modelAPI": "{{modelName}} модели используют ключ API.", - "modelProvider.card.modelNotSupported": "Модели {{modelName}} не установлены.", - "modelProvider.card.modelSupported": "Эту квоту используют модели {{modelName}}.", "modelProvider.card.onTrial": "Пробная версия", "modelProvider.card.paid": "Платный", "modelProvider.card.priorityUse": "Приоритетное использование", @@ -397,7 +394,6 @@ "modelProvider.quotaTip": "Оставшиеся доступные бесплатные токены", "modelProvider.rerankModel.key": "Модель повторного ранжирования", "modelProvider.rerankModel.tip": "Модель повторного ранжирования изменит порядок списка документов-кандидатов на основе семантического соответствия запросу пользователя, улучшая результаты семантического ранжирования", - "modelProvider.resetDate": "Сброс на {{date}}", "modelProvider.searchModel": "Поиск модели", "modelProvider.selectModel": "Выберите свою модель", "modelProvider.selector.emptySetting": "Пожалуйста, перейдите в настройки для настройки", diff --git a/web/i18n/sl-SI/common.json b/web/i18n/sl-SI/common.json index be5a4e5320..d092fe10c8 100644 --- a/web/i18n/sl-SI/common.json +++ b/web/i18n/sl-SI/common.json @@ -339,9 +339,6 @@ "modelProvider.callTimes": "Število klicev", "modelProvider.card.buyQuota": "Kupi kvoto", "modelProvider.card.callTimes": "Časi klicev", - "modelProvider.card.modelAPI": "{{modelName}} modeli uporabljajo API ključ.", - "modelProvider.card.modelNotSupported": "{{modelName}} modeli niso nameščeni.", - "modelProvider.card.modelSupported": "{{modelName}} modeli uporabljajo to kvoto.", "modelProvider.card.onTrial": "Na preizkusu", "modelProvider.card.paid": "Plačano", "modelProvider.card.priorityUse": "Prednostna uporaba", @@ -397,7 +394,6 @@ "modelProvider.quotaTip": "Preostali razpoložljivi brezplačni žetoni", "modelProvider.rerankModel.key": "Model za prerazvrstitev", "modelProvider.rerankModel.tip": "Model za prerazvrstitev bo prerazporedil seznam kandidatskih dokumentov na podlagi semantične ujemanja z uporabniško poizvedbo, s čimer se izboljšajo rezultati semantičnega razvrščanja.", - "modelProvider.resetDate": "Ponastavi na {{date}}", "modelProvider.searchModel": "Model iskanja", "modelProvider.selectModel": "Izberite svoj model", "modelProvider.selector.emptySetting": "Prosimo, pojdite v nastavitve za konfiguracijo", diff --git a/web/i18n/th-TH/common.json b/web/i18n/th-TH/common.json index c6dd9c9259..9a38f7f683 100644 --- a/web/i18n/th-TH/common.json +++ b/web/i18n/th-TH/common.json @@ -339,9 +339,6 @@ "modelProvider.callTimes": "เวลาโทร", "modelProvider.card.buyQuota": "ซื้อโควต้า", "modelProvider.card.callTimes": "เวลาโทร", - "modelProvider.card.modelAPI": "{{modelName}} โมเดลกำลังใช้คีย์ API", - "modelProvider.card.modelNotSupported": "โมเดล {{modelName}} ยังไม่ได้ติดตั้ง", - "modelProvider.card.modelSupported": "โมเดล {{modelName}} กำลังใช้โควต้านี้อยู่", "modelProvider.card.onTrial": "ทดลองใช้", "modelProvider.card.paid": "จ่าย", "modelProvider.card.priorityUse": "ลําดับความสําคัญในการใช้งาน", @@ -397,7 +394,6 @@ "modelProvider.quotaTip": "โทเค็นฟรีที่เหลืออยู่", "modelProvider.rerankModel.key": "จัดอันดับโมเดลใหม่", "modelProvider.rerankModel.tip": "โมเดล Rerank จะจัดลําดับรายการเอกสารผู้สมัครใหม่ตามการจับคู่ความหมายกับการสืบค้นของผู้ใช้ ซึ่งช่วยปรับปรุงผลลัพธ์ของการจัดอันดับความหมาย", - "modelProvider.resetDate": "รีเซ็ตเมื่อ {{date}}", "modelProvider.searchModel": "ค้นหารุ่น", "modelProvider.selectModel": "เลือกรุ่นของคุณ", "modelProvider.selector.emptySetting": "โปรดไปที่การตั้งค่าเพื่อกําหนดค่า", diff --git a/web/i18n/tr-TR/common.json b/web/i18n/tr-TR/common.json index 68d4358281..0ee51e161c 100644 --- a/web/i18n/tr-TR/common.json +++ b/web/i18n/tr-TR/common.json @@ -339,9 +339,6 @@ "modelProvider.callTimes": "Çağrı Süreleri", "modelProvider.card.buyQuota": "Kota Satın Al", "modelProvider.card.callTimes": "Çağrı Süreleri", - "modelProvider.card.modelAPI": "{{modelName}} modelleri API Anahtarını kullanıyor.", - "modelProvider.card.modelNotSupported": "{{modelName}} modelleri yüklü değil.", - "modelProvider.card.modelSupported": "{{modelName}} modelleri bu kotayı kullanıyor.", "modelProvider.card.onTrial": "Deneme Sürümünde", "modelProvider.card.paid": "Ücretli", "modelProvider.card.priorityUse": "Öncelikli Kullan", @@ -397,7 +394,6 @@ "modelProvider.quotaTip": "Kalan kullanılabilir ücretsiz tokenler", "modelProvider.rerankModel.key": "Yeniden Sıralama Modeli", "modelProvider.rerankModel.tip": "Yeniden sıralama modeli, kullanıcı sorgusuyla anlam eşleştirmesine dayalı olarak aday belge listesini yeniden sıralayacak ve anlam sıralama sonuçlarını iyileştirecektir.", - "modelProvider.resetDate": "{{date}} üzerine sıfırlama", "modelProvider.searchModel": "Model ara", "modelProvider.selectModel": "Modelinizi seçin", "modelProvider.selector.emptySetting": "Lütfen ayarlara gidip yapılandırın", diff --git a/web/i18n/uk-UA/common.json b/web/i18n/uk-UA/common.json index 7e3b7fe05f..ddec8637e1 100644 --- a/web/i18n/uk-UA/common.json +++ b/web/i18n/uk-UA/common.json @@ -339,9 +339,6 @@ "modelProvider.callTimes": "Кількість викликів", "modelProvider.card.buyQuota": "Придбати квоту", "modelProvider.card.callTimes": "Кількість викликів", - "modelProvider.card.modelAPI": "Моделі {{modelName}} використовують API-ключ.", - "modelProvider.card.modelNotSupported": "Моделі {{modelName}} не встановлені.", - "modelProvider.card.modelSupported": "Моделі {{modelName}} використовують цю квоту.", "modelProvider.card.onTrial": "У пробному періоді", "modelProvider.card.paid": "Оплачено", "modelProvider.card.priorityUse": "Пріоритетне використання", @@ -397,7 +394,6 @@ "modelProvider.quotaTip": "Залишилося доступних безкоштовних токенів", "modelProvider.rerankModel.key": "Модель повторного ранжування", "modelProvider.rerankModel.tip": "Модель повторного ранжування змінить порядок списку документів-кандидатів на основі семантичної відповідності запиту користувача, покращуючи результати семантичного ранжування.", - "modelProvider.resetDate": "Скинути на {{date}}", "modelProvider.searchModel": "Пошукова модель", "modelProvider.selectModel": "Виберіть свою модель", "modelProvider.selector.emptySetting": "Перейдіть до налаштувань, щоб налаштувати", diff --git a/web/i18n/vi-VN/common.json b/web/i18n/vi-VN/common.json index 20364c3ee9..f8fa9c07d5 100644 --- a/web/i18n/vi-VN/common.json +++ b/web/i18n/vi-VN/common.json @@ -339,9 +339,6 @@ "modelProvider.callTimes": "Số lần gọi", "modelProvider.card.buyQuota": "Mua Quota", "modelProvider.card.callTimes": "Số lần gọi", - "modelProvider.card.modelAPI": "Các mô hình {{modelName}} đang sử dụng Khóa API.", - "modelProvider.card.modelNotSupported": "Các mô hình {{modelName}} chưa được cài đặt.", - "modelProvider.card.modelSupported": "{{modelName}} mô hình đang sử dụng hạn mức này.", "modelProvider.card.onTrial": "Thử nghiệm", "modelProvider.card.paid": "Đã thanh toán", "modelProvider.card.priorityUse": "Ưu tiên sử dụng", @@ -397,7 +394,6 @@ "modelProvider.quotaTip": "Số lượng mã thông báo miễn phí còn lại", "modelProvider.rerankModel.key": "Mô hình Sắp xếp lại", "modelProvider.rerankModel.tip": "Mô hình sắp xếp lại sẽ sắp xếp lại danh sách tài liệu ứng cử viên dựa trên sự phù hợp ngữ nghĩa với truy vấn của người dùng, cải thiện kết quả của việc xếp hạng ngữ nghĩa", - "modelProvider.resetDate": "Đặt lại vào {{date}}", "modelProvider.searchModel": "Mô hình tìm kiếm", "modelProvider.selectModel": "Chọn mô hình của bạn", "modelProvider.selector.emptySetting": "Vui lòng vào cài đặt để cấu hình", diff --git a/web/i18n/zh-Hant/common.json b/web/i18n/zh-Hant/common.json index 11846c9772..8fe3e5bd07 100644 --- a/web/i18n/zh-Hant/common.json +++ b/web/i18n/zh-Hant/common.json @@ -339,9 +339,6 @@ "modelProvider.callTimes": "呼叫次數", "modelProvider.card.buyQuota": "購買額度", "modelProvider.card.callTimes": "呼叫次數", - "modelProvider.card.modelAPI": "{{modelName}} 模型正在使用 API 金鑰。", - "modelProvider.card.modelNotSupported": "{{modelName}} 模型未安裝。", - "modelProvider.card.modelSupported": "{{modelName}} 模型正在使用這個配額。", "modelProvider.card.onTrial": "試用中", "modelProvider.card.paid": "已購買", "modelProvider.card.priorityUse": "優先使用", @@ -397,7 +394,6 @@ "modelProvider.quotaTip": "剩餘免費額度", "modelProvider.rerankModel.key": "Rerank 模型", "modelProvider.rerankModel.tip": "重排序模型將根據候選文件列表與使用者問題語義匹配度進行重新排序,從而改進語義排序的結果", - "modelProvider.resetDate": "在 {{date}} 重置", "modelProvider.searchModel": "搜尋模型", "modelProvider.selectModel": "選擇您的模型", "modelProvider.selector.emptySetting": "請前往設定進行配置", From 591ca05c844b9e415b4c4dd21083d68d2e5e46de Mon Sep 17 00:00:00 2001 From: scdeng Date: Mon, 5 Jan 2026 16:12:41 +0800 Subject: [PATCH 06/17] =?UTF-8?q?feat(logstore):=20make=20`graph`=20field?= =?UTF-8?q?=20optional=20via=20env=20variable=20LOGSTORE=E2=80=A6=20(#3055?= =?UTF-8?q?4)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 阿永 --- api/.env.example | 4 ++++ .../repositories/logstore_workflow_execution_repository.py | 7 ++++++- docker/.env.example | 4 ++++ docker/docker-compose.yaml | 1 + docker/middleware.env.example | 6 +++++- 5 files changed, 20 insertions(+), 2 deletions(-) diff --git a/api/.env.example b/api/.env.example index 88611e016e..44d770ed70 100644 --- a/api/.env.example +++ b/api/.env.example @@ -575,6 +575,10 @@ LOGSTORE_DUAL_WRITE_ENABLED=false # Enable dual-read fallback to SQL database when LogStore returns no results (default: true) # Useful for migration scenarios where historical data exists only in SQL database LOGSTORE_DUAL_READ_ENABLED=true +# Control flag for whether to write the `graph` field to LogStore. +# If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field; +# otherwise write an empty {} instead. Defaults to writing the `graph` field. +LOGSTORE_ENABLE_PUT_GRAPH_FIELD=true # Celery beat configuration CELERY_BEAT_SCHEDULER_TIME=1 diff --git a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py index a6b3706f42..1119534d52 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py @@ -81,6 +81,11 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository): # Set to True to enable dual-write for safe migration, False to use LogStore only self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true" + # Control flag for whether to write the `graph` field to LogStore. + # If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field; + # otherwise write an empty {} instead. Defaults to writing the `graph` field. + self._enable_put_graph_field = os.environ.get("LOGSTORE_ENABLE_PUT_GRAPH_FIELD", "true").lower() == "true" + def _to_logstore_model(self, domain_model: WorkflowExecution) -> list[tuple[str, str]]: """ Convert a domain model to a logstore model (List[Tuple[str, str]]). @@ -123,7 +128,7 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository): ( "graph", json.dumps(domain_model.graph, ensure_ascii=False, default=to_serializable) - if domain_model.graph + if domain_model.graph and self._enable_put_graph_field else "{}", ), ( diff --git a/docker/.env.example b/docker/.env.example index ecb003cd70..66d937e8e7 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -1077,6 +1077,10 @@ LOGSTORE_DUAL_WRITE_ENABLED=false # Enable dual-read fallback to SQL database when LogStore returns no results (default: true) # Useful for migration scenarios where historical data exists only in SQL database LOGSTORE_DUAL_READ_ENABLED=true +# Control flag for whether to write the `graph` field to LogStore. +# If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field; +# otherwise write an empty {} instead. Defaults to writing the `graph` field. +LOGSTORE_ENABLE_PUT_GRAPH_FIELD=true # HTTP request node in workflow configuration HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index a67141ce05..54b9e744f8 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -475,6 +475,7 @@ x-shared-env: &shared-api-worker-env ALIYUN_SLS_LOGSTORE_TTL: ${ALIYUN_SLS_LOGSTORE_TTL:-365} LOGSTORE_DUAL_WRITE_ENABLED: ${LOGSTORE_DUAL_WRITE_ENABLED:-false} LOGSTORE_DUAL_READ_ENABLED: ${LOGSTORE_DUAL_READ_ENABLED:-true} + LOGSTORE_ENABLE_PUT_GRAPH_FIELD: ${LOGSTORE_ENABLE_PUT_GRAPH_FIELD:-true} HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760} HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576} HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True} diff --git a/docker/middleware.env.example b/docker/middleware.env.example index f7e0252a6f..c88dbe5511 100644 --- a/docker/middleware.env.example +++ b/docker/middleware.env.example @@ -233,4 +233,8 @@ ALIYUN_SLS_LOGSTORE_TTL=365 LOGSTORE_DUAL_WRITE_ENABLED=true # Enable dual-read fallback to SQL database when LogStore returns no results (default: true) # Useful for migration scenarios where historical data exists only in SQL database -LOGSTORE_DUAL_READ_ENABLED=true \ No newline at end of file +LOGSTORE_DUAL_READ_ENABLED=true +# Control flag for whether to write the `graph` field to LogStore. +# If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field; +# otherwise write an empty {} instead. Defaults to writing the `graph` field. +LOGSTORE_ENABLE_PUT_GRAPH_FIELD=true \ No newline at end of file From 6f8bd58e19ddef53d604d76058a9f99863017d52 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 5 Jan 2026 16:43:42 +0800 Subject: [PATCH 07/17] feat(graph-engine): make layer runtime state non-null and bound early (#30552) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../app/layers/pause_state_persist_layer.py | 3 +- api/core/app/layers/trigger_post_layer.py | 5 +- api/core/workflow/README.md | 3 + .../workflow/graph_engine/graph_engine.py | 14 ++--- .../workflow/graph_engine/layers/README.md | 5 +- api/core/workflow/graph_engine/layers/base.py | 25 +++++++-- .../graph_engine/layers/debug_logging.py | 11 ++-- .../graph_engine/layers/persistence.py | 4 -- .../layers/test_pause_state_persist_layer.py | 7 ++- .../layers/test_pause_state_persist_layer.py | 10 ++-- .../layers/test_layer_initialization.py | 56 +++++++++++++++++++ 11 files changed, 105 insertions(+), 38 deletions(-) create mode 100644 api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py index 61a3e1baca..bf76ae8178 100644 --- a/api/core/app/layers/pause_state_persist_layer.py +++ b/api/core/app/layers/pause_state_persist_layer.py @@ -66,6 +66,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer): """ if isinstance(session_factory, Engine): session_factory = sessionmaker(session_factory) + super().__init__() self._session_maker = session_factory self._state_owner_user_id = state_owner_user_id self._generate_entity = generate_entity @@ -98,8 +99,6 @@ class PauseStatePersistenceLayer(GraphEngineLayer): if not isinstance(event, GraphRunPausedEvent): return - assert self.graph_runtime_state is not None - entity_wrapper: _GenerateEntityUnion if isinstance(self._generate_entity, WorkflowAppGenerateEntity): entity_wrapper = _WorkflowGenerateEntityWrapper(entity=self._generate_entity) diff --git a/api/core/app/layers/trigger_post_layer.py b/api/core/app/layers/trigger_post_layer.py index fe1a46a945..225b758fcb 100644 --- a/api/core/app/layers/trigger_post_layer.py +++ b/api/core/app/layers/trigger_post_layer.py @@ -33,6 +33,7 @@ class TriggerPostLayer(GraphEngineLayer): trigger_log_id: str, session_maker: sessionmaker[Session], ): + super().__init__() self.trigger_log_id = trigger_log_id self.start_time = start_time self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity @@ -57,10 +58,6 @@ class TriggerPostLayer(GraphEngineLayer): elapsed_time = (datetime.now(UTC) - self.start_time).total_seconds() # Extract relevant data from result - if not self.graph_runtime_state: - logger.exception("Graph runtime state is not set") - return - outputs = self.graph_runtime_state.outputs # BASICLY, workflow_execution_id is the same as workflow_run_id diff --git a/api/core/workflow/README.md b/api/core/workflow/README.md index 72f5dbe1e2..9a39f976a6 100644 --- a/api/core/workflow/README.md +++ b/api/core/workflow/README.md @@ -64,6 +64,9 @@ engine.layer(DebugLoggingLayer(level="INFO")) engine.layer(ExecutionLimitsLayer(max_nodes=100)) ``` +`engine.layer()` binds the read-only runtime state before execution, so layer hooks +can assume `graph_runtime_state` is available. + ### Event-Driven Architecture All node executions emit events for monitoring and integration: diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 2e8b8f345f..e4816afaf8 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -212,9 +212,16 @@ class GraphEngine: if id(node.graph_runtime_state) != expected_state_id: raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance") + def _bind_layer_context( + self, + layer: GraphEngineLayer, + ) -> None: + layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel) + def layer(self, layer: GraphEngineLayer) -> "GraphEngine": """Add a layer for extending functionality.""" self._layers.append(layer) + self._bind_layer_context(layer) return self def run(self) -> Generator[GraphEngineEvent, None, None]: @@ -301,14 +308,7 @@ class GraphEngine: def _initialize_layers(self) -> None: """Initialize layers with context.""" self._event_manager.set_layers(self._layers) - # Create a read-only wrapper for the runtime state - read_only_state = ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state) for layer in self._layers: - try: - layer.initialize(read_only_state, self._command_channel) - except Exception as e: - logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e) - try: layer.on_graph_start() except Exception as e: diff --git a/api/core/workflow/graph_engine/layers/README.md b/api/core/workflow/graph_engine/layers/README.md index 17845ee1f0..b0f295037c 100644 --- a/api/core/workflow/graph_engine/layers/README.md +++ b/api/core/workflow/graph_engine/layers/README.md @@ -8,7 +8,7 @@ Pluggable middleware for engine extensions. Abstract base class for layers. -- `initialize()` - Receive runtime context +- `initialize()` - Receive runtime context (runtime state is bound here and always available to hooks) - `on_graph_start()` - Execution start hook - `on_event()` - Process all events - `on_graph_end()` - Execution end hook @@ -34,6 +34,9 @@ engine.layer(debug_layer) engine.run() ``` +`engine.layer()` binds the read-only runtime state before execution, so +`graph_runtime_state` is always available inside layer hooks. + ## Custom Layers ```python diff --git a/api/core/workflow/graph_engine/layers/base.py b/api/core/workflow/graph_engine/layers/base.py index 780f92a0f4..89293b9b30 100644 --- a/api/core/workflow/graph_engine/layers/base.py +++ b/api/core/workflow/graph_engine/layers/base.py @@ -13,6 +13,14 @@ from core.workflow.nodes.base.node import Node from core.workflow.runtime import ReadOnlyGraphRuntimeState +class GraphEngineLayerNotInitializedError(Exception): + """Raised when a layer's runtime state is accessed before initialization.""" + + def __init__(self, layer_name: str | None = None) -> None: + name = layer_name or "GraphEngineLayer" + super().__init__(f"{name} runtime state is not initialized. Bind the layer to a GraphEngine before access.") + + class GraphEngineLayer(ABC): """ Abstract base class for GraphEngine layers. @@ -28,22 +36,27 @@ class GraphEngineLayer(ABC): def __init__(self) -> None: """Initialize the layer. Subclasses can override with custom parameters.""" - self.graph_runtime_state: ReadOnlyGraphRuntimeState | None = None + self._graph_runtime_state: ReadOnlyGraphRuntimeState | None = None self.command_channel: CommandChannel | None = None + @property + def graph_runtime_state(self) -> ReadOnlyGraphRuntimeState: + if self._graph_runtime_state is None: + raise GraphEngineLayerNotInitializedError(type(self).__name__) + return self._graph_runtime_state + def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None: """ Initialize the layer with engine dependencies. - Called by GraphEngine before execution starts to inject the read-only runtime state - and command channel. This allows layers to observe engine context and send - commands, but prevents direct state modification. - + Called by GraphEngine to inject the read-only runtime state and command channel. + This is invoked when the layer is registered with a `GraphEngine` instance. + Implementations should be idempotent. Args: graph_runtime_state: Read-only view of the runtime state command_channel: Channel for sending commands to the engine """ - self.graph_runtime_state = graph_runtime_state + self._graph_runtime_state = graph_runtime_state self.command_channel = command_channel @abstractmethod diff --git a/api/core/workflow/graph_engine/layers/debug_logging.py b/api/core/workflow/graph_engine/layers/debug_logging.py index 034ebcf54f..e0402cd09c 100644 --- a/api/core/workflow/graph_engine/layers/debug_logging.py +++ b/api/core/workflow/graph_engine/layers/debug_logging.py @@ -109,10 +109,8 @@ class DebugLoggingLayer(GraphEngineLayer): self.logger.info("=" * 80) self.logger.info("🚀 GRAPH EXECUTION STARTED") self.logger.info("=" * 80) - - if self.graph_runtime_state: - # Log initial state - self.logger.info("Initial State:") + # Log initial state + self.logger.info("Initial State:") @override def on_event(self, event: GraphEngineEvent) -> None: @@ -243,8 +241,7 @@ class DebugLoggingLayer(GraphEngineLayer): self.logger.info(" Node retries: %s", self.retry_count) # Log final state if available - if self.graph_runtime_state and self.include_outputs: - if self.graph_runtime_state.outputs: - self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs)) + if self.include_outputs and self.graph_runtime_state.outputs: + self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs)) self.logger.info("=" * 80) diff --git a/api/core/workflow/graph_engine/layers/persistence.py b/api/core/workflow/graph_engine/layers/persistence.py index b70f36ec9e..e81df4f3b7 100644 --- a/api/core/workflow/graph_engine/layers/persistence.py +++ b/api/core/workflow/graph_engine/layers/persistence.py @@ -337,8 +337,6 @@ class WorkflowPersistenceLayer(GraphEngineLayer): if update_finished: execution.finished_at = naive_utc_now() runtime_state = self.graph_runtime_state - if runtime_state is None: - return execution.total_tokens = runtime_state.total_tokens execution.total_steps = runtime_state.node_run_steps execution.outputs = execution.outputs or runtime_state.outputs @@ -404,6 +402,4 @@ class WorkflowPersistenceLayer(GraphEngineLayer): def _system_variables(self) -> Mapping[str, Any]: runtime_state = self.graph_runtime_state - if runtime_state is None: - return {} return runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID) diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py index 72469ad646..dcf31aeca7 100644 --- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -35,6 +35,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage from core.workflow.entities.pause_reason import SchedulingPause from core.workflow.enums import WorkflowExecutionStatus from core.workflow.graph_engine.entities.commands import GraphEngineCommand +from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError from core.workflow.graph_events.graph import GraphRunPausedEvent from core.workflow.runtime.graph_runtime_state import GraphRuntimeState from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState @@ -569,10 +570,10 @@ class TestPauseStatePersistenceLayerTestContainers: """Test that layer requires proper initialization before handling events.""" # Arrange layer = self._create_pause_state_persistence_layer() - # Don't initialize - graph_runtime_state should not be set + # Don't initialize - graph_runtime_state should be uninitialized event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")]) - # Act & Assert - Should raise AttributeError - with pytest.raises(AttributeError): + # Act & Assert - Should raise GraphEngineLayerNotInitializedError + with pytest.raises(GraphEngineLayerNotInitializedError): layer.on_event(event) diff --git a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py index 534420f21e..5a5386ee57 100644 --- a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py @@ -15,6 +15,7 @@ from core.app.layers.pause_state_persist_layer import ( from core.variables.segments import Segment from core.workflow.entities.pause_reason import SchedulingPause from core.workflow.graph_engine.entities.commands import GraphEngineCommand +from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError from core.workflow.graph_events.graph import ( GraphRunFailedEvent, GraphRunPausedEvent, @@ -209,8 +210,9 @@ class TestPauseStatePersistenceLayer: assert layer._session_maker is session_factory assert layer._state_owner_user_id == state_owner_user_id - assert not hasattr(layer, "graph_runtime_state") - assert not hasattr(layer, "command_channel") + with pytest.raises(GraphEngineLayerNotInitializedError): + _ = layer.graph_runtime_state + assert layer.command_channel is None def test_initialize_sets_dependencies(self): session_factory = Mock(name="session_factory") @@ -295,7 +297,7 @@ class TestPauseStatePersistenceLayer: mock_factory.assert_not_called() mock_repo.create_workflow_pause.assert_not_called() - def test_on_event_raises_attribute_error_when_graph_runtime_state_is_none(self): + def test_on_event_raises_when_graph_runtime_state_is_uninitialized(self): session_factory = Mock(name="session_factory") layer = PauseStatePersistenceLayer( session_factory=session_factory, @@ -305,7 +307,7 @@ class TestPauseStatePersistenceLayer: event = TestDataFactory.create_graph_run_paused_event() - with pytest.raises(AttributeError): + with pytest.raises(GraphEngineLayerNotInitializedError): layer.on_event(event) def test_on_event_asserts_when_workflow_execution_id_missing(self, monkeypatch: pytest.MonkeyPatch): diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py new file mode 100644 index 0000000000..d6ba61c50c --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import pytest + +from core.workflow.graph_engine import GraphEngine +from core.workflow.graph_engine.command_channels import InMemoryChannel +from core.workflow.graph_engine.layers.base import ( + GraphEngineLayer, + GraphEngineLayerNotInitializedError, +) +from core.workflow.graph_events import GraphEngineEvent + +from ..test_table_runner import WorkflowRunner + + +class LayerForTest(GraphEngineLayer): + def on_graph_start(self) -> None: + pass + + def on_event(self, event: GraphEngineEvent) -> None: + pass + + def on_graph_end(self, error: Exception | None) -> None: + pass + + +def test_layer_runtime_state_raises_when_uninitialized() -> None: + layer = LayerForTest() + + with pytest.raises(GraphEngineLayerNotInitializedError): + _ = layer.graph_runtime_state + + +def test_layer_runtime_state_available_after_engine_layer() -> None: + runner = WorkflowRunner() + fixture_data = runner.load_fixture("simple_passthrough_workflow") + graph, graph_runtime_state = runner.create_graph_from_fixture( + fixture_data, + inputs={"query": "test layer state"}, + ) + engine = GraphEngine( + workflow_id="test_workflow", + graph=graph, + graph_runtime_state=graph_runtime_state, + command_channel=InMemoryChannel(), + ) + + layer = LayerForTest() + engine.layer(layer) + + outputs = layer.graph_runtime_state.outputs + ready_queue_size = layer.graph_runtime_state.ready_queue_size + + assert outputs == {} + assert isinstance(ready_queue_size, int) + assert ready_queue_size >= 0 From a9e2c05a100413426392d71a5d243950d7a50f8c Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 5 Jan 2026 16:47:34 +0800 Subject: [PATCH 08/17] feat(graph-engine): add command to update variables at runtime (#30563) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../command_channels/redis_channel.py | 4 +- .../command_processing/__init__.py | 3 +- .../command_processing/command_handlers.py | 25 ++++++- .../graph_engine/entities/commands.py | 23 +++++- .../workflow/graph_engine/graph_engine.py | 12 ++- api/core/workflow/graph_engine/manager.py | 21 +++++- .../command_channels/test_redis_channel.py | 46 +++++++++++- .../graph_engine/test_command_system.py | 73 ++++++++++++++++++- 8 files changed, 194 insertions(+), 13 deletions(-) diff --git a/api/core/workflow/graph_engine/command_channels/redis_channel.py b/api/core/workflow/graph_engine/command_channels/redis_channel.py index 4be3adb8f8..0fccd4a0fd 100644 --- a/api/core/workflow/graph_engine/command_channels/redis_channel.py +++ b/api/core/workflow/graph_engine/command_channels/redis_channel.py @@ -9,7 +9,7 @@ Each instance uses a unique key for its command queue. import json from typing import TYPE_CHECKING, Any, final -from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand +from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand if TYPE_CHECKING: from extensions.ext_redis import RedisClientWrapper @@ -113,6 +113,8 @@ class RedisChannel: return AbortCommand.model_validate(data) if command_type == CommandType.PAUSE: return PauseCommand.model_validate(data) + if command_type == CommandType.UPDATE_VARIABLES: + return UpdateVariablesCommand.model_validate(data) # For other command types, use base class return GraphEngineCommand.model_validate(data) diff --git a/api/core/workflow/graph_engine/command_processing/__init__.py b/api/core/workflow/graph_engine/command_processing/__init__.py index 837f5e55fd..7b4f0dfff7 100644 --- a/api/core/workflow/graph_engine/command_processing/__init__.py +++ b/api/core/workflow/graph_engine/command_processing/__init__.py @@ -5,11 +5,12 @@ This package handles external commands sent to the engine during execution. """ -from .command_handlers import AbortCommandHandler, PauseCommandHandler +from .command_handlers import AbortCommandHandler, PauseCommandHandler, UpdateVariablesCommandHandler from .command_processor import CommandProcessor __all__ = [ "AbortCommandHandler", "CommandProcessor", "PauseCommandHandler", + "UpdateVariablesCommandHandler", ] diff --git a/api/core/workflow/graph_engine/command_processing/command_handlers.py b/api/core/workflow/graph_engine/command_processing/command_handlers.py index e9f109c88c..cfe856d9e8 100644 --- a/api/core/workflow/graph_engine/command_processing/command_handlers.py +++ b/api/core/workflow/graph_engine/command_processing/command_handlers.py @@ -4,9 +4,10 @@ from typing import final from typing_extensions import override from core.workflow.entities.pause_reason import SchedulingPause +from core.workflow.runtime import VariablePool from ..domain.graph_execution import GraphExecution -from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand +from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand, UpdateVariablesCommand from .command_processor import CommandHandler logger = logging.getLogger(__name__) @@ -31,3 +32,25 @@ class PauseCommandHandler(CommandHandler): reason = command.reason pause_reason = SchedulingPause(message=reason) execution.pause(pause_reason) + + +@final +class UpdateVariablesCommandHandler(CommandHandler): + def __init__(self, variable_pool: VariablePool) -> None: + self._variable_pool = variable_pool + + @override + def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: + assert isinstance(command, UpdateVariablesCommand) + for update in command.updates: + try: + variable = update.value + self._variable_pool.add(variable.selector, variable) + logger.debug("Updated variable %s for workflow %s", variable.selector, execution.workflow_id) + except ValueError as exc: + logger.warning( + "Skipping invalid variable selector %s for workflow %s: %s", + getattr(update.value, "selector", None), + execution.workflow_id, + exc, + ) diff --git a/api/core/workflow/graph_engine/entities/commands.py b/api/core/workflow/graph_engine/entities/commands.py index 0d51b2b716..6dce03c94d 100644 --- a/api/core/workflow/graph_engine/entities/commands.py +++ b/api/core/workflow/graph_engine/entities/commands.py @@ -5,17 +5,21 @@ This module defines command types that can be sent to a running GraphEngine instance to control its execution flow. """ -from enum import StrEnum +from collections.abc import Sequence +from enum import StrEnum, auto from typing import Any from pydantic import BaseModel, Field +from core.variables.variables import VariableUnion + class CommandType(StrEnum): """Types of commands that can be sent to GraphEngine.""" - ABORT = "abort" - PAUSE = "pause" + ABORT = auto() + PAUSE = auto() + UPDATE_VARIABLES = auto() class GraphEngineCommand(BaseModel): @@ -37,3 +41,16 @@ class PauseCommand(GraphEngineCommand): command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command") reason: str = Field(default="unknown reason", description="reason for pause") + + +class VariableUpdate(BaseModel): + """Represents a single variable update instruction.""" + + value: VariableUnion = Field(description="New variable value") + + +class UpdateVariablesCommand(GraphEngineCommand): + """Command to update a group of variables in the variable pool.""" + + command_type: CommandType = Field(default=CommandType.UPDATE_VARIABLES, description="Type of command") + updates: Sequence[VariableUpdate] = Field(default_factory=list, description="Variable updates") diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index e4816afaf8..88d6e5cac1 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -30,8 +30,13 @@ from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWr if TYPE_CHECKING: # pragma: no cover - used only for static analysis from core.workflow.runtime.graph_runtime_state import GraphProtocol -from .command_processing import AbortCommandHandler, CommandProcessor, PauseCommandHandler -from .entities.commands import AbortCommand, PauseCommand +from .command_processing import ( + AbortCommandHandler, + CommandProcessor, + PauseCommandHandler, + UpdateVariablesCommandHandler, +) +from .entities.commands import AbortCommand, PauseCommand, UpdateVariablesCommand from .error_handler import ErrorHandler from .event_management import EventHandler, EventManager from .graph_state_manager import GraphStateManager @@ -140,6 +145,9 @@ class GraphEngine: pause_handler = PauseCommandHandler() self._command_processor.register_handler(PauseCommand, pause_handler) + update_variables_handler = UpdateVariablesCommandHandler(self._graph_runtime_state.variable_pool) + self._command_processor.register_handler(UpdateVariablesCommand, update_variables_handler) + # === Extensibility === # Layers allow plugins to extend engine functionality self._layers: list[GraphEngineLayer] = [] diff --git a/api/core/workflow/graph_engine/manager.py b/api/core/workflow/graph_engine/manager.py index 0577ba8f02..d2cfa755d9 100644 --- a/api/core/workflow/graph_engine/manager.py +++ b/api/core/workflow/graph_engine/manager.py @@ -3,14 +3,20 @@ GraphEngine Manager for sending control commands via Redis channel. This module provides a simplified interface for controlling workflow executions using the new Redis command channel, without requiring user permission checks. -Supports stop, pause, and resume operations. """ import logging +from collections.abc import Sequence from typing import final from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel -from core.workflow.graph_engine.entities.commands import AbortCommand, GraphEngineCommand, PauseCommand +from core.workflow.graph_engine.entities.commands import ( + AbortCommand, + GraphEngineCommand, + PauseCommand, + UpdateVariablesCommand, + VariableUpdate, +) from extensions.ext_redis import redis_client logger = logging.getLogger(__name__) @@ -23,7 +29,6 @@ class GraphEngineManager: This class provides a simple interface for controlling workflow executions by sending commands through Redis channels, without user validation. - Supports stop and pause operations. """ @staticmethod @@ -45,6 +50,16 @@ class GraphEngineManager: pause_command = PauseCommand(reason=reason or "User requested pause") GraphEngineManager._send_command(task_id, pause_command) + @staticmethod + def send_update_variables_command(task_id: str, updates: Sequence[VariableUpdate]) -> None: + """Send a command to update variables in a running workflow.""" + + if not updates: + return + + update_command = UpdateVariablesCommand(updates=updates) + GraphEngineManager._send_command(task_id, update_command) + @staticmethod def _send_command(task_id: str, command: GraphEngineCommand) -> None: """Send a command to the workflow-specific Redis channel.""" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py b/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py index 8677325d4e..f33fd0deeb 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py @@ -3,8 +3,15 @@ import json from unittest.mock import MagicMock +from core.variables import IntegerVariable, StringVariable from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel -from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, GraphEngineCommand +from core.workflow.graph_engine.entities.commands import ( + AbortCommand, + CommandType, + GraphEngineCommand, + UpdateVariablesCommand, + VariableUpdate, +) class TestRedisChannel: @@ -148,6 +155,43 @@ class TestRedisChannel: assert commands[0].command_type == CommandType.ABORT assert isinstance(commands[1], AbortCommand) + def test_fetch_commands_with_update_variables_command(self): + """Test fetching update variables command from Redis.""" + mock_redis = MagicMock() + pending_pipe = MagicMock() + fetch_pipe = MagicMock() + pending_context = MagicMock() + fetch_context = MagicMock() + pending_context.__enter__.return_value = pending_pipe + pending_context.__exit__.return_value = None + fetch_context.__enter__.return_value = fetch_pipe + fetch_context.__exit__.return_value = None + mock_redis.pipeline.side_effect = [pending_context, fetch_context] + + update_command = UpdateVariablesCommand( + updates=[ + VariableUpdate( + value=StringVariable(name="foo", value="bar", selector=["node1", "foo"]), + ), + VariableUpdate( + value=IntegerVariable(name="baz", value=123, selector=["node2", "baz"]), + ), + ] + ) + command_json = json.dumps(update_command.model_dump()) + + pending_pipe.execute.return_value = [b"1", 1] + fetch_pipe.execute.return_value = [[command_json.encode()], 1] + + channel = RedisChannel(mock_redis, "test:key") + commands = channel.fetch_commands() + + assert len(commands) == 1 + assert isinstance(commands[0], UpdateVariablesCommand) + assert isinstance(commands[0].updates[0].value, StringVariable) + assert list(commands[0].updates[0].value.selector) == ["node1", "foo"] + assert commands[0].updates[0].value.value == "bar" + def test_fetch_commands_skips_invalid_json(self): """Test that invalid JSON commands are skipped.""" mock_redis = MagicMock() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py index b074a11be9..d826f7a900 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py @@ -4,12 +4,19 @@ import time from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom +from core.variables import IntegerVariable, StringVariable from core.workflow.entities.graph_init_params import GraphInitParams from core.workflow.entities.pause_reason import SchedulingPause from core.workflow.graph import Graph from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine.command_channels import InMemoryChannel -from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand +from core.workflow.graph_engine.entities.commands import ( + AbortCommand, + CommandType, + PauseCommand, + UpdateVariablesCommand, + VariableUpdate, +) from core.workflow.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, GraphRunStartedEvent from core.workflow.nodes.start.start_node import StartNode from core.workflow.runtime import GraphRuntimeState, VariablePool @@ -180,3 +187,67 @@ def test_pause_command(): graph_execution = engine.graph_runtime_state.graph_execution assert graph_execution.pause_reasons == [SchedulingPause(message="User requested pause")] + + +def test_update_variables_command_updates_pool(): + """Test that GraphEngine updates variable pool via update variables command.""" + + shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + shared_runtime_state.variable_pool.add(("node1", "foo"), "old value") + + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + mock_graph.root_node.id = "start" + + start_node = StartNode( + id="start", + config={"id": "start", "data": {"title": "start", "variables": []}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=shared_runtime_state, + ) + mock_graph.nodes["start"] = start_node + + mock_graph.get_outgoing_edges = MagicMock(return_value=[]) + mock_graph.get_incoming_edges = MagicMock(return_value=[]) + + command_channel = InMemoryChannel() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=shared_runtime_state, + command_channel=command_channel, + ) + + update_command = UpdateVariablesCommand( + updates=[ + VariableUpdate( + value=StringVariable(name="foo", value="new value", selector=["node1", "foo"]), + ), + VariableUpdate( + value=IntegerVariable(name="bar", value=123, selector=["node2", "bar"]), + ), + ] + ) + command_channel.send_command(update_command) + + list(engine.run()) + + updated_existing = shared_runtime_state.variable_pool.get(["node1", "foo"]) + added_new = shared_runtime_state.variable_pool.get(["node2", "bar"]) + + assert updated_existing is not None + assert updated_existing.value == "new value" + assert added_new is not None + assert added_new.value == 123 From de6262784c9a64101bcbfda4e2f4783211e94c08 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 5 Jan 2026 20:19:26 +0800 Subject: [PATCH 09/17] chore: Harden API image Node.js runtime install (#30497) --- api/Dockerfile | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/api/Dockerfile b/api/Dockerfile index e800e60322..a08d4e3aab 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -50,16 +50,33 @@ WORKDIR /app/api # Create non-root user ARG dify_uid=1001 +ARG NODE_MAJOR=22 +ARG NODE_PACKAGE_VERSION=22.21.0-1nodesource1 +ARG NODESOURCE_KEY_FPR=6F71F525282841EEDAF851B42F59B5F99B1BE0B4 RUN groupadd -r -g ${dify_uid} dify && \ useradd -r -u ${dify_uid} -g ${dify_uid} -s /bin/bash dify && \ chown -R dify:dify /app RUN \ apt-get update \ + && apt-get install -y --no-install-recommends \ + ca-certificates \ + curl \ + gnupg \ + && mkdir -p /etc/apt/keyrings \ + && curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key -o /tmp/nodesource.gpg \ + && gpg --show-keys --with-colons /tmp/nodesource.gpg \ + | awk -F: '/^fpr:/ {print $10}' \ + | grep -Fx "${NODESOURCE_KEY_FPR}" \ + && gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg /tmp/nodesource.gpg \ + && rm -f /tmp/nodesource.gpg \ + && echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_${NODE_MAJOR}.x nodistro main" \ + > /etc/apt/sources.list.d/nodesource.list \ + && apt-get update \ # Install dependencies && apt-get install -y --no-install-recommends \ # basic environment - curl nodejs \ + nodejs=${NODE_PACKAGE_VERSION} \ # for gmpy2 \ libgmp-dev libmpfr-dev libmpc-dev \ # For Security From 615c313f80513b96430d62f0e553a4edf706336a Mon Sep 17 00:00:00 2001 From: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com> Date: Tue, 6 Jan 2026 09:56:30 +0800 Subject: [PATCH 10/17] fix(api): refactors the SQL LIKE pattern escaping logic to use a centralized utility function, ensuring consistent and secure handling of special characters across all database queries. (#30450) Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/app/conversation.py | 22 ++- .../console/datasets/datasets_segments.py | 9 +- .../vdb/clickzetta/clickzetta_vector.py | 8 +- .../rag/datasource/vdb/iris/iris_vector.py | 8 +- api/core/rag/retrieval/dataset_retrieval.py | 14 +- api/libs/helper.py | 32 ++++ api/services/annotation_service.py | 7 +- api/services/app_service.py | 5 +- api/services/conversation_service.py | 4 +- api/services/dataset_service.py | 9 +- api/services/external_knowledge_service.py | 5 +- api/services/tag_service.py | 5 +- api/services/workflow_app_service.py | 15 +- .../services/test_annotation_service.py | 72 +++++++ .../services/test_app_service.py | 179 +++++++++++++++++- .../services/test_tag_service.py | 80 ++++++++ .../services/test_workflow_app_service.py | 160 +++++++++++++++- api/tests/unit_tests/libs/test_helper.py | 50 ++++- 18 files changed, 648 insertions(+), 36 deletions(-) diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index ef2f86d4be..56816dd462 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -348,10 +348,13 @@ class CompletionConversationApi(Resource): ) if args.keyword: + from libs.helper import escape_like_pattern + + escaped_keyword = escape_like_pattern(args.keyword) query = query.join(Message, Message.conversation_id == Conversation.id).where( or_( - Message.query.ilike(f"%{args.keyword}%"), - Message.answer.ilike(f"%{args.keyword}%"), + Message.query.ilike(f"%{escaped_keyword}%", escape="\\"), + Message.answer.ilike(f"%{escaped_keyword}%", escape="\\"), ) ) @@ -460,7 +463,10 @@ class ChatConversationApi(Resource): query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False)) if args.keyword: - keyword_filter = f"%{args.keyword}%" + from libs.helper import escape_like_pattern + + escaped_keyword = escape_like_pattern(args.keyword) + keyword_filter = f"%{escaped_keyword}%" query = ( query.join( Message, @@ -469,11 +475,11 @@ class ChatConversationApi(Resource): .join(subquery, subquery.c.conversation_id == Conversation.id) .where( or_( - Message.query.ilike(keyword_filter), - Message.answer.ilike(keyword_filter), - Conversation.name.ilike(keyword_filter), - Conversation.introduction.ilike(keyword_filter), - subquery.c.from_end_user_session_id.ilike(keyword_filter), + Message.query.ilike(keyword_filter, escape="\\"), + Message.answer.ilike(keyword_filter, escape="\\"), + Conversation.name.ilike(keyword_filter, escape="\\"), + Conversation.introduction.ilike(keyword_filter, escape="\\"), + subquery.c.from_end_user_session_id.ilike(keyword_filter, escape="\\"), ), ) .group_by(Conversation.id) diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 5a536af6d2..16fecb41c6 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -30,6 +30,7 @@ from core.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.segment_fields import child_chunk_fields, segment_fields +from libs.helper import escape_like_pattern from libs.login import current_account_with_tenant, login_required from models.dataset import ChildChunk, DocumentSegment from models.model import UploadFile @@ -145,6 +146,8 @@ class DatasetDocumentSegmentListApi(Resource): query = query.where(DocumentSegment.hit_count >= hit_count_gte) if keyword: + # Escape special characters in keyword to prevent SQL injection via LIKE wildcards + escaped_keyword = escape_like_pattern(keyword) # Search in both content and keywords fields # Use database-specific methods for JSON array search if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql": @@ -156,15 +159,15 @@ class DatasetDocumentSegmentListApi(Resource): .scalar_subquery() ), ",", - ).ilike(f"%{keyword}%") + ).ilike(f"%{escaped_keyword}%", escape="\\") else: # MySQL: Cast JSON to string for pattern matching # MySQL stores Chinese text directly in JSON without Unicode escaping - keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{keyword}%") + keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{escaped_keyword}%", escape="\\") query = query.where( or_( - DocumentSegment.content.ilike(f"%{keyword}%"), + DocumentSegment.content.ilike(f"%{escaped_keyword}%", escape="\\"), keywords_condition, ) ) diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py index a306f9ba0c..e05b70ba22 100644 --- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py @@ -984,9 +984,11 @@ class ClickzettaVector(BaseVector): # No need for dataset_id filter since each dataset has its own table - # Use simple quote escaping for LIKE clause - escaped_query = query.replace("'", "''") - filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%'") + # Escape special characters for LIKE clause to prevent SQL injection + from libs.helper import escape_like_pattern + + escaped_query = escape_like_pattern(query).replace("'", "''") + filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%' ESCAPE '\\\\'") where_clause = " AND ".join(filter_clauses) search_sql = f""" diff --git a/api/core/rag/datasource/vdb/iris/iris_vector.py b/api/core/rag/datasource/vdb/iris/iris_vector.py index b1bfabb76e..5bdb0af0b3 100644 --- a/api/core/rag/datasource/vdb/iris/iris_vector.py +++ b/api/core/rag/datasource/vdb/iris/iris_vector.py @@ -287,11 +287,15 @@ class IrisVector(BaseVector): cursor.execute(sql, (query,)) else: # Fallback to LIKE search (inefficient for large datasets) - query_pattern = f"%{query}%" + # Escape special characters for LIKE clause to prevent SQL injection + from libs.helper import escape_like_pattern + + escaped_query = escape_like_pattern(query) + query_pattern = f"%{escaped_query}%" sql = f""" SELECT TOP {top_k} id, text, meta FROM {self.schema}.{self.table_name} - WHERE text LIKE ? + WHERE text LIKE ? ESCAPE '\\' """ cursor.execute(sql, (query_pattern,)) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index c6339aa3ba..f8f85d141a 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -1198,18 +1198,24 @@ class DatasetRetrieval: json_field = DatasetDocument.doc_metadata[metadata_name].as_string() + from libs.helper import escape_like_pattern + match condition: case "contains": - filters.append(json_field.like(f"%{value}%")) + escaped_value = escape_like_pattern(str(value)) + filters.append(json_field.like(f"%{escaped_value}%", escape="\\")) case "not contains": - filters.append(json_field.notlike(f"%{value}%")) + escaped_value = escape_like_pattern(str(value)) + filters.append(json_field.notlike(f"%{escaped_value}%", escape="\\")) case "start with": - filters.append(json_field.like(f"{value}%")) + escaped_value = escape_like_pattern(str(value)) + filters.append(json_field.like(f"{escaped_value}%", escape="\\")) case "end with": - filters.append(json_field.like(f"%{value}")) + escaped_value = escape_like_pattern(str(value)) + filters.append(json_field.like(f"%{escaped_value}", escape="\\")) case "is" | "=": if isinstance(value, str): diff --git a/api/libs/helper.py b/api/libs/helper.py index 74e1808e49..07c4823727 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -32,6 +32,38 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def escape_like_pattern(pattern: str) -> str: + """ + Escape special characters in a string for safe use in SQL LIKE patterns. + + This function escapes the special characters used in SQL LIKE patterns: + - Backslash (\\) -> \\ + - Percent (%) -> \\% + - Underscore (_) -> \\_ + + The escaped pattern can then be safely used in SQL LIKE queries with the + ESCAPE '\\' clause to prevent SQL injection via LIKE wildcards. + + Args: + pattern: The string pattern to escape + + Returns: + Escaped string safe for use in SQL LIKE queries + + Examples: + >>> escape_like_pattern("50% discount") + '50\\% discount' + >>> escape_like_pattern("test_data") + 'test\\_data' + >>> escape_like_pattern("path\\to\\file") + 'path\\\\to\\\\file' + """ + if not pattern: + return pattern + # Escape backslash first, then percent and underscore + return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + + def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None: """ Extract tenant_id from Account or EndUser object. diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 7f44fe05a6..b73302508a 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -137,13 +137,16 @@ class AppAnnotationService: if not app: raise NotFound("App not found") if keyword: + from libs.helper import escape_like_pattern + + escaped_keyword = escape_like_pattern(keyword) stmt = ( select(MessageAnnotation) .where(MessageAnnotation.app_id == app_id) .where( or_( - MessageAnnotation.question.ilike(f"%{keyword}%"), - MessageAnnotation.content.ilike(f"%{keyword}%"), + MessageAnnotation.question.ilike(f"%{escaped_keyword}%", escape="\\"), + MessageAnnotation.content.ilike(f"%{escaped_keyword}%", escape="\\"), ) ) .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) diff --git a/api/services/app_service.py b/api/services/app_service.py index ef89a4fd10..02ebfbace0 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -55,8 +55,11 @@ class AppService: if args.get("is_created_by_me", False): filters.append(App.created_by == user_id) if args.get("name"): + from libs.helper import escape_like_pattern + name = args["name"][:30] - filters.append(App.name.ilike(f"%{name}%")) + escaped_name = escape_like_pattern(name) + filters.append(App.name.ilike(f"%{escaped_name}%", escape="\\")) # Check if tag_ids is not empty to avoid WHERE false condition if args.get("tag_ids") and len(args["tag_ids"]) > 0: target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"]) diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 659e7406fb..f56de36ef7 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -218,7 +218,9 @@ class ConversationService: # Apply variable_name filter if provided if variable_name: # Filter using JSON extraction to match variable names case-insensitively - escaped_variable_name = variable_name.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + from libs.helper import escape_like_pattern + + escaped_variable_name = escape_like_pattern(variable_name) # Filter using JSON extraction to match variable names case-insensitively if dify_config.DB_TYPE in ["mysql", "oceanbase", "seekdb"]: stmt = stmt.where( diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index ac4b25c5dc..18e5613438 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -144,7 +144,8 @@ class DatasetService: query = query.where(Dataset.permission == DatasetPermissionEnum.ALL_TEAM) if search: - query = query.where(Dataset.name.ilike(f"%{search}%")) + escaped_search = helper.escape_like_pattern(search) + query = query.where(Dataset.name.ilike(f"%{escaped_search}%", escape="\\")) # Check if tag_ids is not empty to avoid WHERE false condition if tag_ids and len(tag_ids) > 0: @@ -3423,7 +3424,8 @@ class SegmentService: .order_by(ChildChunk.position.asc()) ) if keyword: - query = query.where(ChildChunk.content.ilike(f"%{keyword}%")) + escaped_keyword = helper.escape_like_pattern(keyword) + query = query.where(ChildChunk.content.ilike(f"%{escaped_keyword}%", escape="\\")) return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) @classmethod @@ -3456,7 +3458,8 @@ class SegmentService: query = query.where(DocumentSegment.status.in_(status_list)) if keyword: - query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) + escaped_keyword = helper.escape_like_pattern(keyword) + query = query.where(DocumentSegment.content.ilike(f"%{escaped_keyword}%", escape="\\")) query = query.order_by(DocumentSegment.position.asc(), DocumentSegment.id.asc()) paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 40faa85b9a..65dd41af43 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -35,7 +35,10 @@ class ExternalDatasetService: .order_by(ExternalKnowledgeApis.created_at.desc()) ) if search: - query = query.where(ExternalKnowledgeApis.name.ilike(f"%{search}%")) + from libs.helper import escape_like_pattern + + escaped_search = escape_like_pattern(search) + query = query.where(ExternalKnowledgeApis.name.ilike(f"%{escaped_search}%", escape="\\")) external_knowledge_apis = db.paginate( select=query, page=page, per_page=per_page, max_per_page=100, error_out=False diff --git a/api/services/tag_service.py b/api/services/tag_service.py index 937e6593fe..bd3585acf4 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -19,7 +19,10 @@ class TagService: .where(Tag.type == tag_type, Tag.tenant_id == current_tenant_id) ) if keyword: - query = query.where(sa.and_(Tag.name.ilike(f"%{keyword}%"))) + from libs.helper import escape_like_pattern + + escaped_keyword = escape_like_pattern(keyword) + query = query.where(sa.and_(Tag.name.ilike(f"%{escaped_keyword}%", escape="\\"))) query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at) results: list = query.order_by(Tag.created_at.desc()).all() return results diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index 01f0c7a55a..8574d30255 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -86,12 +86,19 @@ class WorkflowAppService: # Join to workflow run for filtering when needed. if keyword: - keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u") + from libs.helper import escape_like_pattern + + # Escape special characters in keyword to prevent SQL injection via LIKE wildcards + escaped_keyword = escape_like_pattern(keyword[:30]) + keyword_like_val = f"%{escaped_keyword}%" keyword_conditions = [ - WorkflowRun.inputs.ilike(keyword_like_val), - WorkflowRun.outputs.ilike(keyword_like_val), + WorkflowRun.inputs.ilike(keyword_like_val, escape="\\"), + WorkflowRun.outputs.ilike(keyword_like_val, escape="\\"), # filter keyword by end user session id if created by end user role - and_(WorkflowRun.created_by_role == "end_user", EndUser.session_id.ilike(keyword_like_val)), + and_( + WorkflowRun.created_by_role == "end_user", + EndUser.session_id.ilike(keyword_like_val, escape="\\"), + ), ] # filter keyword by workflow run id diff --git a/api/tests/test_containers_integration_tests/services/test_annotation_service.py b/api/tests/test_containers_integration_tests/services/test_annotation_service.py index da73122cd7..5555400ca6 100644 --- a/api/tests/test_containers_integration_tests/services/test_annotation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_annotation_service.py @@ -444,6 +444,78 @@ class TestAnnotationService: assert total == 1 assert unique_keyword in annotation_list[0].question or unique_keyword in annotation_list[0].content + def test_get_annotation_list_by_app_id_with_special_characters_in_keyword( + self, db_session_with_containers, mock_external_service_dependencies + ): + r""" + Test retrieval of annotation list with special characters in keyword to verify SQL injection prevention. + + This test verifies: + - Special characters (%, _, \) in keyword are properly escaped + - Search treats special characters as literal characters, not wildcards + - SQL injection via LIKE wildcards is prevented + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotations with special characters in content + annotation_with_percent = { + "question": "Question with 50% discount", + "answer": "Answer about 50% discount offer", + } + AppAnnotationService.insert_app_annotation_directly(annotation_with_percent, app.id) + + annotation_with_underscore = { + "question": "Question with test_data", + "answer": "Answer about test_data value", + } + AppAnnotationService.insert_app_annotation_directly(annotation_with_underscore, app.id) + + annotation_with_backslash = { + "question": "Question with path\\to\\file", + "answer": "Answer about path\\to\\file location", + } + AppAnnotationService.insert_app_annotation_directly(annotation_with_backslash, app.id) + + # Create annotation that should NOT match (contains % but as part of different text) + annotation_no_match = { + "question": "Question with 100% different", + "answer": "Answer about 100% different content", + } + AppAnnotationService.insert_app_annotation_directly(annotation_no_match, app.id) + + # Test 1: Search with % character - should find exact match only + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword="50%" + ) + assert total == 1 + assert len(annotation_list) == 1 + assert "50%" in annotation_list[0].question or "50%" in annotation_list[0].content + + # Test 2: Search with _ character - should find exact match only + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword="test_data" + ) + assert total == 1 + assert len(annotation_list) == 1 + assert "test_data" in annotation_list[0].question or "test_data" in annotation_list[0].content + + # Test 3: Search with \ character - should find exact match only + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword="path\\to\\file" + ) + assert total == 1 + assert len(annotation_list) == 1 + assert "path\\to\\file" in annotation_list[0].question or "path\\to\\file" in annotation_list[0].content + + # Test 4: Search with % should NOT match 100% (verifies escaping works) + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword="50%" + ) + # Should only find the 50% annotation, not the 100% one + assert total == 1 + assert all("50%" in (item.question or "") or "50%" in (item.content or "") for item in annotation_list) + def test_get_annotation_list_by_app_id_app_not_found( self, db_session_with_containers, mock_external_service_dependencies ): diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index e53392bcef..745d6c97b0 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -7,7 +7,9 @@ from constants.model_template import default_app_templates from models import Account from models.model import App, Site from services.account_service import AccountService, TenantService -from services.app_service import AppService + +# Delay import of AppService to avoid circular dependency +# from services.app_service import AppService class TestAppService: @@ -71,6 +73,9 @@ class TestAppService: } # Create app + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -109,6 +114,9 @@ class TestAppService: TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() # Test different app modes @@ -159,6 +167,9 @@ class TestAppService: "icon_background": "#45B7D1", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() created_app = app_service.create_app(tenant.id, app_args, account) @@ -194,6 +205,9 @@ class TestAppService: TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() # Create multiple apps @@ -245,6 +259,9 @@ class TestAppService: TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() # Create apps with different modes @@ -315,6 +332,9 @@ class TestAppService: TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() # Create an app @@ -392,6 +412,9 @@ class TestAppService: "icon_background": "#45B7D1", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -458,6 +481,9 @@ class TestAppService: "icon_background": "#45B7D1", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -508,6 +534,9 @@ class TestAppService: "icon_background": "#45B7D1", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -562,6 +591,9 @@ class TestAppService: "icon_background": "#74B9FF", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -617,6 +649,9 @@ class TestAppService: "icon_background": "#A29BFE", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -672,6 +707,9 @@ class TestAppService: "icon_background": "#FD79A8", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -720,6 +758,9 @@ class TestAppService: "icon_background": "#E17055", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -768,6 +809,9 @@ class TestAppService: "icon_background": "#00B894", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -826,6 +870,9 @@ class TestAppService: "icon_background": "#6C5CE7", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -862,6 +909,9 @@ class TestAppService: "icon_background": "#FDCB6E", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -899,6 +949,9 @@ class TestAppService: "icon_background": "#E84393", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -947,8 +1000,132 @@ class TestAppService: "icon_background": "#D63031", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() # Attempt to create app with invalid mode with pytest.raises(ValueError, match="invalid mode value"): app_service.create_app(tenant.id, app_args, account) + + def test_get_apps_with_special_characters_in_name( + self, db_session_with_containers, mock_external_service_dependencies + ): + r""" + Test app retrieval with special characters in name search to verify SQL injection prevention. + + This test verifies: + - Special characters (%, _, \) in name search are properly escaped + - Search treats special characters as literal characters, not wildcards + - SQL injection via LIKE wildcards is prevented + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Import here to avoid circular dependency + from services.app_service import AppService + + app_service = AppService() + + # Create apps with special characters in names + app_with_percent = app_service.create_app( + tenant.id, + { + "name": "App with 50% discount", + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + }, + account, + ) + + app_with_underscore = app_service.create_app( + tenant.id, + { + "name": "test_data_app", + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + }, + account, + ) + + app_with_backslash = app_service.create_app( + tenant.id, + { + "name": "path\\to\\app", + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + }, + account, + ) + + # Create app that should NOT match + app_no_match = app_service.create_app( + tenant.id, + { + "name": "100% different", + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + }, + account, + ) + + # Test 1: Search with % character + args = {"name": "50%", "mode": "chat", "page": 1, "limit": 10} + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + assert paginated_apps is not None + assert paginated_apps.total == 1 + assert len(paginated_apps.items) == 1 + assert paginated_apps.items[0].name == "App with 50% discount" + + # Test 2: Search with _ character + args = {"name": "test_data", "mode": "chat", "page": 1, "limit": 10} + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + assert paginated_apps is not None + assert paginated_apps.total == 1 + assert len(paginated_apps.items) == 1 + assert paginated_apps.items[0].name == "test_data_app" + + # Test 3: Search with \ character + args = {"name": "path\\to\\app", "mode": "chat", "page": 1, "limit": 10} + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + assert paginated_apps is not None + assert paginated_apps.total == 1 + assert len(paginated_apps.items) == 1 + assert paginated_apps.items[0].name == "path\\to\\app" + + # Test 4: Search with % should NOT match 100% (verifies escaping works) + args = {"name": "50%", "mode": "chat", "page": 1, "limit": 10} + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + assert paginated_apps is not None + assert paginated_apps.total == 1 + assert all("50%" in app.name for app in paginated_apps.items) diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index 6732b8d558..e8c7f17e0b 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -1,3 +1,4 @@ +import uuid from unittest.mock import create_autospec, patch import pytest @@ -312,6 +313,85 @@ class TestTagService: result_no_match = TagService.get_tags("app", tenant.id, keyword="nonexistent") assert len(result_no_match) == 0 + def test_get_tags_with_special_characters_in_keyword( + self, db_session_with_containers, mock_external_service_dependencies + ): + r""" + Test tag retrieval with special characters in keyword to verify SQL injection prevention. + + This test verifies: + - Special characters (%, _, \) in keyword are properly escaped + - Search treats special characters as literal characters, not wildcards + - SQL injection via LIKE wildcards is prevented + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + from extensions.ext_database import db + + # Create tags with special characters in names + tag_with_percent = Tag( + name="50% discount", + type="app", + tenant_id=tenant.id, + created_by=account.id, + ) + tag_with_percent.id = str(uuid.uuid4()) + db.session.add(tag_with_percent) + + tag_with_underscore = Tag( + name="test_data_tag", + type="app", + tenant_id=tenant.id, + created_by=account.id, + ) + tag_with_underscore.id = str(uuid.uuid4()) + db.session.add(tag_with_underscore) + + tag_with_backslash = Tag( + name="path\\to\\tag", + type="app", + tenant_id=tenant.id, + created_by=account.id, + ) + tag_with_backslash.id = str(uuid.uuid4()) + db.session.add(tag_with_backslash) + + # Create tag that should NOT match + tag_no_match = Tag( + name="100% different", + type="app", + tenant_id=tenant.id, + created_by=account.id, + ) + tag_no_match.id = str(uuid.uuid4()) + db.session.add(tag_no_match) + + db.session.commit() + + # Act & Assert: Test 1 - Search with % character + result = TagService.get_tags("app", tenant.id, keyword="50%") + assert len(result) == 1 + assert result[0].name == "50% discount" + + # Test 2 - Search with _ character + result = TagService.get_tags("app", tenant.id, keyword="test_data") + assert len(result) == 1 + assert result[0].name == "test_data_tag" + + # Test 3 - Search with \ character + result = TagService.get_tags("app", tenant.id, keyword="path\\to\\tag") + assert len(result) == 1 + assert result[0].name == "path\\to\\tag" + + # Test 4 - Search with % should NOT match 100% (verifies escaping works) + result = TagService.get_tags("app", tenant.id, keyword="50%") + assert len(result) == 1 + assert all("50%" in item.name for item in result) + def test_get_tags_empty_result(self, db_session_with_containers, mock_external_service_dependencies): """ Test tag retrieval when no tags exist. diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 7b95944bbe..040fb826e1 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -10,7 +10,9 @@ from core.workflow.entities.workflow_execution import WorkflowExecutionStatus from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun from models.enums import CreatorUserRole from services.account_service import AccountService, TenantService -from services.app_service import AppService + +# Delay import of AppService to avoid circular dependency +# from services.app_service import AppService from services.workflow_app_service import WorkflowAppService @@ -86,6 +88,9 @@ class TestWorkflowAppService: "api_rpm": 10, } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -147,6 +152,9 @@ class TestWorkflowAppService: "api_rpm": 10, } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -308,6 +316,156 @@ class TestWorkflowAppService: assert result_no_match["total"] == 0 assert len(result_no_match["data"]) == 0 + def test_get_paginate_workflow_app_logs_with_special_characters_in_keyword( + self, db_session_with_containers, mock_external_service_dependencies + ): + r""" + Test workflow app logs pagination with special characters in keyword to verify SQL injection prevention. + + This test verifies: + - Special characters (%, _) in keyword are properly escaped + - Search treats special characters as literal characters, not wildcards + - SQL injection via LIKE wildcards is prevented + """ + # Arrange: Create test data + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + workflow, _, _ = self._create_test_workflow_data(db_session_with_containers, app, account) + + from extensions.ext_database import db + + service = WorkflowAppService() + + # Test 1: Search with % character + workflow_run_1 = WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type="workflow", + triggered_from="app-run", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + status="succeeded", + inputs=json.dumps({"search_term": "50% discount", "input2": "other_value"}), + outputs=json.dumps({"result": "50% discount applied", "status": "success"}), + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + created_at=datetime.now(UTC), + ) + db.session.add(workflow_run_1) + db.session.flush() + + workflow_app_log_1 = WorkflowAppLog( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + workflow_run_id=workflow_run_1.id, + created_from="service-api", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + ) + workflow_app_log_1.id = str(uuid.uuid4()) + workflow_app_log_1.created_at = datetime.now(UTC) + db.session.add(workflow_app_log_1) + db.session.commit() + + result = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20 + ) + # Should find the workflow_run_1 entry + assert result["total"] >= 1 + assert len(result["data"]) >= 1 + assert any(log.workflow_run_id == workflow_run_1.id for log in result["data"]) + + # Test 2: Search with _ character + workflow_run_2 = WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type="workflow", + triggered_from="app-run", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + status="succeeded", + inputs=json.dumps({"search_term": "test_data_value", "input2": "other_value"}), + outputs=json.dumps({"result": "test_data_value found", "status": "success"}), + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + created_at=datetime.now(UTC), + ) + db.session.add(workflow_run_2) + db.session.flush() + + workflow_app_log_2 = WorkflowAppLog( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + workflow_run_id=workflow_run_2.id, + created_from="service-api", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + ) + workflow_app_log_2.id = str(uuid.uuid4()) + workflow_app_log_2.created_at = datetime.now(UTC) + db.session.add(workflow_app_log_2) + db.session.commit() + + result = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, keyword="test_data", page=1, limit=20 + ) + # Should find the workflow_run_2 entry + assert result["total"] >= 1 + assert len(result["data"]) >= 1 + assert any(log.workflow_run_id == workflow_run_2.id for log in result["data"]) + + # Test 3: Search with % should NOT match 100% (verifies escaping works correctly) + workflow_run_4 = WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type="workflow", + triggered_from="app-run", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + status="succeeded", + inputs=json.dumps({"search_term": "100% different", "input2": "other_value"}), + outputs=json.dumps({"result": "100% different result", "status": "success"}), + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + created_at=datetime.now(UTC), + ) + db.session.add(workflow_run_4) + db.session.flush() + + workflow_app_log_4 = WorkflowAppLog( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + workflow_run_id=workflow_run_4.id, + created_from="service-api", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + ) + workflow_app_log_4.id = str(uuid.uuid4()) + workflow_app_log_4.created_at = datetime.now(UTC) + db.session.add(workflow_app_log_4) + db.session.commit() + + result = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20 + ) + # Should only find the 50% entry (workflow_run_1), not the 100% entry (workflow_run_4) + # This verifies that escaping works correctly - 50% should not match 100% + assert result["total"] >= 1 + assert len(result["data"]) >= 1 + # Verify that we found workflow_run_1 (50% discount) but not workflow_run_4 (100% different) + found_run_ids = [log.workflow_run_id for log in result["data"]] + assert workflow_run_1.id in found_run_ids + assert workflow_run_4.id not in found_run_ids + def test_get_paginate_workflow_app_logs_with_status_filter( self, db_session_with_containers, mock_external_service_dependencies ): diff --git a/api/tests/unit_tests/libs/test_helper.py b/api/tests/unit_tests/libs/test_helper.py index 85789bfa7e..de74eff82f 100644 --- a/api/tests/unit_tests/libs/test_helper.py +++ b/api/tests/unit_tests/libs/test_helper.py @@ -1,6 +1,6 @@ import pytest -from libs.helper import extract_tenant_id +from libs.helper import escape_like_pattern, extract_tenant_id from models.account import Account from models.model import EndUser @@ -63,3 +63,51 @@ class TestExtractTenantId: with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"): extract_tenant_id(dict_user) + + +class TestEscapeLikePattern: + """Test cases for the escape_like_pattern utility function.""" + + def test_escape_percent_character(self): + """Test escaping percent character.""" + result = escape_like_pattern("50% discount") + assert result == "50\\% discount" + + def test_escape_underscore_character(self): + """Test escaping underscore character.""" + result = escape_like_pattern("test_data") + assert result == "test\\_data" + + def test_escape_backslash_character(self): + """Test escaping backslash character.""" + result = escape_like_pattern("path\\to\\file") + assert result == "path\\\\to\\\\file" + + def test_escape_combined_special_characters(self): + """Test escaping multiple special characters together.""" + result = escape_like_pattern("file_50%\\path") + assert result == "file\\_50\\%\\\\path" + + def test_escape_empty_string(self): + """Test escaping empty string returns empty string.""" + result = escape_like_pattern("") + assert result == "" + + def test_escape_none_handling(self): + """Test escaping None returns None (falsy check handles it).""" + # The function checks `if not pattern`, so None is falsy and returns as-is + result = escape_like_pattern(None) + assert result is None + + def test_escape_normal_string_no_change(self): + """Test that normal strings without special characters are unchanged.""" + result = escape_like_pattern("normal text") + assert result == "normal text" + + def test_escape_order_matters(self): + """Test that backslash is escaped first to prevent double escaping.""" + # If we escape % first, then escape \, we might get wrong results + # This test ensures the order is correct: \ first, then % and _ + result = escape_like_pattern("test\\%_value") + # Should be: test\\\%\_value + assert result == "test\\\\\\%\\_value" From ce87371befdc5d1f318205cdc95dee8ee566abd9 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Tue, 6 Jan 2026 11:06:21 +0900 Subject: [PATCH 11/17] refactor: split changes for api/controllers/web/saved_message.py (#30583) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/web/saved_message.py | 35 +++++++++++++++++----------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index 4e20690e9e..29993100f6 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -1,18 +1,30 @@ -from flask_restx import reqparse -from flask_restx.inputs import int_range -from pydantic import TypeAdapter +from flask import request +from pydantic import BaseModel, Field, TypeAdapter from werkzeug.exceptions import NotFound +from controllers.common.schema import register_schema_models from controllers.web import web_ns from controllers.web.error import NotCompletionAppError from controllers.web.wraps import WebApiResource from fields.conversation_fields import ResultResponse from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem -from libs.helper import uuid_value +from libs.helper import UUIDStrOrEmpty from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService +class SavedMessageListQuery(BaseModel): + last_id: UUIDStrOrEmpty | None = None + limit: int = Field(default=20, ge=1, le=100) + + +class SavedMessageCreatePayload(BaseModel): + message_id: UUIDStrOrEmpty + + +register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload) + + @web_ns.route("/saved-messages") class SavedMessageListApi(WebApiResource): @web_ns.doc("Get Saved Messages") @@ -42,14 +54,10 @@ class SavedMessageListApi(WebApiResource): if app_model.mode != "completion": raise NotCompletionAppError() - parser = ( - reqparse.RequestParser() - .add_argument("last_id", type=uuid_value, location="args") - .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - ) - args = parser.parse_args() + raw_args = request.args.to_dict() + query = SavedMessageListQuery.model_validate(raw_args) - pagination = SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"]) + pagination = SavedMessageService.pagination_by_last_id(app_model, end_user, query.last_id, query.limit) adapter = TypeAdapter(SavedMessageItem) items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data] return SavedMessageInfiniteScrollPagination( @@ -79,11 +87,10 @@ class SavedMessageListApi(WebApiResource): if app_model.mode != "completion": raise NotCompletionAppError() - parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json") - args = parser.parse_args() + payload = SavedMessageCreatePayload.model_validate(web_ns.payload or {}) try: - SavedMessageService.save(app_model, end_user, args["message_id"]) + SavedMessageService.save(app_model, end_user, payload.message_id) except MessageNotExistsError: raise NotFound("Message Not Exists.") From c0331b23a9acd8cddf2613af34054ba80f92dd87 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Tue, 6 Jan 2026 11:06:48 +0900 Subject: [PATCH 12/17] refactor: split changes for api/controllers/web/conversation.py (#30582) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/web/conversation.py | 75 ++++++++++++++++------------- 1 file changed, 42 insertions(+), 33 deletions(-) diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index 527eef6094..e76649495a 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -1,9 +1,11 @@ -from flask_restx import reqparse -from flask_restx.inputs import int_range -from pydantic import TypeAdapter +from typing import Literal + +from flask import request +from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound +from controllers.common.schema import register_schema_models from controllers.web import web_ns from controllers.web.error import NotChatAppError from controllers.web.wraps import WebApiResource @@ -21,6 +23,35 @@ from services.errors.conversation import ConversationNotExistsError, LastConvers from services.web_conversation_service import WebConversationService +class ConversationListQuery(BaseModel): + last_id: str | None = None + limit: int = Field(default=20, ge=1, le=100) + pinned: bool | None = None + sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = "-updated_at" + + @field_validator("last_id") + @classmethod + def validate_last_id(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + +class ConversationRenamePayload(BaseModel): + name: str | None = None + auto_generate: bool = False + + @model_validator(mode="after") + def validate_name_requirement(self): + if not self.auto_generate: + if self.name is None or not self.name.strip(): + raise ValueError("name is required when auto_generate is false") + return self + + +register_schema_models(web_ns, ConversationListQuery, ConversationRenamePayload) + + @web_ns.route("/conversations") class ConversationListApi(WebApiResource): @web_ns.doc("Get Conversation List") @@ -64,25 +95,8 @@ class ConversationListApi(WebApiResource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = ( - reqparse.RequestParser() - .add_argument("last_id", type=uuid_value, location="args") - .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - .add_argument("pinned", type=str, choices=["true", "false", None], location="args") - .add_argument( - "sort_by", - type=str, - choices=["created_at", "-created_at", "updated_at", "-updated_at"], - required=False, - default="-updated_at", - location="args", - ) - ) - args = parser.parse_args() - - pinned = None - if "pinned" in args and args["pinned"] is not None: - pinned = args["pinned"] == "true" + raw_args = request.args.to_dict() + query = ConversationListQuery.model_validate(raw_args) try: with Session(db.engine) as session: @@ -90,11 +104,11 @@ class ConversationListApi(WebApiResource): session=session, app_model=app_model, user=end_user, - last_id=args["last_id"], - limit=args["limit"], + last_id=query.last_id, + limit=query.limit, invoke_from=InvokeFrom.WEB_APP, - pinned=pinned, - sort_by=args["sort_by"], + pinned=query.pinned, + sort_by=query.sort_by, ) adapter = TypeAdapter(SimpleConversation) conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data] @@ -168,16 +182,11 @@ class ConversationRenameApi(WebApiResource): conversation_id = str(c_id) - parser = ( - reqparse.RequestParser() - .add_argument("name", type=str, required=False, location="json") - .add_argument("auto_generate", type=bool, required=False, default=False, location="json") - ) - args = parser.parse_args() + payload = ConversationRenamePayload.model_validate(web_ns.payload or {}) try: conversation = ConversationService.rename( - app_model, conversation_id, end_user, args["name"], args["auto_generate"] + app_model, conversation_id, end_user, payload.name, payload.auto_generate ) return ( TypeAdapter(SimpleConversation) From f14c3ce15e7568e9e8816b3d9e1a0479d67774ca Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Tue, 6 Jan 2026 10:07:42 +0800 Subject: [PATCH 13/17] fix: system model selector loading state flash (#30572) --- .../model-provider-page/index.tsx | 18 ++++++++++++------ .../system-model-selector/index.tsx | 11 ++++++++--- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/web/app/components/header/account-setting/model-provider-page/index.tsx b/web/app/components/header/account-setting/model-provider-page/index.tsx index f456bcaaa6..d3daaee859 100644 --- a/web/app/components/header/account-setting/model-provider-page/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/index.tsx @@ -31,14 +31,19 @@ const FixedModelProvider = ['langgenius/openai/openai', 'langgenius/anthropic/an const ModelProviderPage = ({ searchText }: Props) => { const debouncedSearchText = useDebounce(searchText, { wait: 500 }) const { t } = useTranslation() - const { data: textGenerationDefaultModel } = useDefaultModel(ModelTypeEnum.textGeneration) - const { data: embeddingsDefaultModel } = useDefaultModel(ModelTypeEnum.textEmbedding) - const { data: rerankDefaultModel } = useDefaultModel(ModelTypeEnum.rerank) - const { data: speech2textDefaultModel } = useDefaultModel(ModelTypeEnum.speech2text) - const { data: ttsDefaultModel } = useDefaultModel(ModelTypeEnum.tts) + const { data: textGenerationDefaultModel, isLoading: isTextGenerationDefaultModelLoading } = useDefaultModel(ModelTypeEnum.textGeneration) + const { data: embeddingsDefaultModel, isLoading: isEmbeddingsDefaultModelLoading } = useDefaultModel(ModelTypeEnum.textEmbedding) + const { data: rerankDefaultModel, isLoading: isRerankDefaultModelLoading } = useDefaultModel(ModelTypeEnum.rerank) + const { data: speech2textDefaultModel, isLoading: isSpeech2textDefaultModelLoading } = useDefaultModel(ModelTypeEnum.speech2text) + const { data: ttsDefaultModel, isLoading: isTTSDefaultModelLoading } = useDefaultModel(ModelTypeEnum.tts) const { modelProviders: providers } = useProviderContext() const { enable_marketplace } = useGlobalPublicStore(s => s.systemFeatures) - const defaultModelNotConfigured = !textGenerationDefaultModel && !embeddingsDefaultModel && !speech2textDefaultModel && !rerankDefaultModel && !ttsDefaultModel + const isDefaultModelLoading = isTextGenerationDefaultModelLoading + || isEmbeddingsDefaultModelLoading + || isRerankDefaultModelLoading + || isSpeech2textDefaultModelLoading + || isTTSDefaultModelLoading + const defaultModelNotConfigured = !isDefaultModelLoading && !textGenerationDefaultModel && !embeddingsDefaultModel && !speech2textDefaultModel && !rerankDefaultModel && !ttsDefaultModel const [configuredProviders, notConfiguredProviders] = useMemo(() => { const configuredProviders: ModelProvider[] = [] const notConfiguredProviders: ModelProvider[] = [] @@ -106,6 +111,7 @@ const ModelProviderPage = ({ searchText }: Props) => { rerankDefaultModel={rerankDefaultModel} speech2textDefaultModel={speech2textDefaultModel} ttsDefaultModel={ttsDefaultModel} + isLoading={isDefaultModelLoading} /> diff --git a/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx b/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx index 74222ed56d..29c71e04fc 100644 --- a/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx @@ -3,7 +3,7 @@ import type { DefaultModel, DefaultModelResponse, } from '../declarations' -import { RiEqualizer2Line } from '@remixicon/react' +import { RiEqualizer2Line, RiLoader2Line } from '@remixicon/react' import { useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' @@ -32,6 +32,7 @@ type SystemModelSelectorProps = { speech2textDefaultModel: DefaultModelResponse | undefined ttsDefaultModel: DefaultModelResponse | undefined notConfigured: boolean + isLoading?: boolean } const SystemModel: FC = ({ textGenerationDefaultModel, @@ -40,6 +41,7 @@ const SystemModel: FC = ({ speech2textDefaultModel, ttsDefaultModel, notConfigured, + isLoading, }) => { const { t } = useTranslation() const { notify } = useToastContext() @@ -129,13 +131,16 @@ const SystemModel: FC = ({ crossAxis: 8, }} > - setOpen(v => !v)}> + setOpen(v => !v)}> From 7e3bfb9250a64d92142d0d8ff5fbe6674c5194f4 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Tue, 6 Jan 2026 11:08:09 +0900 Subject: [PATCH 14/17] =?UTF-8?q?refactor:=20split=20changes=20for=20api/c?= =?UTF-8?q?ontrollers/console/datasets/hit=5Ftest=E2=80=A6=20(#30581)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../console/datasets/hit_testing_base.py | 15 +++++---------- .../service_api/dataset/hit_testing.py | 2 +- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index db7c50f422..db1a874437 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -1,7 +1,7 @@ import logging from typing import Any -from flask_restx import marshal, reqparse +from flask_restx import marshal from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -56,15 +56,10 @@ class DatasetsHitTestingBase: HitTestingService.hit_testing_args_check(args) @staticmethod - def parse_args(): - parser = ( - reqparse.RequestParser() - .add_argument("query", type=str, required=False, location="json") - .add_argument("attachment_ids", type=list, required=False, location="json") - .add_argument("retrieval_model", type=dict, required=False, location="json") - .add_argument("external_retrieval_model", type=dict, required=False, location="json") - ) - return parser.parse_args() + def parse_args(payload: dict[str, Any]) -> dict[str, Any]: + """Validate and return hit-testing arguments from an incoming payload.""" + hit_testing_payload = HitTestingPayload.model_validate(payload or {}) + return hit_testing_payload.model_dump(exclude_none=True) @staticmethod def perform_hit_testing(dataset, args): diff --git a/api/controllers/service_api/dataset/hit_testing.py b/api/controllers/service_api/dataset/hit_testing.py index d81287d56f..8dbb690901 100644 --- a/api/controllers/service_api/dataset/hit_testing.py +++ b/api/controllers/service_api/dataset/hit_testing.py @@ -24,7 +24,7 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): dataset_id_str = str(dataset_id) dataset = self.get_and_validate_dataset(dataset_id_str) - args = self.parse_args() + args = self.parse_args(service_api_ns.payload) self.hit_testing_args_check(args) return self.perform_hit_testing(dataset, args) From eccf79a710e0f97e456897f82c833527c5e58b25 Mon Sep 17 00:00:00 2001 From: ga_o <45528199+GKTOKWR@users.noreply.github.com> Date: Tue, 6 Jan 2026 11:10:06 +0900 Subject: [PATCH 15/17] chore: remove unused link icon type (#30469) --- api/models/model.py | 3 ++- api/services/app_dsl_service.py | 6 +++--- api/tests/unit_tests/models/test_app_models.py | 2 +- web/types/app.ts | 2 +- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/api/models/model.py b/api/models/model.py index 6cfcc2859d..b6f2751a72 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -70,6 +70,7 @@ class AppMode(StrEnum): class IconType(StrEnum): IMAGE = auto() EMOJI = auto() + LINK = auto() class App(Base): @@ -81,7 +82,7 @@ class App(Base): name: Mapped[str] = mapped_column(String(255)) description: Mapped[str] = mapped_column(LongText, default=sa.text("''")) mode: Mapped[str] = mapped_column(String(255)) - icon_type: Mapped[str | None] = mapped_column(String(255)) # image, emoji + icon_type: Mapped[str | None] = mapped_column(String(255)) # image, emoji, link icon = mapped_column(String(255)) icon_background: Mapped[str | None] = mapped_column(String(255)) app_model_config_id = mapped_column(StringUUID, nullable=True) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index deba0b79e8..acd2a25a86 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -32,7 +32,7 @@ from extensions.ext_redis import redis_client from factories import variable_factory from libs.datetime_utils import naive_utc_now from models import Account, App, AppMode -from models.model import AppModelConfig +from models.model import AppModelConfig, IconType from models.workflow import Workflow from services.plugin.dependencies_analysis import DependenciesAnalysisService from services.workflow_draft_variable_service import WorkflowDraftVariableService @@ -428,10 +428,10 @@ class AppDslService: # Set icon type icon_type_value = icon_type or app_data.get("icon_type") - if icon_type_value in ["emoji", "link", "image"]: + if icon_type_value in [IconType.EMOJI.value, IconType.IMAGE.value, IconType.LINK.value]: icon_type = icon_type_value else: - icon_type = "emoji" + icon_type = IconType.EMOJI.value icon = icon or str(app_data.get("icon", "")) if app: diff --git a/api/tests/unit_tests/models/test_app_models.py b/api/tests/unit_tests/models/test_app_models.py index e35788660d..8be2eea121 100644 --- a/api/tests/unit_tests/models/test_app_models.py +++ b/api/tests/unit_tests/models/test_app_models.py @@ -114,7 +114,7 @@ class TestAppModelValidation: def test_icon_type_validation(self): """Test icon type enum values.""" # Assert - assert {t.value for t in IconType} == {"image", "emoji"} + assert {t.value for t in IconType} == {"image", "emoji", "link"} def test_app_desc_or_prompt_with_description(self): """Test desc_or_prompt property when description exists.""" diff --git a/web/types/app.ts b/web/types/app.ts index eb1b29bb60..c8c409a22c 100644 --- a/web/types/app.ts +++ b/web/types/app.ts @@ -316,7 +316,7 @@ export type SiteConfig = { use_icon_as_answer_icon: boolean } -export type AppIconType = 'image' | 'emoji' +export type AppIconType = 'image' | 'emoji' | 'link' /** * App From 061d552928a72b0f0ffa813a1b8348355f9a02e9 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Tue, 6 Jan 2026 10:12:05 +0800 Subject: [PATCH 16/17] feat: unified management stop event (#30479) --- .../workflow/graph_engine/graph_engine.py | 10 +- .../graph_engine/orchestration/dispatcher.py | 7 +- api/core/workflow/graph_engine/worker.py | 10 +- .../worker_management/worker_pool.py | 5 +- api/core/workflow/nodes/base/node.py | 19 + .../workflow/runtime/graph_runtime_state.py | 2 + .../orchestration/test_dispatcher.py | 4 + .../workflow/graph_engine/test_stop_event.py | 539 ++++++++++++++++++ 8 files changed, 586 insertions(+), 10 deletions(-) create mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_stop_event.py diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 88d6e5cac1..500ba4487b 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -8,6 +8,7 @@ Domain-Driven Design principles for improved maintainability and testability. import contextvars import logging import queue +import threading from collections.abc import Generator from typing import TYPE_CHECKING, cast, final @@ -75,10 +76,13 @@ class GraphEngine: scale_down_idle_time: float | None = None, ) -> None: """Initialize the graph engine with all subsystems and dependencies.""" + # stop event + self._stop_event = threading.Event() # Bind runtime state to current workflow context self._graph = graph self._graph_runtime_state = graph_runtime_state + self._graph_runtime_state.stop_event = self._stop_event self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph)) self._command_channel = command_channel @@ -177,6 +181,7 @@ class GraphEngine: max_workers=self._max_workers, scale_up_threshold=self._scale_up_threshold, scale_down_idle_time=self._scale_down_idle_time, + stop_event=self._stop_event, ) # === Orchestration === @@ -207,6 +212,7 @@ class GraphEngine: event_handler=self._event_handler_registry, execution_coordinator=self._execution_coordinator, event_emitter=self._event_manager, + stop_event=self._stop_event, ) # === Validation === @@ -324,6 +330,7 @@ class GraphEngine: def _start_execution(self, *, resume: bool = False) -> None: """Start execution subsystems.""" + self._stop_event.clear() paused_nodes: list[str] = [] if resume: paused_nodes = self._graph_runtime_state.consume_paused_nodes() @@ -351,13 +358,12 @@ class GraphEngine: def _stop_execution(self) -> None: """Stop execution subsystems.""" + self._stop_event.set() self._dispatcher.stop() self._worker_pool.stop() # Don't mark complete here as the dispatcher already does it # Notify layers - logger = logging.getLogger(__name__) - for layer in self._layers: try: layer.on_graph_end(self._graph_execution.error) diff --git a/api/core/workflow/graph_engine/orchestration/dispatcher.py b/api/core/workflow/graph_engine/orchestration/dispatcher.py index 334a3f77bf..27439a2412 100644 --- a/api/core/workflow/graph_engine/orchestration/dispatcher.py +++ b/api/core/workflow/graph_engine/orchestration/dispatcher.py @@ -44,6 +44,7 @@ class Dispatcher: event_queue: queue.Queue[GraphNodeEventBase], event_handler: "EventHandler", execution_coordinator: ExecutionCoordinator, + stop_event: threading.Event, event_emitter: EventManager | None = None, ) -> None: """ @@ -61,7 +62,7 @@ class Dispatcher: self._event_emitter = event_emitter self._thread: threading.Thread | None = None - self._stop_event = threading.Event() + self._stop_event = stop_event self._start_time: float | None = None def start(self) -> None: @@ -69,16 +70,14 @@ class Dispatcher: if self._thread and self._thread.is_alive(): return - self._stop_event.clear() self._start_time = time.time() self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True) self._thread.start() def stop(self) -> None: """Stop the dispatcher thread.""" - self._stop_event.set() if self._thread and self._thread.is_alive(): - self._thread.join(timeout=10.0) + self._thread.join(timeout=2.0) def _dispatcher_loop(self) -> None: """Main dispatcher loop.""" diff --git a/api/core/workflow/graph_engine/worker.py b/api/core/workflow/graph_engine/worker.py index e37a08ae47..83419830b6 100644 --- a/api/core/workflow/graph_engine/worker.py +++ b/api/core/workflow/graph_engine/worker.py @@ -42,6 +42,7 @@ class Worker(threading.Thread): event_queue: queue.Queue[GraphNodeEventBase], graph: Graph, layers: Sequence[GraphEngineLayer], + stop_event: threading.Event, worker_id: int = 0, flask_app: Flask | None = None, context_vars: contextvars.Context | None = None, @@ -65,13 +66,16 @@ class Worker(threading.Thread): self._worker_id = worker_id self._flask_app = flask_app self._context_vars = context_vars - self._stop_event = threading.Event() self._last_task_time = time.time() + self._stop_event = stop_event self._layers = layers if layers is not None else [] def stop(self) -> None: - """Signal the worker to stop processing.""" - self._stop_event.set() + """Worker is controlled via shared stop_event from GraphEngine. + + This method is a no-op retained for backward compatibility. + """ + pass @property def is_idle(self) -> bool: diff --git a/api/core/workflow/graph_engine/worker_management/worker_pool.py b/api/core/workflow/graph_engine/worker_management/worker_pool.py index 5b9234586b..df76ebe882 100644 --- a/api/core/workflow/graph_engine/worker_management/worker_pool.py +++ b/api/core/workflow/graph_engine/worker_management/worker_pool.py @@ -41,6 +41,7 @@ class WorkerPool: event_queue: queue.Queue[GraphNodeEventBase], graph: Graph, layers: list[GraphEngineLayer], + stop_event: threading.Event, flask_app: "Flask | None" = None, context_vars: "Context | None" = None, min_workers: int | None = None, @@ -81,6 +82,7 @@ class WorkerPool: self._worker_counter = 0 self._lock = threading.RLock() self._running = False + self._stop_event = stop_event # No longer tracking worker states with callbacks to avoid lock contention @@ -135,7 +137,7 @@ class WorkerPool: # Wait for workers to finish for worker in self._workers: if worker.is_alive(): - worker.join(timeout=10.0) + worker.join(timeout=2.0) self._workers.clear() @@ -152,6 +154,7 @@ class WorkerPool: worker_id=worker_id, flask_app=self._flask_app, context_vars=self._context_vars, + stop_event=self._stop_event, ) worker.start() diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 8ebba3659c..e7282313b6 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -264,6 +264,10 @@ class Node(Generic[NodeDataT]): """ raise NotImplementedError + def _should_stop(self) -> bool: + """Check if execution should be stopped.""" + return self.graph_runtime_state.stop_event.is_set() + def run(self) -> Generator[GraphNodeEventBase, None, None]: execution_id = self.ensure_execution_id() self._start_at = naive_utc_now() @@ -332,6 +336,21 @@ class Node(Generic[NodeDataT]): yield event else: yield event + + if self._should_stop(): + error_message = "Execution cancelled" + yield NodeRunFailedEvent( + id=self.execution_id, + node_id=self._node_id, + node_type=self.node_type, + start_at=self._start_at, + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=error_message, + ), + error=error_message, + ) + return except Exception as e: logger.exception("Node %s failed to run", self._node_id) result = NodeRunResult( diff --git a/api/core/workflow/runtime/graph_runtime_state.py b/api/core/workflow/runtime/graph_runtime_state.py index 1561b789df..401cecc162 100644 --- a/api/core/workflow/runtime/graph_runtime_state.py +++ b/api/core/workflow/runtime/graph_runtime_state.py @@ -2,6 +2,7 @@ from __future__ import annotations import importlib import json +import threading from collections.abc import Mapping, Sequence from copy import deepcopy from dataclasses import dataclass @@ -168,6 +169,7 @@ class GraphRuntimeState: self._pending_response_coordinator_dump: str | None = None self._pending_graph_execution_workflow_id: str | None = None self._paused_nodes: set[str] = set() + self.stop_event: threading.Event = threading.Event() if graph is not None: self.attach_graph(graph) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py b/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py index c1fc4acd73..fe3ea576c1 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py @@ -3,6 +3,7 @@ from __future__ import annotations import queue +import threading from unittest import mock from core.workflow.entities.pause_reason import SchedulingPause @@ -36,6 +37,7 @@ def test_dispatcher_should_consume_remains_events_after_pause(): event_queue=event_queue, event_handler=event_handler, execution_coordinator=execution_coordinator, + stop_event=threading.Event(), ) dispatcher._dispatcher_loop() assert event_queue.empty() @@ -96,6 +98,7 @@ def _run_dispatcher_for_event(event) -> int: event_queue=event_queue, event_handler=event_handler, execution_coordinator=coordinator, + stop_event=threading.Event(), ) dispatcher._dispatcher_loop() @@ -181,6 +184,7 @@ def test_dispatcher_drain_event_queue(): event_queue=event_queue, event_handler=event_handler, execution_coordinator=coordinator, + stop_event=threading.Event(), ) dispatcher._dispatcher_loop() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_stop_event.py b/api/tests/unit_tests/core/workflow/graph_engine/test_stop_event.py new file mode 100644 index 0000000000..ea8d3a977f --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_stop_event.py @@ -0,0 +1,539 @@ +""" +Unit tests for stop_event functionality in GraphEngine. + +Tests the unified stop_event management by GraphEngine and its propagation +to WorkerPool, Worker, Dispatcher, and Nodes. +""" + +import threading +import time +from unittest.mock import MagicMock, Mock, patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.graph_init_params import GraphInitParams +from core.workflow.graph import Graph +from core.workflow.graph_engine import GraphEngine +from core.workflow.graph_engine.command_channels import InMemoryChannel +from core.workflow.graph_events import ( + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunStartedEvent, +) +from core.workflow.nodes.answer.answer_node import AnswerNode +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.runtime import GraphRuntimeState, VariablePool +from models.enums import UserFrom + + +class TestStopEventPropagation: + """Test suite for stop_event propagation through GraphEngine components.""" + + def test_graph_engine_creates_stop_event(self): + """Test that GraphEngine creates a stop_event on initialization.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Verify stop_event was created + assert engine._stop_event is not None + assert isinstance(engine._stop_event, threading.Event) + + # Verify it was set in graph_runtime_state + assert runtime_state.stop_event is not None + assert runtime_state.stop_event is engine._stop_event + + def test_stop_event_cleared_on_start(self): + """Test that stop_event is cleared when execution starts.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + mock_graph.root_node.id = "start" # Set proper id + + start_node = StartNode( + id="start", + config={"id": "start", "data": {"title": "start", "variables": []}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + mock_graph.nodes["start"] = start_node + mock_graph.get_outgoing_edges = MagicMock(return_value=[]) + mock_graph.get_incoming_edges = MagicMock(return_value=[]) + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Set the stop_event before running + engine._stop_event.set() + assert engine._stop_event.is_set() + + # Run the engine (should clear the stop_event) + events = list(engine.run()) + + # After running, stop_event should be set again (by _stop_execution) + # But during start it was cleared + assert any(isinstance(e, GraphRunStartedEvent) for e in events) + assert any(isinstance(e, GraphRunSucceededEvent) for e in events) + + def test_stop_event_set_on_stop(self): + """Test that stop_event is set when execution stops.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + mock_graph.root_node.id = "start" # Set proper id + + start_node = StartNode( + id="start", + config={"id": "start", "data": {"title": "start", "variables": []}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + mock_graph.nodes["start"] = start_node + mock_graph.get_outgoing_edges = MagicMock(return_value=[]) + mock_graph.get_incoming_edges = MagicMock(return_value=[]) + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Initially not set + assert not engine._stop_event.is_set() + + # Run the engine + list(engine.run()) + + # After execution completes, stop_event should be set + assert engine._stop_event.is_set() + + def test_stop_event_passed_to_worker_pool(self): + """Test that stop_event is passed to WorkerPool.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Verify WorkerPool has the stop_event + assert engine._worker_pool._stop_event is not None + assert engine._worker_pool._stop_event is engine._stop_event + + def test_stop_event_passed_to_dispatcher(self): + """Test that stop_event is passed to Dispatcher.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Verify Dispatcher has the stop_event + assert engine._dispatcher._stop_event is not None + assert engine._dispatcher._stop_event is engine._stop_event + + +class TestNodeStopCheck: + """Test suite for Node._should_stop() functionality.""" + + def test_node_should_stop_checks_runtime_state(self): + """Test that Node._should_stop() checks GraphRuntimeState.stop_event.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + + answer_node = AnswerNode( + id="answer", + config={"id": "answer", "data": {"title": "answer", "answer": "{{#start.result#}}"}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + + # Initially stop_event is not set + assert not answer_node._should_stop() + + # Set the stop_event + runtime_state.stop_event.set() + + # Now _should_stop should return True + assert answer_node._should_stop() + + def test_node_run_checks_stop_event_between_yields(self): + """Test that Node.run() checks stop_event between yielding events.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + + # Create a simple node + answer_node = AnswerNode( + id="answer", + config={"id": "answer", "data": {"title": "answer", "answer": "hello"}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + + # Set stop_event BEFORE running the node + runtime_state.stop_event.set() + + # Run the node - should yield start event then detect stop + # The node should check stop_event before processing + assert answer_node._should_stop(), "stop_event should be set" + + # Run and collect events + events = list(answer_node.run()) + + # Since stop_event is set at the start, we should get: + # 1. NodeRunStartedEvent (always yielded first) + # 2. Either NodeRunFailedEvent (if detected early) or NodeRunSucceededEvent (if too fast) + assert len(events) >= 2 + assert isinstance(events[0], NodeRunStartedEvent) + + # Note: AnswerNode is very simple and might complete before stop check + # The important thing is that _should_stop() returns True when stop_event is set + assert answer_node._should_stop() + + +class TestStopEventIntegration: + """Integration tests for stop_event in workflow execution.""" + + def test_simple_workflow_respects_stop_event(self): + """Test that a simple workflow respects stop_event.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + mock_graph.root_node.id = "start" + + # Create start and answer nodes + start_node = StartNode( + id="start", + config={"id": "start", "data": {"title": "start", "variables": []}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + + answer_node = AnswerNode( + id="answer", + config={"id": "answer", "data": {"title": "answer", "answer": "hello"}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + + mock_graph.nodes["start"] = start_node + mock_graph.nodes["answer"] = answer_node + mock_graph.get_outgoing_edges = MagicMock(return_value=[]) + mock_graph.get_incoming_edges = MagicMock(return_value=[]) + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Set stop_event before running + runtime_state.stop_event.set() + + # Run the engine + events = list(engine.run()) + + # Should get started event but not succeeded (due to stop) + assert any(isinstance(e, GraphRunStartedEvent) for e in events) + # The workflow should still complete (start node runs quickly) + # but answer node might be cancelled depending on timing + + def test_stop_event_with_concurrent_nodes(self): + """Test stop_event behavior with multiple concurrent nodes.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + # Create multiple nodes + for i in range(3): + answer_node = AnswerNode( + id=f"answer_{i}", + config={"id": f"answer_{i}", "data": {"title": f"answer_{i}", "answer": f"test{i}"}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + mock_graph.nodes[f"answer_{i}"] = answer_node + + mock_graph.get_outgoing_edges = MagicMock(return_value=[]) + mock_graph.get_incoming_edges = MagicMock(return_value=[]) + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # All nodes should share the same stop_event + for node in mock_graph.nodes.values(): + assert node.graph_runtime_state.stop_event is runtime_state.stop_event + assert node.graph_runtime_state.stop_event is engine._stop_event + + +class TestStopEventTimeoutBehavior: + """Test stop_event behavior with join timeouts.""" + + @patch("core.workflow.graph_engine.orchestration.dispatcher.threading.Thread") + def test_dispatcher_uses_shorter_timeout(self, mock_thread_cls: MagicMock): + """Test that Dispatcher uses 2s timeout instead of 10s.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + dispatcher = engine._dispatcher + dispatcher.start() # This will create and start the mocked thread + + mock_thread_instance = mock_thread_cls.return_value + mock_thread_instance.is_alive.return_value = True + + dispatcher.stop() + + mock_thread_instance.join.assert_called_once_with(timeout=2.0) + + @patch("core.workflow.graph_engine.worker_management.worker_pool.Worker") + def test_worker_pool_uses_shorter_timeout(self, mock_worker_cls: MagicMock): + """Test that WorkerPool uses 2s timeout instead of 10s.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + worker_pool = engine._worker_pool + worker_pool.start(initial_count=1) # Start with one worker + + mock_worker_instance = mock_worker_cls.return_value + mock_worker_instance.is_alive.return_value = True + + worker_pool.stop() + + mock_worker_instance.join.assert_called_once_with(timeout=2.0) + + +class TestStopEventResumeBehavior: + """Test stop_event behavior during workflow resume.""" + + def test_stop_event_cleared_on_resume(self): + """Test that stop_event is cleared when resuming a paused workflow.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + mock_graph.root_node.id = "start" # Set proper id + + start_node = StartNode( + id="start", + config={"id": "start", "data": {"title": "start", "variables": []}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + mock_graph.nodes["start"] = start_node + mock_graph.get_outgoing_edges = MagicMock(return_value=[]) + mock_graph.get_incoming_edges = MagicMock(return_value=[]) + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Simulate a previous execution that set stop_event + engine._stop_event.set() + assert engine._stop_event.is_set() + + # Run the engine (should clear stop_event in _start_execution) + events = list(engine.run()) + + # Execution should complete successfully + assert any(isinstance(e, GraphRunStartedEvent) for e in events) + assert any(isinstance(e, GraphRunSucceededEvent) for e in events) + + +class TestWorkerStopBehavior: + """Test Worker behavior with shared stop_event.""" + + def test_worker_uses_shared_stop_event(self): + """Test that Worker uses shared stop_event from GraphEngine.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Get the worker pool and check workers + worker_pool = engine._worker_pool + + # Start the worker pool to create workers + worker_pool.start() + + # Check that at least one worker was created + assert len(worker_pool._workers) > 0 + + # Verify workers use the shared stop_event + for worker in worker_pool._workers: + assert worker._stop_event is engine._stop_event + + # Clean up + worker_pool.stop() + + def test_worker_stop_is_noop(self): + """Test that Worker.stop() is now a no-op.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + + # Create a mock worker + from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue + from core.workflow.graph_engine.worker import Worker + + ready_queue = InMemoryReadyQueue() + event_queue = MagicMock() + + # Create a proper mock graph with real dict + mock_graph = Mock(spec=Graph) + mock_graph.nodes = {} # Use real dict + + stop_event = threading.Event() + + worker = Worker( + ready_queue=ready_queue, + event_queue=event_queue, + graph=mock_graph, + layers=[], + stop_event=stop_event, + ) + + # Calling stop() should do nothing (no-op) + # and should NOT set the stop_event + worker.stop() + assert not stop_event.is_set() From f320fd5f95de26ed8818f4ba31bb16c77f880d11 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Tue, 6 Jan 2026 11:12:52 +0900 Subject: [PATCH 17/17] refactor: port controllers/console/app/app.py (#30522) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/app/app.py | 487 ++++++++++++------ .../console/app/test_app_response_models.py | 285 ++++++++++ 2 files changed, 606 insertions(+), 166 deletions(-) create mode 100644 api/tests/unit_tests/controllers/console/app/test_app_response_models.py diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 44cf89d6a9..d66bb7063f 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,14 +1,16 @@ import re import uuid -from typing import Literal +from datetime import datetime +from typing import Any, Literal, TypeAlias from flask import request -from flask_restx import Resource, fields, marshal, marshal_with -from pydantic import BaseModel, Field, field_validator +from flask_restx import Resource +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import ( @@ -19,27 +21,19 @@ from controllers.console.wraps import ( is_admin_or_owner_required, setup_required, ) +from core.file import helpers as file_helpers from core.ops.ops_trace_manager import OpsTraceManager from core.workflow.enums import NodeType from extensions.ext_database import db -from fields.app_fields import ( - deleted_tool_fields, - model_config_fields, - model_config_partial_fields, - site_fields, - tag_fields, -) -from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict -from libs.helper import AppIconUrlField, TimestampField from libs.login import current_account_with_tenant, login_required from models import App, Workflow +from models.model import IconType from services.app_dsl_service import AppDslService, ImportMode from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class AppListQuery(BaseModel): @@ -192,124 +186,292 @@ class AppTracePayload(BaseModel): return value -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +JSONValue: TypeAlias = Any -reg(AppListQuery) -reg(CreateAppPayload) -reg(UpdateAppPayload) -reg(CopyAppPayload) -reg(AppExportQuery) -reg(AppNamePayload) -reg(AppIconPayload) -reg(AppSiteStatusPayload) -reg(AppApiStatusPayload) -reg(AppTracePayload) +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) -# Register models for flask_restx to avoid dict type issues in Swagger -# Register base models first -tag_model = console_ns.model("Tag", tag_fields) -workflow_partial_model = console_ns.model("WorkflowPartial", _workflow_partial_fields_dict) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -model_config_model = console_ns.model("ModelConfig", model_config_fields) -model_config_partial_model = console_ns.model("ModelConfigPartial", model_config_partial_fields) +def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None: + if icon is None or icon_type is None: + return None + icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type) + if icon_type_value.lower() != IconType.IMAGE.value: + return None + return file_helpers.get_signed_file_url(icon) -deleted_tool_model = console_ns.model("DeletedTool", deleted_tool_fields) -site_model = console_ns.model("Site", site_fields) +class Tag(ResponseModel): + id: str + name: str + type: str -app_partial_model = console_ns.model( - "AppPartial", - { - "id": fields.String, - "name": fields.String, - "max_active_requests": fields.Raw(), - "description": fields.String(attribute="desc_or_prompt"), - "mode": fields.String(attribute="mode_compatible_with_agent"), - "icon_type": fields.String, - "icon": fields.String, - "icon_background": fields.String, - "icon_url": AppIconUrlField, - "model_config": fields.Nested(model_config_partial_model, attribute="app_model_config", allow_null=True), - "workflow": fields.Nested(workflow_partial_model, allow_null=True), - "use_icon_as_answer_icon": fields.Boolean, - "created_by": fields.String, - "created_at": TimestampField, - "updated_by": fields.String, - "updated_at": TimestampField, - "tags": fields.List(fields.Nested(tag_model)), - "access_mode": fields.String, - "create_user_name": fields.String, - "author_name": fields.String, - "has_draft_trigger": fields.Boolean, - }, -) -app_detail_model = console_ns.model( - "AppDetail", - { - "id": fields.String, - "name": fields.String, - "description": fields.String, - "mode": fields.String(attribute="mode_compatible_with_agent"), - "icon": fields.String, - "icon_background": fields.String, - "enable_site": fields.Boolean, - "enable_api": fields.Boolean, - "model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True), - "workflow": fields.Nested(workflow_partial_model, allow_null=True), - "tracing": fields.Raw, - "use_icon_as_answer_icon": fields.Boolean, - "created_by": fields.String, - "created_at": TimestampField, - "updated_by": fields.String, - "updated_at": TimestampField, - "access_mode": fields.String, - "tags": fields.List(fields.Nested(tag_model)), - }, -) +class WorkflowPartial(ResponseModel): + id: str + created_by: str | None = None + created_at: int | None = None + updated_by: str | None = None + updated_at: int | None = None -app_detail_with_site_model = console_ns.model( - "AppDetailWithSite", - { - "id": fields.String, - "name": fields.String, - "description": fields.String, - "mode": fields.String(attribute="mode_compatible_with_agent"), - "icon_type": fields.String, - "icon": fields.String, - "icon_background": fields.String, - "icon_url": AppIconUrlField, - "enable_site": fields.Boolean, - "enable_api": fields.Boolean, - "model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True), - "workflow": fields.Nested(workflow_partial_model, allow_null=True), - "api_base_url": fields.String, - "use_icon_as_answer_icon": fields.Boolean, - "max_active_requests": fields.Integer, - "created_by": fields.String, - "created_at": TimestampField, - "updated_by": fields.String, - "updated_at": TimestampField, - "deleted_tools": fields.List(fields.Nested(deleted_tool_model)), - "access_mode": fields.String, - "tags": fields.List(fields.Nested(tag_model)), - "site": fields.Nested(site_model), - }, -) + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) -app_pagination_model = console_ns.model( - "AppPagination", - { - "page": fields.Integer, - "limit": fields.Integer(attribute="per_page"), - "total": fields.Integer, - "has_more": fields.Boolean(attribute="has_next"), - "data": fields.List(fields.Nested(app_partial_model), attribute="items"), - }, + +class ModelConfigPartial(ResponseModel): + model: JSONValue | None = Field(default=None, validation_alias=AliasChoices("model_dict", "model")) + pre_prompt: str | None = None + created_by: str | None = None + created_at: int | None = None + updated_by: str | None = None + updated_at: int | None = None + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class ModelConfig(ResponseModel): + opening_statement: str | None = None + suggested_questions: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("suggested_questions_list", "suggested_questions") + ) + suggested_questions_after_answer: JSONValue | None = Field( + default=None, + validation_alias=AliasChoices("suggested_questions_after_answer_dict", "suggested_questions_after_answer"), + ) + speech_to_text: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("speech_to_text_dict", "speech_to_text") + ) + text_to_speech: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("text_to_speech_dict", "text_to_speech") + ) + retriever_resource: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("retriever_resource_dict", "retriever_resource") + ) + annotation_reply: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("annotation_reply_dict", "annotation_reply") + ) + more_like_this: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("more_like_this_dict", "more_like_this") + ) + sensitive_word_avoidance: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("sensitive_word_avoidance_dict", "sensitive_word_avoidance") + ) + external_data_tools: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("external_data_tools_list", "external_data_tools") + ) + model: JSONValue | None = Field(default=None, validation_alias=AliasChoices("model_dict", "model")) + user_input_form: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("user_input_form_list", "user_input_form") + ) + dataset_query_variable: str | None = None + pre_prompt: str | None = None + agent_mode: JSONValue | None = Field(default=None, validation_alias=AliasChoices("agent_mode_dict", "agent_mode")) + prompt_type: str | None = None + chat_prompt_config: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("chat_prompt_config_dict", "chat_prompt_config") + ) + completion_prompt_config: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("completion_prompt_config_dict", "completion_prompt_config") + ) + dataset_configs: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("dataset_configs_dict", "dataset_configs") + ) + file_upload: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("file_upload_dict", "file_upload") + ) + created_by: str | None = None + created_at: int | None = None + updated_by: str | None = None + updated_at: int | None = None + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class Site(ResponseModel): + access_token: str | None = Field(default=None, validation_alias="code") + code: str | None = None + title: str | None = None + icon_type: str | IconType | None = None + icon: str | None = None + icon_background: str | None = None + description: str | None = None + default_language: str | None = None + chat_color_theme: str | None = None + chat_color_theme_inverted: bool | None = None + customize_domain: str | None = None + copyright: str | None = None + privacy_policy: str | None = None + custom_disclaimer: str | None = None + customize_token_strategy: str | None = None + prompt_public: bool | None = None + app_base_url: str | None = None + show_workflow_steps: bool | None = None + use_icon_as_answer_icon: bool | None = None + created_by: str | None = None + created_at: int | None = None + updated_by: str | None = None + updated_at: int | None = None + + @computed_field(return_type=str | None) # type: ignore + @property + def icon_url(self) -> str | None: + return _build_icon_url(self.icon_type, self.icon) + + @field_validator("icon_type", mode="before") + @classmethod + def _normalize_icon_type(cls, value: str | IconType | None) -> str | None: + if isinstance(value, IconType): + return value.value + return value + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class DeletedTool(ResponseModel): + type: str + tool_name: str + provider_id: str + + +class AppPartial(ResponseModel): + id: str + name: str + max_active_requests: int | None = None + description: str | None = Field(default=None, validation_alias=AliasChoices("desc_or_prompt", "description")) + mode: str = Field(validation_alias="mode_compatible_with_agent") + icon_type: str | None = None + icon: str | None = None + icon_background: str | None = None + model_config_: ModelConfigPartial | None = Field( + default=None, + validation_alias=AliasChoices("app_model_config", "model_config"), + alias="model_config", + ) + workflow: WorkflowPartial | None = None + use_icon_as_answer_icon: bool | None = None + created_by: str | None = None + created_at: int | None = None + updated_by: str | None = None + updated_at: int | None = None + tags: list[Tag] = Field(default_factory=list) + access_mode: str | None = None + create_user_name: str | None = None + author_name: str | None = None + has_draft_trigger: bool | None = None + + @computed_field(return_type=str | None) # type: ignore + @property + def icon_url(self) -> str | None: + return _build_icon_url(self.icon_type, self.icon) + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AppDetail(ResponseModel): + id: str + name: str + description: str | None = None + mode: str = Field(validation_alias="mode_compatible_with_agent") + icon: str | None = None + icon_background: str | None = None + enable_site: bool + enable_api: bool + model_config_: ModelConfig | None = Field( + default=None, + validation_alias=AliasChoices("app_model_config", "model_config"), + alias="model_config", + ) + workflow: WorkflowPartial | None = None + tracing: JSONValue | None = None + use_icon_as_answer_icon: bool | None = None + created_by: str | None = None + created_at: int | None = None + updated_by: str | None = None + updated_at: int | None = None + access_mode: str | None = None + tags: list[Tag] = Field(default_factory=list) + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AppDetailWithSite(AppDetail): + icon_type: str | None = None + api_base_url: str | None = None + max_active_requests: int | None = None + deleted_tools: list[DeletedTool] = Field(default_factory=list) + site: Site | None = None + + @computed_field(return_type=str | None) # type: ignore + @property + def icon_url(self) -> str | None: + return _build_icon_url(self.icon_type, self.icon) + + +class AppPagination(ResponseModel): + page: int + limit: int = Field(validation_alias=AliasChoices("per_page", "limit")) + total: int + has_more: bool = Field(validation_alias=AliasChoices("has_next", "has_more")) + data: list[AppPartial] = Field(validation_alias=AliasChoices("items", "data")) + + +class AppExportResponse(ResponseModel): + data: str + + +register_schema_models( + console_ns, + AppListQuery, + CreateAppPayload, + UpdateAppPayload, + CopyAppPayload, + AppExportQuery, + AppNamePayload, + AppIconPayload, + AppSiteStatusPayload, + AppApiStatusPayload, + AppTracePayload, + Tag, + WorkflowPartial, + ModelConfigPartial, + ModelConfig, + Site, + DeletedTool, + AppPartial, + AppDetail, + AppDetailWithSite, + AppPagination, + AppExportResponse, ) @@ -318,7 +480,7 @@ class AppListApi(Resource): @console_ns.doc("list_apps") @console_ns.doc(description="Get list of applications with pagination and filtering") @console_ns.expect(console_ns.models[AppListQuery.__name__]) - @console_ns.response(200, "Success", app_pagination_model) + @console_ns.response(200, "Success", console_ns.models[AppPagination.__name__]) @setup_required @login_required @account_initialization_required @@ -334,7 +496,8 @@ class AppListApi(Resource): app_service = AppService() app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args_dict) if not app_pagination: - return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} + empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[]) + return empty.model_dump(mode="json"), 200 if FeatureService.get_system_features().webapp_auth.enabled: app_ids = [str(app.id) for app in app_pagination.items] @@ -378,18 +541,18 @@ class AppListApi(Resource): for app in app_pagination.items: app.has_draft_trigger = str(app.id) in draft_trigger_app_ids - return marshal(app_pagination, app_pagination_model), 200 + pagination_model = AppPagination.model_validate(app_pagination, from_attributes=True) + return pagination_model.model_dump(mode="json"), 200 @console_ns.doc("create_app") @console_ns.doc(description="Create a new application") @console_ns.expect(console_ns.models[CreateAppPayload.__name__]) - @console_ns.response(201, "App created successfully", app_detail_model) + @console_ns.response(201, "App created successfully", console_ns.models[AppDetail.__name__]) @console_ns.response(403, "Insufficient permissions") @console_ns.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required - @marshal_with(app_detail_model) @cloud_edition_billing_resource_check("apps") @edit_permission_required def post(self): @@ -399,8 +562,8 @@ class AppListApi(Resource): app_service = AppService() app = app_service.create_app(current_tenant_id, args.model_dump(), current_user) - - return app, 201 + app_detail = AppDetail.model_validate(app, from_attributes=True) + return app_detail.model_dump(mode="json"), 201 @console_ns.route("/apps/") @@ -408,13 +571,12 @@ class AppApi(Resource): @console_ns.doc("get_app_detail") @console_ns.doc(description="Get application details") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.response(200, "Success", app_detail_with_site_model) + @console_ns.response(200, "Success", console_ns.models[AppDetailWithSite.__name__]) @setup_required @login_required @account_initialization_required @enterprise_license_required - @get_app_model - @marshal_with(app_detail_with_site_model) + @get_app_model(mode=None) def get(self, app_model): """Get app detail""" app_service = AppService() @@ -425,21 +587,21 @@ class AppApi(Resource): app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id)) app_model.access_mode = app_setting.access_mode - return app_model + response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True) + return response_model.model_dump(mode="json") @console_ns.doc("update_app") @console_ns.doc(description="Update application details") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[UpdateAppPayload.__name__]) - @console_ns.response(200, "App updated successfully", app_detail_with_site_model) + @console_ns.response(200, "App updated successfully", console_ns.models[AppDetailWithSite.__name__]) @console_ns.response(403, "Insufficient permissions") @console_ns.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required - @get_app_model + @get_app_model(mode=None) @edit_permission_required - @marshal_with(app_detail_with_site_model) def put(self, app_model): """Update app""" args = UpdateAppPayload.model_validate(console_ns.payload) @@ -456,8 +618,8 @@ class AppApi(Resource): "max_active_requests": args.max_active_requests or 0, } app_model = app_service.update_app(app_model, args_dict) - - return app_model + response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True) + return response_model.model_dump(mode="json") @console_ns.doc("delete_app") @console_ns.doc(description="Delete application") @@ -483,14 +645,13 @@ class AppCopyApi(Resource): @console_ns.doc(description="Create a copy of an existing application") @console_ns.doc(params={"app_id": "Application ID to copy"}) @console_ns.expect(console_ns.models[CopyAppPayload.__name__]) - @console_ns.response(201, "App copied successfully", app_detail_with_site_model) + @console_ns.response(201, "App copied successfully", console_ns.models[AppDetailWithSite.__name__]) @console_ns.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required - @get_app_model + @get_app_model(mode=None) @edit_permission_required - @marshal_with(app_detail_with_site_model) def post(self, app_model): """Copy app""" # The role of the current user in the ta table must be admin, owner, or editor @@ -516,7 +677,8 @@ class AppCopyApi(Resource): stmt = select(App).where(App.id == result.app_id) app = session.scalar(stmt) - return app, 201 + response_model = AppDetailWithSite.model_validate(app, from_attributes=True) + return response_model.model_dump(mode="json"), 201 @console_ns.route("/apps//export") @@ -525,11 +687,7 @@ class AppExportApi(Resource): @console_ns.doc(description="Export application configuration as DSL") @console_ns.doc(params={"app_id": "Application ID to export"}) @console_ns.expect(console_ns.models[AppExportQuery.__name__]) - @console_ns.response( - 200, - "App exported successfully", - console_ns.model("AppExportResponse", {"data": fields.String(description="DSL export data")}), - ) + @console_ns.response(200, "App exported successfully", console_ns.models[AppExportResponse.__name__]) @console_ns.response(403, "Insufficient permissions") @get_app_model @setup_required @@ -540,13 +698,14 @@ class AppExportApi(Resource): """Export app""" args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - return { - "data": AppDslService.export_dsl( + payload = AppExportResponse( + data=AppDslService.export_dsl( app_model=app_model, include_secret=args.include_secret, workflow_id=args.workflow_id, ) - } + ) + return payload.model_dump(mode="json") @console_ns.route("/apps//name") @@ -555,20 +714,19 @@ class AppNameApi(Resource): @console_ns.doc(description="Check if app name is available") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[AppNamePayload.__name__]) - @console_ns.response(200, "Name availability checked") + @console_ns.response(200, "Name availability checked", console_ns.models[AppDetail.__name__]) @setup_required @login_required @account_initialization_required - @get_app_model - @marshal_with(app_detail_model) + @get_app_model(mode=None) @edit_permission_required def post(self, app_model): args = AppNamePayload.model_validate(console_ns.payload) app_service = AppService() app_model = app_service.update_app_name(app_model, args.name) - - return app_model + response_model = AppDetail.model_validate(app_model, from_attributes=True) + return response_model.model_dump(mode="json") @console_ns.route("/apps//icon") @@ -582,16 +740,15 @@ class AppIconApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model - @marshal_with(app_detail_model) + @get_app_model(mode=None) @edit_permission_required def post(self, app_model): args = AppIconPayload.model_validate(console_ns.payload or {}) app_service = AppService() app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "") - - return app_model + response_model = AppDetail.model_validate(app_model, from_attributes=True) + return response_model.model_dump(mode="json") @console_ns.route("/apps//site-enable") @@ -600,21 +757,20 @@ class AppSiteStatus(Resource): @console_ns.doc(description="Enable or disable app site") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[AppSiteStatusPayload.__name__]) - @console_ns.response(200, "Site status updated successfully", app_detail_model) + @console_ns.response(200, "Site status updated successfully", console_ns.models[AppDetail.__name__]) @console_ns.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required - @get_app_model - @marshal_with(app_detail_model) + @get_app_model(mode=None) @edit_permission_required def post(self, app_model): args = AppSiteStatusPayload.model_validate(console_ns.payload) app_service = AppService() app_model = app_service.update_app_site_status(app_model, args.enable_site) - - return app_model + response_model = AppDetail.model_validate(app_model, from_attributes=True) + return response_model.model_dump(mode="json") @console_ns.route("/apps//api-enable") @@ -623,21 +779,20 @@ class AppApiStatus(Resource): @console_ns.doc(description="Enable or disable app API") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[AppApiStatusPayload.__name__]) - @console_ns.response(200, "API status updated successfully", app_detail_model) + @console_ns.response(200, "API status updated successfully", console_ns.models[AppDetail.__name__]) @console_ns.response(403, "Insufficient permissions") @setup_required @login_required @is_admin_or_owner_required @account_initialization_required - @get_app_model - @marshal_with(app_detail_model) + @get_app_model(mode=None) def post(self, app_model): args = AppApiStatusPayload.model_validate(console_ns.payload) app_service = AppService() app_model = app_service.update_app_api_status(app_model, args.enable_api) - - return app_model + response_model = AppDetail.model_validate(app_model, from_attributes=True) + return response_model.model_dump(mode="json") @console_ns.route("/apps//trace") diff --git a/api/tests/unit_tests/controllers/console/app/test_app_response_models.py b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py new file mode 100644 index 0000000000..40eb59a8f4 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py @@ -0,0 +1,285 @@ +from __future__ import annotations + +import builtins +import sys +from datetime import datetime +from importlib import util +from pathlib import Path +from types import ModuleType, SimpleNamespace +from typing import Any + +import pytest +from flask.views import MethodView + +# kombu references MethodView as a global when importing celery/kombu pools. +if not hasattr(builtins, "MethodView"): + builtins.MethodView = MethodView # type: ignore[attr-defined] + + +def _load_app_module(): + module_name = "controllers.console.app.app" + if module_name in sys.modules: + return sys.modules[module_name] + + root = Path(__file__).resolve().parents[5] + module_path = root / "controllers" / "console" / "app" / "app.py" + + class _StubNamespace: + def __init__(self): + self.models: dict[str, Any] = {} + self.payload = None + + def schema_model(self, name, schema): + self.models[name] = schema + + def _decorator(self, obj): + return obj + + def doc(self, *args, **kwargs): + return self._decorator + + def expect(self, *args, **kwargs): + return self._decorator + + def response(self, *args, **kwargs): + return self._decorator + + def route(self, *args, **kwargs): + def decorator(obj): + return obj + + return decorator + + stub_namespace = _StubNamespace() + + original_console = sys.modules.get("controllers.console") + original_app_pkg = sys.modules.get("controllers.console.app") + stubbed_modules: list[tuple[str, ModuleType | None]] = [] + + console_module = ModuleType("controllers.console") + console_module.__path__ = [str(root / "controllers" / "console")] + console_module.console_ns = stub_namespace + console_module.api = None + console_module.bp = None + sys.modules["controllers.console"] = console_module + + app_package = ModuleType("controllers.console.app") + app_package.__path__ = [str(root / "controllers" / "console" / "app")] + sys.modules["controllers.console.app"] = app_package + console_module.app = app_package + + def _stub_module(name: str, attrs: dict[str, Any]): + original = sys.modules.get(name) + module = ModuleType(name) + for key, value in attrs.items(): + setattr(module, key, value) + sys.modules[name] = module + stubbed_modules.append((name, original)) + + class _OpsTraceManager: + @staticmethod + def get_app_tracing_config(app_id: str) -> dict[str, Any]: + return {} + + @staticmethod + def update_app_tracing_config(app_id: str, **kwargs) -> None: + return None + + _stub_module( + "core.ops.ops_trace_manager", + { + "OpsTraceManager": _OpsTraceManager, + "TraceQueueManager": object, + "TraceTask": object, + }, + ) + + spec = util.spec_from_file_location(module_name, module_path) + module = util.module_from_spec(spec) + sys.modules[module_name] = module + + try: + assert spec.loader is not None + spec.loader.exec_module(module) + finally: + for name, original in reversed(stubbed_modules): + if original is not None: + sys.modules[name] = original + else: + sys.modules.pop(name, None) + if original_console is not None: + sys.modules["controllers.console"] = original_console + else: + sys.modules.pop("controllers.console", None) + if original_app_pkg is not None: + sys.modules["controllers.console.app"] = original_app_pkg + else: + sys.modules.pop("controllers.console.app", None) + + return module + + +_app_module = _load_app_module() +AppDetailWithSite = _app_module.AppDetailWithSite +AppPagination = _app_module.AppPagination +AppPartial = _app_module.AppPartial + + +@pytest.fixture(autouse=True) +def patch_signed_url(monkeypatch): + """Ensure icon URL generation uses a deterministic helper for tests.""" + + def _fake_signed_url(key: str | None) -> str | None: + if not key: + return None + return f"signed:{key}" + + monkeypatch.setattr(_app_module.file_helpers, "get_signed_file_url", _fake_signed_url) + + +def _ts(hour: int = 12) -> datetime: + return datetime(2024, 1, 1, hour, 0, 0) + + +def _dummy_model_config(): + return SimpleNamespace( + model_dict={"provider": "openai", "name": "gpt-4o"}, + pre_prompt="hello", + created_by="config-author", + created_at=_ts(9), + updated_by="config-editor", + updated_at=_ts(10), + ) + + +def _dummy_workflow(): + return SimpleNamespace( + id="wf-1", + created_by="workflow-author", + created_at=_ts(8), + updated_by="workflow-editor", + updated_at=_ts(9), + ) + + +def test_app_partial_serialization_uses_aliases(): + created_at = _ts() + app_obj = SimpleNamespace( + id="app-1", + name="My App", + desc_or_prompt="Prompt snippet", + mode_compatible_with_agent="chat", + icon_type="image", + icon="icon-key", + icon_background="#fff", + app_model_config=_dummy_model_config(), + workflow=_dummy_workflow(), + created_by="creator", + created_at=created_at, + updated_by="editor", + updated_at=created_at, + tags=[SimpleNamespace(id="tag-1", name="Utilities", type="app")], + access_mode="private", + create_user_name="Creator", + author_name="Author", + has_draft_trigger=True, + ) + + serialized = AppPartial.model_validate(app_obj, from_attributes=True).model_dump(mode="json") + + assert serialized["description"] == "Prompt snippet" + assert serialized["mode"] == "chat" + assert serialized["icon_url"] == "signed:icon-key" + assert serialized["created_at"] == int(created_at.timestamp()) + assert serialized["updated_at"] == int(created_at.timestamp()) + assert serialized["model_config"]["model"] == {"provider": "openai", "name": "gpt-4o"} + assert serialized["workflow"]["id"] == "wf-1" + assert serialized["tags"][0]["name"] == "Utilities" + + +def test_app_detail_with_site_includes_nested_serialization(): + timestamp = _ts(14) + site = SimpleNamespace( + code="site-code", + title="Public Site", + icon_type="image", + icon="site-icon", + created_at=timestamp, + updated_at=timestamp, + ) + app_obj = SimpleNamespace( + id="app-2", + name="Detailed App", + description="Desc", + mode_compatible_with_agent="advanced-chat", + icon_type="image", + icon="detail-icon", + icon_background="#123456", + enable_site=True, + enable_api=True, + app_model_config={ + "opening_statement": "hi", + "model": {"provider": "openai", "name": "gpt-4o"}, + "retriever_resource": {"enabled": True}, + }, + workflow=_dummy_workflow(), + tracing={"enabled": True}, + use_icon_as_answer_icon=True, + created_by="creator", + created_at=timestamp, + updated_by="editor", + updated_at=timestamp, + access_mode="public", + tags=[SimpleNamespace(id="tag-2", name="Prod", type="app")], + api_base_url="https://api.example.com/v1", + max_active_requests=5, + deleted_tools=[{"type": "api", "tool_name": "search", "provider_id": "prov"}], + site=site, + ) + + serialized = AppDetailWithSite.model_validate(app_obj, from_attributes=True).model_dump(mode="json") + + assert serialized["icon_url"] == "signed:detail-icon" + assert serialized["model_config"]["retriever_resource"] == {"enabled": True} + assert serialized["deleted_tools"][0]["tool_name"] == "search" + assert serialized["site"]["icon_url"] == "signed:site-icon" + assert serialized["site"]["created_at"] == int(timestamp.timestamp()) + + +def test_app_pagination_aliases_per_page_and_has_next(): + item_one = SimpleNamespace( + id="app-10", + name="Paginated One", + desc_or_prompt="Summary", + mode_compatible_with_agent="chat", + icon_type="image", + icon="first-icon", + created_at=_ts(15), + updated_at=_ts(15), + ) + item_two = SimpleNamespace( + id="app-11", + name="Paginated Two", + desc_or_prompt="Summary", + mode_compatible_with_agent="agent-chat", + icon_type="emoji", + icon="🙂", + created_at=_ts(16), + updated_at=_ts(16), + ) + pagination = SimpleNamespace( + page=2, + per_page=10, + total=50, + has_next=True, + items=[item_one, item_two], + ) + + serialized = AppPagination.model_validate(pagination, from_attributes=True).model_dump(mode="json") + + assert serialized["page"] == 2 + assert serialized["limit"] == 10 + assert serialized["has_more"] is True + assert len(serialized["data"]) == 2 + assert serialized["data"][0]["icon_url"] == "signed:first-icon" + assert serialized["data"][1]["icon_url"] is None