mirror of https://github.com/langgenius/dify.git
merge main
This commit is contained in:
commit
25fef5d757
|
|
@ -83,9 +83,15 @@ jobs:
|
|||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
services: |
|
||||
db
|
||||
redis
|
||||
sandbox
|
||||
ssrf_proxy
|
||||
|
||||
- name: setup test config
|
||||
run: |
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
- name: Run Workflow
|
||||
run: uv run --project api bash dev/pytest/pytest_workflow.sh
|
||||
|
||||
|
|
|
|||
|
|
@ -84,6 +84,12 @@ jobs:
|
|||
elasticsearch
|
||||
oceanbase
|
||||
|
||||
- name: setup test config
|
||||
run: |
|
||||
echo $(pwd)
|
||||
ls -lah .
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
- name: Check VDB Ready (TiDB)
|
||||
run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
|
||||
|
|
|
|||
|
|
@ -230,6 +230,10 @@ Deploy Dify to AWS with [CDK](https://aws.amazon.com/cdk/)
|
|||
|
||||
Quickly deploy Dify to Alibaba cloud with [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
|
||||
|
||||
#### Using Alibaba Cloud Data Management
|
||||
|
||||
One-Click deploy Dify to Alibaba Cloud with [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
|
||||
## Contributing
|
||||
|
||||
|
|
|
|||
|
|
@ -211,6 +211,11 @@ docker compose up -d
|
|||
|
||||
#### استخدام Alibaba Cloud للنشر
|
||||
[بسرعة نشر Dify إلى سحابة علي بابا مع عش الحوسبة السحابية علي بابا](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
|
||||
|
||||
#### استخدام Alibaba Cloud Data Management للنشر
|
||||
|
||||
انشر Dify على علي بابا كلاود بنقرة واحدة باستخدام [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
|
||||
## المساهمة
|
||||
|
||||
|
|
|
|||
|
|
@ -229,6 +229,10 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন
|
|||
|
||||
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
|
||||
|
||||
#### Alibaba Cloud Data Management ব্যবহার করে ডিপ্লয়
|
||||
|
||||
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
|
||||
## Contributing
|
||||
|
||||
|
|
|
|||
|
|
@ -225,6 +225,10 @@ docker compose up -d
|
|||
|
||||
使用 [阿里云计算巢](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) 将 Dify 一键部署到 阿里云
|
||||
|
||||
#### 使用 阿里云数据管理DMS 部署
|
||||
|
||||
使用 [阿里云数据管理DMS](https://help.aliyun.com/zh/dms/dify-in-invitational-preview) 将 Dify 一键部署到 阿里云
|
||||
|
||||
|
||||
## Star History
|
||||
|
||||
|
|
|
|||
|
|
@ -225,6 +225,10 @@ Bereitstellung von Dify auf AWS mit [CDK](https://aws.amazon.com/cdk/)
|
|||
|
||||
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
|
||||
|
||||
#### Alibaba Cloud Data Management
|
||||
|
||||
Ein-Klick-Bereitstellung von Dify in der Alibaba Cloud mit [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
|
||||
## Contributing
|
||||
|
||||
|
|
|
|||
|
|
@ -225,6 +225,11 @@ Despliegue Dify en AWS usando [CDK](https://aws.amazon.com/cdk/)
|
|||
|
||||
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
|
||||
|
||||
#### Alibaba Cloud Data Management
|
||||
|
||||
Despliega Dify en Alibaba Cloud con un solo clic con [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
|
||||
## Contribuir
|
||||
|
||||
Para aquellos que deseen contribuir con código, consulten nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
|
|
|
|||
|
|
@ -223,6 +223,10 @@ Déployez Dify sur AWS en utilisant [CDK](https://aws.amazon.com/cdk/)
|
|||
|
||||
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
|
||||
|
||||
#### Alibaba Cloud Data Management
|
||||
|
||||
Déployez Dify en un clic sur Alibaba Cloud avec [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
|
||||
## Contribuer
|
||||
|
||||
|
|
|
|||
|
|
@ -155,7 +155,7 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ
|
|||
[こちら](https://dify.ai)のDify Cloudサービスを利用して、セットアップ不要で試すことができます。サンドボックスプランには、200回のGPT-4呼び出しが無料で含まれています。
|
||||
|
||||
- **Dify Community Editionのセルフホスティング</br>**
|
||||
この[スタートガイド](#quick-start)を使用して、ローカル環境でDifyを簡単に実行できます。
|
||||
この[スタートガイド](#クイックスタート)を使用して、ローカル環境でDifyを簡単に実行できます。
|
||||
詳しくは[ドキュメント](https://docs.dify.ai)をご覧ください。
|
||||
|
||||
- **企業/組織向けのDify</br>**
|
||||
|
|
@ -223,6 +223,9 @@ docker compose up -d
|
|||
#### Alibaba Cloud
|
||||
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
|
||||
|
||||
#### Alibaba Cloud Data Management
|
||||
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) を利用して、DifyをAlibaba Cloudへワンクリックでデプロイできます
|
||||
|
||||
|
||||
## 貢献
|
||||
|
||||
|
|
|
|||
|
|
@ -223,6 +223,10 @@ wa'logh nIqHom neH ghun deployment toy'wI' [CDK](https://aws.amazon.com/cdk/) lo
|
|||
|
||||
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
|
||||
|
||||
#### Alibaba Cloud Data Management
|
||||
|
||||
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
|
||||
## Contributing
|
||||
|
||||
|
|
|
|||
|
|
@ -217,6 +217,10 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했
|
|||
|
||||
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
|
||||
|
||||
#### Alibaba Cloud Data Management
|
||||
|
||||
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)를 통해 원클릭으로 Dify를 Alibaba Cloud에 배포할 수 있습니다
|
||||
|
||||
|
||||
## 기여
|
||||
|
||||
|
|
|
|||
|
|
@ -222,6 +222,10 @@ Implante o Dify na AWS usando [CDK](https://aws.amazon.com/cdk/)
|
|||
|
||||
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
|
||||
|
||||
#### Alibaba Cloud Data Management
|
||||
|
||||
Implante o Dify na Alibaba Cloud com um clique usando o [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
|
||||
## Contribuindo
|
||||
|
||||
|
|
|
|||
|
|
@ -223,6 +223,10 @@ Uvedite Dify v AWS z uporabo [CDK](https://aws.amazon.com/cdk/)
|
|||
|
||||
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
|
||||
|
||||
#### Alibaba Cloud Data Management
|
||||
|
||||
Z enim klikom namestite Dify na Alibaba Cloud z [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
|
||||
## Prispevam
|
||||
|
||||
|
|
|
|||
|
|
@ -216,6 +216,10 @@ Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.ter
|
|||
|
||||
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
|
||||
|
||||
#### Alibaba Cloud Data Management
|
||||
|
||||
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) kullanarak Dify'ı tek tıkla Alibaba Cloud'a dağıtın
|
||||
|
||||
|
||||
## Katkıda Bulunma
|
||||
|
||||
|
|
|
|||
|
|
@ -228,6 +228,10 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify
|
|||
|
||||
[阿里云](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
|
||||
|
||||
#### 使用 阿里雲數據管理DMS 進行部署
|
||||
|
||||
透過 [阿里雲數據管理DMS](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/),一鍵將 Dify 部署至阿里雲
|
||||
|
||||
|
||||
## 貢獻
|
||||
|
||||
|
|
|
|||
|
|
@ -219,6 +219,10 @@ Triển khai Dify trên AWS bằng [CDK](https://aws.amazon.com/cdk/)
|
|||
|
||||
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
|
||||
|
||||
#### Alibaba Cloud Data Management
|
||||
|
||||
Triển khai Dify lên Alibaba Cloud chỉ với một cú nhấp chuột bằng [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
|
||||
## Đóng góp
|
||||
|
||||
|
|
|
|||
103
api/.ruff.toml
103
api/.ruff.toml
|
|
@ -1,6 +1,4 @@
|
|||
exclude = [
|
||||
"migrations/*",
|
||||
]
|
||||
exclude = ["migrations/*"]
|
||||
line-length = 120
|
||||
|
||||
[format]
|
||||
|
|
@ -9,14 +7,14 @@ quote-style = "double"
|
|||
[lint]
|
||||
preview = false
|
||||
select = [
|
||||
"B", # flake8-bugbear rules
|
||||
"C4", # flake8-comprehensions
|
||||
"E", # pycodestyle E rules
|
||||
"F", # pyflakes rules
|
||||
"FURB", # refurb rules
|
||||
"I", # isort rules
|
||||
"N", # pep8-naming
|
||||
"PT", # flake8-pytest-style rules
|
||||
"B", # flake8-bugbear rules
|
||||
"C4", # flake8-comprehensions
|
||||
"E", # pycodestyle E rules
|
||||
"F", # pyflakes rules
|
||||
"FURB", # refurb rules
|
||||
"I", # isort rules
|
||||
"N", # pep8-naming
|
||||
"PT", # flake8-pytest-style rules
|
||||
"PLC0208", # iteration-over-set
|
||||
"PLC0414", # useless-import-alias
|
||||
"PLE0604", # invalid-all-object
|
||||
|
|
@ -24,19 +22,19 @@ select = [
|
|||
"PLR0402", # manual-from-import
|
||||
"PLR1711", # useless-return
|
||||
"PLR1714", # repeated-equality-comparison
|
||||
"RUF013", # implicit-optional
|
||||
"RUF019", # unnecessary-key-check
|
||||
"RUF100", # unused-noqa
|
||||
"RUF101", # redirected-noqa
|
||||
"RUF200", # invalid-pyproject-toml
|
||||
"RUF022", # unsorted-dunder-all
|
||||
"S506", # unsafe-yaml-load
|
||||
"SIM", # flake8-simplify rules
|
||||
"TRY400", # error-instead-of-exception
|
||||
"TRY401", # verbose-log-message
|
||||
"UP", # pyupgrade rules
|
||||
"W191", # tab-indentation
|
||||
"W605", # invalid-escape-sequence
|
||||
"RUF013", # implicit-optional
|
||||
"RUF019", # unnecessary-key-check
|
||||
"RUF100", # unused-noqa
|
||||
"RUF101", # redirected-noqa
|
||||
"RUF200", # invalid-pyproject-toml
|
||||
"RUF022", # unsorted-dunder-all
|
||||
"S506", # unsafe-yaml-load
|
||||
"SIM", # flake8-simplify rules
|
||||
"TRY400", # error-instead-of-exception
|
||||
"TRY401", # verbose-log-message
|
||||
"UP", # pyupgrade rules
|
||||
"W191", # tab-indentation
|
||||
"W605", # invalid-escape-sequence
|
||||
# security related linting rules
|
||||
# RCE proctection (sort of)
|
||||
"S102", # exec-builtin, disallow use of `exec`
|
||||
|
|
@ -47,36 +45,37 @@ select = [
|
|||
]
|
||||
|
||||
ignore = [
|
||||
"E402", # module-import-not-at-top-of-file
|
||||
"E711", # none-comparison
|
||||
"E712", # true-false-comparison
|
||||
"E721", # type-comparison
|
||||
"E722", # bare-except
|
||||
"F821", # undefined-name
|
||||
"F841", # unused-variable
|
||||
"E402", # module-import-not-at-top-of-file
|
||||
"E711", # none-comparison
|
||||
"E712", # true-false-comparison
|
||||
"E721", # type-comparison
|
||||
"E722", # bare-except
|
||||
"F821", # undefined-name
|
||||
"F841", # unused-variable
|
||||
"FURB113", # repeated-append
|
||||
"FURB152", # math-constant
|
||||
"UP007", # non-pep604-annotation
|
||||
"UP032", # f-string
|
||||
"UP045", # non-pep604-annotation-optional
|
||||
"B005", # strip-with-multi-characters
|
||||
"B006", # mutable-argument-default
|
||||
"B007", # unused-loop-control-variable
|
||||
"B026", # star-arg-unpacking-after-keyword-arg
|
||||
"B903", # class-as-data-structure
|
||||
"B904", # raise-without-from-inside-except
|
||||
"B905", # zip-without-explicit-strict
|
||||
"N806", # non-lowercase-variable-in-function
|
||||
"N815", # mixed-case-variable-in-class-scope
|
||||
"PT011", # pytest-raises-too-broad
|
||||
"SIM102", # collapsible-if
|
||||
"SIM103", # needless-bool
|
||||
"SIM105", # suppressible-exception
|
||||
"SIM107", # return-in-try-except-finally
|
||||
"SIM108", # if-else-block-instead-of-if-exp
|
||||
"SIM113", # enumerate-for-loop
|
||||
"SIM117", # multiple-with-statements
|
||||
"SIM210", # if-expr-with-true-false
|
||||
"UP007", # non-pep604-annotation
|
||||
"UP032", # f-string
|
||||
"UP045", # non-pep604-annotation-optional
|
||||
"B005", # strip-with-multi-characters
|
||||
"B006", # mutable-argument-default
|
||||
"B007", # unused-loop-control-variable
|
||||
"B026", # star-arg-unpacking-after-keyword-arg
|
||||
"B903", # class-as-data-structure
|
||||
"B904", # raise-without-from-inside-except
|
||||
"B905", # zip-without-explicit-strict
|
||||
"N806", # non-lowercase-variable-in-function
|
||||
"N815", # mixed-case-variable-in-class-scope
|
||||
"PT011", # pytest-raises-too-broad
|
||||
"SIM102", # collapsible-if
|
||||
"SIM103", # needless-bool
|
||||
"SIM105", # suppressible-exception
|
||||
"SIM107", # return-in-try-except-finally
|
||||
"SIM108", # if-else-block-instead-of-if-exp
|
||||
"SIM113", # enumerate-for-loop
|
||||
"SIM117", # multiple-with-statements
|
||||
"SIM210", # if-expr-with-true-false
|
||||
"UP038", # deprecated and not recommended by Ruff, https://docs.astral.sh/ruff/rules/non-pep604-isinstance/
|
||||
]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
|||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="1.4.3",
|
||||
default="1.5.0",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
|
|
|||
|
|
@ -63,6 +63,7 @@ from .app import (
|
|||
statistic,
|
||||
workflow,
|
||||
workflow_app_log,
|
||||
workflow_draft_variable,
|
||||
workflow_run,
|
||||
workflow_statistic,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from flask import abort, request
|
||||
|
|
@ -18,10 +19,12 @@ from controllers.console.app.error import (
|
|||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from factories import file_factory, variable_factory
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
||||
from libs import helper
|
||||
|
|
@ -30,6 +33,7 @@ from libs.login import current_user, login_required
|
|||
from models import App
|
||||
from models.account import Account
|
||||
from models.model import AppMode
|
||||
from models.workflow import Workflow
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
|
@ -38,6 +42,24 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
|
||||
# at the controller level rather than in the workflow logic. This would improve separation
|
||||
# of concerns and make the code more maintainable.
|
||||
def _parse_file(workflow: Workflow, files: list[dict] | None = None) -> Sequence[File]:
|
||||
files = files or []
|
||||
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
file_objs: Sequence[File] = []
|
||||
if file_extra_config is None:
|
||||
return file_objs
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=workflow.tenant_id,
|
||||
config=file_extra_config,
|
||||
)
|
||||
return file_objs
|
||||
|
||||
|
||||
class DraftWorkflowApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -402,15 +424,30 @@ class DraftWorkflowNodeRunApi(Resource):
|
|||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("query", type=str, required=False, location="json", default="")
|
||||
parser.add_argument("files", type=list, location="json", default=[])
|
||||
args = parser.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
if inputs == None:
|
||||
user_inputs = args.get("inputs")
|
||||
if user_inputs is None:
|
||||
raise ValueError("missing inputs")
|
||||
|
||||
workflow_srv = WorkflowService()
|
||||
# fetch draft workflow by app_model
|
||||
draft_workflow = workflow_srv.get_draft_workflow(app_model=app_model)
|
||||
if not draft_workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
files = _parse_file(draft_workflow, args.get("files"))
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
workflow_node_execution = workflow_service.run_draft_workflow_node(
|
||||
app_model=app_model, node_id=node_id, user_inputs=inputs, account=current_user
|
||||
app_model=app_model,
|
||||
draft_workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
account=current_user,
|
||||
query=args.get("query", ""),
|
||||
files=files,
|
||||
)
|
||||
|
||||
return workflow_node_execution
|
||||
|
|
@ -731,6 +768,27 @@ class WorkflowByIdApi(Resource):
|
|||
return None, 204
|
||||
|
||||
|
||||
class DraftWorkflowNodeLastRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_node_execution_fields)
|
||||
def get(self, app_model: App, node_id: str):
|
||||
srv = WorkflowService()
|
||||
workflow = srv.get_draft_workflow(app_model)
|
||||
if not workflow:
|
||||
raise NotFound("Workflow not found")
|
||||
node_exec = srv.get_node_last_run(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
)
|
||||
if node_exec is None:
|
||||
raise NotFound("last run not found")
|
||||
return node_exec
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DraftWorkflowApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft",
|
||||
|
|
@ -795,3 +853,7 @@ api.add_resource(
|
|||
WorkflowByIdApi,
|
||||
"/apps/<uuid:app_id>/workflows/<string:workflow_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftWorkflowNodeLastRunApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/last-run",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,421 @@
|
|||
import logging
|
||||
from typing import Any, NoReturn
|
||||
|
||||
from flask import Response
|
||||
from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (
|
||||
DraftWorkflowNotExist,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.variables.segment_group import SegmentGroup
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import current_user, login_required
|
||||
from models import App, AppMode, db
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_values_to_json_serializable_object(value: Segment) -> Any:
|
||||
if isinstance(value, FileSegment):
|
||||
return value.value.model_dump()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
return [i.model_dump() for i in value.value]
|
||||
elif isinstance(value, SegmentGroup):
|
||||
return [_convert_values_to_json_serializable_object(i) for i in value.value]
|
||||
else:
|
||||
return value.value
|
||||
|
||||
|
||||
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
|
||||
value = variable.get_value()
|
||||
# create a copy of the value to avoid affecting the model cache.
|
||||
value = value.model_copy(deep=True)
|
||||
# Refresh the url signature before returning it to client.
|
||||
if isinstance(value, FileSegment):
|
||||
file = value.value
|
||||
file.remote_url = file.generate_url()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
files = value.value
|
||||
for file in files:
|
||||
file.remote_url = file.generate_url()
|
||||
return _convert_values_to_json_serializable_object(value)
|
||||
|
||||
|
||||
def _create_pagination_parser():
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"page",
|
||||
type=inputs.int_range(1, 100_000),
|
||||
required=False,
|
||||
default=1,
|
||||
location="args",
|
||||
help="the page of data requested",
|
||||
)
|
||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
return parser
|
||||
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
|
||||
"id": fields.String,
|
||||
"type": fields.String(attribute=lambda model: model.get_variable_type()),
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
|
||||
"value_type": fields.String,
|
||||
"edited": fields.Boolean(attribute=lambda model: model.edited),
|
||||
"visible": fields.Boolean,
|
||||
}
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
|
||||
value=fields.Raw(attribute=_serialize_var_value),
|
||||
)
|
||||
|
||||
_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
|
||||
"id": fields.String,
|
||||
"type": fields.String(attribute=lambda _: "env"),
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
|
||||
"value_type": fields.String,
|
||||
"edited": fields.Boolean(attribute=lambda model: model.edited),
|
||||
"visible": fields.Boolean,
|
||||
}
|
||||
|
||||
_WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS = {
|
||||
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)),
|
||||
}
|
||||
|
||||
|
||||
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
|
||||
return var_list.variables
|
||||
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = {
|
||||
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items),
|
||||
"total": fields.Raw(),
|
||||
}
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
|
||||
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
|
||||
}
|
||||
|
||||
|
||||
def _api_prerequisite(f):
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
It ensures the following conditions are satisfied:
|
||||
|
||||
- Dify has been property setup.
|
||||
- The request user has logged in and initialized.
|
||||
- The requested app is a workflow or a chat flow.
|
||||
- The request user has the edit permission for the app.
|
||||
"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def wrapper(*args, **kwargs):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class WorkflowVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
parser = _create_pagination_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
workflow_service = WorkflowService()
|
||||
workflow_exist = workflow_service.is_workflow_exist(app_model=app_model)
|
||||
if not workflow_exist:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
workflow_vars = draft_var_srv.list_variables_without_values(
|
||||
app_id=app_model.id,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
)
|
||||
|
||||
return workflow_vars
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
draft_var_srv.delete_workflow_variables(app_model.id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
def validate_node_id(node_id: str) -> NoReturn | None:
|
||||
if node_id in [
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
]:
|
||||
# NOTE(QuantumGhost): While we store the system and conversation variables as node variables
|
||||
# with specific `node_id` in database, we still want to make the API separated. By disallowing
|
||||
# accessing system and conversation variables in `WorkflowDraftNodeVariableListApi`,
|
||||
# we mitigate the risk that user of the API depending on the implementation detail of the API.
|
||||
#
|
||||
# ref: [Hyrum's Law](https://www.hyrumslaw.com/)
|
||||
|
||||
raise InvalidArgumentError(
|
||||
f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class NodeVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
def get(self, app_model: App, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
node_vars = draft_var_srv.list_node_variables(app_model.id, node_id)
|
||||
|
||||
return node_vars
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
srv = WorkflowDraftVariableService(db.session())
|
||||
srv.delete_node_variables(app_model.id, node_id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
class VariableApi(Resource):
|
||||
_PATCH_NAME_FIELD = "name"
|
||||
_PATCH_VALUE_FIELD = "value"
|
||||
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
def get(self, app_model: App, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
return variable
|
||||
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
def patch(self, app_model: App, variable_id: str):
|
||||
# Request payload for file types:
|
||||
#
|
||||
# Local File:
|
||||
#
|
||||
# {
|
||||
# "type": "image",
|
||||
# "transfer_method": "local_file",
|
||||
# "url": "",
|
||||
# "upload_file_id": "daded54f-72c7-4f8e-9d18-9b0abdd9f190"
|
||||
# }
|
||||
#
|
||||
# Remote File:
|
||||
#
|
||||
#
|
||||
# {
|
||||
# "type": "image",
|
||||
# "transfer_method": "remote_url",
|
||||
# "url": "http://127.0.0.1:5001/files/1602650a-4fe4-423c-85a2-af76c083e3c4/file-preview?timestamp=1750041099&nonce=...&sign=...=",
|
||||
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||
# }
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||
# Parse 'value' field as-is to maintain its original data structure
|
||||
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
||||
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
args = parser.parse_args(strict=True)
|
||||
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
new_name = args.get(self._PATCH_NAME_FIELD, None)
|
||||
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
|
||||
if new_name is None and raw_value is None:
|
||||
return variable
|
||||
|
||||
new_value = None
|
||||
if raw_value is not None:
|
||||
if variable.value_type == SegmentType.FILE:
|
||||
if not isinstance(raw_value, dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||
raw_value = build_from_mapping(mapping=raw_value, tenant_id=app_model.tenant_id)
|
||||
elif variable.value_type == SegmentType.ARRAY_FILE:
|
||||
if not isinstance(raw_value, list):
|
||||
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id)
|
||||
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
db.session.commit()
|
||||
return variable
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
draft_var_srv.delete_variable(variable)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
class VariableResetApi(Resource):
|
||||
@_api_prerequisite
|
||||
def put(self, app_model: App, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
|
||||
workflow_srv = WorkflowService()
|
||||
draft_workflow = workflow_srv.get_draft_workflow(app_model)
|
||||
if draft_workflow is None:
|
||||
raise NotFoundError(
|
||||
f"Draft workflow not found, app_id={app_model.id}",
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
|
||||
db.session.commit()
|
||||
if resetted is None:
|
||||
return Response("", 204)
|
||||
else:
|
||||
return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
|
||||
|
||||
def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
if node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_conversation_variables(app_model.id)
|
||||
elif node_id == SYSTEM_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_system_variables(app_model.id)
|
||||
else:
|
||||
draft_vars = draft_var_srv.list_node_variables(app_id=app_model.id, node_id=node_id)
|
||||
return draft_vars
|
||||
|
||||
|
||||
class ConversationVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
def get(self, app_model: App):
|
||||
# NOTE(QuantumGhost): Prefill conversation variables into the draft variables table
|
||||
# so their IDs can be returned to the caller.
|
||||
workflow_srv = WorkflowService()
|
||||
draft_workflow = workflow_srv.get_draft_workflow(app_model)
|
||||
if draft_workflow is None:
|
||||
raise NotFoundError(description=f"draft workflow not found, id={app_model.id}")
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow)
|
||||
db.session.commit()
|
||||
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
|
||||
|
||||
|
||||
class SystemVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
def get(self, app_model: App):
|
||||
return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID)
|
||||
|
||||
|
||||
class EnvironmentVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
# fetch draft workflow by app_model
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.get_draft_workflow(app_model=app_model)
|
||||
if workflow is None:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
env_vars = workflow.environment_variables
|
||||
env_vars_list = []
|
||||
for v in env_vars:
|
||||
env_vars_list.append(
|
||||
{
|
||||
"id": v.id,
|
||||
"type": "env",
|
||||
"name": v.name,
|
||||
"description": v.description,
|
||||
"selector": v.selector,
|
||||
"value_type": v.value_type.value,
|
||||
"value": v.value,
|
||||
# Do not track edited for env vars.
|
||||
"edited": False,
|
||||
"visible": True,
|
||||
"editable": True,
|
||||
}
|
||||
)
|
||||
|
||||
return {"items": env_vars_list}
|
||||
|
||||
|
||||
api.add_resource(
|
||||
WorkflowVariableCollectionApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/variables",
|
||||
)
|
||||
api.add_resource(NodeVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
|
||||
api.add_resource(VariableApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>")
|
||||
api.add_resource(VariableResetApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>/reset")
|
||||
|
||||
api.add_resource(ConversationVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/conversation-variables")
|
||||
api.add_resource(SystemVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/system-variables")
|
||||
api.add_resource(EnvironmentVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/environment-variables")
|
||||
|
|
@ -8,6 +8,15 @@ from libs.login import current_user
|
|||
from models import App, AppMode
|
||||
|
||||
|
||||
def _load_app_model(app_id: str) -> Optional[App]:
|
||||
app_model = (
|
||||
db.session.query(App)
|
||||
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
return app_model
|
||||
|
||||
|
||||
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
||||
def decorator(view_func):
|
||||
@wraps(view_func)
|
||||
|
|
@ -20,11 +29,7 @@ def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[
|
|||
|
||||
del kwargs["app_id"]
|
||||
|
||||
app_model = (
|
||||
db.session.query(App)
|
||||
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
app_model = _load_app_model(app_id)
|
||||
|
||||
if not app_model:
|
||||
raise AppNotFoundError()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from typing import cast
|
|||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
|
||||
from flask_restful import Resource, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import asc, desc, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
|
|
@ -239,12 +239,10 @@ class DatasetDocumentListApi(Resource):
|
|||
|
||||
return response
|
||||
|
||||
documents_and_batch_fields = {"documents": fields.List(fields.Nested(document_fields)), "batch": fields.String}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(documents_and_batch_fields)
|
||||
@marshal_with(dataset_and_document_fields)
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id):
|
||||
|
|
@ -290,6 +288,8 @@ class DatasetDocumentListApi(Resource):
|
|||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
|
|
@ -297,7 +297,7 @@ class DatasetDocumentListApi(Resource):
|
|||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
|
||||
return {"documents": documents, "batch": batch}
|
||||
return {"dataset": dataset, "documents": documents, "batch": batch}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
|
|||
|
|
@ -85,6 +85,7 @@ class MemberInviteEmailApi(Resource):
|
|||
return {
|
||||
"result": "success",
|
||||
"invitation_results": invitation_results,
|
||||
"tenant_id": str(current_user.current_tenant.id),
|
||||
}, 201
|
||||
|
||||
|
||||
|
|
@ -110,7 +111,7 @@ class MemberCancelInviteApi(Resource):
|
|||
except Exception as e:
|
||||
raise ValueError(str(e))
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return {"result": "success", "tenant_id": str(current_user.current_tenant.id)}, 200
|
||||
|
||||
|
||||
class MemberUpdateRoleApi(Resource):
|
||||
|
|
|
|||
|
|
@ -139,3 +139,13 @@ class InvokeRateLimitError(BaseHTTPException):
|
|||
error_code = "rate_limit_error"
|
||||
description = "Rate Limit Error"
|
||||
code = 429
|
||||
|
||||
|
||||
class NotFoundError(BaseHTTPException):
|
||||
error_code = "not_found"
|
||||
code = 404
|
||||
|
||||
|
||||
class InvalidArgumentError(BaseHTTPException):
|
||||
error_code = "invalid_param"
|
||||
code = 400
|
||||
|
|
|
|||
|
|
@ -104,6 +104,7 @@ class VariableEntity(BaseModel):
|
|||
Variable Entity.
|
||||
"""
|
||||
|
||||
# `variable` records the name of the variable in user inputs.
|
||||
variable: str
|
||||
label: str
|
||||
description: str = ""
|
||||
|
|
|
|||
|
|
@ -29,13 +29,14 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
|||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -116,6 +117,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
)
|
||||
|
||||
# parse files
|
||||
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
|
||||
# for better separation of concerns.
|
||||
#
|
||||
# For implementation reference, see the `_parse_file` function and
|
||||
# `DraftWorkflowNodeRunApi` class which handle this properly.
|
||||
files = args["files"] if args.get("files") else []
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
if file_extra_config:
|
||||
|
|
@ -261,6 +267,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
var_loader = DraftVarLoader(
|
||||
engine=db.engine,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
)
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow)
|
||||
|
||||
return self._generate(
|
||||
workflow=workflow,
|
||||
|
|
@ -271,6 +284,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
conversation=None,
|
||||
stream=streaming,
|
||||
variable_loader=var_loader,
|
||||
)
|
||||
|
||||
def single_loop_generate(
|
||||
|
|
@ -336,6 +350,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
var_loader = DraftVarLoader(
|
||||
engine=db.engine,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
)
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow)
|
||||
|
||||
return self._generate(
|
||||
workflow=workflow,
|
||||
|
|
@ -346,6 +367,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
conversation=None,
|
||||
stream=streaming,
|
||||
variable_loader=var_loader,
|
||||
)
|
||||
|
||||
def _generate(
|
||||
|
|
@ -359,6 +381,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
conversation: Optional[Conversation] = None,
|
||||
stream: bool = True,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
|
@ -410,6 +433,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
"conversation_id": conversation.id,
|
||||
"message_id": message.id,
|
||||
"context": context,
|
||||
"variable_loader": variable_loader,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -438,6 +462,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
conversation_id: str,
|
||||
message_id: str,
|
||||
context: contextvars.Context,
|
||||
variable_loader: VariableLoader,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
|
|
@ -454,8 +479,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
if message is None:
|
||||
raise MessageNotExistsError("Message not exists")
|
||||
|
||||
# chatbot app
|
||||
runner = AdvancedChatAppRunner(
|
||||
|
|
@ -464,6 +487,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
conversation=conversation,
|
||||
message=message,
|
||||
dialogue_count=self._dialogue_count,
|
||||
variable_loader=variable_loader,
|
||||
)
|
||||
|
||||
runner.run()
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from core.moderation.base import ModerationError
|
|||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
|
|
@ -40,14 +41,17 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
conversation: Conversation,
|
||||
message: Message,
|
||||
dialogue_count: int,
|
||||
variable_loader: VariableLoader,
|
||||
) -> None:
|
||||
super().__init__(queue_manager)
|
||||
|
||||
super().__init__(queue_manager, variable_loader)
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.conversation = conversation
|
||||
self.message = message
|
||||
self._dialogue_count = dialogue_count
|
||||
|
||||
def _get_app_id(self) -> str:
|
||||
return self.application_generate_entity.app_config.app_id
|
||||
|
||||
def run(self) -> None:
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(AdvancedChatAppConfig, app_config)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,6 @@ from factories import file_factory
|
|||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, App, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -124,6 +123,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
override_model_config_dict["retriever_resource"] = {"enabled": True}
|
||||
|
||||
# parse files
|
||||
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
|
||||
# for better separation of concerns.
|
||||
#
|
||||
# For implementation reference, see the `_parse_file` function and
|
||||
# `DraftWorkflowNodeRunApi` class which handle this properly.
|
||||
files = args.get("files") or []
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
|
|
@ -233,8 +237,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
if message is None:
|
||||
raise MessageNotExistsError("Message not exists")
|
||||
|
||||
# chatbot app
|
||||
runner = AgentChatAppRunner()
|
||||
|
|
|
|||
|
|
@ -25,7 +25,6 @@ from factories import file_factory
|
|||
from models.account import Account
|
||||
from models.model import App, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -115,6 +114,11 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||
override_model_config_dict["retriever_resource"] = {"enabled": True}
|
||||
|
||||
# parse files
|
||||
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
|
||||
# for better separation of concerns.
|
||||
#
|
||||
# For implementation reference, see the `_parse_file` function and
|
||||
# `DraftWorkflowNodeRunApi` class which handle this properly.
|
||||
files = args["files"] if args.get("files") else []
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
|
|
@ -219,8 +223,6 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
if message is None:
|
||||
raise MessageNotExistsError("Message not exists")
|
||||
|
||||
# chatbot app
|
||||
runner = ChatAppRunner()
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ from core.workflow.entities.workflow_execution import WorkflowExecution
|
|||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from models import (
|
||||
Account,
|
||||
CreatorUserRole,
|
||||
|
|
@ -125,7 +126,7 @@ class WorkflowResponseConverter:
|
|||
id=workflow_execution.id_,
|
||||
workflow_id=workflow_execution.workflow_id,
|
||||
status=workflow_execution.status,
|
||||
outputs=workflow_execution.outputs,
|
||||
outputs=WorkflowRuntimeTypeConverter().to_json_encodable(workflow_execution.outputs),
|
||||
error=workflow_execution.error_message,
|
||||
elapsed_time=workflow_execution.elapsed_time,
|
||||
total_tokens=workflow_execution.total_tokens,
|
||||
|
|
@ -202,6 +203,8 @@ class WorkflowResponseConverter:
|
|||
if not workflow_node_execution.finished_at:
|
||||
return None
|
||||
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
|
||||
return NodeFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_execution_id,
|
||||
|
|
@ -214,7 +217,7 @@ class WorkflowResponseConverter:
|
|||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.inputs,
|
||||
process_data=workflow_node_execution.process_data,
|
||||
outputs=workflow_node_execution.outputs,
|
||||
outputs=json_converter.to_json_encodable(workflow_node_execution.outputs),
|
||||
status=workflow_node_execution.status,
|
||||
error=workflow_node_execution.error,
|
||||
elapsed_time=workflow_node_execution.elapsed_time,
|
||||
|
|
@ -245,6 +248,8 @@ class WorkflowResponseConverter:
|
|||
if not workflow_node_execution.finished_at:
|
||||
return None
|
||||
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
|
||||
return NodeRetryStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_execution_id,
|
||||
|
|
@ -257,7 +262,7 @@ class WorkflowResponseConverter:
|
|||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.inputs,
|
||||
process_data=workflow_node_execution.process_data,
|
||||
outputs=workflow_node_execution.outputs,
|
||||
outputs=json_converter.to_json_encodable(workflow_node_execution.outputs),
|
||||
status=workflow_node_execution.status,
|
||||
error=workflow_node_execution.error,
|
||||
elapsed_time=workflow_node_execution.elapsed_time,
|
||||
|
|
@ -376,6 +381,7 @@ class WorkflowResponseConverter:
|
|||
workflow_execution_id: str,
|
||||
event: QueueIterationCompletedEvent,
|
||||
) -> IterationNodeCompletedStreamResponse:
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
return IterationNodeCompletedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_execution_id,
|
||||
|
|
@ -384,7 +390,7 @@ class WorkflowResponseConverter:
|
|||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
outputs=event.outputs,
|
||||
outputs=json_converter.to_json_encodable(event.outputs),
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
|
|
@ -463,7 +469,7 @@ class WorkflowResponseConverter:
|
|||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
outputs=event.outputs,
|
||||
outputs=WorkflowRuntimeTypeConverter().to_json_encodable(event.outputs),
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
|
|
|
|||
|
|
@ -101,6 +101,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
)
|
||||
|
||||
# parse files
|
||||
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
|
||||
# for better separation of concerns.
|
||||
#
|
||||
# For implementation reference, see the `_parse_file` function and
|
||||
# `DraftWorkflowNodeRunApi` class which handle this properly.
|
||||
files = args["files"] if args.get("files") else []
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
|
|
@ -196,8 +201,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
try:
|
||||
# get message
|
||||
message = self._get_message(message_id)
|
||||
if message is None:
|
||||
raise MessageNotExistsError()
|
||||
|
||||
# chatbot app
|
||||
runner = CompletionAppRunner()
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from models.enums import CreatorUserRole
|
|||
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -251,7 +252,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
|
||||
return introduction or ""
|
||||
|
||||
def _get_conversation(self, conversation_id: str):
|
||||
def _get_conversation(self, conversation_id: str) -> Conversation:
|
||||
"""
|
||||
Get conversation by conversation id
|
||||
:param conversation_id: conversation id
|
||||
|
|
@ -260,11 +261,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
|
||||
|
||||
if not conversation:
|
||||
raise ConversationNotExistsError()
|
||||
raise ConversationNotExistsError("Conversation not exists")
|
||||
|
||||
return conversation
|
||||
|
||||
def _get_message(self, message_id: str) -> Optional[Message]:
|
||||
def _get_message(self, message_id: str) -> Message:
|
||||
"""
|
||||
Get message by message id
|
||||
:param message_id: message id
|
||||
|
|
@ -272,4 +273,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
"""
|
||||
message = db.session.query(Message).filter(Message.id == message_id).first()
|
||||
|
||||
if message is None:
|
||||
raise MessageNotExistsError("Message not exists")
|
||||
|
||||
return message
|
||||
|
|
|
|||
|
|
@ -27,11 +27,13 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
|||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -94,6 +96,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
# parse files
|
||||
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
|
||||
# for better separation of concerns.
|
||||
#
|
||||
# For implementation reference, see the `_parse_file` function and
|
||||
# `DraftWorkflowNodeRunApi` class which handle this properly.
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
system_files = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
|
|
@ -186,6 +193,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
streaming: bool = True,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
|
@ -219,6 +227,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
"queue_manager": queue_manager,
|
||||
"context": context,
|
||||
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||
"variable_loader": variable_loader,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -303,6 +312,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow)
|
||||
var_loader = DraftVarLoader(
|
||||
engine=db.engine,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
|
|
@ -313,6 +329,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
variable_loader=var_loader,
|
||||
)
|
||||
|
||||
def single_loop_generate(
|
||||
|
|
@ -379,7 +396,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow)
|
||||
var_loader = DraftVarLoader(
|
||||
engine=db.engine,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
)
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
|
|
@ -389,6 +412,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
variable_loader=var_loader,
|
||||
)
|
||||
|
||||
def _generate_worker(
|
||||
|
|
@ -397,6 +421,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context,
|
||||
variable_loader: VariableLoader,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
|
|
@ -415,6 +440,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
variable_loader=variable_loader,
|
||||
)
|
||||
|
||||
runner.run()
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from core.app.entities.app_invoke_entities import (
|
|||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
|
|
@ -30,6 +31,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
self,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
variable_loader: VariableLoader,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
|
|
@ -37,10 +39,13 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
:param queue_manager: application queue manager
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
super().__init__(queue_manager, variable_loader)
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.queue_manager = queue_manager
|
||||
self.workflow_thread_pool_id = workflow_thread_pool_id
|
||||
|
||||
def _get_app_id(self) -> str:
|
||||
return self.application_generate_entity.app_config.app_id
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Run application
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.queue_entities import (
|
||||
|
|
@ -33,6 +35,7 @@ from core.workflow.entities.variable_pool import VariablePool
|
|||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
AgentLogEvent,
|
||||
BaseNodeEvent,
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
|
|
@ -62,15 +65,23 @@ from core.workflow.graph_engine.entities.event import (
|
|||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.model import App
|
||||
from models.workflow import Workflow
|
||||
from services.workflow_draft_variable_service import (
|
||||
DraftVariableSaver,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowBasedAppRunner(AppRunner):
|
||||
def __init__(self, queue_manager: AppQueueManager):
|
||||
def __init__(self, queue_manager: AppQueueManager, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER) -> None:
|
||||
self.queue_manager = queue_manager
|
||||
self._variable_loader = variable_loader
|
||||
|
||||
def _get_app_id(self) -> str:
|
||||
raise NotImplementedError("not implemented")
|
||||
|
||||
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
|
||||
"""
|
||||
|
|
@ -173,6 +184,13 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
||||
load_into_variable_pool(
|
||||
variable_loader=self._variable_loader,
|
||||
variable_pool=variable_pool,
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
)
|
||||
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
|
|
@ -262,6 +280,12 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
load_into_variable_pool(
|
||||
self._variable_loader,
|
||||
variable_pool=variable_pool,
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
)
|
||||
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
|
|
@ -376,6 +400,8 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
self._save_draft_var_for_event(event)
|
||||
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeFailedEvent(
|
||||
|
|
@ -438,6 +464,8 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
self._save_draft_var_for_event(event)
|
||||
|
||||
elif isinstance(event, NodeInIterationFailedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeInIterationFailedEvent(
|
||||
|
|
@ -690,3 +718,30 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
|
||||
def _publish_event(self, event: AppQueueEvent) -> None:
|
||||
self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
def _save_draft_var_for_event(self, event: BaseNodeEvent):
|
||||
run_result = event.route_node_state.node_run_result
|
||||
if run_result is None:
|
||||
return
|
||||
process_data = run_result.process_data
|
||||
outputs = run_result.outputs
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
draft_var_saver = DraftVariableSaver(
|
||||
session=session,
|
||||
app_id=self._get_app_id(),
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
# FIXME(QuantumGhost): rely on private state of queue_manager is not ideal.
|
||||
invoke_from=self.queue_manager._invoke_from,
|
||||
node_execution_id=event.id,
|
||||
enclosing_node_id=event.in_loop_id or event.in_iteration_id or None,
|
||||
)
|
||||
draft_var_saver.save(process_data=process_data, outputs=outputs)
|
||||
|
||||
|
||||
def _remove_first_element_from_variable_string(key: str) -> str:
|
||||
"""
|
||||
Remove the first element from the prefix.
|
||||
"""
|
||||
prefix, remaining = key.split(".", maxsplit=1)
|
||||
return remaining
|
||||
|
|
|
|||
|
|
@ -17,9 +17,24 @@ class InvokeFrom(Enum):
|
|||
Invoke From.
|
||||
"""
|
||||
|
||||
# SERVICE_API indicates that this invocation is from an API call to Dify app.
|
||||
#
|
||||
# Description of service api in Dify docs:
|
||||
# https://docs.dify.ai/en/guides/application-publishing/developing-with-apis
|
||||
SERVICE_API = "service-api"
|
||||
|
||||
# WEB_APP indicates that this invocation is from
|
||||
# the web app of the workflow (or chatflow).
|
||||
#
|
||||
# Description of web app in Dify docs:
|
||||
# https://docs.dify.ai/en/guides/application-publishing/launch-your-webapp-quickly/README
|
||||
WEB_APP = "web-app"
|
||||
|
||||
# EXPLORE indicates that this invocation is from
|
||||
# the workflow (or chatflow) explore page.
|
||||
EXPLORE = "explore"
|
||||
# DEBUGGER indicates that this invocation is from
|
||||
# the workflow (or chatflow) edit page.
|
||||
DEBUGGER = "debugger"
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -1 +1,11 @@
|
|||
from typing import Any
|
||||
|
||||
# TODO(QuantumGhost): Refactor variable type identification. Instead of directly
|
||||
# comparing `dify_model_identity` with constants throughout the codebase, extract
|
||||
# this logic into a dedicated function. This would encapsulate the implementation
|
||||
# details of how different variable types are identified.
|
||||
FILE_MODEL_IDENTITY = "__dify__file__"
|
||||
|
||||
|
||||
def maybe_file_object(o: Any) -> bool:
|
||||
return isinstance(o, dict) and o.get("dify_model_identity") == FILE_MODEL_IDENTITY
|
||||
|
|
|
|||
|
|
@ -534,7 +534,7 @@ class IndexingRunner:
|
|||
# chunk nodes by chunk size
|
||||
indexing_start_at = time.perf_counter()
|
||||
tokens = 0
|
||||
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
|
||||
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy":
|
||||
# create keyword index
|
||||
create_keyword_thread = threading.Thread(
|
||||
target=self._process_keyword_index,
|
||||
|
|
@ -572,7 +572,7 @@ class IndexingRunner:
|
|||
|
||||
for future in futures:
|
||||
tokens += future.result()
|
||||
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
|
||||
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy":
|
||||
create_keyword_thread.join()
|
||||
indexing_end_at = time.perf_counter()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import binascii
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -16,7 +17,7 @@ class OAuthHandler(BasePluginClient):
|
|||
provider: str,
|
||||
system_credentials: Mapping[str, Any],
|
||||
) -> PluginOAuthAuthorizationUrlResponse:
|
||||
return self._request_with_plugin_daemon_response(
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url",
|
||||
PluginOAuthAuthorizationUrlResponse,
|
||||
|
|
@ -32,6 +33,9 @@ class OAuthHandler(BasePluginClient):
|
|||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
for resp in response:
|
||||
return resp
|
||||
raise ValueError("No response received from plugin daemon for authorization URL request.")
|
||||
|
||||
def get_credentials(
|
||||
self,
|
||||
|
|
@ -49,7 +53,7 @@ class OAuthHandler(BasePluginClient):
|
|||
# encode request to raw http request
|
||||
raw_request_bytes = self._convert_request_to_raw_data(request)
|
||||
|
||||
return self._request_with_plugin_daemon_response(
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/oauth/get_credentials",
|
||||
PluginOAuthCredentialsResponse,
|
||||
|
|
@ -58,7 +62,8 @@ class OAuthHandler(BasePluginClient):
|
|||
"data": {
|
||||
"provider": provider,
|
||||
"system_credentials": system_credentials,
|
||||
"raw_request_bytes": raw_request_bytes,
|
||||
# for json serialization
|
||||
"raw_http_request": binascii.hexlify(raw_request_bytes).decode(),
|
||||
},
|
||||
},
|
||||
headers={
|
||||
|
|
@ -66,6 +71,9 @@ class OAuthHandler(BasePluginClient):
|
|||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
for resp in response:
|
||||
return resp
|
||||
raise ValueError("No response received from plugin daemon for authorization URL request.")
|
||||
|
||||
def _convert_request_to_raw_data(self, request: Request) -> bytes:
|
||||
"""
|
||||
|
|
@ -79,7 +87,7 @@ class OAuthHandler(BasePluginClient):
|
|||
"""
|
||||
# Start with the request line
|
||||
method = request.method
|
||||
path = request.path
|
||||
path = request.full_path
|
||||
protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1")
|
||||
raw_data = f"{method} {path} {protocol}\r\n".encode()
|
||||
|
||||
|
|
|
|||
|
|
@ -76,6 +76,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
|||
if dataset.indexing_technique == "high_quality":
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
with_keywords = False
|
||||
if with_keywords:
|
||||
keywords_list = kwargs.get("keywords_list")
|
||||
keyword = Keyword(dataset)
|
||||
|
|
@ -91,6 +92,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
|||
vector.delete_by_ids(node_ids)
|
||||
else:
|
||||
vector.delete()
|
||||
with_keywords = False
|
||||
if with_keywords:
|
||||
keyword = Keyword(dataset)
|
||||
if node_ids:
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from core.workflow.entities.workflow_execution import (
|
|||
WorkflowType,
|
||||
)
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from models import (
|
||||
Account,
|
||||
CreatorUserRole,
|
||||
|
|
@ -152,7 +153,11 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
|
|||
db_model.version = domain_model.workflow_version
|
||||
db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None
|
||||
db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None
|
||||
db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None
|
||||
db_model.outputs = (
|
||||
json.dumps(WorkflowRuntimeTypeConverter().to_json_encodable(domain_model.outputs))
|
||||
if domain_model.outputs
|
||||
else None
|
||||
)
|
||||
db_model.status = domain_model.status
|
||||
db_model.error = domain_model.error_message if domain_model.error_message else None
|
||||
db_model.total_tokens = domain_model.total_tokens
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from core.workflow.entities.workflow_node_execution import (
|
|||
)
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from models import (
|
||||
Account,
|
||||
CreatorUserRole,
|
||||
|
|
@ -146,6 +147,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
if not self._creator_user_role:
|
||||
raise ValueError("created_by_role is required in repository constructor")
|
||||
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
db_model = WorkflowNodeExecutionModel()
|
||||
db_model.id = domain_model.id
|
||||
db_model.tenant_id = self._tenant_id
|
||||
|
|
@ -160,9 +162,17 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
db_model.node_id = domain_model.node_id
|
||||
db_model.node_type = domain_model.node_type
|
||||
db_model.title = domain_model.title
|
||||
db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None
|
||||
db_model.process_data = json.dumps(domain_model.process_data) if domain_model.process_data else None
|
||||
db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None
|
||||
db_model.inputs = (
|
||||
json.dumps(json_converter.to_json_encodable(domain_model.inputs)) if domain_model.inputs else None
|
||||
)
|
||||
db_model.process_data = (
|
||||
json.dumps(json_converter.to_json_encodable(domain_model.process_data))
|
||||
if domain_model.process_data
|
||||
else None
|
||||
)
|
||||
db_model.outputs = (
|
||||
json.dumps(json_converter.to_json_encodable(domain_model.outputs)) if domain_model.outputs else None
|
||||
)
|
||||
db_model.status = domain_model.status
|
||||
db_model.error = domain_model.error
|
||||
db_model.elapsed_time = domain_model.elapsed_time
|
||||
|
|
|
|||
|
|
@ -75,6 +75,20 @@ class StringSegment(Segment):
|
|||
class FloatSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.NUMBER
|
||||
value: float
|
||||
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
|
||||
# The following tests cannot pass.
|
||||
#
|
||||
# def test_float_segment_and_nan():
|
||||
# nan = float("nan")
|
||||
# assert nan != nan
|
||||
#
|
||||
# f1 = FloatSegment(value=float("nan"))
|
||||
# f2 = FloatSegment(value=float("nan"))
|
||||
# assert f1 != f2
|
||||
#
|
||||
# f3 = FloatSegment(value=nan)
|
||||
# f4 = FloatSegment(value=nan)
|
||||
# assert f3 != f4
|
||||
|
||||
|
||||
class IntegerSegment(Segment):
|
||||
|
|
|
|||
|
|
@ -18,3 +18,17 @@ class SegmentType(StrEnum):
|
|||
NONE = "none"
|
||||
|
||||
GROUP = "group"
|
||||
|
||||
def is_array_type(self):
|
||||
return self in _ARRAY_TYPES
|
||||
|
||||
|
||||
_ARRAY_TYPES = frozenset(
|
||||
[
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_FILE,
|
||||
]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,26 @@
|
|||
import json
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
||||
from .segment_group import SegmentGroup
|
||||
from .segments import ArrayFileSegment, FileSegment, Segment
|
||||
|
||||
|
||||
def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]:
|
||||
selectors = [node_id, name]
|
||||
if paths:
|
||||
selectors.extend(paths)
|
||||
return selectors
|
||||
|
||||
|
||||
class SegmentJSONEncoder(json.JSONEncoder):
|
||||
def default(self, o):
|
||||
if isinstance(o, ArrayFileSegment):
|
||||
return [v.model_dump() for v in o.value]
|
||||
elif isinstance(o, FileSegment):
|
||||
return o.value.model_dump()
|
||||
elif isinstance(o, SegmentGroup):
|
||||
return [self.default(seg) for seg in o.value]
|
||||
elif isinstance(o, Segment):
|
||||
return o.value
|
||||
else:
|
||||
super().default(o)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,39 @@
|
|||
import abc
|
||||
from typing import Protocol
|
||||
|
||||
from core.variables import Variable
|
||||
|
||||
|
||||
class ConversationVariableUpdater(Protocol):
|
||||
"""
|
||||
ConversationVariableUpdater defines an abstraction for updating conversation variable values.
|
||||
|
||||
It is intended for use by `v1.VariableAssignerNode` and `v2.VariableAssignerNode` when updating
|
||||
conversation variables.
|
||||
|
||||
Implementations may choose to batch updates. If batching is used, the `flush` method
|
||||
should be implemented to persist buffered changes, and `update`
|
||||
should handle buffering accordingly.
|
||||
|
||||
Note: Since implementations may buffer updates, instances of ConversationVariableUpdater
|
||||
are not thread-safe. Each VariableAssignerNode should create its own instance during execution.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def update(self, conversation_id: str, variable: "Variable") -> None:
|
||||
"""
|
||||
Updates the value of the specified conversation variable in the underlying storage.
|
||||
|
||||
:param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`.
|
||||
:param variable: The `Variable` instance containing the updated value.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def flush(self):
|
||||
"""
|
||||
Flushes all pending updates to the underlying storage system.
|
||||
|
||||
If the implementation does not buffer updates, this method can be a no-op.
|
||||
"""
|
||||
pass
|
||||
|
|
@ -7,12 +7,12 @@ from pydantic import BaseModel, Field
|
|||
|
||||
from core.file import File, FileAttribute, file_manager
|
||||
from core.variables import Segment, SegmentGroup, Variable
|
||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
||||
from core.variables.segments import FileSegment, NoneSegment
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from factories import variable_factory
|
||||
|
||||
from ..constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from ..enums import SystemVariableKey
|
||||
|
||||
VariableValue = Union[str, int, float, dict, list, File]
|
||||
|
||||
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
|
||||
|
|
@ -30,9 +30,11 @@ class VariablePool(BaseModel):
|
|||
# TODO: This user inputs is not used for pool.
|
||||
user_inputs: Mapping[str, Any] = Field(
|
||||
description="User inputs",
|
||||
default_factory=dict,
|
||||
)
|
||||
system_variables: Mapping[SystemVariableKey, Any] = Field(
|
||||
description="System variables",
|
||||
default_factory=dict,
|
||||
)
|
||||
environment_variables: Sequence[Variable] = Field(
|
||||
description="Environment variables.",
|
||||
|
|
@ -43,28 +45,7 @@ class VariablePool(BaseModel):
|
|||
default_factory=list,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
system_variables: Mapping[SystemVariableKey, Any] | None = None,
|
||||
user_inputs: Mapping[str, Any] | None = None,
|
||||
environment_variables: Sequence[Variable] | None = None,
|
||||
conversation_variables: Sequence[Variable] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
environment_variables = environment_variables or []
|
||||
conversation_variables = conversation_variables or []
|
||||
user_inputs = user_inputs or {}
|
||||
system_variables = system_variables or {}
|
||||
|
||||
super().__init__(
|
||||
system_variables=system_variables,
|
||||
user_inputs=user_inputs,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def model_post_init(self, context: Any, /) -> None:
|
||||
for key, value in self.system_variables.items():
|
||||
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
|
||||
# Add environment variables to the variable pool
|
||||
|
|
@ -91,12 +72,12 @@ class VariablePool(BaseModel):
|
|||
Returns:
|
||||
None
|
||||
"""
|
||||
if len(selector) < 2:
|
||||
if len(selector) < MIN_SELECTORS_LENGTH:
|
||||
raise ValueError("Invalid selector")
|
||||
|
||||
if isinstance(value, Variable):
|
||||
variable = value
|
||||
if isinstance(value, Segment):
|
||||
elif isinstance(value, Segment):
|
||||
variable = variable_factory.segment_to_variable(segment=value, selector=selector)
|
||||
else:
|
||||
segment = variable_factory.build_segment(value)
|
||||
|
|
@ -118,7 +99,7 @@ class VariablePool(BaseModel):
|
|||
Raises:
|
||||
ValueError: If the selector is invalid.
|
||||
"""
|
||||
if len(selector) < 2:
|
||||
if len(selector) < MIN_SELECTORS_LENGTH:
|
||||
return None
|
||||
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
|
|
|
|||
|
|
@ -66,6 +66,8 @@ class BaseNodeEvent(GraphEngineEvent):
|
|||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
# The version of the node, or "1" if not specified.
|
||||
node_version: str = "1"
|
||||
|
||||
|
||||
class NodeRunStartedEvent(BaseNodeEvent):
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
|||
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.utils import variable_utils
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
|
@ -314,6 +315,7 @@ class GraphEngine:
|
|||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
)
|
||||
raise e
|
||||
|
||||
|
|
@ -627,6 +629,7 @@ class GraphEngine:
|
|||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
agent_strategy=agent_strategy,
|
||||
node_version=node_instance.version(),
|
||||
)
|
||||
|
||||
max_retries = node_instance.node_data.retry_config.max_retries
|
||||
|
|
@ -677,6 +680,7 @@ class GraphEngine:
|
|||
error=run_result.error or "Unknown error",
|
||||
retry_index=retries,
|
||||
start_at=retry_start_at,
|
||||
node_version=node_instance.version(),
|
||||
)
|
||||
time.sleep(retry_interval)
|
||||
break
|
||||
|
|
@ -712,6 +716,7 @@ class GraphEngine:
|
|||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
)
|
||||
should_continue_retry = False
|
||||
else:
|
||||
|
|
@ -726,6 +731,7 @@ class GraphEngine:
|
|||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
)
|
||||
should_continue_retry = False
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
|
|
@ -786,6 +792,7 @@ class GraphEngine:
|
|||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
)
|
||||
should_continue_retry = False
|
||||
|
||||
|
|
@ -803,6 +810,7 @@ class GraphEngine:
|
|||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
)
|
||||
elif isinstance(event, RunRetrieverResourceEvent):
|
||||
yield NodeRunRetrieverResourceEvent(
|
||||
|
|
@ -817,6 +825,7 @@ class GraphEngine:
|
|||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
)
|
||||
except GenerateTaskStoppedError:
|
||||
# trigger node run failed event
|
||||
|
|
@ -833,6 +842,7 @@ class GraphEngine:
|
|||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
|
|
@ -847,16 +857,12 @@ class GraphEngine:
|
|||
:param variable_value: variable value
|
||||
:return:
|
||||
"""
|
||||
self.graph_runtime_state.variable_pool.add([node_id] + variable_key_list, variable_value)
|
||||
|
||||
# if variable_value is a dict, then recursively append variables
|
||||
if isinstance(variable_value, dict):
|
||||
for key, value in variable_value.items():
|
||||
# construct new key list
|
||||
new_key_list = variable_key_list + [key]
|
||||
self._append_variables_recursively(
|
||||
node_id=node_id, variable_key_list=new_key_list, variable_value=value
|
||||
)
|
||||
variable_utils.append_variables_recursively(
|
||||
self.graph_runtime_state.variable_pool,
|
||||
node_id,
|
||||
variable_key_list,
|
||||
variable_value,
|
||||
)
|
||||
|
||||
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -39,6 +39,10 @@ class AgentNode(ToolNode):
|
|||
_node_data_cls = AgentNodeData # type: ignore
|
||||
_node_type = NodeType.AGENT
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""
|
||||
Run the agent node
|
||||
|
|
|
|||
|
|
@ -18,7 +18,11 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
|||
|
||||
class AnswerNode(BaseNode[AnswerNodeData]):
|
||||
_node_data_cls = AnswerNodeData
|
||||
_node_type: NodeType = NodeType.ANSWER
|
||||
_node_type = NodeType.ANSWER
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
|
|
@ -45,7 +49,10 @@ class AnswerNode(BaseNode[AnswerNodeData]):
|
|||
part = cast(TextGenerateRouteChunk, part)
|
||||
answer += part.text
|
||||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer, "files": files})
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"answer": answer, "files": ArrayFileSegment(value=files)},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
|
|
|
|||
|
|
@ -109,6 +109,7 @@ class AnswerStreamProcessor(StreamProcessor):
|
|||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
from_variable_selector=[answer_node_id, "answer"],
|
||||
node_version=event.node_version,
|
||||
)
|
||||
else:
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
|
|
@ -134,6 +135,7 @@ class AnswerStreamProcessor(StreamProcessor):
|
|||
route_node_state=event.route_node_state,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
node_version=event.node_version,
|
||||
)
|
||||
|
||||
self.route_position[answer_node_id] += 1
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, cast
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
|
|
@ -23,7 +23,7 @@ GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData)
|
|||
|
||||
class BaseNode(Generic[GenericNodeData]):
|
||||
_node_data_cls: type[GenericNodeData]
|
||||
_node_type: NodeType
|
||||
_node_type: ClassVar[NodeType]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -90,8 +90,38 @@ class BaseNode(Generic[GenericNodeData]):
|
|||
graph_config: Mapping[str, Any],
|
||||
config: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
"""Extracts references variable selectors from node configuration.
|
||||
|
||||
The `config` parameter represents the configuration for a specific node type and corresponds
|
||||
to the `data` field in the node definition object.
|
||||
|
||||
The returned mapping has the following structure:
|
||||
|
||||
{'1747829548239.#1747829667553.result#': ['1747829667553', 'result']}
|
||||
|
||||
For loop and iteration nodes, the mapping may look like this:
|
||||
|
||||
{
|
||||
"1748332301644.input_selector": ["1748332363630", "result"],
|
||||
"1748332325079.1748332325079.#sys.workflow_id#": ["sys", "workflow_id"],
|
||||
}
|
||||
|
||||
where `1748332301644` is the ID of the loop / iteration node,
|
||||
and `1748332325079` is the ID of the node inside the loop or iteration node.
|
||||
|
||||
Here, the key consists of two parts: the current node ID (provided as the `node_id`
|
||||
parameter to `_extract_variable_selector_to_variable_mapping`) and the variable selector,
|
||||
enclosed in `#` symbols. These two parts are separated by a dot (`.`).
|
||||
|
||||
The value is a list of string representing the variable selector, where the first element is the node ID
|
||||
of the referenced variable, and the second element is the variable name within that node.
|
||||
|
||||
The meaning of the above response is:
|
||||
|
||||
The node with ID `1747829548239` references the variable `result` from the node with
|
||||
ID `1747829667553`. For example, if `1747829548239` is a LLM node, its prompt may contain a
|
||||
reference to the `result` output variable of node `1747829667553`.
|
||||
|
||||
:param graph_config: graph config
|
||||
:param config: node config
|
||||
:return:
|
||||
|
|
@ -101,9 +131,10 @@ class BaseNode(Generic[GenericNodeData]):
|
|||
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
|
||||
|
||||
node_data = cls._node_data_cls(**config.get("data", {}))
|
||||
return cls._extract_variable_selector_to_variable_mapping(
|
||||
data = cls._extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data)
|
||||
)
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
|
|
@ -139,6 +170,16 @@ class BaseNode(Generic[GenericNodeData]):
|
|||
"""
|
||||
return self._node_type
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def version(cls) -> str:
|
||||
"""`node_version` returns the version of current node type."""
|
||||
# NOTE(QuantumGhost): This should be in sync with `NODE_TYPE_CLASSES_MAPPING`.
|
||||
#
|
||||
# If you have introduced a new node type, please add it to `NODE_TYPE_CLASSES_MAPPING`
|
||||
# in `api/core/workflow/nodes/__init__.py`.
|
||||
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
|
||||
|
||||
@property
|
||||
def should_continue_on_error(self) -> bool:
|
||||
"""judge if should continue on error
|
||||
|
|
|
|||
|
|
@ -40,6 +40,10 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
|
||||
return code_provider.get_default_config()
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
# Get code language
|
||||
code_language = self.node_data.code_language
|
||||
|
|
@ -126,6 +130,9 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
prefix: str = "",
|
||||
depth: int = 1,
|
||||
):
|
||||
# TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes.
|
||||
# Note that `_transform_result` may produce lists containing `None` values,
|
||||
# which don't conform to the type requirements of `Array*Segment` classes.
|
||||
if depth > dify_config.CODE_MAX_DEPTH:
|
||||
raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.")
|
||||
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from configs import dify_config
|
|||
from core.file import File, FileTransferMethod, file_manager
|
||||
from core.helper import ssrf_proxy
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.variables.segments import FileSegment
|
||||
from core.variables.segments import ArrayStringSegment, FileSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
|
|
@ -45,6 +45,10 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
|
|||
_node_data_cls = DocumentExtractorNodeData
|
||||
_node_type = NodeType.DOCUMENT_EXTRACTOR
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self):
|
||||
variable_selector = self.node_data.variable_selector
|
||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector)
|
||||
|
|
@ -67,7 +71,7 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
|
|||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs={"text": extracted_text_list},
|
||||
outputs={"text": ArrayStringSegment(value=extracted_text_list)},
|
||||
)
|
||||
elif isinstance(value, File):
|
||||
extracted_text = _extract_text_from_file(value)
|
||||
|
|
@ -447,7 +451,7 @@ def _extract_text_from_excel(file_content: bytes) -> str:
|
|||
df = df.applymap(lambda x: " ".join(str(x).splitlines()) if isinstance(x, str) else x) # type: ignore
|
||||
|
||||
# Combine multi-line text in column names into a single line
|
||||
df.columns = pd.Index([" ".join(col.splitlines()) for col in df.columns])
|
||||
df.columns = pd.Index([" ".join(str(col).splitlines()) for col in df.columns])
|
||||
|
||||
# Manually construct the Markdown table
|
||||
markdown_table += _construct_markdown_table(df) + "\n\n"
|
||||
|
|
|
|||
|
|
@ -9,6 +9,10 @@ class EndNode(BaseNode[EndNodeData]):
|
|||
_node_data_cls = EndNodeData
|
||||
_node_type = NodeType.END
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
|
|
|
|||
|
|
@ -139,6 +139,7 @@ class EndStreamProcessor(StreamProcessor):
|
|||
route_node_state=event.route_node_state,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
node_version=event.node_version,
|
||||
)
|
||||
|
||||
self.route_position[end_node_id] += 1
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from typing import Any, Optional
|
|||
from configs import dify_config
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.variables.segments import ArrayFileSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
|
|
@ -60,6 +61,10 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
|||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
process_data = {}
|
||||
try:
|
||||
|
|
@ -92,7 +97,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
|||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
"status_code": response.status_code,
|
||||
"body": response.text if not files else "",
|
||||
"body": response.text if not files.value else "",
|
||||
"headers": response.headers,
|
||||
"files": files,
|
||||
},
|
||||
|
|
@ -166,7 +171,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
|||
|
||||
return mapping
|
||||
|
||||
def extract_files(self, url: str, response: Response) -> list[File]:
|
||||
def extract_files(self, url: str, response: Response) -> ArrayFileSegment:
|
||||
"""
|
||||
Extract files from response by checking both Content-Type header and URL
|
||||
"""
|
||||
|
|
@ -178,7 +183,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
|||
content_disposition_type = None
|
||||
|
||||
if not is_file:
|
||||
return files
|
||||
return ArrayFileSegment(value=[])
|
||||
|
||||
if parsed_content_disposition:
|
||||
content_disposition_filename = parsed_content_disposition.get_filename()
|
||||
|
|
@ -211,4 +216,4 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
|||
)
|
||||
files.append(file)
|
||||
|
||||
return files
|
||||
return ArrayFileSegment(value=files)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Literal
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
|
|
@ -16,6 +17,10 @@ class IfElseNode(BaseNode[IfElseNodeData]):
|
|||
_node_data_cls = IfElseNodeData
|
||||
_node_type = NodeType.IF_ELSE
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
|
|
@ -87,6 +92,22 @@ class IfElseNode(BaseNode[IfElseNodeData]):
|
|||
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: IfElseNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
var_mapping: dict[str, list[str]] = {}
|
||||
for case in node_data.cases or []:
|
||||
for condition in case.conditions:
|
||||
key = "{}.#{}#".format(node_id, ".".join(condition.variable_selector))
|
||||
var_mapping[key] = condition.variable_selector
|
||||
|
||||
return var_mapping
|
||||
|
||||
|
||||
@deprecated("This function is deprecated. You should use the new cases structure.")
|
||||
def _should_not_use_old_function(
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from flask import Flask, current_app
|
|||
|
||||
from configs import dify_config
|
||||
from core.variables import ArrayVariable, IntegerVariable, NoneVariable
|
||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||
from core.workflow.entities.node_entities import (
|
||||
NodeRunResult,
|
||||
)
|
||||
|
|
@ -37,6 +38,7 @@ from core.workflow.nodes.base import BaseNode
|
|||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
from factories.variable_factory import build_segment
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
from .exc import (
|
||||
|
|
@ -72,6 +74,10 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
"""
|
||||
Run the node.
|
||||
|
|
@ -85,10 +91,17 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
|
||||
|
||||
if isinstance(variable, NoneVariable) or len(variable.value) == 0:
|
||||
# Try our best to preserve the type informat.
|
||||
if isinstance(variable, ArraySegment):
|
||||
output = variable.model_copy(update={"value": []})
|
||||
else:
|
||||
output = ArrayAnySegment(value=[])
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"output": []},
|
||||
# TODO(QuantumGhost): is it possible to compute the type of `output`
|
||||
# from graph definition?
|
||||
outputs={"output": output},
|
||||
)
|
||||
)
|
||||
return
|
||||
|
|
@ -231,6 +244,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
# Flatten the list of lists
|
||||
if isinstance(outputs, list) and all(isinstance(output, list) for output in outputs):
|
||||
outputs = [item for sublist in outputs for item in sublist]
|
||||
output_segment = build_segment(outputs)
|
||||
|
||||
yield IterationRunSucceededEvent(
|
||||
iteration_id=self.id,
|
||||
|
|
@ -247,7 +261,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"output": outputs},
|
||||
outputs={"output": output_segment},
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
|
|
|
|||
|
|
@ -13,6 +13,10 @@ class IterationStartNode(BaseNode[IterationStartNodeData]):
|
|||
_node_data_cls = IterationStartNodeData
|
||||
_node_type = NodeType.ITERATION_START
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run the node.
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
|||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.variables import StringSegment
|
||||
from core.variables.segments import ArrayObjectSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
|
|
@ -70,6 +71,10 @@ class KnowledgeRetrievalNode(LLMNode):
|
|||
_node_data_cls = KnowledgeRetrievalNodeData # type: ignore
|
||||
_node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
|
||||
@classmethod
|
||||
def version(cls):
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult: # type: ignore
|
||||
node_data = cast(KnowledgeRetrievalNodeData, self.node_data)
|
||||
# extract variables
|
||||
|
|
@ -115,9 +120,12 @@ class KnowledgeRetrievalNode(LLMNode):
|
|||
# retrieve knowledge
|
||||
try:
|
||||
results = self._fetch_dataset_retriever(node_data=node_data, query=query)
|
||||
outputs = {"result": results}
|
||||
outputs = {"result": ArrayObjectSegment(value=results)}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
process_data=None,
|
||||
outputs=outputs, # type: ignore
|
||||
)
|
||||
|
||||
except KnowledgeRetrievalNodeError as e:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from typing import Any, Literal, Union
|
|||
|
||||
from core.file import File
|
||||
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
|
|
@ -16,6 +17,10 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
|||
_node_data_cls = ListOperatorNodeData
|
||||
_node_type = NodeType.LIST_OPERATOR
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self):
|
||||
inputs: dict[str, list] = {}
|
||||
process_data: dict[str, list] = {}
|
||||
|
|
@ -30,7 +35,11 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
|||
if not variable.value:
|
||||
inputs = {"variable": []}
|
||||
process_data = {"variable": []}
|
||||
outputs = {"result": [], "first_record": None, "last_record": None}
|
||||
if isinstance(variable, ArraySegment):
|
||||
result = variable.model_copy(update={"value": []})
|
||||
else:
|
||||
result = ArrayAnySegment(value=[])
|
||||
outputs = {"result": result, "first_record": None, "last_record": None}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
|
|
@ -71,7 +80,7 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
|||
variable = self._apply_slice(variable)
|
||||
|
||||
outputs = {
|
||||
"result": variable.value,
|
||||
"result": variable,
|
||||
"first_record": variable.value[0] if variable.value else None,
|
||||
"last_record": variable.value[-1] if variable.value else None,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -119,9 +119,6 @@ class FileSaverImpl(LLMFileSaver):
|
|||
size=len(data),
|
||||
related_id=tool_file.id,
|
||||
url=url,
|
||||
# TODO(QuantumGhost): how should I set the following key?
|
||||
# What's the difference between `remote_url` and `url`?
|
||||
# What's the purpose of `storage_key` and `dify_model_identity`?
|
||||
storage_key=tool_file.file_key,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -138,6 +138,10 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
def process_structured_output(text: str) -> Optional[dict[str, Any]]:
|
||||
"""Process structured output if enabled"""
|
||||
|
|
@ -255,7 +259,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
if structured_output:
|
||||
outputs["structured_output"] = structured_output
|
||||
if self._file_outputs is not None:
|
||||
outputs["files"] = self._file_outputs
|
||||
outputs["files"] = ArrayFileSegment(value=self._file_outputs)
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
|
|
|
|||
|
|
@ -13,6 +13,10 @@ class LoopEndNode(BaseNode[LoopEndNodeData]):
|
|||
_node_data_cls = LoopEndNodeData
|
||||
_node_type = NodeType.LOOP_END
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run the node.
|
||||
|
|
|
|||
|
|
@ -54,6 +54,10 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||
_node_data_cls = LoopNodeData
|
||||
_node_type = NodeType.LOOP
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
"""Run the node."""
|
||||
# Get inputs
|
||||
|
|
@ -482,6 +486,13 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||
|
||||
variable_mapping.update(sub_node_variable_mapping)
|
||||
|
||||
for loop_variable in node_data.loop_variables or []:
|
||||
if loop_variable.value_type == "variable":
|
||||
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
|
||||
# add loop variable to variable mapping
|
||||
selector = loop_variable.value
|
||||
variable_mapping[f"{node_id}.{loop_variable.label}"] = selector
|
||||
|
||||
# remove variable out from loop
|
||||
variable_mapping = {
|
||||
key: value for key, value in variable_mapping.items() if value[0] not in loop_graph.node_ids
|
||||
|
|
|
|||
|
|
@ -13,6 +13,10 @@ class LoopStartNode(BaseNode[LoopStartNodeData]):
|
|||
_node_data_cls = LoopStartNodeData
|
||||
_node_type = NodeType.LOOP_START
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run the node.
|
||||
|
|
|
|||
|
|
@ -25,6 +25,11 @@ from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as Var
|
|||
|
||||
LATEST_VERSION = "latest"
|
||||
|
||||
# NOTE(QuantumGhost): This should be in sync with subclasses of BaseNode.
|
||||
# Specifically, if you have introduced new node types, you should add them here.
|
||||
#
|
||||
# TODO(QuantumGhost): This could be automated with either metaclass or `__init_subclass__`
|
||||
# hook. Try to avoid duplication of node information.
|
||||
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
|
||||
NodeType.START: {
|
||||
LATEST_VERSION: StartNode,
|
||||
|
|
|
|||
|
|
@ -7,6 +7,10 @@ from core.workflow.nodes.base import BaseNodeData
|
|||
from core.workflow.nodes.llm import ModelConfig, VisionConfig
|
||||
|
||||
|
||||
class _ParameterConfigError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ParameterConfig(BaseModel):
|
||||
"""
|
||||
Parameter Config.
|
||||
|
|
@ -27,6 +31,19 @@ class ParameterConfig(BaseModel):
|
|||
raise ValueError("Invalid parameter name, __reason and __is_success are reserved")
|
||||
return str(value)
|
||||
|
||||
def is_array_type(self) -> bool:
|
||||
return self.type in ("array[string]", "array[number]", "array[object]")
|
||||
|
||||
def element_type(self) -> Literal["string", "number", "object"]:
|
||||
if self.type == "array[number]":
|
||||
return "number"
|
||||
elif self.type == "array[string]":
|
||||
return "string"
|
||||
elif self.type == "array[object]":
|
||||
return "object"
|
||||
else:
|
||||
raise _ParameterConfigError(f"{self.type} is not array type.")
|
||||
|
||||
|
||||
class ParameterExtractorNodeData(BaseNodeData):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
|||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
|
|
@ -32,6 +33,7 @@ from core.workflow.nodes.base.node import BaseNode
|
|||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.llm import ModelConfig, llm_utils
|
||||
from core.workflow.utils import variable_template_parser
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
|
||||
from .entities import ParameterExtractorNodeData
|
||||
from .exc import (
|
||||
|
|
@ -109,6 +111,10 @@ class ParameterExtractorNode(BaseNode):
|
|||
}
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self):
|
||||
"""
|
||||
Run the node.
|
||||
|
|
@ -584,28 +590,30 @@ class ParameterExtractorNode(BaseNode):
|
|||
elif parameter.type in {"string", "select"}:
|
||||
if isinstance(result[parameter.name], str):
|
||||
transformed_result[parameter.name] = result[parameter.name]
|
||||
elif parameter.type.startswith("array"):
|
||||
elif parameter.is_array_type():
|
||||
if isinstance(result[parameter.name], list):
|
||||
nested_type = parameter.type[6:-1]
|
||||
transformed_result[parameter.name] = []
|
||||
nested_type = parameter.element_type()
|
||||
assert nested_type is not None
|
||||
segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[])
|
||||
transformed_result[parameter.name] = segment_value
|
||||
for item in result[parameter.name]:
|
||||
if nested_type == "number":
|
||||
if isinstance(item, int | float):
|
||||
transformed_result[parameter.name].append(item)
|
||||
segment_value.value.append(item)
|
||||
elif isinstance(item, str):
|
||||
try:
|
||||
if "." in item:
|
||||
transformed_result[parameter.name].append(float(item))
|
||||
segment_value.value.append(float(item))
|
||||
else:
|
||||
transformed_result[parameter.name].append(int(item))
|
||||
segment_value.value.append(int(item))
|
||||
except ValueError:
|
||||
pass
|
||||
elif nested_type == "string":
|
||||
if isinstance(item, str):
|
||||
transformed_result[parameter.name].append(item)
|
||||
segment_value.value.append(item)
|
||||
elif nested_type == "object":
|
||||
if isinstance(item, dict):
|
||||
transformed_result[parameter.name].append(item)
|
||||
segment_value.value.append(item)
|
||||
|
||||
if parameter.name not in transformed_result:
|
||||
if parameter.type == "number":
|
||||
|
|
@ -615,7 +623,9 @@ class ParameterExtractorNode(BaseNode):
|
|||
elif parameter.type in {"string", "select"}:
|
||||
transformed_result[parameter.name] = ""
|
||||
elif parameter.type.startswith("array"):
|
||||
transformed_result[parameter.name] = []
|
||||
transformed_result[parameter.name] = build_segment_with_type(
|
||||
segment_type=SegmentType(parameter.type), value=[]
|
||||
)
|
||||
|
||||
return transformed_result
|
||||
|
||||
|
|
|
|||
|
|
@ -40,6 +40,10 @@ class QuestionClassifierNode(LLMNode):
|
|||
_node_data_cls = QuestionClassifierNodeData # type: ignore
|
||||
_node_type = NodeType.QUESTION_CLASSIFIER
|
||||
|
||||
@classmethod
|
||||
def version(cls):
|
||||
return "1"
|
||||
|
||||
def _run(self):
|
||||
node_data = cast(QuestionClassifierNodeData, self.node_data)
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
|
|
|||
|
|
@ -10,6 +10,10 @@ class StartNode(BaseNode[StartNodeData]):
|
|||
_node_data_cls = StartNodeData
|
||||
_node_type = NodeType.START
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
|
||||
system_inputs = self.graph_runtime_state.variable_pool.system_variables
|
||||
|
|
@ -18,5 +22,6 @@ class StartNode(BaseNode[StartNodeData]):
|
|||
# Set system variables as node outputs.
|
||||
for var in system_inputs:
|
||||
node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
|
||||
outputs = dict(node_inputs)
|
||||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=node_inputs)
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs)
|
||||
|
|
|
|||
|
|
@ -28,6 +28,10 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
|
|||
"config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
# Get variables
|
||||
variables = {}
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
|||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
|
@ -44,6 +44,10 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
_node_data_cls = ToolNodeData
|
||||
_node_type = NodeType.TOOL
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""
|
||||
Run the tool node
|
||||
|
|
@ -300,6 +304,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
variables[variable_name] = variable_value
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
assert isinstance(message.meta, File)
|
||||
files.append(message.meta["file"])
|
||||
elif message.type == ToolInvokeMessage.MessageType.LOG:
|
||||
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
|
||||
|
|
@ -363,7 +368,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"text": text, "files": files, "json": json, **variables},
|
||||
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json, **variables},
|
||||
metadata={
|
||||
**agent_execution_metadata,
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,6 @@
|
|||
from collections.abc import Mapping
|
||||
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
|
|
@ -9,16 +12,20 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
|
|||
_node_data_cls = VariableAssignerNodeData
|
||||
_node_type = NodeType.VARIABLE_AGGREGATOR
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
# Get variables
|
||||
outputs = {}
|
||||
outputs: dict[str, Segment | Mapping[str, Segment]] = {}
|
||||
inputs = {}
|
||||
|
||||
if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled:
|
||||
for selector in self.node_data.variables:
|
||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
if variable is not None:
|
||||
outputs = {"output": variable.to_object()}
|
||||
outputs = {"output": variable}
|
||||
|
||||
inputs = {".".join(selector[1:]): variable.to_object()}
|
||||
break
|
||||
|
|
@ -28,7 +35,7 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
|
|||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
|
||||
if variable is not None:
|
||||
outputs[group.group_name] = {"output": variable.to_object()}
|
||||
outputs[group.group_name] = {"output": variable}
|
||||
inputs[".".join(selector[1:])] = variable.to_object()
|
||||
break
|
||||
|
||||
|
|
|
|||
|
|
@ -1,19 +1,55 @@
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from collections.abc import Mapping, MutableMapping, Sequence
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from core.variables import Variable
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from extensions.ext_database import db
|
||||
from models import ConversationVariable
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.variables import Segment
|
||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
||||
from core.variables.types import SegmentType
|
||||
|
||||
# Use double underscore (`__`) prefix for internal variables
|
||||
# to minimize risk of collision with user-defined variable names.
|
||||
_UPDATED_VARIABLES_KEY = "__updated_variables"
|
||||
|
||||
|
||||
def update_conversation_variable(conversation_id: str, variable: Variable):
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
||||
class UpdatedVariable(BaseModel):
|
||||
name: str
|
||||
selector: Sequence[str]
|
||||
value_type: SegmentType
|
||||
new_value: Any
|
||||
|
||||
|
||||
_T = TypeVar("_T", bound=MutableMapping[str, Any])
|
||||
|
||||
|
||||
def variable_to_processed_data(selector: Sequence[str], seg: Segment) -> UpdatedVariable:
|
||||
if len(selector) < MIN_SELECTORS_LENGTH:
|
||||
raise Exception("selector too short")
|
||||
node_id, var_name = selector[:2]
|
||||
return UpdatedVariable(
|
||||
name=var_name,
|
||||
selector=list(selector[:2]),
|
||||
value_type=seg.value_type,
|
||||
new_value=seg.value,
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
row = session.scalar(stmt)
|
||||
if not row:
|
||||
raise VariableOperatorNodeError("conversation variable not found in the database")
|
||||
row.data = variable.model_dump_json()
|
||||
session.commit()
|
||||
|
||||
|
||||
def set_updated_variables(m: _T, updates: Sequence[UpdatedVariable]) -> _T:
|
||||
m[_UPDATED_VARIABLES_KEY] = updates
|
||||
return m
|
||||
|
||||
|
||||
def get_updated_variables(m: Mapping[str, Any]) -> Sequence[UpdatedVariable] | None:
|
||||
updated_values = m.get(_UPDATED_VARIABLES_KEY, None)
|
||||
if updated_values is None:
|
||||
return None
|
||||
result = []
|
||||
for items in updated_values:
|
||||
if isinstance(items, UpdatedVariable):
|
||||
result.append(items)
|
||||
elif isinstance(items, dict):
|
||||
items = UpdatedVariable.model_validate(items)
|
||||
result.append(items)
|
||||
else:
|
||||
raise TypeError(f"Invalid updated variable: {items}, type={type(items)}")
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -0,0 +1,38 @@
|
|||
from sqlalchemy import Engine, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.variables.variables import Variable
|
||||
from models.engine import db
|
||||
from models.workflow import ConversationVariable
|
||||
|
||||
from .exc import VariableOperatorNodeError
|
||||
|
||||
|
||||
class ConversationVariableUpdaterImpl:
|
||||
_engine: Engine | None
|
||||
|
||||
def __init__(self, engine: Engine | None = None) -> None:
|
||||
self._engine = engine
|
||||
|
||||
def _get_engine(self) -> Engine:
|
||||
if self._engine:
|
||||
return self._engine
|
||||
return db.engine
|
||||
|
||||
def update(self, conversation_id: str, variable: Variable):
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
||||
)
|
||||
with Session(self._get_engine()) as session:
|
||||
row = session.scalar(stmt)
|
||||
if not row:
|
||||
raise VariableOperatorNodeError("conversation variable not found in the database")
|
||||
row.data = variable.model_dump_json()
|
||||
session.commit()
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
|
||||
def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl:
|
||||
return ConversationVariableUpdaterImpl()
|
||||
|
|
@ -1,4 +1,9 @@
|
|||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeAlias
|
||||
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
|
|
@ -7,16 +12,71 @@ from core.workflow.nodes.variable_assigner.common import helpers as common_helpe
|
|||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from factories import variable_factory
|
||||
|
||||
from ..common.impl import conversation_variable_updater_factory
|
||||
from .node_data import VariableAssignerData, WriteMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
|
||||
|
||||
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
|
||||
|
||||
|
||||
class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
||||
_node_data_cls = VariableAssignerData
|
||||
_node_type = NodeType.VARIABLE_ASSIGNER
|
||||
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph: "Graph",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
previous_node_id: Optional[str] = None,
|
||||
thread_pool_id: Optional[str] = None,
|
||||
conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
previous_node_id=previous_node_id,
|
||||
thread_pool_id=thread_pool_id,
|
||||
)
|
||||
self._conv_var_updater_factory = conv_var_updater_factory
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: VariableAssignerData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
mapping = {}
|
||||
assigned_variable_node_id = node_data.assigned_variable_selector[0]
|
||||
if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||
selector_key = ".".join(node_data.assigned_variable_selector)
|
||||
key = f"{node_id}.#{selector_key}#"
|
||||
mapping[key] = node_data.assigned_variable_selector
|
||||
|
||||
selector_key = ".".join(node_data.input_variable_selector)
|
||||
key = f"{node_id}.#{selector_key}#"
|
||||
mapping[key] = node_data.input_variable_selector
|
||||
return mapping
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
assigned_variable_selector = self.node_data.assigned_variable_selector
|
||||
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
|
||||
original_variable = self.graph_runtime_state.variable_pool.get(self.node_data.assigned_variable_selector)
|
||||
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
|
||||
if not isinstance(original_variable, Variable):
|
||||
raise VariableOperatorNodeError("assigned variable not found")
|
||||
|
||||
|
|
@ -44,20 +104,28 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
|||
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
|
||||
|
||||
# Over write the variable.
|
||||
self.graph_runtime_state.variable_pool.add(self.node_data.assigned_variable_selector, updated_variable)
|
||||
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
|
||||
|
||||
# TODO: Move database operation to the pipeline.
|
||||
# Update conversation variable.
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
|
||||
if not conversation_id:
|
||||
raise VariableOperatorNodeError("conversation_id not found")
|
||||
common_helpers.update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
|
||||
conv_var_updater = self._conv_var_updater_factory()
|
||||
conv_var_updater.update(conversation_id=conversation_id.text, variable=updated_variable)
|
||||
conv_var_updater.flush()
|
||||
updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)]
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={
|
||||
"value": income_value.to_object(),
|
||||
},
|
||||
# NOTE(QuantumGhost): although only one variable is updated in `v1.VariableAssignerNode`,
|
||||
# we still set `output_variables` as a list to ensure the schema of output is
|
||||
# compatible with `v2.VariableAssignerNode`.
|
||||
process_data=common_helpers.set_updated_variables({}, updated_variables),
|
||||
outputs={},
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,12 @@ class VariableOperationItem(BaseModel):
|
|||
variable_selector: Sequence[str]
|
||||
input_type: InputType
|
||||
operation: Operation
|
||||
# NOTE(QuantumGhost): The `value` field serves multiple purposes depending on context:
|
||||
#
|
||||
# 1. For CONSTANT input_type: Contains the literal value to be used in the operation.
|
||||
# 2. For VARIABLE input_type: Initially contains the selector of the source variable.
|
||||
# 3. During the variable updating procedure: The `value` field is reassigned to hold
|
||||
# the resolved actual value that will be applied to the target variable.
|
||||
value: Any | None = None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -29,3 +29,8 @@ class InvalidInputValueError(VariableOperatorNodeError):
|
|||
class ConversationIDNotFoundError(VariableOperatorNodeError):
|
||||
def __init__(self):
|
||||
super().__init__("conversation_id not found")
|
||||
|
||||
|
||||
class InvalidDataError(VariableOperatorNodeError):
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message)
|
||||
|
|
|
|||
|
|
@ -1,34 +1,84 @@
|
|||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
from collections.abc import Callable, Mapping, MutableMapping, Sequence
|
||||
from typing import Any, TypeAlias, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
|
||||
|
||||
from . import helpers
|
||||
from .constants import EMPTY_VALUE_MAPPING
|
||||
from .entities import VariableAssignerNodeData
|
||||
from .entities import VariableAssignerNodeData, VariableOperationItem
|
||||
from .enums import InputType, Operation
|
||||
from .exc import (
|
||||
ConversationIDNotFoundError,
|
||||
InputTypeNotSupportedError,
|
||||
InvalidDataError,
|
||||
InvalidInputValueError,
|
||||
OperationNotSupportedError,
|
||||
VariableNotFoundError,
|
||||
)
|
||||
|
||||
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
|
||||
|
||||
|
||||
def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem):
|
||||
selector_node_id = item.variable_selector[0]
|
||||
if selector_node_id != CONVERSATION_VARIABLE_NODE_ID:
|
||||
return
|
||||
selector_str = ".".join(item.variable_selector)
|
||||
key = f"{node_id}.#{selector_str}#"
|
||||
mapping[key] = item.variable_selector
|
||||
|
||||
|
||||
def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem):
|
||||
# Keep this in sync with the logic in _run methods...
|
||||
if item.input_type != InputType.VARIABLE:
|
||||
return
|
||||
selector = item.value
|
||||
if not isinstance(selector, list):
|
||||
raise InvalidDataError(f"selector is not a list, {node_id=}, {item=}")
|
||||
if len(selector) < MIN_SELECTORS_LENGTH:
|
||||
raise InvalidDataError(f"selector too short, {node_id=}, {item=}")
|
||||
selector_str = ".".join(selector)
|
||||
key = f"{node_id}.#{selector_str}#"
|
||||
mapping[key] = selector
|
||||
|
||||
|
||||
class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
|
||||
_node_data_cls = VariableAssignerNodeData
|
||||
_node_type = NodeType.VARIABLE_ASSIGNER
|
||||
|
||||
def _conv_var_updater_factory(self) -> ConversationVariableUpdater:
|
||||
return conversation_variable_updater_factory()
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "2"
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: VariableAssignerNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
var_mapping: dict[str, Sequence[str]] = {}
|
||||
for item in node_data.items:
|
||||
_target_mapping_from_item(var_mapping, node_id, item)
|
||||
_source_mapping_from_item(var_mapping, node_id, item)
|
||||
return var_mapping
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
inputs = self.node_data.model_dump()
|
||||
process_data: dict[str, Any] = {}
|
||||
|
|
@ -114,6 +164,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
|
|||
# remove the duplicated items first.
|
||||
updated_variable_selectors = list(set(map(tuple, updated_variable_selectors)))
|
||||
|
||||
conv_var_updater = self._conv_var_updater_factory()
|
||||
# Update variables
|
||||
for selector in updated_variable_selectors:
|
||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
|
|
@ -128,15 +179,23 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
|
|||
raise ConversationIDNotFoundError
|
||||
else:
|
||||
conversation_id = conversation_id.value
|
||||
common_helpers.update_conversation_variable(
|
||||
conv_var_updater.update(
|
||||
conversation_id=cast(str, conversation_id),
|
||||
variable=variable,
|
||||
)
|
||||
conv_var_updater.flush()
|
||||
updated_variables = [
|
||||
common_helpers.variable_to_processed_data(selector, seg)
|
||||
for selector in updated_variable_selectors
|
||||
if (seg := self.graph_runtime_state.variable_pool.get(selector)) is not None
|
||||
]
|
||||
|
||||
process_data = common_helpers.set_updated_variables(process_data, updated_variables)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs={},
|
||||
)
|
||||
|
||||
def _handle_item(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,29 @@
|
|||
from core.variables.segments import ObjectSegment, Segment
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
|
||||
|
||||
def append_variables_recursively(
|
||||
pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue | Segment
|
||||
):
|
||||
"""
|
||||
Append variables recursively
|
||||
:param pool: variable pool to append variables to
|
||||
:param node_id: node id
|
||||
:param variable_key_list: variable key list
|
||||
:param variable_value: variable value
|
||||
:return:
|
||||
"""
|
||||
pool.add([node_id] + variable_key_list, variable_value)
|
||||
|
||||
# if variable_value is a dict, then recursively append variables
|
||||
if isinstance(variable_value, ObjectSegment):
|
||||
variable_dict = variable_value.value
|
||||
elif isinstance(variable_value, dict):
|
||||
variable_dict = variable_value
|
||||
else:
|
||||
return
|
||||
|
||||
for key, value in variable_dict.items():
|
||||
# construct new key list
|
||||
new_key_list = variable_key_list + [key]
|
||||
append_variables_recursively(pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value)
|
||||
|
|
@ -0,0 +1,84 @@
|
|||
import abc
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.variables import Variable
|
||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.utils import variable_utils
|
||||
|
||||
|
||||
class VariableLoader(Protocol):
|
||||
"""Interface for loading variables based on selectors.
|
||||
|
||||
A `VariableLoader` is responsible for retrieving additional variables required during the execution
|
||||
of a single node, which are not provided as user inputs.
|
||||
|
||||
NOTE(QuantumGhost): Typically, all variables loaded by a `VariableLoader` should belong to the same
|
||||
application and share the same `app_id`. However, this interface does not enforce that constraint,
|
||||
and the `app_id` parameter is intentionally omitted from `load_variables` to achieve separation of
|
||||
concern and allow for flexible implementations.
|
||||
|
||||
Implementations of `VariableLoader` should almost always have an `app_id` parameter in
|
||||
their constructor.
|
||||
|
||||
TODO(QuantumGhost): this is a temporally workaround. If we can move the creation of node instance into
|
||||
`WorkflowService.single_step_run`, we may get rid of this interface.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
|
||||
"""Load variables based on the provided selectors. If the selectors are empty,
|
||||
this method should return an empty list.
|
||||
|
||||
The order of the returned variables is not guaranteed. If the caller wants to ensure
|
||||
a specific order, they should sort the returned list themselves.
|
||||
|
||||
:param: selectors: a list of string list, each inner list should have at least two elements:
|
||||
- the first element is the node ID,
|
||||
- the second element is the variable name.
|
||||
:return: a list of Variable objects that match the provided selectors.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class _DummyVariableLoader(VariableLoader):
|
||||
"""A dummy implementation of VariableLoader that does not load any variables.
|
||||
Serves as a placeholder when no variable loading is needed.
|
||||
"""
|
||||
|
||||
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
|
||||
return []
|
||||
|
||||
|
||||
DUMMY_VARIABLE_LOADER = _DummyVariableLoader()
|
||||
|
||||
|
||||
def load_into_variable_pool(
|
||||
variable_loader: VariableLoader,
|
||||
variable_pool: VariablePool,
|
||||
variable_mapping: Mapping[str, Sequence[str]],
|
||||
user_inputs: Mapping[str, Any],
|
||||
):
|
||||
# Loading missing variable from draft var here, and set it into
|
||||
# variable_pool.
|
||||
variables_to_load: list[list[str]] = []
|
||||
for key, selector in variable_mapping.items():
|
||||
# NOTE(QuantumGhost): this logic needs to be in sync with
|
||||
# `WorkflowEntry.mapping_user_inputs_to_variable_pool`.
|
||||
node_variable_list = key.split(".")
|
||||
if len(node_variable_list) < 1:
|
||||
raise ValueError(f"Invalid variable key: {key}. It should have at least one element.")
|
||||
if key in user_inputs:
|
||||
continue
|
||||
node_variable_key = ".".join(node_variable_list[1:])
|
||||
if node_variable_key in user_inputs:
|
||||
continue
|
||||
if variable_pool.get(selector) is None:
|
||||
variables_to_load.append(list(selector))
|
||||
loaded = variable_loader.load_variables(variables_to_load)
|
||||
for var in loaded:
|
||||
assert len(var.selector) >= MIN_SELECTORS_LENGTH, f"Invalid variable {var}"
|
||||
variable_utils.append_variables_recursively(
|
||||
variable_pool, node_id=var.selector[0], variable_key_list=list(var.selector[1:]), variable_value=var
|
||||
)
|
||||
|
|
@ -92,7 +92,7 @@ class WorkflowCycleManager:
|
|||
) -> WorkflowExecution:
|
||||
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
|
||||
|
||||
outputs = WorkflowEntry.handle_special_values(outputs)
|
||||
# outputs = WorkflowEntry.handle_special_values(outputs)
|
||||
|
||||
workflow_execution.status = WorkflowExecutionStatus.SUCCEEDED
|
||||
workflow_execution.outputs = outputs or {}
|
||||
|
|
@ -125,7 +125,7 @@ class WorkflowCycleManager:
|
|||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> WorkflowExecution:
|
||||
execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
|
||||
outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
|
||||
# outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
|
||||
|
||||
execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED
|
||||
execution.outputs = outputs or {}
|
||||
|
|
@ -242,9 +242,9 @@ class WorkflowCycleManager:
|
|||
raise ValueError(f"Domain node execution not found: {event.node_execution_id}")
|
||||
|
||||
# Process data
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
inputs = event.inputs
|
||||
process_data = event.process_data
|
||||
outputs = event.outputs
|
||||
|
||||
# Convert metadata keys to strings
|
||||
execution_metadata_dict = {}
|
||||
|
|
@ -289,7 +289,7 @@ class WorkflowCycleManager:
|
|||
# Process data
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
outputs = event.outputs
|
||||
|
||||
# Convert metadata keys to strings
|
||||
execution_metadata_dict = {}
|
||||
|
|
@ -326,7 +326,7 @@ class WorkflowCycleManager:
|
|||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - created_at).total_seconds()
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
outputs = event.outputs
|
||||
|
||||
# Convert metadata keys to strings
|
||||
origin_metadata = {
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from core.workflow.nodes import NodeType
|
|||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.event import NodeEvent
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
from factories import file_factory
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import (
|
||||
|
|
@ -119,7 +120,9 @@ class WorkflowEntry:
|
|||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_id: str,
|
||||
user_inputs: dict,
|
||||
user_inputs: Mapping[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]:
|
||||
"""
|
||||
Single step run workflow node
|
||||
|
|
@ -129,29 +132,14 @@ class WorkflowEntry:
|
|||
:param user_inputs: user inputs
|
||||
:return:
|
||||
"""
|
||||
# fetch node info from workflow graph
|
||||
workflow_graph = workflow.graph_dict
|
||||
if not workflow_graph:
|
||||
raise ValueError("workflow graph not found")
|
||||
|
||||
nodes = workflow_graph.get("nodes")
|
||||
if not nodes:
|
||||
raise ValueError("nodes not found in workflow graph")
|
||||
|
||||
# fetch node config from node id
|
||||
try:
|
||||
node_config = next(filter(lambda node: node["id"] == node_id, nodes))
|
||||
except StopIteration:
|
||||
raise ValueError("node id not found in workflow graph")
|
||||
node_config = workflow.get_node_config_by_id(node_id)
|
||||
node_config_data = node_config.get("data", {})
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType(node_config.get("data", {}).get("type"))
|
||||
node_version = node_config.get("data", {}).get("version", "1")
|
||||
node_type = NodeType(node_config_data.get("type"))
|
||||
node_version = node_config_data.get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(environment_variables=workflow.environment_variables)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=workflow.graph_dict)
|
||||
|
||||
|
|
@ -182,16 +170,33 @@ class WorkflowEntry:
|
|||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
||||
# Loading missing variable from draft var here, and set it into
|
||||
# variable_pool.
|
||||
load_into_variable_pool(
|
||||
variable_loader=variable_loader,
|
||||
variable_pool=variable_pool,
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
)
|
||||
|
||||
cls.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# run node
|
||||
generator = node_instance.run()
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s",
|
||||
workflow.id,
|
||||
node_instance.id,
|
||||
node_instance.node_type,
|
||||
node_instance.version(),
|
||||
)
|
||||
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
|
||||
return node_instance, generator
|
||||
|
||||
|
|
@ -294,10 +299,20 @@ class WorkflowEntry:
|
|||
|
||||
return node_instance, generator
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"error while running node_instance, node_id=%s, type=%s, version=%s",
|
||||
node_instance.id,
|
||||
node_instance.node_type,
|
||||
node_instance.version(),
|
||||
)
|
||||
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
|
||||
|
||||
@staticmethod
|
||||
def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None:
|
||||
# NOTE(QuantumGhost): Avoid using this function in new code.
|
||||
# Keep values structured as long as possible and only convert to dict
|
||||
# immediately before serialization (e.g., JSON serialization) to maintain
|
||||
# data integrity and type information.
|
||||
result = WorkflowEntry._handle_special_values(value)
|
||||
return result if isinstance(result, Mapping) or result is None else dict(result)
|
||||
|
||||
|
|
@ -324,10 +339,17 @@ class WorkflowEntry:
|
|||
cls,
|
||||
*,
|
||||
variable_mapping: Mapping[str, Sequence[str]],
|
||||
user_inputs: dict,
|
||||
user_inputs: Mapping[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
# NOTE(QuantumGhost): This logic should remain synchronized with
|
||||
# the implementation of `load_into_variable_pool`, specifically the logic about
|
||||
# variable existence checking.
|
||||
|
||||
# WARNING(QuantumGhost): The semantics of this method are not clearly defined,
|
||||
# and multiple parts of the codebase depend on its current behavior.
|
||||
# Modify with caution.
|
||||
for node_variable, variable_selector in variable_mapping.items():
|
||||
# fetch node id and variable key from node_variable
|
||||
node_variable_list = node_variable.split(".")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,49 @@
|
|||
import json
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.file.models import File
|
||||
from core.variables import Segment
|
||||
|
||||
|
||||
class WorkflowRuntimeTypeEncoder(json.JSONEncoder):
|
||||
def default(self, o: Any):
|
||||
if isinstance(o, Segment):
|
||||
return o.value
|
||||
elif isinstance(o, File):
|
||||
return o.to_dict()
|
||||
elif isinstance(o, BaseModel):
|
||||
return o.model_dump(mode="json")
|
||||
else:
|
||||
return super().default(o)
|
||||
|
||||
|
||||
class WorkflowRuntimeTypeConverter:
|
||||
def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
|
||||
result = self._to_json_encodable_recursive(value)
|
||||
return result if isinstance(result, Mapping) or result is None else dict(result)
|
||||
|
||||
def _to_json_encodable_recursive(self, value: Any) -> Any:
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, (bool, int, str, float)):
|
||||
return value
|
||||
if isinstance(value, Segment):
|
||||
return self._to_json_encodable_recursive(value.value)
|
||||
if isinstance(value, File):
|
||||
return value.to_dict()
|
||||
if isinstance(value, BaseModel):
|
||||
return value.model_dump(mode="json")
|
||||
if isinstance(value, dict):
|
||||
res = {}
|
||||
for k, v in value.items():
|
||||
res[k] = self._to_json_encodable_recursive(v)
|
||||
return res
|
||||
if isinstance(value, list):
|
||||
res_list = []
|
||||
for item in value:
|
||||
res_list.append(self._to_json_encodable_recursive(item))
|
||||
return res_list
|
||||
return value
|
||||
|
|
@ -3,8 +3,10 @@ from .clean_when_document_deleted import handle
|
|||
from .create_document_index import handle
|
||||
from .create_installed_app_when_app_created import handle
|
||||
from .create_site_record_when_app_created import handle
|
||||
from .deduct_quota_when_message_created import handle
|
||||
from .delete_tool_parameters_cache_when_sync_draft_workflow import handle
|
||||
from .update_app_dataset_join_when_app_model_config_updated import handle
|
||||
from .update_app_dataset_join_when_app_published_workflow_updated import handle
|
||||
from .update_provider_last_used_at_when_message_created import handle
|
||||
|
||||
# Consolidated handler replaces both deduct_quota_when_message_created and
|
||||
# update_provider_last_used_at_when_message_created
|
||||
from .update_provider_when_message_created import handle
|
||||
|
|
|
|||
|
|
@ -1,65 +0,0 @@
|
|||
from datetime import UTC, datetime
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
|
||||
from core.entities.provider_entities import QuotaUnit
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.provider import Provider, ProviderType
|
||||
|
||||
|
||||
@message_was_created.connect
|
||||
def handle(sender, **kwargs):
|
||||
message = sender
|
||||
application_generate_entity = kwargs.get("application_generate_entity")
|
||||
|
||||
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
|
||||
return
|
||||
|
||||
model_config = application_generate_entity.model_conf
|
||||
provider_model_bundle = model_config.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||
return
|
||||
|
||||
system_configuration = provider_configuration.system_configuration
|
||||
|
||||
if not system_configuration.current_quota_type:
|
||||
return
|
||||
|
||||
quota_unit = None
|
||||
for quota_configuration in system_configuration.quota_configurations:
|
||||
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
||||
quota_unit = quota_configuration.quota_unit
|
||||
|
||||
if quota_configuration.quota_limit == -1:
|
||||
return
|
||||
|
||||
break
|
||||
|
||||
used_quota = None
|
||||
if quota_unit:
|
||||
if quota_unit == QuotaUnit.TOKENS:
|
||||
used_quota = message.message_tokens + message.answer_tokens
|
||||
elif quota_unit == QuotaUnit.CREDITS:
|
||||
used_quota = dify_config.get_model_credits(model_config.model)
|
||||
else:
|
||||
used_quota = 1
|
||||
|
||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||
db.session.query(Provider).filter(
|
||||
Provider.tenant_id == application_generate_entity.app_config.tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_config.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
).update(
|
||||
{
|
||||
"quota_used": Provider.quota_used + used_quota,
|
||||
"last_used": datetime.now(tz=UTC).replace(tzinfo=None),
|
||||
}
|
||||
)
|
||||
db.session.commit()
|
||||
|
|
@ -1,20 +0,0 @@
|
|||
from datetime import UTC, datetime
|
||||
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.provider import Provider
|
||||
|
||||
|
||||
@message_was_created.connect
|
||||
def handle(sender, **kwargs):
|
||||
application_generate_entity = kwargs.get("application_generate_entity")
|
||||
|
||||
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
|
||||
return
|
||||
|
||||
db.session.query(Provider).filter(
|
||||
Provider.tenant_id == application_generate_entity.app_config.tenant_id,
|
||||
Provider.provider_name == application_generate_entity.model_conf.provider,
|
||||
).update({"last_used": datetime.now(UTC).replace(tzinfo=None)})
|
||||
db.session.commit()
|
||||
|
|
@ -0,0 +1,234 @@
|
|||
import logging
|
||||
import time as time_module
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
|
||||
from core.entities.provider_entities import QuotaUnit, SystemConfiguration
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs import datetime_utils
|
||||
from models.model import Message
|
||||
from models.provider import Provider, ProviderType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _ProviderUpdateFilters(BaseModel):
|
||||
"""Filters for identifying Provider records to update."""
|
||||
|
||||
tenant_id: str
|
||||
provider_name: str
|
||||
provider_type: Optional[str] = None
|
||||
quota_type: Optional[str] = None
|
||||
|
||||
|
||||
class _ProviderUpdateAdditionalFilters(BaseModel):
|
||||
"""Additional filters for Provider updates."""
|
||||
|
||||
quota_limit_check: bool = False
|
||||
|
||||
|
||||
class _ProviderUpdateValues(BaseModel):
|
||||
"""Values to update in Provider records."""
|
||||
|
||||
last_used: Optional[datetime] = None
|
||||
quota_used: Optional[Any] = None # Can be Provider.quota_used + int expression
|
||||
|
||||
|
||||
class _ProviderUpdateOperation(BaseModel):
|
||||
"""A single Provider update operation."""
|
||||
|
||||
filters: _ProviderUpdateFilters
|
||||
values: _ProviderUpdateValues
|
||||
additional_filters: _ProviderUpdateAdditionalFilters = _ProviderUpdateAdditionalFilters()
|
||||
description: str = "unknown"
|
||||
|
||||
|
||||
@message_was_created.connect
|
||||
def handle(sender: Message, **kwargs):
|
||||
"""
|
||||
Consolidated handler for Provider updates when a message is created.
|
||||
|
||||
This handler replaces both:
|
||||
- update_provider_last_used_at_when_message_created
|
||||
- deduct_quota_when_message_created
|
||||
|
||||
By performing all Provider updates in a single transaction, we ensure
|
||||
consistency and efficiency when updating Provider records.
|
||||
"""
|
||||
message = sender
|
||||
application_generate_entity = kwargs.get("application_generate_entity")
|
||||
|
||||
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
|
||||
return
|
||||
|
||||
tenant_id = application_generate_entity.app_config.tenant_id
|
||||
provider_name = application_generate_entity.model_conf.provider
|
||||
current_time = datetime_utils.naive_utc_now()
|
||||
|
||||
# Prepare updates for both scenarios
|
||||
updates_to_perform: list[_ProviderUpdateOperation] = []
|
||||
|
||||
# 1. Always update last_used for the provider
|
||||
basic_update = _ProviderUpdateOperation(
|
||||
filters=_ProviderUpdateFilters(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_name,
|
||||
),
|
||||
values=_ProviderUpdateValues(last_used=current_time),
|
||||
description="basic_last_used_update",
|
||||
)
|
||||
updates_to_perform.append(basic_update)
|
||||
|
||||
# 2. Check if we need to deduct quota (system provider only)
|
||||
model_config = application_generate_entity.model_conf
|
||||
provider_model_bundle = model_config.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
if (
|
||||
provider_configuration.using_provider_type == ProviderType.SYSTEM
|
||||
and provider_configuration.system_configuration
|
||||
and provider_configuration.system_configuration.current_quota_type is not None
|
||||
):
|
||||
system_configuration = provider_configuration.system_configuration
|
||||
|
||||
# Calculate quota usage
|
||||
used_quota = _calculate_quota_usage(
|
||||
message=message,
|
||||
system_configuration=system_configuration,
|
||||
model_name=model_config.model,
|
||||
)
|
||||
|
||||
if used_quota is not None:
|
||||
quota_update = _ProviderUpdateOperation(
|
||||
filters=_ProviderUpdateFilters(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=ModelProviderID(model_config.provider).provider_name,
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
quota_type=provider_configuration.system_configuration.current_quota_type.value,
|
||||
),
|
||||
values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
|
||||
additional_filters=_ProviderUpdateAdditionalFilters(
|
||||
quota_limit_check=True # Provider.quota_limit > Provider.quota_used
|
||||
),
|
||||
description="quota_deduction_update",
|
||||
)
|
||||
updates_to_perform.append(quota_update)
|
||||
|
||||
# Execute all updates
|
||||
start_time = time_module.perf_counter()
|
||||
try:
|
||||
_execute_provider_updates(updates_to_perform)
|
||||
|
||||
# Log successful completion with timing
|
||||
duration = time_module.perf_counter() - start_time
|
||||
|
||||
logger.info(
|
||||
f"Provider updates completed successfully. "
|
||||
f"Updates: {len(updates_to_perform)}, Duration: {duration:.3f}s, "
|
||||
f"Tenant: {tenant_id}, Provider: {provider_name}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Log failure with timing and context
|
||||
duration = time_module.perf_counter() - start_time
|
||||
|
||||
logger.exception(
|
||||
f"Provider updates failed after {duration:.3f}s. "
|
||||
f"Updates: {len(updates_to_perform)}, Tenant: {tenant_id}, "
|
||||
f"Provider: {provider_name}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def _calculate_quota_usage(
|
||||
*, message: Message, system_configuration: SystemConfiguration, model_name: str
|
||||
) -> Optional[int]:
|
||||
"""Calculate quota usage based on message tokens and quota type."""
|
||||
quota_unit = None
|
||||
for quota_configuration in system_configuration.quota_configurations:
|
||||
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
||||
quota_unit = quota_configuration.quota_unit
|
||||
if quota_configuration.quota_limit == -1:
|
||||
return None
|
||||
break
|
||||
if quota_unit is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
if quota_unit == QuotaUnit.TOKENS:
|
||||
tokens = message.message_tokens + message.answer_tokens
|
||||
return tokens
|
||||
if quota_unit == QuotaUnit.CREDITS:
|
||||
tokens = dify_config.get_model_credits(model_name)
|
||||
return tokens
|
||||
elif quota_unit == QuotaUnit.TIMES:
|
||||
return 1
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception("Failed to calculate quota usage")
|
||||
return None
|
||||
|
||||
|
||||
def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]):
|
||||
"""Execute all Provider updates in a single transaction."""
|
||||
if not updates_to_perform:
|
||||
return
|
||||
|
||||
# Use SQLAlchemy's context manager for transaction management
|
||||
# This automatically handles commit/rollback
|
||||
with Session(db.engine) as session:
|
||||
# Use a single transaction for all updates
|
||||
for update_operation in updates_to_perform:
|
||||
filters = update_operation.filters
|
||||
values = update_operation.values
|
||||
additional_filters = update_operation.additional_filters
|
||||
description = update_operation.description
|
||||
|
||||
# Build the where conditions
|
||||
where_conditions = [
|
||||
Provider.tenant_id == filters.tenant_id,
|
||||
Provider.provider_name == filters.provider_name,
|
||||
]
|
||||
|
||||
# Add additional filters if specified
|
||||
if filters.provider_type is not None:
|
||||
where_conditions.append(Provider.provider_type == filters.provider_type)
|
||||
if filters.quota_type is not None:
|
||||
where_conditions.append(Provider.quota_type == filters.quota_type)
|
||||
if additional_filters.quota_limit_check:
|
||||
where_conditions.append(Provider.quota_limit > Provider.quota_used)
|
||||
|
||||
# Prepare values dict for SQLAlchemy update
|
||||
update_values = {}
|
||||
if values.last_used is not None:
|
||||
update_values["last_used"] = values.last_used
|
||||
if values.quota_used is not None:
|
||||
update_values["quota_used"] = values.quota_used
|
||||
|
||||
# Build and execute the update statement
|
||||
stmt = update(Provider).where(*where_conditions).values(**update_values)
|
||||
result = session.execute(stmt)
|
||||
rows_affected = result.rowcount
|
||||
|
||||
logger.debug(
|
||||
f"Provider update ({description}): {rows_affected} rows affected. "
|
||||
f"Filters: {filters.model_dump()}, Values: {update_values}"
|
||||
)
|
||||
|
||||
# If no rows were affected for quota updates, log a warning
|
||||
if rows_affected == 0 and description == "quota_deduction_update":
|
||||
logger.warning(
|
||||
f"No Provider rows updated for quota deduction. "
|
||||
f"This may indicate quota limit exceeded or provider not found. "
|
||||
f"Filters: {filters.model_dump()}"
|
||||
)
|
||||
|
||||
logger.debug(f"Successfully processed {len(updates_to_perform)} Provider updates")
|
||||
|
|
@ -5,6 +5,7 @@ from typing import Any, cast
|
|||
|
||||
import httpx
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS
|
||||
from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers
|
||||
|
|
@ -91,6 +92,8 @@ def build_from_mappings(
|
|||
tenant_id: str,
|
||||
strict_type_validation: bool = False,
|
||||
) -> Sequence[File]:
|
||||
# TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query.
|
||||
# Implement batch processing to reduce database load when handling multiple files.
|
||||
files = [
|
||||
build_from_mapping(
|
||||
mapping=mapping,
|
||||
|
|
@ -377,3 +380,75 @@ def _get_file_type_by_mimetype(mime_type: str) -> FileType | None:
|
|||
|
||||
def get_file_type_by_mime_type(mime_type: str) -> FileType:
|
||||
return _get_file_type_by_mimetype(mime_type) or FileType.CUSTOM
|
||||
|
||||
|
||||
class StorageKeyLoader:
|
||||
"""FileKeyLoader load the storage key from database for a list of files.
|
||||
This loader is batched, the database query count is constant regardless of the input size.
|
||||
"""
|
||||
|
||||
def __init__(self, session: Session, tenant_id: str) -> None:
|
||||
self._session = session
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
def _load_upload_files(self, upload_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, UploadFile]:
|
||||
stmt = select(UploadFile).where(
|
||||
UploadFile.id.in_(upload_file_ids),
|
||||
UploadFile.tenant_id == self._tenant_id,
|
||||
)
|
||||
|
||||
return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)}
|
||||
|
||||
def _load_tool_files(self, tool_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, ToolFile]:
|
||||
stmt = select(ToolFile).where(
|
||||
ToolFile.id.in_(tool_file_ids),
|
||||
ToolFile.tenant_id == self._tenant_id,
|
||||
)
|
||||
return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)}
|
||||
|
||||
def load_storage_keys(self, files: Sequence[File]):
|
||||
"""Loads storage keys for a sequence of files by retrieving the corresponding
|
||||
`UploadFile` or `ToolFile` records from the database based on their transfer method.
|
||||
|
||||
This method doesn't modify the input sequence structure but updates the `_storage_key`
|
||||
property of each file object by extracting the relevant key from its database record.
|
||||
|
||||
Performance note: This is a batched operation where database query count remains constant
|
||||
regardless of input size. However, for optimal performance, input sequences should contain
|
||||
fewer than 1000 files. For larger collections, split into smaller batches and process each
|
||||
batch separately.
|
||||
"""
|
||||
|
||||
upload_file_ids: list[uuid.UUID] = []
|
||||
tool_file_ids: list[uuid.UUID] = []
|
||||
for file in files:
|
||||
related_model_id = file.related_id
|
||||
if file.related_id is None:
|
||||
raise ValueError("file id should not be None.")
|
||||
if file.tenant_id != self._tenant_id:
|
||||
err_msg = (
|
||||
f"invalid file, expected tenant_id={self._tenant_id}, "
|
||||
f"got tenant_id={file.tenant_id}, file_id={file.id}, related_model_id={related_model_id}"
|
||||
)
|
||||
raise ValueError(err_msg)
|
||||
model_id = uuid.UUID(related_model_id)
|
||||
|
||||
if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL):
|
||||
upload_file_ids.append(model_id)
|
||||
elif file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
tool_file_ids.append(model_id)
|
||||
|
||||
tool_files = self._load_tool_files(tool_file_ids)
|
||||
upload_files = self._load_upload_files(upload_file_ids)
|
||||
for file in files:
|
||||
model_id = uuid.UUID(file.related_id)
|
||||
if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL):
|
||||
upload_file_row = upload_files.get(model_id)
|
||||
if upload_file_row is None:
|
||||
raise ValueError(f"Upload file not found for id: {model_id}")
|
||||
file._storage_key = upload_file_row.key
|
||||
elif file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
tool_file_row = tool_files.get(model_id)
|
||||
if tool_file_row is None:
|
||||
raise ValueError(f"Tool file not found for id: {model_id}")
|
||||
file._storage_key = tool_file_row.file_key
|
||||
|
|
|
|||
|
|
@ -43,6 +43,10 @@ class UnsupportedSegmentTypeError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
class TypeMismatchError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# Define the constant
|
||||
SEGMENT_TO_VARIABLE_MAP = {
|
||||
StringSegment: StringVariable,
|
||||
|
|
@ -110,6 +114,10 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
|
|||
return cast(Variable, result)
|
||||
|
||||
|
||||
def infer_segment_type_from_value(value: Any, /) -> SegmentType:
|
||||
return build_segment(value).value_type
|
||||
|
||||
|
||||
def build_segment(value: Any, /) -> Segment:
|
||||
if value is None:
|
||||
return NoneSegment()
|
||||
|
|
@ -140,10 +148,80 @@ def build_segment(value: Any, /) -> Segment:
|
|||
case SegmentType.NONE:
|
||||
return ArrayAnySegment(value=value)
|
||||
case _:
|
||||
# This should be unreachable.
|
||||
raise ValueError(f"not supported value {value}")
|
||||
raise ValueError(f"not supported value {value}")
|
||||
|
||||
|
||||
def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
|
||||
"""
|
||||
Build a segment with explicit type checking.
|
||||
|
||||
This function creates a segment from a value while enforcing type compatibility
|
||||
with the specified segment_type. It provides stricter type validation compared
|
||||
to the standard build_segment function.
|
||||
|
||||
Args:
|
||||
segment_type: The expected SegmentType for the resulting segment
|
||||
value: The value to be converted into a segment
|
||||
|
||||
Returns:
|
||||
Segment: A segment instance of the appropriate type
|
||||
|
||||
Raises:
|
||||
TypeMismatchError: If the value type doesn't match the expected segment_type
|
||||
|
||||
Special Cases:
|
||||
- For empty list [] values, if segment_type is array[*], returns the corresponding array type
|
||||
- Type validation is performed before segment creation
|
||||
|
||||
Examples:
|
||||
>>> build_segment_with_type(SegmentType.STRING, "hello")
|
||||
StringSegment(value="hello")
|
||||
|
||||
>>> build_segment_with_type(SegmentType.ARRAY_STRING, [])
|
||||
ArrayStringSegment(value=[])
|
||||
|
||||
>>> build_segment_with_type(SegmentType.STRING, 123)
|
||||
# Raises TypeMismatchError
|
||||
"""
|
||||
# Handle None values
|
||||
if value is None:
|
||||
if segment_type == SegmentType.NONE:
|
||||
return NoneSegment()
|
||||
else:
|
||||
raise TypeMismatchError(f"Expected {segment_type}, but got None")
|
||||
|
||||
# Handle empty list special case for array types
|
||||
if isinstance(value, list) and len(value) == 0:
|
||||
if segment_type == SegmentType.ARRAY_ANY:
|
||||
return ArrayAnySegment(value=value)
|
||||
elif segment_type == SegmentType.ARRAY_STRING:
|
||||
return ArrayStringSegment(value=value)
|
||||
elif segment_type == SegmentType.ARRAY_NUMBER:
|
||||
return ArrayNumberSegment(value=value)
|
||||
elif segment_type == SegmentType.ARRAY_OBJECT:
|
||||
return ArrayObjectSegment(value=value)
|
||||
elif segment_type == SegmentType.ARRAY_FILE:
|
||||
return ArrayFileSegment(value=value)
|
||||
else:
|
||||
raise TypeMismatchError(f"Expected {segment_type}, but got empty list")
|
||||
|
||||
# Build segment using existing logic to infer actual type
|
||||
inferred_segment = build_segment(value)
|
||||
inferred_type = inferred_segment.value_type
|
||||
|
||||
# Type compatibility checking
|
||||
if inferred_type == segment_type:
|
||||
return inferred_segment
|
||||
|
||||
# Type mismatch - raise error with descriptive message
|
||||
raise TypeMismatchError(
|
||||
f"Type mismatch: expected {segment_type}, but value '{value}' "
|
||||
f"(type: {type(value).__name__}) corresponds to {inferred_type}"
|
||||
)
|
||||
|
||||
|
||||
def segment_to_variable(
|
||||
*,
|
||||
segment: Segment,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,22 @@
|
|||
import abc
|
||||
import datetime
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class _NowFunction(Protocol):
|
||||
@abc.abstractmethod
|
||||
def __call__(self, tz: datetime.timezone | None) -> datetime.datetime:
|
||||
pass
|
||||
|
||||
|
||||
# _now_func is a callable with the _NowFunction signature.
|
||||
# Its sole purpose is to abstract time retrieval, enabling
|
||||
# developers to mock this behavior in tests and time-dependent scenarios.
|
||||
_now_func: _NowFunction = datetime.datetime.now
|
||||
|
||||
|
||||
def naive_utc_now() -> datetime.datetime:
|
||||
"""Return a naive datetime object (without timezone information)
|
||||
representing current UTC time.
|
||||
"""
|
||||
return _now_func(datetime.UTC).replace(tzinfo=None)
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
import json
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PydanticModelEncoder(json.JSONEncoder):
|
||||
def default(self, o):
|
||||
if isinstance(o, BaseModel):
|
||||
return o.model_dump()
|
||||
else:
|
||||
super().default(o)
|
||||
|
|
@ -22,7 +22,11 @@ class SMTPClient:
|
|||
if self.use_tls:
|
||||
if self.opportunistic_tls:
|
||||
smtp = smtplib.SMTP(self.server, self.port, timeout=10)
|
||||
# Send EHLO command with the HELO domain name as the server address
|
||||
smtp.ehlo(self.server)
|
||||
smtp.starttls()
|
||||
# Resend EHLO command to identify the TLS session
|
||||
smtp.ehlo(self.server)
|
||||
else:
|
||||
smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,20 @@
|
|||
"""All these exceptions are not meant to be caught by callers."""
|
||||
|
||||
|
||||
class WorkflowDataError(Exception):
|
||||
"""Base class for all workflow data related exceptions.
|
||||
|
||||
This should be used to indicate issues with workflow data integrity, such as
|
||||
no `graph` configuration, missing `nodes` field in `graph` configuration, or
|
||||
similar issues.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class NodeNotFoundError(WorkflowDataError):
|
||||
"""Raised when a node with the specified ID is not found in the workflow."""
|
||||
|
||||
def __init__(self, node_id: str):
|
||||
super().__init__(f"Node with ID '{node_id}' not found in the workflow.")
|
||||
self.node_id = node_id
|
||||
|
|
@ -611,6 +611,14 @@ class InstalledApp(Base):
|
|||
return tenant
|
||||
|
||||
|
||||
class ConversationSource(StrEnum):
|
||||
"""This enumeration is designed for use with `Conversation.from_source`."""
|
||||
|
||||
# NOTE(QuantumGhost): The enumeration members may not cover all possible cases.
|
||||
API = "api"
|
||||
CONSOLE = "console"
|
||||
|
||||
|
||||
class Conversation(Base):
|
||||
__tablename__ = "conversations"
|
||||
__table_args__ = (
|
||||
|
|
@ -632,7 +640,14 @@ class Conversation(Base):
|
|||
system_instruction = db.Column(db.Text)
|
||||
system_instruction_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
status = db.Column(db.String(255), nullable=False)
|
||||
|
||||
# The `invoke_from` records how the conversation is created.
|
||||
#
|
||||
# Its value corresponds to the members of `InvokeFrom`.
|
||||
# (api/core/app/entities/app_invoke_entities.py)
|
||||
invoke_from = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# ref: ConversationSource.
|
||||
from_source = db.Column(db.String(255), nullable=False)
|
||||
from_end_user_id = db.Column(StringUUID)
|
||||
from_account_id = db.Column(StringUUID)
|
||||
|
|
@ -703,7 +718,6 @@ class Conversation(Base):
|
|||
if "model" in override_model_configs:
|
||||
app_model_config = AppModelConfig()
|
||||
app_model_config = app_model_config.from_model_config_dict(override_model_configs)
|
||||
assert app_model_config is not None, "app model config not found"
|
||||
model_config = app_model_config.to_dict()
|
||||
else:
|
||||
model_config["configs"] = override_model_configs
|
||||
|
|
@ -817,7 +831,12 @@ class Conversation(Base):
|
|||
|
||||
@property
|
||||
def first_message(self):
|
||||
return db.session.query(Message).filter(Message.conversation_id == self.id).first()
|
||||
return (
|
||||
db.session.query(Message)
|
||||
.filter(Message.conversation_id == self.id)
|
||||
.order_by(Message.created_at.asc())
|
||||
.first()
|
||||
)
|
||||
|
||||
@property
|
||||
def app(self):
|
||||
|
|
@ -894,11 +913,11 @@ class Message(Base):
|
|||
_inputs: Mapped[dict] = mapped_column("inputs", db.JSON)
|
||||
query: Mapped[str] = db.Column(db.Text, nullable=False)
|
||||
message = db.Column(db.JSON, nullable=False)
|
||||
message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
message_tokens: Mapped[int] = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
message_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
|
||||
message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
|
||||
answer: Mapped[str] = db.Column(db.Text, nullable=False)
|
||||
answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
answer_tokens: Mapped[int] = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
|
||||
answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
|
||||
parent_message_id = db.Column(StringUUID, nullable=True)
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue