Merge branch 'main' into feat/memory-orchestration-be

# Conflicts:
#	api/core/app/apps/advanced_chat/app_runner.py
#	api/models/workflow.py
This commit is contained in:
Stream 2025-11-05 17:40:48 +08:00
commit e53147266c
No known key found for this signature in database
GPG Key ID: 033728094B100D70
859 changed files with 32107 additions and 10379 deletions

View File

@ -11,7 +11,7 @@
"nodeGypDependencies": true,
"version": "lts"
},
"ghcr.io/devcontainers-contrib/features/npm-package:1": {
"ghcr.io/devcontainers-extra/features/npm-package:1": {
"package": "typescript",
"version": "latest"
},

View File

@ -6,7 +6,7 @@ cd web && pnpm install
pipx install uv
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage\"" >> ~/.bashrc
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage\"" >> ~/.bashrc
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc

View File

@ -103,6 +103,11 @@ jobs:
run: |
pnpm run lint
- name: Web type check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run type-check
docker-compose-template:
name: Docker Compose Template
runs-on: ubuntu-latest

1
.gitignore vendored
View File

@ -97,6 +97,7 @@ __pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat-schedule.db
celerybeat.pid
# SageMath parsed files

View File

@ -8,8 +8,7 @@
"module": "flask",
"env": {
"FLASK_APP": "app.py",
"FLASK_ENV": "development",
"GEVENT_SUPPORT": "True"
"FLASK_ENV": "development"
},
"args": [
"run",
@ -28,9 +27,7 @@
"type": "debugpy",
"request": "launch",
"module": "celery",
"env": {
"GEVENT_SUPPORT": "True"
},
"env": {},
"args": [
"-A",
"app.celery",
@ -40,7 +37,7 @@
"-c",
"1",
"-Q",
"dataset,generation,mail,ops_trace",
"dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline",
"--loglevel",
"INFO"
],

View File

@ -63,7 +63,7 @@ Dify is an open-source platform for developing LLM applications. Its intuitive i
> - CPU >= 2 Core
> - RAM >= 4 GiB
</br>
<br/>
The easiest way to start the Dify server is through [Docker Compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine:
@ -109,15 +109,15 @@ All of Dify's offerings come with corresponding APIs, so you could effortlessly
## Using Dify
- **Cloud </br>**
- **Cloud <br/>**
We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan.
- **Self-hosting Dify Community Edition</br>**
- **Self-hosting Dify Community Edition<br/>**
Quickly get Dify running in your environment with this [starter guide](#quick-start).
Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions.
- **Dify for enterprise / organizations</br>**
We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss enterprise needs. </br>
- **Dify for enterprise / organizations<br/>**
We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss enterprise needs. <br/>
> For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one click. It's an affordable AMI offering with the option to create apps with custom logo and branding.

View File

@ -156,6 +156,9 @@ SUPABASE_URL=your-server-url
# CORS configuration
WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
# Set COOKIE_DOMAIN when the console frontend and API are on different subdomains.
# Provide the registrable domain (e.g. example.com); leading dots are optional.
COOKIE_DOMAIN=
# Vector database configuration
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
@ -368,6 +371,12 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
# Comma-separated list of file extensions blocked from upload for security reasons.
# Extensions should be lowercase without dots (e.g., exe,bat,sh,dll).
# Empty by default to allow all file types.
# Recommended: exe,bat,cmd,com,scr,vbs,ps1,msi,dll
UPLOAD_FILE_EXTENSION_BLACKLIST=
# Model configuration
MULTIMODAL_SEND_FORMAT=base64
PROMPT_GENERATION_MAX_TOKENS=512
@ -605,3 +614,6 @@ SWAGGER_UI_PATH=/swagger-ui.html
# Whether to encrypt dataset IDs when exporting DSL files (default: true)
# Set to false to export dataset IDs as plain text for easier cross-environment import
DSL_EXPORT_ENCRYPT_DATASET_ID=true
# Maximum number of segments for dataset segments API (0 for unlimited)
DATASET_MAX_SEGMENTS_PER_REQUEST=0

View File

@ -54,7 +54,7 @@
"--loglevel",
"DEBUG",
"-Q",
"dataset,generation,mail,ops_trace,app_deletion"
"dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline"
]
}
]

View File

@ -15,7 +15,11 @@ FROM base AS packages
# RUN sed -i 's@deb.debian.org@mirrors.aliyun.com@g' /etc/apt/sources.list.d/debian.sources
RUN apt-get update \
&& apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev
&& apt-get install -y --no-install-recommends \
# basic environment
g++ \
# for building gmpy2
libmpfr-dev libmpc-dev
# Install Python dependencies
COPY pyproject.toml uv.lock ./
@ -49,7 +53,9 @@ RUN \
# Install dependencies
&& apt-get install -y --no-install-recommends \
# basic environment
curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
curl nodejs \
# for gmpy2 \
libgmp-dev libmpfr-dev libmpc-dev \
# For Security
expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
# install fonts to support the use of tools like pypdfium2

View File

@ -80,7 +80,7 @@
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
```bash
uv run celery -A app.celery worker -P gevent -c 2 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline
```
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:

View File

@ -13,23 +13,12 @@ if is_db_command():
app = create_migrations_app()
else:
# It seems that JetBrains Python debugger does not work well with gevent,
# so we need to disable gevent in debug mode.
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
# if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
# from gevent import monkey
# Gunicorn and Celery handle monkey patching automatically in production by
# specifying the `gevent` worker class. Manual monkey patching is not required here.
#
# # gevent
# monkey.patch_all()
# See `api/docker/entrypoint.sh` (lines 33 and 47) for details.
#
# from grpc.experimental import gevent as grpc_gevent # type: ignore
#
# # grpc gevent
# grpc_gevent.init_gevent()
# import psycogreen.gevent # type: ignore
#
# psycogreen.gevent.patch_psycopg()
# For third-party library patching, refer to `gunicorn.conf.py` and `celery_entrypoint.py`.
from app_factory import create_app

View File

@ -321,6 +321,8 @@ def migrate_knowledge_vector_database():
)
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
if not datasets.items:
break
except SQLAlchemyError:
raise

View File

@ -331,12 +331,42 @@ class FileUploadConfig(BaseSettings):
default=10,
)
inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field(
description=(
"Comma-separated list of file extensions that are blocked from upload. "
"Extensions should be lowercase without dots (e.g., 'exe,bat,sh,dll'). "
"Empty by default to allow all file types."
),
validation_alias=AliasChoices("UPLOAD_FILE_EXTENSION_BLACKLIST"),
default="",
)
@computed_field # type: ignore[misc]
@property
def UPLOAD_FILE_EXTENSION_BLACKLIST(self) -> set[str]:
"""
Parse and return the blacklist as a set of lowercase extensions.
Returns an empty set if no blacklist is configured.
"""
if not self.inner_UPLOAD_FILE_EXTENSION_BLACKLIST:
return set()
return {
ext.strip().lower().strip(".")
for ext in self.inner_UPLOAD_FILE_EXTENSION_BLACKLIST.split(",")
if ext.strip()
}
class HttpConfig(BaseSettings):
"""
HTTP-related configurations for the application
"""
COOKIE_DOMAIN: str = Field(
description="Explicit cookie domain for console/service cookies when sharing across subdomains",
default="",
)
API_COMPRESSION_ENABLED: bool = Field(
description="Enable or disable gzip compression for HTTP responses",
default=False,
@ -915,6 +945,11 @@ class DataSetConfig(BaseSettings):
default=True,
)
DATASET_MAX_SEGMENTS_PER_REQUEST: NonNegativeInt = Field(
description="Maximum number of segments for dataset segments API (0 for unlimited)",
default=0,
)
class WorkspaceConfig(BaseSettings):
"""

View File

@ -22,6 +22,11 @@ class WeaviateConfig(BaseSettings):
default=True,
)
WEAVIATE_GRPC_ENDPOINT: str | None = Field(
description="URL of the Weaviate gRPC server (e.g., 'grpc://localhost:50051' or 'grpcs://weaviate.example.com:443')",
default=None,
)
WEAVIATE_BATCH_SIZE: PositiveInt = Field(
description="Number of objects to be processed in a single batch operation (default is 100)",
default=100,

View File

@ -56,11 +56,15 @@ else:
}
DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)
# console
COOKIE_NAME_ACCESS_TOKEN = "access_token"
COOKIE_NAME_REFRESH_TOKEN = "refresh_token"
COOKIE_NAME_PASSPORT = "passport"
COOKIE_NAME_CSRF_TOKEN = "csrf_token"
# webapp
COOKIE_NAME_WEBAPP_ACCESS_TOKEN = "webapp_access_token"
COOKIE_NAME_PASSPORT = "passport"
HEADER_NAME_CSRF_TOKEN = "X-CSRF-Token"
HEADER_NAME_APP_CODE = "X-App-Code"
HEADER_NAME_PASSPORT = "X-App-Passport"

View File

@ -31,3 +31,9 @@ def supported_language(lang):
error = f"{lang} is not a valid language."
raise ValueError(error)
def get_valid_language(lang: str | None) -> str:
if lang and lang in languages:
return lang
return languages[0]

View File

@ -25,6 +25,12 @@ class UnsupportedFileTypeError(BaseHTTPException):
code = 415
class BlockedFileExtensionError(BaseHTTPException):
error_code = "file_extension_blocked"
description = "The file extension is blocked for security reasons."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = "too_many_files"
description = "Only one file is allowed."

View File

@ -16,6 +16,7 @@ from fields.annotation_fields import (
annotation_fields,
annotation_hit_history_fields,
)
from libs.helper import uuid_value
from libs.login import login_required
from services.annotation_service import AppAnnotationService
@ -175,8 +176,10 @@ class AnnotationApi(Resource):
api.model(
"CreateAnnotationRequest",
{
"question": fields.String(required=True, description="Question text"),
"answer": fields.String(required=True, description="Answer text"),
"message_id": fields.String(description="Message ID (optional)"),
"question": fields.String(description="Question text (required when message_id not provided)"),
"answer": fields.String(description="Answer text (use 'answer' or 'content')"),
"content": fields.String(description="Content text (use 'answer' or 'content')"),
"annotation_reply": fields.Raw(description="Annotation reply data"),
},
)
@ -193,11 +196,14 @@ class AnnotationApi(Resource):
app_id = str(app_id)
parser = (
reqparse.RequestParser()
.add_argument("question", required=True, type=str, location="json")
.add_argument("answer", required=True, type=str, location="json")
.add_argument("message_id", required=False, type=uuid_value, location="json")
.add_argument("question", required=False, type=str, location="json")
.add_argument("answer", required=False, type=str, location="json")
.add_argument("content", required=False, type=str, location="json")
.add_argument("annotation_reply", required=False, type=dict, location="json")
)
args = parser.parse_args()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
return annotation
@setup_required

View File

@ -1,7 +1,5 @@
from datetime import datetime
import pytz
import sqlalchemy as sa
from flask import abort
from flask_restx import Resource, marshal_with, reqparse
from flask_restx.inputs import int_range
from sqlalchemy import func, or_
@ -19,7 +17,7 @@ from fields.conversation_fields import (
conversation_pagination_fields,
conversation_with_summary_pagination_fields,
)
from libs.datetime_utils import naive_utc_now
from libs.datetime_utils import naive_utc_now, parse_time_range
from libs.helper import DatetimeString
from libs.login import current_account_with_tenant, login_required
from models import Conversation, EndUser, Message, MessageAnnotation
@ -90,25 +88,17 @@ class CompletionConversationApi(Resource):
account = current_user
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
query = query.where(Conversation.created_at >= start_datetime_utc)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=59)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
end_datetime_utc = end_datetime_utc.replace(second=59)
query = query.where(Conversation.created_at < end_datetime_utc)
# FIXME, the type ignore in this file
@ -270,29 +260,21 @@ class ChatConversationApi(Resource):
account = current_user
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
match args["sort_by"]:
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at >= start_datetime_utc)
case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at >= start_datetime_utc)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=59)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
end_datetime_utc = end_datetime_utc.replace(second=59)
match args["sort_by"]:
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at <= end_datetime_utc)

View File

@ -16,7 +16,6 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
edit_permission_required,
setup_required,
)
@ -24,12 +23,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from fields.conversation_fields import annotation_fields, message_detail_fields
from fields.conversation_fields import message_detail_fields
from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import current_account_with_tenant, login_required
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.annotation_service import AppAnnotationService
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
from services.message_service import MessageService
@ -194,45 +192,6 @@ class MessageFeedbackApi(Resource):
return {"result": "success"}
@console_ns.route("/apps/<uuid:app_id>/annotations")
class MessageAnnotationApi(Resource):
@api.doc("create_message_annotation")
@api.doc(description="Create message annotation")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"MessageAnnotationRequest",
{
"message_id": fields.String(description="Message ID"),
"question": fields.String(required=True, description="Question text"),
"answer": fields.String(required=True, description="Answer text"),
"annotation_reply": fields.Raw(description="Annotation reply"),
},
)
)
@api.response(200, "Annotation created successfully", annotation_fields)
@api.response(403, "Insufficient permissions")
@marshal_with(annotation_fields)
@get_app_model
@setup_required
@login_required
@cloud_edition_billing_resource_check("annotation")
@account_initialization_required
@edit_permission_required
def post(self, app_model):
parser = (
reqparse.RequestParser()
.add_argument("message_id", required=False, type=uuid_value, location="json")
.add_argument("question", required=True, type=str, location="json")
.add_argument("answer", required=True, type=str, location="json")
.add_argument("annotation_reply", required=False, type=dict, location="json")
)
args = parser.parse_args()
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id)
return annotation
@console_ns.route("/apps/<uuid:app_id>/annotations/count")
class MessageAnnotationCountApi(Resource):
@api.doc("get_annotation_count")

View File

@ -1,9 +1,7 @@
from datetime import datetime
from decimal import Decimal
import pytz
import sqlalchemy as sa
from flask import jsonify
from flask import abort, jsonify
from flask_restx import Resource, fields, reqparse
from controllers.console import api, console_ns
@ -11,6 +9,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.datetime_utils import parse_time_range
from libs.helper import DatetimeString
from libs.login import current_account_with_tenant, login_required
from models import AppMode, Message
@ -56,26 +55,16 @@ WHERE
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
@ -120,8 +109,11 @@ class DailyConversationStatistic(Resource):
)
args = parser.parse_args()
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
stmt = (
sa.select(
@ -134,18 +126,10 @@ class DailyConversationStatistic(Resource):
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER)
)
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
if start_datetime_utc:
stmt = stmt.where(Message.created_at >= start_datetime_utc)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
stmt = stmt.where(Message.created_at < end_datetime_utc)
stmt = stmt.group_by("date").order_by("date")
@ -198,26 +182,17 @@ WHERE
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
@ -273,26 +248,17 @@ WHERE
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
@ -357,26 +323,17 @@ FROM
AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND c.created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND c.created_at < :end"
arg_dict["end"] = end_datetime_utc
@ -446,26 +403,17 @@ WHERE
AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND m.created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND m.created_at < :end"
arg_dict["end"] = end_datetime_utc
@ -525,26 +473,17 @@ WHERE
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
@ -602,26 +541,17 @@ WHERE
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc

View File

@ -104,7 +104,18 @@ class DraftWorkflowApi(Resource):
},
)
)
@api.response(200, "Draft workflow synced successfully", workflow_fields)
@api.response(
200,
"Draft workflow synced successfully",
api.model(
"SyncDraftWorkflowResponse",
{
"result": fields.String,
"hash": fields.String,
"updated_at": fields.String,
},
),
)
@api.response(400, "Invalid workflow configuration")
@api.response(403, "Permission denied")
@edit_permission_required

View File

@ -1,23 +1,26 @@
from datetime import datetime
from decimal import Decimal
import pytz
import sqlalchemy as sa
from flask import jsonify
from flask import abort, jsonify
from flask_restx import Resource, reqparse
from sqlalchemy.orm import sessionmaker
from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from libs.datetime_utils import parse_time_range
from libs.helper import DatetimeString
from libs.login import current_account_with_tenant, login_required
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
from repositories.factory import DifyAPIRepositoryFactory
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
class WorkflowDailyRunsStatistic(Resource):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@api.doc("get_workflow_daily_runs_statistic")
@api.doc(description="Get workflow daily runs statistics")
@api.doc(params={"app_id": "Application ID"})
@ -37,57 +40,32 @@ class WorkflowDailyRunsStatistic(Resource):
)
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(id) AS runs
FROM
workflow_runs
WHERE
app_id = :app_id
AND triggered_from = :triggered_from"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
try:
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "runs": i.runs})
response_data = self._workflow_run_repo.get_daily_runs_statistics(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
start_date=start_date,
end_date=end_date,
timezone=account.timezone,
)
return jsonify({"data": response_data})
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-terminals")
class WorkflowDailyTerminalsStatistic(Resource):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@api.doc("get_workflow_daily_terminals_statistic")
@api.doc(description="Get workflow daily terminals statistics")
@api.doc(params={"app_id": "Application ID"})
@ -107,57 +85,32 @@ class WorkflowDailyTerminalsStatistic(Resource):
)
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(DISTINCT workflow_runs.created_by) AS terminal_count
FROM
workflow_runs
WHERE
app_id = :app_id
AND triggered_from = :triggered_from"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
try:
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
response_data = self._workflow_run_repo.get_daily_terminals_statistics(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
start_date=start_date,
end_date=end_date,
timezone=account.timezone,
)
return jsonify({"data": response_data})
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/token-costs")
class WorkflowDailyTokenCostStatistic(Resource):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@api.doc("get_workflow_daily_token_cost_statistic")
@api.doc(description="Get workflow daily token cost statistics")
@api.doc(params={"app_id": "Application ID"})
@ -177,62 +130,32 @@ class WorkflowDailyTokenCostStatistic(Resource):
)
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
SUM(workflow_runs.total_tokens) AS token_count
FROM
workflow_runs
WHERE
app_id = :app_id
AND triggered_from = :triggered_from"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
try:
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append(
{
"date": str(i.date),
"token_count": i.token_count,
}
)
response_data = self._workflow_run_repo.get_daily_token_cost_statistics(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
start_date=start_date,
end_date=end_date,
timezone=account.timezone,
)
return jsonify({"data": response_data})
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/average-app-interactions")
class WorkflowAverageAppInteractionStatistic(Resource):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@api.doc("get_workflow_average_app_interaction_statistic")
@api.doc(description="Get workflow average app interaction statistics")
@api.doc(params={"app_id": "Application ID"})
@ -252,67 +175,20 @@ class WorkflowAverageAppInteractionStatistic(Resource):
)
args = parser.parse_args()
sql_query = """SELECT
AVG(sub.interactions) AS interactions,
sub.date
FROM
(
SELECT
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
c.created_by,
COUNT(c.id) AS interactions
FROM
workflow_runs c
WHERE
c.app_id = :app_id
AND c.triggered_from = :triggered_from
{{start}}
{{end}}
GROUP BY
date, c.created_by
) sub
GROUP BY
sub.date"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
try:
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query = sql_query.replace("{{start}}", " AND c.created_at >= :start")
arg_dict["start"] = start_datetime_utc
else:
sql_query = sql_query.replace("{{start}}", "")
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query = sql_query.replace("{{end}}", " AND c.created_at < :end")
arg_dict["end"] = end_datetime_utc
else:
sql_query = sql_query.replace("{{end}}", "")
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append(
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
)
response_data = self._workflow_run_repo.get_average_app_interaction_statistics(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
start_date=start_date,
end_date=end_date,
timezone=account.timezone,
)
return jsonify({"data": response_data})

View File

@ -4,7 +4,7 @@ from flask_restx import Resource, reqparse
import services
from configs import dify_config
from constants.languages import languages
from constants.languages import get_valid_language
from controllers.console import console_ns
from controllers.console.auth.error import (
AuthenticationFailedError,
@ -29,6 +29,7 @@ from libs.token import (
clear_access_token_from_cookie,
clear_csrf_token_from_cookie,
clear_refresh_token_from_cookie,
extract_refresh_token,
set_access_token_to_cookie,
set_csrf_token_to_cookie,
set_refresh_token_to_cookie,
@ -204,10 +205,12 @@ class EmailCodeLoginApi(Resource):
.add_argument("email", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
user_email = args["email"]
language = args["language"]
token_data = AccountService.get_email_code_login_data(args["token"])
if token_data is None:
@ -241,7 +244,9 @@ class EmailCodeLoginApi(Resource):
if account is None:
try:
account = AccountService.create_account_and_tenant(
email=user_email, name=user_email, interface_language=languages[0]
email=user_email,
name=user_email,
interface_language=get_valid_language(language),
)
except WorkSpaceNotAllowedCreateError:
raise NotAllowedCreateWorkspace()
@ -266,7 +271,7 @@ class EmailCodeLoginApi(Resource):
class RefreshTokenApi(Resource):
def post(self):
# Get refresh token from cookie instead of request body
refresh_token = request.cookies.get("refresh_token")
refresh_token = extract_refresh_token(request)
if not refresh_token:
return {"result": "fail", "message": "No refresh token provided"}, 401

View File

@ -2,6 +2,7 @@ from flask_restx import Resource, reqparse
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from enums.cloud_plan import CloudPlan
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
@ -16,7 +17,13 @@ class Subscription(Resource):
current_user, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
.add_argument(
"plan",
type=str,
required=True,
location="args",
choices=[CloudPlan.PROFESSIONAL, CloudPlan.TEAM],
)
.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
)
args = parser.parse_args()

View File

@ -746,7 +746,7 @@ class DocumentApi(DocumentResource):
"name": document.name,
"created_from": document.created_from,
"created_by": document.created_by,
"created_at": document.created_at.timestamp(),
"created_at": int(document.created_at.timestamp()),
"tokens": document.tokens,
"indexing_status": document.indexing_status,
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
@ -779,7 +779,7 @@ class DocumentApi(DocumentResource):
"name": document.name,
"created_from": document.created_from,
"created_by": document.created_by,
"created_at": document.created_at.timestamp(),
"created_at": int(document.created_at.timestamp()),
"tokens": document.tokens,
"indexing_status": document.indexing_status,
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,

View File

@ -4,7 +4,7 @@ from flask_restx import ( # type: ignore
)
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import current_user, login_required
@ -12,9 +12,17 @@ from models import Account
from models.dataset import Pipeline
from services.rag_pipeline.rag_pipeline import RagPipelineService
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
class DataSourceContentPreviewApi(Resource):
@api.expect(parser)
@setup_required
@login_required
@account_initialization_required
@ -26,12 +34,6 @@ class DataSourceContentPreviewApi(Resource):
if not isinstance(current_user, Account):
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
)
args = parser.parse_args()
inputs = args.get("inputs")

View File

@ -22,7 +22,7 @@ from core.errors.error import (
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager
from libs import helper
from libs.login import current_user as current_user_
from libs.login import current_account_with_tenant
from models.model import AppMode, InstalledApp
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
@ -31,8 +31,6 @@ from .. import console_ns
logger = logging.getLogger(__name__)
current_user = current_user_._get_current_object() # type: ignore
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
class InstalledAppWorkflowRunApi(InstalledAppResource):
@ -40,6 +38,7 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
"""
Run workflow
"""
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if not app_model:
raise NotWorkflowAppError()
@ -53,7 +52,6 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
.add_argument("files", type=list, required=False, location="json")
)
args = parser.parse_args()
assert current_user is not None
try:
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
@ -89,7 +87,6 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
assert current_user is not None
# Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)

View File

@ -66,13 +66,7 @@ class APIBasedExtensionAPI(Resource):
@account_initialization_required
@marshal_with(api_based_extension_fields)
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=True, location="json")
.add_argument("api_endpoint", type=str, required=True, location="json")
.add_argument("api_key", type=str, required=True, location="json")
)
args = parser.parse_args()
args = api.payload
_, current_tenant_id = current_account_with_tenant()
extension_data = APIBasedExtension(
@ -125,13 +119,7 @@ class APIBasedExtensionDetailAPI(Resource):
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=True, location="json")
.add_argument("api_endpoint", type=str, required=True, location="json")
.add_argument("api_key", type=str, required=True, location="json")
)
args = parser.parse_args()
args = api.payload
extension_data_from_db.name = args["name"]
extension_data_from_db.api_endpoint = args["api_endpoint"]

View File

@ -8,6 +8,7 @@ import services
from configs import dify_config
from constants import DOCUMENT_EXTENSIONS
from controllers.common.errors import (
BlockedFileExtensionError,
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
@ -39,6 +40,7 @@ class FileApi(Resource):
return {
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
"batch_count_limit": dify_config.UPLOAD_FILE_BATCH_LIMIT,
"file_upload_limit": dify_config.BATCH_UPLOAD_LIMIT,
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
@ -82,6 +84,8 @@ class FileApi(Resource):
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
except services.errors.file.BlockedFileExtensionError as blocked_extension_error:
raise BlockedFileExtensionError(blocked_extension_error.description)
return upload_file, 201

View File

@ -74,12 +74,17 @@ class SetupApi(Resource):
.add_argument("email", type=email, required=True, location="json")
.add_argument("name", type=StrLen(30), required=True, location="json")
.add_argument("password", type=valid_password, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
# setup
RegisterService.setup(
email=args["email"], name=args["name"], password=args["password"], ip_address=extract_remote_ip(request)
email=args["email"],
name=args["name"],
password=args["password"],
ip_address=extract_remote_ip(request),
language=args["language"],
)
return {"result": "success"}, 201

View File

@ -6,6 +6,7 @@ from flask_restx import (
Resource,
reqparse,
)
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from configs import dify_config
@ -15,20 +16,21 @@ from controllers.console.wraps import (
enterprise_license_required,
setup_required,
)
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.auth.auth_provider import OAuthClientProvider
from core.mcp.error import MCPAuthError, MCPError
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
from core.mcp.mcp_client import MCPClient
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.oauth import OAuthHandler
from core.tools.entities.tool_entities import CredentialType
from extensions.ext_database import db
from libs.helper import StrLen, alphanumeric, uuid_value
from libs.login import current_account_with_tenant, login_required
from models.provider_ids import ToolProviderID
from services.plugin.oauth_service import OAuthProxyService
from services.tools.api_tools_manage_service import ApiToolManageService
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from services.tools.mcp_tools_manage_service import MCPToolManageService
from services.tools.mcp_tools_manage_service import MCPToolManageService, OAuthDataType
from services.tools.tool_labels_service import ToolLabelsService
from services.tools.tools_manage_service import ToolCommonService
from services.tools.tools_transform_service import ToolTransformService
@ -42,7 +44,9 @@ def is_valid_url(url: str) -> bool:
try:
parsed = urlparse(url)
return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"]
except Exception:
except (ValueError, TypeError):
# ValueError: Invalid URL format
# TypeError: url is not a string
return False
@ -886,29 +890,34 @@ class ToolProviderMCPApi(Resource):
.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
.add_argument("timeout", type=float, required=False, nullable=False, location="json", default=30)
.add_argument("sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300)
.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
)
args = parser.parse_args()
user, tenant_id = current_account_with_tenant()
if not is_valid_url(args["server_url"]):
raise ValueError("Server URL is not valid.")
return jsonable_encoder(
MCPToolManageService.create_mcp_provider(
# Parse and validate models
configuration = MCPConfiguration.model_validate(args["configuration"])
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
# Create provider
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
result = service.create_provider(
tenant_id=tenant_id,
user_id=user.id,
server_url=args["server_url"],
name=args["name"],
icon=args["icon"],
icon_type=args["icon_type"],
icon_background=args["icon_background"],
user_id=user.id,
server_identifier=args["server_identifier"],
timeout=args["timeout"],
sse_read_timeout=args["sse_read_timeout"],
headers=args["headers"],
configuration=configuration,
authentication=authentication,
)
)
return jsonable_encoder(result)
@setup_required
@login_required
@ -923,31 +932,43 @@ class ToolProviderMCPApi(Resource):
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
.add_argument("timeout", type=float, required=False, nullable=True, location="json")
.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json")
.add_argument("headers", type=dict, required=False, nullable=True, location="json")
.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
)
args = parser.parse_args()
if not is_valid_url(args["server_url"]):
if "[__HIDDEN__]" in args["server_url"]:
pass
else:
raise ValueError("Server URL is not valid.")
configuration = MCPConfiguration.model_validate(args["configuration"])
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
_, current_tenant_id = current_account_with_tenant()
MCPToolManageService.update_mcp_provider(
tenant_id=current_tenant_id,
provider_id=args["provider_id"],
server_url=args["server_url"],
name=args["name"],
icon=args["icon"],
icon_type=args["icon_type"],
icon_background=args["icon_background"],
server_identifier=args["server_identifier"],
timeout=args.get("timeout"),
sse_read_timeout=args.get("sse_read_timeout"),
headers=args.get("headers"),
)
return {"result": "success"}
# Step 1: Validate server URL change if needed (includes URL format validation and network operation)
validation_result = None
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
validation_result = service.validate_server_url_change(
tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"]
)
# No need to check for errors here, exceptions will be raised directly
# Step 2: Perform database update in a transaction
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.update_provider(
tenant_id=current_tenant_id,
provider_id=args["provider_id"],
server_url=args["server_url"],
name=args["name"],
icon=args["icon"],
icon_type=args["icon_type"],
icon_background=args["icon_background"],
server_identifier=args["server_identifier"],
headers=args["headers"],
configuration=configuration,
authentication=authentication,
validation_result=validation_result,
)
return {"result": "success"}
@setup_required
@login_required
@ -958,8 +979,11 @@ class ToolProviderMCPApi(Resource):
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
MCPToolManageService.delete_mcp_tool(tenant_id=current_tenant_id, provider_id=args["provider_id"])
return {"result": "success"}
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"])
return {"result": "success"}
@console_ns.route("/workspaces/current/tool-provider/mcp/auth")
@ -976,37 +1000,53 @@ class ToolMCPAuthApi(Resource):
args = parser.parse_args()
provider_id = args["provider_id"]
_, tenant_id = current_account_with_tenant()
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if not provider:
raise ValueError("provider not found")
try:
with MCPClient(
provider.decrypted_server_url,
provider_id,
tenant_id,
authed=False,
authorization_code=args["authorization_code"],
for_list=True,
headers=provider.decrypted_headers,
timeout=provider.timeout,
sse_read_timeout=provider.sse_read_timeout,
):
MCPToolManageService.update_mcp_provider_credentials(
mcp_provider=provider,
credentials=provider.decrypted_credentials,
authed=True,
)
return {"result": "success"}
except MCPAuthError:
auth_provider = OAuthClientProvider(provider_id, tenant_id, for_list=True)
return auth(auth_provider, provider.decrypted_server_url, args["authorization_code"])
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
if not db_provider:
raise ValueError("provider not found")
# Convert to entity
provider_entity = db_provider.to_entity()
server_url = provider_entity.decrypt_server_url()
headers = provider_entity.decrypt_authentication()
# Try to connect without active transaction
try:
# Use MCPClientWithAuthRetry to handle authentication automatically
with MCPClient(
server_url=server_url,
headers=headers,
timeout=provider_entity.timeout,
sse_read_timeout=provider_entity.sse_read_timeout,
):
# Update credentials in new transaction
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.update_provider_credentials(
provider_id=provider_id,
tenant_id=tenant_id,
credentials=provider_entity.credentials,
authed=True,
)
return {"result": "success"}
except MCPAuthError as e:
try:
auth_result = auth(provider_entity, args.get("authorization_code"))
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
response = service.execute_auth_actions(auth_result)
return response
except MCPRefreshTokenError as e:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
except MCPError as e:
MCPToolManageService.update_mcp_provider_credentials(
mcp_provider=provider,
credentials={},
authed=False,
)
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
raise ValueError(f"Failed to connect to MCP server: {e}") from e
@ -1017,8 +1057,10 @@ class ToolMCPDetailApi(Resource):
@account_initialization_required
def get(self, provider_id):
_, tenant_id = current_account_with_tenant()
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
@console_ns.route("/workspaces/current/tools/mcp")
@ -1029,9 +1071,12 @@ class ToolMCPListAllApi(Resource):
def get(self):
_, tenant_id = current_account_with_tenant()
tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id)
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
# Skip sensitive data decryption for list view to improve performance
tools = service.list_providers(tenant_id=tenant_id, include_sensitive=False)
return [tool.to_dict() for tool in tools]
return [tool.to_dict() for tool in tools]
@console_ns.route("/workspaces/current/tool-provider/mcp/update/<path:provider_id>")
@ -1041,11 +1086,13 @@ class ToolMCPUpdateApi(Resource):
@account_initialization_required
def get(self, provider_id):
_, tenant_id = current_account_with_tenant()
tools = MCPToolManageService.list_mcp_tool_from_remote_server(
tenant_id=tenant_id,
provider_id=provider_id,
)
return jsonable_encoder(tools)
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
tools = service.list_provider_tools(
tenant_id=tenant_id,
provider_id=provider_id,
)
return jsonable_encoder(tools)
@console_ns.route("/mcp/oauth/callback")
@ -1059,5 +1106,15 @@ class ToolMCPCallbackApi(Resource):
args = parser.parse_args()
state_key = args["state"]
authorization_code = args["code"]
handle_callback(state_key, authorization_code)
# Create service instance for handle_callback
with Session(db.engine) as session, session.begin():
mcp_service = MCPToolManageService(session=session)
# handle_callback now returns state data and tokens
state_data, tokens = handle_callback(state_key, authorization_code)
# Save tokens using the service layer
mcp_service.save_oauth_data(
state_data.provider_id, state_data.tenant_id, tokens.model_dump(), OAuthDataType.TOKENS
)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")

View File

@ -21,6 +21,7 @@ from controllers.console.wraps import (
cloud_edition_billing_resource_check,
setup_required,
)
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
@ -83,7 +84,7 @@ class TenantListApi(Resource):
"name": tenant.name,
"status": tenant.status,
"created_at": tenant.created_at,
"plan": features.billing.subscription.plan if features.billing.enabled else "sandbox",
"plan": features.billing.subscription.plan if features.billing.enabled else CloudPlan.SANDBOX,
"current": tenant.id == current_tenant_id if current_tenant_id else False,
}

View File

@ -10,6 +10,7 @@ from flask import abort, request
from configs import dify_config
from controllers.console.workspace.error import AccountNotInitializedError
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.login import current_account_with_tenant
@ -133,7 +134,7 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
features = FeatureService.get_features(current_tenant_id)
if features.billing.enabled:
if resource == "add_segment":
if features.billing.subscription.plan == "sandbox":
if features.billing.subscription.plan == CloudPlan.SANDBOX:
abort(
403,
"To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.",

View File

@ -14,10 +14,25 @@ from services.file_service import FileService
@files_ns.route("/<uuid:file_id>/image-preview")
class ImagePreviewApi(Resource):
"""
Deprecated
"""
"""Deprecated endpoint for retrieving image previews."""
@files_ns.doc("get_image_preview")
@files_ns.doc(description="Retrieve a signed image preview for a file")
@files_ns.doc(
params={
"file_id": "ID of the file to preview",
"timestamp": "Unix timestamp used in the signature",
"nonce": "Random string used in the signature",
"sign": "HMAC signature verifying the request",
}
)
@files_ns.doc(
responses={
200: "Image preview returned successfully",
400: "Missing or invalid signature parameters",
415: "Unsupported file type",
}
)
def get(self, file_id):
file_id = str(file_id)
@ -43,6 +58,25 @@ class ImagePreviewApi(Resource):
@files_ns.route("/<uuid:file_id>/file-preview")
class FilePreviewApi(Resource):
@files_ns.doc("get_file_preview")
@files_ns.doc(description="Download a file preview or attachment using signed parameters")
@files_ns.doc(
params={
"file_id": "ID of the file to preview",
"timestamp": "Unix timestamp used in the signature",
"nonce": "Random string used in the signature",
"sign": "HMAC signature verifying the request",
"as_attachment": "Whether to download the file as an attachment",
}
)
@files_ns.doc(
responses={
200: "File stream returned successfully",
400: "Missing or invalid signature parameters",
404: "File not found",
415: "Unsupported file type",
}
)
def get(self, file_id):
file_id = str(file_id)
@ -101,6 +135,20 @@ class FilePreviewApi(Resource):
@files_ns.route("/workspaces/<uuid:workspace_id>/webapp-logo")
class WorkspaceWebappLogoApi(Resource):
@files_ns.doc("get_workspace_webapp_logo")
@files_ns.doc(description="Fetch the custom webapp logo for a workspace")
@files_ns.doc(
params={
"workspace_id": "Workspace identifier",
}
)
@files_ns.doc(
responses={
200: "Logo returned successfully",
404: "Webapp logo not configured",
415: "Unsupported file type",
}
)
def get(self, workspace_id):
workspace_id = str(workspace_id)

View File

@ -13,6 +13,26 @@ from extensions.ext_database import db as global_db
@files_ns.route("/tools/<uuid:file_id>.<string:extension>")
class ToolFileApi(Resource):
@files_ns.doc("get_tool_file")
@files_ns.doc(description="Download a tool file by ID using signed parameters")
@files_ns.doc(
params={
"file_id": "Tool file identifier",
"extension": "Expected file extension",
"timestamp": "Unix timestamp used in the signature",
"nonce": "Random string used in the signature",
"sign": "HMAC signature verifying the request",
"as_attachment": "Whether to download the file as an attachment",
}
)
@files_ns.doc(
responses={
200: "Tool file stream returned successfully",
403: "Forbidden - invalid signature",
404: "File not found",
415: "Unsupported file type",
}
)
def get(self, file_id, extension):
file_id = str(file_id)

View File

@ -193,15 +193,16 @@ class MCPAppApi(Resource):
except ValidationError as e:
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str, session: Session) -> EndUser | None:
"""Get end user from existing session - optimized query"""
return (
session.query(EndUser)
.where(EndUser.tenant_id == tenant_id)
.where(EndUser.session_id == mcp_server_id)
.where(EndUser.type == "mcp")
.first()
)
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str) -> EndUser | None:
"""Get end user - manages its own database session"""
with Session(db.engine, expire_on_commit=False) as session, session.begin():
return (
session.query(EndUser)
.where(EndUser.tenant_id == tenant_id)
.where(EndUser.session_id == mcp_server_id)
.where(EndUser.type == "mcp")
.first()
)
def _create_end_user(
self, client_name: str, tenant_id: str, app_id: str, mcp_server_id: str, session: Session
@ -229,7 +230,7 @@ class MCPAppApi(Resource):
request_id: Union[int, str],
) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError | None:
"""Handle MCP request and return response"""
end_user = self._retrieve_end_user(mcp_server.tenant_id, mcp_server.id, session)
end_user = self._retrieve_end_user(mcp_server.tenant_id, mcp_server.id)
if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest):
client_info = mcp_request.root.params.clientInfo

View File

@ -592,7 +592,7 @@ class DocumentApi(DatasetApiResource):
"name": document.name,
"created_from": document.created_from,
"created_by": document.created_by,
"created_at": document.created_at.timestamp(),
"created_at": int(document.created_at.timestamp()),
"tokens": document.tokens,
"indexing_status": document.indexing_status,
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
@ -625,7 +625,7 @@ class DocumentApi(DatasetApiResource):
"name": document.name,
"created_from": document.created_from,
"created_by": document.created_by,
"created_at": document.created_at.timestamp(),
"created_at": int(document.created_at.timestamp()),
"tokens": document.tokens,
"indexing_status": document.indexing_status,
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,

View File

@ -2,6 +2,7 @@ from flask import request
from flask_restx import marshal, reqparse
from werkzeug.exceptions import NotFound
from configs import dify_config
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.wraps import (
@ -107,6 +108,10 @@ class SegmentApi(DatasetApiResource):
# validate args
args = segment_create_parser.parse_args()
if args["segments"] is not None:
segments_limit = dify_config.DATASET_MAX_SEGMENTS_PER_REQUEST
if segments_limit > 0 and len(args["segments"]) > segments_limit:
raise ValueError(f"Exceeded maximum segments limit of {segments_limit}.")
for args_item in args["segments"]:
SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(args["segments"], document, dataset)

View File

@ -13,6 +13,7 @@ from sqlalchemy import select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
@ -66,6 +67,7 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
kwargs["app_model"] = app_model
# If caller needs end-user context, attach EndUser to current_user
if fetch_user_arg:
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
user_id = request.args.get("user")
@ -74,7 +76,6 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
user_id = request.form.get("user")
else:
# use default-user
user_id = None
if not user_id and fetch_user_arg.required:
@ -89,6 +90,28 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
# Set EndUser as current logged-in user for flask_login.current_user
current_app.login_manager._update_request_context_with_user(end_user) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=end_user) # type: ignore
else:
# For service API without end-user context, ensure an Account is logged in
# so services relying on current_account_with_tenant() work correctly.
tenant_owner_info = (
db.session.query(Tenant, Account)
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
.join(Account, TenantAccountJoin.account_id == Account.id)
.where(
Tenant.id == app_model.tenant_id,
TenantAccountJoin.role == "owner",
Tenant.status == TenantStatus.NORMAL,
)
.one_or_none()
)
if tenant_owner_info:
tenant_model, account = tenant_owner_info
account.current_tenant = tenant_model
current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
else:
raise Unauthorized("Tenant owner account not found or tenant is not active.")
return view_func(*args, **kwargs)
@ -138,7 +161,7 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
features = FeatureService.get_features(api_token.tenant_id)
if features.billing.enabled:
if resource == "add_segment":
if features.billing.subscription.plan == "sandbox":
if features.billing.subscription.plan == CloudPlan.SANDBOX:
raise Forbidden(
"To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."
)

View File

@ -17,8 +17,8 @@ from libs.helper import email
from libs.passport import PassportService
from libs.password import valid_password
from libs.token import (
clear_access_token_from_cookie,
extract_access_token,
clear_webapp_access_token_from_cookie,
extract_webapp_access_token,
)
from services.account_service import AccountService
from services.app_service import AppService
@ -81,7 +81,7 @@ class LoginStatusApi(Resource):
)
def get(self):
app_code = request.args.get("app_code")
token = extract_access_token(request)
token = extract_webapp_access_token(request)
if not app_code:
return {
"logged_in": bool(token),
@ -128,7 +128,7 @@ class LogoutApi(Resource):
response = make_response({"result": "success"})
# enterprise SSO sets same site to None in https deployment
# so we need to logout by calling api
clear_access_token_from_cookie(response, samesite="None")
clear_webapp_access_token_from_cookie(response, samesite="None")
return response

View File

@ -12,10 +12,8 @@ from controllers.web import web_ns
from controllers.web.error import WebAppAuthRequiredError
from extensions.ext_database import db
from libs.passport import PassportService
from libs.token import extract_access_token
from libs.token import extract_webapp_access_token
from models.model import App, EndUser, Site
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
@ -37,23 +35,18 @@ class PassportResource(Resource):
system_features = FeatureService.get_system_features()
app_code = request.headers.get(HEADER_NAME_APP_CODE)
user_id = request.args.get("user_id")
access_token = extract_access_token(request)
access_token = extract_webapp_access_token(request)
if app_code is None:
raise Unauthorized("X-App-Code header is missing.")
app_id = AppService.get_app_id_by_code(app_code)
# exchange token for enterprise logined web user
enterprise_user_decoded = decode_enterprise_webapp_user_id(access_token)
if enterprise_user_decoded:
# a web user has already logged in, exchange a token for this app without redirecting to the login page
return exchange_token_for_existing_web_user(
app_code=app_code, enterprise_user_decoded=enterprise_user_decoded
)
if system_features.webapp_auth.enabled:
app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id)
if not app_settings or not app_settings.access_mode == "public":
raise WebAppAuthRequiredError()
enterprise_user_decoded = decode_enterprise_webapp_user_id(access_token)
app_auth_type = WebAppAuthService.get_app_auth_type(app_code=app_code)
if app_auth_type != WebAppAuthType.PUBLIC:
if not enterprise_user_decoded:
raise WebAppAuthRequiredError()
return exchange_token_for_existing_web_user(
app_code=app_code, enterprise_user_decoded=enterprise_user_decoded, auth_type=app_auth_type
)
# get site from db and check if it is normal
site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal"))
@ -124,7 +117,7 @@ def decode_enterprise_webapp_user_id(jwt_token: str | None):
return decoded
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict):
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict, auth_type: WebAppAuthType):
"""
Exchange a token for an existing web user session.
"""
@ -145,13 +138,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound()
app_auth_type = WebAppAuthService.get_app_auth_type(app_code=app_code)
if app_auth_type == WebAppAuthType.PUBLIC:
if auth_type == WebAppAuthType.PUBLIC:
return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded)
elif app_auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external":
elif auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external":
raise WebAppAuthRequiredError("Please login as external user.")
elif app_auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal":
elif auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal":
raise WebAppAuthRequiredError("Please login as internal user.")
end_user = None

View File

@ -1,6 +1,6 @@
import logging
import time
from collections.abc import Mapping, MutableMapping
from collections.abc import Mapping, MutableMapping, Sequence
from typing import Any, cast
from sqlalchemy import select
@ -28,6 +28,7 @@ from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.graph_events import GraphRunSucceededEvent
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
@ -67,11 +68,13 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
app: App,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
):
super().__init__(
queue_manager=queue_manager,
variable_loader=variable_loader,
app_id=application_generate_entity.app_config.app_id,
graph_engine_layers=graph_engine_layers,
)
self.application_generate_entity = application_generate_entity
self.conversation = conversation
@ -207,6 +210,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
)
workflow_entry.graph_engine.layer(persistence_layer)
for layer in self._graph_engine_layers:
workflow_entry.graph_engine.layer(layer)
generator = workflow_entry.run()

View File

@ -1,3 +1,4 @@
import json
import logging
import re
import time
@ -60,6 +61,7 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.nodes import NodeType
@ -391,6 +393,14 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if should_direct_answer:
return
current_time = time.perf_counter()
if self._task_state.first_token_time is None and delta_text.strip():
self._task_state.first_token_time = current_time
self._task_state.is_streaming_response = True
if delta_text.strip():
self._task_state.last_token_time = current_time
# Only publish tts message at text chunk streaming
if tts_publisher and queue_message:
tts_publisher.publish(queue_message)
@ -772,7 +782,33 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
message.answer = answer_text
message.updated_at = naive_utc_now()
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline.start_at
message.message_metadata = self._task_state.metadata.model_dump_json()
# Set usage first before dumping metadata
if graph_runtime_state and graph_runtime_state.llm_usage:
usage = graph_runtime_state.llm_usage
message.message_tokens = usage.prompt_tokens
message.message_unit_price = usage.prompt_unit_price
message.message_price_unit = usage.prompt_price_unit
message.answer_tokens = usage.completion_tokens
message.answer_unit_price = usage.completion_unit_price
message.answer_price_unit = usage.completion_price_unit
message.total_price = usage.total_price
message.currency = usage.currency
self._task_state.metadata.usage = usage
else:
usage = LLMUsage.empty_usage()
self._task_state.metadata.usage = usage
# Add streaming metrics to usage if available
if self._task_state.is_streaming_response and self._task_state.first_token_time:
start_time = self._base_task_pipeline.start_at
first_token_time = self._task_state.first_token_time
last_token_time = self._task_state.last_token_time or first_token_time
usage.time_to_first_token = round(first_token_time - start_time, 3)
usage.time_to_generate = round(last_token_time - first_token_time, 3)
metadata = self._task_state.metadata.model_dump()
message.message_metadata = json.dumps(jsonable_encoder(metadata))
message_files = [
MessageFile(
message_id=message.id,
@ -790,20 +826,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
]
session.add_all(message_files)
if graph_runtime_state and graph_runtime_state.llm_usage:
usage = graph_runtime_state.llm_usage
message.message_tokens = usage.prompt_tokens
message.message_unit_price = usage.prompt_unit_price
message.message_price_unit = usage.prompt_price_unit
message.answer_tokens = usage.completion_tokens
message.answer_unit_price = usage.completion_unit_price
message.answer_price_unit = usage.completion_price_unit
message.total_price = usage.total_price
message.currency = usage.currency
self._task_state.metadata.usage = usage
else:
self._task_state.metadata.usage = LLMUsage.empty_usage()
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
"""Bootstrap the cached runtime state from the queue manager when present."""
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state

View File

@ -144,7 +144,7 @@ class AgentChatAppRunner(AppRunner):
prompt_template_entity=app_config.prompt_template,
inputs=dict(inputs),
files=list(files),
query=query or "",
query=query,
memory=memory,
)
@ -172,7 +172,7 @@ class AgentChatAppRunner(AppRunner):
prompt_template_entity=app_config.prompt_template,
inputs=dict(inputs),
files=list(files),
query=query or "",
query=query,
memory=memory,
)

View File

@ -79,7 +79,7 @@ class AppRunner:
prompt_template_entity: PromptTemplateEntity,
inputs: Mapping[str, str],
files: Sequence["File"],
query: str | None = None,
query: str = "",
context: str | None = None,
memory: TokenBufferMemory | None = None,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
@ -105,7 +105,7 @@ class AppRunner:
app_mode=AppMode.value_of(app_record.mode),
prompt_template_entity=prompt_template_entity,
inputs=inputs,
query=query or "",
query=query,
files=files,
context=context,
memory=memory,

View File

@ -4,7 +4,7 @@ from dataclasses import dataclass
from datetime import datetime
from typing import Any, NewType, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueIterationCompletedEvent,
@ -51,7 +51,7 @@ from core.workflow.workflow_entry import WorkflowEntry
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from libs.datetime_utils import naive_utc_now
from models import Account, EndUser
from services.variable_truncator import VariableTruncator
from services.variable_truncator import BaseTruncator, DummyVariableTruncator, VariableTruncator
NodeExecutionId = NewType("NodeExecutionId", str)
@ -70,6 +70,8 @@ class _NodeSnapshot:
class WorkflowResponseConverter:
_truncator: BaseTruncator
def __init__(
self,
*,
@ -81,7 +83,13 @@ class WorkflowResponseConverter:
self._user = user
self._system_variables = system_variables
self._workflow_inputs = self._prepare_workflow_inputs()
self._truncator = VariableTruncator.default()
# Disable truncation for SERVICE_API calls to keep backward compatibility.
if application_generate_entity.invoke_from == InvokeFrom.SERVICE_API:
self._truncator = DummyVariableTruncator()
else:
self._truncator = VariableTruncator.default()
self._node_snapshots: dict[NodeExecutionId, _NodeSnapshot] = {}
self._workflow_execution_id: str | None = None
self._workflow_started_at: datetime | None = None

View File

@ -190,7 +190,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
conversation_id=conversation.id,
inputs=application_generate_entity.inputs,
query=application_generate_entity.query or "",
query=application_generate_entity.query,
message="",
message_tokens=0,
message_unit_price=0,

View File

@ -40,6 +40,7 @@ from core.workflow.repositories.draft_variable_repository import DraftVariableSa
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.flask_utils import preserve_flask_contexts
@ -255,7 +256,7 @@ class PipelineGenerator(BaseAppGenerator):
json_text = json.dumps(text)
upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id)
features = FeatureService.get_features(dataset.tenant_id)
if features.billing.subscription.plan == "sandbox":
if features.billing.enabled and features.billing.subscription.plan == CloudPlan.SANDBOX:
tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}"
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}"

View File

@ -135,6 +135,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
)
workflow_entry.graph_engine.layer(persistence_layer)
for layer in self._graph_engine_layers:
workflow_entry.graph_engine.layer(layer)
generator = workflow_entry.run()

View File

@ -1,5 +1,5 @@
import time
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from typing import Any, cast
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -27,6 +27,7 @@ from core.app.entities.queue_entities import (
)
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunFailedEvent,
@ -69,10 +70,12 @@ class WorkflowBasedAppRunner:
queue_manager: AppQueueManager,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
app_id: str,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
):
self._queue_manager = queue_manager
self._variable_loader = variable_loader
self._app_id = app_id
self._graph_engine_layers = graph_engine_layers
def _init_graph(
self,

View File

@ -129,7 +129,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
app_config: EasyUIBasedAppConfig = None # type: ignore
model_conf: ModelConfigWithCredentialsEntity
query: str | None = None
query: str = ""
# pydantic configs
model_config = ConfigDict(protected_namespaces=())

View File

@ -48,6 +48,9 @@ class WorkflowTaskState(TaskState):
"""
answer: str = ""
first_token_time: float | None = None
last_token_time: float | None = None
is_streaming_response: bool = False
class StreamEvent(StrEnum):

View File

@ -0,0 +1,71 @@
from sqlalchemy import Engine
from sqlalchemy.orm import sessionmaker
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events.base import GraphEngineEvent
from core.workflow.graph_events.graph import GraphRunPausedEvent
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.factory import DifyAPIRepositoryFactory
class PauseStatePersistenceLayer(GraphEngineLayer):
def __init__(self, session_factory: Engine | sessionmaker, state_owner_user_id: str):
"""Create a PauseStatePersistenceLayer.
The `state_owner_user_id` is used when creating state file for pause.
It generally should id of the creator of workflow.
"""
if isinstance(session_factory, Engine):
session_factory = sessionmaker(session_factory)
self._session_maker = session_factory
self._state_owner_user_id = state_owner_user_id
def _get_repo(self) -> APIWorkflowRunRepository:
return DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_maker)
def on_graph_start(self) -> None:
"""
Called when graph execution starts.
This is called after the engine has been initialized but before any nodes
are executed. Layers can use this to set up resources or log start information.
"""
pass
def on_event(self, event: GraphEngineEvent) -> None:
"""
Called for every event emitted by the engine.
This method receives all events generated during graph execution, including:
- Graph lifecycle events (start, success, failure)
- Node execution events (start, success, failure, retry)
- Stream events for response nodes
- Container events (iteration, loop)
Args:
event: The event emitted by the engine
"""
if not isinstance(event, GraphRunPausedEvent):
return
assert self.graph_runtime_state is not None
workflow_run_id: str | None = self.graph_runtime_state.system_variable.workflow_execution_id
assert workflow_run_id is not None
repo = self._get_repo()
repo.create_workflow_pause(
workflow_run_id=workflow_run_id,
state_owner_user_id=self._state_owner_user_id,
state=self.graph_runtime_state.dumps(),
)
def on_graph_end(self, error: Exception | None) -> None:
"""
Called when graph execution ends.
This is called after all nodes have been executed or when execution is
aborted. Layers can use this to clean up resources or log final state.
Args:
error: The exception that caused execution to fail, or None if successful
"""
pass

View File

@ -121,7 +121,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
# start generate conversation name thread
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
conversation_id=self._conversation_id, query=self._application_generate_entity.query or ""
conversation_id=self._conversation_id, query=self._application_generate_entity.query
)
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)

View File

@ -140,7 +140,27 @@ class MessageCycleManager:
if not self._application_generate_entity.app_config.additional_features:
raise ValueError("Additional features not found")
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
self._task_state.metadata.retriever_resources = event.retriever_resources
merged_resources = [r for r in self._task_state.metadata.retriever_resources or [] if r]
existing_ids = {(r.dataset_id, r.document_id) for r in merged_resources if r.dataset_id and r.document_id}
# Add new unique resources from the event
for resource in event.retriever_resources or []:
if not resource:
continue
is_duplicate = (
resource.dataset_id
and resource.document_id
and (resource.dataset_id, resource.document_id) in existing_ids
)
if not is_duplicate:
merged_resources.append(resource)
for i, resource in enumerate(merged_resources, 1):
resource.position = i
self._task_state.metadata.retriever_resources = merged_resources
def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> MessageFileStreamResponse | None:
"""

View File

@ -0,0 +1,329 @@
import json
from datetime import datetime
from enum import StrEnum
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse
from pydantic import BaseModel
from configs import dify_config
from core.entities.provider_entities import BasicProviderConfig
from core.file import helpers as file_helpers
from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
if TYPE_CHECKING:
from models.tools import MCPToolProvider
# Constants
CLIENT_NAME = "Dify"
CLIENT_URI = "https://github.com/langgenius/dify"
DEFAULT_TOKEN_TYPE = "Bearer"
DEFAULT_EXPIRES_IN = 3600
MASK_CHAR = "*"
MIN_UNMASK_LENGTH = 6
class MCPSupportGrantType(StrEnum):
"""The supported grant types for MCP"""
AUTHORIZATION_CODE = "authorization_code"
CLIENT_CREDENTIALS = "client_credentials"
REFRESH_TOKEN = "refresh_token"
class MCPAuthentication(BaseModel):
client_id: str
client_secret: str | None = None
class MCPConfiguration(BaseModel):
timeout: float = 30
sse_read_timeout: float = 300
class MCPProviderEntity(BaseModel):
"""MCP Provider domain entity for business logic operations"""
# Basic identification
id: str
provider_id: str # server_identifier
name: str
tenant_id: str
user_id: str
# Server connection info
server_url: str # encrypted URL
headers: dict[str, str] # encrypted headers
timeout: float
sse_read_timeout: float
# Authentication related
authed: bool
credentials: dict[str, Any] # encrypted credentials
code_verifier: str | None = None # for OAuth
# Tools and display info
tools: list[dict[str, Any]] # parsed tools list
icon: str | dict[str, str] # parsed icon
# Timestamps
created_at: datetime
updated_at: datetime
@classmethod
def from_db_model(cls, db_provider: "MCPToolProvider") -> "MCPProviderEntity":
"""Create entity from database model with decryption"""
return cls(
id=db_provider.id,
provider_id=db_provider.server_identifier,
name=db_provider.name,
tenant_id=db_provider.tenant_id,
user_id=db_provider.user_id,
server_url=db_provider.server_url,
headers=db_provider.headers,
timeout=db_provider.timeout,
sse_read_timeout=db_provider.sse_read_timeout,
authed=db_provider.authed,
credentials=db_provider.credentials,
tools=db_provider.tool_dict,
icon=db_provider.icon or "",
created_at=db_provider.created_at,
updated_at=db_provider.updated_at,
)
@property
def redirect_url(self) -> str:
"""OAuth redirect URL"""
return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
@property
def client_metadata(self) -> OAuthClientMetadata:
"""Metadata about this OAuth client."""
# Get grant type from credentials
credentials = self.decrypt_credentials()
# Try to get grant_type from different locations
grant_type = credentials.get("grant_type", MCPSupportGrantType.AUTHORIZATION_CODE)
# For nested structure, check if client_information has grant_types
if "client_information" in credentials and isinstance(credentials["client_information"], dict):
client_info = credentials["client_information"]
# If grant_types is specified in client_information, use it to determine grant_type
if "grant_types" in client_info and isinstance(client_info["grant_types"], list):
if "client_credentials" in client_info["grant_types"]:
grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS
elif "authorization_code" in client_info["grant_types"]:
grant_type = MCPSupportGrantType.AUTHORIZATION_CODE
# Configure based on grant type
is_client_credentials = grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS
grant_types = ["refresh_token"]
grant_types.append("client_credentials" if is_client_credentials else "authorization_code")
response_types = [] if is_client_credentials else ["code"]
redirect_uris = [] if is_client_credentials else [self.redirect_url]
return OAuthClientMetadata(
redirect_uris=redirect_uris,
token_endpoint_auth_method="none",
grant_types=grant_types,
response_types=response_types,
client_name=CLIENT_NAME,
client_uri=CLIENT_URI,
)
@property
def provider_icon(self) -> dict[str, str] | str:
"""Get provider icon, handling both dict and string formats"""
if isinstance(self.icon, dict):
return self.icon
try:
return json.loads(self.icon)
except (json.JSONDecodeError, TypeError):
# If not JSON, assume it's a file path
return file_helpers.get_signed_file_url(self.icon)
def to_api_response(self, user_name: str | None = None, include_sensitive: bool = True) -> dict[str, Any]:
"""Convert to API response format
Args:
user_name: User name to display
include_sensitive: If False, skip expensive decryption operations (for list view optimization)
"""
response = {
"id": self.id,
"author": user_name or "Anonymous",
"name": self.name,
"icon": self.provider_icon,
"type": ToolProviderType.MCP.value,
"is_team_authorization": self.authed,
"server_url": self.masked_server_url(),
"server_identifier": self.provider_id,
"updated_at": int(self.updated_at.timestamp()),
"label": I18nObject(en_US=self.name, zh_Hans=self.name).to_dict(),
"description": I18nObject(en_US="", zh_Hans="").to_dict(),
}
# Add configuration
response["configuration"] = {
"timeout": str(self.timeout),
"sse_read_timeout": str(self.sse_read_timeout),
}
# Skip expensive operations when sensitive data is not needed (e.g., list view)
if not include_sensitive:
response["masked_headers"] = {}
response["is_dynamic_registration"] = True
else:
# Add masked headers
response["masked_headers"] = self.masked_headers()
# Add authentication info if available
masked_creds = self.masked_credentials()
if masked_creds:
response["authentication"] = masked_creds
response["is_dynamic_registration"] = self.credentials.get("client_information", {}).get(
"is_dynamic_registration", True
)
return response
def retrieve_client_information(self) -> OAuthClientInformation | None:
"""OAuth client information if available"""
credentials = self.decrypt_credentials()
if not credentials:
return None
# Check if we have nested client_information structure
if "client_information" not in credentials:
return None
client_info_data = credentials["client_information"]
if isinstance(client_info_data, dict):
if "encrypted_client_secret" in client_info_data:
client_info_data["client_secret"] = encrypter.decrypt_token(
self.tenant_id, client_info_data["encrypted_client_secret"]
)
return OAuthClientInformation.model_validate(client_info_data)
return None
def retrieve_tokens(self) -> OAuthTokens | None:
"""OAuth tokens if available"""
if not self.credentials:
return None
credentials = self.decrypt_credentials()
return OAuthTokens(
access_token=credentials.get("access_token", ""),
token_type=credentials.get("token_type", DEFAULT_TOKEN_TYPE),
expires_in=int(credentials.get("expires_in", str(DEFAULT_EXPIRES_IN)) or DEFAULT_EXPIRES_IN),
refresh_token=credentials.get("refresh_token", ""),
)
def masked_server_url(self) -> str:
"""Masked server URL for display"""
parsed = urlparse(self.decrypt_server_url())
if parsed.path and parsed.path != "/":
masked = parsed._replace(path="/******")
return masked.geturl()
return parsed.geturl()
def _mask_value(self, value: str) -> str:
"""Mask a sensitive value for display"""
if len(value) > MIN_UNMASK_LENGTH:
return value[:2] + MASK_CHAR * (len(value) - 4) + value[-2:]
else:
return MASK_CHAR * len(value)
def masked_headers(self) -> dict[str, str]:
"""Masked headers for display"""
return {key: self._mask_value(value) for key, value in self.decrypt_headers().items()}
def masked_credentials(self) -> dict[str, str]:
"""Masked credentials for display"""
credentials = self.decrypt_credentials()
if not credentials:
return {}
masked = {}
if "client_information" not in credentials or not isinstance(credentials["client_information"], dict):
return {}
client_info = credentials["client_information"]
# Mask sensitive fields from nested structure
if client_info.get("client_id"):
masked["client_id"] = self._mask_value(client_info["client_id"])
if client_info.get("encrypted_client_secret"):
masked["client_secret"] = self._mask_value(
encrypter.decrypt_token(self.tenant_id, client_info["encrypted_client_secret"])
)
if client_info.get("client_secret"):
masked["client_secret"] = self._mask_value(client_info["client_secret"])
return masked
def decrypt_server_url(self) -> str:
"""Decrypt server URL"""
return encrypter.decrypt_token(self.tenant_id, self.server_url)
def _decrypt_dict(self, data: dict[str, Any]) -> dict[str, Any]:
"""Generic method to decrypt dictionary fields"""
from core.tools.utils.encryption import create_provider_encrypter
if not data:
return {}
# Only decrypt fields that are actually encrypted
# For nested structures, client_information is not encrypted as a whole
encrypted_fields = []
for key, value in data.items():
# Skip nested objects - they are not encrypted
if isinstance(value, dict):
continue
# Only process string values that might be encrypted
if isinstance(value, str) and value:
encrypted_fields.append(key)
if not encrypted_fields:
return data
# Create dynamic config only for encrypted fields
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in encrypted_fields]
encrypter_instance, _ = create_provider_encrypter(
tenant_id=self.tenant_id,
config=config,
cache=NoOpProviderCredentialCache(),
)
# Decrypt only the encrypted fields
decrypted_data = encrypter_instance.decrypt({k: data[k] for k in encrypted_fields})
# Merge decrypted data with original data (preserving non-encrypted fields)
result = data.copy()
result.update(decrypted_data)
return result
def decrypt_headers(self) -> dict[str, Any]:
"""Decrypt headers"""
return self._decrypt_dict(self.headers)
def decrypt_credentials(self) -> dict[str, Any]:
"""Decrypt credentials"""
return self._decrypt_dict(self.credentials)
def decrypt_authentication(self) -> dict[str, Any]:
"""Decrypt authentication"""
# Option 1: if headers is provided, use it and don't need to get token
headers = self.decrypt_headers()
# Option 2: Add OAuth token if authed and no headers provided
if not self.headers and self.authed:
token = self.retrieve_tokens()
if token:
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
return headers

View File

@ -74,6 +74,10 @@ class File(BaseModel):
storage_key: str | None = None,
dify_model_identity: str | None = FILE_MODEL_IDENTITY,
url: str | None = None,
# Legacy compatibility fields - explicitly handle known extra fields
tool_file_id: str | None = None,
upload_file_id: str | None = None,
datasource_file_id: str | None = None,
):
super().__init__(
id=id,

View File

@ -6,10 +6,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
class NodeJsTemplateTransformer(TemplateTransformer):
@classmethod
def get_runner_script(cls) -> str:
runner_script = dedent(
f"""
// declare main function
{cls._code_placeholder}
runner_script = dedent(f""" {cls._code_placeholder}
// decode and prepare input object
var inputs_obj = JSON.parse(Buffer.from('{cls._inputs_placeholder}', 'base64').toString('utf-8'))
@ -21,6 +18,5 @@ class NodeJsTemplateTransformer(TemplateTransformer):
var output_json = JSON.stringify(output_obj)
var result = `<<RESULT>>${{output_json}}<<RESULT>>`
console.log(result)
"""
)
""")
return runner_script

View File

@ -6,9 +6,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
class Python3TemplateTransformer(TemplateTransformer):
@classmethod
def get_runner_script(cls) -> str:
runner_script = dedent(f"""
# declare main function
{cls._code_placeholder}
runner_script = dedent(f""" {cls._code_placeholder}
import json
from base64 import b64decode

View File

@ -29,6 +29,18 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
return [MarketplacePluginDeclaration.model_validate(plugin) for plugin in response.json()["data"]["plugins"]]
def batch_fetch_plugin_by_ids(plugin_ids: list[str]) -> list[dict]:
if not plugin_ids:
return []
url = str(marketplace_api_url / "api/v1/plugins/batch")
response = httpx.post(url, json={"plugin_ids": plugin_ids}, headers={"X-Dify-Version": dify_config.project.version})
response.raise_for_status()
data = response.json()
return data.get("data", {}).get("plugins", [])
def batch_fetch_plugin_manifests_ignore_deserialization_error(
plugin_ids: list[str],
) -> Sequence[MarketplacePluginDeclaration]:

View File

@ -415,7 +415,6 @@ class IndexingRunner:
document_id=dataset_document.id,
after_indexing_status="splitting",
extra_update_params={
DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs),
DatasetDocument.parsing_completed_at: naive_utc_now(),
},
)
@ -755,6 +754,7 @@ class IndexingRunner:
extra_update_params={
DatasetDocument.cleaning_completed_at: cur_time,
DatasetDocument.splitting_completed_at: cur_time,
DatasetDocument.word_count: sum(len(doc.page_content) for doc in documents),
},
)

View File

@ -6,11 +6,15 @@ import secrets
import urllib.parse
from urllib.parse import urljoin, urlparse
import httpx
from pydantic import BaseModel, ValidationError
from httpx import ConnectError, HTTPStatusError, RequestError
from pydantic import ValidationError
from core.mcp.auth.auth_provider import OAuthClientProvider
from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType
from core.helper import ssrf_proxy
from core.mcp.entities import AuthAction, AuthActionType, AuthResult, OAuthCallbackState
from core.mcp.error import MCPRefreshTokenError
from core.mcp.types import (
LATEST_PROTOCOL_VERSION,
OAuthClientInformation,
OAuthClientInformationFull,
OAuthClientMetadata,
@ -19,21 +23,10 @@ from core.mcp.types import (
)
from extensions.ext_redis import redis_client
LATEST_PROTOCOL_VERSION = "1.0"
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
class OAuthCallbackState(BaseModel):
provider_id: str
tenant_id: str
server_url: str
metadata: OAuthMetadata | None = None
client_information: OAuthClientInformation
code_verifier: str
redirect_uri: str
def generate_pkce_challenge() -> tuple[str, str]:
"""Generate PKCE challenge and verifier."""
code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8")
@ -80,8 +73,13 @@ def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
raise ValueError(f"Invalid state parameter: {str(e)}")
def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackState:
"""Handle the callback from the OAuth provider."""
def handle_callback(state_key: str, authorization_code: str) -> tuple[OAuthCallbackState, OAuthTokens]:
"""
Handle the callback from the OAuth provider.
Returns:
A tuple of (callback_state, tokens) that can be used by the caller to save data.
"""
# Retrieve state data from Redis (state is automatically deleted after retrieval)
full_state_data = _retrieve_redis_state(state_key)
@ -93,30 +91,32 @@ def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackSta
full_state_data.code_verifier,
full_state_data.redirect_uri,
)
provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id, for_list=True)
provider.save_tokens(tokens)
return full_state_data
return full_state_data, tokens
def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
"""Check if the server supports OAuth 2.0 Resource Discovery."""
b_scheme, b_netloc, b_path, _, b_query, b_fragment = urlparse(server_url, "", True)
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource{b_path}"
b_scheme, b_netloc, _, _, b_query, b_fragment = urlparse(server_url, "", True)
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource"
if b_query:
url_for_resource_discovery += f"?{b_query}"
if b_fragment:
url_for_resource_discovery += f"#{b_fragment}"
try:
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
response = httpx.get(url_for_resource_discovery, headers=headers)
response = ssrf_proxy.get(url_for_resource_discovery, headers=headers)
if 200 <= response.status_code < 300:
body = response.json()
if "authorization_server_url" in body:
# Support both singular and plural forms
if body.get("authorization_servers"):
return True, body["authorization_servers"][0]
elif body.get("authorization_server_url"):
return True, body["authorization_server_url"][0]
else:
return False, ""
return False, ""
except httpx.RequestError:
except RequestError:
# Not support resource discovery, fall back to well-known OAuth metadata
return False, ""
@ -126,27 +126,37 @@ def discover_oauth_metadata(server_url: str, protocol_version: str | None = None
# First check if the server supports OAuth 2.0 Resource Discovery
support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url)
if support_resource_discovery:
url = oauth_discovery_url
# The oauth_discovery_url is the authorization server base URL
# Try OpenID Connect discovery first (more common), then OAuth 2.0
urls_to_try = [
urljoin(oauth_discovery_url + "/", ".well-known/oauth-authorization-server"),
urljoin(oauth_discovery_url + "/", ".well-known/openid-configuration"),
]
else:
url = urljoin(server_url, "/.well-known/oauth-authorization-server")
urls_to_try = [urljoin(server_url, "/.well-known/oauth-authorization-server")]
try:
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
response = httpx.get(url, headers=headers)
if response.status_code == 404:
return None
if not response.is_success:
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
return OAuthMetadata.model_validate(response.json())
except httpx.RequestError as e:
if isinstance(e, httpx.ConnectError):
response = httpx.get(url)
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
for url in urls_to_try:
try:
response = ssrf_proxy.get(url, headers=headers)
if response.status_code == 404:
return None
continue
if not response.is_success:
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
response.raise_for_status()
return OAuthMetadata.model_validate(response.json())
raise
except (RequestError, HTTPStatusError) as e:
if isinstance(e, ConnectError):
response = ssrf_proxy.get(url)
if response.status_code == 404:
continue # Try next URL
if not response.is_success:
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
return OAuthMetadata.model_validate(response.json())
# For other errors, try next URL
continue
return None # No metadata found
def start_authorization(
@ -213,7 +223,7 @@ def exchange_authorization(
redirect_uri: str,
) -> OAuthTokens:
"""Exchanges an authorization code for an access token."""
grant_type = "authorization_code"
grant_type = MCPSupportGrantType.AUTHORIZATION_CODE.value
if metadata:
token_url = metadata.token_endpoint
@ -233,7 +243,7 @@ def exchange_authorization(
if client_information.client_secret:
params["client_secret"] = client_information.client_secret
response = httpx.post(token_url, data=params)
response = ssrf_proxy.post(token_url, data=params)
if not response.is_success:
raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
return OAuthTokens.model_validate(response.json())
@ -246,7 +256,7 @@ def refresh_authorization(
refresh_token: str,
) -> OAuthTokens:
"""Exchange a refresh token for an updated access token."""
grant_type = "refresh_token"
grant_type = MCPSupportGrantType.REFRESH_TOKEN.value
if metadata:
token_url = metadata.token_endpoint
@ -263,10 +273,55 @@ def refresh_authorization(
if client_information.client_secret:
params["client_secret"] = client_information.client_secret
response = httpx.post(token_url, data=params)
try:
response = ssrf_proxy.post(token_url, data=params)
except ssrf_proxy.MaxRetriesExceededError as e:
raise MCPRefreshTokenError(e) from e
if not response.is_success:
raise ValueError(f"Token refresh failed: HTTP {response.status_code}")
raise MCPRefreshTokenError(response.text)
return OAuthTokens.model_validate(response.json())
def client_credentials_flow(
server_url: str,
metadata: OAuthMetadata | None,
client_information: OAuthClientInformation,
scope: str | None = None,
) -> OAuthTokens:
"""Execute Client Credentials Flow to get access token."""
grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
if metadata:
token_url = metadata.token_endpoint
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
else:
token_url = urljoin(server_url, "/token")
# Support both Basic Auth and body parameters for client authentication
headers = {"Content-Type": "application/x-www-form-urlencoded"}
data = {"grant_type": grant_type}
if scope:
data["scope"] = scope
# If client_secret is provided, use Basic Auth (preferred method)
if client_information.client_secret:
credentials = f"{client_information.client_id}:{client_information.client_secret}"
encoded_credentials = base64.b64encode(credentials.encode()).decode()
headers["Authorization"] = f"Basic {encoded_credentials}"
else:
# Fall back to including credentials in the body
data["client_id"] = client_information.client_id
if client_information.client_secret:
data["client_secret"] = client_information.client_secret
response = ssrf_proxy.post(token_url, headers=headers, data=data)
if not response.is_success:
raise ValueError(
f"Client credentials token request failed: HTTP {response.status_code}, Response: {response.text}"
)
return OAuthTokens.model_validate(response.json())
@ -283,7 +338,7 @@ def register_client(
else:
registration_url = urljoin(server_url, "/register")
response = httpx.post(
response = ssrf_proxy.post(
registration_url,
json=client_metadata.model_dump(),
headers={"Content-Type": "application/json"},
@ -294,28 +349,111 @@ def register_client(
def auth(
provider: OAuthClientProvider,
server_url: str,
provider: MCPProviderEntity,
authorization_code: str | None = None,
state_param: str | None = None,
for_list: bool = False,
) -> dict[str, str]:
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
metadata = discover_oauth_metadata(server_url)
) -> AuthResult:
"""
Orchestrates the full auth flow with a server using secure Redis state storage.
This function performs only network operations and returns actions that need
to be performed by the caller (such as saving data to database).
Args:
provider: The MCP provider entity
authorization_code: Optional authorization code from OAuth callback
state_param: Optional state parameter from OAuth callback
Returns:
AuthResult containing actions to be performed and response data
"""
actions: list[AuthAction] = []
server_url = provider.decrypt_server_url()
server_metadata = discover_oauth_metadata(server_url)
client_metadata = provider.client_metadata
provider_id = provider.id
tenant_id = provider.tenant_id
client_information = provider.retrieve_client_information()
redirect_url = provider.redirect_url
# Determine grant type based on server metadata
if not server_metadata:
raise ValueError("Failed to discover OAuth metadata from server")
supported_grant_types = server_metadata.grant_types_supported or []
# Convert to lowercase for comparison
supported_grant_types_lower = [gt.lower() for gt in supported_grant_types]
# Determine which grant type to use
effective_grant_type = None
if MCPSupportGrantType.AUTHORIZATION_CODE.value in supported_grant_types_lower:
effective_grant_type = MCPSupportGrantType.AUTHORIZATION_CODE.value
else:
effective_grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
# Get stored credentials
credentials = provider.decrypt_credentials()
# Handle client registration if needed
client_information = provider.client_information()
if not client_information:
if authorization_code is not None:
raise ValueError("Existing OAuth client information is required when exchanging an authorization code")
# For client credentials flow, we don't need to register client dynamically
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
# Client should provide client_id and client_secret directly
raise ValueError("Client credentials flow requires client_id and client_secret to be provided")
try:
full_information = register_client(server_url, metadata, provider.client_metadata)
except httpx.RequestError as e:
full_information = register_client(server_url, server_metadata, client_metadata)
except RequestError as e:
raise ValueError(f"Could not register OAuth client: {e}")
provider.save_client_information(full_information)
# Return action to save client information
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_CLIENT_INFO,
data={"client_information": full_information.model_dump()},
provider_id=provider_id,
tenant_id=tenant_id,
)
)
client_information = full_information
# Exchange authorization code for tokens
# Handle client credentials flow
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
# Direct token request without user interaction
try:
scope = credentials.get("scope")
tokens = client_credentials_flow(
server_url,
server_metadata,
client_information,
scope,
)
# Return action to save tokens and grant type
token_data = tokens.model_dump()
token_data["grant_type"] = MCPSupportGrantType.CLIENT_CREDENTIALS.value
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_TOKENS,
data=token_data,
provider_id=provider_id,
tenant_id=tenant_id,
)
)
return AuthResult(actions=actions, response={"result": "success"})
except (RequestError, ValueError, KeyError) as e:
# RequestError: HTTP request failed
# ValueError: Invalid response data
# KeyError: Missing required fields in response
raise ValueError(f"Client credentials flow failed: {e}")
# Exchange authorization code for tokens (Authorization Code flow)
if authorization_code is not None:
if not state_param:
raise ValueError("State parameter is required when exchanging authorization code")
@ -335,35 +473,69 @@ def auth(
tokens = exchange_authorization(
server_url,
metadata,
server_metadata,
client_information,
authorization_code,
code_verifier,
redirect_uri,
)
provider.save_tokens(tokens)
return {"result": "success"}
provider_tokens = provider.tokens()
# Return action to save tokens
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_TOKENS,
data=tokens.model_dump(),
provider_id=provider_id,
tenant_id=tenant_id,
)
)
return AuthResult(actions=actions, response={"result": "success"})
provider_tokens = provider.retrieve_tokens()
# Handle token refresh or new authorization
if provider_tokens and provider_tokens.refresh_token:
try:
new_tokens = refresh_authorization(server_url, metadata, client_information, provider_tokens.refresh_token)
provider.save_tokens(new_tokens)
return {"result": "success"}
except Exception as e:
new_tokens = refresh_authorization(
server_url, server_metadata, client_information, provider_tokens.refresh_token
)
# Return action to save new tokens
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_TOKENS,
data=new_tokens.model_dump(),
provider_id=provider_id,
tenant_id=tenant_id,
)
)
return AuthResult(actions=actions, response={"result": "success"})
except (RequestError, ValueError, KeyError) as e:
# RequestError: HTTP request failed
# ValueError: Invalid response data
# KeyError: Missing required fields in response
raise ValueError(f"Could not refresh OAuth tokens: {e}")
# Start new authorization flow
# Start new authorization flow (only for authorization code flow)
authorization_url, code_verifier = start_authorization(
server_url,
metadata,
server_metadata,
client_information,
provider.redirect_url,
provider.mcp_provider.id,
provider.mcp_provider.tenant_id,
redirect_url,
provider_id,
tenant_id,
)
provider.save_code_verifier(code_verifier)
return {"authorization_url": authorization_url}
# Return action to save code verifier
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_CODE_VERIFIER,
data={"code_verifier": code_verifier},
provider_id=provider_id,
tenant_id=tenant_id,
)
)
return AuthResult(actions=actions, response={"authorization_url": authorization_url})

View File

@ -1,77 +0,0 @@
from configs import dify_config
from core.mcp.types import (
OAuthClientInformation,
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthTokens,
)
from models.tools import MCPToolProvider
from services.tools.mcp_tools_manage_service import MCPToolManageService
class OAuthClientProvider:
mcp_provider: MCPToolProvider
def __init__(self, provider_id: str, tenant_id: str, for_list: bool = False):
if for_list:
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
else:
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_server_identifier(provider_id, tenant_id)
@property
def redirect_url(self) -> str:
"""The URL to redirect the user agent to after authorization."""
return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
@property
def client_metadata(self) -> OAuthClientMetadata:
"""Metadata about this OAuth client."""
return OAuthClientMetadata(
redirect_uris=[self.redirect_url],
token_endpoint_auth_method="none",
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
client_name="Dify",
client_uri="https://github.com/langgenius/dify",
)
def client_information(self) -> OAuthClientInformation | None:
"""Loads information about this OAuth client."""
client_information = self.mcp_provider.decrypted_credentials.get("client_information", {})
if not client_information:
return None
return OAuthClientInformation.model_validate(client_information)
def save_client_information(self, client_information: OAuthClientInformationFull):
"""Saves client information after dynamic registration."""
MCPToolManageService.update_mcp_provider_credentials(
self.mcp_provider,
{"client_information": client_information.model_dump()},
)
def tokens(self) -> OAuthTokens | None:
"""Loads any existing OAuth tokens for the current session."""
credentials = self.mcp_provider.decrypted_credentials
if not credentials:
return None
return OAuthTokens(
access_token=credentials.get("access_token", ""),
token_type=credentials.get("token_type", "Bearer"),
expires_in=int(credentials.get("expires_in", "3600") or 3600),
refresh_token=credentials.get("refresh_token", ""),
)
def save_tokens(self, tokens: OAuthTokens):
"""Stores new OAuth tokens for the current session."""
# update mcp provider credentials
token_dict = tokens.model_dump()
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True)
def save_code_verifier(self, code_verifier: str):
"""Saves a PKCE code verifier for the current session."""
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier})
def code_verifier(self) -> str:
"""Loads the PKCE code verifier for the current session."""
# get code verifier from mcp provider credentials
return str(self.mcp_provider.decrypted_credentials.get("code_verifier", ""))

191
api/core/mcp/auth_client.py Normal file
View File

@ -0,0 +1,191 @@
"""
MCP Client with Authentication Retry Support
This module provides an enhanced MCPClient that automatically handles
authentication failures and retries operations after refreshing tokens.
"""
import logging
from collections.abc import Callable
from typing import Any
from sqlalchemy.orm import Session
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.error import MCPAuthError
from core.mcp.mcp_client import MCPClient
from core.mcp.types import CallToolResult, Tool
from extensions.ext_database import db
logger = logging.getLogger(__name__)
class MCPClientWithAuthRetry(MCPClient):
"""
An enhanced MCPClient that provides automatic authentication retry.
This class extends MCPClient and intercepts MCPAuthError exceptions
to refresh authentication before retrying failed operations.
Note: This class uses lazy session creation - database sessions are only
created when authentication retry is actually needed, not on every request.
"""
def __init__(
self,
server_url: str,
headers: dict[str, str] | None = None,
timeout: float | None = None,
sse_read_timeout: float | None = None,
provider_entity: MCPProviderEntity | None = None,
authorization_code: str | None = None,
by_server_id: bool = False,
):
"""
Initialize the MCP client with auth retry capability.
Args:
server_url: The MCP server URL
headers: Optional headers for requests
timeout: Request timeout
sse_read_timeout: SSE read timeout
provider_entity: Provider entity for authentication
authorization_code: Optional authorization code for initial auth
by_server_id: Whether to look up provider by server ID
"""
super().__init__(server_url, headers, timeout, sse_read_timeout)
self.provider_entity = provider_entity
self.authorization_code = authorization_code
self.by_server_id = by_server_id
self._has_retried = False
def _handle_auth_error(self, error: MCPAuthError) -> None:
"""
Handle authentication error by refreshing tokens.
This method creates a short-lived database session only when authentication
retry is needed, minimizing database connection hold time.
Args:
error: The authentication error
Raises:
MCPAuthError: If authentication fails or max retries reached
"""
if not self.provider_entity:
raise error
if self._has_retried:
raise error
self._has_retried = True
try:
# Create a temporary session only for auth retry
# This session is short-lived and only exists during the auth operation
from services.tools.mcp_tools_manage_service import MCPToolManageService
with Session(db.engine) as session, session.begin():
mcp_service = MCPToolManageService(session=session)
# Perform authentication using the service's auth method
mcp_service.auth_with_actions(self.provider_entity, self.authorization_code)
# Retrieve new tokens
self.provider_entity = mcp_service.get_provider_entity(
self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id
)
# Session is closed here, before we update headers
token = self.provider_entity.retrieve_tokens()
if not token:
raise MCPAuthError("Authentication failed - no token received")
# Update headers with new token
self.headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
# Clear authorization code after first use
self.authorization_code = None
except MCPAuthError:
# Re-raise MCPAuthError as is
raise
except Exception as e:
# Catch all exceptions during auth retry
logger.exception("Authentication retry failed")
raise MCPAuthError(f"Authentication retry failed: {e}") from e
def _execute_with_retry(self, func: Callable[..., Any], *args, **kwargs) -> Any:
"""
Execute a function with authentication retry logic.
Args:
func: The function to execute
*args: Positional arguments for the function
**kwargs: Keyword arguments for the function
Returns:
The result of the function call
Raises:
MCPAuthError: If authentication fails after retries
Any other exceptions from the function
"""
try:
return func(*args, **kwargs)
except MCPAuthError as e:
self._handle_auth_error(e)
# Re-initialize the connection with new headers
if self._initialized:
# Clean up existing connection
self._exit_stack.close()
self._session = None
self._initialized = False
# Re-initialize with new headers
self._initialize()
self._initialized = True
return func(*args, **kwargs)
finally:
# Reset retry flag after operation completes
self._has_retried = False
def __enter__(self):
"""Enter the context manager with retry support."""
def initialize_with_retry():
super(MCPClientWithAuthRetry, self).__enter__()
return self
return self._execute_with_retry(initialize_with_retry)
def list_tools(self) -> list[Tool]:
"""
List available tools from the MCP server with auth retry.
Returns:
List of available tools
Raises:
MCPAuthError: If authentication fails after retries
"""
return self._execute_with_retry(super().list_tools)
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
"""
Invoke a tool on the MCP server with auth retry.
Args:
tool_name: Name of the tool to invoke
tool_args: Arguments for the tool
Returns:
Result of the tool invocation
Raises:
MCPAuthError: If authentication fails after retries
"""
return self._execute_with_retry(super().invoke_tool, tool_name, tool_args)

View File

View File

@ -46,7 +46,7 @@ class SSETransport:
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5.0,
sse_read_timeout: float = 5 * 60,
sse_read_timeout: float = 1 * 60,
):
"""Initialize the SSE transport.
@ -255,7 +255,7 @@ def sse_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5.0,
sse_read_timeout: float = 5 * 60,
sse_read_timeout: float = 1 * 60,
) -> Generator[tuple[ReadQueue, WriteQueue], None, None]:
"""
Client transport for SSE.
@ -276,31 +276,34 @@ def sse_client(
read_queue: ReadQueue | None = None
write_queue: WriteQueue | None = None
with ThreadPoolExecutor() as executor:
try:
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
with ssrf_proxy_sse_connect(
url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
) as event_source:
event_source.response.raise_for_status()
executor = ThreadPoolExecutor()
try:
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
with ssrf_proxy_sse_connect(
url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
) as event_source:
event_source.response.raise_for_status()
read_queue, write_queue = transport.connect(executor, client, event_source)
read_queue, write_queue = transport.connect(executor, client, event_source)
yield read_queue, write_queue
yield read_queue, write_queue
except httpx.HTTPStatusError as exc:
if exc.response.status_code == 401:
raise MCPAuthError()
raise MCPConnectionError()
except Exception:
logger.exception("Error connecting to SSE endpoint")
raise
finally:
# Clean up queues
if read_queue:
read_queue.put(None)
if write_queue:
write_queue.put(None)
except httpx.HTTPStatusError as exc:
if exc.response.status_code == 401:
raise MCPAuthError()
raise MCPConnectionError()
except Exception:
logger.exception("Error connecting to SSE endpoint")
raise
finally:
# Clean up queues
if read_queue:
read_queue.put(None)
if write_queue:
write_queue.put(None)
# Shutdown executor without waiting to prevent hanging
executor.shutdown(wait=False)
def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage):

View File

@ -434,45 +434,48 @@ def streamablehttp_client(
server_to_client_queue: ServerToClientQueue = queue.Queue() # For messages FROM server TO client
client_to_server_queue: ClientToServerQueue = queue.Queue() # For messages FROM client TO server
with ThreadPoolExecutor(max_workers=2) as executor:
try:
with create_ssrf_proxy_mcp_http_client(
headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
) as client:
# Define callbacks that need access to thread pool
def start_get_stream():
"""Start a worker thread to handle server-initiated messages."""
executor.submit(transport.handle_get_stream, client, server_to_client_queue)
executor = ThreadPoolExecutor(max_workers=2)
try:
with create_ssrf_proxy_mcp_http_client(
headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
) as client:
# Define callbacks that need access to thread pool
def start_get_stream():
"""Start a worker thread to handle server-initiated messages."""
executor.submit(transport.handle_get_stream, client, server_to_client_queue)
# Start the post_writer worker thread
executor.submit(
transport.post_writer,
client,
client_to_server_queue, # Queue for messages FROM client TO server
server_to_client_queue, # Queue for messages FROM server TO client
start_get_stream,
)
# Start the post_writer worker thread
executor.submit(
transport.post_writer,
client,
client_to_server_queue, # Queue for messages FROM client TO server
server_to_client_queue, # Queue for messages FROM server TO client
start_get_stream,
)
try:
yield (
server_to_client_queue, # Queue for receiving messages FROM server
client_to_server_queue, # Queue for sending messages TO server
transport.get_session_id,
)
finally:
if transport.session_id and terminate_on_close:
transport.terminate_session(client)
# Signal threads to stop
client_to_server_queue.put(None)
finally:
# Clear any remaining items and add None sentinel to unblock any waiting threads
try:
while not client_to_server_queue.empty():
client_to_server_queue.get_nowait()
except queue.Empty:
pass
yield (
server_to_client_queue, # Queue for receiving messages FROM server
client_to_server_queue, # Queue for sending messages TO server
transport.get_session_id,
)
finally:
if transport.session_id and terminate_on_close:
transport.terminate_session(client)
client_to_server_queue.put(None)
server_to_client_queue.put(None)
# Signal threads to stop
client_to_server_queue.put(None)
finally:
# Clear any remaining items and add None sentinel to unblock any waiting threads
try:
while not client_to_server_queue.empty():
client_to_server_queue.get_nowait()
except queue.Empty:
pass
client_to_server_queue.put(None)
server_to_client_queue.put(None)
# Shutdown executor without waiting to prevent hanging
executor.shutdown(wait=False)

View File

@ -1,10 +1,13 @@
from dataclasses import dataclass
from enum import StrEnum
from typing import Any, Generic, TypeVar
from core.mcp.session.base_session import BaseSession
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestId, RequestParams
from pydantic import BaseModel
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", LATEST_PROTOCOL_VERSION]
from core.mcp.session.base_session import BaseSession
from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAuthMetadata, RequestId, RequestParams
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
@ -17,3 +20,41 @@ class RequestContext(Generic[SessionT, LifespanContextT]):
meta: RequestParams.Meta | None
session: SessionT
lifespan_context: LifespanContextT
class AuthActionType(StrEnum):
"""Types of actions that can be performed during auth flow."""
SAVE_CLIENT_INFO = "save_client_info"
SAVE_TOKENS = "save_tokens"
SAVE_CODE_VERIFIER = "save_code_verifier"
START_AUTHORIZATION = "start_authorization"
SUCCESS = "success"
class AuthAction(BaseModel):
"""Represents an action that needs to be performed as a result of auth flow."""
action_type: AuthActionType
data: dict[str, Any]
provider_id: str | None = None
tenant_id: str | None = None
class AuthResult(BaseModel):
"""Result of auth function containing actions to be performed and response data."""
actions: list[AuthAction]
response: dict[str, str]
class OAuthCallbackState(BaseModel):
"""State data stored in Redis during OAuth callback flow."""
provider_id: str
tenant_id: str
server_url: str
metadata: OAuthMetadata | None = None
client_information: OAuthClientInformation
code_verifier: str
redirect_uri: str

View File

@ -8,3 +8,7 @@ class MCPConnectionError(MCPError):
class MCPAuthError(MCPConnectionError):
pass
class MCPRefreshTokenError(MCPError):
pass

View File

@ -7,9 +7,9 @@ from urllib.parse import urlparse
from core.mcp.client.sse_client import sse_client
from core.mcp.client.streamable_client import streamablehttp_client
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.error import MCPConnectionError
from core.mcp.session.client_session import ClientSession
from core.mcp.types import Tool
from core.mcp.types import CallToolResult, Tool
logger = logging.getLogger(__name__)
@ -18,40 +18,18 @@ class MCPClient:
def __init__(
self,
server_url: str,
provider_id: str,
tenant_id: str,
authed: bool = True,
authorization_code: str | None = None,
for_list: bool = False,
headers: dict[str, str] | None = None,
timeout: float | None = None,
sse_read_timeout: float | None = None,
):
# Initialize info
self.provider_id = provider_id
self.tenant_id = tenant_id
self.client_type = "streamable"
self.server_url = server_url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
# Authentication info
self.authed = authed
self.authorization_code = authorization_code
if authed:
from core.mcp.auth.auth_provider import OAuthClientProvider
self.provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=for_list)
self.token = self.provider.tokens()
# Initialize session and client objects
self._session: ClientSession | None = None
self._streams_context: AbstractContextManager[Any] | None = None
self._session_context: ClientSession | None = None
self._exit_stack = ExitStack()
# Whether the client has been initialized
self._initialized = False
def __enter__(self):
@ -85,61 +63,42 @@ class MCPClient:
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
self.connect_server(streamablehttp_client, "mcp")
def connect_server(
self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str, first_try: bool = True
):
from core.mcp.auth.auth_flow import auth
def connect_server(self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str) -> None:
"""
Connect to the MCP server using streamable http or sse.
Default to streamable http.
Args:
client_factory: The client factory to use(streamablehttp_client or sse_client).
method_name: The method name to use(mcp or sse).
"""
streams_context = client_factory(
url=self.server_url,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)
try:
headers = (
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
if self.authed and self.token
else self.headers
)
self._streams_context = client_factory(
url=self.server_url,
headers=headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)
if not self._streams_context:
raise MCPConnectionError("Failed to create connection context")
# Use exit_stack to manage context managers properly
if method_name == "mcp":
read_stream, write_stream, _ = self._exit_stack.enter_context(streams_context)
streams = (read_stream, write_stream)
else: # sse_client
streams = self._exit_stack.enter_context(streams_context)
# Use exit_stack to manage context managers properly
if method_name == "mcp":
read_stream, write_stream, _ = self._exit_stack.enter_context(self._streams_context)
streams = (read_stream, write_stream)
else: # sse_client
streams = self._exit_stack.enter_context(self._streams_context)
self._session_context = ClientSession(*streams)
self._session = self._exit_stack.enter_context(self._session_context)
self._session.initialize()
return
except MCPAuthError:
if not self.authed:
raise
try:
auth(self.provider, self.server_url, self.authorization_code)
except Exception as e:
raise ValueError(f"Failed to authenticate: {e}")
self.token = self.provider.tokens()
if first_try:
return self.connect_server(client_factory, method_name, first_try=False)
session_context = ClientSession(*streams)
self._session = self._exit_stack.enter_context(session_context)
self._session.initialize()
def list_tools(self) -> list[Tool]:
"""Connect to an MCP server running with SSE transport"""
# List available tools to verify connection
if not self._initialized or not self._session:
"""List available tools from the MCP server"""
if not self._session:
raise ValueError("Session not initialized.")
response = self._session.list_tools()
tools = response.tools
return tools
return response.tools
def invoke_tool(self, tool_name: str, tool_args: dict):
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
"""Call a tool"""
if not self._initialized or not self._session:
if not self._session:
raise ValueError("Session not initialized.")
return self._session.call_tool(tool_name, tool_args)
@ -153,6 +112,4 @@ class MCPClient:
raise ValueError(f"Error during cleanup: {e}")
finally:
self._session = None
self._session_context = None
self._streams_context = None
self._initialized = False

View File

@ -201,11 +201,14 @@ class BaseSession(
self._receiver_future.result(timeout=5.0) # Wait up to 5 seconds
except TimeoutError:
# If the receiver loop is still running after timeout, we'll force shutdown
pass
# Cancel the future to interrupt the receiver loop
self._receiver_future.cancel()
# Shutdown the executor
if self._executor:
self._executor.shutdown(wait=True)
# Use non-blocking shutdown to prevent hanging
# The receiver thread should have already exited due to the None message in the queue
self._executor.shutdown(wait=False)
def send_request(
self,

View File

@ -109,12 +109,16 @@ class ClientSession(
self._message_handler = message_handler or _default_message_handler
def initialize(self) -> types.InitializeResult:
sampling = types.SamplingCapability()
roots = types.RootsCapability(
# TODO: Should this be based on whether we
# _will_ send notifications, or only whether
# they're supported?
listChanged=True,
# Only set capabilities if non-default callbacks are provided
# This prevents servers from attempting callbacks when we don't actually support them
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
roots = (
types.RootsCapability(
# Only enable listChanged if we have a custom callback
listChanged=True,
)
if self._list_roots_callback is not _default_list_roots_callback
else None
)
result = self.send_request(
@ -284,7 +288,7 @@ class ClientSession(
def complete(
self,
ref: types.ResourceReference | types.PromptReference,
ref: types.ResourceTemplateReference | types.PromptReference,
argument: dict[str, str],
) -> types.CompleteResult:
"""Send a completion/complete request."""

View File

@ -1,13 +1,6 @@
from collections.abc import Callable
from dataclasses import dataclass
from typing import (
Annotated,
Any,
Generic,
Literal,
TypeAlias,
TypeVar,
)
from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar
from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel
from pydantic.networks import AnyUrl, UrlConstraints
@ -33,6 +26,7 @@ for reference.
LATEST_PROTOCOL_VERSION = "2025-03-26"
# Server support 2024-11-05 to allow claude to use.
SERVER_LATEST_PROTOCOL_VERSION = "2024-11-05"
DEFAULT_NEGOTIATED_VERSION = "2025-03-26"
ProgressToken = str | int
Cursor = str
Role = Literal["user", "assistant"]
@ -55,14 +49,22 @@ class RequestParams(BaseModel):
meta: Meta | None = Field(alias="_meta", default=None)
class PaginatedRequestParams(RequestParams):
cursor: Cursor | None = None
"""
An opaque token representing the current pagination position.
If provided, the server should return results starting after this cursor.
"""
class NotificationParams(BaseModel):
class Meta(BaseModel):
model_config = ConfigDict(extra="allow")
meta: Meta | None = Field(alias="_meta", default=None)
"""
This parameter name is reserved by MCP to allow clients and servers to attach
additional metadata to their notifications.
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
@ -79,12 +81,11 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]):
model_config = ConfigDict(extra="allow")
class PaginatedRequest(Request[RequestParamsT, MethodT]):
cursor: Cursor | None = None
"""
An opaque token representing the current pagination position.
If provided, the server should return results starting after this cursor.
"""
class PaginatedRequest(Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]):
"""Base class for paginated requests,
matching the schema's PaginatedRequest interface."""
params: PaginatedRequestParams | None = None
class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
@ -98,13 +99,12 @@ class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
class Result(BaseModel):
"""Base class for JSON-RPC results."""
model_config = ConfigDict(extra="allow")
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
This result property is reserved by the protocol to allow clients and servers to
attach additional metadata to their responses.
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
class PaginatedResult(Result):
@ -186,10 +186,26 @@ class EmptyResult(Result):
"""A response that indicates success but carries no data."""
class Implementation(BaseModel):
"""Describes the name and version of an MCP implementation."""
class BaseMetadata(BaseModel):
"""Base class for entities with name and optional title fields."""
name: str
"""The programmatic name of the entity."""
title: str | None = None
"""
Intended for UI and end-user contexts optimized to be human-readable and easily understood,
even by those unfamiliar with domain-specific terminology.
If not provided, the name should be used for display (except for Tool,
where `annotations.title` should be given precedence over using `name`,
if present).
"""
class Implementation(BaseMetadata):
"""Describes the name and version of an MCP implementation."""
version: str
model_config = ConfigDict(extra="allow")
@ -203,7 +219,7 @@ class RootsCapability(BaseModel):
class SamplingCapability(BaseModel):
"""Capability for logging operations."""
"""Capability for sampling operations."""
model_config = ConfigDict(extra="allow")
@ -252,6 +268,12 @@ class LoggingCapability(BaseModel):
model_config = ConfigDict(extra="allow")
class CompletionsCapability(BaseModel):
"""Capability for completions operations."""
model_config = ConfigDict(extra="allow")
class ServerCapabilities(BaseModel):
"""Capabilities that a server may support."""
@ -265,6 +287,8 @@ class ServerCapabilities(BaseModel):
"""Present if the server offers any resources to read."""
tools: ToolsCapability | None = None
"""Present if the server offers any tools to call."""
completions: CompletionsCapability | None = None
"""Present if the server offers autocompletion suggestions for prompts and resources."""
model_config = ConfigDict(extra="allow")
@ -284,7 +308,7 @@ class InitializeRequest(Request[InitializeRequestParams, Literal["initialize"]])
to begin initialization.
"""
method: Literal["initialize"]
method: Literal["initialize"] = "initialize"
params: InitializeRequestParams
@ -305,7 +329,7 @@ class InitializedNotification(Notification[NotificationParams | None, Literal["n
finished.
"""
method: Literal["notifications/initialized"]
method: Literal["notifications/initialized"] = "notifications/initialized"
params: NotificationParams | None = None
@ -315,7 +339,7 @@ class PingRequest(Request[RequestParams | None, Literal["ping"]]):
still alive.
"""
method: Literal["ping"]
method: Literal["ping"] = "ping"
params: RequestParams | None = None
@ -334,6 +358,11 @@ class ProgressNotificationParams(NotificationParams):
"""
total: float | None = None
"""Total number of items to process (or total progress required), if known."""
message: str | None = None
"""
Message related to progress. This should provide relevant human readable
progress information.
"""
model_config = ConfigDict(extra="allow")
@ -343,15 +372,14 @@ class ProgressNotification(Notification[ProgressNotificationParams, Literal["not
long-running request.
"""
method: Literal["notifications/progress"]
method: Literal["notifications/progress"] = "notifications/progress"
params: ProgressNotificationParams
class ListResourcesRequest(PaginatedRequest[RequestParams | None, Literal["resources/list"]]):
class ListResourcesRequest(PaginatedRequest[Literal["resources/list"]]):
"""Sent from the client to request a list of resources the server has."""
method: Literal["resources/list"]
params: RequestParams | None = None
method: Literal["resources/list"] = "resources/list"
class Annotations(BaseModel):
@ -360,13 +388,11 @@ class Annotations(BaseModel):
model_config = ConfigDict(extra="allow")
class Resource(BaseModel):
class Resource(BaseMetadata):
"""A known resource that the server is capable of reading."""
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
"""The URI of this resource."""
name: str
"""A human-readable name for this resource."""
description: str | None = None
"""A description of what this resource represents."""
mimeType: str | None = None
@ -379,10 +405,15 @@ class Resource(BaseModel):
This can be used by Hosts to display file sizes and estimate context window usage.
"""
annotations: Annotations | None = None
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
class ResourceTemplate(BaseModel):
class ResourceTemplate(BaseMetadata):
"""A template description for resources available on the server."""
uriTemplate: str
@ -390,8 +421,6 @@ class ResourceTemplate(BaseModel):
A URI template (according to RFC 6570) that can be used to construct resource
URIs.
"""
name: str
"""A human-readable name for the type of resource this template refers to."""
description: str | None = None
"""A human-readable description of what this template is for."""
mimeType: str | None = None
@ -400,6 +429,11 @@ class ResourceTemplate(BaseModel):
included if all resources matching this template have the same type.
"""
annotations: Annotations | None = None
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
@ -409,11 +443,10 @@ class ListResourcesResult(PaginatedResult):
resources: list[Resource]
class ListResourceTemplatesRequest(PaginatedRequest[RequestParams | None, Literal["resources/templates/list"]]):
class ListResourceTemplatesRequest(PaginatedRequest[Literal["resources/templates/list"]]):
"""Sent from the client to request a list of resource templates the server has."""
method: Literal["resources/templates/list"]
params: RequestParams | None = None
method: Literal["resources/templates/list"] = "resources/templates/list"
class ListResourceTemplatesResult(PaginatedResult):
@ -436,7 +469,7 @@ class ReadResourceRequestParams(RequestParams):
class ReadResourceRequest(Request[ReadResourceRequestParams, Literal["resources/read"]]):
"""Sent from the client to the server, to read a specific resource URI."""
method: Literal["resources/read"]
method: Literal["resources/read"] = "resources/read"
params: ReadResourceRequestParams
@ -447,6 +480,11 @@ class ResourceContents(BaseModel):
"""The URI of this resource."""
mimeType: str | None = None
"""The MIME type of this resource, if known."""
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
@ -481,7 +519,7 @@ class ResourceListChangedNotification(
of resources it can read from has changed.
"""
method: Literal["notifications/resources/list_changed"]
method: Literal["notifications/resources/list_changed"] = "notifications/resources/list_changed"
params: NotificationParams | None = None
@ -502,7 +540,7 @@ class SubscribeRequest(Request[SubscribeRequestParams, Literal["resources/subscr
whenever a particular resource changes.
"""
method: Literal["resources/subscribe"]
method: Literal["resources/subscribe"] = "resources/subscribe"
params: SubscribeRequestParams
@ -520,7 +558,7 @@ class UnsubscribeRequest(Request[UnsubscribeRequestParams, Literal["resources/un
the server.
"""
method: Literal["resources/unsubscribe"]
method: Literal["resources/unsubscribe"] = "resources/unsubscribe"
params: UnsubscribeRequestParams
@ -543,15 +581,14 @@ class ResourceUpdatedNotification(
changed and may need to be read again.
"""
method: Literal["notifications/resources/updated"]
method: Literal["notifications/resources/updated"] = "notifications/resources/updated"
params: ResourceUpdatedNotificationParams
class ListPromptsRequest(PaginatedRequest[RequestParams | None, Literal["prompts/list"]]):
class ListPromptsRequest(PaginatedRequest[Literal["prompts/list"]]):
"""Sent from the client to request a list of prompts and prompt templates."""
method: Literal["prompts/list"]
params: RequestParams | None = None
method: Literal["prompts/list"] = "prompts/list"
class PromptArgument(BaseModel):
@ -566,15 +603,18 @@ class PromptArgument(BaseModel):
model_config = ConfigDict(extra="allow")
class Prompt(BaseModel):
class Prompt(BaseMetadata):
"""A prompt or prompt template that the server offers."""
name: str
"""The name of the prompt or prompt template."""
description: str | None = None
"""An optional description of what this prompt provides."""
arguments: list[PromptArgument] | None = None
"""A list of arguments to use for templating the prompt."""
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
@ -597,7 +637,7 @@ class GetPromptRequestParams(RequestParams):
class GetPromptRequest(Request[GetPromptRequestParams, Literal["prompts/get"]]):
"""Used by the client to get a prompt provided by the server."""
method: Literal["prompts/get"]
method: Literal["prompts/get"] = "prompts/get"
params: GetPromptRequestParams
@ -608,6 +648,11 @@ class TextContent(BaseModel):
text: str
"""The text content of the message."""
annotations: Annotations | None = None
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
@ -623,6 +668,31 @@ class ImageContent(BaseModel):
image types.
"""
annotations: Annotations | None = None
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
class AudioContent(BaseModel):
"""Audio content for a message."""
type: Literal["audio"]
data: str
"""The base64-encoded audio data."""
mimeType: str
"""
The MIME type of the audio. Different providers may support different
audio types.
"""
annotations: Annotations | None = None
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
@ -630,7 +700,7 @@ class SamplingMessage(BaseModel):
"""Describes a message issued to or received from an LLM API."""
role: Role
content: TextContent | ImageContent
content: TextContent | ImageContent | AudioContent
model_config = ConfigDict(extra="allow")
@ -645,14 +715,36 @@ class EmbeddedResource(BaseModel):
type: Literal["resource"]
resource: TextResourceContents | BlobResourceContents
annotations: Annotations | None = None
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
class ResourceLink(Resource):
"""
A resource that the server is capable of reading, included in a prompt or tool call result.
Note: resource links returned by tools are not guaranteed to appear in the results of `resources/list` requests.
"""
type: Literal["resource_link"]
ContentBlock = TextContent | ImageContent | AudioContent | ResourceLink | EmbeddedResource
"""A content block that can be used in prompts and tool results."""
Content: TypeAlias = ContentBlock
# """DEPRECATED: Content is deprecated, you should use ContentBlock directly."""
class PromptMessage(BaseModel):
"""Describes a message returned as part of a prompt."""
role: Role
content: TextContent | ImageContent | EmbeddedResource
content: ContentBlock
model_config = ConfigDict(extra="allow")
@ -672,15 +764,14 @@ class PromptListChangedNotification(
of prompts it offers has changed.
"""
method: Literal["notifications/prompts/list_changed"]
method: Literal["notifications/prompts/list_changed"] = "notifications/prompts/list_changed"
params: NotificationParams | None = None
class ListToolsRequest(PaginatedRequest[RequestParams | None, Literal["tools/list"]]):
class ListToolsRequest(PaginatedRequest[Literal["tools/list"]]):
"""Sent from the client to request a list of tools the server has."""
method: Literal["tools/list"]
params: RequestParams | None = None
method: Literal["tools/list"] = "tools/list"
class ToolAnnotations(BaseModel):
@ -731,17 +822,25 @@ class ToolAnnotations(BaseModel):
model_config = ConfigDict(extra="allow")
class Tool(BaseModel):
class Tool(BaseMetadata):
"""Definition for a tool the client can call."""
name: str
"""The name of the tool."""
description: str | None = None
"""A human-readable description of the tool."""
inputSchema: dict[str, Any]
"""A JSON Schema object defining the expected parameters for the tool."""
outputSchema: dict[str, Any] | None = None
"""
An optional JSON Schema object defining the structure of the tool's output
returned in the structuredContent field of a CallToolResult.
"""
annotations: ToolAnnotations | None = None
"""Optional additional tool information."""
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
@ -762,14 +861,16 @@ class CallToolRequestParams(RequestParams):
class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]):
"""Used by the client to invoke a tool provided by the server."""
method: Literal["tools/call"]
method: Literal["tools/call"] = "tools/call"
params: CallToolRequestParams
class CallToolResult(Result):
"""The server's response to a tool call."""
content: list[TextContent | ImageContent | EmbeddedResource]
content: list[ContentBlock]
structuredContent: dict[str, Any] | None = None
"""An optional JSON object that represents the structured result of the tool call."""
isError: bool = False
@ -779,7 +880,7 @@ class ToolListChangedNotification(Notification[NotificationParams | None, Litera
of tools it offers has changed.
"""
method: Literal["notifications/tools/list_changed"]
method: Literal["notifications/tools/list_changed"] = "notifications/tools/list_changed"
params: NotificationParams | None = None
@ -797,7 +898,7 @@ class SetLevelRequestParams(RequestParams):
class SetLevelRequest(Request[SetLevelRequestParams, Literal["logging/setLevel"]]):
"""A request from the client to the server, to enable or adjust logging."""
method: Literal["logging/setLevel"]
method: Literal["logging/setLevel"] = "logging/setLevel"
params: SetLevelRequestParams
@ -808,7 +909,7 @@ class LoggingMessageNotificationParams(NotificationParams):
"""The severity of this log message."""
logger: str | None = None
"""An optional name of the logger issuing this message."""
data: Any = None
data: Any
"""
The data to be logged, such as a string message or an object. Any JSON serializable
type is allowed here.
@ -819,7 +920,7 @@ class LoggingMessageNotificationParams(NotificationParams):
class LoggingMessageNotification(Notification[LoggingMessageNotificationParams, Literal["notifications/message"]]):
"""Notification of a log message passed from server to client."""
method: Literal["notifications/message"]
method: Literal["notifications/message"] = "notifications/message"
params: LoggingMessageNotificationParams
@ -914,7 +1015,7 @@ class CreateMessageRequestParams(RequestParams):
class CreateMessageRequest(Request[CreateMessageRequestParams, Literal["sampling/createMessage"]]):
"""A request from the server to sample an LLM via the client."""
method: Literal["sampling/createMessage"]
method: Literal["sampling/createMessage"] = "sampling/createMessage"
params: CreateMessageRequestParams
@ -925,14 +1026,14 @@ class CreateMessageResult(Result):
"""The client's response to a sampling/create_message request from the server."""
role: Role
content: TextContent | ImageContent
content: TextContent | ImageContent | AudioContent
model: str
"""The name of the model that generated the message."""
stopReason: StopReason | None = None
"""The reason why sampling stopped, if known."""
class ResourceReference(BaseModel):
class ResourceTemplateReference(BaseModel):
"""A reference to a resource or resource template definition."""
type: Literal["ref/resource"]
@ -960,18 +1061,28 @@ class CompletionArgument(BaseModel):
model_config = ConfigDict(extra="allow")
class CompletionContext(BaseModel):
"""Additional, optional context for completions."""
arguments: dict[str, str] | None = None
"""Previously-resolved variables in a URI template or prompt."""
model_config = ConfigDict(extra="allow")
class CompleteRequestParams(RequestParams):
"""Parameters for completion requests."""
ref: ResourceReference | PromptReference
ref: ResourceTemplateReference | PromptReference
argument: CompletionArgument
context: CompletionContext | None = None
"""Additional, optional context for completions"""
model_config = ConfigDict(extra="allow")
class CompleteRequest(Request[CompleteRequestParams, Literal["completion/complete"]]):
"""A request from the client to the server, to ask for completion options."""
method: Literal["completion/complete"]
method: Literal["completion/complete"] = "completion/complete"
params: CompleteRequestParams
@ -1010,7 +1121,7 @@ class ListRootsRequest(Request[RequestParams | None, Literal["roots/list"]]):
structure or access specific locations that the client has permission to read from.
"""
method: Literal["roots/list"]
method: Literal["roots/list"] = "roots/list"
params: RequestParams | None = None
@ -1029,6 +1140,11 @@ class Root(BaseModel):
identifier for the root, which may be useful for display purposes or for
referencing the root in other parts of the application.
"""
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
@ -1054,7 +1170,7 @@ class RootsListChangedNotification(
using the ListRootsRequest.
"""
method: Literal["notifications/roots/list_changed"]
method: Literal["notifications/roots/list_changed"] = "notifications/roots/list_changed"
params: NotificationParams | None = None
@ -1074,7 +1190,7 @@ class CancelledNotification(Notification[CancelledNotificationParams, Literal["n
previously-issued request.
"""
method: Literal["notifications/cancelled"]
method: Literal["notifications/cancelled"] = "notifications/cancelled"
params: CancelledNotificationParams

View File

@ -1,6 +1,7 @@
from collections.abc import Sequence
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.file import file_manager
@ -18,7 +19,9 @@ from core.prompt.utils.extract_thread_messages import extract_thread_messages
from extensions.ext_database import db
from factories import file_factory
from models.model import AppMode, Conversation, Message, MessageFile
from models.workflow import Workflow, WorkflowRun
from models.workflow import Workflow
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.factory import DifyAPIRepositoryFactory
class TokenBufferMemory:
@ -29,6 +32,14 @@ class TokenBufferMemory:
):
self.conversation = conversation
self.model_instance = model_instance
self._workflow_run_repo: APIWorkflowRunRepository | None = None
@property
def workflow_run_repo(self) -> APIWorkflowRunRepository:
if self._workflow_run_repo is None:
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
return self._workflow_run_repo
def _build_prompt_message_with_files(
self,
@ -50,7 +61,16 @@ class TokenBufferMemory:
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow_run = db.session.scalar(select(WorkflowRun).where(WorkflowRun.id == message.workflow_run_id))
app = self.conversation.app
if not app:
raise ValueError("App not found for conversation")
if not message.workflow_run_id:
raise ValueError("Workflow run ID not found")
workflow_run = self.workflow_run_repo.get_workflow_run_by_id(
tenant_id=app.tenant_id, app_id=app.id, run_id=message.workflow_run_id
)
if not workflow_run:
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))

View File

@ -38,6 +38,8 @@ class LLMUsageMetadata(TypedDict, total=False):
prompt_price: Union[float, str]
completion_price: Union[float, str]
latency: float
time_to_first_token: float
time_to_generate: float
class LLMUsage(ModelUsage):
@ -57,6 +59,8 @@ class LLMUsage(ModelUsage):
total_price: Decimal
currency: str
latency: float
time_to_first_token: float | None = None
time_to_generate: float | None = None
@classmethod
def empty_usage(cls):
@ -73,6 +77,8 @@ class LLMUsage(ModelUsage):
total_price=Decimal("0.0"),
currency="USD",
latency=0.0,
time_to_first_token=None,
time_to_generate=None,
)
@classmethod
@ -108,6 +114,8 @@ class LLMUsage(ModelUsage):
prompt_price=Decimal(str(metadata.get("prompt_price", 0))),
completion_price=Decimal(str(metadata.get("completion_price", 0))),
latency=metadata.get("latency", 0.0),
time_to_first_token=metadata.get("time_to_first_token"),
time_to_generate=metadata.get("time_to_generate"),
)
def plus(self, other: LLMUsage) -> LLMUsage:
@ -133,6 +141,8 @@ class LLMUsage(ModelUsage):
total_price=self.total_price + other.total_price,
currency=other.currency,
latency=self.latency + other.latency,
time_to_first_token=other.time_to_first_token,
time_to_generate=other.time_to_generate,
)
def __add__(self, other: LLMUsage) -> LLMUsage:

View File

@ -62,6 +62,9 @@ class MessageTraceInfo(BaseTraceInfo):
file_list: Union[str, dict[str, Any], list] | None = None
message_file_data: Any | None = None
conversation_mode: str
gen_ai_server_time_to_first_token: float | None = None
llm_streaming_time_to_generate: float | None = None
is_streaming_request: bool = False
class ModerationTraceInfo(BaseTraceInfo):

View File

@ -12,9 +12,9 @@ from uuid import UUID, uuid4
from cachetools import LRUCache
from flask import current_app
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
from core.ops.entities.config_entity import (
OPS_FILE_PATH,
TracingProviderEnum,
@ -34,7 +34,8 @@ from core.ops.utils import get_message_data
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
from models.workflow import WorkflowAppLog, WorkflowRun
from models.workflow import WorkflowAppLog
from repositories.factory import DifyAPIRepositoryFactory
from tasks.ops_trace_task import process_trace_tasks
if TYPE_CHECKING:
@ -140,6 +141,8 @@ provider_config_map = OpsTraceProviderConfigMap()
class OpsTraceManager:
ops_trace_instances_cache: LRUCache = LRUCache(maxsize=128)
decrypted_configs_cache: LRUCache = LRUCache(maxsize=128)
_decryption_cache_lock = threading.RLock()
@classmethod
def encrypt_tracing_config(
@ -160,7 +163,7 @@ class OpsTraceManager:
provider_config_map[tracing_provider]["other_keys"],
)
new_config = {}
new_config: dict[str, Any] = {}
# Encrypt necessary keys
for key in secret_keys:
if key in tracing_config:
@ -190,20 +193,41 @@ class OpsTraceManager:
:param tracing_config: tracing config
:return:
"""
config_class, secret_keys, other_keys = (
provider_config_map[tracing_provider]["config_class"],
provider_config_map[tracing_provider]["secret_keys"],
provider_config_map[tracing_provider]["other_keys"],
config_json = json.dumps(tracing_config, sort_keys=True)
decrypted_config_key = (
tenant_id,
tracing_provider,
config_json,
)
new_config = {}
for key in secret_keys:
if key in tracing_config:
new_config[key] = decrypt_token(tenant_id, tracing_config[key])
for key in other_keys:
new_config[key] = tracing_config.get(key, "")
# First check without lock for performance
cached_config = cls.decrypted_configs_cache.get(decrypted_config_key)
if cached_config is not None:
return dict(cached_config)
return config_class(**new_config).model_dump()
with cls._decryption_cache_lock:
# Second check (double-checked locking) to prevent race conditions
cached_config = cls.decrypted_configs_cache.get(decrypted_config_key)
if cached_config is not None:
return dict(cached_config)
config_class, secret_keys, other_keys = (
provider_config_map[tracing_provider]["config_class"],
provider_config_map[tracing_provider]["secret_keys"],
provider_config_map[tracing_provider]["other_keys"],
)
new_config: dict[str, Any] = {}
keys_to_decrypt = [key for key in secret_keys if key in tracing_config]
if keys_to_decrypt:
decrypted_values = batch_decrypt_token(tenant_id, [tracing_config[key] for key in keys_to_decrypt])
new_config.update(zip(keys_to_decrypt, decrypted_values))
for key in other_keys:
new_config[key] = tracing_config.get(key, "")
decrypted_config = config_class(**new_config).model_dump()
cls.decrypted_configs_cache[decrypted_config_key] = decrypted_config
return dict(decrypted_config)
@classmethod
def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict):
@ -218,7 +242,7 @@ class OpsTraceManager:
provider_config_map[tracing_provider]["secret_keys"],
provider_config_map[tracing_provider]["other_keys"],
)
new_config = {}
new_config: dict[str, Any] = {}
for key in secret_keys:
if key in decrypt_tracing_config:
new_config[key] = obfuscated_token(decrypt_tracing_config[key])
@ -419,6 +443,18 @@ class OpsTraceManager:
class TraceTask:
_workflow_run_repo = None
_repo_lock = threading.Lock()
@classmethod
def _get_workflow_run_repo(cls):
if cls._workflow_run_repo is None:
with cls._repo_lock:
if cls._workflow_run_repo is None:
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
cls._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
return cls._workflow_run_repo
def __init__(
self,
trace_type: Any,
@ -486,27 +522,27 @@ class TraceTask:
if not workflow_run_id:
return {}
workflow_run_repo = self._get_workflow_run_repo()
workflow_run = workflow_run_repo.get_workflow_run_by_id_without_tenant(run_id=workflow_run_id)
if not workflow_run:
raise ValueError("Workflow run not found")
workflow_id = workflow_run.workflow_id
tenant_id = workflow_run.tenant_id
workflow_run_id = workflow_run.id
workflow_run_elapsed_time = workflow_run.elapsed_time
workflow_run_status = workflow_run.status
workflow_run_inputs = workflow_run.inputs_dict
workflow_run_outputs = workflow_run.outputs_dict
workflow_run_version = workflow_run.version
error = workflow_run.error or ""
total_tokens = workflow_run.total_tokens
file_list = workflow_run_inputs.get("sys.file") or []
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
with Session(db.engine) as session:
workflow_run_stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
workflow_run = session.scalars(workflow_run_stmt).first()
if not workflow_run:
raise ValueError("Workflow run not found")
workflow_id = workflow_run.workflow_id
tenant_id = workflow_run.tenant_id
workflow_run_id = workflow_run.id
workflow_run_elapsed_time = workflow_run.elapsed_time
workflow_run_status = workflow_run.status
workflow_run_inputs = workflow_run.inputs_dict
workflow_run_outputs = workflow_run.outputs_dict
workflow_run_version = workflow_run.version
error = workflow_run.error or ""
total_tokens = workflow_run.total_tokens
file_list = workflow_run_inputs.get("sys.file") or []
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
# get workflow_app_log_id
workflow_app_log_data_stmt = select(WorkflowAppLog.id).where(
WorkflowAppLog.tenant_id == tenant_id,
@ -523,43 +559,43 @@ class TraceTask:
)
message_id = session.scalar(message_data_stmt)
metadata = {
"workflow_id": workflow_id,
"conversation_id": conversation_id,
"workflow_run_id": workflow_run_id,
"tenant_id": tenant_id,
"elapsed_time": workflow_run_elapsed_time,
"status": workflow_run_status,
"version": workflow_run_version,
"total_tokens": total_tokens,
"file_list": file_list,
"triggered_from": workflow_run.triggered_from,
"user_id": user_id,
"app_id": workflow_run.app_id,
}
metadata = {
"workflow_id": workflow_id,
"conversation_id": conversation_id,
"workflow_run_id": workflow_run_id,
"tenant_id": tenant_id,
"elapsed_time": workflow_run_elapsed_time,
"status": workflow_run_status,
"version": workflow_run_version,
"total_tokens": total_tokens,
"file_list": file_list,
"triggered_from": workflow_run.triggered_from,
"user_id": user_id,
"app_id": workflow_run.app_id,
}
workflow_trace_info = WorkflowTraceInfo(
trace_id=self.trace_id,
workflow_data=workflow_run.to_dict(),
conversation_id=conversation_id,
workflow_id=workflow_id,
tenant_id=tenant_id,
workflow_run_id=workflow_run_id,
workflow_run_elapsed_time=workflow_run_elapsed_time,
workflow_run_status=workflow_run_status,
workflow_run_inputs=workflow_run_inputs,
workflow_run_outputs=workflow_run_outputs,
workflow_run_version=workflow_run_version,
error=error,
total_tokens=total_tokens,
file_list=file_list,
query=query,
metadata=metadata,
workflow_app_log_id=workflow_app_log_id,
message_id=message_id,
start_time=workflow_run.created_at,
end_time=workflow_run.finished_at,
)
workflow_trace_info = WorkflowTraceInfo(
trace_id=self.trace_id,
workflow_data=workflow_run.to_dict(),
conversation_id=conversation_id,
workflow_id=workflow_id,
tenant_id=tenant_id,
workflow_run_id=workflow_run_id,
workflow_run_elapsed_time=workflow_run_elapsed_time,
workflow_run_status=workflow_run_status,
workflow_run_inputs=workflow_run_inputs,
workflow_run_outputs=workflow_run_outputs,
workflow_run_version=workflow_run_version,
error=error,
total_tokens=total_tokens,
file_list=file_list,
query=query,
metadata=metadata,
workflow_app_log_id=workflow_app_log_id,
message_id=message_id,
start_time=workflow_run.created_at,
end_time=workflow_run.finished_at,
)
return workflow_trace_info
def message_trace(self, message_id: str | None):
@ -583,6 +619,8 @@ class TraceTask:
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
file_list.append(file_url)
streaming_metrics = self._extract_streaming_metrics(message_data)
metadata = {
"conversation_id": message_data.conversation_id,
"ls_provider": message_data.model_provider,
@ -615,6 +653,9 @@ class TraceTask:
metadata=metadata,
message_file_data=message_file_data,
conversation_mode=conversation_mode,
gen_ai_server_time_to_first_token=streaming_metrics.get("gen_ai_server_time_to_first_token"),
llm_streaming_time_to_generate=streaming_metrics.get("llm_streaming_time_to_generate"),
is_streaming_request=streaming_metrics.get("is_streaming_request", False),
)
return message_trace_info
@ -840,6 +881,24 @@ class TraceTask:
return generate_name_trace_info
def _extract_streaming_metrics(self, message_data) -> dict:
if not message_data.message_metadata:
return {}
try:
metadata = json.loads(message_data.message_metadata)
usage = metadata.get("usage", {})
time_to_first_token = usage.get("time_to_first_token")
time_to_generate = usage.get("time_to_generate")
return {
"gen_ai_server_time_to_first_token": time_to_first_token,
"llm_streaming_time_to_generate": time_to_generate,
"is_streaming_request": time_to_first_token is not None,
}
except (json.JSONDecodeError, AttributeError):
return {}
trace_manager_timer: threading.Timer | None = None
trace_manager_queue: queue.Queue = queue.Queue()

View File

@ -5,12 +5,18 @@ Tencent APM Trace Client - handles network operations, metrics, and API communic
from __future__ import annotations
import importlib
import json
import logging
import os
import socket
from typing import TYPE_CHECKING
from urllib.parse import urlparse
try:
from importlib.metadata import version
except ImportError:
from importlib_metadata import version # type: ignore[import-not-found]
if TYPE_CHECKING:
from opentelemetry.metrics import Meter
from opentelemetry.metrics._internal.instrument import Histogram
@ -27,12 +33,27 @@ from opentelemetry.util.types import AttributeValue
from configs import dify_config
from .entities.tencent_semconv import LLM_OPERATION_DURATION
from .entities.semconv import (
GEN_AI_SERVER_TIME_TO_FIRST_TOKEN,
GEN_AI_STREAMING_TIME_TO_GENERATE,
GEN_AI_TOKEN_USAGE,
GEN_AI_TRACE_DURATION,
LLM_OPERATION_DURATION,
)
from .entities.tencent_trace_entity import SpanData
logger = logging.getLogger(__name__)
def _get_opentelemetry_sdk_version() -> str:
"""Get OpenTelemetry SDK version dynamically."""
try:
return version("opentelemetry-sdk")
except Exception:
logger.debug("Failed to get opentelemetry-sdk version, using default")
return "1.27.0" # fallback version
class TencentTraceClient:
"""Tencent APM trace client using OpenTelemetry OTLP exporter"""
@ -57,6 +78,9 @@ class TencentTraceClient:
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
ResourceAttributes.HOST_NAME: socket.gethostname(),
ResourceAttributes.TELEMETRY_SDK_LANGUAGE: "python",
ResourceAttributes.TELEMETRY_SDK_NAME: "opentelemetry",
ResourceAttributes.TELEMETRY_SDK_VERSION: _get_opentelemetry_sdk_version(),
}
)
# Prepare gRPC endpoint/metadata
@ -80,18 +104,23 @@ class TencentTraceClient:
)
self.tracer_provider.add_span_processor(self.span_processor)
self.tracer = self.tracer_provider.get_tracer("dify.tencent_apm")
# use dify api version as tracer version
self.tracer = self.tracer_provider.get_tracer("dify-sdk", dify_config.project.version)
# Store span contexts for parent-child relationships
self.span_contexts: dict[int, trace_api.SpanContext] = {}
self.meter: Meter | None = None
self.meter_provider: MeterProvider | None = None
self.hist_llm_duration: Histogram | None = None
self.hist_token_usage: Histogram | None = None
self.hist_time_to_first_token: Histogram | None = None
self.hist_time_to_generate: Histogram | None = None
self.hist_trace_duration: Histogram | None = None
self.metric_reader: MetricReader | None = None
# Metrics exporter and instruments
try:
from opentelemetry import metrics
from opentelemetry.sdk.metrics import Histogram, MeterProvider
from opentelemetry.sdk.metrics.export import AggregationTemporality, PeriodicExportingMetricReader
@ -99,7 +128,7 @@ class TencentTraceClient:
use_http_protobuf = protocol in {"http/protobuf", "http-protobuf"}
use_http_json = protocol in {"http/json", "http-json"}
# Set preferred temporality for histograms to DELTA
# Tencent APM works best with delta aggregation temporality
preferred_temporality: dict[type, AggregationTemporality] = {Histogram: AggregationTemporality.DELTA}
def _create_metric_exporter(exporter_cls, **kwargs):
@ -174,23 +203,66 @@ class TencentTraceClient:
)
if metric_reader is not None:
# Use instance-level MeterProvider instead of global to support config changes
# without worker restart. Each TencentTraceClient manages its own MeterProvider.
provider = MeterProvider(resource=self.resource, metric_readers=[metric_reader])
metrics.set_meter_provider(provider)
self.meter = metrics.get_meter("dify-sdk", dify_config.project.version)
self.meter_provider = provider
self.meter = provider.get_meter("dify-sdk", dify_config.project.version)
# LLM operation duration histogram
self.hist_llm_duration = self.meter.create_histogram(
name=LLM_OPERATION_DURATION,
unit="s",
description="LLM operation duration (seconds)",
)
# Token usage histogram with exponential buckets
self.hist_token_usage = self.meter.create_histogram(
name=GEN_AI_TOKEN_USAGE,
unit="token",
description="Number of tokens used in prompt and completions",
)
# Time to first token histogram
self.hist_time_to_first_token = self.meter.create_histogram(
name=GEN_AI_SERVER_TIME_TO_FIRST_TOKEN,
unit="s",
description="Time to first token for streaming LLM responses (seconds)",
)
# Time to generate histogram
self.hist_time_to_generate = self.meter.create_histogram(
name=GEN_AI_STREAMING_TIME_TO_GENERATE,
unit="s",
description="Total time to generate streaming LLM responses (seconds)",
)
# Trace duration histogram
self.hist_trace_duration = self.meter.create_histogram(
name=GEN_AI_TRACE_DURATION,
unit="s",
description="End-to-end GenAI trace duration (seconds)",
)
self.metric_reader = metric_reader
else:
self.meter = None
self.meter_provider = None
self.hist_llm_duration = None
self.hist_token_usage = None
self.hist_time_to_first_token = None
self.hist_time_to_generate = None
self.hist_trace_duration = None
self.metric_reader = None
except Exception:
logger.exception("[Tencent APM] Metrics initialization failed; metrics disabled")
self.meter = None
self.meter_provider = None
self.hist_llm_duration = None
self.hist_token_usage = None
self.hist_time_to_first_token = None
self.hist_time_to_generate = None
self.hist_trace_duration = None
self.metric_reader = None
def add_span(self, span_data: SpanData) -> None:
@ -212,10 +284,158 @@ class TencentTraceClient:
if attributes:
for k, v in attributes.items():
attrs[k] = str(v) if not isinstance(v, (str, int, float, bool)) else v # type: ignore[assignment]
logger.info(
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
LLM_OPERATION_DURATION,
latency_seconds,
json.dumps(attrs, ensure_ascii=False),
)
self.hist_llm_duration.record(latency_seconds, attrs) # type: ignore[attr-defined]
except Exception:
logger.debug("[Tencent APM] Failed to record LLM duration", exc_info=True)
def record_token_usage(
self,
token_count: int,
token_type: str,
operation_name: str,
request_model: str,
response_model: str,
server_address: str,
provider: str,
) -> None:
"""Record token usage histogram.
Args:
token_count: Number of tokens used
token_type: "input" or "output"
operation_name: Operation name (e.g., "chat")
request_model: Model used in request
response_model: Model used in response
server_address: Server address
provider: Model provider name
"""
try:
if not hasattr(self, "hist_token_usage") or self.hist_token_usage is None:
return
attributes = {
"gen_ai.operation.name": operation_name,
"gen_ai.request.model": request_model,
"gen_ai.response.model": response_model,
"gen_ai.system": provider,
"gen_ai.token.type": token_type,
"server.address": server_address,
}
logger.info(
"[Tencent Metrics] Metric: %s | Value: %d | Attributes: %s",
GEN_AI_TOKEN_USAGE,
token_count,
json.dumps(attributes, ensure_ascii=False),
)
self.hist_token_usage.record(token_count, attributes) # type: ignore[attr-defined]
except Exception:
logger.debug("[Tencent APM] Failed to record token usage", exc_info=True)
def record_time_to_first_token(
self, ttft_seconds: float, provider: str, model: str, operation_name: str = "chat"
) -> None:
"""Record time to first token histogram.
Args:
ttft_seconds: Time to first token in seconds
provider: Model provider name
model: Model name
operation_name: Operation name (default: "chat")
"""
try:
if not hasattr(self, "hist_time_to_first_token") or self.hist_time_to_first_token is None:
return
attributes = {
"gen_ai.operation.name": operation_name,
"gen_ai.system": provider,
"gen_ai.request.model": model,
"gen_ai.response.model": model,
"stream": "true",
}
logger.info(
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
GEN_AI_SERVER_TIME_TO_FIRST_TOKEN,
ttft_seconds,
json.dumps(attributes, ensure_ascii=False),
)
self.hist_time_to_first_token.record(ttft_seconds, attributes) # type: ignore[attr-defined]
except Exception:
logger.debug("[Tencent APM] Failed to record time to first token", exc_info=True)
def record_time_to_generate(
self, ttg_seconds: float, provider: str, model: str, operation_name: str = "chat"
) -> None:
"""Record time to generate histogram.
Args:
ttg_seconds: Time to generate in seconds
provider: Model provider name
model: Model name
operation_name: Operation name (default: "chat")
"""
try:
if not hasattr(self, "hist_time_to_generate") or self.hist_time_to_generate is None:
return
attributes = {
"gen_ai.operation.name": operation_name,
"gen_ai.system": provider,
"gen_ai.request.model": model,
"gen_ai.response.model": model,
"stream": "true",
}
logger.info(
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
GEN_AI_STREAMING_TIME_TO_GENERATE,
ttg_seconds,
json.dumps(attributes, ensure_ascii=False),
)
self.hist_time_to_generate.record(ttg_seconds, attributes) # type: ignore[attr-defined]
except Exception:
logger.debug("[Tencent APM] Failed to record time to generate", exc_info=True)
def record_trace_duration(self, duration_seconds: float, attributes: dict[str, str] | None = None) -> None:
"""Record end-to-end trace duration histogram in seconds.
Args:
duration_seconds: Trace duration in seconds
attributes: Optional attributes (e.g., conversation_mode, app_id)
"""
try:
if not hasattr(self, "hist_trace_duration") or self.hist_trace_duration is None:
return
attrs: dict[str, str] = {}
if attributes:
for k, v in attributes.items():
attrs[k] = str(v) if not isinstance(v, (str, int, float, bool)) else v # type: ignore[assignment]
logger.info(
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
GEN_AI_TRACE_DURATION,
duration_seconds,
json.dumps(attrs, ensure_ascii=False),
)
self.hist_trace_duration.record(duration_seconds, attrs) # type: ignore[attr-defined]
except Exception:
logger.debug("[Tencent APM] Failed to record trace duration", exc_info=True)
def _create_and_export_span(self, span_data: SpanData) -> None:
"""Create span using OpenTelemetry Tracer API"""
try:
@ -296,11 +516,19 @@ class TencentTraceClient:
if self.tracer_provider:
self.tracer_provider.shutdown()
# Shutdown instance-level meter provider
if self.meter_provider is not None:
try:
self.meter_provider.shutdown() # type: ignore[attr-defined]
except Exception:
logger.debug("[Tencent APM] Error shutting down meter provider", exc_info=True)
if self.metric_reader is not None:
try:
self.metric_reader.shutdown() # type: ignore[attr-defined]
except Exception:
pass
logger.debug("[Tencent APM] Error shutting down metric reader", exc_info=True)
except Exception:
logger.exception("[Tencent APM] Error during client shutdown")

View File

@ -47,6 +47,9 @@ GEN_AI_COMPLETION = "gen_ai.completion"
GEN_AI_RESPONSE_FINISH_REASON = "gen_ai.response.finish_reason"
# Streaming Span Attributes
GEN_AI_IS_STREAMING_REQUEST = "llm.is_streaming" # Same as OpenLLMetry semconv
# Tool
TOOL_NAME = "tool.name"
@ -62,6 +65,19 @@ INSTRUMENTATION_LANGUAGE = "python"
# Metrics
LLM_OPERATION_DURATION = "gen_ai.client.operation.duration"
GEN_AI_TOKEN_USAGE = "gen_ai.client.token.usage"
GEN_AI_SERVER_TIME_TO_FIRST_TOKEN = "gen_ai.server.time_to_first_token"
GEN_AI_STREAMING_TIME_TO_GENERATE = "gen_ai.streaming.time_to_generate"
# The LLM trace duration which is exclusive to tencent apm
GEN_AI_TRACE_DURATION = "gen_ai.trace.duration"
# Token Usage Attributes
GEN_AI_OPERATION_NAME = "gen_ai.operation.name"
GEN_AI_REQUEST_MODEL = "gen_ai.request.model"
GEN_AI_RESPONSE_MODEL = "gen_ai.response.model"
GEN_AI_SYSTEM = "gen_ai.system"
GEN_AI_TOKEN_TYPE = "gen_ai.token.type"
SERVER_ADDRESS = "server.address"
class GenAISpanKind(Enum):

View File

@ -14,10 +14,11 @@ from core.ops.entities.trace_entity import (
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.ops.tencent_trace.entities.tencent_semconv import (
from core.ops.tencent_trace.entities.semconv import (
GEN_AI_COMPLETION,
GEN_AI_FRAMEWORK,
GEN_AI_IS_ENTRY,
GEN_AI_IS_STREAMING_REQUEST,
GEN_AI_MODEL_NAME,
GEN_AI_PROMPT,
GEN_AI_PROVIDER,
@ -156,6 +157,25 @@ class TencentSpanBuilder:
outputs = node_execution.outputs or {}
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
attributes = {
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_SPAN_KIND: GenAISpanKind.GENERATION.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_MODEL_NAME: process_data.get("model_name", ""),
GEN_AI_PROVIDER: process_data.get("model_provider", ""),
GEN_AI_USAGE_INPUT_TOKENS: str(usage_data.get("prompt_tokens", 0)),
GEN_AI_USAGE_OUTPUT_TOKENS: str(usage_data.get("completion_tokens", 0)),
GEN_AI_USAGE_TOTAL_TOKENS: str(usage_data.get("total_tokens", 0)),
GEN_AI_PROMPT: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
GEN_AI_COMPLETION: str(outputs.get("text", "")),
GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason", ""),
INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
OUTPUT_VALUE: str(outputs.get("text", "")),
}
if usage_data.get("time_to_first_token") is not None:
attributes[GEN_AI_IS_STREAMING_REQUEST] = "true"
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
@ -163,21 +183,7 @@ class TencentSpanBuilder:
name="GENERATION",
start_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.created_at),
end_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_SPAN_KIND: GenAISpanKind.GENERATION.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_MODEL_NAME: process_data.get("model_name", ""),
GEN_AI_PROVIDER: process_data.get("model_provider", ""),
GEN_AI_USAGE_INPUT_TOKENS: str(usage_data.get("prompt_tokens", 0)),
GEN_AI_USAGE_OUTPUT_TOKENS: str(usage_data.get("completion_tokens", 0)),
GEN_AI_USAGE_TOTAL_TOKENS: str(usage_data.get("total_tokens", 0)),
GEN_AI_PROMPT: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
GEN_AI_COMPLETION: str(outputs.get("text", "")),
GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason", ""),
INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
OUTPUT_VALUE: str(outputs.get("text", "")),
},
attributes=attributes,
status=TencentSpanBuilder._get_workflow_node_status(node_execution),
)
@ -191,6 +197,19 @@ class TencentSpanBuilder:
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
attributes = {
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.WORKFLOW.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_IS_ENTRY: "true",
INPUT_VALUE: str(trace_info.inputs or ""),
OUTPUT_VALUE: str(trace_info.outputs or ""),
}
if trace_info.is_streaming_request:
attributes[GEN_AI_IS_STREAMING_REQUEST] = "true"
return SpanData(
trace_id=trace_id,
parent_span_id=None,
@ -198,15 +217,7 @@ class TencentSpanBuilder:
name="message",
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.WORKFLOW.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_IS_ENTRY: "true",
INPUT_VALUE: str(trace_info.inputs or ""),
OUTPUT_VALUE: str(trace_info.outputs or ""),
},
attributes=attributes,
status=status,
links=links,
)

View File

@ -90,6 +90,9 @@ class TencentDataTrace(BaseTraceInstance):
self._process_workflow_nodes(trace_info, trace_id)
# Record trace duration for entry span
self._record_workflow_trace_duration(trace_info)
except Exception:
logger.exception("[Tencent APM] Failed to process workflow trace")
@ -107,6 +110,11 @@ class TencentDataTrace(BaseTraceInstance):
self.trace_client.add_span(message_span)
self._record_message_llm_metrics(trace_info)
# Record trace duration for entry span
self._record_message_trace_duration(trace_info)
except Exception:
logger.exception("[Tencent APM] Failed to process message trace")
@ -290,24 +298,219 @@ class TencentDataTrace(BaseTraceInstance):
def _record_llm_metrics(self, node_execution: WorkflowNodeExecution) -> None:
"""Record LLM performance metrics"""
try:
if not hasattr(self.trace_client, "record_llm_duration"):
return
process_data = node_execution.process_data or {}
usage = process_data.get("usage", {})
latency_s = float(usage.get("latency", 0.0))
outputs = node_execution.outputs or {}
usage = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
if latency_s > 0:
attributes = {
"provider": process_data.get("model_provider", ""),
"model": process_data.get("model_name", ""),
"span_kind": "GENERATION",
}
self.trace_client.record_llm_duration(latency_s, attributes)
model_provider = process_data.get("model_provider", "unknown")
model_name = process_data.get("model_name", "unknown")
model_mode = process_data.get("model_mode", "chat")
# Record LLM duration
if hasattr(self.trace_client, "record_llm_duration"):
latency_s = float(usage.get("latency", 0.0))
if latency_s > 0:
# Determine if streaming from usage metrics
is_streaming = usage.get("time_to_first_token") is not None
attributes = {
"gen_ai.system": model_provider,
"gen_ai.response.model": model_name,
"gen_ai.operation.name": model_mode,
"stream": "true" if is_streaming else "false",
}
self.trace_client.record_llm_duration(latency_s, attributes)
# Record streaming metrics from usage
time_to_first_token = usage.get("time_to_first_token")
if time_to_first_token is not None and hasattr(self.trace_client, "record_time_to_first_token"):
ttft_seconds = float(time_to_first_token)
if ttft_seconds > 0:
self.trace_client.record_time_to_first_token(
ttft_seconds=ttft_seconds, provider=model_provider, model=model_name, operation_name=model_mode
)
time_to_generate = usage.get("time_to_generate")
if time_to_generate is not None and hasattr(self.trace_client, "record_time_to_generate"):
ttg_seconds = float(time_to_generate)
if ttg_seconds > 0:
self.trace_client.record_time_to_generate(
ttg_seconds=ttg_seconds, provider=model_provider, model=model_name, operation_name=model_mode
)
# Record token usage
if hasattr(self.trace_client, "record_token_usage"):
# Extract token counts
input_tokens = int(usage.get("prompt_tokens", 0))
output_tokens = int(usage.get("completion_tokens", 0))
if input_tokens > 0 or output_tokens > 0:
server_address = f"{model_provider}"
# Record input tokens
if input_tokens > 0:
self.trace_client.record_token_usage(
token_count=input_tokens,
token_type="input",
operation_name=model_mode,
request_model=model_name,
response_model=model_name,
server_address=server_address,
provider=model_provider,
)
# Record output tokens
if output_tokens > 0:
self.trace_client.record_token_usage(
token_count=output_tokens,
token_type="output",
operation_name=model_mode,
request_model=model_name,
response_model=model_name,
server_address=server_address,
provider=model_provider,
)
except Exception:
logger.debug("[Tencent APM] Failed to record LLM metrics")
def _record_message_llm_metrics(self, trace_info: MessageTraceInfo) -> None:
"""Record LLM metrics for message traces"""
try:
trace_metadata = trace_info.metadata or {}
message_data = trace_info.message_data or {}
provider_latency = 0.0
if isinstance(message_data, dict):
provider_latency = float(message_data.get("provider_response_latency", 0.0) or 0.0)
else:
provider_latency = float(getattr(message_data, "provider_response_latency", 0.0) or 0.0)
model_provider = trace_metadata.get("ls_provider") or (
message_data.get("model_provider", "") if isinstance(message_data, dict) else ""
)
model_name = trace_metadata.get("ls_model_name") or (
message_data.get("model_id", "") if isinstance(message_data, dict) else ""
)
# Record LLM duration
if provider_latency > 0 and hasattr(self.trace_client, "record_llm_duration"):
is_streaming = trace_info.is_streaming_request
duration_attributes = {
"gen_ai.system": model_provider,
"gen_ai.response.model": model_name,
"gen_ai.operation.name": "chat", # Message traces are always chat
"stream": "true" if is_streaming else "false",
}
self.trace_client.record_llm_duration(provider_latency, duration_attributes)
# Record streaming metrics for message traces
if trace_info.is_streaming_request:
# Record time to first token
if trace_info.gen_ai_server_time_to_first_token is not None and hasattr(
self.trace_client, "record_time_to_first_token"
):
ttft_seconds = float(trace_info.gen_ai_server_time_to_first_token)
if ttft_seconds > 0:
self.trace_client.record_time_to_first_token(
ttft_seconds=ttft_seconds, provider=str(model_provider or ""), model=str(model_name or "")
)
# Record time to generate
if trace_info.llm_streaming_time_to_generate is not None and hasattr(
self.trace_client, "record_time_to_generate"
):
ttg_seconds = float(trace_info.llm_streaming_time_to_generate)
if ttg_seconds > 0:
self.trace_client.record_time_to_generate(
ttg_seconds=ttg_seconds, provider=str(model_provider or ""), model=str(model_name or "")
)
# Record token usage
if hasattr(self.trace_client, "record_token_usage"):
input_tokens = int(trace_info.message_tokens or 0)
output_tokens = int(trace_info.answer_tokens or 0)
if input_tokens > 0:
self.trace_client.record_token_usage(
token_count=input_tokens,
token_type="input",
operation_name="chat",
request_model=str(model_name or ""),
response_model=str(model_name or ""),
server_address=str(model_provider or ""),
provider=str(model_provider or ""),
)
if output_tokens > 0:
self.trace_client.record_token_usage(
token_count=output_tokens,
token_type="output",
operation_name="chat",
request_model=str(model_name or ""),
response_model=str(model_name or ""),
server_address=str(model_provider or ""),
provider=str(model_provider or ""),
)
except Exception:
logger.debug("[Tencent APM] Failed to record message LLM metrics")
def _record_workflow_trace_duration(self, trace_info: WorkflowTraceInfo) -> None:
"""Record end-to-end workflow trace duration."""
try:
if not hasattr(self.trace_client, "record_trace_duration"):
return
# Calculate duration from start_time and end_time to match span duration
if trace_info.start_time and trace_info.end_time:
duration_s = (trace_info.end_time - trace_info.start_time).total_seconds()
else:
# Fallback to workflow_run_elapsed_time if timestamps not available
duration_s = float(trace_info.workflow_run_elapsed_time)
if duration_s > 0:
attributes = {
"conversation_mode": "workflow",
"workflow_status": trace_info.workflow_run_status,
}
# Add conversation_id if available
if trace_info.conversation_id:
attributes["has_conversation"] = "true"
else:
attributes["has_conversation"] = "false"
self.trace_client.record_trace_duration(duration_s, attributes)
except Exception:
logger.debug("[Tencent APM] Failed to record workflow trace duration")
def _record_message_trace_duration(self, trace_info: MessageTraceInfo) -> None:
"""Record end-to-end message trace duration."""
try:
if not hasattr(self.trace_client, "record_trace_duration"):
return
# Calculate duration from start_time and end_time
if trace_info.start_time and trace_info.end_time:
duration = (trace_info.end_time - trace_info.start_time).total_seconds()
if duration > 0:
attributes = {
"conversation_mode": trace_info.conversation_mode,
}
# Add streaming flag if available
if hasattr(trace_info, "is_streaming_request"):
attributes["stream"] = "true" if trace_info.is_streaming_request else "false"
self.trace_client.record_trace_duration(duration, attributes)
except Exception:
logger.debug("[Tencent APM] Failed to record message trace duration")
def __del__(self):
"""Ensure proper cleanup on garbage collection."""
try:

View File

@ -76,7 +76,7 @@ class PluginParameter(BaseModel):
auto_generate: PluginParameterAutoGenerate | None = None
template: PluginParameterTemplate | None = None
required: bool = False
default: Union[float, int, str] | None = None
default: Union[float, int, str, bool] | None = None
min: Union[float, int] | None = None
max: Union[float, int] | None = None
precision: int | None = None

View File

@ -40,7 +40,7 @@ class PluginDaemonBadRequestError(PluginDaemonClientSideError):
description: str = "Bad Request"
class PluginInvokeError(PluginDaemonClientSideError):
class PluginInvokeError(PluginDaemonClientSideError, ValueError):
description: str = "Invoke Error"
def _get_error_object(self) -> Mapping:

View File

@ -161,7 +161,7 @@ class OpenSearchVector(BaseVector):
logger.exception("Error deleting document: %s", error)
def delete(self):
self._client.indices.delete(index=self._collection_name.lower())
self._client.indices.delete(index=self._collection_name.lower(), ignore_unavailable=True)
def text_exists(self, id: str) -> bool:
try:

View File

@ -39,11 +39,13 @@ class WeaviateConfig(BaseModel):
Attributes:
endpoint: Weaviate server endpoint URL
grpc_endpoint: Optional Weaviate gRPC server endpoint URL
api_key: Optional API key for authentication
batch_size: Number of objects to batch per insert operation
"""
endpoint: str
grpc_endpoint: str | None = None
api_key: str | None = None
batch_size: int = 100
@ -88,9 +90,22 @@ class WeaviateVector(BaseVector):
http_secure = p.scheme == "https"
http_port = p.port or (443 if http_secure else 80)
grpc_host = host
grpc_secure = http_secure
grpc_port = 443 if grpc_secure else 50051
# Parse gRPC configuration
if config.grpc_endpoint:
# Urls without scheme won't be parsed correctly in some python verions,
# see https://bugs.python.org/issue27657
grpc_endpoint_with_scheme = (
config.grpc_endpoint if "://" in config.grpc_endpoint else f"grpc://{config.grpc_endpoint}"
)
grpc_p = urlparse(grpc_endpoint_with_scheme)
grpc_host = grpc_p.hostname or "localhost"
grpc_port = grpc_p.port or (443 if grpc_p.scheme == "grpcs" else 50051)
grpc_secure = grpc_p.scheme == "grpcs"
else:
# Infer from HTTP endpoint as fallback
grpc_host = host
grpc_secure = http_secure
grpc_port = 443 if grpc_secure else 50051
client = weaviate.connect_to_custom(
http_host=host,
@ -100,6 +115,7 @@ class WeaviateVector(BaseVector):
grpc_port=grpc_port,
grpc_secure=grpc_secure,
auth_credentials=Auth.api_key(config.api_key) if config.api_key else None,
skip_init_checks=True, # Skip PyPI version check to avoid unnecessary HTTP requests
)
if not client.is_ready():
@ -431,6 +447,7 @@ class WeaviateVectorFactory(AbstractVectorFactory):
collection_name=collection_name,
config=WeaviateConfig(
endpoint=dify_config.WEAVIATE_ENDPOINT or "",
grpc_endpoint=dify_config.WEAVIATE_GRPC_ENDPOINT or "",
api_key=dify_config.WEAVIATE_API_KEY,
batch_size=dify_config.WEAVIATE_BATCH_SIZE,
),

View File

@ -72,6 +72,19 @@ default_retrieval_model: dict[str, Any] = {
class DatasetRetrieval:
def __init__(self, application_generate_entity=None):
self.application_generate_entity = application_generate_entity
self._llm_usage = LLMUsage.empty_usage()
@property
def llm_usage(self) -> LLMUsage:
return self._llm_usage.model_copy()
def _record_usage(self, usage: LLMUsage | None) -> None:
if usage is None or usage.total_tokens <= 0:
return
if self._llm_usage.total_tokens == 0:
self._llm_usage = usage
else:
self._llm_usage = self._llm_usage.plus(usage)
def retrieve(
self,
@ -312,15 +325,18 @@ class DatasetRetrieval:
)
tools.append(message_tool)
dataset_id = None
router_usage = LLMUsage.empty_usage()
if planning_strategy == PlanningStrategy.REACT_ROUTER:
react_multi_dataset_router = ReactMultiDatasetRouter()
dataset_id = react_multi_dataset_router.invoke(
dataset_id, router_usage = react_multi_dataset_router.invoke(
query, tools, model_config, model_instance, user_id, tenant_id
)
elif planning_strategy == PlanningStrategy.ROUTER:
function_call_router = FunctionCallMultiDatasetRouter()
dataset_id = function_call_router.invoke(query, tools, model_config, model_instance)
dataset_id, router_usage = function_call_router.invoke(query, tools, model_config, model_instance)
self._record_usage(router_usage)
if dataset_id:
# get retrieval model config
@ -983,7 +999,8 @@ class DatasetRetrieval:
)
# handle invoke result
result_text, _ = self._handle_invoke_result(invoke_result=invoke_result)
result_text, usage = self._handle_invoke_result(invoke_result=invoke_result)
self._record_usage(usage)
result_text_json = parse_and_check_json_markdown(result_text, [])
automatic_metadata_filters = []

View File

@ -2,7 +2,7 @@ from typing import Union
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage
@ -13,15 +13,15 @@ class FunctionCallMultiDatasetRouter:
dataset_tools: list[PromptMessageTool],
model_config: ModelConfigWithCredentialsEntity,
model_instance: ModelInstance,
) -> Union[str, None]:
) -> tuple[Union[str, None], LLMUsage]:
"""Given input, decided what to do.
Returns:
Action specifying what tool to use.
"""
if len(dataset_tools) == 0:
return None
return None, LLMUsage.empty_usage()
elif len(dataset_tools) == 1:
return dataset_tools[0].name
return dataset_tools[0].name, LLMUsage.empty_usage()
try:
prompt_messages = [
@ -34,9 +34,10 @@ class FunctionCallMultiDatasetRouter:
stream=False,
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
)
usage = result.usage or LLMUsage.empty_usage()
if result.message.tool_calls:
# get retrieval model config
return result.message.tool_calls[0].function.name
return None
return result.message.tool_calls[0].function.name, usage
return None, usage
except Exception:
return None
return None, LLMUsage.empty_usage()

View File

@ -58,15 +58,15 @@ class ReactMultiDatasetRouter:
model_instance: ModelInstance,
user_id: str,
tenant_id: str,
) -> Union[str, None]:
) -> tuple[Union[str, None], LLMUsage]:
"""Given input, decided what to do.
Returns:
Action specifying what tool to use.
"""
if len(dataset_tools) == 0:
return None
return None, LLMUsage.empty_usage()
elif len(dataset_tools) == 1:
return dataset_tools[0].name
return dataset_tools[0].name, LLMUsage.empty_usage()
try:
return self._react_invoke(
@ -78,7 +78,7 @@ class ReactMultiDatasetRouter:
tenant_id=tenant_id,
)
except Exception:
return None
return None, LLMUsage.empty_usage()
def _react_invoke(
self,
@ -91,7 +91,7 @@ class ReactMultiDatasetRouter:
prefix: str = PREFIX,
suffix: str = SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
) -> Union[str, None]:
) -> tuple[Union[str, None], LLMUsage]:
prompt: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
if model_config.mode == "chat":
prompt = self.create_chat_prompt(
@ -120,7 +120,7 @@ class ReactMultiDatasetRouter:
memory=None,
model_config=model_config,
)
result_text, _ = self._invoke_llm(
result_text, usage = self._invoke_llm(
completion_param=model_config.parameters,
model_instance=model_instance,
prompt_messages=prompt_messages,
@ -131,8 +131,8 @@ class ReactMultiDatasetRouter:
output_parser = StructuredChatOutputParser()
react_decision = output_parser.parse(result_text)
if isinstance(react_decision, ReactAction):
return react_decision.tool
return None
return react_decision.tool, usage
return None, usage
def _invoke_llm(
self,

View File

@ -217,3 +217,16 @@ class Tool(ABC):
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.JSON, message=ToolInvokeMessage.JsonMessage(json_object=object)
)
def create_variable_message(
self, variable_name: str, variable_value: Any, stream: bool = False
) -> ToolInvokeMessage:
"""
create a variable message
"""
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.VARIABLE,
message=ToolInvokeMessage.VariableMessage(
variable_name=variable_name, variable_value=variable_value, stream=stream
),
)

View File

@ -4,6 +4,7 @@ from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool import ToolParameter
from core.tools.entities.common_entities import I18nObject
@ -44,10 +45,14 @@ class ToolProviderApiEntity(BaseModel):
server_url: str | None = Field(default="", description="The server url of the tool")
updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
server_identifier: str | None = Field(default="", description="The server identifier of the MCP tool")
timeout: float | None = Field(default=30.0, description="The timeout of the MCP tool")
sse_read_timeout: float | None = Field(default=300.0, description="The SSE read timeout of the MCP tool")
masked_headers: dict[str, str] | None = Field(default=None, description="The masked headers of the MCP tool")
original_headers: dict[str, str] | None = Field(default=None, description="The original headers of the MCP tool")
authentication: MCPAuthentication | None = Field(default=None, description="The OAuth config of the MCP tool")
is_dynamic_registration: bool = Field(default=True, description="Whether the MCP tool is dynamically registered")
configuration: MCPConfiguration | None = Field(
default=None, description="The timeout and sse_read_timeout of the MCP tool"
)
@field_validator("tools", mode="before")
@classmethod
@ -70,8 +75,15 @@ class ToolProviderApiEntity(BaseModel):
if self.type == ToolProviderType.MCP:
optional_fields.update(self.optional_field("updated_at", self.updated_at))
optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
optional_fields.update(self.optional_field("timeout", self.timeout))
optional_fields.update(self.optional_field("sse_read_timeout", self.sse_read_timeout))
optional_fields.update(
self.optional_field(
"configuration", self.configuration.model_dump() if self.configuration else MCPConfiguration()
)
)
optional_fields.update(
self.optional_field("authentication", self.authentication.model_dump() if self.authentication else None)
)
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
optional_fields.update(self.optional_field("original_headers", self.original_headers))
return {

View File

@ -1,6 +1,6 @@
import json
from typing import Any, Self
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.types import Tool as RemoteMCPTool
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
@ -52,18 +52,25 @@ class MCPToolProviderController(ToolProviderController):
"""
from db provider
"""
tools = []
tools_data = json.loads(db_provider.tools)
remote_mcp_tools = [RemoteMCPTool.model_validate(tool) for tool in tools_data]
user = db_provider.load_user()
# Convert to entity first
provider_entity = db_provider.to_entity()
return cls.from_entity(provider_entity)
@classmethod
def from_entity(cls, entity: MCPProviderEntity) -> Self:
"""
create a MCPToolProviderController from a MCPProviderEntity
"""
remote_mcp_tools = [RemoteMCPTool(**tool) for tool in entity.tools]
tools = [
ToolEntity(
identity=ToolIdentity(
author=user.name if user else "Anonymous",
author="Anonymous", # Tool level author is not stored
name=remote_mcp_tool.name,
label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name),
provider=db_provider.server_identifier,
icon=db_provider.icon,
provider=entity.provider_id,
icon=entity.icon if isinstance(entity.icon, str) else "",
),
parameters=ToolTransformService.convert_mcp_schema_to_parameter(remote_mcp_tool.inputSchema),
description=ToolDescription(
@ -72,31 +79,32 @@ class MCPToolProviderController(ToolProviderController):
),
llm=remote_mcp_tool.description or "",
),
output_schema=remote_mcp_tool.outputSchema or {},
has_runtime_parameters=len(remote_mcp_tool.inputSchema) > 0,
)
for remote_mcp_tool in remote_mcp_tools
]
if not db_provider.icon:
if not entity.icon:
raise ValueError("Database provider icon is required")
return cls(
entity=ToolProviderEntityWithPlugin(
identity=ToolProviderIdentity(
author=user.name if user else "Anonymous",
name=db_provider.name,
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
author="Anonymous", # Provider level author is not stored in entity
name=entity.name,
label=I18nObject(en_US=entity.name, zh_Hans=entity.name),
description=I18nObject(en_US="", zh_Hans=""),
icon=db_provider.icon,
icon=entity.icon if isinstance(entity.icon, str) else "",
),
plugin_id=None,
credentials_schema=[],
tools=tools,
),
provider_id=db_provider.server_identifier or "",
tenant_id=db_provider.tenant_id or "",
server_url=db_provider.decrypted_server_url,
headers=db_provider.decrypted_headers or {},
timeout=db_provider.timeout,
sse_read_timeout=db_provider.sse_read_timeout,
provider_id=entity.provider_id,
tenant_id=entity.tenant_id,
server_url=entity.server_url,
headers=entity.headers,
timeout=entity.timeout,
sse_read_timeout=entity.sse_read_timeout,
)
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):

View File

@ -3,12 +3,13 @@ import json
from collections.abc import Generator
from typing import Any
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.mcp_client import MCPClient
from core.mcp.types import ImageContent, TextContent
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPConnectionError
from core.mcp.types import CallToolResult, ImageContent, TextContent
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
from core.tools.errors import ToolInvokeError
class MCPTool(Tool):
@ -44,40 +45,32 @@ class MCPTool(Tool):
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
from core.tools.errors import ToolInvokeError
try:
with MCPClient(
self.server_url,
self.provider_id,
self.tenant_id,
authed=True,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
) as mcp_client:
tool_parameters = self._handle_none_parameter(tool_parameters)
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPAuthError as e:
raise ToolInvokeError("Please auth the tool first") from e
except MCPConnectionError as e:
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
except Exception as e:
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
result = self.invoke_remote_mcp_tool(tool_parameters)
# handle dify tool output
for content in result.content:
if isinstance(content, TextContent):
yield from self._process_text_content(content)
elif isinstance(content, ImageContent):
yield self._process_image_content(content)
# handle MCP structured output
if self.entity.output_schema and result.structuredContent:
for k, v in result.structuredContent.items():
yield self.create_variable_message(k, v)
def _process_text_content(self, content: TextContent) -> Generator[ToolInvokeMessage, None, None]:
"""Process text content and yield appropriate messages."""
try:
content_json = json.loads(content.text)
yield from self._process_json_content(content_json)
except json.JSONDecodeError:
yield self.create_text_message(content.text)
# Check if content looks like JSON before attempting to parse
text = content.text.strip()
if text and text[0] in ("{", "[") and text[-1] in ("}", "]"):
try:
content_json = json.loads(text)
yield from self._process_json_content(content_json)
return
except json.JSONDecodeError:
pass
# If not JSON or parsing failed, treat as plain text
yield self.create_text_message(content.text)
def _process_json_content(self, content_json: Any) -> Generator[ToolInvokeMessage, None, None]:
"""Process JSON content based on its type."""
@ -126,3 +119,44 @@ class MCPTool(Tool):
for key, value in parameter.items()
if value is not None and not (isinstance(value, str) and value.strip() == "")
}
def invoke_remote_mcp_tool(self, tool_parameters: dict[str, Any]) -> CallToolResult:
headers = self.headers.copy() if self.headers else {}
tool_parameters = self._handle_none_parameter(tool_parameters)
from sqlalchemy.orm import Session
from extensions.ext_database import db
from services.tools.mcp_tools_manage_service import MCPToolManageService
# Step 1: Load provider entity and credentials in a short-lived session
# This minimizes database connection hold time
with Session(db.engine, expire_on_commit=False) as session:
mcp_service = MCPToolManageService(session=session)
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
# Decrypt and prepare all credentials before closing session
server_url = provider_entity.decrypt_server_url()
headers = provider_entity.decrypt_headers()
# Try to get existing token and add to headers
if not headers:
tokens = provider_entity.retrieve_tokens()
if tokens and tokens.access_token:
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
# Step 2: Session is now closed, perform network operations without holding database connection
# MCPClientWithAuthRetry will create a new session lazily only if auth retry is needed
try:
with MCPClientWithAuthRetry(
server_url=server_url,
headers=headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
provider_entity=provider_entity,
) as mcp_client:
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPConnectionError as e:
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
except Exception as e:
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e

View File

@ -228,29 +228,38 @@ class ToolEngine:
"""
Handle tool response
"""
result = ""
parts: list[str] = []
json_parts: list[str] = []
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.TEXT:
result += cast(ToolInvokeMessage.TextMessage, response.message).text
parts.append(cast(ToolInvokeMessage.TextMessage, response.message).text)
elif response.type == ToolInvokeMessage.MessageType.LINK:
result += (
parts.append(
f"result link: {cast(ToolInvokeMessage.TextMessage, response.message).text}."
+ " please tell user to check it."
)
elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
result += (
parts.append(
"image has been created and sent to user already, "
+ "you do not need to create it, just tell the user to check it now."
)
elif response.type == ToolInvokeMessage.MessageType.JSON:
result += json.dumps(
safe_json_value(cast(ToolInvokeMessage.JsonMessage, response.message).json_object),
ensure_ascii=False,
json_parts.append(
json.dumps(
safe_json_value(cast(ToolInvokeMessage.JsonMessage, response.message).json_object),
ensure_ascii=False,
)
)
else:
result += str(response.message)
parts.append(str(response.message))
return result
# Add JSON parts, avoiding duplicates from text parts.
if json_parts:
existing_parts = set(parts)
parts.extend(p for p in json_parts if p not in existing_parts)
return "".join(parts)
@staticmethod
def _extract_tool_response_binary_and_text(

View File

@ -14,17 +14,32 @@ from sqlalchemy.orm import Session
from yarl import URL
import contexts
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.mcp_tool.tool import MCPTool
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.plugin_tool.tool import PluginTool
from core.tools.utils.uuid_utils import is_valid_uuid
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.workflow.runtime.variable_pool import VariablePool
from extensions.ext_database import db
from models.provider_ids import ToolProviderID
from services.enterprise.plugin_manager_service import PluginCredentialType
from services.tools.mcp_tools_manage_service import MCPToolManageService
if TYPE_CHECKING:
from core.workflow.nodes.tool.entities import ToolEntity
from configs import dify_config
from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.helper.module_import_helper import load_single_subclass_from_source
from core.helper.position_helper import is_filtered
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool import Tool
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
from core.tools.builtin_tool.tool import BuiltinTool
@ -40,21 +55,11 @@ from core.tools.entities.tool_entities import (
ToolProviderType,
)
from core.tools.errors import ToolProviderNotFoundError
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.mcp_tool.tool import MCPTool
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.plugin_tool.tool import PluginTool
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
from core.tools.utils.uuid_utils import is_valid_uuid
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
from models.provider_ids import ToolProviderID
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
from services.enterprise.plugin_manager_service import PluginCredentialType
from services.tools.mcp_tools_manage_service import MCPToolManageService
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
from services.tools.tools_transform_service import ToolTransformService
if TYPE_CHECKING:
@ -719,7 +724,9 @@ class ToolManager:
)
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
if "mcp" in filters:
mcp_providers = MCPToolManageService.retrieve_mcp_tools(tenant_id, for_list=True)
with Session(db.engine) as session:
mcp_service = MCPToolManageService(session=session)
mcp_providers = mcp_service.list_providers(tenant_id=tenant_id, for_list=True)
for mcp_provider in mcp_providers:
result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider
@ -774,17 +781,12 @@ class ToolManager:
:return: the provider controller, the credentials
"""
provider: MCPToolProvider | None = (
db.session.query(MCPToolProvider)
.where(
MCPToolProvider.server_identifier == provider_id,
MCPToolProvider.tenant_id == tenant_id,
)
.first()
)
if provider is None:
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
with Session(db.engine) as session:
mcp_service = MCPToolManageService(session=session)
try:
provider = mcp_service.get_provider(server_identifier=provider_id, tenant_id=tenant_id)
except ValueError:
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
controller = MCPToolProviderController.from_db(provider)
@ -922,16 +924,15 @@ class ToolManager:
@classmethod
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str] | str:
try:
mcp_provider: MCPToolProvider | None = (
db.session.query(MCPToolProvider)
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id)
.first()
)
if mcp_provider is None:
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
return mcp_provider.provider_icon
with Session(db.engine) as session:
mcp_service = MCPToolManageService(session=session)
try:
mcp_provider = mcp_service.get_provider_entity(
provider_id=provider_id, tenant_id=tenant_id, by_server_id=True
)
return mcp_provider.provider_icon
except ValueError:
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}

View File

@ -62,6 +62,11 @@ class ApiBasedToolSchemaParser:
root = root[ref]
interface["operation"]["parameters"][i] = root
for parameter in interface["operation"]["parameters"]:
# Handle complex type defaults that are not supported by PluginParameter
default_value = None
if "schema" in parameter and "default" in parameter["schema"]:
default_value = ApiBasedToolSchemaParser._sanitize_default_value(parameter["schema"]["default"])
tool_parameter = ToolParameter(
name=parameter["name"],
label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]),
@ -72,9 +77,7 @@ class ApiBasedToolSchemaParser:
required=parameter.get("required", False),
form=ToolParameter.ToolParameterForm.LLM,
llm_description=parameter.get("description"),
default=parameter["schema"]["default"]
if "schema" in parameter and "default" in parameter["schema"]
else None,
default=default_value,
placeholder=I18nObject(
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
),
@ -134,6 +137,11 @@ class ApiBasedToolSchemaParser:
required = body_schema.get("required", [])
properties = body_schema.get("properties", {})
for name, property in properties.items():
# Handle complex type defaults that are not supported by PluginParameter
default_value = ApiBasedToolSchemaParser._sanitize_default_value(
property.get("default", None)
)
tool = ToolParameter(
name=name,
label=I18nObject(en_US=name, zh_Hans=name),
@ -144,12 +152,11 @@ class ApiBasedToolSchemaParser:
required=name in required,
form=ToolParameter.ToolParameterForm.LLM,
llm_description=property.get("description", ""),
default=property.get("default", None),
default=default_value,
placeholder=I18nObject(
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
),
)
# check if there is a type
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
if typ:
@ -197,6 +204,22 @@ class ApiBasedToolSchemaParser:
return bundles
@staticmethod
def _sanitize_default_value(value):
"""
Sanitize default values for PluginParameter compatibility.
Complex types (list, dict) are converted to None to avoid validation errors.
Args:
value: The default value from OpenAPI schema
Returns:
None for complex types (list, dict), otherwise the original value
"""
if isinstance(value, (list, dict)):
return None
return value
@staticmethod
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None:
parameter = parameter or {}
@ -217,7 +240,11 @@ class ApiBasedToolSchemaParser:
return ToolParameter.ToolParameterType.STRING
elif typ == "array":
items = parameter.get("items") or parameter.get("schema", {}).get("items")
return ToolParameter.ToolParameterType.FILES if items and items.get("format") == "binary" else None
if items and items.get("format") == "binary":
return ToolParameter.ToolParameterType.FILES
else:
# For regular arrays, return ARRAY type instead of None
return ToolParameter.ToolParameterType.ARRAY
else:
return None

View File

@ -31,6 +31,7 @@ VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING,
VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT,
VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER,
VariableEntityType.CHECKBOX: ToolParameter.ToolParameterType.BOOLEAN,
VariableEntityType.FILE: ToolParameter.ToolParameterType.FILE,
VariableEntityType.FILE_LIST: ToolParameter.ToolParameterType.FILES,
}

View File

@ -1,13 +1,14 @@
import json
import logging
from collections.abc import Generator
from typing import Any
from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast
from flask import has_request_context
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import (
@ -49,6 +50,7 @@ class WorkflowTool(Tool):
self.workflow_entities = workflow_entities
self.workflow_call_depth = workflow_call_depth
self.label = label
self._latest_usage = LLMUsage.empty_usage()
super().__init__(entity=entity, runtime=runtime)
@ -84,10 +86,11 @@ class WorkflowTool(Tool):
assert self.runtime.invoke_from is not None
user = self._resolve_user(user_id=user_id)
if user is None:
raise ToolInvokeError("User not found")
self._latest_usage = LLMUsage.empty_usage()
result = generator.generate(
app_model=app,
workflow=workflow,
@ -111,9 +114,68 @@ class WorkflowTool(Tool):
for file in files:
yield self.create_file_message(file) # type: ignore
self._latest_usage = self._derive_usage_from_result(data)
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
yield self.create_json_message(outputs)
@property
def latest_usage(self) -> LLMUsage:
return self._latest_usage
@classmethod
def _derive_usage_from_result(cls, data: Mapping[str, Any]) -> LLMUsage:
usage_dict = cls._extract_usage_dict(data)
if usage_dict is not None:
return LLMUsage.from_metadata(cast(LLMUsageMetadata, dict(usage_dict)))
total_tokens = data.get("total_tokens")
total_price = data.get("total_price")
if total_tokens is None and total_price is None:
return LLMUsage.empty_usage()
usage_metadata: dict[str, Any] = {}
if total_tokens is not None:
try:
usage_metadata["total_tokens"] = int(str(total_tokens))
except (TypeError, ValueError):
pass
if total_price is not None:
usage_metadata["total_price"] = str(total_price)
currency = data.get("currency")
if currency is not None:
usage_metadata["currency"] = currency
if not usage_metadata:
return LLMUsage.empty_usage()
return LLMUsage.from_metadata(cast(LLMUsageMetadata, usage_metadata))
@classmethod
def _extract_usage_dict(cls, payload: Mapping[str, Any]) -> Mapping[str, Any] | None:
usage_candidate = payload.get("usage")
if isinstance(usage_candidate, Mapping):
return usage_candidate
metadata_candidate = payload.get("metadata")
if isinstance(metadata_candidate, Mapping):
usage_candidate = metadata_candidate.get("usage")
if isinstance(usage_candidate, Mapping):
return usage_candidate
for value in payload.values():
if isinstance(value, Mapping):
found = cls._extract_usage_dict(value)
if found is not None:
return found
elif isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
for item in value:
if isinstance(item, Mapping):
found = cls._extract_usage_dict(item)
if found is not None:
return found
return None
def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
"""
fork a new tool with metadata

View File

@ -4,6 +4,7 @@ from .agent import AgentNodeStrategyInit
from .graph_init_params import GraphInitParams
from .workflow_execution import WorkflowExecution
from .workflow_node_execution import WorkflowNodeExecution
from .workflow_pause import WorkflowPauseEntity
__all__ = [
"AgentNodeStrategyInit",
@ -12,4 +13,5 @@ __all__ = [
"VariablePool",
"WorkflowExecution",
"WorkflowNodeExecution",
"WorkflowPauseEntity",
]

View File

@ -0,0 +1,49 @@
from enum import StrEnum, auto
from typing import Annotated, Any, ClassVar, TypeAlias
from pydantic import BaseModel, Discriminator, Tag
class _PauseReasonType(StrEnum):
HUMAN_INPUT_REQUIRED = auto()
SCHEDULED_PAUSE = auto()
class _PauseReasonBase(BaseModel):
TYPE: ClassVar[_PauseReasonType]
class HumanInputRequired(_PauseReasonBase):
TYPE = _PauseReasonType.HUMAN_INPUT_REQUIRED
class SchedulingPause(_PauseReasonBase):
TYPE = _PauseReasonType.SCHEDULED_PAUSE
message: str
def _get_pause_reason_discriminator(v: Any) -> _PauseReasonType | None:
if isinstance(v, _PauseReasonBase):
return v.TYPE
elif isinstance(v, dict):
reason_type_str = v.get("TYPE")
if reason_type_str is None:
return None
try:
reason_type = _PauseReasonType(reason_type_str)
except ValueError:
return None
return reason_type
else:
# return None if the discriminator value isn't found
return None
PauseReason: TypeAlias = Annotated[
(
Annotated[HumanInputRequired, Tag(_PauseReasonType.HUMAN_INPUT_REQUIRED)]
| Annotated[SchedulingPause, Tag(_PauseReasonType.SCHEDULED_PAUSE)]
),
Discriminator(_get_pause_reason_discriminator),
]

Some files were not shown because too many files have changed in this diff Show More