mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into feat/webapp-verified-sso-main
This commit is contained in:
commit
c4821a333f
|
|
@ -1,12 +1,13 @@
|
|||
#!/bin/bash
|
||||
|
||||
npm add -g pnpm@10.8.0
|
||||
npm add -g pnpm@10.11.1
|
||||
cd web && pnpm install
|
||||
pipx install uv
|
||||
|
||||
echo 'alias start-api="cd /workspaces/dify/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
|
||||
echo 'alias start-worker="cd /workspaces/dify/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc
|
||||
echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc
|
||||
echo 'alias start-web-prod="cd /workspaces/dify/web && pnpm build && pnpm start"' >> ~/.bashrc
|
||||
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d"' >> ~/.bashrc
|
||||
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down"' >> ~/.bashrc
|
||||
|
||||
|
|
|
|||
|
|
@ -1,25 +1,23 @@
|
|||
# Summary
|
||||
> [!IMPORTANT]
|
||||
>
|
||||
> 1. Make sure you have read our [contribution guidelines](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)
|
||||
> 2. Ensure there is an associated issue and you have been assigned to it
|
||||
> 3. Use the correct syntax to link this PR: `Fixes #<issue number>`.
|
||||
|
||||
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
|
||||
## Summary
|
||||
|
||||
> [!Tip]
|
||||
> Close issue syntax: `Fixes #<issue number>` or `Resolves #<issue number>`, see [documentation](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue#linking-a-pull-request-to-an-issue-using-a-keyword) for more details.
|
||||
<!-- Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change. -->
|
||||
|
||||
|
||||
# Screenshots
|
||||
## Screenshots
|
||||
|
||||
| Before | After |
|
||||
|--------|-------|
|
||||
| ... | ... |
|
||||
|
||||
# Checklist
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Please review the checklist below before submitting your pull request.
|
||||
## Checklist
|
||||
|
||||
- [ ] This change requires a documentation update, included: [Dify Document](https://github.com/langgenius/dify-docs)
|
||||
- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
|
||||
- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
|
||||
- [x] I've updated the documentation accordingly.
|
||||
- [x] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods
|
||||
|
||||
|
|
|
|||
|
|
@ -31,11 +31,19 @@ jobs:
|
|||
echo "FILES_CHANGED=false" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 10
|
||||
run_install: false
|
||||
|
||||
- name: Set up Node.js
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 'lts/*'
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/package.json
|
||||
|
||||
- name: Install dependencies
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
|
|
|
|||
|
|
@ -235,7 +235,7 @@ At the same time, please consider supporting Dify by sharing it on social media
|
|||
|
||||
## Community & contact
|
||||
|
||||
- [Github Discussion](https://github.com/langgenius/dify/discussions). Best for: sharing feedback and asking questions.
|
||||
- [GitHub Discussion](https://github.com/langgenius/dify/discussions). Best for: sharing feedback and asking questions.
|
||||
- [GitHub Issues](https://github.com/langgenius/dify/issues). Best for: bugs you encounter using Dify.AI, and feature proposals. See our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
- [Discord](https://discord.gg/FngNHpbcY7). Best for: sharing your applications and hanging out with the community.
|
||||
- [X(Twitter)](https://twitter.com/dify_ai). Best for: sharing your applications and hanging out with the community.
|
||||
|
|
|
|||
|
|
@ -223,7 +223,7 @@ docker compose up -d
|
|||
</a>
|
||||
|
||||
## المجتمع والاتصال
|
||||
- [مناقشة Github](https://github.com/langgenius/dify/discussions). الأفضل لـ: مشاركة التعليقات وطرح الأسئلة.
|
||||
- [مناقشة GitHub](https://github.com/langgenius/dify/discussions). الأفضل لـ: مشاركة التعليقات وطرح الأسئلة.
|
||||
- [المشكلات على GitHub](https://github.com/langgenius/dify/issues). الأفضل لـ: الأخطاء التي تواجهها في استخدام Dify.AI، واقتراحات الميزات. انظر [دليل المساهمة](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
- [Discord](https://discord.gg/FngNHpbcY7). الأفضل لـ: مشاركة تطبيقاتك والترفيه مع المجتمع.
|
||||
- [تويتر](https://twitter.com/dify_ai). الأفضل لـ: مشاركة تطبيقاتك والترفيه مع المجتمع.
|
||||
|
|
|
|||
|
|
@ -234,7 +234,7 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন
|
|||
|
||||
## কমিউনিটি এবং যোগাযোগ
|
||||
|
||||
- [Github Discussion](https://github.com/langgenius/dify/discussions) ফিডব্যাক এবং প্রতিক্রিয়া জানানোর মাধ্যম।
|
||||
- [GitHub Discussion](https://github.com/langgenius/dify/discussions) ফিডব্যাক এবং প্রতিক্রিয়া জানানোর মাধ্যম।
|
||||
- [GitHub Issues](https://github.com/langgenius/dify/issues). Dify.AI ব্যবহার করে আপনি যেসব বাগের সম্মুখীন হন এবং ফিচার প্রস্তাবনা। আমাদের [অবদান নির্দেশিকা](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) দেখুন।
|
||||
- [Discord](https://discord.gg/FngNHpbcY7) আপনার এপ্লিকেশন শেয়ার এবং কমিউনিটি আড্ডার মাধ্যম।
|
||||
- [X(Twitter)](https://twitter.com/dify_ai) আপনার এপ্লিকেশন শেয়ার এবং কমিউনিটি আড্ডার মাধ্যম।
|
||||
|
|
|
|||
|
|
@ -243,7 +243,7 @@ docker compose up -d
|
|||
|
||||
我们欢迎您为 Dify 做出贡献,以帮助改善 Dify。包括:提交代码、问题、新想法,或分享您基于 Dify 创建的有趣且有用的 AI 应用程序。同时,我们也欢迎您在不同的活动、会议和社交媒体上分享 Dify。
|
||||
|
||||
- [Github Discussion](https://github.com/langgenius/dify/discussions). 👉:分享您的应用程序并与社区交流。
|
||||
- [GitHub Discussion](https://github.com/langgenius/dify/discussions). 👉:分享您的应用程序并与社区交流。
|
||||
- [GitHub Issues](https://github.com/langgenius/dify/issues)。👉:使用 Dify.AI 时遇到的错误和问题,请参阅[贡献指南](CONTRIBUTING.md)。
|
||||
- [电子邮件支持](mailto:hello@dify.ai?subject=[GitHub]Questions%20About%20Dify)。👉:关于使用 Dify.AI 的问题。
|
||||
- [Discord](https://discord.gg/FngNHpbcY7)。👉:分享您的应用程序并与社区交流。
|
||||
|
|
|
|||
|
|
@ -230,7 +230,7 @@ Falls Sie Code beitragen möchten, lesen Sie bitte unseren [Contribution Guide](
|
|||
|
||||
## Gemeinschaft & Kontakt
|
||||
|
||||
* [Github Discussion](https://github.com/langgenius/dify/discussions). Am besten geeignet für: den Austausch von Feedback und das Stellen von Fragen.
|
||||
* [GitHub Discussion](https://github.com/langgenius/dify/discussions). Am besten geeignet für: den Austausch von Feedback und das Stellen von Fragen.
|
||||
* [GitHub Issues](https://github.com/langgenius/dify/issues). Am besten für: Fehler, auf die Sie bei der Verwendung von Dify.AI stoßen, und Funktionsvorschläge. Siehe unseren [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
* [Discord](https://discord.gg/FngNHpbcY7). Am besten geeignet für: den Austausch von Bewerbungen und den Austausch mit der Community.
|
||||
* [X(Twitter)](https://twitter.com/dify_ai). Am besten geeignet für: den Austausch von Bewerbungen und den Austausch mit der Community.
|
||||
|
|
|
|||
|
|
@ -236,7 +236,7 @@ docker compose up -d
|
|||
|
||||
## コミュニティ & お問い合わせ
|
||||
|
||||
* [Github Discussion](https://github.com/langgenius/dify/discussions). 主に: フィードバックの共有や質問。
|
||||
* [GitHub Discussion](https://github.com/langgenius/dify/discussions). 主に: フィードバックの共有や質問。
|
||||
* [GitHub Issues](https://github.com/langgenius/dify/issues). 主に: Dify.AIを使用する際に発生するエラーや問題については、[貢献ガイド](CONTRIBUTING_JA.md)を参照してください
|
||||
* [Discord](https://discord.gg/FngNHpbcY7). 主に: アプリケーションの共有やコミュニティとの交流。
|
||||
* [X(Twitter)](https://twitter.com/dify_ai). 主に: アプリケーションの共有やコミュニティとの交流。
|
||||
|
|
|
|||
|
|
@ -235,7 +235,7 @@ At the same time, please consider supporting Dify by sharing it on social media
|
|||
|
||||
## Community & Contact
|
||||
|
||||
* [Github Discussion](https://github.com/langgenius/dify/discussions
|
||||
* [GitHub Discussion](https://github.com/langgenius/dify/discussions
|
||||
|
||||
). Best for: sharing feedback and asking questions.
|
||||
* [GitHub Issues](https://github.com/langgenius/dify/issues). Best for: bugs you encounter using Dify.AI, and feature proposals. See our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
|
|
|
|||
|
|
@ -229,7 +229,7 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했
|
|||
|
||||
## 커뮤니티 & 연락처
|
||||
|
||||
* [Github 토론](https://github.com/langgenius/dify/discussions). 피드백 공유 및 질문하기에 적합합니다.
|
||||
* [GitHub 토론](https://github.com/langgenius/dify/discussions). 피드백 공유 및 질문하기에 적합합니다.
|
||||
* [GitHub 이슈](https://github.com/langgenius/dify/issues). Dify.AI 사용 중 발견한 버그와 기능 제안에 적합합니다. [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)를 참조하세요.
|
||||
* [디스코드](https://discord.gg/FngNHpbcY7). 애플리케이션 공유 및 커뮤니티와 소통하기에 적합합니다.
|
||||
* [트위터](https://twitter.com/dify_ai). 애플리케이션 공유 및 커뮤니티와 소통하기에 적합합니다.
|
||||
|
|
|
|||
|
|
@ -229,7 +229,7 @@ Za tiste, ki bi radi prispevali kodo, si oglejte naš vodnik za prispevke . Hkra
|
|||
|
||||
## Skupnost in stik
|
||||
|
||||
* [Github Discussion](https://github.com/langgenius/dify/discussions). Najboljše za: izmenjavo povratnih informacij in postavljanje vprašanj.
|
||||
* [GitHub Discussion](https://github.com/langgenius/dify/discussions). Najboljše za: izmenjavo povratnih informacij in postavljanje vprašanj.
|
||||
* [GitHub Issues](https://github.com/langgenius/dify/issues). Najboljše za: hrošče, na katere naletite pri uporabi Dify.AI, in predloge funkcij. Oglejte si naš [vodnik za prispevke](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
* [Discord](https://discord.gg/FngNHpbcY7). Najboljše za: deljenje vaših aplikacij in druženje s skupnostjo.
|
||||
* [X(Twitter)](https://twitter.com/dify_ai). Najboljše za: deljenje vaših aplikacij in druženje s skupnostjo.
|
||||
|
|
|
|||
|
|
@ -227,7 +227,7 @@ Aynı zamanda, lütfen Dify'ı sosyal medyada, etkinliklerde ve konferanslarda p
|
|||
|
||||
## Topluluk & iletişim
|
||||
|
||||
* [Github Tartışmaları](https://github.com/langgenius/dify/discussions). En uygun: geri bildirim paylaşmak ve soru sormak için.
|
||||
* [GitHub Tartışmaları](https://github.com/langgenius/dify/discussions). En uygun: geri bildirim paylaşmak ve soru sormak için.
|
||||
* [GitHub Sorunları](https://github.com/langgenius/dify/issues). En uygun: Dify.AI kullanırken karşılaştığınız hatalar ve özellik önerileri için. [Katkı Kılavuzumuza](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) bakın.
|
||||
* [Discord](https://discord.gg/FngNHpbcY7). En uygun: uygulamalarınızı paylaşmak ve toplulukla vakit geçirmek için.
|
||||
* [X(Twitter)](https://twitter.com/dify_ai). En uygun: uygulamalarınızı paylaşmak ve toplulukla vakit geçirmek için.
|
||||
|
|
|
|||
|
|
@ -233,7 +233,7 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify
|
|||
|
||||
## 社群與聯絡方式
|
||||
|
||||
- [Github Discussion](https://github.com/langgenius/dify/discussions):最適合分享反饋和提問。
|
||||
- [GitHub Discussion](https://github.com/langgenius/dify/discussions):最適合分享反饋和提問。
|
||||
- [GitHub Issues](https://github.com/langgenius/dify/issues):最適合報告使用 Dify.AI 時遇到的問題和提出功能建議。請參閱我們的[貢獻指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。
|
||||
- [Discord](https://discord.gg/FngNHpbcY7):最適合分享您的應用程式並與社群互動。
|
||||
- [X(Twitter)](https://twitter.com/dify_ai):最適合分享您的應用程式並與社群互動。
|
||||
|
|
|
|||
|
|
@ -152,6 +152,7 @@ QDRANT_API_KEY=difyai123456
|
|||
QDRANT_CLIENT_TIMEOUT=20
|
||||
QDRANT_GRPC_ENABLED=false
|
||||
QDRANT_GRPC_PORT=6334
|
||||
QDRANT_REPLICATION_FACTOR=1
|
||||
|
||||
#Couchbase configuration
|
||||
COUCHBASE_CONNECTION_STRING=127.0.0.1
|
||||
|
|
|
|||
|
|
@ -846,6 +846,9 @@ def clear_orphaned_file_records(force: bool):
|
|||
{"type": "text", "table": "workflow_node_executions", "column": "outputs"},
|
||||
{"type": "text", "table": "conversations", "column": "introduction"},
|
||||
{"type": "text", "table": "conversations", "column": "system_instruction"},
|
||||
{"type": "text", "table": "accounts", "column": "avatar"},
|
||||
{"type": "text", "table": "apps", "column": "icon"},
|
||||
{"type": "text", "table": "sites", "column": "icon"},
|
||||
{"type": "json", "table": "messages", "column": "inputs"},
|
||||
{"type": "json", "table": "messages", "column": "message"},
|
||||
]
|
||||
|
|
|
|||
|
|
@ -33,3 +33,8 @@ class QdrantConfig(BaseSettings):
|
|||
description="Port number for gRPC connection to Qdrant server (default is 6334)",
|
||||
default=6334,
|
||||
)
|
||||
|
||||
QDRANT_REPLICATION_FACTOR: PositiveInt = Field(
|
||||
description="Replication factor for Qdrant collections (default is 1)",
|
||||
default=1,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
|||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="1.4.0",
|
||||
default="1.4.1",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
|
|
|||
|
|
@ -60,8 +60,7 @@ class NacosHttpClient:
|
|||
sign_str = tenant + "+"
|
||||
if group:
|
||||
sign_str = sign_str + group + "+"
|
||||
if sign_str:
|
||||
sign_str += ts
|
||||
sign_str += ts # Directly concatenate ts without conditional checks, because the nacos auth header forced it.
|
||||
return sign_str
|
||||
|
||||
def get_access_token(self, force_refresh=False):
|
||||
|
|
|
|||
|
|
@ -11,10 +11,6 @@ if TYPE_CHECKING:
|
|||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
|
||||
tenant_id: ContextVar[str] = ContextVar("tenant_id")
|
||||
|
||||
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
|
||||
|
||||
"""
|
||||
To avoid race-conditions caused by gunicorn thread recycling, using RecyclableContextVar to replace with
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -6,12 +6,12 @@ from sqlalchemy.orm import Session
|
|||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRunStatus
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
|
||||
|
|
@ -38,7 +38,7 @@ class WorkflowAppLogApi(Resource):
|
|||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
args.status = WorkflowRunStatus(args.status) if args.status else None
|
||||
args.status = WorkflowExecutionStatus(args.status) if args.status else None
|
||||
if args.created_at__before:
|
||||
args.created_at__before = isoparse(args.created_at__before)
|
||||
|
||||
|
|
|
|||
|
|
@ -41,12 +41,16 @@ class PluginListApi(Resource):
|
|||
@account_initialization_required
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=int, required=False, location="args", default=1)
|
||||
parser.add_argument("page_size", type=int, required=False, location="args", default=256)
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
plugins = PluginService.list(tenant_id)
|
||||
plugins_with_total = PluginService.list_with_total(tenant_id, args["page"], args["page_size"])
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder({"plugins": plugins})
|
||||
return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total})
|
||||
|
||||
|
||||
class PluginListLatestVersionsApi(Resource):
|
||||
|
|
|
|||
|
|
@ -2,12 +2,14 @@ from collections.abc import Callable
|
|||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from flask import request
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
from flask_restful import reqparse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.login import _get_user
|
||||
from models.account import Account, Tenant
|
||||
from models.model import EndUser
|
||||
from services.account_service import AccountService
|
||||
|
|
@ -80,7 +82,12 @@ def get_user_tenant(view: Optional[Callable] = None):
|
|||
raise ValueError("tenant not found")
|
||||
|
||||
kwargs["tenant_model"] = tenant_model
|
||||
kwargs["user_model"] = get_user(tenant_id, user_id)
|
||||
|
||||
user = get_user(tenant_id, user_id)
|
||||
kwargs["user_model"] = user
|
||||
|
||||
current_app.login_manager._update_request_context_with_user(user) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,19 +3,19 @@ from flask_restful import Resource, marshal, marshal_with, reqparse
|
|||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import (
|
||||
annotation_fields,
|
||||
)
|
||||
from libs.login import current_user
|
||||
from models.model import App, EndUser
|
||||
from models.model import App
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
|
||||
class AnnotationReplyActionApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
def post(self, app_model: App, end_user: EndUser, action):
|
||||
@validate_app_token
|
||||
def post(self, app_model: App, action):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("score_threshold", required=True, type=float, location="json")
|
||||
parser.add_argument("embedding_provider_name", required=True, type=str, location="json")
|
||||
|
|
@ -31,8 +31,8 @@ class AnnotationReplyActionApi(Resource):
|
|||
|
||||
|
||||
class AnnotationReplyActionStatusApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
def get(self, app_model: App, end_user: EndUser, job_id, action):
|
||||
@validate_app_token
|
||||
def get(self, app_model: App, job_id, action):
|
||||
job_id = str(job_id)
|
||||
app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id))
|
||||
cache_result = redis_client.get(app_annotation_job_key)
|
||||
|
|
@ -49,8 +49,8 @@ class AnnotationReplyActionStatusApi(Resource):
|
|||
|
||||
|
||||
class AnnotationListApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
@validate_app_token
|
||||
def get(self, app_model: App):
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
keyword = request.args.get("keyword", default="", type=str)
|
||||
|
|
@ -65,9 +65,9 @@ class AnnotationListApi(Resource):
|
|||
}
|
||||
return response, 200
|
||||
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@validate_app_token
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
def post(self, app_model: App):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("question", required=True, type=str, location="json")
|
||||
parser.add_argument("answer", required=True, type=str, location="json")
|
||||
|
|
@ -77,9 +77,9 @@ class AnnotationListApi(Resource):
|
|||
|
||||
|
||||
class AnnotationUpdateDeleteApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@validate_app_token
|
||||
@marshal_with(annotation_fields)
|
||||
def put(self, app_model: App, end_user: EndUser, annotation_id):
|
||||
def put(self, app_model: App, annotation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
@ -91,8 +91,8 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
|
||||
return annotation
|
||||
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
def delete(self, app_model: App, end_user: EndUser, annotation_id):
|
||||
@validate_app_token
|
||||
def delete(self, app_model: App, annotation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
|
|||
|
|
@ -24,12 +24,13 @@ from core.errors.error import (
|
|||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models.workflow import WorkflowRun, WorkflowRunStatus
|
||||
from models.workflow import WorkflowRun
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
|
@ -138,7 +139,7 @@ class WorkflowAppLogApi(Resource):
|
|||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
args.status = WorkflowRunStatus(args.status) if args.status else None
|
||||
args.status = WorkflowExecutionStatus(args.status) if args.status else None
|
||||
if args.created_at__before:
|
||||
args.created_at__before = isoparse(args.created_at__before)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,19 +1,21 @@
|
|||
from flask import request
|
||||
from flask_restful import marshal, reqparse
|
||||
from flask_restful import marshal, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services.dataset_service
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from controllers.service_api.wraps import DatasetApiResource, validate_dataset_token
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from fields.tag_fields import tag_fields
|
||||
from libs.login import current_user
|
||||
from models.dataset import Dataset, DatasetPermissionEnum
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||
from services.tag_service import TagService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
|
|
@ -320,5 +322,134 @@ class DatasetApi(DatasetApiResource):
|
|||
raise DatasetInUseError()
|
||||
|
||||
|
||||
class DatasetTagsApi(DatasetApiResource):
|
||||
@validate_dataset_token
|
||||
@marshal_with(tag_fields)
|
||||
def get(self, _, dataset_id):
|
||||
"""Get all knowledge type tags."""
|
||||
tags = TagService.get_tags("knowledge", current_user.current_tenant_id)
|
||||
|
||||
return tags, 200
|
||||
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
"""Add a knowledge type tag."""
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 50 characters.",
|
||||
type=DatasetTagsApi._validate_tag_name,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
tag = TagService.save_tags(args)
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
|
||||
return response, 200
|
||||
|
||||
@validate_dataset_token
|
||||
def patch(self, _, dataset_id):
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 50 characters.",
|
||||
type=DatasetTagsApi._validate_tag_name,
|
||||
)
|
||||
parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
|
||||
args = parser.parse_args()
|
||||
tag = TagService.update_tags(args, args.get("tag_id"))
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(args.get("tag_id"))
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
|
||||
return response, 200
|
||||
|
||||
@validate_dataset_token
|
||||
def delete(self, _, dataset_id):
|
||||
"""Delete a knowledge type tag."""
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
|
||||
args = parser.parse_args()
|
||||
TagService.delete_tag(args.get("tag_id"))
|
||||
|
||||
return 204
|
||||
|
||||
@staticmethod
|
||||
def _validate_tag_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 50:
|
||||
raise ValueError("Name must be between 1 to 50 characters.")
|
||||
return name
|
||||
|
||||
|
||||
class DatasetTagBindingApi(DatasetApiResource):
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required."
|
||||
)
|
||||
parser.add_argument(
|
||||
"target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
TagService.save_tag_binding(args)
|
||||
|
||||
return 204
|
||||
|
||||
|
||||
class DatasetTagUnbindingApi(DatasetApiResource):
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
|
||||
parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
|
||||
|
||||
args = parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
TagService.delete_tag_binding(args)
|
||||
|
||||
return 204
|
||||
|
||||
|
||||
class DatasetTagsBindingStatusApi(DatasetApiResource):
|
||||
@validate_dataset_token
|
||||
def get(self, _, *args, **kwargs):
|
||||
"""Get all knowledge type tags."""
|
||||
dataset_id = kwargs.get("dataset_id")
|
||||
tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id))
|
||||
tags_list = [{"id": tag.id, "name": tag.name} for tag in tags]
|
||||
response = {"data": tags_list, "total": len(tags)}
|
||||
return response, 200
|
||||
|
||||
|
||||
api.add_resource(DatasetListApi, "/datasets")
|
||||
api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")
|
||||
api.add_resource(DatasetTagsApi, "/datasets/tags")
|
||||
api.add_resource(DatasetTagBindingApi, "/datasets/tags/binding")
|
||||
api.add_resource(DatasetTagUnbindingApi, "/datasets/tags/unbinding")
|
||||
api.add_resource(DatasetTagsBindingStatusApi, "/datasets/<uuid:dataset_id>/tags")
|
||||
|
|
|
|||
|
|
@ -208,6 +208,28 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||
)
|
||||
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
def get(self, tenant_id, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
|
||||
class ChildChunkApi(DatasetApiResource):
|
||||
"""Resource for child chunks."""
|
||||
|
|
|
|||
|
|
@ -99,7 +99,12 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
|
|||
if user_id:
|
||||
user_id = str(user_id)
|
||||
|
||||
kwargs["end_user"] = create_or_update_end_user_for_user_id(app_model, user_id)
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, user_id)
|
||||
kwargs["end_user"] = end_user
|
||||
|
||||
# Set EndUser as current logged-in user for flask_login.current_user
|
||||
current_app.login_manager._update_request_context_with_user(end_user) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=end_user) # type: ignore
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration if app_config.agent else 5, 5) + 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ class AgentEntity(BaseModel):
|
|||
strategy: Strategy
|
||||
prompt: Optional[AgentPromptEntity] = None
|
||||
tools: Optional[list[AgentToolEntity]] = None
|
||||
max_iteration: int = 5
|
||||
max_iteration: int = 10
|
||||
|
||||
|
||||
class AgentInvokeMessage(ToolInvokeMessage):
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
assert app_config.agent
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1
|
||||
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = True
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ class AgentConfigManager:
|
|||
strategy=strategy,
|
||||
prompt=agent_prompt_entity,
|
||||
tools=agent_tools,
|
||||
max_iteration=agent_dict.get("max_iteration", 5),
|
||||
max_iteration=agent_dict.get("max_iteration", 10),
|
||||
)
|
||||
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ class ModelConfigConverter:
|
|||
if not model_mode:
|
||||
model_mode = LLMMode.CHAT.value
|
||||
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
|
||||
model_mode = LLMMode.value_of(model_schema.model_properties[ModelPropertyKey.MODE]).value
|
||||
model_mode = LLMMode(model_schema.model_properties[ModelPropertyKey.MODE]).value
|
||||
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import uuid
|
|||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from flask import Flask, copy_current_request_context, current_app, has_request_context
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
|
|
@ -27,8 +27,8 @@ from core.ops.ops_trace_manager import TraceQueueManager
|
|||
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
|
|
@ -158,7 +158,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
trace_manager=trace_manager,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
|
|
@ -240,7 +239,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
node_id=node_id, inputs=args["inputs"]
|
||||
),
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
|
|
@ -316,7 +314,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
extras={"auto_generate_conversation_name": False},
|
||||
single_loop_run=AdvancedChatAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
|
|
@ -399,18 +396,23 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
"message_id": message.id,
|
||||
"context": contextvars.copy_context(),
|
||||
},
|
||||
)
|
||||
# new thread with request context and contextvars
|
||||
context = contextvars.copy_context()
|
||||
|
||||
@copy_current_request_context
|
||||
def worker_with_context():
|
||||
# Run the worker within the copied context
|
||||
return context.run(
|
||||
self._generate_worker,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation_id=conversation.id,
|
||||
message_id=message.id,
|
||||
context=context,
|
||||
)
|
||||
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
|
|
@ -449,8 +451,22 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
"""
|
||||
for var, val in context.items():
|
||||
var.set(val)
|
||||
|
||||
# FIXME(-LAN-): Save current user before entering new app context
|
||||
from flask import g
|
||||
|
||||
saved_user = None
|
||||
if has_request_context() and hasattr(g, "_login_user"):
|
||||
saved_user = g._login_user
|
||||
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
# Restore user in new app context
|
||||
if saved_user is not None:
|
||||
from flask import g
|
||||
|
||||
g._login_user = saved_user
|
||||
|
||||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
|
|
|
|||
|
|
@ -140,7 +140,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
SystemVariableKey.DIALOGUE_COUNT: self._dialogue_count,
|
||||
SystemVariableKey.APP_ID: app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id,
|
||||
SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_run_id,
|
||||
}
|
||||
|
||||
# init variable pool
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator, Mapping
|
||||
|
|
@ -57,26 +56,23 @@ from core.app.entities.task_entities import (
|
|||
WorkflowTaskState,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models import Conversation, EndUser, Message, MessageFile
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -126,8 +122,14 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
|
||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||
SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_run_id,
|
||||
},
|
||||
workflow_info=CycleManagerWorkflowInfo(
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType(workflow.type),
|
||||
version=workflow.version,
|
||||
graph_data=workflow.graph_dict,
|
||||
),
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
)
|
||||
|
|
@ -137,7 +139,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
)
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._message_cycle_manager = MessageCycleManage(
|
||||
self._message_cycle_manager = MessageCycleManager(
|
||||
application_generate_entity=application_generate_entity, task_state=self._task_state
|
||||
)
|
||||
|
||||
|
|
@ -158,7 +160,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
:return:
|
||||
"""
|
||||
# start generate conversation name thread
|
||||
self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name(
|
||||
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
|
||||
conversation_id=self._conversation_id, query=self._application_generate_entity.query
|
||||
)
|
||||
|
||||
|
|
@ -302,19 +304,17 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# init workflow run
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start(
|
||||
session=session,
|
||||
workflow_id=self._workflow_id,
|
||||
)
|
||||
self._workflow_run_id = workflow_execution.id
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
|
||||
self._workflow_run_id = workflow_execution.id_
|
||||
message = self._get_message(session=session)
|
||||
if not message:
|
||||
raise ValueError(f"Message not found: {self._message_id}")
|
||||
message.workflow_run_id = workflow_execution.id
|
||||
message.workflow_run_id = workflow_execution.id_
|
||||
workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
yield workflow_start_resp
|
||||
elif isinstance(
|
||||
|
|
@ -549,7 +549,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
status=WorkflowRunStatus.FAILED,
|
||||
status=WorkflowExecutionStatus.FAILED,
|
||||
error_message=event.error,
|
||||
conversation_id=self._conversation_id,
|
||||
trace_manager=trace_manager,
|
||||
|
|
@ -575,7 +575,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
status=WorkflowRunStatus.STOPPED,
|
||||
status=WorkflowExecutionStatus.STOPPED,
|
||||
error_message=event.get_stop_reason(),
|
||||
conversation_id=self._conversation_id,
|
||||
trace_manager=trace_manager,
|
||||
|
|
@ -603,22 +603,18 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
yield self._message_end_to_stream_response()
|
||||
break
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
self._message_cycle_manager._handle_retriever_resources(event)
|
||||
self._message_cycle_manager.handle_retriever_resources(event)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message = self._get_message(session=session)
|
||||
message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
session.commit()
|
||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||
self._message_cycle_manager._handle_annotation_reply(event)
|
||||
self._message_cycle_manager.handle_annotation_reply(event)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message = self._get_message(session=session)
|
||||
message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
session.commit()
|
||||
elif isinstance(event, QueueTextChunkEvent):
|
||||
delta_text = event.text
|
||||
|
|
@ -635,12 +631,12 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
tts_publisher.publish(queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._message_cycle_manager._message_to_stream_response(
|
||||
yield self._message_cycle_manager.message_to_stream_response(
|
||||
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
# published by moderation
|
||||
yield self._message_cycle_manager._message_replace_to_stream_response(
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(
|
||||
answer=event.text, reason=event.reason
|
||||
)
|
||||
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
||||
|
|
@ -652,7 +648,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
)
|
||||
if output_moderation_answer:
|
||||
self._task_state.answer = output_moderation_answer
|
||||
yield self._message_cycle_manager._message_replace_to_stream_response(
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(
|
||||
answer=output_moderation_answer,
|
||||
reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
|
||||
)
|
||||
|
|
@ -681,9 +677,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
message = self._get_message(session=session)
|
||||
message.answer = self._task_state.answer
|
||||
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
|
||||
message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
message_files = [
|
||||
MessageFile(
|
||||
message_id=message.id,
|
||||
|
|
@ -711,9 +705,9 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
message.answer_price_unit = usage.completion_price_unit
|
||||
message.total_price = usage.total_price
|
||||
message.currency = usage.currency
|
||||
self._task_state.metadata["usage"] = jsonable_encoder(usage)
|
||||
self._task_state.metadata.usage = usage
|
||||
else:
|
||||
self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage())
|
||||
self._task_state.metadata.usage = LLMUsage.empty_usage()
|
||||
message_was_created.send(
|
||||
message,
|
||||
application_generate_entity=self._application_generate_entity,
|
||||
|
|
@ -724,18 +718,16 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
Message end to stream response.
|
||||
:return:
|
||||
"""
|
||||
extras = {}
|
||||
if self._task_state.metadata:
|
||||
extras["metadata"] = self._task_state.metadata.copy()
|
||||
extras = self._task_state.metadata.model_dump()
|
||||
|
||||
if "annotation_reply" in extras["metadata"]:
|
||||
del extras["metadata"]["annotation_reply"]
|
||||
if self._task_state.metadata.annotation_reply:
|
||||
del extras["annotation_reply"]
|
||||
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=self._message_id,
|
||||
files=self._recorded_files,
|
||||
metadata=extras.get("metadata", {}),
|
||||
metadata=extras,
|
||||
)
|
||||
|
||||
def _handle_output_moderation_chunk(self, text: str) -> bool:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import uuid
|
|||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from flask import Flask, copy_current_request_context, current_app, has_request_context
|
||||
from pydantic import ValidationError
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -179,18 +179,23 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"context": contextvars.copy_context(),
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
"message_id": message.id,
|
||||
},
|
||||
)
|
||||
# new thread with request context and contextvars
|
||||
context = contextvars.copy_context()
|
||||
|
||||
@copy_current_request_context
|
||||
def worker_with_context():
|
||||
# Run the worker within the copied context
|
||||
return context.run(
|
||||
self._generate_worker,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
context=context,
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation_id=conversation.id,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
|
|
@ -227,8 +232,21 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
for var, val in context.items():
|
||||
var.set(val)
|
||||
|
||||
# FIXME(-LAN-): Save current user before entering new app context
|
||||
from flask import g
|
||||
|
||||
saved_user = None
|
||||
if has_request_context() and hasattr(g, "_login_user"):
|
||||
saved_user = g._login_user
|
||||
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
# Restore user in new app context
|
||||
if saved_user is not None:
|
||||
from flask import g
|
||||
|
||||
g._login_user = saved_user
|
||||
|
||||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import uuid
|
|||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from flask import Flask, copy_current_request_context, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -170,17 +170,18 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
"message_id": message.id,
|
||||
},
|
||||
)
|
||||
# new thread with request context
|
||||
@copy_current_request_context
|
||||
def worker_with_context():
|
||||
return self._generate_worker(
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation_id=conversation.id,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
|
|
|
|||
|
|
@ -44,15 +44,14 @@ from core.app.entities.task_entities import (
|
|||
)
|
||||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.workflow.entities.node_execution_entities import NodeExecution
|
||||
from core.workflow.entities.workflow_execution_entities import WorkflowExecution
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from models import (
|
||||
Account,
|
||||
CreatorUserRole,
|
||||
EndUser,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowRun,
|
||||
)
|
||||
|
||||
|
|
@ -73,11 +72,10 @@ class WorkflowResponseConverter:
|
|||
) -> WorkflowStartStreamResponse:
|
||||
return WorkflowStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_execution.id,
|
||||
workflow_run_id=workflow_execution.id_,
|
||||
data=WorkflowStartStreamResponse.Data(
|
||||
id=workflow_execution.id,
|
||||
id=workflow_execution.id_,
|
||||
workflow_id=workflow_execution.workflow_id,
|
||||
sequence_number=workflow_execution.sequence_number,
|
||||
inputs=workflow_execution.inputs,
|
||||
created_at=int(workflow_execution.started_at.timestamp()),
|
||||
),
|
||||
|
|
@ -91,7 +89,7 @@ class WorkflowResponseConverter:
|
|||
workflow_execution: WorkflowExecution,
|
||||
) -> WorkflowFinishStreamResponse:
|
||||
created_by = None
|
||||
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id))
|
||||
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id_))
|
||||
assert workflow_run is not None
|
||||
if workflow_run.created_by_role == CreatorUserRole.ACCOUNT:
|
||||
stmt = select(Account).where(Account.id == workflow_run.created_by)
|
||||
|
|
@ -122,11 +120,10 @@ class WorkflowResponseConverter:
|
|||
|
||||
return WorkflowFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_execution.id,
|
||||
workflow_run_id=workflow_execution.id_,
|
||||
data=WorkflowFinishStreamResponse.Data(
|
||||
id=workflow_execution.id,
|
||||
id=workflow_execution.id_,
|
||||
workflow_id=workflow_execution.workflow_id,
|
||||
sequence_number=workflow_execution.sequence_number,
|
||||
status=workflow_execution.status,
|
||||
outputs=workflow_execution.outputs,
|
||||
error=workflow_execution.error_message,
|
||||
|
|
@ -146,16 +143,16 @@ class WorkflowResponseConverter:
|
|||
*,
|
||||
event: QueueNodeStartedEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: NodeExecution,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeStartStreamResponse]:
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
if not workflow_node_execution.workflow_run_id:
|
||||
if not workflow_node_execution.workflow_execution_id:
|
||||
return None
|
||||
|
||||
response = NodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_execution_id,
|
||||
data=NodeStartStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
|
|
@ -196,18 +193,18 @@ class WorkflowResponseConverter:
|
|||
| QueueNodeInLoopFailedEvent
|
||||
| QueueNodeExceptionEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: NodeExecution,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
if not workflow_node_execution.workflow_run_id:
|
||||
if not workflow_node_execution.workflow_execution_id:
|
||||
return None
|
||||
if not workflow_node_execution.finished_at:
|
||||
return None
|
||||
|
||||
return NodeFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_execution_id,
|
||||
data=NodeFinishStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
|
|
@ -239,18 +236,18 @@ class WorkflowResponseConverter:
|
|||
*,
|
||||
event: QueueNodeRetryEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: NodeExecution,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
if not workflow_node_execution.workflow_run_id:
|
||||
if not workflow_node_execution.workflow_execution_id:
|
||||
return None
|
||||
if not workflow_node_execution.finished_at:
|
||||
return None
|
||||
|
||||
return NodeRetryStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_execution_id,
|
||||
data=NodeRetryStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import uuid
|
|||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from flask import Flask, copy_current_request_context, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -151,16 +151,17 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"message_id": message.id,
|
||||
},
|
||||
)
|
||||
# new thread with request context
|
||||
@copy_current_request_context
|
||||
def worker_with_context():
|
||||
return self._generate_worker(
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
|
|
@ -313,16 +314,17 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"message_id": message.id,
|
||||
},
|
||||
)
|
||||
# new thread with request context
|
||||
@copy_current_request_context
|
||||
def worker_with_context():
|
||||
return self._generate_worker(
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import uuid
|
|||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from flask import Flask, copy_current_request_context, current_app, has_request_context
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
|
|
@ -25,8 +25,8 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
|||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
|
|
@ -132,10 +132,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
invoke_from=invoke_from,
|
||||
call_depth=call_depth,
|
||||
trace_manager=trace_manager,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_execution_id=workflow_run_id,
|
||||
)
|
||||
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
|
|
@ -207,17 +206,22 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
app_mode=app_model.mode,
|
||||
)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"context": contextvars.copy_context(),
|
||||
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||
},
|
||||
)
|
||||
# new thread with request context and contextvars
|
||||
context = contextvars.copy_context()
|
||||
|
||||
@copy_current_request_context
|
||||
def worker_with_context():
|
||||
# Run the worker within the copied context
|
||||
return context.run(
|
||||
self._generate_worker,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
context=context,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
|
|
@ -275,9 +279,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id, inputs=args["inputs"]
|
||||
),
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
workflow_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
|
|
@ -352,9 +355,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
workflow_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
|
|
@ -408,8 +410,22 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
"""
|
||||
for var, val in context.items():
|
||||
var.set(val)
|
||||
|
||||
# FIXME(-LAN-): Save current user before entering new app context
|
||||
from flask import g
|
||||
|
||||
saved_user = None
|
||||
if has_request_context() and hasattr(g, "_login_user"):
|
||||
saved_user = g._login_user
|
||||
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
# Restore user in new app context
|
||||
if saved_user is not None:
|
||||
from flask import g
|
||||
|
||||
g._login_user = saved_user
|
||||
|
||||
# workflow app
|
||||
runner = WorkflowAppRunner(
|
||||
application_generate_entity=application_generate_entity,
|
||||
|
|
|
|||
|
|
@ -95,7 +95,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
SystemVariableKey.USER_ID: user_id,
|
||||
SystemVariableKey.APP_ID: app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id,
|
||||
SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_execution_id,
|
||||
}
|
||||
|
||||
variable_pool = VariablePool(
|
||||
|
|
|
|||
|
|
@ -50,16 +50,15 @@ from core.app.entities.task_entities import (
|
|||
WorkflowAppStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.workflow_execution_entities import WorkflowExecution
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole
|
||||
|
|
@ -69,7 +68,6 @@ from models.workflow import (
|
|||
WorkflowAppLog,
|
||||
WorkflowAppLogCreatedFrom,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -114,8 +112,14 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
SystemVariableKey.USER_ID: user_session_id,
|
||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||
SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_execution_id,
|
||||
},
|
||||
workflow_info=CycleManagerWorkflowInfo(
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType(workflow.type),
|
||||
version=workflow.version,
|
||||
graph_data=workflow.graph_dict,
|
||||
),
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
)
|
||||
|
|
@ -125,9 +129,7 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
)
|
||||
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_id = workflow.id
|
||||
self._workflow_features_dict = workflow.features_dict
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._workflow_run_id = ""
|
||||
|
||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
|
|
@ -266,17 +268,13 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
# override graph runtime state
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# init workflow run
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start(
|
||||
session=session,
|
||||
workflow_id=self._workflow_id,
|
||||
)
|
||||
self._workflow_run_id = workflow_execution.id
|
||||
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
# init workflow run
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
|
||||
self._workflow_run_id = workflow_execution.id_
|
||||
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
|
||||
yield start_resp
|
||||
elif isinstance(
|
||||
|
|
@ -511,9 +509,9 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
status=WorkflowRunStatus.FAILED
|
||||
status=WorkflowExecutionStatus.FAILED
|
||||
if isinstance(event, QueueWorkflowFailedEvent)
|
||||
else WorkflowRunStatus.STOPPED,
|
||||
else WorkflowExecutionStatus.STOPPED,
|
||||
error_message=event.error
|
||||
if isinstance(event, QueueWorkflowFailedEvent)
|
||||
else event.get_stop_reason(),
|
||||
|
|
@ -542,7 +540,6 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
if tts_publisher:
|
||||
tts_publisher.publish(queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._text_chunk_to_stream_response(
|
||||
delta_text, from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
|
|
@ -557,7 +554,7 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
tts_publisher.publish(None)
|
||||
|
||||
def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None:
|
||||
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id))
|
||||
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id_))
|
||||
assert workflow_run is not None
|
||||
invoke_from = self._application_generate_entity.invoke_from
|
||||
if invoke_from == InvokeFrom.SERVICE_API:
|
||||
|
|
|
|||
|
|
@ -29,8 +29,8 @@ from core.app.entities.queue_entities import (
|
|||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
AgentLogEvent,
|
||||
GraphEngineEvent,
|
||||
|
|
@ -295,7 +295,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
inputs: Mapping[str, Any] | None = {}
|
||||
process_data: Mapping[str, Any] | None = {}
|
||||
outputs: Mapping[str, Any] | None = {}
|
||||
execution_metadata: Mapping[NodeRunMetadataKey, Any] | None = {}
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = {}
|
||||
if node_run_result:
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
|
|
|
|||
|
|
@ -76,6 +76,8 @@ class AppGenerateEntity(BaseModel):
|
|||
App Generate Entity.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
task_id: str
|
||||
|
||||
# app config
|
||||
|
|
@ -99,9 +101,6 @@ class AppGenerateEntity(BaseModel):
|
|||
# tracing instance
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
|
|
@ -205,7 +204,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
|||
|
||||
# app config
|
||||
app_config: WorkflowUIBasedAppConfig
|
||||
workflow_run_id: str
|
||||
workflow_execution_id: str
|
||||
|
||||
class SingleIterationRunEntity(BaseModel):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Optional
|
||||
|
|
@ -6,7 +6,9 @@ from typing import Any, Optional
|
|||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
|
@ -282,7 +284,7 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
|
|||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES
|
||||
retriever_resources: list[dict]
|
||||
retriever_resources: Sequence[RetrievalSourceMetadata]
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
|
|
@ -412,7 +414,7 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
|||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
|
||||
error: Optional[str] = None
|
||||
"""single iteration duration map"""
|
||||
|
|
@ -446,7 +448,7 @@ class QueueNodeRetryEvent(QueueNodeStartedEvent):
|
|||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
retry_index: int # retry index
|
||||
|
|
@ -480,7 +482,7 @@ class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
|||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
|
@ -513,7 +515,7 @@ class QueueNodeInLoopFailedEvent(AppQueueEvent):
|
|||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
|
@ -546,7 +548,7 @@ class QueueNodeExceptionEvent(AppQueueEvent):
|
|||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
|
@ -579,7 +581,7 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
|||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
|
|
|||
|
|
@ -2,12 +2,29 @@ from collections.abc import Mapping, Sequence
|
|||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class AnnotationReplyAccount(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class AnnotationReply(BaseModel):
|
||||
id: str
|
||||
account: AnnotationReplyAccount
|
||||
|
||||
|
||||
class TaskStateMetadata(BaseModel):
|
||||
annotation_reply: AnnotationReply | None = None
|
||||
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(default_factory=list)
|
||||
usage: LLMUsage | None = None
|
||||
|
||||
|
||||
class TaskState(BaseModel):
|
||||
|
|
@ -15,7 +32,7 @@ class TaskState(BaseModel):
|
|||
TaskState entity
|
||||
"""
|
||||
|
||||
metadata: dict = {}
|
||||
metadata: TaskStateMetadata = Field(default_factory=TaskStateMetadata)
|
||||
|
||||
|
||||
class EasyUITaskState(TaskState):
|
||||
|
|
@ -189,7 +206,6 @@ class WorkflowStartStreamResponse(StreamResponse):
|
|||
|
||||
id: str
|
||||
workflow_id: str
|
||||
sequence_number: int
|
||||
inputs: Mapping[str, Any]
|
||||
created_at: int
|
||||
|
||||
|
|
@ -210,7 +226,6 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
|||
|
||||
id: str
|
||||
workflow_id: str
|
||||
sequence_number: int
|
||||
status: str
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
|
|
@ -307,7 +322,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
|||
status: str
|
||||
error: Optional[str] = None
|
||||
elapsed_time: float
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||
|
|
@ -376,7 +391,7 @@ class NodeRetryStreamResponse(StreamResponse):
|
|||
status: str
|
||||
error: Optional[str] = None
|
||||
elapsed_time: float
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
|
|
@ -43,7 +42,7 @@ from core.app.entities.task_entities import (
|
|||
StreamResponse,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
|
|
@ -51,7 +50,6 @@ from core.model_runtime.entities.message_entities import (
|
|||
AssistantPromptMessage,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
|
|
@ -63,7 +61,7 @@ from models.model import AppMode, Conversation, Message, MessageAgentThought
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleManage):
|
||||
class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
"""
|
||||
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
|
@ -104,6 +102,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
)
|
||||
)
|
||||
|
||||
self._message_cycle_manager = MessageCycleManager(
|
||||
application_generate_entity=application_generate_entity,
|
||||
task_state=self._task_state,
|
||||
)
|
||||
|
||||
self._conversation_name_generate_thread: Optional[Thread] = None
|
||||
|
||||
def process(
|
||||
|
|
@ -115,7 +118,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
]:
|
||||
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
|
||||
# start generate conversation name thread
|
||||
self._conversation_name_generate_thread = self._generate_conversation_name(
|
||||
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
|
||||
conversation_id=self._conversation_id, query=self._application_generate_entity.query or ""
|
||||
)
|
||||
|
||||
|
|
@ -136,9 +139,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
if isinstance(stream_response, ErrorStreamResponse):
|
||||
raise stream_response.err
|
||||
elif isinstance(stream_response, MessageEndStreamResponse):
|
||||
extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)}
|
||||
extras = {"usage": self._task_state.llm_result.usage.model_dump()}
|
||||
if self._task_state.metadata:
|
||||
extras["metadata"] = self._task_state.metadata
|
||||
extras["metadata"] = self._task_state.metadata.model_dump()
|
||||
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
|
||||
if self._conversation_mode == AppMode.COMPLETION.value:
|
||||
response = CompletionAppBlockingResponse(
|
||||
|
|
@ -277,7 +280,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
)
|
||||
if output_moderation_answer:
|
||||
self._task_state.llm_result.message.content = output_moderation_answer
|
||||
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(
|
||||
answer=output_moderation_answer
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Save message
|
||||
|
|
@ -286,9 +291,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
message_end_resp = self._message_end_to_stream_response()
|
||||
yield message_end_resp
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
self._handle_retriever_resources(event)
|
||||
self._message_cycle_manager.handle_retriever_resources(event)
|
||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||
annotation = self._handle_annotation_reply(event)
|
||||
annotation = self._message_cycle_manager.handle_annotation_reply(event)
|
||||
if annotation:
|
||||
self._task_state.llm_result.message.content = annotation.content
|
||||
elif isinstance(event, QueueAgentThoughtEvent):
|
||||
|
|
@ -296,7 +301,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
if agent_thought_response is not None:
|
||||
yield agent_thought_response
|
||||
elif isinstance(event, QueueMessageFileEvent):
|
||||
response = self._message_file_to_stream_response(event)
|
||||
response = self._message_cycle_manager.message_file_to_stream_response(event)
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent):
|
||||
|
|
@ -318,7 +323,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
self._task_state.llm_result.message.content = current_content
|
||||
|
||||
if isinstance(event, QueueLLMChunkEvent):
|
||||
yield self._message_to_stream_response(
|
||||
yield self._message_cycle_manager.message_to_stream_response(
|
||||
answer=cast(str, delta_text),
|
||||
message_id=self._message_id,
|
||||
)
|
||||
|
|
@ -328,7 +333,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
message_id=self._message_id,
|
||||
)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
yield self._message_replace_to_stream_response(answer=event.text)
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text)
|
||||
elif isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
else:
|
||||
|
|
@ -372,9 +377,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
message.provider_response_latency = time.perf_counter() - self._start_at
|
||||
message.total_price = usage.total_price
|
||||
message.currency = usage.currency
|
||||
message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
|
|
@ -423,16 +426,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
Message end to stream response.
|
||||
:return:
|
||||
"""
|
||||
self._task_state.metadata["usage"] = jsonable_encoder(self._task_state.llm_result.usage)
|
||||
|
||||
extras = {}
|
||||
if self._task_state.metadata:
|
||||
extras["metadata"] = self._task_state.metadata
|
||||
|
||||
self._task_state.metadata.usage = self._task_state.llm_result.usage
|
||||
metadata_dict = self._task_state.metadata.model_dump()
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=self._message_id,
|
||||
metadata=extras.get("metadata", {}),
|
||||
metadata=metadata_dict,
|
||||
)
|
||||
|
||||
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
|
||||
|
|
@ -455,8 +454,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
agent_thought: Optional[MessageAgentThought] = (
|
||||
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first()
|
||||
)
|
||||
db.session.refresh(agent_thought)
|
||||
db.session.close()
|
||||
|
||||
if agent_thought:
|
||||
return AgentThoughtStreamResponse(
|
||||
|
|
|
|||
|
|
@ -17,6 +17,8 @@ from core.app.entities.queue_entities import (
|
|||
QueueRetrieverResourcesEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AnnotationReply,
|
||||
AnnotationReplyAccount,
|
||||
EasyUITaskState,
|
||||
MessageFileStreamResponse,
|
||||
MessageReplaceStreamResponse,
|
||||
|
|
@ -30,7 +32,7 @@ from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
|
|||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
|
||||
class MessageCycleManage:
|
||||
class MessageCycleManager:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
|
|
@ -45,7 +47,7 @@ class MessageCycleManage:
|
|||
self._application_generate_entity = application_generate_entity
|
||||
self._task_state = task_state
|
||||
|
||||
def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
|
||||
def generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
|
||||
"""
|
||||
Generate conversation name.
|
||||
:param conversation_id: conversation id
|
||||
|
|
@ -102,7 +104,7 @@ class MessageCycleManage:
|
|||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
def _handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]:
|
||||
def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]:
|
||||
"""
|
||||
Handle annotation reply.
|
||||
:param event: event
|
||||
|
|
@ -111,25 +113,28 @@ class MessageCycleManage:
|
|||
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
|
||||
if annotation:
|
||||
account = annotation.account
|
||||
self._task_state.metadata["annotation_reply"] = {
|
||||
"id": annotation.id,
|
||||
"account": {"id": annotation.account_id, "name": account.name if account else "Dify user"},
|
||||
}
|
||||
self._task_state.metadata.annotation_reply = AnnotationReply(
|
||||
id=annotation.id,
|
||||
account=AnnotationReplyAccount(
|
||||
id=annotation.account_id,
|
||||
name=account.name if account else "Dify user",
|
||||
),
|
||||
)
|
||||
|
||||
return annotation
|
||||
|
||||
return None
|
||||
|
||||
def _handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None:
|
||||
def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None:
|
||||
"""
|
||||
Handle retriever resources.
|
||||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
|
||||
self._task_state.metadata["retriever_resources"] = event.retriever_resources
|
||||
self._task_state.metadata.retriever_resources = event.retriever_resources
|
||||
|
||||
def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
|
||||
def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
|
||||
"""
|
||||
Message file to stream response.
|
||||
:param event: event
|
||||
|
|
@ -166,7 +171,7 @@ class MessageCycleManage:
|
|||
|
||||
return None
|
||||
|
||||
def _message_to_stream_response(
|
||||
def message_to_stream_response(
|
||||
self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None
|
||||
) -> MessageStreamResponse:
|
||||
"""
|
||||
|
|
@ -182,7 +187,7 @@ class MessageCycleManage:
|
|||
from_variable_selector=from_variable_selector,
|
||||
)
|
||||
|
||||
def _message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
|
||||
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
|
||||
"""
|
||||
Message replace to stream response.
|
||||
:param answer: answer
|
||||
|
|
@ -1,8 +1,10 @@
|
|||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -85,7 +87,8 @@ class DatasetIndexToolCallbackHandler:
|
|||
|
||||
db.session.commit()
|
||||
|
||||
def return_retriever_resource_info(self, resource: list):
|
||||
# TODO(-LAN-): Improve type check
|
||||
def return_retriever_resource_info(self, resource: Sequence[RetrievalSourceMetadata]):
|
||||
"""Handle return_retriever_resource_info."""
|
||||
self._queue_manager.publish(
|
||||
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from core.helper.code_executor.python3.python3_transformer import Python3Templat
|
|||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT))
|
||||
|
||||
|
||||
class CodeExecutionError(Exception):
|
||||
|
|
@ -64,7 +65,7 @@ class CodeExecutor:
|
|||
:param code: code
|
||||
:return:
|
||||
"""
|
||||
url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) / "v1" / "sandbox" / "run"
|
||||
url = code_execution_endpoint_url / "v1" / "sandbox" / "run"
|
||||
|
||||
headers = {"X-Api-Key": dify_config.CODE_EXECUTION_API_KEY}
|
||||
|
||||
|
|
|
|||
|
|
@ -7,29 +7,28 @@ from configs import dify_config
|
|||
from core.helper.download import download_with_size_limit
|
||||
from core.plugin.entities.marketplace import MarketplacePluginDeclaration
|
||||
|
||||
marketplace_api_url = URL(str(dify_config.MARKETPLACE_API_URL))
|
||||
|
||||
def get_plugin_pkg_url(plugin_unique_identifier: str):
|
||||
return (URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/plugins/download").with_query(
|
||||
unique_identifier=plugin_unique_identifier
|
||||
)
|
||||
|
||||
def get_plugin_pkg_url(plugin_unique_identifier: str) -> str:
|
||||
return str((marketplace_api_url / "api/v1/plugins/download").with_query(unique_identifier=plugin_unique_identifier))
|
||||
|
||||
|
||||
def download_plugin_pkg(plugin_unique_identifier: str):
|
||||
url = str(get_plugin_pkg_url(plugin_unique_identifier))
|
||||
return download_with_size_limit(url, dify_config.PLUGIN_MAX_PACKAGE_SIZE)
|
||||
return download_with_size_limit(get_plugin_pkg_url(plugin_unique_identifier), dify_config.PLUGIN_MAX_PACKAGE_SIZE)
|
||||
|
||||
|
||||
def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplacePluginDeclaration]:
|
||||
if len(plugin_ids) == 0:
|
||||
return []
|
||||
|
||||
url = str(URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/plugins/batch")
|
||||
url = str(marketplace_api_url / "api/v1/plugins/batch")
|
||||
response = requests.post(url, json={"plugin_ids": plugin_ids})
|
||||
response.raise_for_status()
|
||||
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
|
||||
|
||||
|
||||
def record_install_plugin_event(plugin_unique_identifier: str):
|
||||
url = str(URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/stats/plugins/install_count")
|
||||
url = str(marketplace_api_url / "api/v1/stats/plugins/install_count")
|
||||
response = requests.post(url, json={"unique_identifier": plugin_unique_identifier})
|
||||
response.raise_for_status()
|
||||
|
|
|
|||
|
|
@ -51,15 +51,19 @@ class LLMGenerator:
|
|||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(prompts), model_parameters={"max_tokens": 100, "temperature": 1}, stream=False
|
||||
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
|
||||
),
|
||||
)
|
||||
answer = cast(str, response.message.content)
|
||||
cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL)
|
||||
if cleaned_answer is None:
|
||||
return ""
|
||||
result_dict = json.loads(cleaned_answer)
|
||||
answer = result_dict["Your Output"]
|
||||
try:
|
||||
result_dict = json.loads(cleaned_answer)
|
||||
answer = result_dict["Your Output"]
|
||||
except json.JSONDecodeError as e:
|
||||
logging.exception("Failed to generate name after answer, use query instead")
|
||||
answer = query
|
||||
name = answer.strip()
|
||||
|
||||
if len(name) > 75:
|
||||
|
|
|
|||
|
|
@ -1,61 +1,20 @@
|
|||
# Written by YORKI MINAKO🤡, Edited by Xiaoyi
|
||||
CONVERSATION_TITLE_PROMPT = """You need to decompose the user's input into "subject" and "intention" in order to accurately figure out what the user's input language actually is.
|
||||
Notice: the language type user uses could be diverse, which can be English, Chinese, Italian, Español, Arabic, Japanese, French, and etc.
|
||||
ENSURE your output is in the SAME language as the user's input!
|
||||
Your output is restricted only to: (Input language) Intention + Subject(short as possible)
|
||||
Your output MUST be a valid JSON.
|
||||
# Written by YORKI MINAKO🤡, Edited by Xiaoyi, Edited by yasu-oh
|
||||
CONVERSATION_TITLE_PROMPT = """You are asked to generate a concise chat title by decomposing the user’s input into two parts: “Intention” and “Subject”.
|
||||
|
||||
Tip: When the user's question is directed at you (the language model), you can add an emoji to make it more fun.
|
||||
1. Detect Input Language
|
||||
Automatically identify the language of the user’s input (e.g. English, Chinese, Italian, Español, Arabic, Japanese, French, and etc.).
|
||||
|
||||
2. Generate Title
|
||||
- Combine Intention + Subject into a single, as-short-as-possible phrase.
|
||||
- The title must be natural, friendly, and in the same language as the input.
|
||||
- If the input is a direct question to the model, you may add an emoji at the end.
|
||||
|
||||
example 1:
|
||||
User Input: hi, yesterday i had some burgers.
|
||||
3. Output Format
|
||||
Return **only** a valid JSON object with these exact keys and no additional text:
|
||||
{
|
||||
"Language Type": "The user's input is pure English",
|
||||
"Your Reasoning": "The language of my output must be pure English.",
|
||||
"Your Output": "sharing yesterday's food"
|
||||
}
|
||||
|
||||
example 2:
|
||||
User Input: hello
|
||||
{
|
||||
"Language Type": "The user's input is pure English",
|
||||
"Your Reasoning": "The language of my output must be pure English.",
|
||||
"Your Output": "Greeting myself☺️"
|
||||
}
|
||||
|
||||
|
||||
example 3:
|
||||
User Input: why mmap file: oom
|
||||
{
|
||||
"Language Type": "The user's input is written in pure English",
|
||||
"Your Reasoning": "The language of my output must be pure English.",
|
||||
"Your Output": "Asking about the reason for mmap file: oom"
|
||||
}
|
||||
|
||||
|
||||
example 4:
|
||||
User Input: www.convinceme.yesterday-you-ate-seafood.tv讲了什么?
|
||||
{
|
||||
"Language Type": "The user's input English-Chinese mixed",
|
||||
"Your Reasoning": "The English-part is an URL, the main intention is still written in Chinese, so the language of my output must be using Chinese.",
|
||||
"Your Output": "询问网站www.convinceme.yesterday-you-ate-seafood.tv"
|
||||
}
|
||||
|
||||
example 5:
|
||||
User Input: why小红的年龄is老than小明?
|
||||
{
|
||||
"Language Type": "The user's input is English-Chinese mixed",
|
||||
"Your Reasoning": "The English parts are filler words, the main intention is written in Chinese, besides, Chinese occupies a greater \"actual meaning\" than English, so the language of my output must be using Chinese.",
|
||||
"Your Output": "询问小红和小明的年龄"
|
||||
}
|
||||
|
||||
example 6:
|
||||
User Input: yo, 你今天咋样?
|
||||
{
|
||||
"Language Type": "The user's input is English-Chinese mixed",
|
||||
"Your Reasoning": "The English-part is a subjective particle, the main intention is written in Chinese, so the language of my output must be using Chinese.",
|
||||
"Your Output": "查询今日我的状态☺️"
|
||||
"Language Type": "<Detected language>",
|
||||
"Your Reasoning": "<Brief explanation in that language>",
|
||||
"Your Output": "<Intention + Subject>"
|
||||
}
|
||||
|
||||
User Input:
|
||||
|
|
|
|||
|
|
@ -17,19 +17,6 @@ class LLMMode(StrEnum):
|
|||
COMPLETION = "completion"
|
||||
CHAT = "chat"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "LLMMode":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
class LLMUsage(ModelUsage):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -129,17 +129,18 @@ def jsonable_encoder(
|
|||
sqlalchemy_safe=sqlalchemy_safe,
|
||||
)
|
||||
if dataclasses.is_dataclass(obj):
|
||||
# FIXME: mypy error, try to fix it instead of using type: ignore
|
||||
obj_dict = dataclasses.asdict(obj) # type: ignore
|
||||
return jsonable_encoder(
|
||||
obj_dict,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
custom_encoder=custom_encoder,
|
||||
sqlalchemy_safe=sqlalchemy_safe,
|
||||
)
|
||||
# Ensure obj is a dataclass instance, not a dataclass type
|
||||
if not isinstance(obj, type):
|
||||
obj_dict = dataclasses.asdict(obj)
|
||||
return jsonable_encoder(
|
||||
obj_dict,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
custom_encoder=custom_encoder,
|
||||
sqlalchemy_safe=sqlalchemy_safe,
|
||||
)
|
||||
if isinstance(obj, Enum):
|
||||
return obj.value
|
||||
if isinstance(obj, PurePath):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,11 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.ops.entities.config_entity import BaseTracingConfig
|
||||
from core.ops.entities.trace_entity import BaseTraceInfo
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App, TenantAccountJoin
|
||||
|
||||
|
||||
class BaseTraceInstance(ABC):
|
||||
|
|
@ -24,3 +28,38 @@ class BaseTraceInstance(ABC):
|
|||
Subclasses must implement specific tracing logic for activities.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_service_account_with_tenant(self, app_id: str) -> Account:
|
||||
"""
|
||||
Get service account for an app and set up its tenant.
|
||||
|
||||
Args:
|
||||
app_id: The ID of the app
|
||||
|
||||
Returns:
|
||||
Account: The service account with tenant set up
|
||||
|
||||
Raises:
|
||||
ValueError: If app, creator account or tenant cannot be found
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get the app to find its creator
|
||||
app = session.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||
|
||||
service_account = session.query(Account).filter(Account.id == app.created_by).first()
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||
|
||||
current_tenant = (
|
||||
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
|
||||
)
|
||||
if not current_tenant:
|
||||
raise ValueError(f"Current tenant not found for account {service_account.id}")
|
||||
service_account.set_tenant_id(current_tenant.tenant_id)
|
||||
|
||||
return service_account
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from datetime import datetime
|
|||
from enum import StrEnum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
|
||||
|
||||
|
||||
class BaseTraceInfo(BaseModel):
|
||||
|
|
@ -24,10 +24,13 @@ class BaseTraceInfo(BaseModel):
|
|||
return v
|
||||
return ""
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat(),
|
||||
}
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@field_serializer("start_time", "end_time")
|
||||
def serialize_datetime(self, dt: datetime | None) -> str | None:
|
||||
if dt is None:
|
||||
return None
|
||||
return dt.isoformat()
|
||||
|
||||
|
||||
class WorkflowTraceInfo(BaseTraceInfo):
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from datetime import datetime, timedelta
|
|||
from typing import Optional
|
||||
|
||||
from langfuse import Langfuse # type: ignore
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import LangfuseConfig
|
||||
|
|
@ -31,7 +31,7 @@ from core.ops.utils import filter_none_values
|
|||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App, EndUser, WorkflowNodeExecutionTriggeredFrom
|
||||
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -114,22 +114,11 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
# through workflow_run_id get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
# Find the app's creator account
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get the app to find its creator
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
app = session.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||
|
||||
service_account = session.query(Account).filter(Account.id == app.created_by).first()
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from typing import Optional, cast
|
|||
|
||||
from langsmith import Client
|
||||
from langsmith.schemas import RunBase
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import LangSmithConfig
|
||||
|
|
@ -28,10 +28,10 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
|
|||
)
|
||||
from core.ops.utils import filter_none_values, generate_dotted_order
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -139,22 +139,11 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
# through workflow_run_id get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
# Find the app's creator account
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get the app to find its creator
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
app = session.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||
|
||||
service_account = session.query(Account).filter(Account.id == app.created_by).first()
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
|
|
@ -185,7 +174,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
execution_metadata = node_execution.metadata if node_execution.metadata else {}
|
||||
node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0
|
||||
node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
|
||||
metadata = {str(key): value for key, value in execution_metadata.items()}
|
||||
metadata.update(
|
||||
{
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from typing import Optional, cast
|
|||
|
||||
from opik import Opik, Trace
|
||||
from opik.id_helpers import uuid4_to_uuid7
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import OpikConfig
|
||||
|
|
@ -22,10 +22,10 @@ from core.ops.entities.trace_entity import (
|
|||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -154,22 +154,11 @@ class OpikDataTrace(BaseTraceInstance):
|
|||
# through workflow_run_id get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
# Find the app's creator account
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get the app to find its creator
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
app = session.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||
|
||||
service_account = session.query(Account).filter(Account.id == app.created_by).first()
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
|
|
@ -246,7 +235,7 @@ class OpikDataTrace(BaseTraceInstance):
|
|||
parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id
|
||||
|
||||
if not total_tokens:
|
||||
total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0
|
||||
total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
|
||||
|
||||
span_data = {
|
||||
"trace_id": opik_trace_id,
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ from core.ops.entities.trace_entity import (
|
|||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.utils import get_message_data
|
||||
from core.workflow.entities.workflow_execution_entities import WorkflowExecution
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
|
||||
|
|
@ -386,7 +386,7 @@ class TraceTask:
|
|||
):
|
||||
self.trace_type = trace_type
|
||||
self.message_id = message_id
|
||||
self.workflow_run_id = workflow_execution.id if workflow_execution else None
|
||||
self.workflow_run_id = workflow_execution.id_ if workflow_execution else None
|
||||
self.conversation_id = conversation_id
|
||||
self.user_id = user_id
|
||||
self.timer = timer
|
||||
|
|
@ -487,6 +487,7 @@ class TraceTask:
|
|||
"file_list": file_list,
|
||||
"triggered_from": workflow_run.triggered_from,
|
||||
"user_id": user_id,
|
||||
"app_id": workflow_run.app_id,
|
||||
}
|
||||
|
||||
workflow_trace_info = WorkflowTraceInfo(
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from typing import Any, Optional, cast
|
|||
|
||||
import wandb
|
||||
import weave
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import WeaveConfig
|
||||
|
|
@ -23,10 +23,10 @@ from core.ops.entities.trace_entity import (
|
|||
)
|
||||
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -133,22 +133,11 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
# through workflow_run_id get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
# Find the app's creator account
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get the app to find its creator
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
app = session.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||
|
||||
service_account = session.query(Account).filter(Account.id == app.created_by).first()
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
|
|
@ -179,7 +168,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
execution_metadata = node_execution.metadata if node_execution.metadata else {}
|
||||
node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0
|
||||
node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
|
||||
attributes = {str(k): v for k, v in execution_metadata.items()}
|
||||
attributes.update(
|
||||
{
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
|||
LLMNode.deduct_llm_quota(
|
||||
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
|
||||
)
|
||||
chunk.prompt_messages = []
|
||||
yield chunk
|
||||
|
||||
return handle()
|
||||
|
|
@ -68,7 +69,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
|||
def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
|
||||
yield LLMResultChunk(
|
||||
model=response.model,
|
||||
prompt_messages=response.prompt_messages,
|
||||
prompt_messages=[],
|
||||
system_fingerprint=response.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from core.agent.plugin_entities import AgentProviderEntityWithPlugin
|
|||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from core.plugin.entities.base import BasePluginEntity
|
||||
from core.plugin.entities.plugin import PluginDeclaration
|
||||
from core.plugin.entities.plugin import PluginDeclaration, PluginEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin
|
||||
|
||||
|
|
@ -167,3 +167,8 @@ class PluginOAuthAuthorizationUrlResponse(BaseModel):
|
|||
|
||||
class PluginOAuthCredentialsResponse(BaseModel):
|
||||
credentials: Mapping[str, Any] = Field(description="The credentials of the OAuth.")
|
||||
|
||||
|
||||
class PluginListResponse(BaseModel):
|
||||
list: list[PluginEntity]
|
||||
total: int
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from typing import TypeVar
|
|||
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from requests.exceptions import HTTPError
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -30,8 +31,7 @@ from core.plugin.impl.exc import (
|
|||
PluginUniqueIdentifierError,
|
||||
)
|
||||
|
||||
plugin_daemon_inner_api_baseurl = dify_config.PLUGIN_DAEMON_URL
|
||||
plugin_daemon_inner_api_key = dify_config.PLUGIN_DAEMON_KEY
|
||||
plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL))
|
||||
|
||||
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
|
||||
|
||||
|
|
@ -52,9 +52,9 @@ class BasePluginClient:
|
|||
"""
|
||||
Make a request to the plugin daemon inner API.
|
||||
"""
|
||||
url = URL(str(plugin_daemon_inner_api_baseurl)) / path
|
||||
url = plugin_daemon_inner_api_baseurl / path
|
||||
headers = headers or {}
|
||||
headers["X-Api-Key"] = plugin_daemon_inner_api_key
|
||||
headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
|
||||
headers["Accept-Encoding"] = "gzip, deflate, br"
|
||||
|
||||
if headers.get("Content-Type") == "application/json" and isinstance(data, dict):
|
||||
|
|
@ -136,12 +136,31 @@ class BasePluginClient:
|
|||
"""
|
||||
Make a request to the plugin daemon inner API and return the response as a model.
|
||||
"""
|
||||
response = self._request(method, path, headers, data, params, files)
|
||||
json_response = response.json()
|
||||
if transformer:
|
||||
json_response = transformer(json_response)
|
||||
try:
|
||||
response = self._request(method, path, headers, data, params, files)
|
||||
response.raise_for_status()
|
||||
except HTTPError as e:
|
||||
msg = f"Failed to request plugin daemon, status: {e.response.status_code}, url: {path}"
|
||||
logging.exception(msg)
|
||||
raise e
|
||||
except Exception as e:
|
||||
msg = f"Failed to request plugin daemon, url: {path}"
|
||||
logging.exception(msg)
|
||||
raise ValueError(msg) from e
|
||||
|
||||
try:
|
||||
json_response = response.json()
|
||||
if transformer:
|
||||
json_response = transformer(json_response)
|
||||
rep = PluginDaemonBasicResponse[type](**json_response) # type: ignore
|
||||
except Exception:
|
||||
msg = (
|
||||
f"Failed to parse response from plugin daemon to PluginDaemonBasicResponse [{str(type.__name__)}],"
|
||||
f" url: {path}"
|
||||
)
|
||||
logging.exception(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
rep = PluginDaemonBasicResponse[type](**json_response) # type: ignore
|
||||
if rep.code != 0:
|
||||
try:
|
||||
error = PluginDaemonError(**json.loads(rep.message))
|
||||
|
|
|
|||
|
|
@ -9,7 +9,12 @@ from core.plugin.entities.plugin import (
|
|||
PluginInstallation,
|
||||
PluginInstallationSource,
|
||||
)
|
||||
from core.plugin.entities.plugin_daemon import PluginInstallTask, PluginInstallTaskStartResponse, PluginUploadResponse
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
PluginInstallTask,
|
||||
PluginInstallTaskStartResponse,
|
||||
PluginListResponse,
|
||||
PluginUploadResponse,
|
||||
)
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
|
||||
|
||||
|
|
@ -27,11 +32,20 @@ class PluginInstaller(BasePluginClient):
|
|||
)
|
||||
|
||||
def list_plugins(self, tenant_id: str) -> list[PluginEntity]:
|
||||
result = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/list",
|
||||
PluginListResponse,
|
||||
params={"page": 1, "page_size": 256},
|
||||
)
|
||||
return result.list
|
||||
|
||||
def list_plugins_with_total(self, tenant_id: str, page: int, page_size: int) -> PluginListResponse:
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/list",
|
||||
list[PluginEntity],
|
||||
params={"page": 1, "page_size": 256},
|
||||
PluginListResponse,
|
||||
params={"page": page, "page_size": page_size},
|
||||
)
|
||||
|
||||
def upload_pkg(
|
||||
|
|
|
|||
|
|
@ -85,7 +85,6 @@ class BaiduVector(BaseVector):
|
|||
end = min(start + batch_size, total_count)
|
||||
rows = []
|
||||
assert len(metadatas) == total_count, "metadatas length should be equal to total_count"
|
||||
# FIXME do you need this assert?
|
||||
for i in range(start, end, 1):
|
||||
row = Row(
|
||||
id=metadatas[i].get("doc_id", str(uuid.uuid4())),
|
||||
|
|
|
|||
|
|
@ -142,7 +142,7 @@ class ElasticSearchVector(BaseVector):
|
|||
if score > score_threshold:
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
|
|
|
|||
|
|
@ -97,6 +97,10 @@ class MilvusVector(BaseVector):
|
|||
|
||||
try:
|
||||
milvus_version = self._client.get_server_version()
|
||||
# Check if it's Zilliz Cloud - it supports full-text search with Milvus 2.5 compatibility
|
||||
if "Zilliz Cloud" in milvus_version:
|
||||
return True
|
||||
# For standard Milvus installations, check version number
|
||||
return version.parse(milvus_version).base_version >= version.parse("2.5.0").base_version
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to check Milvus version: {str(e)}. Disabling hybrid search.")
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ class QdrantConfig(BaseModel):
|
|||
root_path: Optional[str] = None
|
||||
grpc_port: int = 6334
|
||||
prefer_grpc: bool = False
|
||||
replication_factor: int = 1
|
||||
|
||||
def to_qdrant_params(self):
|
||||
if self.endpoint and self.endpoint.startswith("path:"):
|
||||
|
|
@ -119,11 +120,13 @@ class QdrantVector(BaseVector):
|
|||
max_indexing_threads=0,
|
||||
on_disk=False,
|
||||
)
|
||||
|
||||
self._client.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=vectors_config,
|
||||
hnsw_config=hnsw_config,
|
||||
timeout=int(self._client_config.timeout),
|
||||
replication_factor=self._client_config.replication_factor,
|
||||
)
|
||||
|
||||
# create group_id payload index
|
||||
|
|
@ -466,5 +469,6 @@ class QdrantVectorFactory(AbstractVectorFactory):
|
|||
timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
|
||||
grpc_port=dify_config.QDRANT_GRPC_PORT,
|
||||
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,
|
||||
replication_factor=dify_config.QDRANT_REPLICATION_FACTOR,
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ class TidbOnQdrantConfig(BaseModel):
|
|||
root_path: Optional[str] = None
|
||||
grpc_port: int = 6334
|
||||
prefer_grpc: bool = False
|
||||
replication_factor: int = 1
|
||||
|
||||
def to_qdrant_params(self):
|
||||
if self.endpoint and self.endpoint.startswith("path:"):
|
||||
|
|
@ -134,6 +135,7 @@ class TidbOnQdrantVector(BaseVector):
|
|||
vectors_config=vectors_config,
|
||||
hnsw_config=hnsw_config,
|
||||
timeout=int(self._client_config.timeout),
|
||||
replication_factor=self._client_config.replication_factor,
|
||||
)
|
||||
|
||||
# create group_id payload index
|
||||
|
|
@ -484,6 +486,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
|||
timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT,
|
||||
grpc_port=dify_config.TIDB_ON_QDRANT_GRPC_PORT,
|
||||
prefer_grpc=dify_config.TIDB_ON_QDRANT_GRPC_ENABLED,
|
||||
replication_factor=dify_config.QDRANT_REPLICATION_FACTOR,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -245,4 +245,4 @@ class TidbService:
|
|||
return cluster_infos
|
||||
else:
|
||||
response.raise_for_status()
|
||||
return [] # FIXME for mypy, This line will not be reached as raise_for_status() will raise an exception
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -0,0 +1,23 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RetrievalSourceMetadata(BaseModel):
|
||||
position: Optional[int] = None
|
||||
dataset_id: Optional[str] = None
|
||||
dataset_name: Optional[str] = None
|
||||
document_id: Optional[str] = None
|
||||
document_name: Optional[str] = None
|
||||
data_source_type: Optional[str] = None
|
||||
segment_id: Optional[str] = None
|
||||
retriever_from: Optional[str] = None
|
||||
score: Optional[float] = None
|
||||
hit_count: Optional[int] = None
|
||||
word_count: Optional[int] = None
|
||||
segment_position: Optional[int] = None
|
||||
index_node_hash: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
page: Optional[int] = None
|
||||
doc_metadata: Optional[dict[str, Any]] = None
|
||||
title: Optional[str] = None
|
||||
|
|
@ -27,6 +27,8 @@ class WebsiteInfo(BaseModel):
|
|||
website import info.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
provider: str
|
||||
job_id: str
|
||||
url: str
|
||||
|
|
@ -34,12 +36,6 @@ class WebsiteInfo(BaseModel):
|
|||
tenant_id: str
|
||||
only_main_content: bool = False
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
|
||||
class ExtractSetting(BaseModel):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -45,13 +45,12 @@ class BaseDocumentTransformer(ABC):
|
|||
.. code-block:: python
|
||||
|
||||
class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
embeddings: Embeddings
|
||||
similarity_fn: Callable = cosine_similarity
|
||||
similarity_threshold: float = 0.95
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def transform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ from core.prompt.simple_prompt_transform import ModelMode
|
|||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
|
|
@ -198,21 +199,21 @@ class DatasetRetrieval:
|
|||
|
||||
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
||||
external_documents = [item for item in all_documents if item.provider == "external"]
|
||||
document_context_list = []
|
||||
retrieval_resource_list = []
|
||||
document_context_list: list[DocumentContext] = []
|
||||
retrieval_resource_list: list[RetrievalSourceMetadata] = []
|
||||
# deal with external documents
|
||||
for item in external_documents:
|
||||
document_context_list.append(DocumentContext(content=item.page_content, score=item.metadata.get("score")))
|
||||
source = {
|
||||
"dataset_id": item.metadata.get("dataset_id"),
|
||||
"dataset_name": item.metadata.get("dataset_name"),
|
||||
"document_id": item.metadata.get("document_id") or item.metadata.get("title"),
|
||||
"document_name": item.metadata.get("title"),
|
||||
"data_source_type": "external",
|
||||
"retriever_from": invoke_from.to_source(),
|
||||
"score": item.metadata.get("score"),
|
||||
"content": item.page_content,
|
||||
}
|
||||
source = RetrievalSourceMetadata(
|
||||
dataset_id=item.metadata.get("dataset_id"),
|
||||
dataset_name=item.metadata.get("dataset_name"),
|
||||
document_id=item.metadata.get("document_id") or item.metadata.get("title"),
|
||||
document_name=item.metadata.get("title"),
|
||||
data_source_type="external",
|
||||
retriever_from=invoke_from.to_source(),
|
||||
score=item.metadata.get("score"),
|
||||
content=item.page_content,
|
||||
)
|
||||
retrieval_resource_list.append(source)
|
||||
# deal with dify documents
|
||||
if dify_documents:
|
||||
|
|
@ -248,32 +249,32 @@ class DatasetRetrieval:
|
|||
.first()
|
||||
)
|
||||
if dataset and document:
|
||||
source = {
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"document_id": document.id,
|
||||
"document_name": document.name,
|
||||
"data_source_type": document.data_source_type,
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": invoke_from.to_source(),
|
||||
"score": record.score or 0.0,
|
||||
"doc_metadata": document.doc_metadata,
|
||||
}
|
||||
source = RetrievalSourceMetadata(
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
document_id=document.id,
|
||||
document_name=document.name,
|
||||
data_source_type=document.data_source_type,
|
||||
segment_id=segment.id,
|
||||
retriever_from=invoke_from.to_source(),
|
||||
score=record.score or 0.0,
|
||||
doc_metadata=document.doc_metadata,
|
||||
)
|
||||
|
||||
if invoke_from.to_source() == "dev":
|
||||
source["hit_count"] = segment.hit_count
|
||||
source["word_count"] = segment.word_count
|
||||
source["segment_position"] = segment.position
|
||||
source["index_node_hash"] = segment.index_node_hash
|
||||
source.hit_count = segment.hit_count
|
||||
source.word_count = segment.word_count
|
||||
source.segment_position = segment.position
|
||||
source.index_node_hash = segment.index_node_hash
|
||||
if segment.answer:
|
||||
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source["content"] = segment.content
|
||||
source.content = segment.content
|
||||
retrieval_resource_list.append(source)
|
||||
if hit_callback and retrieval_resource_list:
|
||||
retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.get("score") or 0.0, reverse=True)
|
||||
retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.score or 0.0, reverse=True)
|
||||
for position, item in enumerate(retrieval_resource_list, start=1):
|
||||
item["position"] = position
|
||||
item.position = position
|
||||
hit_callback.return_retriever_resource_info(retrieval_resource_list)
|
||||
if document_context_list:
|
||||
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
|
||||
|
|
@ -936,6 +937,9 @@ class DatasetRetrieval:
|
|||
return metadata_filter_document_ids, metadata_condition
|
||||
|
||||
def _replace_metadata_filter_value(self, text: str, inputs: dict) -> str:
|
||||
if not inputs:
|
||||
return text
|
||||
|
||||
def replacer(match):
|
||||
key = match.group(1)
|
||||
return str(inputs.get(key, f"{{{{{key}}}}}"))
|
||||
|
|
|
|||
|
|
@ -10,12 +10,12 @@ from sqlalchemy import select
|
|||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.entities.workflow_execution_entities import (
|
||||
from core.workflow.entities.workflow_execution import (
|
||||
WorkflowExecution,
|
||||
WorkflowExecutionStatus,
|
||||
WorkflowType,
|
||||
)
|
||||
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from models import (
|
||||
Account,
|
||||
CreatorUserRole,
|
||||
|
|
@ -104,10 +104,9 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
|
|||
status = WorkflowExecutionStatus(db_model.status)
|
||||
|
||||
return WorkflowExecution(
|
||||
id=db_model.id,
|
||||
id_=db_model.id,
|
||||
workflow_id=db_model.workflow_id,
|
||||
sequence_number=db_model.sequence_number,
|
||||
type=WorkflowType(db_model.type),
|
||||
workflow_type=WorkflowType(db_model.type),
|
||||
workflow_version=db_model.version,
|
||||
graph=graph,
|
||||
inputs=inputs,
|
||||
|
|
@ -140,14 +139,29 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
|
|||
raise ValueError("created_by_role is required in repository constructor")
|
||||
|
||||
db_model = WorkflowRun()
|
||||
db_model.id = domain_model.id
|
||||
db_model.id = domain_model.id_
|
||||
db_model.tenant_id = self._tenant_id
|
||||
if self._app_id is not None:
|
||||
db_model.app_id = self._app_id
|
||||
db_model.workflow_id = domain_model.workflow_id
|
||||
db_model.triggered_from = self._triggered_from
|
||||
db_model.sequence_number = domain_model.sequence_number
|
||||
db_model.type = domain_model.type
|
||||
|
||||
# Check if this is a new record
|
||||
with self._session_factory() as session:
|
||||
existing = session.scalar(select(WorkflowRun).where(WorkflowRun.id == domain_model.id_))
|
||||
if not existing:
|
||||
# For new records, get the next sequence number
|
||||
stmt = select(WorkflowRun.sequence_number).where(
|
||||
WorkflowRun.app_id == self._app_id,
|
||||
WorkflowRun.tenant_id == self._tenant_id,
|
||||
)
|
||||
max_sequence = session.scalar(stmt.order_by(WorkflowRun.sequence_number.desc()))
|
||||
db_model.sequence_number = (max_sequence or 0) + 1
|
||||
else:
|
||||
# For updates, keep the existing sequence number
|
||||
db_model.sequence_number = existing.sequence_number
|
||||
|
||||
db_model.type = domain_model.workflow_type
|
||||
db_model.version = domain_model.workflow_version
|
||||
db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None
|
||||
db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None
|
||||
|
|
|
|||
|
|
@ -12,19 +12,18 @@ from sqlalchemy.engine import Engine
|
|||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.entities.node_execution_entities import (
|
||||
NodeExecution,
|
||||
NodeExecutionStatus,
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
|
||||
from models import (
|
||||
Account,
|
||||
CreatorUserRole,
|
||||
EndUser,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowNodeExecutionModel,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
)
|
||||
|
||||
|
|
@ -87,9 +86,9 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
|
||||
# Initialize in-memory cache for node executions
|
||||
# Key: node_execution_id, Value: WorkflowNodeExecution (DB model)
|
||||
self._node_execution_cache: dict[str, WorkflowNodeExecution] = {}
|
||||
self._node_execution_cache: dict[str, WorkflowNodeExecutionModel] = {}
|
||||
|
||||
def _to_domain_model(self, db_model: WorkflowNodeExecution) -> NodeExecution:
|
||||
def _to_domain_model(self, db_model: WorkflowNodeExecutionModel) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Convert a database model to a domain model.
|
||||
|
||||
|
|
@ -103,16 +102,16 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
inputs = db_model.inputs_dict
|
||||
process_data = db_model.process_data_dict
|
||||
outputs = db_model.outputs_dict
|
||||
metadata = {NodeRunMetadataKey(k): v for k, v in db_model.execution_metadata_dict.items()}
|
||||
metadata = {WorkflowNodeExecutionMetadataKey(k): v for k, v in db_model.execution_metadata_dict.items()}
|
||||
|
||||
# Convert status to domain enum
|
||||
status = NodeExecutionStatus(db_model.status)
|
||||
status = WorkflowNodeExecutionStatus(db_model.status)
|
||||
|
||||
return NodeExecution(
|
||||
return WorkflowNodeExecution(
|
||||
id=db_model.id,
|
||||
node_execution_id=db_model.node_execution_id,
|
||||
workflow_id=db_model.workflow_id,
|
||||
workflow_run_id=db_model.workflow_run_id,
|
||||
workflow_execution_id=db_model.workflow_run_id,
|
||||
index=db_model.index,
|
||||
predecessor_node_id=db_model.predecessor_node_id,
|
||||
node_id=db_model.node_id,
|
||||
|
|
@ -129,7 +128,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
finished_at=db_model.finished_at,
|
||||
)
|
||||
|
||||
def to_db_model(self, domain_model: NodeExecution) -> WorkflowNodeExecution:
|
||||
def to_db_model(self, domain_model: WorkflowNodeExecution) -> WorkflowNodeExecutionModel:
|
||||
"""
|
||||
Convert a domain model to a database model.
|
||||
|
||||
|
|
@ -147,14 +146,14 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
if not self._creator_user_role:
|
||||
raise ValueError("created_by_role is required in repository constructor")
|
||||
|
||||
db_model = WorkflowNodeExecution()
|
||||
db_model = WorkflowNodeExecutionModel()
|
||||
db_model.id = domain_model.id
|
||||
db_model.tenant_id = self._tenant_id
|
||||
if self._app_id is not None:
|
||||
db_model.app_id = self._app_id
|
||||
db_model.workflow_id = domain_model.workflow_id
|
||||
db_model.triggered_from = self._triggered_from
|
||||
db_model.workflow_run_id = domain_model.workflow_run_id
|
||||
db_model.workflow_run_id = domain_model.workflow_execution_id
|
||||
db_model.index = domain_model.index
|
||||
db_model.predecessor_node_id = domain_model.predecessor_node_id
|
||||
db_model.node_execution_id = domain_model.node_execution_id
|
||||
|
|
@ -176,7 +175,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
db_model.finished_at = domain_model.finished_at
|
||||
return db_model
|
||||
|
||||
def save(self, execution: NodeExecution) -> None:
|
||||
def save(self, execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Save or update a NodeExecution domain entity to the database.
|
||||
|
||||
|
|
@ -208,7 +207,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
logger.debug(f"Updating cache for node_execution_id: {db_model.node_execution_id}")
|
||||
self._node_execution_cache[db_model.node_execution_id] = db_model
|
||||
|
||||
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[NodeExecution]:
|
||||
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve a NodeExecution by its node_execution_id.
|
||||
|
||||
|
|
@ -231,13 +230,13 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
# If not in cache, query the database
|
||||
logger.debug(f"Cache miss for node_execution_id: {node_execution_id}, querying database")
|
||||
with self._session_factory() as session:
|
||||
stmt = select(WorkflowNodeExecution).where(
|
||||
WorkflowNodeExecution.node_execution_id == node_execution_id,
|
||||
WorkflowNodeExecution.tenant_id == self._tenant_id,
|
||||
stmt = select(WorkflowNodeExecutionModel).where(
|
||||
WorkflowNodeExecutionModel.node_execution_id == node_execution_id,
|
||||
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
|
||||
)
|
||||
|
||||
if self._app_id:
|
||||
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
|
||||
|
||||
db_model = session.scalar(stmt)
|
||||
if db_model:
|
||||
|
|
@ -253,7 +252,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
self,
|
||||
workflow_run_id: str,
|
||||
order_config: Optional[OrderConfig] = None,
|
||||
) -> Sequence[WorkflowNodeExecution]:
|
||||
) -> Sequence[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Retrieve all WorkflowNodeExecution database models for a specific workflow run.
|
||||
|
||||
|
|
@ -271,20 +270,20 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
A list of WorkflowNodeExecution database models
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
stmt = select(WorkflowNodeExecution).where(
|
||||
WorkflowNodeExecution.workflow_run_id == workflow_run_id,
|
||||
WorkflowNodeExecution.tenant_id == self._tenant_id,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
stmt = select(WorkflowNodeExecutionModel).where(
|
||||
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
||||
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
|
||||
WorkflowNodeExecutionModel.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
if self._app_id:
|
||||
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
|
||||
|
||||
# Apply ordering if provided
|
||||
if order_config and order_config.order_by:
|
||||
order_columns: list[UnaryExpression] = []
|
||||
for field in order_config.order_by:
|
||||
column = getattr(WorkflowNodeExecution, field, None)
|
||||
column = getattr(WorkflowNodeExecutionModel, field, None)
|
||||
if not column:
|
||||
continue
|
||||
if order_config.order_direction == "desc":
|
||||
|
|
@ -308,7 +307,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
self,
|
||||
workflow_run_id: str,
|
||||
order_config: Optional[OrderConfig] = None,
|
||||
) -> Sequence[NodeExecution]:
|
||||
) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all NodeExecution instances for a specific workflow run.
|
||||
|
||||
|
|
@ -335,7 +334,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
|
||||
return domain_models
|
||||
|
||||
def get_running_executions(self, workflow_run_id: str) -> Sequence[NodeExecution]:
|
||||
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all running NodeExecution instances for a specific workflow run.
|
||||
|
||||
|
|
@ -349,15 +348,15 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
A list of running NodeExecution instances
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
stmt = select(WorkflowNodeExecution).where(
|
||||
WorkflowNodeExecution.workflow_run_id == workflow_run_id,
|
||||
WorkflowNodeExecution.tenant_id == self._tenant_id,
|
||||
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
stmt = select(WorkflowNodeExecutionModel).where(
|
||||
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
||||
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
|
||||
WorkflowNodeExecutionModel.status == WorkflowNodeExecutionStatus.RUNNING,
|
||||
WorkflowNodeExecutionModel.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
if self._app_id:
|
||||
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
|
||||
|
||||
db_models = session.scalars(stmt).all()
|
||||
domain_models = []
|
||||
|
|
@ -382,10 +381,10 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
It also clears the in-memory cache.
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
stmt = delete(WorkflowNodeExecution).where(WorkflowNodeExecution.tenant_id == self._tenant_id)
|
||||
stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.tenant_id == self._tenant_id)
|
||||
|
||||
if self._app_id:
|
||||
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
|
||||
|
||||
result = session.execute(stmt)
|
||||
session.commit()
|
||||
|
|
|
|||
|
|
@ -168,7 +168,7 @@ class ApiTool(Tool):
|
|||
cookies[parameter["name"]] = value
|
||||
|
||||
elif parameter["in"] == "header":
|
||||
headers[parameter["name"]] = value
|
||||
headers[parameter["name"]] = str(value)
|
||||
|
||||
# check if there is a request body and handle it
|
||||
if "requestBody" in self.api_bundle.openapi and self.api_bundle.openapi["requestBody"] is not None:
|
||||
|
|
|
|||
|
|
@ -279,7 +279,6 @@ class ToolParameter(PluginParameter):
|
|||
:param options: the options of the parameter
|
||||
"""
|
||||
# convert options to ToolParameterOption
|
||||
# FIXME fix the type error
|
||||
if options:
|
||||
option_objs = [
|
||||
PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa
|
|||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.models.document import Document as RagDocument
|
||||
from core.rag.rerank.rerank_model import RerankModelRunner
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
|
|
@ -107,7 +108,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
|||
else:
|
||||
document_context_list.append(segment.get_sign_content())
|
||||
if self.return_resource:
|
||||
context_list = []
|
||||
context_list: list[RetrievalSourceMetadata] = []
|
||||
resource_number = 1
|
||||
for segment in sorted_segments:
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||
|
|
@ -121,28 +122,28 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
|||
.first()
|
||||
)
|
||||
if dataset and document:
|
||||
source = {
|
||||
"position": resource_number,
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"document_id": document.id,
|
||||
"document_name": document.name,
|
||||
"data_source_type": document.data_source_type,
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": self.retriever_from,
|
||||
"score": document_score_list.get(segment.index_node_id, None),
|
||||
"doc_metadata": document.doc_metadata,
|
||||
}
|
||||
source = RetrievalSourceMetadata(
|
||||
position=resource_number,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
document_id=document.id,
|
||||
document_name=document.name,
|
||||
data_source_type=document.data_source_type,
|
||||
segment_id=segment.id,
|
||||
retriever_from=self.retriever_from,
|
||||
score=document_score_list.get(segment.index_node_id, None),
|
||||
doc_metadata=document.doc_metadata,
|
||||
)
|
||||
|
||||
if self.retriever_from == "dev":
|
||||
source["hit_count"] = segment.hit_count
|
||||
source["word_count"] = segment.word_count
|
||||
source["segment_position"] = segment.position
|
||||
source["index_node_hash"] = segment.index_node_hash
|
||||
source.hit_count = segment.hit_count
|
||||
source.word_count = segment.word_count
|
||||
source.segment_position = segment.position
|
||||
source.index_node_hash = segment.index_node_hash
|
||||
if segment.answer:
|
||||
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source["content"] = segment.content
|
||||
source.content = segment.content
|
||||
context_list.append(source)
|
||||
resource_number += 1
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from pydantic import BaseModel, Field
|
|||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.models.document import Document as RetrievalDocument
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
|
|
@ -14,7 +15,7 @@ from models.dataset import Dataset
|
|||
from models.dataset import Document as DatasetDocument
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
default_retrieval_model = {
|
||||
default_retrieval_model: dict[str, Any] = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
|
|
@ -79,7 +80,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||
else:
|
||||
document_ids_filter = None
|
||||
if dataset.provider == "external":
|
||||
results = []
|
||||
results: list[RetrievalDocument] = []
|
||||
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
|
|
@ -100,21 +101,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||
document.metadata["dataset_name"] = dataset.name
|
||||
results.append(document)
|
||||
# deal with external documents
|
||||
context_list = []
|
||||
context_list: list[RetrievalSourceMetadata] = []
|
||||
for position, item in enumerate(results, start=1):
|
||||
if item.metadata is not None:
|
||||
source = {
|
||||
"position": position,
|
||||
"dataset_id": item.metadata.get("dataset_id"),
|
||||
"dataset_name": item.metadata.get("dataset_name"),
|
||||
"document_id": item.metadata.get("document_id") or item.metadata.get("title"),
|
||||
"document_name": item.metadata.get("title"),
|
||||
"data_source_type": "external",
|
||||
"retriever_from": self.retriever_from,
|
||||
"score": item.metadata.get("score"),
|
||||
"title": item.metadata.get("title"),
|
||||
"content": item.page_content,
|
||||
}
|
||||
source = RetrievalSourceMetadata(
|
||||
position=position,
|
||||
dataset_id=item.metadata.get("dataset_id"),
|
||||
dataset_name=item.metadata.get("dataset_name"),
|
||||
document_id=item.metadata.get("document_id") or item.metadata.get("title"),
|
||||
document_name=item.metadata.get("title"),
|
||||
data_source_type="external",
|
||||
retriever_from=self.retriever_from,
|
||||
score=item.metadata.get("score"),
|
||||
title=item.metadata.get("title"),
|
||||
content=item.page_content,
|
||||
)
|
||||
context_list.append(source)
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
|
@ -125,7 +126,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||
return ""
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
|
||||
retrieval_resource_list = []
|
||||
retrieval_resource_list: list[RetrievalSourceMetadata] = []
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
|
|
@ -163,7 +164,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||
for item in documents:
|
||||
if item.metadata is not None and item.metadata.get("score"):
|
||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
||||
document_context_list = []
|
||||
document_context_list: list[DocumentContext] = []
|
||||
records = RetrievalService.format_retrieval_documents(documents)
|
||||
if records:
|
||||
for record in records:
|
||||
|
|
@ -197,37 +198,37 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||
.first()
|
||||
)
|
||||
if dataset and document:
|
||||
source = {
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"document_id": document.id, # type: ignore
|
||||
"document_name": document.name, # type: ignore
|
||||
"data_source_type": document.data_source_type, # type: ignore
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": self.retriever_from,
|
||||
"score": record.score or 0.0,
|
||||
"doc_metadata": document.doc_metadata, # type: ignore
|
||||
}
|
||||
source = RetrievalSourceMetadata(
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
document_id=document.id, # type: ignore
|
||||
document_name=document.name, # type: ignore
|
||||
data_source_type=document.data_source_type, # type: ignore
|
||||
segment_id=segment.id,
|
||||
retriever_from=self.retriever_from,
|
||||
score=record.score or 0.0,
|
||||
doc_metadata=document.doc_metadata, # type: ignore
|
||||
)
|
||||
|
||||
if self.retriever_from == "dev":
|
||||
source["hit_count"] = segment.hit_count
|
||||
source["word_count"] = segment.word_count
|
||||
source["segment_position"] = segment.position
|
||||
source["index_node_hash"] = segment.index_node_hash
|
||||
source.hit_count = segment.hit_count
|
||||
source.word_count = segment.word_count
|
||||
source.segment_position = segment.position
|
||||
source.index_node_hash = segment.index_node_hash
|
||||
if segment.answer:
|
||||
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source["content"] = segment.content
|
||||
source.content = segment.content
|
||||
retrieval_resource_list.append(source)
|
||||
|
||||
if self.return_resource and retrieval_resource_list:
|
||||
retrieval_resource_list = sorted(
|
||||
retrieval_resource_list,
|
||||
key=lambda x: x.get("score") or 0.0,
|
||||
key=lambda x: x.score or 0.0,
|
||||
reverse=True,
|
||||
)
|
||||
for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore
|
||||
item["position"] = position # type: ignore
|
||||
item.position = position # type: ignore
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(retrieval_resource_list)
|
||||
if document_context_list:
|
||||
|
|
|
|||
|
|
@ -66,7 +66,6 @@ class ToolFileMessageTransformer:
|
|||
if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
|
||||
raise ValueError("unexpected message type")
|
||||
|
||||
# FIXME: should do a type check here.
|
||||
assert isinstance(message.message.blob, bytes)
|
||||
tool_file_manager = ToolFileManager()
|
||||
file = tool_file_manager.create_file_by_raw(
|
||||
|
|
|
|||
|
|
@ -55,6 +55,13 @@ class ApiBasedToolSchemaParser:
|
|||
# convert parameters
|
||||
parameters = []
|
||||
if "parameters" in interface["operation"]:
|
||||
for i, parameter in enumerate(interface["operation"]["parameters"]):
|
||||
if "$ref" in parameter:
|
||||
root = openapi
|
||||
reference = parameter["$ref"].split("/")[1:]
|
||||
for ref in reference:
|
||||
root = root[ref]
|
||||
interface["operation"]["parameters"][i] = root
|
||||
for parameter in interface["operation"]["parameters"]:
|
||||
tool_parameter = ToolParameter(
|
||||
name=parameter["name"],
|
||||
|
|
|
|||
|
|
@ -1,21 +1,13 @@
|
|||
import hashlib
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
import site
|
||||
import subprocess
|
||||
import tempfile
|
||||
import unicodedata
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Optional, cast
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, cast
|
||||
from urllib.parse import unquote
|
||||
|
||||
import chardet
|
||||
import cloudscraper # type: ignore
|
||||
from bs4 import BeautifulSoup, CData, Comment, NavigableString # type: ignore
|
||||
from regex import regex # type: ignore
|
||||
from readabilipy import simple_json_from_html_string # type: ignore
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.rag.extractor import extract_processor
|
||||
|
|
@ -23,9 +15,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
|
|||
|
||||
FULL_TEMPLATE = """
|
||||
TITLE: {title}
|
||||
AUTHORS: {authors}
|
||||
PUBLISH DATE: {publish_date}
|
||||
TOP_IMAGE_URL: {top_image}
|
||||
AUTHOR: {author}
|
||||
TEXT:
|
||||
|
||||
{text}
|
||||
|
|
@ -73,8 +63,8 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str:
|
|||
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
|
||||
elif response.status_code == 403:
|
||||
scraper = cloudscraper.create_scraper()
|
||||
scraper.perform_request = ssrf_proxy.make_request
|
||||
response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
|
||||
scraper.perform_request = ssrf_proxy.make_request # type: ignore
|
||||
response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) # type: ignore
|
||||
|
||||
if response.status_code != 200:
|
||||
return "URL returned status code {}.".format(response.status_code)
|
||||
|
|
@ -90,273 +80,36 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str:
|
|||
else:
|
||||
content = response.text
|
||||
|
||||
a = extract_using_readabilipy(content)
|
||||
article = extract_using_readabilipy(content)
|
||||
|
||||
if not a["plain_text"] or not a["plain_text"].strip():
|
||||
if not article.text:
|
||||
return ""
|
||||
|
||||
res = FULL_TEMPLATE.format(
|
||||
title=a["title"],
|
||||
authors=a["byline"],
|
||||
publish_date=a["date"],
|
||||
top_image="",
|
||||
text=a["plain_text"] or "",
|
||||
title=article.title,
|
||||
author=article.auther,
|
||||
text=article.text,
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def extract_using_readabilipy(html):
|
||||
with tempfile.NamedTemporaryFile(delete=False, mode="w+") as f_html:
|
||||
f_html.write(html)
|
||||
f_html.close()
|
||||
html_path = f_html.name
|
||||
|
||||
# Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file
|
||||
article_json_path = html_path + ".json"
|
||||
jsdir = os.path.join(find_module_path("readabilipy"), "javascript")
|
||||
with chdir(jsdir):
|
||||
subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path])
|
||||
|
||||
# Read output of call to Readability.parse() from JSON file and return as Python dictionary
|
||||
input_json = json.loads(Path(article_json_path).read_text(encoding="utf-8"))
|
||||
|
||||
# Deleting files after processing
|
||||
os.unlink(article_json_path)
|
||||
os.unlink(html_path)
|
||||
|
||||
article_json: dict[str, Any] = {
|
||||
"title": None,
|
||||
"byline": None,
|
||||
"date": None,
|
||||
"content": None,
|
||||
"plain_content": None,
|
||||
"plain_text": None,
|
||||
}
|
||||
# Populate article fields from readability fields where present
|
||||
if input_json:
|
||||
if input_json.get("title"):
|
||||
article_json["title"] = input_json["title"]
|
||||
if input_json.get("byline"):
|
||||
article_json["byline"] = input_json["byline"]
|
||||
if input_json.get("date"):
|
||||
article_json["date"] = input_json["date"]
|
||||
if input_json.get("content"):
|
||||
article_json["content"] = input_json["content"]
|
||||
article_json["plain_content"] = plain_content(article_json["content"], False, False)
|
||||
article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"])
|
||||
if input_json.get("textContent"):
|
||||
article_json["plain_text"] = input_json["textContent"]
|
||||
article_json["plain_text"] = re.sub(r"\n\s*\n", "\n", article_json["plain_text"])
|
||||
|
||||
return article_json
|
||||
@dataclass
|
||||
class Article:
|
||||
title: str
|
||||
auther: str
|
||||
text: Sequence[dict]
|
||||
|
||||
|
||||
def find_module_path(module_name):
|
||||
for package_path in site.getsitepackages():
|
||||
potential_path = os.path.join(package_path, module_name)
|
||||
if os.path.exists(potential_path):
|
||||
return potential_path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def chdir(path):
|
||||
"""Change directory in context and return to original on exit"""
|
||||
# From https://stackoverflow.com/a/37996581, couldn't find a built-in
|
||||
original_path = os.getcwd()
|
||||
os.chdir(path)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
os.chdir(original_path)
|
||||
|
||||
|
||||
def extract_text_blocks_as_plain_text(paragraph_html):
|
||||
# Load article as DOM
|
||||
soup = BeautifulSoup(paragraph_html, "html.parser")
|
||||
# Select all lists
|
||||
list_elements = soup.find_all(["ul", "ol"])
|
||||
# Prefix text in all list items with "* " and make lists paragraphs
|
||||
for list_element in list_elements:
|
||||
plain_items = "".join(
|
||||
list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all("li")]))
|
||||
)
|
||||
list_element.string = plain_items
|
||||
list_element.name = "p"
|
||||
# Select all text blocks
|
||||
text_blocks = [s.parent for s in soup.find_all(string=True)]
|
||||
text_blocks = [plain_text_leaf_node(block) for block in text_blocks]
|
||||
# Drop empty paragraphs
|
||||
text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks))
|
||||
return text_blocks
|
||||
|
||||
|
||||
def plain_text_leaf_node(element):
|
||||
# Extract all text, stripped of any child HTML elements and normalize it
|
||||
plain_text = normalize_text(element.get_text())
|
||||
if plain_text != "" and element.name == "li":
|
||||
plain_text = "* {}, ".format(plain_text)
|
||||
if plain_text == "":
|
||||
plain_text = None
|
||||
if "data-node-index" in element.attrs:
|
||||
plain = {"node_index": element["data-node-index"], "text": plain_text}
|
||||
else:
|
||||
plain = {"text": plain_text}
|
||||
return plain
|
||||
|
||||
|
||||
def plain_content(readability_content, content_digests, node_indexes):
|
||||
# Load article as DOM
|
||||
soup = BeautifulSoup(readability_content, "html.parser")
|
||||
# Make all elements plain
|
||||
elements = plain_elements(soup.contents, content_digests, node_indexes)
|
||||
if node_indexes:
|
||||
# Add node index attributes to nodes
|
||||
elements = [add_node_indexes(element) for element in elements]
|
||||
# Replace article contents with plain elements
|
||||
soup.contents = elements
|
||||
return str(soup)
|
||||
|
||||
|
||||
def plain_elements(elements, content_digests, node_indexes):
|
||||
# Get plain content versions of all elements
|
||||
elements = [plain_element(element, content_digests, node_indexes) for element in elements]
|
||||
if content_digests:
|
||||
# Add content digest attribute to nodes
|
||||
elements = [add_content_digest(element) for element in elements]
|
||||
return elements
|
||||
|
||||
|
||||
def plain_element(element, content_digests, node_indexes):
|
||||
# For lists, we make each item plain text
|
||||
if is_leaf(element):
|
||||
# For leaf node elements, extract the text content, discarding any HTML tags
|
||||
# 1. Get element contents as text
|
||||
plain_text = element.get_text()
|
||||
# 2. Normalize the extracted text string to a canonical representation
|
||||
plain_text = normalize_text(plain_text)
|
||||
# 3. Update element content to be plain text
|
||||
element.string = plain_text
|
||||
elif is_text(element):
|
||||
if is_non_printing(element):
|
||||
# The simplified HTML may have come from Readability.js so might
|
||||
# have non-printing text (e.g. Comment or CData). In this case, we
|
||||
# keep the structure, but ensure that the string is empty.
|
||||
element = type(element)("")
|
||||
else:
|
||||
plain_text = element.string
|
||||
plain_text = normalize_text(plain_text)
|
||||
element = type(element)(plain_text)
|
||||
else:
|
||||
# If not a leaf node or leaf type call recursively on child nodes, replacing
|
||||
element.contents = plain_elements(element.contents, content_digests, node_indexes)
|
||||
return element
|
||||
|
||||
|
||||
def add_node_indexes(element, node_index="0"):
|
||||
# Can't add attributes to string types
|
||||
if is_text(element):
|
||||
return element
|
||||
# Add index to current element
|
||||
element["data-node-index"] = node_index
|
||||
# Add index to child elements
|
||||
for local_idx, child in enumerate([c for c in element.contents if not is_text(c)], start=1):
|
||||
# Can't add attributes to leaf string types
|
||||
child_index = "{stem}.{local}".format(stem=node_index, local=local_idx)
|
||||
add_node_indexes(child, node_index=child_index)
|
||||
return element
|
||||
|
||||
|
||||
def normalize_text(text):
|
||||
"""Normalize unicode and whitespace."""
|
||||
# Normalize unicode first to try and standardize whitespace characters as much as possible before normalizing them
|
||||
text = strip_control_characters(text)
|
||||
text = normalize_unicode(text)
|
||||
text = normalize_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def strip_control_characters(text):
|
||||
"""Strip out unicode control characters which might break the parsing."""
|
||||
# Unicode control characters
|
||||
# [Cc]: Other, Control [includes new lines]
|
||||
# [Cf]: Other, Format
|
||||
# [Cn]: Other, Not Assigned
|
||||
# [Co]: Other, Private Use
|
||||
# [Cs]: Other, Surrogate
|
||||
control_chars = {"Cc", "Cf", "Cn", "Co", "Cs"}
|
||||
retained_chars = ["\t", "\n", "\r", "\f"]
|
||||
|
||||
# Remove non-printing control characters
|
||||
return "".join(
|
||||
[
|
||||
"" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char
|
||||
for char in text
|
||||
]
|
||||
def extract_using_readabilipy(html: str):
|
||||
json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=True)
|
||||
article = Article(
|
||||
title=json_article.get("title") or "",
|
||||
auther=json_article.get("byline") or "",
|
||||
text=json_article.get("plain_text") or [],
|
||||
)
|
||||
|
||||
|
||||
def normalize_unicode(text):
|
||||
"""Normalize unicode such that things that are visually equivalent map to the same unicode string where possible."""
|
||||
normal_form: Literal["NFC", "NFD", "NFKC", "NFKD"] = "NFKC"
|
||||
text = unicodedata.normalize(normal_form, text)
|
||||
return text
|
||||
|
||||
|
||||
def normalize_whitespace(text):
|
||||
"""Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed."""
|
||||
text = regex.sub(r"\s+", " ", text)
|
||||
# Remove leading and trailing whitespace
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
def is_leaf(element):
|
||||
return element.name in {"p", "li"}
|
||||
|
||||
|
||||
def is_text(element):
|
||||
return isinstance(element, NavigableString)
|
||||
|
||||
|
||||
def is_non_printing(element):
|
||||
return any(isinstance(element, _e) for _e in [Comment, CData])
|
||||
|
||||
|
||||
def add_content_digest(element):
|
||||
if not is_text(element):
|
||||
element["data-content-digest"] = content_digest(element)
|
||||
return element
|
||||
|
||||
|
||||
def content_digest(element):
|
||||
digest: Any
|
||||
if is_text(element):
|
||||
# Hash
|
||||
trimmed_string = element.string.strip()
|
||||
if trimmed_string == "":
|
||||
digest = ""
|
||||
else:
|
||||
digest = hashlib.sha256(trimmed_string.encode("utf-8")).hexdigest()
|
||||
else:
|
||||
contents = element.contents
|
||||
num_contents = len(contents)
|
||||
if num_contents == 0:
|
||||
# No hash when no child elements exist
|
||||
digest = ""
|
||||
elif num_contents == 1:
|
||||
# If single child, use digest of child
|
||||
digest = content_digest(contents[0])
|
||||
else:
|
||||
# Build content digest from the "non-empty" digests of child nodes
|
||||
digest = hashlib.sha256()
|
||||
child_digests = list(filter(lambda x: x != "", [content_digest(content) for content in contents]))
|
||||
for child in child_digests:
|
||||
digest.update(child.encode("utf-8"))
|
||||
digest = digest.hexdigest()
|
||||
return digest
|
||||
return article
|
||||
|
||||
|
||||
def get_image_upload_file_ids(content):
|
||||
|
|
|
|||
|
|
@ -1,36 +1,10 @@
|
|||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class NodeRunMetadataKey(StrEnum):
|
||||
"""
|
||||
Node Run Metadata Key.
|
||||
"""
|
||||
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
TOTAL_PRICE = "total_price"
|
||||
CURRENCY = "currency"
|
||||
TOOL_INFO = "tool_info"
|
||||
AGENT_LOG = "agent_log"
|
||||
ITERATION_ID = "iteration_id"
|
||||
ITERATION_INDEX = "iteration_index"
|
||||
LOOP_ID = "loop_id"
|
||||
LOOP_INDEX = "loop_index"
|
||||
PARALLEL_ID = "parallel_id"
|
||||
PARALLEL_START_NODE_ID = "parallel_start_node_id"
|
||||
PARENT_PARALLEL_ID = "parent_parallel_id"
|
||||
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
|
||||
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
|
||||
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
|
||||
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
|
||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class NodeRunResult(BaseModel):
|
||||
|
|
@ -43,7 +17,7 @@ class NodeRunResult(BaseModel):
|
|||
inputs: Optional[Mapping[str, Any]] = None # node inputs
|
||||
process_data: Optional[Mapping[str, Any]] = None # process data
|
||||
outputs: Optional[Mapping[str, Any]] = None # node outputs
|
||||
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # node metadata
|
||||
metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None # node metadata
|
||||
llm_usage: Optional[LLMUsage] = None # llm usage
|
||||
|
||||
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
|
||||
|
|
|
|||
|
|
@ -36,12 +36,10 @@ class WorkflowExecution(BaseModel):
|
|||
user, tenant, and app attributes.
|
||||
"""
|
||||
|
||||
id: str = Field(...)
|
||||
id_: str = Field(...)
|
||||
workflow_id: str = Field(...)
|
||||
workflow_version: str = Field(...)
|
||||
sequence_number: int = Field(...)
|
||||
|
||||
type: WorkflowType = Field(...)
|
||||
workflow_type: WorkflowType = Field(...)
|
||||
graph: Mapping[str, Any] = Field(...)
|
||||
|
||||
inputs: Mapping[str, Any] = Field(...)
|
||||
|
|
@ -69,20 +67,18 @@ class WorkflowExecution(BaseModel):
|
|||
def new(
|
||||
cls,
|
||||
*,
|
||||
id: str,
|
||||
id_: str,
|
||||
workflow_id: str,
|
||||
sequence_number: int,
|
||||
type: WorkflowType,
|
||||
workflow_type: WorkflowType,
|
||||
workflow_version: str,
|
||||
graph: Mapping[str, Any],
|
||||
inputs: Mapping[str, Any],
|
||||
started_at: datetime,
|
||||
) -> "WorkflowExecution":
|
||||
return WorkflowExecution(
|
||||
id=id,
|
||||
id_=id_,
|
||||
workflow_id=workflow_id,
|
||||
sequence_number=sequence_number,
|
||||
type=type,
|
||||
workflow_type=workflow_type,
|
||||
workflow_version=workflow_version,
|
||||
graph=graph,
|
||||
inputs=inputs,
|
||||
|
|
@ -13,11 +13,35 @@ from typing import Any, Optional
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
|
||||
|
||||
class NodeExecutionStatus(StrEnum):
|
||||
class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
"""
|
||||
Node Run Metadata Key.
|
||||
"""
|
||||
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
TOTAL_PRICE = "total_price"
|
||||
CURRENCY = "currency"
|
||||
TOOL_INFO = "tool_info"
|
||||
AGENT_LOG = "agent_log"
|
||||
ITERATION_ID = "iteration_id"
|
||||
ITERATION_INDEX = "iteration_index"
|
||||
LOOP_ID = "loop_id"
|
||||
LOOP_INDEX = "loop_index"
|
||||
PARALLEL_ID = "parallel_id"
|
||||
PARALLEL_START_NODE_ID = "parallel_start_node_id"
|
||||
PARENT_PARALLEL_ID = "parent_parallel_id"
|
||||
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
|
||||
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
|
||||
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
|
||||
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
|
||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
|
||||
|
||||
class WorkflowNodeExecutionStatus(StrEnum):
|
||||
"""
|
||||
Node Execution Status Enum.
|
||||
"""
|
||||
|
|
@ -29,7 +53,7 @@ class NodeExecutionStatus(StrEnum):
|
|||
RETRY = "retry"
|
||||
|
||||
|
||||
class NodeExecution(BaseModel):
|
||||
class WorkflowNodeExecution(BaseModel):
|
||||
"""
|
||||
Domain model for workflow node execution.
|
||||
|
||||
|
|
@ -46,7 +70,7 @@ class NodeExecution(BaseModel):
|
|||
id: str # Unique identifier for this execution record
|
||||
node_execution_id: Optional[str] = None # Optional secondary ID for cross-referencing
|
||||
workflow_id: str # ID of the workflow this node belongs to
|
||||
workflow_run_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging)
|
||||
workflow_execution_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging)
|
||||
|
||||
# Execution positioning and flow
|
||||
index: int # Sequence number for ordering in trace visualization
|
||||
|
|
@ -61,12 +85,12 @@ class NodeExecution(BaseModel):
|
|||
outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node
|
||||
|
||||
# Execution state
|
||||
status: NodeExecutionStatus = NodeExecutionStatus.RUNNING # Current execution status
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING # Current execution status
|
||||
error: Optional[str] = None # Error message if execution failed
|
||||
elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds
|
||||
|
||||
# Additional metadata
|
||||
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.)
|
||||
metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.)
|
||||
|
||||
# Timing information
|
||||
created_at: datetime # When execution started
|
||||
|
|
@ -77,7 +101,7 @@ class NodeExecution(BaseModel):
|
|||
inputs: Optional[Mapping[str, Any]] = None,
|
||||
process_data: Optional[Mapping[str, Any]] = None,
|
||||
outputs: Optional[Mapping[str, Any]] = None,
|
||||
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None,
|
||||
metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Update the model from mappings.
|
||||
|
|
@ -13,4 +13,4 @@ class SystemVariableKey(StrEnum):
|
|||
DIALOGUE_COUNT = "dialogue_count"
|
||||
APP_ID = "app_id"
|
||||
WORKFLOW_ID = "workflow_id"
|
||||
WORKFLOW_RUN_ID = "workflow_run_id"
|
||||
WORKFLOW_EXECUTION_ID = "workflow_run_id"
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.nodes import NodeType
|
||||
|
|
@ -82,7 +83,7 @@ class NodeRunStreamChunkEvent(BaseNodeEvent):
|
|||
|
||||
|
||||
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
|
||||
retriever_resources: list[dict] = Field(..., description="retriever resources")
|
||||
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from typing import Optional
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class RouteNodeState(BaseModel):
|
||||
|
|
|
|||
|
|
@ -9,13 +9,14 @@ from copy import copy, deepcopy
|
|||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from flask import Flask, current_app, has_request_context
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseAgentEvent,
|
||||
|
|
@ -52,9 +53,8 @@ from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
|||
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -540,8 +540,21 @@ class GraphEngine:
|
|||
for var, val in context.items():
|
||||
var.set(val)
|
||||
|
||||
# FIXME(-LAN-): Save current user before entering new app context
|
||||
from flask import g
|
||||
|
||||
saved_user = None
|
||||
if has_request_context() and hasattr(g, "_login_user"):
|
||||
saved_user = g._login_user
|
||||
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
# Restore user in new app context
|
||||
if saved_user is not None:
|
||||
from flask import g
|
||||
|
||||
g._login_user = saved_user
|
||||
|
||||
q.put(
|
||||
ParallelBranchRunStartedEvent(
|
||||
parallel_id=parallel_id,
|
||||
|
|
@ -593,8 +606,6 @@ class GraphEngine:
|
|||
error=str(e),
|
||||
)
|
||||
)
|
||||
finally:
|
||||
db.session.remove()
|
||||
|
||||
def _run_node(
|
||||
self,
|
||||
|
|
@ -632,7 +643,6 @@ class GraphEngine:
|
|||
agent_strategy=agent_strategy,
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
max_retries = node_instance.node_data.retry_config.max_retries
|
||||
retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
|
||||
retries = 0
|
||||
|
|
@ -746,10 +756,12 @@ class GraphEngine:
|
|||
and node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH
|
||||
):
|
||||
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
|
||||
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
if run_result.metadata and run_result.metadata.get(
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS
|
||||
):
|
||||
# plus state total_tokens
|
||||
self.graph_runtime_state.total_tokens += int(
|
||||
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
|
||||
run_result.metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
if run_result.llm_usage:
|
||||
|
|
@ -772,13 +784,17 @@ class GraphEngine:
|
|||
|
||||
if parallel_id and parallel_start_node_id:
|
||||
metadata_dict = dict(run_result.metadata)
|
||||
metadata_dict[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
|
||||
metadata_dict[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
|
||||
metadata_dict[WorkflowNodeExecutionMetadataKey.PARALLEL_ID] = parallel_id
|
||||
metadata_dict[WorkflowNodeExecutionMetadataKey.PARALLEL_START_NODE_ID] = (
|
||||
parallel_start_node_id
|
||||
)
|
||||
if parent_parallel_id and parent_parallel_start_node_id:
|
||||
metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
|
||||
metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
|
||||
parent_parallel_start_node_id
|
||||
metadata_dict[WorkflowNodeExecutionMetadataKey.PARENT_PARALLEL_ID] = (
|
||||
parent_parallel_id
|
||||
)
|
||||
metadata_dict[
|
||||
WorkflowNodeExecutionMetadataKey.PARENT_PARALLEL_START_NODE_ID
|
||||
] = parent_parallel_start_node_id
|
||||
run_result.metadata = metadata_dict
|
||||
|
||||
yield NodeRunSucceededEvent(
|
||||
|
|
@ -843,8 +859,6 @@ class GraphEngine:
|
|||
except Exception as e:
|
||||
logger.exception(f"Node {node_instance.node_data.title} run failed")
|
||||
raise e
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
|
||||
"""
|
||||
|
|
@ -910,7 +924,7 @@ class GraphEngine:
|
|||
"error": error_result.error,
|
||||
"inputs": error_result.inputs,
|
||||
"metadata": {
|
||||
NodeRunMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy,
|
||||
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@ import json
|
|||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
|
|
@ -15,6 +18,7 @@ from core.tools.tool_manager import ToolManager
|
|||
from core.variables.segments import StringSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
|
|
@ -25,7 +29,6 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
|||
from extensions.ext_database import db
|
||||
from factories.agent_factory import get_plugin_agent_strategy
|
||||
from models.model import Conversation
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class AgentNode(ToolNode):
|
||||
|
|
@ -320,15 +323,12 @@ class AgentNode(ToolNode):
|
|||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
# get conversation
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
|
|
@ -356,7 +356,9 @@ class AgentNode(ToolNode):
|
|||
|
||||
def _remove_unsupported_model_features_for_old_version(self, model_schema: AIModelEntity) -> AIModelEntity:
|
||||
if model_schema.features:
|
||||
for feature in model_schema.features:
|
||||
if feature.value not in AgentOldVersionModelFeatures:
|
||||
for feature in model_schema.features[:]: # Create a copy to safely modify during iteration
|
||||
try:
|
||||
AgentOldVersionModelFeatures(feature.value) # Try to create enum member from value
|
||||
except ValueError:
|
||||
model_schema.features.remove(feature)
|
||||
return model_schema
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from enum import Enum
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -26,7 +26,7 @@ class ParamsAutoGenerated(Enum):
|
|||
OPEN = 1
|
||||
|
||||
|
||||
class AgentOldVersionModelFeatures(Enum):
|
||||
class AgentOldVersionModelFeatures(StrEnum):
|
||||
"""
|
||||
Enum class for old SDK version llm feature.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from typing import Any, cast
|
|||
|
||||
from core.variables import ArrayFileSegment, FileSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
|
||||
from core.workflow.nodes.answer.entities import (
|
||||
AnswerNodeData,
|
||||
|
|
@ -13,7 +14,6 @@ from core.workflow.nodes.answer.entities import (
|
|||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class AnswerNode(BaseNode[AnswerNodeData]):
|
||||
|
|
|
|||
|
|
@ -4,9 +4,9 @@ from collections.abc import Generator, Mapping, Sequence
|
|||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import BaseNodeData
|
||||
|
||||
|
|
|
|||
|
|
@ -8,10 +8,10 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
|
|||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.variables.segments import ArrayFileSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.code.entities import CodeNodeData
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .exc import (
|
||||
CodeNodeError,
|
||||
|
|
@ -167,8 +167,11 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
value=value,
|
||||
variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]",
|
||||
)
|
||||
elif isinstance(first_element, dict) and all(
|
||||
value is None or isinstance(value, dict) for value in output_value
|
||||
elif (
|
||||
isinstance(first_element, dict)
|
||||
and all(value is None or isinstance(value, dict) for value in output_value)
|
||||
or isinstance(first_element, list)
|
||||
and all(value is None or isinstance(value, list) for value in output_value)
|
||||
):
|
||||
for i, value in enumerate(output_value):
|
||||
if value is not None:
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import tempfile
|
|||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
import chardet
|
||||
import docx
|
||||
import pandas as pd
|
||||
import pypandoc # type: ignore
|
||||
|
|
@ -25,9 +26,9 @@ from core.helper import ssrf_proxy
|
|||
from core.variables import ArrayFileSegment
|
||||
from core.variables.segments import FileSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import DocumentExtractorNodeData
|
||||
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
|
||||
|
|
@ -180,26 +181,64 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str)
|
|||
|
||||
def _extract_text_from_plain_text(file_content: bytes) -> str:
|
||||
try:
|
||||
return file_content.decode("utf-8", "ignore")
|
||||
except UnicodeDecodeError as e:
|
||||
raise TextExtractionError("Failed to decode plain text file") from e
|
||||
# Detect encoding using chardet
|
||||
result = chardet.detect(file_content)
|
||||
encoding = result["encoding"]
|
||||
|
||||
# Fallback to utf-8 if detection fails
|
||||
if not encoding:
|
||||
encoding = "utf-8"
|
||||
|
||||
return file_content.decode(encoding, errors="ignore")
|
||||
except (UnicodeDecodeError, LookupError) as e:
|
||||
# If decoding fails, try with utf-8 as last resort
|
||||
try:
|
||||
return file_content.decode("utf-8", errors="ignore")
|
||||
except UnicodeDecodeError:
|
||||
raise TextExtractionError(f"Failed to decode plain text file: {e}") from e
|
||||
|
||||
|
||||
def _extract_text_from_json(file_content: bytes) -> str:
|
||||
try:
|
||||
json_data = json.loads(file_content.decode("utf-8", "ignore"))
|
||||
# Detect encoding using chardet
|
||||
result = chardet.detect(file_content)
|
||||
encoding = result["encoding"]
|
||||
|
||||
# Fallback to utf-8 if detection fails
|
||||
if not encoding:
|
||||
encoding = "utf-8"
|
||||
|
||||
json_data = json.loads(file_content.decode(encoding, errors="ignore"))
|
||||
return json.dumps(json_data, indent=2, ensure_ascii=False)
|
||||
except (UnicodeDecodeError, json.JSONDecodeError) as e:
|
||||
raise TextExtractionError(f"Failed to decode or parse JSON file: {e}") from e
|
||||
except (UnicodeDecodeError, LookupError, json.JSONDecodeError) as e:
|
||||
# If decoding fails, try with utf-8 as last resort
|
||||
try:
|
||||
json_data = json.loads(file_content.decode("utf-8", errors="ignore"))
|
||||
return json.dumps(json_data, indent=2, ensure_ascii=False)
|
||||
except (UnicodeDecodeError, json.JSONDecodeError):
|
||||
raise TextExtractionError(f"Failed to decode or parse JSON file: {e}") from e
|
||||
|
||||
|
||||
def _extract_text_from_yaml(file_content: bytes) -> str:
|
||||
"""Extract the content from yaml file"""
|
||||
try:
|
||||
yaml_data = yaml.safe_load_all(file_content.decode("utf-8", "ignore"))
|
||||
# Detect encoding using chardet
|
||||
result = chardet.detect(file_content)
|
||||
encoding = result["encoding"]
|
||||
|
||||
# Fallback to utf-8 if detection fails
|
||||
if not encoding:
|
||||
encoding = "utf-8"
|
||||
|
||||
yaml_data = yaml.safe_load_all(file_content.decode(encoding, errors="ignore"))
|
||||
return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False))
|
||||
except (UnicodeDecodeError, yaml.YAMLError) as e:
|
||||
raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e
|
||||
except (UnicodeDecodeError, LookupError, yaml.YAMLError) as e:
|
||||
# If decoding fails, try with utf-8 as last resort
|
||||
try:
|
||||
yaml_data = yaml.safe_load_all(file_content.decode("utf-8", errors="ignore"))
|
||||
return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False))
|
||||
except (UnicodeDecodeError, yaml.YAMLError):
|
||||
raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e
|
||||
|
||||
|
||||
def _extract_text_from_pdf(file_content: bytes) -> str:
|
||||
|
|
@ -338,7 +377,20 @@ def _extract_text_from_file(file: File):
|
|||
|
||||
def _extract_text_from_csv(file_content: bytes) -> str:
|
||||
try:
|
||||
csv_file = io.StringIO(file_content.decode("utf-8", "ignore"))
|
||||
# Detect encoding using chardet
|
||||
result = chardet.detect(file_content)
|
||||
encoding = result["encoding"]
|
||||
|
||||
# Fallback to utf-8 if detection fails
|
||||
if not encoding:
|
||||
encoding = "utf-8"
|
||||
|
||||
try:
|
||||
csv_file = io.StringIO(file_content.decode(encoding, errors="ignore"))
|
||||
except (UnicodeDecodeError, LookupError):
|
||||
# If decoding fails, try with utf-8 as last resort
|
||||
csv_file = io.StringIO(file_content.decode("utf-8", errors="ignore"))
|
||||
|
||||
csv_reader = csv.reader(csv_file)
|
||||
rows = list(csv_reader)
|
||||
|
||||
|
|
@ -366,7 +418,7 @@ def _extract_text_from_excel(file_content: bytes) -> str:
|
|||
df = excel_file.parse(sheet_name=sheet_name)
|
||||
df.dropna(how="all", inplace=True)
|
||||
# Create Markdown table two times to separate tables with a newline
|
||||
markdown_table += df.to_markdown(index=False) + "\n\n"
|
||||
markdown_table += df.to_markdown(index=False, floatfmt="") + "\n\n"
|
||||
except Exception as e:
|
||||
continue
|
||||
return markdown_table
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class EndNode(BaseNode[EndNodeData]):
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue