mirror of
https://github.com/langgenius/dify.git
synced 2026-04-28 20:17:29 +08:00
Merge branch 'main' into feat/memory-orchestration-fed
This commit is contained in:
commit
38d27100ac
@ -11,7 +11,7 @@
|
|||||||
"nodeGypDependencies": true,
|
"nodeGypDependencies": true,
|
||||||
"version": "lts"
|
"version": "lts"
|
||||||
},
|
},
|
||||||
"ghcr.io/devcontainers-contrib/features/npm-package:1": {
|
"ghcr.io/devcontainers-extra/features/npm-package:1": {
|
||||||
"package": "typescript",
|
"package": "typescript",
|
||||||
"version": "latest"
|
"version": "latest"
|
||||||
},
|
},
|
||||||
|
|||||||
10
README.md
10
README.md
@ -63,7 +63,7 @@ Dify is an open-source platform for developing LLM applications. Its intuitive i
|
|||||||
> - CPU >= 2 Core
|
> - CPU >= 2 Core
|
||||||
> - RAM >= 4 GiB
|
> - RAM >= 4 GiB
|
||||||
|
|
||||||
</br>
|
<br/>
|
||||||
|
|
||||||
The easiest way to start the Dify server is through [Docker Compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine:
|
The easiest way to start the Dify server is through [Docker Compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine:
|
||||||
|
|
||||||
@ -109,15 +109,15 @@ All of Dify's offerings come with corresponding APIs, so you could effortlessly
|
|||||||
|
|
||||||
## Using Dify
|
## Using Dify
|
||||||
|
|
||||||
- **Cloud </br>**
|
- **Cloud <br/>**
|
||||||
We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan.
|
We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan.
|
||||||
|
|
||||||
- **Self-hosting Dify Community Edition</br>**
|
- **Self-hosting Dify Community Edition<br/>**
|
||||||
Quickly get Dify running in your environment with this [starter guide](#quick-start).
|
Quickly get Dify running in your environment with this [starter guide](#quick-start).
|
||||||
Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions.
|
Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions.
|
||||||
|
|
||||||
- **Dify for enterprise / organizations</br>**
|
- **Dify for enterprise / organizations<br/>**
|
||||||
We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss enterprise needs. </br>
|
We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss enterprise needs. <br/>
|
||||||
|
|
||||||
> For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one click. It's an affordable AMI offering with the option to create apps with custom logo and branding.
|
> For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one click. It's an affordable AMI offering with the option to create apps with custom logo and branding.
|
||||||
|
|
||||||
|
|||||||
@ -156,6 +156,9 @@ SUPABASE_URL=your-server-url
|
|||||||
# CORS configuration
|
# CORS configuration
|
||||||
WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
||||||
CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
||||||
|
# Set COOKIE_DOMAIN when the console frontend and API are on different subdomains.
|
||||||
|
# Provide the registrable domain (e.g. example.com); leading dots are optional.
|
||||||
|
COOKIE_DOMAIN=
|
||||||
|
|
||||||
# Vector database configuration
|
# Vector database configuration
|
||||||
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
|
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
|
||||||
|
|||||||
@ -80,7 +80,7 @@
|
|||||||
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
|
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv run celery -A app.celery worker -P gevent -c 2 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation
|
uv run celery -A app.celery worker -P gevent -c 2 --loglevel INFO -Q dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation
|
||||||
```
|
```
|
||||||
|
|
||||||
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:
|
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:
|
||||||
|
|||||||
@ -321,6 +321,8 @@ def migrate_knowledge_vector_database():
|
|||||||
)
|
)
|
||||||
|
|
||||||
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
|
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
|
||||||
|
if not datasets.items:
|
||||||
|
break
|
||||||
except SQLAlchemyError:
|
except SQLAlchemyError:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|||||||
@ -337,6 +337,11 @@ class HttpConfig(BaseSettings):
|
|||||||
HTTP-related configurations for the application
|
HTTP-related configurations for the application
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
COOKIE_DOMAIN: str = Field(
|
||||||
|
description="Explicit cookie domain for console/service cookies when sharing across subdomains",
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
|
||||||
API_COMPRESSION_ENABLED: bool = Field(
|
API_COMPRESSION_ENABLED: bool = Field(
|
||||||
description="Enable or disable gzip compression for HTTP responses",
|
description="Enable or disable gzip compression for HTTP responses",
|
||||||
default=False,
|
default=False,
|
||||||
|
|||||||
@ -1,10 +1,9 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from decimal import Decimal
|
|
||||||
|
|
||||||
import pytz
|
import pytz
|
||||||
import sqlalchemy as sa
|
|
||||||
from flask import jsonify
|
from flask import jsonify
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api, console_ns
|
||||||
from controllers.console.app.wraps import get_app_model
|
from controllers.console.app.wraps import get_app_model
|
||||||
@ -14,10 +13,16 @@ from libs.helper import DatetimeString
|
|||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.enums import WorkflowRunTriggeredFrom
|
from models.enums import WorkflowRunTriggeredFrom
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
|
from repositories.factory import DifyAPIRepositoryFactory
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
|
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
|
||||||
class WorkflowDailyRunsStatistic(Resource):
|
class WorkflowDailyRunsStatistic(Resource):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||||
|
|
||||||
@api.doc("get_workflow_daily_runs_statistic")
|
@api.doc("get_workflow_daily_runs_statistic")
|
||||||
@api.doc(description="Get workflow daily runs statistics")
|
@api.doc(description="Get workflow daily runs statistics")
|
||||||
@api.doc(params={"app_id": "Application ID"})
|
@api.doc(params={"app_id": "Application ID"})
|
||||||
@ -37,57 +42,44 @@ class WorkflowDailyRunsStatistic(Resource):
|
|||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """SELECT
|
|
||||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
|
||||||
COUNT(id) AS runs
|
|
||||||
FROM
|
|
||||||
workflow_runs
|
|
||||||
WHERE
|
|
||||||
app_id = :app_id
|
|
||||||
AND triggered_from = :triggered_from"""
|
|
||||||
arg_dict = {
|
|
||||||
"tz": account.timezone,
|
|
||||||
"app_id": app_model.id,
|
|
||||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
|
||||||
}
|
|
||||||
assert account.timezone is not None
|
assert account.timezone is not None
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
utc_timezone = pytz.utc
|
utc_timezone = pytz.utc
|
||||||
|
|
||||||
|
start_date = None
|
||||||
|
end_date = None
|
||||||
|
|
||||||
if args["start"]:
|
if args["start"]:
|
||||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||||
start_datetime = start_datetime.replace(second=0)
|
start_datetime = start_datetime.replace(second=0)
|
||||||
|
|
||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_datetime_timezone = timezone.localize(start_datetime)
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
start_date = start_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += " AND created_at >= :start"
|
|
||||||
arg_dict["start"] = start_datetime_utc
|
|
||||||
|
|
||||||
if args["end"]:
|
if args["end"]:
|
||||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||||
end_datetime = end_datetime.replace(second=0)
|
end_datetime = end_datetime.replace(second=0)
|
||||||
|
|
||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
end_datetime_timezone = timezone.localize(end_datetime)
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
end_date = end_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += " AND created_at < :end"
|
response_data = self._workflow_run_repo.get_daily_runs_statistics(
|
||||||
arg_dict["end"] = end_datetime_utc
|
tenant_id=app_model.tenant_id,
|
||||||
|
app_id=app_model.id,
|
||||||
sql_query += " GROUP BY date ORDER BY date"
|
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||||
|
start_date=start_date,
|
||||||
response_data = []
|
end_date=end_date,
|
||||||
|
timezone=account.timezone,
|
||||||
with db.engine.begin() as conn:
|
)
|
||||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
|
||||||
for i in rs:
|
|
||||||
response_data.append({"date": str(i.date), "runs": i.runs})
|
|
||||||
|
|
||||||
return jsonify({"data": response_data})
|
return jsonify({"data": response_data})
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-terminals")
|
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-terminals")
|
||||||
class WorkflowDailyTerminalsStatistic(Resource):
|
class WorkflowDailyTerminalsStatistic(Resource):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||||
|
|
||||||
@api.doc("get_workflow_daily_terminals_statistic")
|
@api.doc("get_workflow_daily_terminals_statistic")
|
||||||
@api.doc(description="Get workflow daily terminals statistics")
|
@api.doc(description="Get workflow daily terminals statistics")
|
||||||
@api.doc(params={"app_id": "Application ID"})
|
@api.doc(params={"app_id": "Application ID"})
|
||||||
@ -107,57 +99,44 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
|||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """SELECT
|
|
||||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
|
||||||
COUNT(DISTINCT workflow_runs.created_by) AS terminal_count
|
|
||||||
FROM
|
|
||||||
workflow_runs
|
|
||||||
WHERE
|
|
||||||
app_id = :app_id
|
|
||||||
AND triggered_from = :triggered_from"""
|
|
||||||
arg_dict = {
|
|
||||||
"tz": account.timezone,
|
|
||||||
"app_id": app_model.id,
|
|
||||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
|
||||||
}
|
|
||||||
assert account.timezone is not None
|
assert account.timezone is not None
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
utc_timezone = pytz.utc
|
utc_timezone = pytz.utc
|
||||||
|
|
||||||
|
start_date = None
|
||||||
|
end_date = None
|
||||||
|
|
||||||
if args["start"]:
|
if args["start"]:
|
||||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||||
start_datetime = start_datetime.replace(second=0)
|
start_datetime = start_datetime.replace(second=0)
|
||||||
|
|
||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_datetime_timezone = timezone.localize(start_datetime)
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
start_date = start_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += " AND created_at >= :start"
|
|
||||||
arg_dict["start"] = start_datetime_utc
|
|
||||||
|
|
||||||
if args["end"]:
|
if args["end"]:
|
||||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||||
end_datetime = end_datetime.replace(second=0)
|
end_datetime = end_datetime.replace(second=0)
|
||||||
|
|
||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
end_datetime_timezone = timezone.localize(end_datetime)
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
end_date = end_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += " AND created_at < :end"
|
response_data = self._workflow_run_repo.get_daily_terminals_statistics(
|
||||||
arg_dict["end"] = end_datetime_utc
|
tenant_id=app_model.tenant_id,
|
||||||
|
app_id=app_model.id,
|
||||||
sql_query += " GROUP BY date ORDER BY date"
|
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||||
|
start_date=start_date,
|
||||||
response_data = []
|
end_date=end_date,
|
||||||
|
timezone=account.timezone,
|
||||||
with db.engine.begin() as conn:
|
)
|
||||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
|
||||||
for i in rs:
|
|
||||||
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
|
|
||||||
|
|
||||||
return jsonify({"data": response_data})
|
return jsonify({"data": response_data})
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/token-costs")
|
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/token-costs")
|
||||||
class WorkflowDailyTokenCostStatistic(Resource):
|
class WorkflowDailyTokenCostStatistic(Resource):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||||
|
|
||||||
@api.doc("get_workflow_daily_token_cost_statistic")
|
@api.doc("get_workflow_daily_token_cost_statistic")
|
||||||
@api.doc(description="Get workflow daily token cost statistics")
|
@api.doc(description="Get workflow daily token cost statistics")
|
||||||
@api.doc(params={"app_id": "Application ID"})
|
@api.doc(params={"app_id": "Application ID"})
|
||||||
@ -177,62 +156,44 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
|||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """SELECT
|
|
||||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
|
||||||
SUM(workflow_runs.total_tokens) AS token_count
|
|
||||||
FROM
|
|
||||||
workflow_runs
|
|
||||||
WHERE
|
|
||||||
app_id = :app_id
|
|
||||||
AND triggered_from = :triggered_from"""
|
|
||||||
arg_dict = {
|
|
||||||
"tz": account.timezone,
|
|
||||||
"app_id": app_model.id,
|
|
||||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
|
||||||
}
|
|
||||||
assert account.timezone is not None
|
assert account.timezone is not None
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
utc_timezone = pytz.utc
|
utc_timezone = pytz.utc
|
||||||
|
|
||||||
|
start_date = None
|
||||||
|
end_date = None
|
||||||
|
|
||||||
if args["start"]:
|
if args["start"]:
|
||||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||||
start_datetime = start_datetime.replace(second=0)
|
start_datetime = start_datetime.replace(second=0)
|
||||||
|
|
||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_datetime_timezone = timezone.localize(start_datetime)
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
start_date = start_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += " AND created_at >= :start"
|
|
||||||
arg_dict["start"] = start_datetime_utc
|
|
||||||
|
|
||||||
if args["end"]:
|
if args["end"]:
|
||||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||||
end_datetime = end_datetime.replace(second=0)
|
end_datetime = end_datetime.replace(second=0)
|
||||||
|
|
||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
end_datetime_timezone = timezone.localize(end_datetime)
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
end_date = end_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += " AND created_at < :end"
|
response_data = self._workflow_run_repo.get_daily_token_cost_statistics(
|
||||||
arg_dict["end"] = end_datetime_utc
|
tenant_id=app_model.tenant_id,
|
||||||
|
app_id=app_model.id,
|
||||||
sql_query += " GROUP BY date ORDER BY date"
|
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||||
|
start_date=start_date,
|
||||||
response_data = []
|
end_date=end_date,
|
||||||
|
timezone=account.timezone,
|
||||||
with db.engine.begin() as conn:
|
)
|
||||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
|
||||||
for i in rs:
|
|
||||||
response_data.append(
|
|
||||||
{
|
|
||||||
"date": str(i.date),
|
|
||||||
"token_count": i.token_count,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return jsonify({"data": response_data})
|
return jsonify({"data": response_data})
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/average-app-interactions")
|
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/average-app-interactions")
|
||||||
class WorkflowAverageAppInteractionStatistic(Resource):
|
class WorkflowAverageAppInteractionStatistic(Resource):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||||
|
|
||||||
@api.doc("get_workflow_average_app_interaction_statistic")
|
@api.doc("get_workflow_average_app_interaction_statistic")
|
||||||
@api.doc(description="Get workflow average app interaction statistics")
|
@api.doc(description="Get workflow average app interaction statistics")
|
||||||
@api.doc(params={"app_id": "Application ID"})
|
@api.doc(params={"app_id": "Application ID"})
|
||||||
@ -252,67 +213,32 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
|||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """SELECT
|
|
||||||
AVG(sub.interactions) AS interactions,
|
|
||||||
sub.date
|
|
||||||
FROM
|
|
||||||
(
|
|
||||||
SELECT
|
|
||||||
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
|
||||||
c.created_by,
|
|
||||||
COUNT(c.id) AS interactions
|
|
||||||
FROM
|
|
||||||
workflow_runs c
|
|
||||||
WHERE
|
|
||||||
c.app_id = :app_id
|
|
||||||
AND c.triggered_from = :triggered_from
|
|
||||||
{{start}}
|
|
||||||
{{end}}
|
|
||||||
GROUP BY
|
|
||||||
date, c.created_by
|
|
||||||
) sub
|
|
||||||
GROUP BY
|
|
||||||
sub.date"""
|
|
||||||
arg_dict = {
|
|
||||||
"tz": account.timezone,
|
|
||||||
"app_id": app_model.id,
|
|
||||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
|
||||||
}
|
|
||||||
assert account.timezone is not None
|
assert account.timezone is not None
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
utc_timezone = pytz.utc
|
utc_timezone = pytz.utc
|
||||||
|
|
||||||
|
start_date = None
|
||||||
|
end_date = None
|
||||||
|
|
||||||
if args["start"]:
|
if args["start"]:
|
||||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||||
start_datetime = start_datetime.replace(second=0)
|
start_datetime = start_datetime.replace(second=0)
|
||||||
|
|
||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_datetime_timezone = timezone.localize(start_datetime)
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
start_date = start_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query = sql_query.replace("{{start}}", " AND c.created_at >= :start")
|
|
||||||
arg_dict["start"] = start_datetime_utc
|
|
||||||
else:
|
|
||||||
sql_query = sql_query.replace("{{start}}", "")
|
|
||||||
|
|
||||||
if args["end"]:
|
if args["end"]:
|
||||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||||
end_datetime = end_datetime.replace(second=0)
|
end_datetime = end_datetime.replace(second=0)
|
||||||
|
|
||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
end_datetime_timezone = timezone.localize(end_datetime)
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
end_date = end_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query = sql_query.replace("{{end}}", " AND c.created_at < :end")
|
response_data = self._workflow_run_repo.get_average_app_interaction_statistics(
|
||||||
arg_dict["end"] = end_datetime_utc
|
tenant_id=app_model.tenant_id,
|
||||||
else:
|
app_id=app_model.id,
|
||||||
sql_query = sql_query.replace("{{end}}", "")
|
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||||
|
start_date=start_date,
|
||||||
response_data = []
|
end_date=end_date,
|
||||||
|
timezone=account.timezone,
|
||||||
with db.engine.begin() as conn:
|
)
|
||||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
|
||||||
for i in rs:
|
|
||||||
response_data.append(
|
|
||||||
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
|
|
||||||
)
|
|
||||||
|
|
||||||
return jsonify({"data": response_data})
|
return jsonify({"data": response_data})
|
||||||
|
|||||||
@ -29,6 +29,7 @@ from libs.token import (
|
|||||||
clear_access_token_from_cookie,
|
clear_access_token_from_cookie,
|
||||||
clear_csrf_token_from_cookie,
|
clear_csrf_token_from_cookie,
|
||||||
clear_refresh_token_from_cookie,
|
clear_refresh_token_from_cookie,
|
||||||
|
extract_refresh_token,
|
||||||
set_access_token_to_cookie,
|
set_access_token_to_cookie,
|
||||||
set_csrf_token_to_cookie,
|
set_csrf_token_to_cookie,
|
||||||
set_refresh_token_to_cookie,
|
set_refresh_token_to_cookie,
|
||||||
@ -270,7 +271,7 @@ class EmailCodeLoginApi(Resource):
|
|||||||
class RefreshTokenApi(Resource):
|
class RefreshTokenApi(Resource):
|
||||||
def post(self):
|
def post(self):
|
||||||
# Get refresh token from cookie instead of request body
|
# Get refresh token from cookie instead of request body
|
||||||
refresh_token = request.cookies.get("refresh_token")
|
refresh_token = extract_refresh_token(request)
|
||||||
|
|
||||||
if not refresh_token:
|
if not refresh_token:
|
||||||
return {"result": "fail", "message": "No refresh token provided"}, 401
|
return {"result": "fail", "message": "No refresh token provided"}, 401
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from flask_restx import ( # type: ignore
|
|||||||
)
|
)
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
from controllers.console import console_ns
|
from controllers.console import api, console_ns
|
||||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_user, login_required
|
||||||
@ -12,9 +12,17 @@ from models import Account
|
|||||||
from models.dataset import Pipeline
|
from models.dataset import Pipeline
|
||||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||||
|
|
||||||
|
parser = (
|
||||||
|
reqparse.RequestParser()
|
||||||
|
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||||
|
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||||
|
.add_argument("credential_id", type=str, required=False, location="json")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
|
||||||
class DataSourceContentPreviewApi(Resource):
|
class DataSourceContentPreviewApi(Resource):
|
||||||
|
@api.expect(parser)
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@ -26,12 +34,6 @@ class DataSourceContentPreviewApi(Resource):
|
|||||||
if not isinstance(current_user, Account):
|
if not isinstance(current_user, Account):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = (
|
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
|
||||||
.add_argument("credential_id", type=str, required=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
inputs = args.get("inputs")
|
inputs = args.get("inputs")
|
||||||
|
|||||||
@ -39,6 +39,7 @@ class FileApi(Resource):
|
|||||||
return {
|
return {
|
||||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||||
"batch_count_limit": dify_config.UPLOAD_FILE_BATCH_LIMIT,
|
"batch_count_limit": dify_config.UPLOAD_FILE_BATCH_LIMIT,
|
||||||
|
"file_upload_limit": dify_config.BATCH_UPLOAD_LIMIT,
|
||||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from flask_restx import (
|
|||||||
Resource,
|
Resource,
|
||||||
reqparse,
|
reqparse,
|
||||||
)
|
)
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
@ -15,20 +16,21 @@ from controllers.console.wraps import (
|
|||||||
enterprise_license_required,
|
enterprise_license_required,
|
||||||
setup_required,
|
setup_required,
|
||||||
)
|
)
|
||||||
|
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
||||||
from core.mcp.auth.auth_flow import auth, handle_callback
|
from core.mcp.auth.auth_flow import auth, handle_callback
|
||||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
|
||||||
from core.mcp.error import MCPAuthError, MCPError
|
|
||||||
from core.mcp.mcp_client import MCPClient
|
from core.mcp.mcp_client import MCPClient
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.plugin.impl.oauth import OAuthHandler
|
from core.plugin.impl.oauth import OAuthHandler
|
||||||
from core.tools.entities.tool_entities import CredentialType
|
from core.tools.entities.tool_entities import CredentialType
|
||||||
|
from extensions.ext_database import db
|
||||||
from libs.helper import StrLen, alphanumeric, uuid_value
|
from libs.helper import StrLen, alphanumeric, uuid_value
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.provider_ids import ToolProviderID
|
from models.provider_ids import ToolProviderID
|
||||||
from services.plugin.oauth_service import OAuthProxyService
|
from services.plugin.oauth_service import OAuthProxyService
|
||||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
from services.tools.mcp_tools_manage_service import MCPToolManageService, OAuthDataType
|
||||||
from services.tools.tool_labels_service import ToolLabelsService
|
from services.tools.tool_labels_service import ToolLabelsService
|
||||||
from services.tools.tools_manage_service import ToolCommonService
|
from services.tools.tools_manage_service import ToolCommonService
|
||||||
from services.tools.tools_transform_service import ToolTransformService
|
from services.tools.tools_transform_service import ToolTransformService
|
||||||
@ -42,7 +44,9 @@ def is_valid_url(url: str) -> bool:
|
|||||||
try:
|
try:
|
||||||
parsed = urlparse(url)
|
parsed = urlparse(url)
|
||||||
return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"]
|
return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"]
|
||||||
except Exception:
|
except (ValueError, TypeError):
|
||||||
|
# ValueError: Invalid URL format
|
||||||
|
# TypeError: url is not a string
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@ -886,29 +890,34 @@ class ToolProviderMCPApi(Resource):
|
|||||||
.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
|
.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
|
||||||
.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
|
.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
|
||||||
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||||
.add_argument("timeout", type=float, required=False, nullable=False, location="json", default=30)
|
.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
|
||||||
.add_argument("sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300)
|
|
||||||
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
|
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
|
||||||
|
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
user, tenant_id = current_account_with_tenant()
|
user, tenant_id = current_account_with_tenant()
|
||||||
if not is_valid_url(args["server_url"]):
|
|
||||||
raise ValueError("Server URL is not valid.")
|
# Parse and validate models
|
||||||
return jsonable_encoder(
|
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||||
MCPToolManageService.create_mcp_provider(
|
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||||
|
|
||||||
|
# Create provider
|
||||||
|
with Session(db.engine) as session, session.begin():
|
||||||
|
service = MCPToolManageService(session=session)
|
||||||
|
result = service.create_provider(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
|
user_id=user.id,
|
||||||
server_url=args["server_url"],
|
server_url=args["server_url"],
|
||||||
name=args["name"],
|
name=args["name"],
|
||||||
icon=args["icon"],
|
icon=args["icon"],
|
||||||
icon_type=args["icon_type"],
|
icon_type=args["icon_type"],
|
||||||
icon_background=args["icon_background"],
|
icon_background=args["icon_background"],
|
||||||
user_id=user.id,
|
|
||||||
server_identifier=args["server_identifier"],
|
server_identifier=args["server_identifier"],
|
||||||
timeout=args["timeout"],
|
|
||||||
sse_read_timeout=args["sse_read_timeout"],
|
|
||||||
headers=args["headers"],
|
headers=args["headers"],
|
||||||
|
configuration=configuration,
|
||||||
|
authentication=authentication,
|
||||||
)
|
)
|
||||||
)
|
return jsonable_encoder(result)
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -923,31 +932,43 @@ class ToolProviderMCPApi(Resource):
|
|||||||
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||||
.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||||
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||||
.add_argument("timeout", type=float, required=False, nullable=True, location="json")
|
.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
|
||||||
.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json")
|
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
|
||||||
.add_argument("headers", type=dict, required=False, nullable=True, location="json")
|
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if not is_valid_url(args["server_url"]):
|
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||||
if "[__HIDDEN__]" in args["server_url"]:
|
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise ValueError("Server URL is not valid.")
|
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
MCPToolManageService.update_mcp_provider(
|
|
||||||
tenant_id=current_tenant_id,
|
# Step 1: Validate server URL change if needed (includes URL format validation and network operation)
|
||||||
provider_id=args["provider_id"],
|
validation_result = None
|
||||||
server_url=args["server_url"],
|
with Session(db.engine) as session:
|
||||||
name=args["name"],
|
service = MCPToolManageService(session=session)
|
||||||
icon=args["icon"],
|
validation_result = service.validate_server_url_change(
|
||||||
icon_type=args["icon_type"],
|
tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"]
|
||||||
icon_background=args["icon_background"],
|
)
|
||||||
server_identifier=args["server_identifier"],
|
|
||||||
timeout=args.get("timeout"),
|
# No need to check for errors here, exceptions will be raised directly
|
||||||
sse_read_timeout=args.get("sse_read_timeout"),
|
|
||||||
headers=args.get("headers"),
|
# Step 2: Perform database update in a transaction
|
||||||
)
|
with Session(db.engine) as session, session.begin():
|
||||||
return {"result": "success"}
|
service = MCPToolManageService(session=session)
|
||||||
|
service.update_provider(
|
||||||
|
tenant_id=current_tenant_id,
|
||||||
|
provider_id=args["provider_id"],
|
||||||
|
server_url=args["server_url"],
|
||||||
|
name=args["name"],
|
||||||
|
icon=args["icon"],
|
||||||
|
icon_type=args["icon_type"],
|
||||||
|
icon_background=args["icon_background"],
|
||||||
|
server_identifier=args["server_identifier"],
|
||||||
|
headers=args["headers"],
|
||||||
|
configuration=configuration,
|
||||||
|
authentication=authentication,
|
||||||
|
validation_result=validation_result,
|
||||||
|
)
|
||||||
|
return {"result": "success"}
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -958,8 +979,11 @@ class ToolProviderMCPApi(Resource):
|
|||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
MCPToolManageService.delete_mcp_tool(tenant_id=current_tenant_id, provider_id=args["provider_id"])
|
|
||||||
return {"result": "success"}
|
with Session(db.engine) as session, session.begin():
|
||||||
|
service = MCPToolManageService(session=session)
|
||||||
|
service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"])
|
||||||
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/workspaces/current/tool-provider/mcp/auth")
|
@console_ns.route("/workspaces/current/tool-provider/mcp/auth")
|
||||||
@ -976,37 +1000,53 @@ class ToolMCPAuthApi(Resource):
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
provider_id = args["provider_id"]
|
provider_id = args["provider_id"]
|
||||||
_, tenant_id = current_account_with_tenant()
|
_, tenant_id = current_account_with_tenant()
|
||||||
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
|
||||||
if not provider:
|
|
||||||
raise ValueError("provider not found")
|
|
||||||
try:
|
|
||||||
with MCPClient(
|
|
||||||
provider.decrypted_server_url,
|
|
||||||
provider_id,
|
|
||||||
tenant_id,
|
|
||||||
authed=False,
|
|
||||||
authorization_code=args["authorization_code"],
|
|
||||||
for_list=True,
|
|
||||||
headers=provider.decrypted_headers,
|
|
||||||
timeout=provider.timeout,
|
|
||||||
sse_read_timeout=provider.sse_read_timeout,
|
|
||||||
):
|
|
||||||
MCPToolManageService.update_mcp_provider_credentials(
|
|
||||||
mcp_provider=provider,
|
|
||||||
credentials=provider.decrypted_credentials,
|
|
||||||
authed=True,
|
|
||||||
)
|
|
||||||
return {"result": "success"}
|
|
||||||
|
|
||||||
except MCPAuthError:
|
with Session(db.engine) as session, session.begin():
|
||||||
auth_provider = OAuthClientProvider(provider_id, tenant_id, for_list=True)
|
service = MCPToolManageService(session=session)
|
||||||
return auth(auth_provider, provider.decrypted_server_url, args["authorization_code"])
|
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
|
if not db_provider:
|
||||||
|
raise ValueError("provider not found")
|
||||||
|
|
||||||
|
# Convert to entity
|
||||||
|
provider_entity = db_provider.to_entity()
|
||||||
|
server_url = provider_entity.decrypt_server_url()
|
||||||
|
headers = provider_entity.decrypt_authentication()
|
||||||
|
|
||||||
|
# Try to connect without active transaction
|
||||||
|
try:
|
||||||
|
# Use MCPClientWithAuthRetry to handle authentication automatically
|
||||||
|
with MCPClient(
|
||||||
|
server_url=server_url,
|
||||||
|
headers=headers,
|
||||||
|
timeout=provider_entity.timeout,
|
||||||
|
sse_read_timeout=provider_entity.sse_read_timeout,
|
||||||
|
):
|
||||||
|
# Update credentials in new transaction
|
||||||
|
with Session(db.engine) as session, session.begin():
|
||||||
|
service = MCPToolManageService(session=session)
|
||||||
|
service.update_provider_credentials(
|
||||||
|
provider_id=provider_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
credentials=provider_entity.credentials,
|
||||||
|
authed=True,
|
||||||
|
)
|
||||||
|
return {"result": "success"}
|
||||||
|
except MCPAuthError as e:
|
||||||
|
try:
|
||||||
|
auth_result = auth(provider_entity, args.get("authorization_code"))
|
||||||
|
with Session(db.engine) as session, session.begin():
|
||||||
|
service = MCPToolManageService(session=session)
|
||||||
|
response = service.execute_auth_actions(auth_result)
|
||||||
|
return response
|
||||||
|
except MCPRefreshTokenError as e:
|
||||||
|
with Session(db.engine) as session, session.begin():
|
||||||
|
service = MCPToolManageService(session=session)
|
||||||
|
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
|
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
|
||||||
except MCPError as e:
|
except MCPError as e:
|
||||||
MCPToolManageService.update_mcp_provider_credentials(
|
with Session(db.engine) as session, session.begin():
|
||||||
mcp_provider=provider,
|
service = MCPToolManageService(session=session)
|
||||||
credentials={},
|
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
authed=False,
|
|
||||||
)
|
|
||||||
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
@ -1017,8 +1057,10 @@ class ToolMCPDetailApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider_id):
|
def get(self, provider_id):
|
||||||
_, tenant_id = current_account_with_tenant()
|
_, tenant_id = current_account_with_tenant()
|
||||||
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
with Session(db.engine) as session, session.begin():
|
||||||
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
|
service = MCPToolManageService(session=session)
|
||||||
|
provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
|
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/workspaces/current/tools/mcp")
|
@console_ns.route("/workspaces/current/tools/mcp")
|
||||||
@ -1029,9 +1071,12 @@ class ToolMCPListAllApi(Resource):
|
|||||||
def get(self):
|
def get(self):
|
||||||
_, tenant_id = current_account_with_tenant()
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id)
|
with Session(db.engine) as session, session.begin():
|
||||||
|
service = MCPToolManageService(session=session)
|
||||||
|
# Skip sensitive data decryption for list view to improve performance
|
||||||
|
tools = service.list_providers(tenant_id=tenant_id, include_sensitive=False)
|
||||||
|
|
||||||
return [tool.to_dict() for tool in tools]
|
return [tool.to_dict() for tool in tools]
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/workspaces/current/tool-provider/mcp/update/<path:provider_id>")
|
@console_ns.route("/workspaces/current/tool-provider/mcp/update/<path:provider_id>")
|
||||||
@ -1041,11 +1086,13 @@ class ToolMCPUpdateApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider_id):
|
def get(self, provider_id):
|
||||||
_, tenant_id = current_account_with_tenant()
|
_, tenant_id = current_account_with_tenant()
|
||||||
tools = MCPToolManageService.list_mcp_tool_from_remote_server(
|
with Session(db.engine) as session, session.begin():
|
||||||
tenant_id=tenant_id,
|
service = MCPToolManageService(session=session)
|
||||||
provider_id=provider_id,
|
tools = service.list_provider_tools(
|
||||||
)
|
tenant_id=tenant_id,
|
||||||
return jsonable_encoder(tools)
|
provider_id=provider_id,
|
||||||
|
)
|
||||||
|
return jsonable_encoder(tools)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/mcp/oauth/callback")
|
@console_ns.route("/mcp/oauth/callback")
|
||||||
@ -1059,5 +1106,15 @@ class ToolMCPCallbackApi(Resource):
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
state_key = args["state"]
|
state_key = args["state"]
|
||||||
authorization_code = args["code"]
|
authorization_code = args["code"]
|
||||||
handle_callback(state_key, authorization_code)
|
|
||||||
|
# Create service instance for handle_callback
|
||||||
|
with Session(db.engine) as session, session.begin():
|
||||||
|
mcp_service = MCPToolManageService(session=session)
|
||||||
|
# handle_callback now returns state data and tokens
|
||||||
|
state_data, tokens = handle_callback(state_key, authorization_code)
|
||||||
|
# Save tokens using the service layer
|
||||||
|
mcp_service.save_oauth_data(
|
||||||
|
state_data.provider_id, state_data.tenant_id, tokens.model_dump(), OAuthDataType.TOKENS
|
||||||
|
)
|
||||||
|
|
||||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||||
|
|||||||
@ -14,10 +14,25 @@ from services.file_service import FileService
|
|||||||
|
|
||||||
@files_ns.route("/<uuid:file_id>/image-preview")
|
@files_ns.route("/<uuid:file_id>/image-preview")
|
||||||
class ImagePreviewApi(Resource):
|
class ImagePreviewApi(Resource):
|
||||||
"""
|
"""Deprecated endpoint for retrieving image previews."""
|
||||||
Deprecated
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
@files_ns.doc("get_image_preview")
|
||||||
|
@files_ns.doc(description="Retrieve a signed image preview for a file")
|
||||||
|
@files_ns.doc(
|
||||||
|
params={
|
||||||
|
"file_id": "ID of the file to preview",
|
||||||
|
"timestamp": "Unix timestamp used in the signature",
|
||||||
|
"nonce": "Random string used in the signature",
|
||||||
|
"sign": "HMAC signature verifying the request",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@files_ns.doc(
|
||||||
|
responses={
|
||||||
|
200: "Image preview returned successfully",
|
||||||
|
400: "Missing or invalid signature parameters",
|
||||||
|
415: "Unsupported file type",
|
||||||
|
}
|
||||||
|
)
|
||||||
def get(self, file_id):
|
def get(self, file_id):
|
||||||
file_id = str(file_id)
|
file_id = str(file_id)
|
||||||
|
|
||||||
@ -43,6 +58,25 @@ class ImagePreviewApi(Resource):
|
|||||||
|
|
||||||
@files_ns.route("/<uuid:file_id>/file-preview")
|
@files_ns.route("/<uuid:file_id>/file-preview")
|
||||||
class FilePreviewApi(Resource):
|
class FilePreviewApi(Resource):
|
||||||
|
@files_ns.doc("get_file_preview")
|
||||||
|
@files_ns.doc(description="Download a file preview or attachment using signed parameters")
|
||||||
|
@files_ns.doc(
|
||||||
|
params={
|
||||||
|
"file_id": "ID of the file to preview",
|
||||||
|
"timestamp": "Unix timestamp used in the signature",
|
||||||
|
"nonce": "Random string used in the signature",
|
||||||
|
"sign": "HMAC signature verifying the request",
|
||||||
|
"as_attachment": "Whether to download the file as an attachment",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@files_ns.doc(
|
||||||
|
responses={
|
||||||
|
200: "File stream returned successfully",
|
||||||
|
400: "Missing or invalid signature parameters",
|
||||||
|
404: "File not found",
|
||||||
|
415: "Unsupported file type",
|
||||||
|
}
|
||||||
|
)
|
||||||
def get(self, file_id):
|
def get(self, file_id):
|
||||||
file_id = str(file_id)
|
file_id = str(file_id)
|
||||||
|
|
||||||
@ -101,6 +135,20 @@ class FilePreviewApi(Resource):
|
|||||||
|
|
||||||
@files_ns.route("/workspaces/<uuid:workspace_id>/webapp-logo")
|
@files_ns.route("/workspaces/<uuid:workspace_id>/webapp-logo")
|
||||||
class WorkspaceWebappLogoApi(Resource):
|
class WorkspaceWebappLogoApi(Resource):
|
||||||
|
@files_ns.doc("get_workspace_webapp_logo")
|
||||||
|
@files_ns.doc(description="Fetch the custom webapp logo for a workspace")
|
||||||
|
@files_ns.doc(
|
||||||
|
params={
|
||||||
|
"workspace_id": "Workspace identifier",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@files_ns.doc(
|
||||||
|
responses={
|
||||||
|
200: "Logo returned successfully",
|
||||||
|
404: "Webapp logo not configured",
|
||||||
|
415: "Unsupported file type",
|
||||||
|
}
|
||||||
|
)
|
||||||
def get(self, workspace_id):
|
def get(self, workspace_id):
|
||||||
workspace_id = str(workspace_id)
|
workspace_id = str(workspace_id)
|
||||||
|
|
||||||
|
|||||||
@ -13,6 +13,26 @@ from extensions.ext_database import db as global_db
|
|||||||
|
|
||||||
@files_ns.route("/tools/<uuid:file_id>.<string:extension>")
|
@files_ns.route("/tools/<uuid:file_id>.<string:extension>")
|
||||||
class ToolFileApi(Resource):
|
class ToolFileApi(Resource):
|
||||||
|
@files_ns.doc("get_tool_file")
|
||||||
|
@files_ns.doc(description="Download a tool file by ID using signed parameters")
|
||||||
|
@files_ns.doc(
|
||||||
|
params={
|
||||||
|
"file_id": "Tool file identifier",
|
||||||
|
"extension": "Expected file extension",
|
||||||
|
"timestamp": "Unix timestamp used in the signature",
|
||||||
|
"nonce": "Random string used in the signature",
|
||||||
|
"sign": "HMAC signature verifying the request",
|
||||||
|
"as_attachment": "Whether to download the file as an attachment",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@files_ns.doc(
|
||||||
|
responses={
|
||||||
|
200: "Tool file stream returned successfully",
|
||||||
|
403: "Forbidden - invalid signature",
|
||||||
|
404: "File not found",
|
||||||
|
415: "Unsupported file type",
|
||||||
|
}
|
||||||
|
)
|
||||||
def get(self, file_id, extension):
|
def get(self, file_id, extension):
|
||||||
file_id = str(file_id)
|
file_id = str(file_id)
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from dataclasses import dataclass
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, NewType, Union
|
from typing import Any, NewType, Union
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||||
from core.app.entities.queue_entities import (
|
from core.app.entities.queue_entities import (
|
||||||
QueueAgentLogEvent,
|
QueueAgentLogEvent,
|
||||||
QueueIterationCompletedEvent,
|
QueueIterationCompletedEvent,
|
||||||
@ -51,7 +51,7 @@ from core.workflow.workflow_entry import WorkflowEntry
|
|||||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from models import Account, EndUser
|
from models import Account, EndUser
|
||||||
from services.variable_truncator import VariableTruncator
|
from services.variable_truncator import BaseTruncator, DummyVariableTruncator, VariableTruncator
|
||||||
|
|
||||||
NodeExecutionId = NewType("NodeExecutionId", str)
|
NodeExecutionId = NewType("NodeExecutionId", str)
|
||||||
|
|
||||||
@ -70,6 +70,8 @@ class _NodeSnapshot:
|
|||||||
|
|
||||||
|
|
||||||
class WorkflowResponseConverter:
|
class WorkflowResponseConverter:
|
||||||
|
_truncator: BaseTruncator
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@ -81,7 +83,13 @@ class WorkflowResponseConverter:
|
|||||||
self._user = user
|
self._user = user
|
||||||
self._system_variables = system_variables
|
self._system_variables = system_variables
|
||||||
self._workflow_inputs = self._prepare_workflow_inputs()
|
self._workflow_inputs = self._prepare_workflow_inputs()
|
||||||
self._truncator = VariableTruncator.default()
|
|
||||||
|
# Disable truncation for SERVICE_API calls to keep backward compatibility.
|
||||||
|
if application_generate_entity.invoke_from == InvokeFrom.SERVICE_API:
|
||||||
|
self._truncator = DummyVariableTruncator()
|
||||||
|
else:
|
||||||
|
self._truncator = VariableTruncator.default()
|
||||||
|
|
||||||
self._node_snapshots: dict[NodeExecutionId, _NodeSnapshot] = {}
|
self._node_snapshots: dict[NodeExecutionId, _NodeSnapshot] = {}
|
||||||
self._workflow_execution_id: str | None = None
|
self._workflow_execution_id: str | None = None
|
||||||
self._workflow_started_at: datetime | None = None
|
self._workflow_started_at: datetime | None = None
|
||||||
|
|||||||
328
api/core/entities/mcp_provider.py
Normal file
328
api/core/entities/mcp_provider.py
Normal file
@ -0,0 +1,328 @@
|
|||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from core.entities.provider_entities import BasicProviderConfig
|
||||||
|
from core.file import helpers as file_helpers
|
||||||
|
from core.helper import encrypter
|
||||||
|
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||||
|
from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens
|
||||||
|
from core.tools.entities.common_entities import I18nObject
|
||||||
|
from core.tools.entities.tool_entities import ToolProviderType
|
||||||
|
from core.tools.utils.encryption import create_provider_encrypter
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from models.tools import MCPToolProvider
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
CLIENT_NAME = "Dify"
|
||||||
|
CLIENT_URI = "https://github.com/langgenius/dify"
|
||||||
|
DEFAULT_TOKEN_TYPE = "Bearer"
|
||||||
|
DEFAULT_EXPIRES_IN = 3600
|
||||||
|
MASK_CHAR = "*"
|
||||||
|
MIN_UNMASK_LENGTH = 6
|
||||||
|
|
||||||
|
|
||||||
|
class MCPSupportGrantType(StrEnum):
|
||||||
|
"""The supported grant types for MCP"""
|
||||||
|
|
||||||
|
AUTHORIZATION_CODE = "authorization_code"
|
||||||
|
CLIENT_CREDENTIALS = "client_credentials"
|
||||||
|
REFRESH_TOKEN = "refresh_token"
|
||||||
|
|
||||||
|
|
||||||
|
class MCPAuthentication(BaseModel):
|
||||||
|
client_id: str
|
||||||
|
client_secret: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class MCPConfiguration(BaseModel):
|
||||||
|
timeout: float = 30
|
||||||
|
sse_read_timeout: float = 300
|
||||||
|
|
||||||
|
|
||||||
|
class MCPProviderEntity(BaseModel):
|
||||||
|
"""MCP Provider domain entity for business logic operations"""
|
||||||
|
|
||||||
|
# Basic identification
|
||||||
|
id: str
|
||||||
|
provider_id: str # server_identifier
|
||||||
|
name: str
|
||||||
|
tenant_id: str
|
||||||
|
user_id: str
|
||||||
|
|
||||||
|
# Server connection info
|
||||||
|
server_url: str # encrypted URL
|
||||||
|
headers: dict[str, str] # encrypted headers
|
||||||
|
timeout: float
|
||||||
|
sse_read_timeout: float
|
||||||
|
|
||||||
|
# Authentication related
|
||||||
|
authed: bool
|
||||||
|
credentials: dict[str, Any] # encrypted credentials
|
||||||
|
code_verifier: str | None = None # for OAuth
|
||||||
|
|
||||||
|
# Tools and display info
|
||||||
|
tools: list[dict[str, Any]] # parsed tools list
|
||||||
|
icon: str | dict[str, str] # parsed icon
|
||||||
|
|
||||||
|
# Timestamps
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_db_model(cls, db_provider: "MCPToolProvider") -> "MCPProviderEntity":
|
||||||
|
"""Create entity from database model with decryption"""
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
id=db_provider.id,
|
||||||
|
provider_id=db_provider.server_identifier,
|
||||||
|
name=db_provider.name,
|
||||||
|
tenant_id=db_provider.tenant_id,
|
||||||
|
user_id=db_provider.user_id,
|
||||||
|
server_url=db_provider.server_url,
|
||||||
|
headers=db_provider.headers,
|
||||||
|
timeout=db_provider.timeout,
|
||||||
|
sse_read_timeout=db_provider.sse_read_timeout,
|
||||||
|
authed=db_provider.authed,
|
||||||
|
credentials=db_provider.credentials,
|
||||||
|
tools=db_provider.tool_dict,
|
||||||
|
icon=db_provider.icon or "",
|
||||||
|
created_at=db_provider.created_at,
|
||||||
|
updated_at=db_provider.updated_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def redirect_url(self) -> str:
|
||||||
|
"""OAuth redirect URL"""
|
||||||
|
return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client_metadata(self) -> OAuthClientMetadata:
|
||||||
|
"""Metadata about this OAuth client."""
|
||||||
|
# Get grant type from credentials
|
||||||
|
credentials = self.decrypt_credentials()
|
||||||
|
|
||||||
|
# Try to get grant_type from different locations
|
||||||
|
grant_type = credentials.get("grant_type", MCPSupportGrantType.AUTHORIZATION_CODE)
|
||||||
|
|
||||||
|
# For nested structure, check if client_information has grant_types
|
||||||
|
if "client_information" in credentials and isinstance(credentials["client_information"], dict):
|
||||||
|
client_info = credentials["client_information"]
|
||||||
|
# If grant_types is specified in client_information, use it to determine grant_type
|
||||||
|
if "grant_types" in client_info and isinstance(client_info["grant_types"], list):
|
||||||
|
if "client_credentials" in client_info["grant_types"]:
|
||||||
|
grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS
|
||||||
|
elif "authorization_code" in client_info["grant_types"]:
|
||||||
|
grant_type = MCPSupportGrantType.AUTHORIZATION_CODE
|
||||||
|
|
||||||
|
# Configure based on grant type
|
||||||
|
is_client_credentials = grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS
|
||||||
|
|
||||||
|
grant_types = ["refresh_token"]
|
||||||
|
grant_types.append("client_credentials" if is_client_credentials else "authorization_code")
|
||||||
|
|
||||||
|
response_types = [] if is_client_credentials else ["code"]
|
||||||
|
redirect_uris = [] if is_client_credentials else [self.redirect_url]
|
||||||
|
|
||||||
|
return OAuthClientMetadata(
|
||||||
|
redirect_uris=redirect_uris,
|
||||||
|
token_endpoint_auth_method="none",
|
||||||
|
grant_types=grant_types,
|
||||||
|
response_types=response_types,
|
||||||
|
client_name=CLIENT_NAME,
|
||||||
|
client_uri=CLIENT_URI,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_icon(self) -> dict[str, str] | str:
|
||||||
|
"""Get provider icon, handling both dict and string formats"""
|
||||||
|
if isinstance(self.icon, dict):
|
||||||
|
return self.icon
|
||||||
|
try:
|
||||||
|
return json.loads(self.icon)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
# If not JSON, assume it's a file path
|
||||||
|
return file_helpers.get_signed_file_url(self.icon)
|
||||||
|
|
||||||
|
def to_api_response(self, user_name: str | None = None, include_sensitive: bool = True) -> dict[str, Any]:
|
||||||
|
"""Convert to API response format
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_name: User name to display
|
||||||
|
include_sensitive: If False, skip expensive decryption operations (for list view optimization)
|
||||||
|
"""
|
||||||
|
response = {
|
||||||
|
"id": self.id,
|
||||||
|
"author": user_name or "Anonymous",
|
||||||
|
"name": self.name,
|
||||||
|
"icon": self.provider_icon,
|
||||||
|
"type": ToolProviderType.MCP.value,
|
||||||
|
"is_team_authorization": self.authed,
|
||||||
|
"server_url": self.masked_server_url(),
|
||||||
|
"server_identifier": self.provider_id,
|
||||||
|
"updated_at": int(self.updated_at.timestamp()),
|
||||||
|
"label": I18nObject(en_US=self.name, zh_Hans=self.name).to_dict(),
|
||||||
|
"description": I18nObject(en_US="", zh_Hans="").to_dict(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add configuration
|
||||||
|
response["configuration"] = {
|
||||||
|
"timeout": str(self.timeout),
|
||||||
|
"sse_read_timeout": str(self.sse_read_timeout),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Skip expensive operations when sensitive data is not needed (e.g., list view)
|
||||||
|
if not include_sensitive:
|
||||||
|
response["masked_headers"] = {}
|
||||||
|
response["is_dynamic_registration"] = True
|
||||||
|
else:
|
||||||
|
# Add masked headers
|
||||||
|
response["masked_headers"] = self.masked_headers()
|
||||||
|
|
||||||
|
# Add authentication info if available
|
||||||
|
masked_creds = self.masked_credentials()
|
||||||
|
if masked_creds:
|
||||||
|
response["authentication"] = masked_creds
|
||||||
|
response["is_dynamic_registration"] = self.credentials.get("client_information", {}).get(
|
||||||
|
"is_dynamic_registration", True
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def retrieve_client_information(self) -> OAuthClientInformation | None:
|
||||||
|
"""OAuth client information if available"""
|
||||||
|
credentials = self.decrypt_credentials()
|
||||||
|
if not credentials:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if we have nested client_information structure
|
||||||
|
if "client_information" not in credentials:
|
||||||
|
return None
|
||||||
|
client_info_data = credentials["client_information"]
|
||||||
|
if isinstance(client_info_data, dict):
|
||||||
|
if "encrypted_client_secret" in client_info_data:
|
||||||
|
client_info_data["client_secret"] = encrypter.decrypt_token(
|
||||||
|
self.tenant_id, client_info_data["encrypted_client_secret"]
|
||||||
|
)
|
||||||
|
return OAuthClientInformation.model_validate(client_info_data)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def retrieve_tokens(self) -> OAuthTokens | None:
|
||||||
|
"""OAuth tokens if available"""
|
||||||
|
if not self.credentials:
|
||||||
|
return None
|
||||||
|
credentials = self.decrypt_credentials()
|
||||||
|
return OAuthTokens(
|
||||||
|
access_token=credentials.get("access_token", ""),
|
||||||
|
token_type=credentials.get("token_type", DEFAULT_TOKEN_TYPE),
|
||||||
|
expires_in=int(credentials.get("expires_in", str(DEFAULT_EXPIRES_IN)) or DEFAULT_EXPIRES_IN),
|
||||||
|
refresh_token=credentials.get("refresh_token", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
def masked_server_url(self) -> str:
|
||||||
|
"""Masked server URL for display"""
|
||||||
|
parsed = urlparse(self.decrypt_server_url())
|
||||||
|
if parsed.path and parsed.path != "/":
|
||||||
|
masked = parsed._replace(path="/******")
|
||||||
|
return masked.geturl()
|
||||||
|
return parsed.geturl()
|
||||||
|
|
||||||
|
def _mask_value(self, value: str) -> str:
|
||||||
|
"""Mask a sensitive value for display"""
|
||||||
|
if len(value) > MIN_UNMASK_LENGTH:
|
||||||
|
return value[:2] + MASK_CHAR * (len(value) - 4) + value[-2:]
|
||||||
|
else:
|
||||||
|
return MASK_CHAR * len(value)
|
||||||
|
|
||||||
|
def masked_headers(self) -> dict[str, str]:
|
||||||
|
"""Masked headers for display"""
|
||||||
|
return {key: self._mask_value(value) for key, value in self.decrypt_headers().items()}
|
||||||
|
|
||||||
|
def masked_credentials(self) -> dict[str, str]:
|
||||||
|
"""Masked credentials for display"""
|
||||||
|
credentials = self.decrypt_credentials()
|
||||||
|
if not credentials:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
masked = {}
|
||||||
|
|
||||||
|
if "client_information" not in credentials or not isinstance(credentials["client_information"], dict):
|
||||||
|
return {}
|
||||||
|
client_info = credentials["client_information"]
|
||||||
|
# Mask sensitive fields from nested structure
|
||||||
|
if client_info.get("client_id"):
|
||||||
|
masked["client_id"] = self._mask_value(client_info["client_id"])
|
||||||
|
if client_info.get("encrypted_client_secret"):
|
||||||
|
masked["client_secret"] = self._mask_value(
|
||||||
|
encrypter.decrypt_token(self.tenant_id, client_info["encrypted_client_secret"])
|
||||||
|
)
|
||||||
|
if client_info.get("client_secret"):
|
||||||
|
masked["client_secret"] = self._mask_value(client_info["client_secret"])
|
||||||
|
return masked
|
||||||
|
|
||||||
|
def decrypt_server_url(self) -> str:
|
||||||
|
"""Decrypt server URL"""
|
||||||
|
return encrypter.decrypt_token(self.tenant_id, self.server_url)
|
||||||
|
|
||||||
|
def _decrypt_dict(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Generic method to decrypt dictionary fields"""
|
||||||
|
if not data:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Only decrypt fields that are actually encrypted
|
||||||
|
# For nested structures, client_information is not encrypted as a whole
|
||||||
|
encrypted_fields = []
|
||||||
|
for key, value in data.items():
|
||||||
|
# Skip nested objects - they are not encrypted
|
||||||
|
if isinstance(value, dict):
|
||||||
|
continue
|
||||||
|
# Only process string values that might be encrypted
|
||||||
|
if isinstance(value, str) and value:
|
||||||
|
encrypted_fields.append(key)
|
||||||
|
|
||||||
|
if not encrypted_fields:
|
||||||
|
return data
|
||||||
|
|
||||||
|
# Create dynamic config only for encrypted fields
|
||||||
|
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in encrypted_fields]
|
||||||
|
|
||||||
|
encrypter_instance, _ = create_provider_encrypter(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
config=config,
|
||||||
|
cache=NoOpProviderCredentialCache(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decrypt only the encrypted fields
|
||||||
|
decrypted_data = encrypter_instance.decrypt({k: data[k] for k in encrypted_fields})
|
||||||
|
|
||||||
|
# Merge decrypted data with original data (preserving non-encrypted fields)
|
||||||
|
result = data.copy()
|
||||||
|
result.update(decrypted_data)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def decrypt_headers(self) -> dict[str, Any]:
|
||||||
|
"""Decrypt headers"""
|
||||||
|
return self._decrypt_dict(self.headers)
|
||||||
|
|
||||||
|
def decrypt_credentials(self) -> dict[str, Any]:
|
||||||
|
"""Decrypt credentials"""
|
||||||
|
return self._decrypt_dict(self.credentials)
|
||||||
|
|
||||||
|
def decrypt_authentication(self) -> dict[str, Any]:
|
||||||
|
"""Decrypt authentication"""
|
||||||
|
# Option 1: if headers is provided, use it and don't need to get token
|
||||||
|
headers = self.decrypt_headers()
|
||||||
|
|
||||||
|
# Option 2: Add OAuth token if authed and no headers provided
|
||||||
|
if not self.headers and self.authed:
|
||||||
|
token = self.retrieve_tokens()
|
||||||
|
if token:
|
||||||
|
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
|
||||||
|
return headers
|
||||||
@ -6,11 +6,15 @@ import secrets
|
|||||||
import urllib.parse
|
import urllib.parse
|
||||||
from urllib.parse import urljoin, urlparse
|
from urllib.parse import urljoin, urlparse
|
||||||
|
|
||||||
import httpx
|
from httpx import ConnectError, HTTPStatusError, RequestError
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType
|
||||||
|
from core.helper import ssrf_proxy
|
||||||
|
from core.mcp.entities import AuthAction, AuthActionType, AuthResult, OAuthCallbackState
|
||||||
|
from core.mcp.error import MCPRefreshTokenError
|
||||||
from core.mcp.types import (
|
from core.mcp.types import (
|
||||||
|
LATEST_PROTOCOL_VERSION,
|
||||||
OAuthClientInformation,
|
OAuthClientInformation,
|
||||||
OAuthClientInformationFull,
|
OAuthClientInformationFull,
|
||||||
OAuthClientMetadata,
|
OAuthClientMetadata,
|
||||||
@ -19,21 +23,10 @@ from core.mcp.types import (
|
|||||||
)
|
)
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
|
|
||||||
LATEST_PROTOCOL_VERSION = "1.0"
|
|
||||||
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
|
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
|
||||||
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
|
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
|
||||||
|
|
||||||
|
|
||||||
class OAuthCallbackState(BaseModel):
|
|
||||||
provider_id: str
|
|
||||||
tenant_id: str
|
|
||||||
server_url: str
|
|
||||||
metadata: OAuthMetadata | None = None
|
|
||||||
client_information: OAuthClientInformation
|
|
||||||
code_verifier: str
|
|
||||||
redirect_uri: str
|
|
||||||
|
|
||||||
|
|
||||||
def generate_pkce_challenge() -> tuple[str, str]:
|
def generate_pkce_challenge() -> tuple[str, str]:
|
||||||
"""Generate PKCE challenge and verifier."""
|
"""Generate PKCE challenge and verifier."""
|
||||||
code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8")
|
code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8")
|
||||||
@ -80,8 +73,13 @@ def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
|
|||||||
raise ValueError(f"Invalid state parameter: {str(e)}")
|
raise ValueError(f"Invalid state parameter: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackState:
|
def handle_callback(state_key: str, authorization_code: str) -> tuple[OAuthCallbackState, OAuthTokens]:
|
||||||
"""Handle the callback from the OAuth provider."""
|
"""
|
||||||
|
Handle the callback from the OAuth provider.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (callback_state, tokens) that can be used by the caller to save data.
|
||||||
|
"""
|
||||||
# Retrieve state data from Redis (state is automatically deleted after retrieval)
|
# Retrieve state data from Redis (state is automatically deleted after retrieval)
|
||||||
full_state_data = _retrieve_redis_state(state_key)
|
full_state_data = _retrieve_redis_state(state_key)
|
||||||
|
|
||||||
@ -93,30 +91,32 @@ def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackSta
|
|||||||
full_state_data.code_verifier,
|
full_state_data.code_verifier,
|
||||||
full_state_data.redirect_uri,
|
full_state_data.redirect_uri,
|
||||||
)
|
)
|
||||||
provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id, for_list=True)
|
|
||||||
provider.save_tokens(tokens)
|
return full_state_data, tokens
|
||||||
return full_state_data
|
|
||||||
|
|
||||||
|
|
||||||
def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
|
def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
|
||||||
"""Check if the server supports OAuth 2.0 Resource Discovery."""
|
"""Check if the server supports OAuth 2.0 Resource Discovery."""
|
||||||
b_scheme, b_netloc, b_path, _, b_query, b_fragment = urlparse(server_url, "", True)
|
b_scheme, b_netloc, _, _, b_query, b_fragment = urlparse(server_url, "", True)
|
||||||
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource{b_path}"
|
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource"
|
||||||
if b_query:
|
if b_query:
|
||||||
url_for_resource_discovery += f"?{b_query}"
|
url_for_resource_discovery += f"?{b_query}"
|
||||||
if b_fragment:
|
if b_fragment:
|
||||||
url_for_resource_discovery += f"#{b_fragment}"
|
url_for_resource_discovery += f"#{b_fragment}"
|
||||||
try:
|
try:
|
||||||
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
|
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
|
||||||
response = httpx.get(url_for_resource_discovery, headers=headers)
|
response = ssrf_proxy.get(url_for_resource_discovery, headers=headers)
|
||||||
if 200 <= response.status_code < 300:
|
if 200 <= response.status_code < 300:
|
||||||
body = response.json()
|
body = response.json()
|
||||||
if "authorization_server_url" in body:
|
# Support both singular and plural forms
|
||||||
|
if body.get("authorization_servers"):
|
||||||
|
return True, body["authorization_servers"][0]
|
||||||
|
elif body.get("authorization_server_url"):
|
||||||
return True, body["authorization_server_url"][0]
|
return True, body["authorization_server_url"][0]
|
||||||
else:
|
else:
|
||||||
return False, ""
|
return False, ""
|
||||||
return False, ""
|
return False, ""
|
||||||
except httpx.RequestError:
|
except RequestError:
|
||||||
# Not support resource discovery, fall back to well-known OAuth metadata
|
# Not support resource discovery, fall back to well-known OAuth metadata
|
||||||
return False, ""
|
return False, ""
|
||||||
|
|
||||||
@ -126,27 +126,37 @@ def discover_oauth_metadata(server_url: str, protocol_version: str | None = None
|
|||||||
# First check if the server supports OAuth 2.0 Resource Discovery
|
# First check if the server supports OAuth 2.0 Resource Discovery
|
||||||
support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url)
|
support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url)
|
||||||
if support_resource_discovery:
|
if support_resource_discovery:
|
||||||
url = oauth_discovery_url
|
# The oauth_discovery_url is the authorization server base URL
|
||||||
|
# Try OpenID Connect discovery first (more common), then OAuth 2.0
|
||||||
|
urls_to_try = [
|
||||||
|
urljoin(oauth_discovery_url + "/", ".well-known/oauth-authorization-server"),
|
||||||
|
urljoin(oauth_discovery_url + "/", ".well-known/openid-configuration"),
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
url = urljoin(server_url, "/.well-known/oauth-authorization-server")
|
urls_to_try = [urljoin(server_url, "/.well-known/oauth-authorization-server")]
|
||||||
|
|
||||||
try:
|
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
|
||||||
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
|
|
||||||
response = httpx.get(url, headers=headers)
|
for url in urls_to_try:
|
||||||
if response.status_code == 404:
|
try:
|
||||||
return None
|
response = ssrf_proxy.get(url, headers=headers)
|
||||||
if not response.is_success:
|
|
||||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
|
||||||
return OAuthMetadata.model_validate(response.json())
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
if isinstance(e, httpx.ConnectError):
|
|
||||||
response = httpx.get(url)
|
|
||||||
if response.status_code == 404:
|
if response.status_code == 404:
|
||||||
return None
|
continue
|
||||||
if not response.is_success:
|
if not response.is_success:
|
||||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
response.raise_for_status()
|
||||||
return OAuthMetadata.model_validate(response.json())
|
return OAuthMetadata.model_validate(response.json())
|
||||||
raise
|
except (RequestError, HTTPStatusError) as e:
|
||||||
|
if isinstance(e, ConnectError):
|
||||||
|
response = ssrf_proxy.get(url)
|
||||||
|
if response.status_code == 404:
|
||||||
|
continue # Try next URL
|
||||||
|
if not response.is_success:
|
||||||
|
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
||||||
|
return OAuthMetadata.model_validate(response.json())
|
||||||
|
# For other errors, try next URL
|
||||||
|
continue
|
||||||
|
|
||||||
|
return None # No metadata found
|
||||||
|
|
||||||
|
|
||||||
def start_authorization(
|
def start_authorization(
|
||||||
@ -213,7 +223,7 @@ def exchange_authorization(
|
|||||||
redirect_uri: str,
|
redirect_uri: str,
|
||||||
) -> OAuthTokens:
|
) -> OAuthTokens:
|
||||||
"""Exchanges an authorization code for an access token."""
|
"""Exchanges an authorization code for an access token."""
|
||||||
grant_type = "authorization_code"
|
grant_type = MCPSupportGrantType.AUTHORIZATION_CODE.value
|
||||||
|
|
||||||
if metadata:
|
if metadata:
|
||||||
token_url = metadata.token_endpoint
|
token_url = metadata.token_endpoint
|
||||||
@ -233,7 +243,7 @@ def exchange_authorization(
|
|||||||
if client_information.client_secret:
|
if client_information.client_secret:
|
||||||
params["client_secret"] = client_information.client_secret
|
params["client_secret"] = client_information.client_secret
|
||||||
|
|
||||||
response = httpx.post(token_url, data=params)
|
response = ssrf_proxy.post(token_url, data=params)
|
||||||
if not response.is_success:
|
if not response.is_success:
|
||||||
raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
|
raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
|
||||||
return OAuthTokens.model_validate(response.json())
|
return OAuthTokens.model_validate(response.json())
|
||||||
@ -246,7 +256,7 @@ def refresh_authorization(
|
|||||||
refresh_token: str,
|
refresh_token: str,
|
||||||
) -> OAuthTokens:
|
) -> OAuthTokens:
|
||||||
"""Exchange a refresh token for an updated access token."""
|
"""Exchange a refresh token for an updated access token."""
|
||||||
grant_type = "refresh_token"
|
grant_type = MCPSupportGrantType.REFRESH_TOKEN.value
|
||||||
|
|
||||||
if metadata:
|
if metadata:
|
||||||
token_url = metadata.token_endpoint
|
token_url = metadata.token_endpoint
|
||||||
@ -263,10 +273,55 @@ def refresh_authorization(
|
|||||||
|
|
||||||
if client_information.client_secret:
|
if client_information.client_secret:
|
||||||
params["client_secret"] = client_information.client_secret
|
params["client_secret"] = client_information.client_secret
|
||||||
|
try:
|
||||||
response = httpx.post(token_url, data=params)
|
response = ssrf_proxy.post(token_url, data=params)
|
||||||
|
except ssrf_proxy.MaxRetriesExceededError as e:
|
||||||
|
raise MCPRefreshTokenError(e) from e
|
||||||
if not response.is_success:
|
if not response.is_success:
|
||||||
raise ValueError(f"Token refresh failed: HTTP {response.status_code}")
|
raise MCPRefreshTokenError(response.text)
|
||||||
|
return OAuthTokens.model_validate(response.json())
|
||||||
|
|
||||||
|
|
||||||
|
def client_credentials_flow(
|
||||||
|
server_url: str,
|
||||||
|
metadata: OAuthMetadata | None,
|
||||||
|
client_information: OAuthClientInformation,
|
||||||
|
scope: str | None = None,
|
||||||
|
) -> OAuthTokens:
|
||||||
|
"""Execute Client Credentials Flow to get access token."""
|
||||||
|
grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
|
||||||
|
|
||||||
|
if metadata:
|
||||||
|
token_url = metadata.token_endpoint
|
||||||
|
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
|
||||||
|
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
|
||||||
|
else:
|
||||||
|
token_url = urljoin(server_url, "/token")
|
||||||
|
|
||||||
|
# Support both Basic Auth and body parameters for client authentication
|
||||||
|
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||||
|
data = {"grant_type": grant_type}
|
||||||
|
|
||||||
|
if scope:
|
||||||
|
data["scope"] = scope
|
||||||
|
|
||||||
|
# If client_secret is provided, use Basic Auth (preferred method)
|
||||||
|
if client_information.client_secret:
|
||||||
|
credentials = f"{client_information.client_id}:{client_information.client_secret}"
|
||||||
|
encoded_credentials = base64.b64encode(credentials.encode()).decode()
|
||||||
|
headers["Authorization"] = f"Basic {encoded_credentials}"
|
||||||
|
else:
|
||||||
|
# Fall back to including credentials in the body
|
||||||
|
data["client_id"] = client_information.client_id
|
||||||
|
if client_information.client_secret:
|
||||||
|
data["client_secret"] = client_information.client_secret
|
||||||
|
|
||||||
|
response = ssrf_proxy.post(token_url, headers=headers, data=data)
|
||||||
|
if not response.is_success:
|
||||||
|
raise ValueError(
|
||||||
|
f"Client credentials token request failed: HTTP {response.status_code}, Response: {response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
return OAuthTokens.model_validate(response.json())
|
return OAuthTokens.model_validate(response.json())
|
||||||
|
|
||||||
|
|
||||||
@ -283,7 +338,7 @@ def register_client(
|
|||||||
else:
|
else:
|
||||||
registration_url = urljoin(server_url, "/register")
|
registration_url = urljoin(server_url, "/register")
|
||||||
|
|
||||||
response = httpx.post(
|
response = ssrf_proxy.post(
|
||||||
registration_url,
|
registration_url,
|
||||||
json=client_metadata.model_dump(),
|
json=client_metadata.model_dump(),
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
@ -294,28 +349,111 @@ def register_client(
|
|||||||
|
|
||||||
|
|
||||||
def auth(
|
def auth(
|
||||||
provider: OAuthClientProvider,
|
provider: MCPProviderEntity,
|
||||||
server_url: str,
|
|
||||||
authorization_code: str | None = None,
|
authorization_code: str | None = None,
|
||||||
state_param: str | None = None,
|
state_param: str | None = None,
|
||||||
for_list: bool = False,
|
) -> AuthResult:
|
||||||
) -> dict[str, str]:
|
"""
|
||||||
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
|
Orchestrates the full auth flow with a server using secure Redis state storage.
|
||||||
metadata = discover_oauth_metadata(server_url)
|
|
||||||
|
This function performs only network operations and returns actions that need
|
||||||
|
to be performed by the caller (such as saving data to database).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: The MCP provider entity
|
||||||
|
authorization_code: Optional authorization code from OAuth callback
|
||||||
|
state_param: Optional state parameter from OAuth callback
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AuthResult containing actions to be performed and response data
|
||||||
|
"""
|
||||||
|
actions: list[AuthAction] = []
|
||||||
|
server_url = provider.decrypt_server_url()
|
||||||
|
server_metadata = discover_oauth_metadata(server_url)
|
||||||
|
client_metadata = provider.client_metadata
|
||||||
|
provider_id = provider.id
|
||||||
|
tenant_id = provider.tenant_id
|
||||||
|
client_information = provider.retrieve_client_information()
|
||||||
|
redirect_url = provider.redirect_url
|
||||||
|
|
||||||
|
# Determine grant type based on server metadata
|
||||||
|
if not server_metadata:
|
||||||
|
raise ValueError("Failed to discover OAuth metadata from server")
|
||||||
|
|
||||||
|
supported_grant_types = server_metadata.grant_types_supported or []
|
||||||
|
|
||||||
|
# Convert to lowercase for comparison
|
||||||
|
supported_grant_types_lower = [gt.lower() for gt in supported_grant_types]
|
||||||
|
|
||||||
|
# Determine which grant type to use
|
||||||
|
effective_grant_type = None
|
||||||
|
if MCPSupportGrantType.AUTHORIZATION_CODE.value in supported_grant_types_lower:
|
||||||
|
effective_grant_type = MCPSupportGrantType.AUTHORIZATION_CODE.value
|
||||||
|
else:
|
||||||
|
effective_grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
|
||||||
|
|
||||||
|
# Get stored credentials
|
||||||
|
credentials = provider.decrypt_credentials()
|
||||||
|
|
||||||
# Handle client registration if needed
|
|
||||||
client_information = provider.client_information()
|
|
||||||
if not client_information:
|
if not client_information:
|
||||||
if authorization_code is not None:
|
if authorization_code is not None:
|
||||||
raise ValueError("Existing OAuth client information is required when exchanging an authorization code")
|
raise ValueError("Existing OAuth client information is required when exchanging an authorization code")
|
||||||
|
|
||||||
|
# For client credentials flow, we don't need to register client dynamically
|
||||||
|
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
|
||||||
|
# Client should provide client_id and client_secret directly
|
||||||
|
raise ValueError("Client credentials flow requires client_id and client_secret to be provided")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
full_information = register_client(server_url, metadata, provider.client_metadata)
|
full_information = register_client(server_url, server_metadata, client_metadata)
|
||||||
except httpx.RequestError as e:
|
except RequestError as e:
|
||||||
raise ValueError(f"Could not register OAuth client: {e}")
|
raise ValueError(f"Could not register OAuth client: {e}")
|
||||||
provider.save_client_information(full_information)
|
|
||||||
|
# Return action to save client information
|
||||||
|
actions.append(
|
||||||
|
AuthAction(
|
||||||
|
action_type=AuthActionType.SAVE_CLIENT_INFO,
|
||||||
|
data={"client_information": full_information.model_dump()},
|
||||||
|
provider_id=provider_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
client_information = full_information
|
client_information = full_information
|
||||||
|
|
||||||
# Exchange authorization code for tokens
|
# Handle client credentials flow
|
||||||
|
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
|
||||||
|
# Direct token request without user interaction
|
||||||
|
try:
|
||||||
|
scope = credentials.get("scope")
|
||||||
|
tokens = client_credentials_flow(
|
||||||
|
server_url,
|
||||||
|
server_metadata,
|
||||||
|
client_information,
|
||||||
|
scope,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return action to save tokens and grant type
|
||||||
|
token_data = tokens.model_dump()
|
||||||
|
token_data["grant_type"] = MCPSupportGrantType.CLIENT_CREDENTIALS.value
|
||||||
|
|
||||||
|
actions.append(
|
||||||
|
AuthAction(
|
||||||
|
action_type=AuthActionType.SAVE_TOKENS,
|
||||||
|
data=token_data,
|
||||||
|
provider_id=provider_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return AuthResult(actions=actions, response={"result": "success"})
|
||||||
|
except (RequestError, ValueError, KeyError) as e:
|
||||||
|
# RequestError: HTTP request failed
|
||||||
|
# ValueError: Invalid response data
|
||||||
|
# KeyError: Missing required fields in response
|
||||||
|
raise ValueError(f"Client credentials flow failed: {e}")
|
||||||
|
|
||||||
|
# Exchange authorization code for tokens (Authorization Code flow)
|
||||||
if authorization_code is not None:
|
if authorization_code is not None:
|
||||||
if not state_param:
|
if not state_param:
|
||||||
raise ValueError("State parameter is required when exchanging authorization code")
|
raise ValueError("State parameter is required when exchanging authorization code")
|
||||||
@ -335,35 +473,69 @@ def auth(
|
|||||||
|
|
||||||
tokens = exchange_authorization(
|
tokens = exchange_authorization(
|
||||||
server_url,
|
server_url,
|
||||||
metadata,
|
server_metadata,
|
||||||
client_information,
|
client_information,
|
||||||
authorization_code,
|
authorization_code,
|
||||||
code_verifier,
|
code_verifier,
|
||||||
redirect_uri,
|
redirect_uri,
|
||||||
)
|
)
|
||||||
provider.save_tokens(tokens)
|
|
||||||
return {"result": "success"}
|
|
||||||
|
|
||||||
provider_tokens = provider.tokens()
|
# Return action to save tokens
|
||||||
|
actions.append(
|
||||||
|
AuthAction(
|
||||||
|
action_type=AuthActionType.SAVE_TOKENS,
|
||||||
|
data=tokens.model_dump(),
|
||||||
|
provider_id=provider_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return AuthResult(actions=actions, response={"result": "success"})
|
||||||
|
|
||||||
|
provider_tokens = provider.retrieve_tokens()
|
||||||
|
|
||||||
# Handle token refresh or new authorization
|
# Handle token refresh or new authorization
|
||||||
if provider_tokens and provider_tokens.refresh_token:
|
if provider_tokens and provider_tokens.refresh_token:
|
||||||
try:
|
try:
|
||||||
new_tokens = refresh_authorization(server_url, metadata, client_information, provider_tokens.refresh_token)
|
new_tokens = refresh_authorization(
|
||||||
provider.save_tokens(new_tokens)
|
server_url, server_metadata, client_information, provider_tokens.refresh_token
|
||||||
return {"result": "success"}
|
)
|
||||||
except Exception as e:
|
|
||||||
|
# Return action to save new tokens
|
||||||
|
actions.append(
|
||||||
|
AuthAction(
|
||||||
|
action_type=AuthActionType.SAVE_TOKENS,
|
||||||
|
data=new_tokens.model_dump(),
|
||||||
|
provider_id=provider_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return AuthResult(actions=actions, response={"result": "success"})
|
||||||
|
except (RequestError, ValueError, KeyError) as e:
|
||||||
|
# RequestError: HTTP request failed
|
||||||
|
# ValueError: Invalid response data
|
||||||
|
# KeyError: Missing required fields in response
|
||||||
raise ValueError(f"Could not refresh OAuth tokens: {e}")
|
raise ValueError(f"Could not refresh OAuth tokens: {e}")
|
||||||
|
|
||||||
# Start new authorization flow
|
# Start new authorization flow (only for authorization code flow)
|
||||||
authorization_url, code_verifier = start_authorization(
|
authorization_url, code_verifier = start_authorization(
|
||||||
server_url,
|
server_url,
|
||||||
metadata,
|
server_metadata,
|
||||||
client_information,
|
client_information,
|
||||||
provider.redirect_url,
|
redirect_url,
|
||||||
provider.mcp_provider.id,
|
provider_id,
|
||||||
provider.mcp_provider.tenant_id,
|
tenant_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
provider.save_code_verifier(code_verifier)
|
# Return action to save code verifier
|
||||||
return {"authorization_url": authorization_url}
|
actions.append(
|
||||||
|
AuthAction(
|
||||||
|
action_type=AuthActionType.SAVE_CODE_VERIFIER,
|
||||||
|
data={"code_verifier": code_verifier},
|
||||||
|
provider_id=provider_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return AuthResult(actions=actions, response={"authorization_url": authorization_url})
|
||||||
|
|||||||
@ -1,77 +0,0 @@
|
|||||||
from configs import dify_config
|
|
||||||
from core.mcp.types import (
|
|
||||||
OAuthClientInformation,
|
|
||||||
OAuthClientInformationFull,
|
|
||||||
OAuthClientMetadata,
|
|
||||||
OAuthTokens,
|
|
||||||
)
|
|
||||||
from models.tools import MCPToolProvider
|
|
||||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthClientProvider:
|
|
||||||
mcp_provider: MCPToolProvider
|
|
||||||
|
|
||||||
def __init__(self, provider_id: str, tenant_id: str, for_list: bool = False):
|
|
||||||
if for_list:
|
|
||||||
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
|
||||||
else:
|
|
||||||
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_server_identifier(provider_id, tenant_id)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def redirect_url(self) -> str:
|
|
||||||
"""The URL to redirect the user agent to after authorization."""
|
|
||||||
return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def client_metadata(self) -> OAuthClientMetadata:
|
|
||||||
"""Metadata about this OAuth client."""
|
|
||||||
return OAuthClientMetadata(
|
|
||||||
redirect_uris=[self.redirect_url],
|
|
||||||
token_endpoint_auth_method="none",
|
|
||||||
grant_types=["authorization_code", "refresh_token"],
|
|
||||||
response_types=["code"],
|
|
||||||
client_name="Dify",
|
|
||||||
client_uri="https://github.com/langgenius/dify",
|
|
||||||
)
|
|
||||||
|
|
||||||
def client_information(self) -> OAuthClientInformation | None:
|
|
||||||
"""Loads information about this OAuth client."""
|
|
||||||
client_information = self.mcp_provider.decrypted_credentials.get("client_information", {})
|
|
||||||
if not client_information:
|
|
||||||
return None
|
|
||||||
return OAuthClientInformation.model_validate(client_information)
|
|
||||||
|
|
||||||
def save_client_information(self, client_information: OAuthClientInformationFull):
|
|
||||||
"""Saves client information after dynamic registration."""
|
|
||||||
MCPToolManageService.update_mcp_provider_credentials(
|
|
||||||
self.mcp_provider,
|
|
||||||
{"client_information": client_information.model_dump()},
|
|
||||||
)
|
|
||||||
|
|
||||||
def tokens(self) -> OAuthTokens | None:
|
|
||||||
"""Loads any existing OAuth tokens for the current session."""
|
|
||||||
credentials = self.mcp_provider.decrypted_credentials
|
|
||||||
if not credentials:
|
|
||||||
return None
|
|
||||||
return OAuthTokens(
|
|
||||||
access_token=credentials.get("access_token", ""),
|
|
||||||
token_type=credentials.get("token_type", "Bearer"),
|
|
||||||
expires_in=int(credentials.get("expires_in", "3600") or 3600),
|
|
||||||
refresh_token=credentials.get("refresh_token", ""),
|
|
||||||
)
|
|
||||||
|
|
||||||
def save_tokens(self, tokens: OAuthTokens):
|
|
||||||
"""Stores new OAuth tokens for the current session."""
|
|
||||||
# update mcp provider credentials
|
|
||||||
token_dict = tokens.model_dump()
|
|
||||||
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True)
|
|
||||||
|
|
||||||
def save_code_verifier(self, code_verifier: str):
|
|
||||||
"""Saves a PKCE code verifier for the current session."""
|
|
||||||
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier})
|
|
||||||
|
|
||||||
def code_verifier(self) -> str:
|
|
||||||
"""Loads the PKCE code verifier for the current session."""
|
|
||||||
# get code verifier from mcp provider credentials
|
|
||||||
return str(self.mcp_provider.decrypted_credentials.get("code_verifier", ""))
|
|
||||||
191
api/core/mcp/auth_client.py
Normal file
191
api/core/mcp/auth_client.py
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
"""
|
||||||
|
MCP Client with Authentication Retry Support
|
||||||
|
|
||||||
|
This module provides an enhanced MCPClient that automatically handles
|
||||||
|
authentication failures and retries operations after refreshing tokens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from core.entities.mcp_provider import MCPProviderEntity
|
||||||
|
from core.mcp.error import MCPAuthError
|
||||||
|
from core.mcp.mcp_client import MCPClient
|
||||||
|
from core.mcp.types import CallToolResult, Tool
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPClientWithAuthRetry(MCPClient):
|
||||||
|
"""
|
||||||
|
An enhanced MCPClient that provides automatic authentication retry.
|
||||||
|
|
||||||
|
This class extends MCPClient and intercepts MCPAuthError exceptions
|
||||||
|
to refresh authentication before retrying failed operations.
|
||||||
|
|
||||||
|
Note: This class uses lazy session creation - database sessions are only
|
||||||
|
created when authentication retry is actually needed, not on every request.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_url: str,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
timeout: float | None = None,
|
||||||
|
sse_read_timeout: float | None = None,
|
||||||
|
provider_entity: MCPProviderEntity | None = None,
|
||||||
|
authorization_code: str | None = None,
|
||||||
|
by_server_id: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the MCP client with auth retry capability.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_url: The MCP server URL
|
||||||
|
headers: Optional headers for requests
|
||||||
|
timeout: Request timeout
|
||||||
|
sse_read_timeout: SSE read timeout
|
||||||
|
provider_entity: Provider entity for authentication
|
||||||
|
authorization_code: Optional authorization code for initial auth
|
||||||
|
by_server_id: Whether to look up provider by server ID
|
||||||
|
"""
|
||||||
|
super().__init__(server_url, headers, timeout, sse_read_timeout)
|
||||||
|
|
||||||
|
self.provider_entity = provider_entity
|
||||||
|
self.authorization_code = authorization_code
|
||||||
|
self.by_server_id = by_server_id
|
||||||
|
self._has_retried = False
|
||||||
|
|
||||||
|
def _handle_auth_error(self, error: MCPAuthError) -> None:
|
||||||
|
"""
|
||||||
|
Handle authentication error by refreshing tokens.
|
||||||
|
|
||||||
|
This method creates a short-lived database session only when authentication
|
||||||
|
retry is needed, minimizing database connection hold time.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error: The authentication error
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MCPAuthError: If authentication fails or max retries reached
|
||||||
|
"""
|
||||||
|
if not self.provider_entity:
|
||||||
|
raise error
|
||||||
|
if self._has_retried:
|
||||||
|
raise error
|
||||||
|
|
||||||
|
self._has_retried = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create a temporary session only for auth retry
|
||||||
|
# This session is short-lived and only exists during the auth operation
|
||||||
|
|
||||||
|
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||||
|
|
||||||
|
with Session(db.engine) as session, session.begin():
|
||||||
|
mcp_service = MCPToolManageService(session=session)
|
||||||
|
|
||||||
|
# Perform authentication using the service's auth method
|
||||||
|
mcp_service.auth_with_actions(self.provider_entity, self.authorization_code)
|
||||||
|
|
||||||
|
# Retrieve new tokens
|
||||||
|
self.provider_entity = mcp_service.get_provider_entity(
|
||||||
|
self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Session is closed here, before we update headers
|
||||||
|
token = self.provider_entity.retrieve_tokens()
|
||||||
|
if not token:
|
||||||
|
raise MCPAuthError("Authentication failed - no token received")
|
||||||
|
|
||||||
|
# Update headers with new token
|
||||||
|
self.headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
|
||||||
|
|
||||||
|
# Clear authorization code after first use
|
||||||
|
self.authorization_code = None
|
||||||
|
|
||||||
|
except MCPAuthError:
|
||||||
|
# Re-raise MCPAuthError as is
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
# Catch all exceptions during auth retry
|
||||||
|
logger.exception("Authentication retry failed")
|
||||||
|
raise MCPAuthError(f"Authentication retry failed: {e}") from e
|
||||||
|
|
||||||
|
def _execute_with_retry(self, func: Callable[..., Any], *args, **kwargs) -> Any:
|
||||||
|
"""
|
||||||
|
Execute a function with authentication retry logic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The function to execute
|
||||||
|
*args: Positional arguments for the function
|
||||||
|
**kwargs: Keyword arguments for the function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result of the function call
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MCPAuthError: If authentication fails after retries
|
||||||
|
Any other exceptions from the function
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except MCPAuthError as e:
|
||||||
|
self._handle_auth_error(e)
|
||||||
|
|
||||||
|
# Re-initialize the connection with new headers
|
||||||
|
if self._initialized:
|
||||||
|
# Clean up existing connection
|
||||||
|
self._exit_stack.close()
|
||||||
|
self._session = None
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
|
# Re-initialize with new headers
|
||||||
|
self._initialize()
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
# Reset retry flag after operation completes
|
||||||
|
self._has_retried = False
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
"""Enter the context manager with retry support."""
|
||||||
|
|
||||||
|
def initialize_with_retry():
|
||||||
|
super(MCPClientWithAuthRetry, self).__enter__()
|
||||||
|
return self
|
||||||
|
|
||||||
|
return self._execute_with_retry(initialize_with_retry)
|
||||||
|
|
||||||
|
def list_tools(self) -> list[Tool]:
|
||||||
|
"""
|
||||||
|
List available tools from the MCP server with auth retry.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of available tools
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MCPAuthError: If authentication fails after retries
|
||||||
|
"""
|
||||||
|
return self._execute_with_retry(super().list_tools)
|
||||||
|
|
||||||
|
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
|
||||||
|
"""
|
||||||
|
Invoke a tool on the MCP server with auth retry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the tool to invoke
|
||||||
|
tool_args: Arguments for the tool
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Result of the tool invocation
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MCPAuthError: If authentication fails after retries
|
||||||
|
"""
|
||||||
|
return self._execute_with_retry(super().invoke_tool, tool_name, tool_args)
|
||||||
0
api/core/mcp/auth_client_comparison.md
Normal file
0
api/core/mcp/auth_client_comparison.md
Normal file
@ -46,7 +46,7 @@ class SSETransport:
|
|||||||
url: str,
|
url: str,
|
||||||
headers: dict[str, Any] | None = None,
|
headers: dict[str, Any] | None = None,
|
||||||
timeout: float = 5.0,
|
timeout: float = 5.0,
|
||||||
sse_read_timeout: float = 5 * 60,
|
sse_read_timeout: float = 1 * 60,
|
||||||
):
|
):
|
||||||
"""Initialize the SSE transport.
|
"""Initialize the SSE transport.
|
||||||
|
|
||||||
@ -255,7 +255,7 @@ def sse_client(
|
|||||||
url: str,
|
url: str,
|
||||||
headers: dict[str, Any] | None = None,
|
headers: dict[str, Any] | None = None,
|
||||||
timeout: float = 5.0,
|
timeout: float = 5.0,
|
||||||
sse_read_timeout: float = 5 * 60,
|
sse_read_timeout: float = 1 * 60,
|
||||||
) -> Generator[tuple[ReadQueue, WriteQueue], None, None]:
|
) -> Generator[tuple[ReadQueue, WriteQueue], None, None]:
|
||||||
"""
|
"""
|
||||||
Client transport for SSE.
|
Client transport for SSE.
|
||||||
@ -276,31 +276,34 @@ def sse_client(
|
|||||||
read_queue: ReadQueue | None = None
|
read_queue: ReadQueue | None = None
|
||||||
write_queue: WriteQueue | None = None
|
write_queue: WriteQueue | None = None
|
||||||
|
|
||||||
with ThreadPoolExecutor() as executor:
|
executor = ThreadPoolExecutor()
|
||||||
try:
|
try:
|
||||||
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
|
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
|
||||||
with ssrf_proxy_sse_connect(
|
with ssrf_proxy_sse_connect(
|
||||||
url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
|
url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
|
||||||
) as event_source:
|
) as event_source:
|
||||||
event_source.response.raise_for_status()
|
event_source.response.raise_for_status()
|
||||||
|
|
||||||
read_queue, write_queue = transport.connect(executor, client, event_source)
|
read_queue, write_queue = transport.connect(executor, client, event_source)
|
||||||
|
|
||||||
yield read_queue, write_queue
|
yield read_queue, write_queue
|
||||||
|
|
||||||
except httpx.HTTPStatusError as exc:
|
except httpx.HTTPStatusError as exc:
|
||||||
if exc.response.status_code == 401:
|
if exc.response.status_code == 401:
|
||||||
raise MCPAuthError()
|
raise MCPAuthError()
|
||||||
raise MCPConnectionError()
|
raise MCPConnectionError()
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error connecting to SSE endpoint")
|
logger.exception("Error connecting to SSE endpoint")
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
# Clean up queues
|
# Clean up queues
|
||||||
if read_queue:
|
if read_queue:
|
||||||
read_queue.put(None)
|
read_queue.put(None)
|
||||||
if write_queue:
|
if write_queue:
|
||||||
write_queue.put(None)
|
write_queue.put(None)
|
||||||
|
|
||||||
|
# Shutdown executor without waiting to prevent hanging
|
||||||
|
executor.shutdown(wait=False)
|
||||||
|
|
||||||
|
|
||||||
def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage):
|
def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage):
|
||||||
|
|||||||
@ -434,45 +434,48 @@ def streamablehttp_client(
|
|||||||
server_to_client_queue: ServerToClientQueue = queue.Queue() # For messages FROM server TO client
|
server_to_client_queue: ServerToClientQueue = queue.Queue() # For messages FROM server TO client
|
||||||
client_to_server_queue: ClientToServerQueue = queue.Queue() # For messages FROM client TO server
|
client_to_server_queue: ClientToServerQueue = queue.Queue() # For messages FROM client TO server
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=2) as executor:
|
executor = ThreadPoolExecutor(max_workers=2)
|
||||||
try:
|
try:
|
||||||
with create_ssrf_proxy_mcp_http_client(
|
with create_ssrf_proxy_mcp_http_client(
|
||||||
headers=transport.request_headers,
|
headers=transport.request_headers,
|
||||||
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
|
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
|
||||||
) as client:
|
) as client:
|
||||||
# Define callbacks that need access to thread pool
|
# Define callbacks that need access to thread pool
|
||||||
def start_get_stream():
|
def start_get_stream():
|
||||||
"""Start a worker thread to handle server-initiated messages."""
|
"""Start a worker thread to handle server-initiated messages."""
|
||||||
executor.submit(transport.handle_get_stream, client, server_to_client_queue)
|
executor.submit(transport.handle_get_stream, client, server_to_client_queue)
|
||||||
|
|
||||||
# Start the post_writer worker thread
|
# Start the post_writer worker thread
|
||||||
executor.submit(
|
executor.submit(
|
||||||
transport.post_writer,
|
transport.post_writer,
|
||||||
client,
|
client,
|
||||||
client_to_server_queue, # Queue for messages FROM client TO server
|
client_to_server_queue, # Queue for messages FROM client TO server
|
||||||
server_to_client_queue, # Queue for messages FROM server TO client
|
server_to_client_queue, # Queue for messages FROM server TO client
|
||||||
start_get_stream,
|
start_get_stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
yield (
|
|
||||||
server_to_client_queue, # Queue for receiving messages FROM server
|
|
||||||
client_to_server_queue, # Queue for sending messages TO server
|
|
||||||
transport.get_session_id,
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
if transport.session_id and terminate_on_close:
|
|
||||||
transport.terminate_session(client)
|
|
||||||
|
|
||||||
# Signal threads to stop
|
|
||||||
client_to_server_queue.put(None)
|
|
||||||
finally:
|
|
||||||
# Clear any remaining items and add None sentinel to unblock any waiting threads
|
|
||||||
try:
|
try:
|
||||||
while not client_to_server_queue.empty():
|
yield (
|
||||||
client_to_server_queue.get_nowait()
|
server_to_client_queue, # Queue for receiving messages FROM server
|
||||||
except queue.Empty:
|
client_to_server_queue, # Queue for sending messages TO server
|
||||||
pass
|
transport.get_session_id,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if transport.session_id and terminate_on_close:
|
||||||
|
transport.terminate_session(client)
|
||||||
|
|
||||||
client_to_server_queue.put(None)
|
# Signal threads to stop
|
||||||
server_to_client_queue.put(None)
|
client_to_server_queue.put(None)
|
||||||
|
finally:
|
||||||
|
# Clear any remaining items and add None sentinel to unblock any waiting threads
|
||||||
|
try:
|
||||||
|
while not client_to_server_queue.empty():
|
||||||
|
client_to_server_queue.get_nowait()
|
||||||
|
except queue.Empty:
|
||||||
|
pass
|
||||||
|
|
||||||
|
client_to_server_queue.put(None)
|
||||||
|
server_to_client_queue.put(None)
|
||||||
|
|
||||||
|
# Shutdown executor without waiting to prevent hanging
|
||||||
|
executor.shutdown(wait=False)
|
||||||
|
|||||||
@ -1,10 +1,13 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from enum import StrEnum
|
||||||
from typing import Any, Generic, TypeVar
|
from typing import Any, Generic, TypeVar
|
||||||
|
|
||||||
from core.mcp.session.base_session import BaseSession
|
from pydantic import BaseModel
|
||||||
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestId, RequestParams
|
|
||||||
|
|
||||||
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", LATEST_PROTOCOL_VERSION]
|
from core.mcp.session.base_session import BaseSession
|
||||||
|
from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAuthMetadata, RequestId, RequestParams
|
||||||
|
|
||||||
|
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
|
||||||
|
|
||||||
|
|
||||||
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
|
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
|
||||||
@ -17,3 +20,41 @@ class RequestContext(Generic[SessionT, LifespanContextT]):
|
|||||||
meta: RequestParams.Meta | None
|
meta: RequestParams.Meta | None
|
||||||
session: SessionT
|
session: SessionT
|
||||||
lifespan_context: LifespanContextT
|
lifespan_context: LifespanContextT
|
||||||
|
|
||||||
|
|
||||||
|
class AuthActionType(StrEnum):
|
||||||
|
"""Types of actions that can be performed during auth flow."""
|
||||||
|
|
||||||
|
SAVE_CLIENT_INFO = "save_client_info"
|
||||||
|
SAVE_TOKENS = "save_tokens"
|
||||||
|
SAVE_CODE_VERIFIER = "save_code_verifier"
|
||||||
|
START_AUTHORIZATION = "start_authorization"
|
||||||
|
SUCCESS = "success"
|
||||||
|
|
||||||
|
|
||||||
|
class AuthAction(BaseModel):
|
||||||
|
"""Represents an action that needs to be performed as a result of auth flow."""
|
||||||
|
|
||||||
|
action_type: AuthActionType
|
||||||
|
data: dict[str, Any]
|
||||||
|
provider_id: str | None = None
|
||||||
|
tenant_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class AuthResult(BaseModel):
|
||||||
|
"""Result of auth function containing actions to be performed and response data."""
|
||||||
|
|
||||||
|
actions: list[AuthAction]
|
||||||
|
response: dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthCallbackState(BaseModel):
|
||||||
|
"""State data stored in Redis during OAuth callback flow."""
|
||||||
|
|
||||||
|
provider_id: str
|
||||||
|
tenant_id: str
|
||||||
|
server_url: str
|
||||||
|
metadata: OAuthMetadata | None = None
|
||||||
|
client_information: OAuthClientInformation
|
||||||
|
code_verifier: str
|
||||||
|
redirect_uri: str
|
||||||
|
|||||||
@ -8,3 +8,7 @@ class MCPConnectionError(MCPError):
|
|||||||
|
|
||||||
class MCPAuthError(MCPConnectionError):
|
class MCPAuthError(MCPConnectionError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MCPRefreshTokenError(MCPError):
|
||||||
|
pass
|
||||||
|
|||||||
@ -7,9 +7,9 @@ from urllib.parse import urlparse
|
|||||||
|
|
||||||
from core.mcp.client.sse_client import sse_client
|
from core.mcp.client.sse_client import sse_client
|
||||||
from core.mcp.client.streamable_client import streamablehttp_client
|
from core.mcp.client.streamable_client import streamablehttp_client
|
||||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
from core.mcp.error import MCPConnectionError
|
||||||
from core.mcp.session.client_session import ClientSession
|
from core.mcp.session.client_session import ClientSession
|
||||||
from core.mcp.types import Tool
|
from core.mcp.types import CallToolResult, Tool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -18,40 +18,18 @@ class MCPClient:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
server_url: str,
|
server_url: str,
|
||||||
provider_id: str,
|
|
||||||
tenant_id: str,
|
|
||||||
authed: bool = True,
|
|
||||||
authorization_code: str | None = None,
|
|
||||||
for_list: bool = False,
|
|
||||||
headers: dict[str, str] | None = None,
|
headers: dict[str, str] | None = None,
|
||||||
timeout: float | None = None,
|
timeout: float | None = None,
|
||||||
sse_read_timeout: float | None = None,
|
sse_read_timeout: float | None = None,
|
||||||
):
|
):
|
||||||
# Initialize info
|
|
||||||
self.provider_id = provider_id
|
|
||||||
self.tenant_id = tenant_id
|
|
||||||
self.client_type = "streamable"
|
|
||||||
self.server_url = server_url
|
self.server_url = server_url
|
||||||
self.headers = headers or {}
|
self.headers = headers or {}
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.sse_read_timeout = sse_read_timeout
|
self.sse_read_timeout = sse_read_timeout
|
||||||
|
|
||||||
# Authentication info
|
|
||||||
self.authed = authed
|
|
||||||
self.authorization_code = authorization_code
|
|
||||||
if authed:
|
|
||||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
|
||||||
|
|
||||||
self.provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=for_list)
|
|
||||||
self.token = self.provider.tokens()
|
|
||||||
|
|
||||||
# Initialize session and client objects
|
# Initialize session and client objects
|
||||||
self._session: ClientSession | None = None
|
self._session: ClientSession | None = None
|
||||||
self._streams_context: AbstractContextManager[Any] | None = None
|
|
||||||
self._session_context: ClientSession | None = None
|
|
||||||
self._exit_stack = ExitStack()
|
self._exit_stack = ExitStack()
|
||||||
|
|
||||||
# Whether the client has been initialized
|
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
@ -85,61 +63,42 @@ class MCPClient:
|
|||||||
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
|
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
|
||||||
self.connect_server(streamablehttp_client, "mcp")
|
self.connect_server(streamablehttp_client, "mcp")
|
||||||
|
|
||||||
def connect_server(
|
def connect_server(self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str) -> None:
|
||||||
self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str, first_try: bool = True
|
"""
|
||||||
):
|
Connect to the MCP server using streamable http or sse.
|
||||||
from core.mcp.auth.auth_flow import auth
|
Default to streamable http.
|
||||||
|
Args:
|
||||||
|
client_factory: The client factory to use(streamablehttp_client or sse_client).
|
||||||
|
method_name: The method name to use(mcp or sse).
|
||||||
|
"""
|
||||||
|
streams_context = client_factory(
|
||||||
|
url=self.server_url,
|
||||||
|
headers=self.headers,
|
||||||
|
timeout=self.timeout,
|
||||||
|
sse_read_timeout=self.sse_read_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
# Use exit_stack to manage context managers properly
|
||||||
headers = (
|
if method_name == "mcp":
|
||||||
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
|
read_stream, write_stream, _ = self._exit_stack.enter_context(streams_context)
|
||||||
if self.authed and self.token
|
streams = (read_stream, write_stream)
|
||||||
else self.headers
|
else: # sse_client
|
||||||
)
|
streams = self._exit_stack.enter_context(streams_context)
|
||||||
self._streams_context = client_factory(
|
|
||||||
url=self.server_url,
|
|
||||||
headers=headers,
|
|
||||||
timeout=self.timeout,
|
|
||||||
sse_read_timeout=self.sse_read_timeout,
|
|
||||||
)
|
|
||||||
if not self._streams_context:
|
|
||||||
raise MCPConnectionError("Failed to create connection context")
|
|
||||||
|
|
||||||
# Use exit_stack to manage context managers properly
|
session_context = ClientSession(*streams)
|
||||||
if method_name == "mcp":
|
self._session = self._exit_stack.enter_context(session_context)
|
||||||
read_stream, write_stream, _ = self._exit_stack.enter_context(self._streams_context)
|
self._session.initialize()
|
||||||
streams = (read_stream, write_stream)
|
|
||||||
else: # sse_client
|
|
||||||
streams = self._exit_stack.enter_context(self._streams_context)
|
|
||||||
|
|
||||||
self._session_context = ClientSession(*streams)
|
|
||||||
self._session = self._exit_stack.enter_context(self._session_context)
|
|
||||||
self._session.initialize()
|
|
||||||
return
|
|
||||||
|
|
||||||
except MCPAuthError:
|
|
||||||
if not self.authed:
|
|
||||||
raise
|
|
||||||
try:
|
|
||||||
auth(self.provider, self.server_url, self.authorization_code)
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"Failed to authenticate: {e}")
|
|
||||||
self.token = self.provider.tokens()
|
|
||||||
if first_try:
|
|
||||||
return self.connect_server(client_factory, method_name, first_try=False)
|
|
||||||
|
|
||||||
def list_tools(self) -> list[Tool]:
|
def list_tools(self) -> list[Tool]:
|
||||||
"""Connect to an MCP server running with SSE transport"""
|
"""List available tools from the MCP server"""
|
||||||
# List available tools to verify connection
|
if not self._session:
|
||||||
if not self._initialized or not self._session:
|
|
||||||
raise ValueError("Session not initialized.")
|
raise ValueError("Session not initialized.")
|
||||||
response = self._session.list_tools()
|
response = self._session.list_tools()
|
||||||
tools = response.tools
|
return response.tools
|
||||||
return tools
|
|
||||||
|
|
||||||
def invoke_tool(self, tool_name: str, tool_args: dict):
|
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
|
||||||
"""Call a tool"""
|
"""Call a tool"""
|
||||||
if not self._initialized or not self._session:
|
if not self._session:
|
||||||
raise ValueError("Session not initialized.")
|
raise ValueError("Session not initialized.")
|
||||||
return self._session.call_tool(tool_name, tool_args)
|
return self._session.call_tool(tool_name, tool_args)
|
||||||
|
|
||||||
@ -153,6 +112,4 @@ class MCPClient:
|
|||||||
raise ValueError(f"Error during cleanup: {e}")
|
raise ValueError(f"Error during cleanup: {e}")
|
||||||
finally:
|
finally:
|
||||||
self._session = None
|
self._session = None
|
||||||
self._session_context = None
|
|
||||||
self._streams_context = None
|
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
|
|||||||
@ -201,11 +201,14 @@ class BaseSession(
|
|||||||
self._receiver_future.result(timeout=5.0) # Wait up to 5 seconds
|
self._receiver_future.result(timeout=5.0) # Wait up to 5 seconds
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
# If the receiver loop is still running after timeout, we'll force shutdown
|
# If the receiver loop is still running after timeout, we'll force shutdown
|
||||||
pass
|
# Cancel the future to interrupt the receiver loop
|
||||||
|
self._receiver_future.cancel()
|
||||||
|
|
||||||
# Shutdown the executor
|
# Shutdown the executor
|
||||||
if self._executor:
|
if self._executor:
|
||||||
self._executor.shutdown(wait=True)
|
# Use non-blocking shutdown to prevent hanging
|
||||||
|
# The receiver thread should have already exited due to the None message in the queue
|
||||||
|
self._executor.shutdown(wait=False)
|
||||||
|
|
||||||
def send_request(
|
def send_request(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -284,7 +284,7 @@ class ClientSession(
|
|||||||
|
|
||||||
def complete(
|
def complete(
|
||||||
self,
|
self,
|
||||||
ref: types.ResourceReference | types.PromptReference,
|
ref: types.ResourceTemplateReference | types.PromptReference,
|
||||||
argument: dict[str, str],
|
argument: dict[str, str],
|
||||||
) -> types.CompleteResult:
|
) -> types.CompleteResult:
|
||||||
"""Send a completion/complete request."""
|
"""Send a completion/complete request."""
|
||||||
|
|||||||
@ -1,13 +1,6 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar
|
||||||
Annotated,
|
|
||||||
Any,
|
|
||||||
Generic,
|
|
||||||
Literal,
|
|
||||||
TypeAlias,
|
|
||||||
TypeVar,
|
|
||||||
)
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel
|
from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel
|
||||||
from pydantic.networks import AnyUrl, UrlConstraints
|
from pydantic.networks import AnyUrl, UrlConstraints
|
||||||
@ -33,6 +26,7 @@ for reference.
|
|||||||
LATEST_PROTOCOL_VERSION = "2025-03-26"
|
LATEST_PROTOCOL_VERSION = "2025-03-26"
|
||||||
# Server support 2024-11-05 to allow claude to use.
|
# Server support 2024-11-05 to allow claude to use.
|
||||||
SERVER_LATEST_PROTOCOL_VERSION = "2024-11-05"
|
SERVER_LATEST_PROTOCOL_VERSION = "2024-11-05"
|
||||||
|
DEFAULT_NEGOTIATED_VERSION = "2025-03-26"
|
||||||
ProgressToken = str | int
|
ProgressToken = str | int
|
||||||
Cursor = str
|
Cursor = str
|
||||||
Role = Literal["user", "assistant"]
|
Role = Literal["user", "assistant"]
|
||||||
@ -55,14 +49,22 @@ class RequestParams(BaseModel):
|
|||||||
meta: Meta | None = Field(alias="_meta", default=None)
|
meta: Meta | None = Field(alias="_meta", default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class PaginatedRequestParams(RequestParams):
|
||||||
|
cursor: Cursor | None = None
|
||||||
|
"""
|
||||||
|
An opaque token representing the current pagination position.
|
||||||
|
If provided, the server should return results starting after this cursor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class NotificationParams(BaseModel):
|
class NotificationParams(BaseModel):
|
||||||
class Meta(BaseModel):
|
class Meta(BaseModel):
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
meta: Meta | None = Field(alias="_meta", default=None)
|
meta: Meta | None = Field(alias="_meta", default=None)
|
||||||
"""
|
"""
|
||||||
This parameter name is reserved by MCP to allow clients and servers to attach
|
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||||
additional metadata to their notifications.
|
for notes on _meta usage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -79,12 +81,11 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]):
|
|||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
class PaginatedRequest(Request[RequestParamsT, MethodT]):
|
class PaginatedRequest(Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]):
|
||||||
cursor: Cursor | None = None
|
"""Base class for paginated requests,
|
||||||
"""
|
matching the schema's PaginatedRequest interface."""
|
||||||
An opaque token representing the current pagination position.
|
|
||||||
If provided, the server should return results starting after this cursor.
|
params: PaginatedRequestParams | None = None
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
|
class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
|
||||||
@ -98,13 +99,12 @@ class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
|
|||||||
class Result(BaseModel):
|
class Result(BaseModel):
|
||||||
"""Base class for JSON-RPC results."""
|
"""Base class for JSON-RPC results."""
|
||||||
|
|
||||||
model_config = ConfigDict(extra="allow")
|
|
||||||
|
|
||||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||||
"""
|
"""
|
||||||
This result property is reserved by the protocol to allow clients and servers to
|
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||||
attach additional metadata to their responses.
|
for notes on _meta usage.
|
||||||
"""
|
"""
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
class PaginatedResult(Result):
|
class PaginatedResult(Result):
|
||||||
@ -186,10 +186,26 @@ class EmptyResult(Result):
|
|||||||
"""A response that indicates success but carries no data."""
|
"""A response that indicates success but carries no data."""
|
||||||
|
|
||||||
|
|
||||||
class Implementation(BaseModel):
|
class BaseMetadata(BaseModel):
|
||||||
"""Describes the name and version of an MCP implementation."""
|
"""Base class for entities with name and optional title fields."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
|
"""The programmatic name of the entity."""
|
||||||
|
|
||||||
|
title: str | None = None
|
||||||
|
"""
|
||||||
|
Intended for UI and end-user contexts — optimized to be human-readable and easily understood,
|
||||||
|
even by those unfamiliar with domain-specific terminology.
|
||||||
|
|
||||||
|
If not provided, the name should be used for display (except for Tool,
|
||||||
|
where `annotations.title` should be given precedence over using `name`,
|
||||||
|
if present).
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Implementation(BaseMetadata):
|
||||||
|
"""Describes the name and version of an MCP implementation."""
|
||||||
|
|
||||||
version: str
|
version: str
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
@ -203,7 +219,7 @@ class RootsCapability(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class SamplingCapability(BaseModel):
|
class SamplingCapability(BaseModel):
|
||||||
"""Capability for logging operations."""
|
"""Capability for sampling operations."""
|
||||||
|
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
@ -252,6 +268,12 @@ class LoggingCapability(BaseModel):
|
|||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionsCapability(BaseModel):
|
||||||
|
"""Capability for completions operations."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
class ServerCapabilities(BaseModel):
|
class ServerCapabilities(BaseModel):
|
||||||
"""Capabilities that a server may support."""
|
"""Capabilities that a server may support."""
|
||||||
|
|
||||||
@ -265,6 +287,8 @@ class ServerCapabilities(BaseModel):
|
|||||||
"""Present if the server offers any resources to read."""
|
"""Present if the server offers any resources to read."""
|
||||||
tools: ToolsCapability | None = None
|
tools: ToolsCapability | None = None
|
||||||
"""Present if the server offers any tools to call."""
|
"""Present if the server offers any tools to call."""
|
||||||
|
completions: CompletionsCapability | None = None
|
||||||
|
"""Present if the server offers autocompletion suggestions for prompts and resources."""
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
@ -284,7 +308,7 @@ class InitializeRequest(Request[InitializeRequestParams, Literal["initialize"]])
|
|||||||
to begin initialization.
|
to begin initialization.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
method: Literal["initialize"]
|
method: Literal["initialize"] = "initialize"
|
||||||
params: InitializeRequestParams
|
params: InitializeRequestParams
|
||||||
|
|
||||||
|
|
||||||
@ -305,7 +329,7 @@ class InitializedNotification(Notification[NotificationParams | None, Literal["n
|
|||||||
finished.
|
finished.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
method: Literal["notifications/initialized"]
|
method: Literal["notifications/initialized"] = "notifications/initialized"
|
||||||
params: NotificationParams | None = None
|
params: NotificationParams | None = None
|
||||||
|
|
||||||
|
|
||||||
@ -315,7 +339,7 @@ class PingRequest(Request[RequestParams | None, Literal["ping"]]):
|
|||||||
still alive.
|
still alive.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
method: Literal["ping"]
|
method: Literal["ping"] = "ping"
|
||||||
params: RequestParams | None = None
|
params: RequestParams | None = None
|
||||||
|
|
||||||
|
|
||||||
@ -334,6 +358,11 @@ class ProgressNotificationParams(NotificationParams):
|
|||||||
"""
|
"""
|
||||||
total: float | None = None
|
total: float | None = None
|
||||||
"""Total number of items to process (or total progress required), if known."""
|
"""Total number of items to process (or total progress required), if known."""
|
||||||
|
message: str | None = None
|
||||||
|
"""
|
||||||
|
Message related to progress. This should provide relevant human readable
|
||||||
|
progress information.
|
||||||
|
"""
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
@ -343,15 +372,14 @@ class ProgressNotification(Notification[ProgressNotificationParams, Literal["not
|
|||||||
long-running request.
|
long-running request.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
method: Literal["notifications/progress"]
|
method: Literal["notifications/progress"] = "notifications/progress"
|
||||||
params: ProgressNotificationParams
|
params: ProgressNotificationParams
|
||||||
|
|
||||||
|
|
||||||
class ListResourcesRequest(PaginatedRequest[RequestParams | None, Literal["resources/list"]]):
|
class ListResourcesRequest(PaginatedRequest[Literal["resources/list"]]):
|
||||||
"""Sent from the client to request a list of resources the server has."""
|
"""Sent from the client to request a list of resources the server has."""
|
||||||
|
|
||||||
method: Literal["resources/list"]
|
method: Literal["resources/list"] = "resources/list"
|
||||||
params: RequestParams | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class Annotations(BaseModel):
|
class Annotations(BaseModel):
|
||||||
@ -360,13 +388,11 @@ class Annotations(BaseModel):
|
|||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
class Resource(BaseModel):
|
class Resource(BaseMetadata):
|
||||||
"""A known resource that the server is capable of reading."""
|
"""A known resource that the server is capable of reading."""
|
||||||
|
|
||||||
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
|
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
|
||||||
"""The URI of this resource."""
|
"""The URI of this resource."""
|
||||||
name: str
|
|
||||||
"""A human-readable name for this resource."""
|
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
"""A description of what this resource represents."""
|
"""A description of what this resource represents."""
|
||||||
mimeType: str | None = None
|
mimeType: str | None = None
|
||||||
@ -379,10 +405,15 @@ class Resource(BaseModel):
|
|||||||
This can be used by Hosts to display file sizes and estimate context window usage.
|
This can be used by Hosts to display file sizes and estimate context window usage.
|
||||||
"""
|
"""
|
||||||
annotations: Annotations | None = None
|
annotations: Annotations | None = None
|
||||||
|
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||||
|
"""
|
||||||
|
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||||
|
for notes on _meta usage.
|
||||||
|
"""
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
class ResourceTemplate(BaseModel):
|
class ResourceTemplate(BaseMetadata):
|
||||||
"""A template description for resources available on the server."""
|
"""A template description for resources available on the server."""
|
||||||
|
|
||||||
uriTemplate: str
|
uriTemplate: str
|
||||||
@ -390,8 +421,6 @@ class ResourceTemplate(BaseModel):
|
|||||||
A URI template (according to RFC 6570) that can be used to construct resource
|
A URI template (according to RFC 6570) that can be used to construct resource
|
||||||
URIs.
|
URIs.
|
||||||
"""
|
"""
|
||||||
name: str
|
|
||||||
"""A human-readable name for the type of resource this template refers to."""
|
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
"""A human-readable description of what this template is for."""
|
"""A human-readable description of what this template is for."""
|
||||||
mimeType: str | None = None
|
mimeType: str | None = None
|
||||||
@ -400,6 +429,11 @@ class ResourceTemplate(BaseModel):
|
|||||||
included if all resources matching this template have the same type.
|
included if all resources matching this template have the same type.
|
||||||
"""
|
"""
|
||||||
annotations: Annotations | None = None
|
annotations: Annotations | None = None
|
||||||
|
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||||
|
"""
|
||||||
|
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||||
|
for notes on _meta usage.
|
||||||
|
"""
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
@ -409,11 +443,10 @@ class ListResourcesResult(PaginatedResult):
|
|||||||
resources: list[Resource]
|
resources: list[Resource]
|
||||||
|
|
||||||
|
|
||||||
class ListResourceTemplatesRequest(PaginatedRequest[RequestParams | None, Literal["resources/templates/list"]]):
|
class ListResourceTemplatesRequest(PaginatedRequest[Literal["resources/templates/list"]]):
|
||||||
"""Sent from the client to request a list of resource templates the server has."""
|
"""Sent from the client to request a list of resource templates the server has."""
|
||||||
|
|
||||||
method: Literal["resources/templates/list"]
|
method: Literal["resources/templates/list"] = "resources/templates/list"
|
||||||
params: RequestParams | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ListResourceTemplatesResult(PaginatedResult):
|
class ListResourceTemplatesResult(PaginatedResult):
|
||||||
@ -436,7 +469,7 @@ class ReadResourceRequestParams(RequestParams):
|
|||||||
class ReadResourceRequest(Request[ReadResourceRequestParams, Literal["resources/read"]]):
|
class ReadResourceRequest(Request[ReadResourceRequestParams, Literal["resources/read"]]):
|
||||||
"""Sent from the client to the server, to read a specific resource URI."""
|
"""Sent from the client to the server, to read a specific resource URI."""
|
||||||
|
|
||||||
method: Literal["resources/read"]
|
method: Literal["resources/read"] = "resources/read"
|
||||||
params: ReadResourceRequestParams
|
params: ReadResourceRequestParams
|
||||||
|
|
||||||
|
|
||||||
@ -447,6 +480,11 @@ class ResourceContents(BaseModel):
|
|||||||
"""The URI of this resource."""
|
"""The URI of this resource."""
|
||||||
mimeType: str | None = None
|
mimeType: str | None = None
|
||||||
"""The MIME type of this resource, if known."""
|
"""The MIME type of this resource, if known."""
|
||||||
|
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||||
|
"""
|
||||||
|
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||||
|
for notes on _meta usage.
|
||||||
|
"""
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
@ -481,7 +519,7 @@ class ResourceListChangedNotification(
|
|||||||
of resources it can read from has changed.
|
of resources it can read from has changed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
method: Literal["notifications/resources/list_changed"]
|
method: Literal["notifications/resources/list_changed"] = "notifications/resources/list_changed"
|
||||||
params: NotificationParams | None = None
|
params: NotificationParams | None = None
|
||||||
|
|
||||||
|
|
||||||
@ -502,7 +540,7 @@ class SubscribeRequest(Request[SubscribeRequestParams, Literal["resources/subscr
|
|||||||
whenever a particular resource changes.
|
whenever a particular resource changes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
method: Literal["resources/subscribe"]
|
method: Literal["resources/subscribe"] = "resources/subscribe"
|
||||||
params: SubscribeRequestParams
|
params: SubscribeRequestParams
|
||||||
|
|
||||||
|
|
||||||
@ -520,7 +558,7 @@ class UnsubscribeRequest(Request[UnsubscribeRequestParams, Literal["resources/un
|
|||||||
the server.
|
the server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
method: Literal["resources/unsubscribe"]
|
method: Literal["resources/unsubscribe"] = "resources/unsubscribe"
|
||||||
params: UnsubscribeRequestParams
|
params: UnsubscribeRequestParams
|
||||||
|
|
||||||
|
|
||||||
@ -543,15 +581,14 @@ class ResourceUpdatedNotification(
|
|||||||
changed and may need to be read again.
|
changed and may need to be read again.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
method: Literal["notifications/resources/updated"]
|
method: Literal["notifications/resources/updated"] = "notifications/resources/updated"
|
||||||
params: ResourceUpdatedNotificationParams
|
params: ResourceUpdatedNotificationParams
|
||||||
|
|
||||||
|
|
||||||
class ListPromptsRequest(PaginatedRequest[RequestParams | None, Literal["prompts/list"]]):
|
class ListPromptsRequest(PaginatedRequest[Literal["prompts/list"]]):
|
||||||
"""Sent from the client to request a list of prompts and prompt templates."""
|
"""Sent from the client to request a list of prompts and prompt templates."""
|
||||||
|
|
||||||
method: Literal["prompts/list"]
|
method: Literal["prompts/list"] = "prompts/list"
|
||||||
params: RequestParams | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class PromptArgument(BaseModel):
|
class PromptArgument(BaseModel):
|
||||||
@ -566,15 +603,18 @@ class PromptArgument(BaseModel):
|
|||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
class Prompt(BaseModel):
|
class Prompt(BaseMetadata):
|
||||||
"""A prompt or prompt template that the server offers."""
|
"""A prompt or prompt template that the server offers."""
|
||||||
|
|
||||||
name: str
|
|
||||||
"""The name of the prompt or prompt template."""
|
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
"""An optional description of what this prompt provides."""
|
"""An optional description of what this prompt provides."""
|
||||||
arguments: list[PromptArgument] | None = None
|
arguments: list[PromptArgument] | None = None
|
||||||
"""A list of arguments to use for templating the prompt."""
|
"""A list of arguments to use for templating the prompt."""
|
||||||
|
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||||
|
"""
|
||||||
|
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||||
|
for notes on _meta usage.
|
||||||
|
"""
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
@ -597,7 +637,7 @@ class GetPromptRequestParams(RequestParams):
|
|||||||
class GetPromptRequest(Request[GetPromptRequestParams, Literal["prompts/get"]]):
|
class GetPromptRequest(Request[GetPromptRequestParams, Literal["prompts/get"]]):
|
||||||
"""Used by the client to get a prompt provided by the server."""
|
"""Used by the client to get a prompt provided by the server."""
|
||||||
|
|
||||||
method: Literal["prompts/get"]
|
method: Literal["prompts/get"] = "prompts/get"
|
||||||
params: GetPromptRequestParams
|
params: GetPromptRequestParams
|
||||||
|
|
||||||
|
|
||||||
@ -608,6 +648,11 @@ class TextContent(BaseModel):
|
|||||||
text: str
|
text: str
|
||||||
"""The text content of the message."""
|
"""The text content of the message."""
|
||||||
annotations: Annotations | None = None
|
annotations: Annotations | None = None
|
||||||
|
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||||
|
"""
|
||||||
|
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||||
|
for notes on _meta usage.
|
||||||
|
"""
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
@ -623,6 +668,31 @@ class ImageContent(BaseModel):
|
|||||||
image types.
|
image types.
|
||||||
"""
|
"""
|
||||||
annotations: Annotations | None = None
|
annotations: Annotations | None = None
|
||||||
|
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||||
|
"""
|
||||||
|
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||||
|
for notes on _meta usage.
|
||||||
|
"""
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
|
class AudioContent(BaseModel):
|
||||||
|
"""Audio content for a message."""
|
||||||
|
|
||||||
|
type: Literal["audio"]
|
||||||
|
data: str
|
||||||
|
"""The base64-encoded audio data."""
|
||||||
|
mimeType: str
|
||||||
|
"""
|
||||||
|
The MIME type of the audio. Different providers may support different
|
||||||
|
audio types.
|
||||||
|
"""
|
||||||
|
annotations: Annotations | None = None
|
||||||
|
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||||
|
"""
|
||||||
|
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||||
|
for notes on _meta usage.
|
||||||
|
"""
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
@ -630,7 +700,7 @@ class SamplingMessage(BaseModel):
|
|||||||
"""Describes a message issued to or received from an LLM API."""
|
"""Describes a message issued to or received from an LLM API."""
|
||||||
|
|
||||||
role: Role
|
role: Role
|
||||||
content: TextContent | ImageContent
|
content: TextContent | ImageContent | AudioContent
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
@ -645,14 +715,36 @@ class EmbeddedResource(BaseModel):
|
|||||||
type: Literal["resource"]
|
type: Literal["resource"]
|
||||||
resource: TextResourceContents | BlobResourceContents
|
resource: TextResourceContents | BlobResourceContents
|
||||||
annotations: Annotations | None = None
|
annotations: Annotations | None = None
|
||||||
|
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||||
|
"""
|
||||||
|
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||||
|
for notes on _meta usage.
|
||||||
|
"""
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
|
class ResourceLink(Resource):
|
||||||
|
"""
|
||||||
|
A resource that the server is capable of reading, included in a prompt or tool call result.
|
||||||
|
|
||||||
|
Note: resource links returned by tools are not guaranteed to appear in the results of `resources/list` requests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal["resource_link"]
|
||||||
|
|
||||||
|
|
||||||
|
ContentBlock = TextContent | ImageContent | AudioContent | ResourceLink | EmbeddedResource
|
||||||
|
"""A content block that can be used in prompts and tool results."""
|
||||||
|
|
||||||
|
Content: TypeAlias = ContentBlock
|
||||||
|
# """DEPRECATED: Content is deprecated, you should use ContentBlock directly."""
|
||||||
|
|
||||||
|
|
||||||
class PromptMessage(BaseModel):
|
class PromptMessage(BaseModel):
|
||||||
"""Describes a message returned as part of a prompt."""
|
"""Describes a message returned as part of a prompt."""
|
||||||
|
|
||||||
role: Role
|
role: Role
|
||||||
content: TextContent | ImageContent | EmbeddedResource
|
content: ContentBlock
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
@ -672,15 +764,14 @@ class PromptListChangedNotification(
|
|||||||
of prompts it offers has changed.
|
of prompts it offers has changed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
method: Literal["notifications/prompts/list_changed"]
|
method: Literal["notifications/prompts/list_changed"] = "notifications/prompts/list_changed"
|
||||||
params: NotificationParams | None = None
|
params: NotificationParams | None = None
|
||||||
|
|
||||||
|
|
||||||
class ListToolsRequest(PaginatedRequest[RequestParams | None, Literal["tools/list"]]):
|
class ListToolsRequest(PaginatedRequest[Literal["tools/list"]]):
|
||||||
"""Sent from the client to request a list of tools the server has."""
|
"""Sent from the client to request a list of tools the server has."""
|
||||||
|
|
||||||
method: Literal["tools/list"]
|
method: Literal["tools/list"] = "tools/list"
|
||||||
params: RequestParams | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ToolAnnotations(BaseModel):
|
class ToolAnnotations(BaseModel):
|
||||||
@ -731,17 +822,25 @@ class ToolAnnotations(BaseModel):
|
|||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
class Tool(BaseModel):
|
class Tool(BaseMetadata):
|
||||||
"""Definition for a tool the client can call."""
|
"""Definition for a tool the client can call."""
|
||||||
|
|
||||||
name: str
|
|
||||||
"""The name of the tool."""
|
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
"""A human-readable description of the tool."""
|
"""A human-readable description of the tool."""
|
||||||
inputSchema: dict[str, Any]
|
inputSchema: dict[str, Any]
|
||||||
"""A JSON Schema object defining the expected parameters for the tool."""
|
"""A JSON Schema object defining the expected parameters for the tool."""
|
||||||
|
outputSchema: dict[str, Any] | None = None
|
||||||
|
"""
|
||||||
|
An optional JSON Schema object defining the structure of the tool's output
|
||||||
|
returned in the structuredContent field of a CallToolResult.
|
||||||
|
"""
|
||||||
annotations: ToolAnnotations | None = None
|
annotations: ToolAnnotations | None = None
|
||||||
"""Optional additional tool information."""
|
"""Optional additional tool information."""
|
||||||
|
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||||
|
"""
|
||||||
|
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||||
|
for notes on _meta usage.
|
||||||
|
"""
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
@ -762,14 +861,16 @@ class CallToolRequestParams(RequestParams):
|
|||||||
class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]):
|
class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]):
|
||||||
"""Used by the client to invoke a tool provided by the server."""
|
"""Used by the client to invoke a tool provided by the server."""
|
||||||
|
|
||||||
method: Literal["tools/call"]
|
method: Literal["tools/call"] = "tools/call"
|
||||||
params: CallToolRequestParams
|
params: CallToolRequestParams
|
||||||
|
|
||||||
|
|
||||||
class CallToolResult(Result):
|
class CallToolResult(Result):
|
||||||
"""The server's response to a tool call."""
|
"""The server's response to a tool call."""
|
||||||
|
|
||||||
content: list[TextContent | ImageContent | EmbeddedResource]
|
content: list[ContentBlock]
|
||||||
|
structuredContent: dict[str, Any] | None = None
|
||||||
|
"""An optional JSON object that represents the structured result of the tool call."""
|
||||||
isError: bool = False
|
isError: bool = False
|
||||||
|
|
||||||
|
|
||||||
@ -779,7 +880,7 @@ class ToolListChangedNotification(Notification[NotificationParams | None, Litera
|
|||||||
of tools it offers has changed.
|
of tools it offers has changed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
method: Literal["notifications/tools/list_changed"]
|
method: Literal["notifications/tools/list_changed"] = "notifications/tools/list_changed"
|
||||||
params: NotificationParams | None = None
|
params: NotificationParams | None = None
|
||||||
|
|
||||||
|
|
||||||
@ -797,7 +898,7 @@ class SetLevelRequestParams(RequestParams):
|
|||||||
class SetLevelRequest(Request[SetLevelRequestParams, Literal["logging/setLevel"]]):
|
class SetLevelRequest(Request[SetLevelRequestParams, Literal["logging/setLevel"]]):
|
||||||
"""A request from the client to the server, to enable or adjust logging."""
|
"""A request from the client to the server, to enable or adjust logging."""
|
||||||
|
|
||||||
method: Literal["logging/setLevel"]
|
method: Literal["logging/setLevel"] = "logging/setLevel"
|
||||||
params: SetLevelRequestParams
|
params: SetLevelRequestParams
|
||||||
|
|
||||||
|
|
||||||
@ -808,7 +909,7 @@ class LoggingMessageNotificationParams(NotificationParams):
|
|||||||
"""The severity of this log message."""
|
"""The severity of this log message."""
|
||||||
logger: str | None = None
|
logger: str | None = None
|
||||||
"""An optional name of the logger issuing this message."""
|
"""An optional name of the logger issuing this message."""
|
||||||
data: Any = None
|
data: Any
|
||||||
"""
|
"""
|
||||||
The data to be logged, such as a string message or an object. Any JSON serializable
|
The data to be logged, such as a string message or an object. Any JSON serializable
|
||||||
type is allowed here.
|
type is allowed here.
|
||||||
@ -819,7 +920,7 @@ class LoggingMessageNotificationParams(NotificationParams):
|
|||||||
class LoggingMessageNotification(Notification[LoggingMessageNotificationParams, Literal["notifications/message"]]):
|
class LoggingMessageNotification(Notification[LoggingMessageNotificationParams, Literal["notifications/message"]]):
|
||||||
"""Notification of a log message passed from server to client."""
|
"""Notification of a log message passed from server to client."""
|
||||||
|
|
||||||
method: Literal["notifications/message"]
|
method: Literal["notifications/message"] = "notifications/message"
|
||||||
params: LoggingMessageNotificationParams
|
params: LoggingMessageNotificationParams
|
||||||
|
|
||||||
|
|
||||||
@ -914,7 +1015,7 @@ class CreateMessageRequestParams(RequestParams):
|
|||||||
class CreateMessageRequest(Request[CreateMessageRequestParams, Literal["sampling/createMessage"]]):
|
class CreateMessageRequest(Request[CreateMessageRequestParams, Literal["sampling/createMessage"]]):
|
||||||
"""A request from the server to sample an LLM via the client."""
|
"""A request from the server to sample an LLM via the client."""
|
||||||
|
|
||||||
method: Literal["sampling/createMessage"]
|
method: Literal["sampling/createMessage"] = "sampling/createMessage"
|
||||||
params: CreateMessageRequestParams
|
params: CreateMessageRequestParams
|
||||||
|
|
||||||
|
|
||||||
@ -925,14 +1026,14 @@ class CreateMessageResult(Result):
|
|||||||
"""The client's response to a sampling/create_message request from the server."""
|
"""The client's response to a sampling/create_message request from the server."""
|
||||||
|
|
||||||
role: Role
|
role: Role
|
||||||
content: TextContent | ImageContent
|
content: TextContent | ImageContent | AudioContent
|
||||||
model: str
|
model: str
|
||||||
"""The name of the model that generated the message."""
|
"""The name of the model that generated the message."""
|
||||||
stopReason: StopReason | None = None
|
stopReason: StopReason | None = None
|
||||||
"""The reason why sampling stopped, if known."""
|
"""The reason why sampling stopped, if known."""
|
||||||
|
|
||||||
|
|
||||||
class ResourceReference(BaseModel):
|
class ResourceTemplateReference(BaseModel):
|
||||||
"""A reference to a resource or resource template definition."""
|
"""A reference to a resource or resource template definition."""
|
||||||
|
|
||||||
type: Literal["ref/resource"]
|
type: Literal["ref/resource"]
|
||||||
@ -960,18 +1061,28 @@ class CompletionArgument(BaseModel):
|
|||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionContext(BaseModel):
|
||||||
|
"""Additional, optional context for completions."""
|
||||||
|
|
||||||
|
arguments: dict[str, str] | None = None
|
||||||
|
"""Previously-resolved variables in a URI template or prompt."""
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
class CompleteRequestParams(RequestParams):
|
class CompleteRequestParams(RequestParams):
|
||||||
"""Parameters for completion requests."""
|
"""Parameters for completion requests."""
|
||||||
|
|
||||||
ref: ResourceReference | PromptReference
|
ref: ResourceTemplateReference | PromptReference
|
||||||
argument: CompletionArgument
|
argument: CompletionArgument
|
||||||
|
context: CompletionContext | None = None
|
||||||
|
"""Additional, optional context for completions"""
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
class CompleteRequest(Request[CompleteRequestParams, Literal["completion/complete"]]):
|
class CompleteRequest(Request[CompleteRequestParams, Literal["completion/complete"]]):
|
||||||
"""A request from the client to the server, to ask for completion options."""
|
"""A request from the client to the server, to ask for completion options."""
|
||||||
|
|
||||||
method: Literal["completion/complete"]
|
method: Literal["completion/complete"] = "completion/complete"
|
||||||
params: CompleteRequestParams
|
params: CompleteRequestParams
|
||||||
|
|
||||||
|
|
||||||
@ -1010,7 +1121,7 @@ class ListRootsRequest(Request[RequestParams | None, Literal["roots/list"]]):
|
|||||||
structure or access specific locations that the client has permission to read from.
|
structure or access specific locations that the client has permission to read from.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
method: Literal["roots/list"]
|
method: Literal["roots/list"] = "roots/list"
|
||||||
params: RequestParams | None = None
|
params: RequestParams | None = None
|
||||||
|
|
||||||
|
|
||||||
@ -1029,6 +1140,11 @@ class Root(BaseModel):
|
|||||||
identifier for the root, which may be useful for display purposes or for
|
identifier for the root, which may be useful for display purposes or for
|
||||||
referencing the root in other parts of the application.
|
referencing the root in other parts of the application.
|
||||||
"""
|
"""
|
||||||
|
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||||
|
"""
|
||||||
|
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||||
|
for notes on _meta usage.
|
||||||
|
"""
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
@ -1054,7 +1170,7 @@ class RootsListChangedNotification(
|
|||||||
using the ListRootsRequest.
|
using the ListRootsRequest.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
method: Literal["notifications/roots/list_changed"]
|
method: Literal["notifications/roots/list_changed"] = "notifications/roots/list_changed"
|
||||||
params: NotificationParams | None = None
|
params: NotificationParams | None = None
|
||||||
|
|
||||||
|
|
||||||
@ -1074,7 +1190,7 @@ class CancelledNotification(Notification[CancelledNotificationParams, Literal["n
|
|||||||
previously-issued request.
|
previously-issued request.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
method: Literal["notifications/cancelled"]
|
method: Literal["notifications/cancelled"] = "notifications/cancelled"
|
||||||
params: CancelledNotificationParams
|
params: CancelledNotificationParams
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||||
from core.file import file_manager
|
from core.file import file_manager
|
||||||
@ -18,7 +19,9 @@ from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
from models.model import AppMode, Conversation, Message, MessageFile
|
from models.model import AppMode, Conversation, Message, MessageFile
|
||||||
from models.workflow import Workflow, WorkflowRun
|
from models.workflow import Workflow
|
||||||
|
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||||
|
from repositories.factory import DifyAPIRepositoryFactory
|
||||||
|
|
||||||
|
|
||||||
class TokenBufferMemory:
|
class TokenBufferMemory:
|
||||||
@ -29,6 +32,14 @@ class TokenBufferMemory:
|
|||||||
):
|
):
|
||||||
self.conversation = conversation
|
self.conversation = conversation
|
||||||
self.model_instance = model_instance
|
self.model_instance = model_instance
|
||||||
|
self._workflow_run_repo: APIWorkflowRunRepository | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def workflow_run_repo(self) -> APIWorkflowRunRepository:
|
||||||
|
if self._workflow_run_repo is None:
|
||||||
|
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||||
|
return self._workflow_run_repo
|
||||||
|
|
||||||
def _build_prompt_message_with_files(
|
def _build_prompt_message_with_files(
|
||||||
self,
|
self,
|
||||||
@ -50,7 +61,16 @@ class TokenBufferMemory:
|
|||||||
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
|
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
|
||||||
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||||
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||||
workflow_run = db.session.scalar(select(WorkflowRun).where(WorkflowRun.id == message.workflow_run_id))
|
app = self.conversation.app
|
||||||
|
if not app:
|
||||||
|
raise ValueError("App not found for conversation")
|
||||||
|
|
||||||
|
if not message.workflow_run_id:
|
||||||
|
raise ValueError("Workflow run ID not found")
|
||||||
|
|
||||||
|
workflow_run = self.workflow_run_repo.get_workflow_run_by_id(
|
||||||
|
tenant_id=app.tenant_id, app_id=app.id, run_id=message.workflow_run_id
|
||||||
|
)
|
||||||
if not workflow_run:
|
if not workflow_run:
|
||||||
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
|
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
|
||||||
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
|
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from uuid import UUID, uuid4
|
|||||||
from cachetools import LRUCache
|
from cachetools import LRUCache
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
|
from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
|
||||||
from core.ops.entities.config_entity import (
|
from core.ops.entities.config_entity import (
|
||||||
@ -34,7 +34,8 @@ from core.ops.utils import get_message_data
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_storage import storage
|
from extensions.ext_storage import storage
|
||||||
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
|
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
|
||||||
from models.workflow import WorkflowAppLog, WorkflowRun
|
from models.workflow import WorkflowAppLog
|
||||||
|
from repositories.factory import DifyAPIRepositoryFactory
|
||||||
from tasks.ops_trace_task import process_trace_tasks
|
from tasks.ops_trace_task import process_trace_tasks
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -419,6 +420,18 @@ class OpsTraceManager:
|
|||||||
|
|
||||||
|
|
||||||
class TraceTask:
|
class TraceTask:
|
||||||
|
_workflow_run_repo = None
|
||||||
|
_repo_lock = threading.Lock()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_workflow_run_repo(cls):
|
||||||
|
if cls._workflow_run_repo is None:
|
||||||
|
with cls._repo_lock:
|
||||||
|
if cls._workflow_run_repo is None:
|
||||||
|
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
cls._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||||
|
return cls._workflow_run_repo
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
trace_type: Any,
|
trace_type: Any,
|
||||||
@ -486,27 +499,27 @@ class TraceTask:
|
|||||||
if not workflow_run_id:
|
if not workflow_run_id:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
workflow_run_repo = self._get_workflow_run_repo()
|
||||||
|
workflow_run = workflow_run_repo.get_workflow_run_by_id_without_tenant(run_id=workflow_run_id)
|
||||||
|
if not workflow_run:
|
||||||
|
raise ValueError("Workflow run not found")
|
||||||
|
|
||||||
|
workflow_id = workflow_run.workflow_id
|
||||||
|
tenant_id = workflow_run.tenant_id
|
||||||
|
workflow_run_id = workflow_run.id
|
||||||
|
workflow_run_elapsed_time = workflow_run.elapsed_time
|
||||||
|
workflow_run_status = workflow_run.status
|
||||||
|
workflow_run_inputs = workflow_run.inputs_dict
|
||||||
|
workflow_run_outputs = workflow_run.outputs_dict
|
||||||
|
workflow_run_version = workflow_run.version
|
||||||
|
error = workflow_run.error or ""
|
||||||
|
|
||||||
|
total_tokens = workflow_run.total_tokens
|
||||||
|
|
||||||
|
file_list = workflow_run_inputs.get("sys.file") or []
|
||||||
|
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
workflow_run_stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
|
|
||||||
workflow_run = session.scalars(workflow_run_stmt).first()
|
|
||||||
if not workflow_run:
|
|
||||||
raise ValueError("Workflow run not found")
|
|
||||||
|
|
||||||
workflow_id = workflow_run.workflow_id
|
|
||||||
tenant_id = workflow_run.tenant_id
|
|
||||||
workflow_run_id = workflow_run.id
|
|
||||||
workflow_run_elapsed_time = workflow_run.elapsed_time
|
|
||||||
workflow_run_status = workflow_run.status
|
|
||||||
workflow_run_inputs = workflow_run.inputs_dict
|
|
||||||
workflow_run_outputs = workflow_run.outputs_dict
|
|
||||||
workflow_run_version = workflow_run.version
|
|
||||||
error = workflow_run.error or ""
|
|
||||||
|
|
||||||
total_tokens = workflow_run.total_tokens
|
|
||||||
|
|
||||||
file_list = workflow_run_inputs.get("sys.file") or []
|
|
||||||
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
|
|
||||||
|
|
||||||
# get workflow_app_log_id
|
# get workflow_app_log_id
|
||||||
workflow_app_log_data_stmt = select(WorkflowAppLog.id).where(
|
workflow_app_log_data_stmt = select(WorkflowAppLog.id).where(
|
||||||
WorkflowAppLog.tenant_id == tenant_id,
|
WorkflowAppLog.tenant_id == tenant_id,
|
||||||
@ -523,43 +536,43 @@ class TraceTask:
|
|||||||
)
|
)
|
||||||
message_id = session.scalar(message_data_stmt)
|
message_id = session.scalar(message_data_stmt)
|
||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
"workflow_id": workflow_id,
|
"workflow_id": workflow_id,
|
||||||
"conversation_id": conversation_id,
|
"conversation_id": conversation_id,
|
||||||
"workflow_run_id": workflow_run_id,
|
"workflow_run_id": workflow_run_id,
|
||||||
"tenant_id": tenant_id,
|
"tenant_id": tenant_id,
|
||||||
"elapsed_time": workflow_run_elapsed_time,
|
"elapsed_time": workflow_run_elapsed_time,
|
||||||
"status": workflow_run_status,
|
"status": workflow_run_status,
|
||||||
"version": workflow_run_version,
|
"version": workflow_run_version,
|
||||||
"total_tokens": total_tokens,
|
"total_tokens": total_tokens,
|
||||||
"file_list": file_list,
|
"file_list": file_list,
|
||||||
"triggered_from": workflow_run.triggered_from,
|
"triggered_from": workflow_run.triggered_from,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"app_id": workflow_run.app_id,
|
"app_id": workflow_run.app_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
workflow_trace_info = WorkflowTraceInfo(
|
workflow_trace_info = WorkflowTraceInfo(
|
||||||
trace_id=self.trace_id,
|
trace_id=self.trace_id,
|
||||||
workflow_data=workflow_run.to_dict(),
|
workflow_data=workflow_run.to_dict(),
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
workflow_id=workflow_id,
|
workflow_id=workflow_id,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
workflow_run_id=workflow_run_id,
|
workflow_run_id=workflow_run_id,
|
||||||
workflow_run_elapsed_time=workflow_run_elapsed_time,
|
workflow_run_elapsed_time=workflow_run_elapsed_time,
|
||||||
workflow_run_status=workflow_run_status,
|
workflow_run_status=workflow_run_status,
|
||||||
workflow_run_inputs=workflow_run_inputs,
|
workflow_run_inputs=workflow_run_inputs,
|
||||||
workflow_run_outputs=workflow_run_outputs,
|
workflow_run_outputs=workflow_run_outputs,
|
||||||
workflow_run_version=workflow_run_version,
|
workflow_run_version=workflow_run_version,
|
||||||
error=error,
|
error=error,
|
||||||
total_tokens=total_tokens,
|
total_tokens=total_tokens,
|
||||||
file_list=file_list,
|
file_list=file_list,
|
||||||
query=query,
|
query=query,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
workflow_app_log_id=workflow_app_log_id,
|
workflow_app_log_id=workflow_app_log_id,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
start_time=workflow_run.created_at,
|
start_time=workflow_run.created_at,
|
||||||
end_time=workflow_run.finished_at,
|
end_time=workflow_run.finished_at,
|
||||||
)
|
)
|
||||||
return workflow_trace_info
|
return workflow_trace_info
|
||||||
|
|
||||||
def message_trace(self, message_id: str | None):
|
def message_trace(self, message_id: str | None):
|
||||||
|
|||||||
@ -161,7 +161,7 @@ class OpenSearchVector(BaseVector):
|
|||||||
logger.exception("Error deleting document: %s", error)
|
logger.exception("Error deleting document: %s", error)
|
||||||
|
|
||||||
def delete(self):
|
def delete(self):
|
||||||
self._client.indices.delete(index=self._collection_name.lower())
|
self._client.indices.delete(index=self._collection_name.lower(), ignore_unavailable=True)
|
||||||
|
|
||||||
def text_exists(self, id: str) -> bool:
|
def text_exists(self, id: str) -> bool:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -217,3 +217,16 @@ class Tool(ABC):
|
|||||||
return ToolInvokeMessage(
|
return ToolInvokeMessage(
|
||||||
type=ToolInvokeMessage.MessageType.JSON, message=ToolInvokeMessage.JsonMessage(json_object=object)
|
type=ToolInvokeMessage.MessageType.JSON, message=ToolInvokeMessage.JsonMessage(json_object=object)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def create_variable_message(
|
||||||
|
self, variable_name: str, variable_value: Any, stream: bool = False
|
||||||
|
) -> ToolInvokeMessage:
|
||||||
|
"""
|
||||||
|
create a variable message
|
||||||
|
"""
|
||||||
|
return ToolInvokeMessage(
|
||||||
|
type=ToolInvokeMessage.MessageType.VARIABLE,
|
||||||
|
message=ToolInvokeMessage.VariableMessage(
|
||||||
|
variable_name=variable_name, variable_value=variable_value, stream=stream
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from typing import Any, Literal
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.tools.__base.tool import ToolParameter
|
from core.tools.__base.tool import ToolParameter
|
||||||
from core.tools.entities.common_entities import I18nObject
|
from core.tools.entities.common_entities import I18nObject
|
||||||
@ -44,10 +45,14 @@ class ToolProviderApiEntity(BaseModel):
|
|||||||
server_url: str | None = Field(default="", description="The server url of the tool")
|
server_url: str | None = Field(default="", description="The server url of the tool")
|
||||||
updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
|
updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
|
||||||
server_identifier: str | None = Field(default="", description="The server identifier of the MCP tool")
|
server_identifier: str | None = Field(default="", description="The server identifier of the MCP tool")
|
||||||
timeout: float | None = Field(default=30.0, description="The timeout of the MCP tool")
|
|
||||||
sse_read_timeout: float | None = Field(default=300.0, description="The SSE read timeout of the MCP tool")
|
|
||||||
masked_headers: dict[str, str] | None = Field(default=None, description="The masked headers of the MCP tool")
|
masked_headers: dict[str, str] | None = Field(default=None, description="The masked headers of the MCP tool")
|
||||||
original_headers: dict[str, str] | None = Field(default=None, description="The original headers of the MCP tool")
|
original_headers: dict[str, str] | None = Field(default=None, description="The original headers of the MCP tool")
|
||||||
|
authentication: MCPAuthentication | None = Field(default=None, description="The OAuth config of the MCP tool")
|
||||||
|
is_dynamic_registration: bool = Field(default=True, description="Whether the MCP tool is dynamically registered")
|
||||||
|
configuration: MCPConfiguration | None = Field(
|
||||||
|
default=None, description="The timeout and sse_read_timeout of the MCP tool"
|
||||||
|
)
|
||||||
|
|
||||||
@field_validator("tools", mode="before")
|
@field_validator("tools", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -70,8 +75,15 @@ class ToolProviderApiEntity(BaseModel):
|
|||||||
if self.type == ToolProviderType.MCP:
|
if self.type == ToolProviderType.MCP:
|
||||||
optional_fields.update(self.optional_field("updated_at", self.updated_at))
|
optional_fields.update(self.optional_field("updated_at", self.updated_at))
|
||||||
optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
|
optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
|
||||||
optional_fields.update(self.optional_field("timeout", self.timeout))
|
optional_fields.update(
|
||||||
optional_fields.update(self.optional_field("sse_read_timeout", self.sse_read_timeout))
|
self.optional_field(
|
||||||
|
"configuration", self.configuration.model_dump() if self.configuration else MCPConfiguration()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
optional_fields.update(
|
||||||
|
self.optional_field("authentication", self.authentication.model_dump() if self.authentication else None)
|
||||||
|
)
|
||||||
|
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
|
||||||
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
|
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
|
||||||
optional_fields.update(self.optional_field("original_headers", self.original_headers))
|
optional_fields.update(self.optional_field("original_headers", self.original_headers))
|
||||||
return {
|
return {
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import json
|
|
||||||
from typing import Any, Self
|
from typing import Any, Self
|
||||||
|
|
||||||
|
from core.entities.mcp_provider import MCPProviderEntity
|
||||||
from core.mcp.types import Tool as RemoteMCPTool
|
from core.mcp.types import Tool as RemoteMCPTool
|
||||||
from core.tools.__base.tool_provider import ToolProviderController
|
from core.tools.__base.tool_provider import ToolProviderController
|
||||||
from core.tools.__base.tool_runtime import ToolRuntime
|
from core.tools.__base.tool_runtime import ToolRuntime
|
||||||
@ -52,18 +52,25 @@ class MCPToolProviderController(ToolProviderController):
|
|||||||
"""
|
"""
|
||||||
from db provider
|
from db provider
|
||||||
"""
|
"""
|
||||||
tools = []
|
# Convert to entity first
|
||||||
tools_data = json.loads(db_provider.tools)
|
provider_entity = db_provider.to_entity()
|
||||||
remote_mcp_tools = [RemoteMCPTool.model_validate(tool) for tool in tools_data]
|
return cls.from_entity(provider_entity)
|
||||||
user = db_provider.load_user()
|
|
||||||
|
@classmethod
|
||||||
|
def from_entity(cls, entity: MCPProviderEntity) -> Self:
|
||||||
|
"""
|
||||||
|
create a MCPToolProviderController from a MCPProviderEntity
|
||||||
|
"""
|
||||||
|
remote_mcp_tools = [RemoteMCPTool(**tool) for tool in entity.tools]
|
||||||
|
|
||||||
tools = [
|
tools = [
|
||||||
ToolEntity(
|
ToolEntity(
|
||||||
identity=ToolIdentity(
|
identity=ToolIdentity(
|
||||||
author=user.name if user else "Anonymous",
|
author="Anonymous", # Tool level author is not stored
|
||||||
name=remote_mcp_tool.name,
|
name=remote_mcp_tool.name,
|
||||||
label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name),
|
label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name),
|
||||||
provider=db_provider.server_identifier,
|
provider=entity.provider_id,
|
||||||
icon=db_provider.icon,
|
icon=entity.icon if isinstance(entity.icon, str) else "",
|
||||||
),
|
),
|
||||||
parameters=ToolTransformService.convert_mcp_schema_to_parameter(remote_mcp_tool.inputSchema),
|
parameters=ToolTransformService.convert_mcp_schema_to_parameter(remote_mcp_tool.inputSchema),
|
||||||
description=ToolDescription(
|
description=ToolDescription(
|
||||||
@ -72,31 +79,32 @@ class MCPToolProviderController(ToolProviderController):
|
|||||||
),
|
),
|
||||||
llm=remote_mcp_tool.description or "",
|
llm=remote_mcp_tool.description or "",
|
||||||
),
|
),
|
||||||
|
output_schema=remote_mcp_tool.outputSchema or {},
|
||||||
has_runtime_parameters=len(remote_mcp_tool.inputSchema) > 0,
|
has_runtime_parameters=len(remote_mcp_tool.inputSchema) > 0,
|
||||||
)
|
)
|
||||||
for remote_mcp_tool in remote_mcp_tools
|
for remote_mcp_tool in remote_mcp_tools
|
||||||
]
|
]
|
||||||
if not db_provider.icon:
|
if not entity.icon:
|
||||||
raise ValueError("Database provider icon is required")
|
raise ValueError("Database provider icon is required")
|
||||||
return cls(
|
return cls(
|
||||||
entity=ToolProviderEntityWithPlugin(
|
entity=ToolProviderEntityWithPlugin(
|
||||||
identity=ToolProviderIdentity(
|
identity=ToolProviderIdentity(
|
||||||
author=user.name if user else "Anonymous",
|
author="Anonymous", # Provider level author is not stored in entity
|
||||||
name=db_provider.name,
|
name=entity.name,
|
||||||
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
|
label=I18nObject(en_US=entity.name, zh_Hans=entity.name),
|
||||||
description=I18nObject(en_US="", zh_Hans=""),
|
description=I18nObject(en_US="", zh_Hans=""),
|
||||||
icon=db_provider.icon,
|
icon=entity.icon if isinstance(entity.icon, str) else "",
|
||||||
),
|
),
|
||||||
plugin_id=None,
|
plugin_id=None,
|
||||||
credentials_schema=[],
|
credentials_schema=[],
|
||||||
tools=tools,
|
tools=tools,
|
||||||
),
|
),
|
||||||
provider_id=db_provider.server_identifier or "",
|
provider_id=entity.provider_id,
|
||||||
tenant_id=db_provider.tenant_id or "",
|
tenant_id=entity.tenant_id,
|
||||||
server_url=db_provider.decrypted_server_url,
|
server_url=entity.server_url,
|
||||||
headers=db_provider.decrypted_headers or {},
|
headers=entity.headers,
|
||||||
timeout=db_provider.timeout,
|
timeout=entity.timeout,
|
||||||
sse_read_timeout=db_provider.sse_read_timeout,
|
sse_read_timeout=entity.sse_read_timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||||
|
|||||||
@ -3,12 +3,13 @@ import json
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||||
from core.mcp.mcp_client import MCPClient
|
from core.mcp.error import MCPConnectionError
|
||||||
from core.mcp.types import ImageContent, TextContent
|
from core.mcp.types import CallToolResult, ImageContent, TextContent
|
||||||
from core.tools.__base.tool import Tool
|
from core.tools.__base.tool import Tool
|
||||||
from core.tools.__base.tool_runtime import ToolRuntime
|
from core.tools.__base.tool_runtime import ToolRuntime
|
||||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
|
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
|
||||||
|
from core.tools.errors import ToolInvokeError
|
||||||
|
|
||||||
|
|
||||||
class MCPTool(Tool):
|
class MCPTool(Tool):
|
||||||
@ -44,40 +45,32 @@ class MCPTool(Tool):
|
|||||||
app_id: str | None = None,
|
app_id: str | None = None,
|
||||||
message_id: str | None = None,
|
message_id: str | None = None,
|
||||||
) -> Generator[ToolInvokeMessage, None, None]:
|
) -> Generator[ToolInvokeMessage, None, None]:
|
||||||
from core.tools.errors import ToolInvokeError
|
result = self.invoke_remote_mcp_tool(tool_parameters)
|
||||||
|
# handle dify tool output
|
||||||
try:
|
|
||||||
with MCPClient(
|
|
||||||
self.server_url,
|
|
||||||
self.provider_id,
|
|
||||||
self.tenant_id,
|
|
||||||
authed=True,
|
|
||||||
headers=self.headers,
|
|
||||||
timeout=self.timeout,
|
|
||||||
sse_read_timeout=self.sse_read_timeout,
|
|
||||||
) as mcp_client:
|
|
||||||
tool_parameters = self._handle_none_parameter(tool_parameters)
|
|
||||||
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
|
||||||
except MCPAuthError as e:
|
|
||||||
raise ToolInvokeError("Please auth the tool first") from e
|
|
||||||
except MCPConnectionError as e:
|
|
||||||
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
|
||||||
except Exception as e:
|
|
||||||
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
|
|
||||||
|
|
||||||
for content in result.content:
|
for content in result.content:
|
||||||
if isinstance(content, TextContent):
|
if isinstance(content, TextContent):
|
||||||
yield from self._process_text_content(content)
|
yield from self._process_text_content(content)
|
||||||
elif isinstance(content, ImageContent):
|
elif isinstance(content, ImageContent):
|
||||||
yield self._process_image_content(content)
|
yield self._process_image_content(content)
|
||||||
|
# handle MCP structured output
|
||||||
|
if self.entity.output_schema and result.structuredContent:
|
||||||
|
for k, v in result.structuredContent.items():
|
||||||
|
yield self.create_variable_message(k, v)
|
||||||
|
|
||||||
def _process_text_content(self, content: TextContent) -> Generator[ToolInvokeMessage, None, None]:
|
def _process_text_content(self, content: TextContent) -> Generator[ToolInvokeMessage, None, None]:
|
||||||
"""Process text content and yield appropriate messages."""
|
"""Process text content and yield appropriate messages."""
|
||||||
try:
|
# Check if content looks like JSON before attempting to parse
|
||||||
content_json = json.loads(content.text)
|
text = content.text.strip()
|
||||||
yield from self._process_json_content(content_json)
|
if text and text[0] in ("{", "[") and text[-1] in ("}", "]"):
|
||||||
except json.JSONDecodeError:
|
try:
|
||||||
yield self.create_text_message(content.text)
|
content_json = json.loads(text)
|
||||||
|
yield from self._process_json_content(content_json)
|
||||||
|
return
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# If not JSON or parsing failed, treat as plain text
|
||||||
|
yield self.create_text_message(content.text)
|
||||||
|
|
||||||
def _process_json_content(self, content_json: Any) -> Generator[ToolInvokeMessage, None, None]:
|
def _process_json_content(self, content_json: Any) -> Generator[ToolInvokeMessage, None, None]:
|
||||||
"""Process JSON content based on its type."""
|
"""Process JSON content based on its type."""
|
||||||
@ -126,3 +119,44 @@ class MCPTool(Tool):
|
|||||||
for key, value in parameter.items()
|
for key, value in parameter.items()
|
||||||
if value is not None and not (isinstance(value, str) and value.strip() == "")
|
if value is not None and not (isinstance(value, str) and value.strip() == "")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def invoke_remote_mcp_tool(self, tool_parameters: dict[str, Any]) -> CallToolResult:
|
||||||
|
headers = self.headers.copy() if self.headers else {}
|
||||||
|
tool_parameters = self._handle_none_parameter(tool_parameters)
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||||
|
|
||||||
|
# Step 1: Load provider entity and credentials in a short-lived session
|
||||||
|
# This minimizes database connection hold time
|
||||||
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
|
mcp_service = MCPToolManageService(session=session)
|
||||||
|
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
|
||||||
|
|
||||||
|
# Decrypt and prepare all credentials before closing session
|
||||||
|
server_url = provider_entity.decrypt_server_url()
|
||||||
|
headers = provider_entity.decrypt_headers()
|
||||||
|
|
||||||
|
# Try to get existing token and add to headers
|
||||||
|
if not headers:
|
||||||
|
tokens = provider_entity.retrieve_tokens()
|
||||||
|
if tokens and tokens.access_token:
|
||||||
|
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
|
||||||
|
|
||||||
|
# Step 2: Session is now closed, perform network operations without holding database connection
|
||||||
|
# MCPClientWithAuthRetry will create a new session lazily only if auth retry is needed
|
||||||
|
try:
|
||||||
|
with MCPClientWithAuthRetry(
|
||||||
|
server_url=server_url,
|
||||||
|
headers=headers,
|
||||||
|
timeout=self.timeout,
|
||||||
|
sse_read_timeout=self.sse_read_timeout,
|
||||||
|
provider_entity=provider_entity,
|
||||||
|
) as mcp_client:
|
||||||
|
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||||
|
except MCPConnectionError as e:
|
||||||
|
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
||||||
|
except Exception as e:
|
||||||
|
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
|
||||||
|
|||||||
@ -14,17 +14,32 @@ from sqlalchemy.orm import Session
|
|||||||
from yarl import URL
|
from yarl import URL
|
||||||
|
|
||||||
import contexts
|
import contexts
|
||||||
|
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||||
|
from core.plugin.impl.tool import PluginToolManager
|
||||||
|
from core.tools.__base.tool_provider import ToolProviderController
|
||||||
|
from core.tools.__base.tool_runtime import ToolRuntime
|
||||||
|
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||||
|
from core.tools.mcp_tool.tool import MCPTool
|
||||||
|
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||||
|
from core.tools.plugin_tool.tool import PluginTool
|
||||||
|
from core.tools.utils.uuid_utils import is_valid_uuid
|
||||||
|
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||||
|
from core.workflow.runtime.variable_pool import VariablePool
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.provider_ids import ToolProviderID
|
||||||
|
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||||
|
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.workflow.nodes.tool.entities import ToolEntity
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.agent.entities import AgentToolEntity
|
from core.agent.entities import AgentToolEntity
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||||
from core.helper.position_helper import is_filtered
|
from core.helper.position_helper import is_filtered
|
||||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.plugin.impl.tool import PluginToolManager
|
|
||||||
from core.tools.__base.tool import Tool
|
from core.tools.__base.tool import Tool
|
||||||
from core.tools.__base.tool_provider import ToolProviderController
|
|
||||||
from core.tools.__base.tool_runtime import ToolRuntime
|
|
||||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||||
from core.tools.builtin_tool.tool import BuiltinTool
|
from core.tools.builtin_tool.tool import BuiltinTool
|
||||||
@ -40,21 +55,11 @@ from core.tools.entities.tool_entities import (
|
|||||||
ToolProviderType,
|
ToolProviderType,
|
||||||
)
|
)
|
||||||
from core.tools.errors import ToolProviderNotFoundError
|
from core.tools.errors import ToolProviderNotFoundError
|
||||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
|
||||||
from core.tools.mcp_tool.tool import MCPTool
|
|
||||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
|
||||||
from core.tools.plugin_tool.tool import PluginTool
|
|
||||||
from core.tools.tool_label_manager import ToolLabelManager
|
from core.tools.tool_label_manager import ToolLabelManager
|
||||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||||
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
|
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
|
||||||
from core.tools.utils.uuid_utils import is_valid_uuid
|
|
||||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
|
||||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||||
from extensions.ext_database import db
|
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||||
from models.provider_ids import ToolProviderID
|
|
||||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
|
||||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
|
||||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
|
||||||
from services.tools.tools_transform_service import ToolTransformService
|
from services.tools.tools_transform_service import ToolTransformService
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -719,7 +724,9 @@ class ToolManager:
|
|||||||
)
|
)
|
||||||
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
|
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
|
||||||
if "mcp" in filters:
|
if "mcp" in filters:
|
||||||
mcp_providers = MCPToolManageService.retrieve_mcp_tools(tenant_id, for_list=True)
|
with Session(db.engine) as session:
|
||||||
|
mcp_service = MCPToolManageService(session=session)
|
||||||
|
mcp_providers = mcp_service.list_providers(tenant_id=tenant_id, for_list=True)
|
||||||
for mcp_provider in mcp_providers:
|
for mcp_provider in mcp_providers:
|
||||||
result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider
|
result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider
|
||||||
|
|
||||||
@ -774,17 +781,12 @@ class ToolManager:
|
|||||||
|
|
||||||
:return: the provider controller, the credentials
|
:return: the provider controller, the credentials
|
||||||
"""
|
"""
|
||||||
provider: MCPToolProvider | None = (
|
with Session(db.engine) as session:
|
||||||
db.session.query(MCPToolProvider)
|
mcp_service = MCPToolManageService(session=session)
|
||||||
.where(
|
try:
|
||||||
MCPToolProvider.server_identifier == provider_id,
|
provider = mcp_service.get_provider(server_identifier=provider_id, tenant_id=tenant_id)
|
||||||
MCPToolProvider.tenant_id == tenant_id,
|
except ValueError:
|
||||||
)
|
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if provider is None:
|
|
||||||
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
|
|
||||||
|
|
||||||
controller = MCPToolProviderController.from_db(provider)
|
controller = MCPToolProviderController.from_db(provider)
|
||||||
|
|
||||||
@ -922,16 +924,15 @@ class ToolManager:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str] | str:
|
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str] | str:
|
||||||
try:
|
try:
|
||||||
mcp_provider: MCPToolProvider | None = (
|
with Session(db.engine) as session:
|
||||||
db.session.query(MCPToolProvider)
|
mcp_service = MCPToolManageService(session=session)
|
||||||
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id)
|
try:
|
||||||
.first()
|
mcp_provider = mcp_service.get_provider_entity(
|
||||||
)
|
provider_id=provider_id, tenant_id=tenant_id, by_server_id=True
|
||||||
|
)
|
||||||
if mcp_provider is None:
|
return mcp_provider.provider_icon
|
||||||
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
|
except ValueError:
|
||||||
|
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
|
||||||
return mcp_provider.provider_icon
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||||
|
|
||||||
|
|||||||
@ -31,6 +31,7 @@ VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
|
|||||||
VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING,
|
VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING,
|
||||||
VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT,
|
VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT,
|
||||||
VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER,
|
VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER,
|
||||||
|
VariableEntityType.CHECKBOX: ToolParameter.ToolParameterType.BOOLEAN,
|
||||||
VariableEntityType.FILE: ToolParameter.ToolParameterType.FILE,
|
VariableEntityType.FILE: ToolParameter.ToolParameterType.FILE,
|
||||||
VariableEntityType.FILE_LIST: ToolParameter.ToolParameterType.FILES,
|
VariableEntityType.FILE_LIST: ToolParameter.ToolParameterType.FILES,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -24,6 +24,7 @@ from core.workflow.graph_events import (
|
|||||||
NodeRunLoopStartedEvent,
|
NodeRunLoopStartedEvent,
|
||||||
NodeRunLoopSucceededEvent,
|
NodeRunLoopSucceededEvent,
|
||||||
NodeRunPauseRequestedEvent,
|
NodeRunPauseRequestedEvent,
|
||||||
|
NodeRunRetrieverResourceEvent,
|
||||||
NodeRunRetryEvent,
|
NodeRunRetryEvent,
|
||||||
NodeRunStartedEvent,
|
NodeRunStartedEvent,
|
||||||
NodeRunStreamChunkEvent,
|
NodeRunStreamChunkEvent,
|
||||||
@ -112,6 +113,7 @@ class EventHandler:
|
|||||||
@_dispatch.register(NodeRunLoopSucceededEvent)
|
@_dispatch.register(NodeRunLoopSucceededEvent)
|
||||||
@_dispatch.register(NodeRunLoopFailedEvent)
|
@_dispatch.register(NodeRunLoopFailedEvent)
|
||||||
@_dispatch.register(NodeRunAgentLogEvent)
|
@_dispatch.register(NodeRunAgentLogEvent)
|
||||||
|
@_dispatch.register(NodeRunRetrieverResourceEvent)
|
||||||
def _(self, event: GraphNodeEventBase) -> None:
|
def _(self, event: GraphNodeEventBase) -> None:
|
||||||
self._event_collector.collect(event)
|
self._event_collector.collect(event)
|
||||||
|
|
||||||
|
|||||||
@ -23,6 +23,7 @@ class IterationNodeData(BaseIterationNodeData):
|
|||||||
is_parallel: bool = False # open the parallel mode or not
|
is_parallel: bool = False # open the parallel mode or not
|
||||||
parallel_nums: int = 10 # the numbers of parallel
|
parallel_nums: int = 10 # the numbers of parallel
|
||||||
error_handle_mode: ErrorHandleMode = ErrorHandleMode.TERMINATED # how to handle the error
|
error_handle_mode: ErrorHandleMode = ErrorHandleMode.TERMINATED # how to handle the error
|
||||||
|
flatten_output: bool = True # whether to flatten the output array if all elements are lists
|
||||||
|
|
||||||
|
|
||||||
class IterationStartNodeData(BaseNodeData):
|
class IterationStartNodeData(BaseNodeData):
|
||||||
|
|||||||
@ -98,6 +98,7 @@ class IterationNode(LLMUsageTrackingMixin, Node):
|
|||||||
"is_parallel": False,
|
"is_parallel": False,
|
||||||
"parallel_nums": 10,
|
"parallel_nums": 10,
|
||||||
"error_handle_mode": ErrorHandleMode.TERMINATED,
|
"error_handle_mode": ErrorHandleMode.TERMINATED,
|
||||||
|
"flatten_output": True,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -411,7 +412,14 @@ class IterationNode(LLMUsageTrackingMixin, Node):
|
|||||||
"""
|
"""
|
||||||
Flatten the outputs list if all elements are lists.
|
Flatten the outputs list if all elements are lists.
|
||||||
This maintains backward compatibility with version 1.8.1 behavior.
|
This maintains backward compatibility with version 1.8.1 behavior.
|
||||||
|
|
||||||
|
If flatten_output is False, returns outputs as-is (nested structure).
|
||||||
|
If flatten_output is True (default), flattens the list if all elements are lists.
|
||||||
"""
|
"""
|
||||||
|
# If flatten_output is disabled, return outputs as-is
|
||||||
|
if not self._node_data.flatten_output:
|
||||||
|
return outputs
|
||||||
|
|
||||||
if not outputs:
|
if not outputs:
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|||||||
@ -32,7 +32,7 @@ if [[ "${MODE}" == "worker" ]]; then
|
|||||||
|
|
||||||
exec celery -A celery_entrypoint.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \
|
exec celery -A celery_entrypoint.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \
|
||||||
--max-tasks-per-child ${MAX_TASKS_PER_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \
|
--max-tasks-per-child ${MAX_TASKS_PER_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \
|
||||||
-Q ${CELERY_QUEUES:-dataset,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation} \
|
-Q ${CELERY_QUEUES:-dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation} \
|
||||||
--prefetch-multiplier=1
|
--prefetch-multiplier=1
|
||||||
|
|
||||||
elif [[ "${MODE}" == "beat" ]]; then
|
elif [[ "${MODE}" == "beat" ]]; then
|
||||||
|
|||||||
@ -6,10 +6,11 @@ from flask_login import user_loaded_from_request, user_logged_in
|
|||||||
from werkzeug.exceptions import NotFound, Unauthorized
|
from werkzeug.exceptions import NotFound, Unauthorized
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from constants import HEADER_NAME_APP_CODE
|
||||||
from dify_app import DifyApp
|
from dify_app import DifyApp
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.passport import PassportService
|
from libs.passport import PassportService
|
||||||
from libs.token import extract_access_token
|
from libs.token import extract_access_token, extract_webapp_passport
|
||||||
from models import Account, Tenant, TenantAccountJoin
|
from models import Account, Tenant, TenantAccountJoin
|
||||||
from models.model import AppMCPServer, EndUser
|
from models.model import AppMCPServer, EndUser
|
||||||
from services.account_service import AccountService
|
from services.account_service import AccountService
|
||||||
@ -61,14 +62,30 @@ def load_user_from_request(request_from_flask_login):
|
|||||||
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
||||||
return logged_in_account
|
return logged_in_account
|
||||||
elif request.blueprint == "web":
|
elif request.blueprint == "web":
|
||||||
decoded = PassportService().verify(auth_token)
|
app_code = request.headers.get(HEADER_NAME_APP_CODE)
|
||||||
end_user_id = decoded.get("end_user_id")
|
webapp_token = extract_webapp_passport(app_code, request) if app_code else None
|
||||||
if not end_user_id:
|
|
||||||
raise Unauthorized("Invalid Authorization token.")
|
if webapp_token:
|
||||||
end_user = db.session.query(EndUser).where(EndUser.id == decoded["end_user_id"]).first()
|
decoded = PassportService().verify(webapp_token)
|
||||||
if not end_user:
|
end_user_id = decoded.get("end_user_id")
|
||||||
raise NotFound("End user not found.")
|
if not end_user_id:
|
||||||
return end_user
|
raise Unauthorized("Invalid Authorization token.")
|
||||||
|
end_user = db.session.query(EndUser).where(EndUser.id == end_user_id).first()
|
||||||
|
if not end_user:
|
||||||
|
raise NotFound("End user not found.")
|
||||||
|
return end_user
|
||||||
|
else:
|
||||||
|
if not auth_token:
|
||||||
|
raise Unauthorized("Invalid Authorization token.")
|
||||||
|
decoded = PassportService().verify(auth_token)
|
||||||
|
end_user_id = decoded.get("end_user_id")
|
||||||
|
if end_user_id:
|
||||||
|
end_user = db.session.query(EndUser).where(EndUser.id == end_user_id).first()
|
||||||
|
if not end_user:
|
||||||
|
raise NotFound("End user not found.")
|
||||||
|
return end_user
|
||||||
|
else:
|
||||||
|
raise Unauthorized("Invalid Authorization token for web API.")
|
||||||
elif request.blueprint == "mcp":
|
elif request.blueprint == "mcp":
|
||||||
server_code = request.view_args.get("server_code") if request.view_args else None
|
server_code = request.view_args.get("server_code") if request.view_args else None
|
||||||
if not server_code:
|
if not server_code:
|
||||||
|
|||||||
@ -9,9 +9,8 @@ from werkzeug.exceptions import HTTPException
|
|||||||
from werkzeug.http import HTTP_STATUS_CODES
|
from werkzeug.http import HTTP_STATUS_CODES
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from constants import COOKIE_NAME_ACCESS_TOKEN, COOKIE_NAME_CSRF_TOKEN, COOKIE_NAME_REFRESH_TOKEN
|
|
||||||
from core.errors.error import AppInvokeQuotaExceededError
|
from core.errors.error import AppInvokeQuotaExceededError
|
||||||
from libs.token import is_secure
|
from libs.token import build_force_logout_cookie_headers
|
||||||
|
|
||||||
|
|
||||||
def http_status_message(code):
|
def http_status_message(code):
|
||||||
@ -73,15 +72,7 @@ def register_external_error_handlers(api: Api):
|
|||||||
error_code = getattr(e, "error_code", None)
|
error_code = getattr(e, "error_code", None)
|
||||||
if error_code == "unauthorized_and_force_logout":
|
if error_code == "unauthorized_and_force_logout":
|
||||||
# Add Set-Cookie headers to clear auth cookies
|
# Add Set-Cookie headers to clear auth cookies
|
||||||
|
headers["Set-Cookie"] = build_force_logout_cookie_headers()
|
||||||
secure = is_secure()
|
|
||||||
# response is not accessible, so we need to do it ugly
|
|
||||||
common_part = "Path=/; Expires=Thu, 01 Jan 1970 00:00:00 GMT; HttpOnly"
|
|
||||||
headers["Set-Cookie"] = [
|
|
||||||
f'{COOKIE_NAME_ACCESS_TOKEN}=""; {common_part}{"; Secure" if secure else ""}; SameSite=Lax',
|
|
||||||
f'{COOKIE_NAME_CSRF_TOKEN}=""; {common_part}{"; Secure" if secure else ""}; SameSite=Lax',
|
|
||||||
f'{COOKIE_NAME_REFRESH_TOKEN}=""; {common_part}{"; Secure" if secure else ""}; SameSite=Lax',
|
|
||||||
]
|
|
||||||
return data, status_code, headers
|
return data, status_code, headers
|
||||||
|
|
||||||
_ = handle_http_exception
|
_ = handle_http_exception
|
||||||
|
|||||||
@ -30,17 +30,28 @@ def is_secure() -> bool:
|
|||||||
return dify_config.CONSOLE_WEB_URL.startswith("https") and dify_config.CONSOLE_API_URL.startswith("https")
|
return dify_config.CONSOLE_WEB_URL.startswith("https") and dify_config.CONSOLE_API_URL.startswith("https")
|
||||||
|
|
||||||
|
|
||||||
|
def _cookie_domain() -> str | None:
|
||||||
|
"""
|
||||||
|
Returns the normalized cookie domain.
|
||||||
|
|
||||||
|
Leading dots are stripped from the configured domain. Historically, a leading dot
|
||||||
|
indicated that a cookie should be sent to all subdomains, but modern browsers treat
|
||||||
|
'example.com' and '.example.com' identically. This normalization ensures consistent
|
||||||
|
behavior and avoids confusion.
|
||||||
|
"""
|
||||||
|
domain = dify_config.COOKIE_DOMAIN.strip()
|
||||||
|
domain = domain.removeprefix(".")
|
||||||
|
return domain or None
|
||||||
|
|
||||||
|
|
||||||
def _real_cookie_name(cookie_name: str) -> str:
|
def _real_cookie_name(cookie_name: str) -> str:
|
||||||
if is_secure():
|
if is_secure() and _cookie_domain() is None:
|
||||||
return "__Host-" + cookie_name
|
return "__Host-" + cookie_name
|
||||||
else:
|
else:
|
||||||
return cookie_name
|
return cookie_name
|
||||||
|
|
||||||
|
|
||||||
def _try_extract_from_header(request: Request) -> str | None:
|
def _try_extract_from_header(request: Request) -> str | None:
|
||||||
"""
|
|
||||||
Try to extract access token from header
|
|
||||||
"""
|
|
||||||
auth_header = request.headers.get("Authorization")
|
auth_header = request.headers.get("Authorization")
|
||||||
if auth_header:
|
if auth_header:
|
||||||
if " " not in auth_header:
|
if " " not in auth_header:
|
||||||
@ -55,27 +66,19 @@ def _try_extract_from_header(request: Request) -> str | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def extract_refresh_token(request: Request) -> str | None:
|
||||||
|
return request.cookies.get(_real_cookie_name(COOKIE_NAME_REFRESH_TOKEN))
|
||||||
|
|
||||||
|
|
||||||
def extract_csrf_token(request: Request) -> str | None:
|
def extract_csrf_token(request: Request) -> str | None:
|
||||||
"""
|
|
||||||
Try to extract CSRF token from header or cookie.
|
|
||||||
"""
|
|
||||||
return request.headers.get(HEADER_NAME_CSRF_TOKEN)
|
return request.headers.get(HEADER_NAME_CSRF_TOKEN)
|
||||||
|
|
||||||
|
|
||||||
def extract_csrf_token_from_cookie(request: Request) -> str | None:
|
def extract_csrf_token_from_cookie(request: Request) -> str | None:
|
||||||
"""
|
|
||||||
Try to extract CSRF token from cookie.
|
|
||||||
"""
|
|
||||||
return request.cookies.get(_real_cookie_name(COOKIE_NAME_CSRF_TOKEN))
|
return request.cookies.get(_real_cookie_name(COOKIE_NAME_CSRF_TOKEN))
|
||||||
|
|
||||||
|
|
||||||
def extract_access_token(request: Request) -> str | None:
|
def extract_access_token(request: Request) -> str | None:
|
||||||
"""
|
|
||||||
Try to extract access token from cookie, header or params.
|
|
||||||
|
|
||||||
Access token is either for console session or webapp passport exchange.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _try_extract_from_cookie(request: Request) -> str | None:
|
def _try_extract_from_cookie(request: Request) -> str | None:
|
||||||
return request.cookies.get(_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN))
|
return request.cookies.get(_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN))
|
||||||
|
|
||||||
@ -83,20 +86,10 @@ def extract_access_token(request: Request) -> str | None:
|
|||||||
|
|
||||||
|
|
||||||
def extract_webapp_access_token(request: Request) -> str | None:
|
def extract_webapp_access_token(request: Request) -> str | None:
|
||||||
"""
|
|
||||||
Try to extract webapp access token from cookie, then header.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return request.cookies.get(_real_cookie_name(COOKIE_NAME_WEBAPP_ACCESS_TOKEN)) or _try_extract_from_header(request)
|
return request.cookies.get(_real_cookie_name(COOKIE_NAME_WEBAPP_ACCESS_TOKEN)) or _try_extract_from_header(request)
|
||||||
|
|
||||||
|
|
||||||
def extract_webapp_passport(app_code: str, request: Request) -> str | None:
|
def extract_webapp_passport(app_code: str, request: Request) -> str | None:
|
||||||
"""
|
|
||||||
Try to extract app token from header or params.
|
|
||||||
|
|
||||||
Webapp access token (part of passport) is only used for webapp session.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _try_extract_passport_token_from_cookie(request: Request) -> str | None:
|
def _try_extract_passport_token_from_cookie(request: Request) -> str | None:
|
||||||
return request.cookies.get(_real_cookie_name(COOKIE_NAME_PASSPORT + "-" + app_code))
|
return request.cookies.get(_real_cookie_name(COOKIE_NAME_PASSPORT + "-" + app_code))
|
||||||
|
|
||||||
@ -112,6 +105,7 @@ def set_access_token_to_cookie(request: Request, response: Response, token: str,
|
|||||||
_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN),
|
_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN),
|
||||||
value=token,
|
value=token,
|
||||||
httponly=True,
|
httponly=True,
|
||||||
|
domain=_cookie_domain(),
|
||||||
secure=is_secure(),
|
secure=is_secure(),
|
||||||
samesite=samesite,
|
samesite=samesite,
|
||||||
max_age=int(dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 60),
|
max_age=int(dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 60),
|
||||||
@ -124,6 +118,7 @@ def set_refresh_token_to_cookie(request: Request, response: Response, token: str
|
|||||||
_real_cookie_name(COOKIE_NAME_REFRESH_TOKEN),
|
_real_cookie_name(COOKIE_NAME_REFRESH_TOKEN),
|
||||||
value=token,
|
value=token,
|
||||||
httponly=True,
|
httponly=True,
|
||||||
|
domain=_cookie_domain(),
|
||||||
secure=is_secure(),
|
secure=is_secure(),
|
||||||
samesite="Lax",
|
samesite="Lax",
|
||||||
max_age=int(60 * 60 * 24 * dify_config.REFRESH_TOKEN_EXPIRE_DAYS),
|
max_age=int(60 * 60 * 24 * dify_config.REFRESH_TOKEN_EXPIRE_DAYS),
|
||||||
@ -136,6 +131,7 @@ def set_csrf_token_to_cookie(request: Request, response: Response, token: str):
|
|||||||
_real_cookie_name(COOKIE_NAME_CSRF_TOKEN),
|
_real_cookie_name(COOKIE_NAME_CSRF_TOKEN),
|
||||||
value=token,
|
value=token,
|
||||||
httponly=False,
|
httponly=False,
|
||||||
|
domain=_cookie_domain(),
|
||||||
secure=is_secure(),
|
secure=is_secure(),
|
||||||
samesite="Lax",
|
samesite="Lax",
|
||||||
max_age=int(60 * dify_config.ACCESS_TOKEN_EXPIRE_MINUTES),
|
max_age=int(60 * dify_config.ACCESS_TOKEN_EXPIRE_MINUTES),
|
||||||
@ -154,6 +150,7 @@ def _clear_cookie(
|
|||||||
"",
|
"",
|
||||||
expires=0,
|
expires=0,
|
||||||
path="/",
|
path="/",
|
||||||
|
domain=_cookie_domain(),
|
||||||
secure=is_secure(),
|
secure=is_secure(),
|
||||||
httponly=http_only,
|
httponly=http_only,
|
||||||
samesite=samesite,
|
samesite=samesite,
|
||||||
@ -176,6 +173,19 @@ def clear_csrf_token_from_cookie(response: Response):
|
|||||||
_clear_cookie(response, COOKIE_NAME_CSRF_TOKEN, http_only=False)
|
_clear_cookie(response, COOKIE_NAME_CSRF_TOKEN, http_only=False)
|
||||||
|
|
||||||
|
|
||||||
|
def build_force_logout_cookie_headers() -> list[str]:
|
||||||
|
"""
|
||||||
|
Generate Set-Cookie header values that clear all auth-related cookies.
|
||||||
|
This mirrors the behavior of the standard cookie clearing helpers while
|
||||||
|
allowing callers that do not have a Response instance to reuse the logic.
|
||||||
|
"""
|
||||||
|
response = Response()
|
||||||
|
clear_access_token_from_cookie(response)
|
||||||
|
clear_csrf_token_from_cookie(response)
|
||||||
|
clear_refresh_token_from_cookie(response)
|
||||||
|
return response.headers.getlist("Set-Cookie")
|
||||||
|
|
||||||
|
|
||||||
def check_csrf_token(request: Request, user_id: str):
|
def check_csrf_token(request: Request, user_id: str):
|
||||||
# some apis are sent by beacon, so we need to bypass csrf token check
|
# some apis are sent by beacon, so we need to bypass csrf token check
|
||||||
# since these APIs are post, they are already protected by SameSite: Lax, so csrf is not required.
|
# since these APIs are post, they are already protected by SameSite: Lax, so csrf is not required.
|
||||||
|
|||||||
@ -1,16 +1,13 @@
|
|||||||
import json
|
import json
|
||||||
from collections.abc import Mapping
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from typing import TYPE_CHECKING, Any, cast
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from deprecated import deprecated
|
from deprecated import deprecated
|
||||||
from sqlalchemy import ForeignKey, String, func
|
from sqlalchemy import ForeignKey, String, func
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
from core.helper import encrypter
|
|
||||||
from core.tools.entities.common_entities import I18nObject
|
from core.tools.entities.common_entities import I18nObject
|
||||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
|
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
|
||||||
@ -21,7 +18,7 @@ from .model import Account, App, Tenant
|
|||||||
from .types import StringUUID
|
from .types import StringUUID
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from core.mcp.types import Tool as MCPTool
|
from core.entities.mcp_provider import MCPProviderEntity
|
||||||
from core.tools.entities.common_entities import I18nObject
|
from core.tools.entities.common_entities import I18nObject
|
||||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
|
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
|
||||||
@ -331,126 +328,36 @@ class MCPToolProvider(TypeBase):
|
|||||||
def load_user(self) -> Account | None:
|
def load_user(self) -> Account | None:
|
||||||
return db.session.query(Account).where(Account.id == self.user_id).first()
|
return db.session.query(Account).where(Account.id == self.user_id).first()
|
||||||
|
|
||||||
@property
|
|
||||||
def tenant(self) -> Tenant | None:
|
|
||||||
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def credentials(self) -> dict[str, Any]:
|
def credentials(self) -> dict[str, Any]:
|
||||||
if not self.encrypted_credentials:
|
if not self.encrypted_credentials:
|
||||||
return {}
|
return {}
|
||||||
try:
|
try:
|
||||||
return cast(dict[str, Any], json.loads(self.encrypted_credentials)) or {}
|
return json.loads(self.encrypted_credentials)
|
||||||
except json.JSONDecodeError:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def mcp_tools(self) -> list["MCPTool"]:
|
|
||||||
from core.mcp.types import Tool as MCPTool
|
|
||||||
|
|
||||||
return [MCPTool.model_validate(tool) for tool in json.loads(self.tools)]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def provider_icon(self) -> Mapping[str, str] | str:
|
|
||||||
from core.file import helpers as file_helpers
|
|
||||||
|
|
||||||
assert self.icon
|
|
||||||
try:
|
|
||||||
return json.loads(self.icon)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
return file_helpers.get_signed_file_url(self.icon)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def decrypted_server_url(self) -> str:
|
|
||||||
return encrypter.decrypt_token(self.tenant_id, self.server_url)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def decrypted_headers(self) -> dict[str, Any]:
|
|
||||||
"""Get decrypted headers for MCP server requests."""
|
|
||||||
from core.entities.provider_entities import BasicProviderConfig
|
|
||||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
|
||||||
from core.tools.utils.encryption import create_provider_encrypter
|
|
||||||
|
|
||||||
try:
|
|
||||||
if not self.encrypted_headers:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
headers_data = json.loads(self.encrypted_headers)
|
|
||||||
|
|
||||||
# Create dynamic config for all headers as SECRET_INPUT
|
|
||||||
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers_data]
|
|
||||||
|
|
||||||
encrypter_instance, _ = create_provider_encrypter(
|
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
config=config,
|
|
||||||
cache=NoOpProviderCredentialCache(),
|
|
||||||
)
|
|
||||||
|
|
||||||
result = encrypter_instance.decrypt(headers_data)
|
|
||||||
return result
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def masked_headers(self) -> dict[str, Any]:
|
def headers(self) -> dict[str, Any]:
|
||||||
"""Get masked headers for frontend display."""
|
if self.encrypted_headers is None:
|
||||||
from core.entities.provider_entities import BasicProviderConfig
|
return {}
|
||||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
|
||||||
from core.tools.utils.encryption import create_provider_encrypter
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not self.encrypted_headers:
|
return json.loads(self.encrypted_headers)
|
||||||
return {}
|
|
||||||
|
|
||||||
headers_data = json.loads(self.encrypted_headers)
|
|
||||||
|
|
||||||
# Create dynamic config for all headers as SECRET_INPUT
|
|
||||||
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers_data]
|
|
||||||
|
|
||||||
encrypter_instance, _ = create_provider_encrypter(
|
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
config=config,
|
|
||||||
cache=NoOpProviderCredentialCache(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# First decrypt, then mask
|
|
||||||
decrypted_headers = encrypter_instance.decrypt(headers_data)
|
|
||||||
result = encrypter_instance.mask_tool_credentials(decrypted_headers)
|
|
||||||
return result
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def masked_server_url(self) -> str:
|
def tool_dict(self) -> list[dict[str, Any]]:
|
||||||
def mask_url(url: str, mask_char: str = "*") -> str:
|
try:
|
||||||
"""
|
return json.loads(self.tools) if self.tools else []
|
||||||
mask the url to a simple string
|
except (json.JSONDecodeError, TypeError):
|
||||||
"""
|
return []
|
||||||
parsed = urlparse(url)
|
|
||||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
|
||||||
|
|
||||||
if parsed.path and parsed.path != "/":
|
def to_entity(self) -> "MCPProviderEntity":
|
||||||
return f"{base_url}/{mask_char * 6}"
|
"""Convert to domain entity"""
|
||||||
else:
|
from core.entities.mcp_provider import MCPProviderEntity
|
||||||
return base_url
|
|
||||||
|
|
||||||
return mask_url(self.decrypted_server_url)
|
return MCPProviderEntity.from_db_model(self)
|
||||||
|
|
||||||
@property
|
|
||||||
def decrypted_credentials(self) -> dict[str, Any]:
|
|
||||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
|
||||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
|
||||||
from core.tools.utils.encryption import create_provider_encrypter
|
|
||||||
|
|
||||||
provider_controller = MCPToolProviderController.from_db(self)
|
|
||||||
|
|
||||||
encrypter, _ = create_provider_encrypter(
|
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
|
||||||
cache=NoOpProviderCredentialCache(),
|
|
||||||
)
|
|
||||||
|
|
||||||
return encrypter.decrypt(self.credentials)
|
|
||||||
|
|
||||||
|
|
||||||
class ToolModelInvoke(TypeBase):
|
class ToolModelInvoke(TypeBase):
|
||||||
|
|||||||
@ -117,7 +117,7 @@ dev = [
|
|||||||
"pytest-cov~=4.1.0",
|
"pytest-cov~=4.1.0",
|
||||||
"pytest-env~=1.1.3",
|
"pytest-env~=1.1.3",
|
||||||
"pytest-mock~=3.14.0",
|
"pytest-mock~=3.14.0",
|
||||||
"testcontainers~=4.10.0",
|
"testcontainers~=4.13.2",
|
||||||
"types-aiofiles~=24.1.0",
|
"types-aiofiles~=24.1.0",
|
||||||
"types-beautifulsoup4~=4.12.0",
|
"types-beautifulsoup4~=4.12.0",
|
||||||
"types-cachetools~=5.5.0",
|
"types-cachetools~=5.5.0",
|
||||||
|
|||||||
@ -28,7 +28,7 @@ Example:
|
|||||||
runs = repo.get_paginated_workflow_runs(
|
runs = repo.get_paginated_workflow_runs(
|
||||||
tenant_id="tenant-123",
|
tenant_id="tenant-123",
|
||||||
app_id="app-456",
|
app_id="app-456",
|
||||||
triggered_from="debugging",
|
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||||
limit=20
|
limit=20
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
@ -40,7 +40,14 @@ from typing import Protocol
|
|||||||
|
|
||||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||||
|
from models.enums import WorkflowRunTriggeredFrom
|
||||||
from models.workflow import WorkflowRun
|
from models.workflow import WorkflowRun
|
||||||
|
from repositories.types import (
|
||||||
|
AverageInteractionStats,
|
||||||
|
DailyRunsStats,
|
||||||
|
DailyTerminalsStats,
|
||||||
|
DailyTokenCostStats,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
|
class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
|
||||||
@ -56,7 +63,7 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
|
|||||||
self,
|
self,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
app_id: str,
|
app_id: str,
|
||||||
triggered_from: str,
|
triggered_from: WorkflowRunTriggeredFrom | Sequence[WorkflowRunTriggeredFrom],
|
||||||
limit: int = 20,
|
limit: int = 20,
|
||||||
last_id: str | None = None,
|
last_id: str | None = None,
|
||||||
status: str | None = None,
|
status: str | None = None,
|
||||||
@ -71,7 +78,7 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
|
|||||||
Args:
|
Args:
|
||||||
tenant_id: Tenant identifier for multi-tenant isolation
|
tenant_id: Tenant identifier for multi-tenant isolation
|
||||||
app_id: Application identifier
|
app_id: Application identifier
|
||||||
triggered_from: Filter by trigger source (e.g., "debugging", "app-run")
|
triggered_from: Filter by trigger source(s) (e.g., "debugging", "app-run", or list of values)
|
||||||
limit: Maximum number of records to return (default: 20)
|
limit: Maximum number of records to return (default: 20)
|
||||||
last_id: Cursor for pagination - ID of the last record from previous page
|
last_id: Cursor for pagination - ID of the last record from previous page
|
||||||
status: Optional filter by status (e.g., "running", "succeeded", "failed")
|
status: Optional filter by status (e.g., "running", "succeeded", "failed")
|
||||||
@ -109,6 +116,31 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
def get_workflow_run_by_id_without_tenant(
|
||||||
|
self,
|
||||||
|
run_id: str,
|
||||||
|
) -> WorkflowRun | None:
|
||||||
|
"""
|
||||||
|
Get a specific workflow run by ID without tenant/app context.
|
||||||
|
|
||||||
|
Retrieves a single workflow run using only the run ID, without
|
||||||
|
requiring tenant_id or app_id. This method is intended for internal
|
||||||
|
system operations like tracing and monitoring where the tenant context
|
||||||
|
is not available upfront.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: Workflow run identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WorkflowRun object if found, None otherwise
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This method bypasses tenant isolation checks and should only be used
|
||||||
|
in trusted system contexts like ops trace collection. For user-facing
|
||||||
|
operations, use get_workflow_run_by_id() with proper tenant isolation.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
def get_workflow_runs_count(
|
def get_workflow_runs_count(
|
||||||
self,
|
self,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
@ -218,3 +250,119 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
|
|||||||
and ensure proper data retention policies are followed.
|
and ensure proper data retention policies are followed.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
def get_daily_runs_statistics(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
app_id: str,
|
||||||
|
triggered_from: str,
|
||||||
|
start_date: datetime | None = None,
|
||||||
|
end_date: datetime | None = None,
|
||||||
|
timezone: str = "UTC",
|
||||||
|
) -> list[DailyRunsStats]:
|
||||||
|
"""
|
||||||
|
Get daily runs statistics.
|
||||||
|
|
||||||
|
Retrieves daily workflow runs count grouped by date for a specific app
|
||||||
|
and trigger source. Used for workflow statistics dashboard.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: Tenant identifier for multi-tenant isolation
|
||||||
|
app_id: Application identifier
|
||||||
|
triggered_from: Filter by trigger source (e.g., "app-run")
|
||||||
|
start_date: Optional start date filter
|
||||||
|
end_date: Optional end date filter
|
||||||
|
timezone: Timezone for date grouping (default: "UTC")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dictionaries containing date and runs count:
|
||||||
|
[{"date": "2024-01-01", "runs": 10}, ...]
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_daily_terminals_statistics(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
app_id: str,
|
||||||
|
triggered_from: str,
|
||||||
|
start_date: datetime | None = None,
|
||||||
|
end_date: datetime | None = None,
|
||||||
|
timezone: str = "UTC",
|
||||||
|
) -> list[DailyTerminalsStats]:
|
||||||
|
"""
|
||||||
|
Get daily terminals statistics.
|
||||||
|
|
||||||
|
Retrieves daily unique terminal count grouped by date for a specific app
|
||||||
|
and trigger source. Used for workflow statistics dashboard.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: Tenant identifier for multi-tenant isolation
|
||||||
|
app_id: Application identifier
|
||||||
|
triggered_from: Filter by trigger source (e.g., "app-run")
|
||||||
|
start_date: Optional start date filter
|
||||||
|
end_date: Optional end date filter
|
||||||
|
timezone: Timezone for date grouping (default: "UTC")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dictionaries containing date and terminal count:
|
||||||
|
[{"date": "2024-01-01", "terminal_count": 5}, ...]
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_daily_token_cost_statistics(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
app_id: str,
|
||||||
|
triggered_from: str,
|
||||||
|
start_date: datetime | None = None,
|
||||||
|
end_date: datetime | None = None,
|
||||||
|
timezone: str = "UTC",
|
||||||
|
) -> list[DailyTokenCostStats]:
|
||||||
|
"""
|
||||||
|
Get daily token cost statistics.
|
||||||
|
|
||||||
|
Retrieves daily total token count grouped by date for a specific app
|
||||||
|
and trigger source. Used for workflow statistics dashboard.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: Tenant identifier for multi-tenant isolation
|
||||||
|
app_id: Application identifier
|
||||||
|
triggered_from: Filter by trigger source (e.g., "app-run")
|
||||||
|
start_date: Optional start date filter
|
||||||
|
end_date: Optional end date filter
|
||||||
|
timezone: Timezone for date grouping (default: "UTC")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dictionaries containing date and token count:
|
||||||
|
[{"date": "2024-01-01", "token_count": 1000}, ...]
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_average_app_interaction_statistics(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
app_id: str,
|
||||||
|
triggered_from: str,
|
||||||
|
start_date: datetime | None = None,
|
||||||
|
end_date: datetime | None = None,
|
||||||
|
timezone: str = "UTC",
|
||||||
|
) -> list[AverageInteractionStats]:
|
||||||
|
"""
|
||||||
|
Get average app interaction statistics.
|
||||||
|
|
||||||
|
Retrieves daily average interactions per user grouped by date for a specific app
|
||||||
|
and trigger source. Used for workflow statistics dashboard.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: Tenant identifier for multi-tenant isolation
|
||||||
|
app_id: Application identifier
|
||||||
|
triggered_from: Filter by trigger source (e.g., "app-run")
|
||||||
|
start_date: Optional start date filter
|
||||||
|
end_date: Optional end date filter
|
||||||
|
timezone: Timezone for date grouping (default: "UTC")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dictionaries containing date and average interactions:
|
||||||
|
[{"date": "2024-01-01", "interactions": 2.5}, ...]
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|||||||
@ -22,16 +22,25 @@ Implementation Notes:
|
|||||||
import logging
|
import logging
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import cast
|
from decimal import Decimal
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
from sqlalchemy import delete, func, select
|
from sqlalchemy import delete, func, select
|
||||||
from sqlalchemy.engine import CursorResult
|
from sqlalchemy.engine import CursorResult
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||||
from libs.time_parser import get_time_threshold
|
from libs.time_parser import get_time_threshold
|
||||||
|
from models.enums import WorkflowRunTriggeredFrom
|
||||||
from models.workflow import WorkflowRun
|
from models.workflow import WorkflowRun
|
||||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||||
|
from repositories.types import (
|
||||||
|
AverageInteractionStats,
|
||||||
|
DailyRunsStats,
|
||||||
|
DailyTerminalsStats,
|
||||||
|
DailyTokenCostStats,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -61,7 +70,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||||||
self,
|
self,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
app_id: str,
|
app_id: str,
|
||||||
triggered_from: str,
|
triggered_from: WorkflowRunTriggeredFrom | Sequence[WorkflowRunTriggeredFrom],
|
||||||
limit: int = 20,
|
limit: int = 20,
|
||||||
last_id: str | None = None,
|
last_id: str | None = None,
|
||||||
status: str | None = None,
|
status: str | None = None,
|
||||||
@ -78,9 +87,14 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||||||
base_stmt = select(WorkflowRun).where(
|
base_stmt = select(WorkflowRun).where(
|
||||||
WorkflowRun.tenant_id == tenant_id,
|
WorkflowRun.tenant_id == tenant_id,
|
||||||
WorkflowRun.app_id == app_id,
|
WorkflowRun.app_id == app_id,
|
||||||
WorkflowRun.triggered_from == triggered_from,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Handle triggered_from values
|
||||||
|
if isinstance(triggered_from, WorkflowRunTriggeredFrom):
|
||||||
|
triggered_from = [triggered_from]
|
||||||
|
if triggered_from:
|
||||||
|
base_stmt = base_stmt.where(WorkflowRun.triggered_from.in_(triggered_from))
|
||||||
|
|
||||||
# Add optional status filter
|
# Add optional status filter
|
||||||
if status:
|
if status:
|
||||||
base_stmt = base_stmt.where(WorkflowRun.status == status)
|
base_stmt = base_stmt.where(WorkflowRun.status == status)
|
||||||
@ -126,6 +140,17 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||||||
)
|
)
|
||||||
return session.scalar(stmt)
|
return session.scalar(stmt)
|
||||||
|
|
||||||
|
def get_workflow_run_by_id_without_tenant(
|
||||||
|
self,
|
||||||
|
run_id: str,
|
||||||
|
) -> WorkflowRun | None:
|
||||||
|
"""
|
||||||
|
Get a specific workflow run by ID without tenant/app context.
|
||||||
|
"""
|
||||||
|
with self._session_maker() as session:
|
||||||
|
stmt = select(WorkflowRun).where(WorkflowRun.id == run_id)
|
||||||
|
return session.scalar(stmt)
|
||||||
|
|
||||||
def get_workflow_runs_count(
|
def get_workflow_runs_count(
|
||||||
self,
|
self,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
@ -275,3 +300,213 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||||||
|
|
||||||
logger.info("Total deleted %s workflow runs for app %s", total_deleted, app_id)
|
logger.info("Total deleted %s workflow runs for app %s", total_deleted, app_id)
|
||||||
return total_deleted
|
return total_deleted
|
||||||
|
|
||||||
|
def get_daily_runs_statistics(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
app_id: str,
|
||||||
|
triggered_from: str,
|
||||||
|
start_date: datetime | None = None,
|
||||||
|
end_date: datetime | None = None,
|
||||||
|
timezone: str = "UTC",
|
||||||
|
) -> list[DailyRunsStats]:
|
||||||
|
"""
|
||||||
|
Get daily runs statistics using raw SQL for optimal performance.
|
||||||
|
"""
|
||||||
|
sql_query = """SELECT
|
||||||
|
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||||
|
COUNT(id) AS runs
|
||||||
|
FROM
|
||||||
|
workflow_runs
|
||||||
|
WHERE
|
||||||
|
tenant_id = :tenant_id
|
||||||
|
AND app_id = :app_id
|
||||||
|
AND triggered_from = :triggered_from"""
|
||||||
|
|
||||||
|
arg_dict: dict[str, Any] = {
|
||||||
|
"tz": timezone,
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"app_id": app_id,
|
||||||
|
"triggered_from": triggered_from,
|
||||||
|
}
|
||||||
|
|
||||||
|
if start_date:
|
||||||
|
sql_query += " AND created_at >= :start_date"
|
||||||
|
arg_dict["start_date"] = start_date
|
||||||
|
|
||||||
|
if end_date:
|
||||||
|
sql_query += " AND created_at < :end_date"
|
||||||
|
arg_dict["end_date"] = end_date
|
||||||
|
|
||||||
|
sql_query += " GROUP BY date ORDER BY date"
|
||||||
|
|
||||||
|
response_data = []
|
||||||
|
with self._session_maker() as session:
|
||||||
|
rs = session.execute(sa.text(sql_query), arg_dict)
|
||||||
|
for row in rs:
|
||||||
|
response_data.append({"date": str(row.date), "runs": row.runs})
|
||||||
|
|
||||||
|
return cast(list[DailyRunsStats], response_data)
|
||||||
|
|
||||||
|
def get_daily_terminals_statistics(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
app_id: str,
|
||||||
|
triggered_from: str,
|
||||||
|
start_date: datetime | None = None,
|
||||||
|
end_date: datetime | None = None,
|
||||||
|
timezone: str = "UTC",
|
||||||
|
) -> list[DailyTerminalsStats]:
|
||||||
|
"""
|
||||||
|
Get daily terminals statistics using raw SQL for optimal performance.
|
||||||
|
"""
|
||||||
|
sql_query = """SELECT
|
||||||
|
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||||
|
COUNT(DISTINCT created_by) AS terminal_count
|
||||||
|
FROM
|
||||||
|
workflow_runs
|
||||||
|
WHERE
|
||||||
|
tenant_id = :tenant_id
|
||||||
|
AND app_id = :app_id
|
||||||
|
AND triggered_from = :triggered_from"""
|
||||||
|
|
||||||
|
arg_dict: dict[str, Any] = {
|
||||||
|
"tz": timezone,
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"app_id": app_id,
|
||||||
|
"triggered_from": triggered_from,
|
||||||
|
}
|
||||||
|
|
||||||
|
if start_date:
|
||||||
|
sql_query += " AND created_at >= :start_date"
|
||||||
|
arg_dict["start_date"] = start_date
|
||||||
|
|
||||||
|
if end_date:
|
||||||
|
sql_query += " AND created_at < :end_date"
|
||||||
|
arg_dict["end_date"] = end_date
|
||||||
|
|
||||||
|
sql_query += " GROUP BY date ORDER BY date"
|
||||||
|
|
||||||
|
response_data = []
|
||||||
|
with self._session_maker() as session:
|
||||||
|
rs = session.execute(sa.text(sql_query), arg_dict)
|
||||||
|
for row in rs:
|
||||||
|
response_data.append({"date": str(row.date), "terminal_count": row.terminal_count})
|
||||||
|
|
||||||
|
return cast(list[DailyTerminalsStats], response_data)
|
||||||
|
|
||||||
|
def get_daily_token_cost_statistics(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
app_id: str,
|
||||||
|
triggered_from: str,
|
||||||
|
start_date: datetime | None = None,
|
||||||
|
end_date: datetime | None = None,
|
||||||
|
timezone: str = "UTC",
|
||||||
|
) -> list[DailyTokenCostStats]:
|
||||||
|
"""
|
||||||
|
Get daily token cost statistics using raw SQL for optimal performance.
|
||||||
|
"""
|
||||||
|
sql_query = """SELECT
|
||||||
|
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||||
|
SUM(total_tokens) AS token_count
|
||||||
|
FROM
|
||||||
|
workflow_runs
|
||||||
|
WHERE
|
||||||
|
tenant_id = :tenant_id
|
||||||
|
AND app_id = :app_id
|
||||||
|
AND triggered_from = :triggered_from"""
|
||||||
|
|
||||||
|
arg_dict: dict[str, Any] = {
|
||||||
|
"tz": timezone,
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"app_id": app_id,
|
||||||
|
"triggered_from": triggered_from,
|
||||||
|
}
|
||||||
|
|
||||||
|
if start_date:
|
||||||
|
sql_query += " AND created_at >= :start_date"
|
||||||
|
arg_dict["start_date"] = start_date
|
||||||
|
|
||||||
|
if end_date:
|
||||||
|
sql_query += " AND created_at < :end_date"
|
||||||
|
arg_dict["end_date"] = end_date
|
||||||
|
|
||||||
|
sql_query += " GROUP BY date ORDER BY date"
|
||||||
|
|
||||||
|
response_data = []
|
||||||
|
with self._session_maker() as session:
|
||||||
|
rs = session.execute(sa.text(sql_query), arg_dict)
|
||||||
|
for row in rs:
|
||||||
|
response_data.append(
|
||||||
|
{
|
||||||
|
"date": str(row.date),
|
||||||
|
"token_count": row.token_count,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return cast(list[DailyTokenCostStats], response_data)
|
||||||
|
|
||||||
|
def get_average_app_interaction_statistics(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
app_id: str,
|
||||||
|
triggered_from: str,
|
||||||
|
start_date: datetime | None = None,
|
||||||
|
end_date: datetime | None = None,
|
||||||
|
timezone: str = "UTC",
|
||||||
|
) -> list[AverageInteractionStats]:
|
||||||
|
"""
|
||||||
|
Get average app interaction statistics using raw SQL for optimal performance.
|
||||||
|
"""
|
||||||
|
sql_query = """SELECT
|
||||||
|
AVG(sub.interactions) AS interactions,
|
||||||
|
sub.date
|
||||||
|
FROM
|
||||||
|
(
|
||||||
|
SELECT
|
||||||
|
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||||
|
c.created_by,
|
||||||
|
COUNT(c.id) AS interactions
|
||||||
|
FROM
|
||||||
|
workflow_runs c
|
||||||
|
WHERE
|
||||||
|
c.tenant_id = :tenant_id
|
||||||
|
AND c.app_id = :app_id
|
||||||
|
AND c.triggered_from = :triggered_from
|
||||||
|
{{start}}
|
||||||
|
{{end}}
|
||||||
|
GROUP BY
|
||||||
|
date, c.created_by
|
||||||
|
) sub
|
||||||
|
GROUP BY
|
||||||
|
sub.date"""
|
||||||
|
|
||||||
|
arg_dict: dict[str, Any] = {
|
||||||
|
"tz": timezone,
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"app_id": app_id,
|
||||||
|
"triggered_from": triggered_from,
|
||||||
|
}
|
||||||
|
|
||||||
|
if start_date:
|
||||||
|
sql_query = sql_query.replace("{{start}}", " AND c.created_at >= :start_date")
|
||||||
|
arg_dict["start_date"] = start_date
|
||||||
|
else:
|
||||||
|
sql_query = sql_query.replace("{{start}}", "")
|
||||||
|
|
||||||
|
if end_date:
|
||||||
|
sql_query = sql_query.replace("{{end}}", " AND c.created_at < :end_date")
|
||||||
|
arg_dict["end_date"] = end_date
|
||||||
|
else:
|
||||||
|
sql_query = sql_query.replace("{{end}}", "")
|
||||||
|
|
||||||
|
response_data = []
|
||||||
|
with self._session_maker() as session:
|
||||||
|
rs = session.execute(sa.text(sql_query), arg_dict)
|
||||||
|
for row in rs:
|
||||||
|
response_data.append(
|
||||||
|
{"date": str(row.date), "interactions": float(row.interactions.quantize(Decimal("0.01")))}
|
||||||
|
)
|
||||||
|
|
||||||
|
return cast(list[AverageInteractionStats], response_data)
|
||||||
|
|||||||
21
api/repositories/types.py
Normal file
21
api/repositories/types.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from typing import TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
class DailyRunsStats(TypedDict):
|
||||||
|
date: str
|
||||||
|
runs: int
|
||||||
|
|
||||||
|
|
||||||
|
class DailyTerminalsStats(TypedDict):
|
||||||
|
date: str
|
||||||
|
terminal_count: int
|
||||||
|
|
||||||
|
|
||||||
|
class DailyTokenCostStats(TypedDict):
|
||||||
|
date: str
|
||||||
|
token_count: int
|
||||||
|
|
||||||
|
|
||||||
|
class AverageInteractionStats(TypedDict):
|
||||||
|
date: str
|
||||||
|
interactions: float
|
||||||
@ -82,54 +82,51 @@ class AudioService:
|
|||||||
message_id: str | None = None,
|
message_id: str | None = None,
|
||||||
is_draft: bool = False,
|
is_draft: bool = False,
|
||||||
):
|
):
|
||||||
from app import app
|
|
||||||
|
|
||||||
def invoke_tts(text_content: str, app_model: App, voice: str | None = None, is_draft: bool = False):
|
def invoke_tts(text_content: str, app_model: App, voice: str | None = None, is_draft: bool = False):
|
||||||
with app.app_context():
|
if voice is None:
|
||||||
if voice is None:
|
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
if is_draft:
|
||||||
if is_draft:
|
workflow = WorkflowService().get_draft_workflow(app_model=app_model)
|
||||||
workflow = WorkflowService().get_draft_workflow(app_model=app_model)
|
else:
|
||||||
else:
|
workflow = app_model.workflow
|
||||||
workflow = app_model.workflow
|
if (
|
||||||
if (
|
workflow is None
|
||||||
workflow is None
|
or "text_to_speech" not in workflow.features_dict
|
||||||
or "text_to_speech" not in workflow.features_dict
|
or not workflow.features_dict["text_to_speech"].get("enabled")
|
||||||
or not workflow.features_dict["text_to_speech"].get("enabled")
|
):
|
||||||
):
|
raise ValueError("TTS is not enabled")
|
||||||
|
|
||||||
|
voice = workflow.features_dict["text_to_speech"].get("voice")
|
||||||
|
else:
|
||||||
|
if not is_draft:
|
||||||
|
if app_model.app_model_config is None:
|
||||||
|
raise ValueError("AppModelConfig not found")
|
||||||
|
text_to_speech_dict = app_model.app_model_config.text_to_speech_dict
|
||||||
|
|
||||||
|
if not text_to_speech_dict.get("enabled"):
|
||||||
raise ValueError("TTS is not enabled")
|
raise ValueError("TTS is not enabled")
|
||||||
|
|
||||||
voice = workflow.features_dict["text_to_speech"].get("voice")
|
voice = text_to_speech_dict.get("voice")
|
||||||
else:
|
|
||||||
if not is_draft:
|
|
||||||
if app_model.app_model_config is None:
|
|
||||||
raise ValueError("AppModelConfig not found")
|
|
||||||
text_to_speech_dict = app_model.app_model_config.text_to_speech_dict
|
|
||||||
|
|
||||||
if not text_to_speech_dict.get("enabled"):
|
model_manager = ModelManager()
|
||||||
raise ValueError("TTS is not enabled")
|
model_instance = model_manager.get_default_model_instance(
|
||||||
|
tenant_id=app_model.tenant_id, model_type=ModelType.TTS
|
||||||
voice = text_to_speech_dict.get("voice")
|
)
|
||||||
|
try:
|
||||||
model_manager = ModelManager()
|
if not voice:
|
||||||
model_instance = model_manager.get_default_model_instance(
|
voices = model_instance.get_tts_voices()
|
||||||
tenant_id=app_model.tenant_id, model_type=ModelType.TTS
|
if voices:
|
||||||
)
|
voice = voices[0].get("value")
|
||||||
try:
|
if not voice:
|
||||||
if not voice:
|
|
||||||
voices = model_instance.get_tts_voices()
|
|
||||||
if voices:
|
|
||||||
voice = voices[0].get("value")
|
|
||||||
if not voice:
|
|
||||||
raise ValueError("Sorry, no voice available.")
|
|
||||||
else:
|
|
||||||
raise ValueError("Sorry, no voice available.")
|
raise ValueError("Sorry, no voice available.")
|
||||||
|
else:
|
||||||
|
raise ValueError("Sorry, no voice available.")
|
||||||
|
|
||||||
return model_instance.invoke_tts(
|
return model_instance.invoke_tts(
|
||||||
content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice
|
content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
if message_id:
|
if message_id:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -1417,7 +1417,7 @@ class DocumentService:
|
|||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
assert current_user.current_tenant_id is not None
|
assert current_user.current_tenant_id is not None
|
||||||
assert knowledge_config.data_source
|
assert knowledge_config.data_source
|
||||||
assert knowledge_config.data_source.info_list.file_info_list
|
assert knowledge_config.data_source.info_list
|
||||||
|
|
||||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||||
|
|
||||||
@ -1426,6 +1426,8 @@ class DocumentService:
|
|||||||
count = 0
|
count = 0
|
||||||
if knowledge_config.data_source:
|
if knowledge_config.data_source:
|
||||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
||||||
|
if not knowledge_config.data_source.info_list.file_info_list:
|
||||||
|
raise ValueError("File source info is required")
|
||||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
||||||
count = len(upload_file_list)
|
count = len(upload_file_list)
|
||||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||||
@ -1531,6 +1533,8 @@ class DocumentService:
|
|||||||
document_ids = []
|
document_ids = []
|
||||||
duplicate_document_ids = []
|
duplicate_document_ids = []
|
||||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
||||||
|
if not knowledge_config.data_source.info_list.file_info_list:
|
||||||
|
raise ValueError("File source info is required")
|
||||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
||||||
for file_id in upload_file_list:
|
for file_id in upload_file_list:
|
||||||
file = (
|
file = (
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from typing import Any, Union, cast
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from sqlalchemy import func, or_, select
|
from sqlalchemy import func, select
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
import contexts
|
import contexts
|
||||||
@ -94,6 +94,7 @@ class RagPipelineService:
|
|||||||
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
|
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
|
||||||
session_maker
|
session_maker
|
||||||
)
|
)
|
||||||
|
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict:
|
def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict:
|
||||||
@ -1015,48 +1016,21 @@ class RagPipelineService:
|
|||||||
:param args: request args
|
:param args: request args
|
||||||
"""
|
"""
|
||||||
limit = int(args.get("limit", 20))
|
limit = int(args.get("limit", 20))
|
||||||
|
last_id = args.get("last_id")
|
||||||
|
|
||||||
base_query = db.session.query(WorkflowRun).where(
|
triggered_from_values = [
|
||||||
WorkflowRun.tenant_id == pipeline.tenant_id,
|
WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN,
|
||||||
WorkflowRun.app_id == pipeline.id,
|
WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
|
||||||
or_(
|
]
|
||||||
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value,
|
|
||||||
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value,
|
return self._workflow_run_repo.get_paginated_workflow_runs(
|
||||||
),
|
tenant_id=pipeline.tenant_id,
|
||||||
|
app_id=pipeline.id,
|
||||||
|
triggered_from=triggered_from_values,
|
||||||
|
limit=limit,
|
||||||
|
last_id=last_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.get("last_id"):
|
|
||||||
last_workflow_run = base_query.where(
|
|
||||||
WorkflowRun.id == args.get("last_id"),
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not last_workflow_run:
|
|
||||||
raise ValueError("Last workflow run not exists")
|
|
||||||
|
|
||||||
workflow_runs = (
|
|
||||||
base_query.where(
|
|
||||||
WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id
|
|
||||||
)
|
|
||||||
.order_by(WorkflowRun.created_at.desc())
|
|
||||||
.limit(limit)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all()
|
|
||||||
|
|
||||||
has_more = False
|
|
||||||
if len(workflow_runs) == limit:
|
|
||||||
current_page_first_workflow_run = workflow_runs[-1]
|
|
||||||
rest_count = base_query.where(
|
|
||||||
WorkflowRun.created_at < current_page_first_workflow_run.created_at,
|
|
||||||
WorkflowRun.id != current_page_first_workflow_run.id,
|
|
||||||
).count()
|
|
||||||
|
|
||||||
if rest_count > 0:
|
|
||||||
has_more = True
|
|
||||||
|
|
||||||
return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
|
|
||||||
|
|
||||||
def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> WorkflowRun | None:
|
def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> WorkflowRun | None:
|
||||||
"""
|
"""
|
||||||
Get workflow run detail
|
Get workflow run detail
|
||||||
@ -1064,18 +1038,12 @@ class RagPipelineService:
|
|||||||
:param app_model: app model
|
:param app_model: app model
|
||||||
:param run_id: workflow run id
|
:param run_id: workflow run id
|
||||||
"""
|
"""
|
||||||
workflow_run = (
|
return self._workflow_run_repo.get_workflow_run_by_id(
|
||||||
db.session.query(WorkflowRun)
|
tenant_id=pipeline.tenant_id,
|
||||||
.where(
|
app_id=pipeline.id,
|
||||||
WorkflowRun.tenant_id == pipeline.tenant_id,
|
run_id=run_id,
|
||||||
WorkflowRun.app_id == pipeline.id,
|
|
||||||
WorkflowRun.id == run_id,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return workflow_run
|
|
||||||
|
|
||||||
def get_rag_pipeline_workflow_run_node_executions(
|
def get_rag_pipeline_workflow_run_node_executions(
|
||||||
self,
|
self,
|
||||||
pipeline: Pipeline,
|
pipeline: Pipeline,
|
||||||
|
|||||||
@ -1,86 +1,118 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from enum import StrEnum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from sqlalchemy import or_
|
from pydantic import BaseModel, Field
|
||||||
|
from sqlalchemy import or_, select
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||||
|
from core.mcp.auth.auth_flow import auth
|
||||||
|
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||||
from core.mcp.error import MCPAuthError, MCPError
|
from core.mcp.error import MCPAuthError, MCPError
|
||||||
from core.mcp.mcp_client import MCPClient
|
|
||||||
from core.tools.entities.api_entities import ToolProviderApiEntity
|
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||||
from core.tools.entities.common_entities import I18nObject
|
|
||||||
from core.tools.entities.tool_entities import ToolProviderType
|
|
||||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
|
||||||
from core.tools.utils.encryption import ProviderConfigEncrypter
|
from core.tools.utils.encryption import ProviderConfigEncrypter
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.tools import MCPToolProvider
|
from models.tools import MCPToolProvider
|
||||||
from services.tools.tools_transform_service import ToolTransformService
|
from services.tools.tools_transform_service import ToolTransformService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Constants
|
||||||
UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]"
|
UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]"
|
||||||
|
CLIENT_NAME = "Dify"
|
||||||
|
EMPTY_TOOLS_JSON = "[]"
|
||||||
|
EMPTY_CREDENTIALS_JSON = "{}"
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthDataType(StrEnum):
|
||||||
|
"""Types of OAuth data that can be saved."""
|
||||||
|
|
||||||
|
TOKENS = "tokens"
|
||||||
|
CLIENT_INFO = "client_info"
|
||||||
|
CODE_VERIFIER = "code_verifier"
|
||||||
|
MIXED = "mixed"
|
||||||
|
|
||||||
|
|
||||||
|
class ReconnectResult(BaseModel):
|
||||||
|
"""Result of reconnecting to an MCP provider"""
|
||||||
|
|
||||||
|
authed: bool = Field(description="Whether the provider is authenticated")
|
||||||
|
tools: str = Field(description="JSON string of tool list")
|
||||||
|
encrypted_credentials: str = Field(description="JSON string of encrypted credentials")
|
||||||
|
|
||||||
|
|
||||||
|
class ServerUrlValidationResult(BaseModel):
|
||||||
|
"""Result of server URL validation check"""
|
||||||
|
|
||||||
|
needs_validation: bool
|
||||||
|
validation_passed: bool = False
|
||||||
|
reconnect_result: ReconnectResult | None = None
|
||||||
|
encrypted_server_url: str | None = None
|
||||||
|
server_url_hash: str | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def should_update_server_url(self) -> bool:
|
||||||
|
"""Check if server URL should be updated based on validation result"""
|
||||||
|
return self.needs_validation and self.validation_passed and self.reconnect_result is not None
|
||||||
|
|
||||||
|
|
||||||
class MCPToolManageService:
|
class MCPToolManageService:
|
||||||
"""
|
"""Service class for managing MCP tools and providers."""
|
||||||
Service class for managing mcp tools.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
def __init__(self, session: Session):
|
||||||
def _encrypt_headers(headers: dict[str, str], tenant_id: str) -> dict[str, str]:
|
self._session = session
|
||||||
|
|
||||||
|
# ========== Provider CRUD Operations ==========
|
||||||
|
|
||||||
|
def get_provider(
|
||||||
|
self, *, provider_id: str | None = None, server_identifier: str | None = None, tenant_id: str
|
||||||
|
) -> MCPToolProvider:
|
||||||
"""
|
"""
|
||||||
Encrypt headers using ProviderConfigEncrypter with all headers as SECRET_INPUT.
|
Get MCP provider by ID or server identifier.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
headers: Dictionary of headers to encrypt
|
provider_id: Provider ID (UUID)
|
||||||
tenant_id: Tenant ID for encryption
|
server_identifier: Server identifier
|
||||||
|
tenant_id: Tenant ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with all headers encrypted
|
MCPToolProvider instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If provider not found
|
||||||
"""
|
"""
|
||||||
if not headers:
|
if server_identifier:
|
||||||
return {}
|
stmt = select(MCPToolProvider).where(
|
||||||
|
MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
stmt = select(MCPToolProvider).where(
|
||||||
|
MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id
|
||||||
|
)
|
||||||
|
|
||||||
from core.entities.provider_entities import BasicProviderConfig
|
provider = self._session.scalar(stmt)
|
||||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
if not provider:
|
||||||
from core.tools.utils.encryption import create_provider_encrypter
|
|
||||||
|
|
||||||
# Create dynamic config for all headers as SECRET_INPUT
|
|
||||||
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers]
|
|
||||||
|
|
||||||
encrypter_instance, _ = create_provider_encrypter(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
config=config,
|
|
||||||
cache=NoOpProviderCredentialCache(),
|
|
||||||
)
|
|
||||||
|
|
||||||
return encrypter_instance.encrypt(headers)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider:
|
|
||||||
res = (
|
|
||||||
db.session.query(MCPToolProvider)
|
|
||||||
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if not res:
|
|
||||||
raise ValueError("MCP tool not found")
|
raise ValueError("MCP tool not found")
|
||||||
return res
|
return provider
|
||||||
|
|
||||||
@staticmethod
|
def get_provider_entity(self, provider_id: str, tenant_id: str, by_server_id: bool = False) -> MCPProviderEntity:
|
||||||
def get_mcp_provider_by_server_identifier(server_identifier: str, tenant_id: str) -> MCPToolProvider:
|
"""Get provider entity by ID or server identifier."""
|
||||||
res = (
|
if by_server_id:
|
||||||
db.session.query(MCPToolProvider)
|
db_provider = self.get_provider(server_identifier=provider_id, tenant_id=tenant_id)
|
||||||
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier)
|
else:
|
||||||
.first()
|
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
)
|
return db_provider.to_entity()
|
||||||
if not res:
|
|
||||||
raise ValueError("MCP tool not found")
|
|
||||||
return res
|
|
||||||
|
|
||||||
@staticmethod
|
def create_provider(
|
||||||
def create_mcp_provider(
|
self,
|
||||||
|
*,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
name: str,
|
name: str,
|
||||||
server_url: str,
|
server_url: str,
|
||||||
@ -89,37 +121,30 @@ class MCPToolManageService:
|
|||||||
icon_type: str,
|
icon_type: str,
|
||||||
icon_background: str,
|
icon_background: str,
|
||||||
server_identifier: str,
|
server_identifier: str,
|
||||||
timeout: float,
|
configuration: MCPConfiguration,
|
||||||
sse_read_timeout: float,
|
authentication: MCPAuthentication | None = None,
|
||||||
headers: dict[str, str] | None = None,
|
headers: dict[str, str] | None = None,
|
||||||
) -> ToolProviderApiEntity:
|
) -> ToolProviderApiEntity:
|
||||||
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
"""Create a new MCP provider."""
|
||||||
existing_provider = (
|
# Validate URL format
|
||||||
db.session.query(MCPToolProvider)
|
if not self._is_valid_url(server_url):
|
||||||
.where(
|
raise ValueError("Server URL is not valid.")
|
||||||
MCPToolProvider.tenant_id == tenant_id,
|
|
||||||
or_(
|
|
||||||
MCPToolProvider.name == name,
|
|
||||||
MCPToolProvider.server_url_hash == server_url_hash,
|
|
||||||
MCPToolProvider.server_identifier == server_identifier,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if existing_provider:
|
|
||||||
if existing_provider.name == name:
|
|
||||||
raise ValueError(f"MCP tool {name} already exists")
|
|
||||||
if existing_provider.server_url_hash == server_url_hash:
|
|
||||||
raise ValueError(f"MCP tool {server_url} already exists")
|
|
||||||
if existing_provider.server_identifier == server_identifier:
|
|
||||||
raise ValueError(f"MCP tool {server_identifier} already exists")
|
|
||||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
|
||||||
# Encrypt headers
|
|
||||||
encrypted_headers = None
|
|
||||||
if headers:
|
|
||||||
encrypted_headers_dict = MCPToolManageService._encrypt_headers(headers, tenant_id)
|
|
||||||
encrypted_headers = json.dumps(encrypted_headers_dict)
|
|
||||||
|
|
||||||
|
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
||||||
|
|
||||||
|
# Check for existing provider
|
||||||
|
self._check_provider_exists(tenant_id, name, server_url_hash, server_identifier)
|
||||||
|
|
||||||
|
# Encrypt sensitive data
|
||||||
|
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
||||||
|
encrypted_headers = self._prepare_encrypted_dict(headers, tenant_id) if headers else None
|
||||||
|
encrypted_credentials = None
|
||||||
|
if authentication is not None and authentication.client_id:
|
||||||
|
encrypted_credentials = self._build_and_encrypt_credentials(
|
||||||
|
authentication.client_id, authentication.client_secret, tenant_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create provider
|
||||||
mcp_tool = MCPToolProvider(
|
mcp_tool = MCPToolProvider(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
name=name,
|
name=name,
|
||||||
@ -127,93 +152,23 @@ class MCPToolManageService:
|
|||||||
server_url_hash=server_url_hash,
|
server_url_hash=server_url_hash,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
authed=False,
|
authed=False,
|
||||||
tools="[]",
|
tools=EMPTY_TOOLS_JSON,
|
||||||
icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon,
|
icon=self._prepare_icon(icon, icon_type, icon_background),
|
||||||
server_identifier=server_identifier,
|
server_identifier=server_identifier,
|
||||||
timeout=timeout,
|
timeout=configuration.timeout,
|
||||||
sse_read_timeout=sse_read_timeout,
|
sse_read_timeout=configuration.sse_read_timeout,
|
||||||
encrypted_headers=encrypted_headers,
|
encrypted_headers=encrypted_headers,
|
||||||
)
|
encrypted_credentials=encrypted_credentials,
|
||||||
db.session.add(mcp_tool)
|
|
||||||
db.session.commit()
|
|
||||||
return ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def retrieve_mcp_tools(tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]:
|
|
||||||
mcp_providers = (
|
|
||||||
db.session.query(MCPToolProvider)
|
|
||||||
.where(MCPToolProvider.tenant_id == tenant_id)
|
|
||||||
.order_by(MCPToolProvider.name)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
return [
|
|
||||||
ToolTransformService.mcp_provider_to_user_provider(mcp_provider, for_list=for_list)
|
|
||||||
for mcp_provider in mcp_providers
|
|
||||||
]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
|
|
||||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
|
||||||
server_url = mcp_provider.decrypted_server_url
|
|
||||||
authed = mcp_provider.authed
|
|
||||||
headers = mcp_provider.decrypted_headers
|
|
||||||
timeout = mcp_provider.timeout
|
|
||||||
sse_read_timeout = mcp_provider.sse_read_timeout
|
|
||||||
|
|
||||||
try:
|
|
||||||
with MCPClient(
|
|
||||||
server_url,
|
|
||||||
provider_id,
|
|
||||||
tenant_id,
|
|
||||||
authed=authed,
|
|
||||||
for_list=True,
|
|
||||||
headers=headers,
|
|
||||||
timeout=timeout,
|
|
||||||
sse_read_timeout=sse_read_timeout,
|
|
||||||
) as mcp_client:
|
|
||||||
tools = mcp_client.list_tools()
|
|
||||||
except MCPAuthError:
|
|
||||||
raise ValueError("Please auth the tool first")
|
|
||||||
except MCPError as e:
|
|
||||||
raise ValueError(f"Failed to connect to MCP server: {e}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
|
||||||
mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
|
|
||||||
mcp_provider.authed = True
|
|
||||||
mcp_provider.updated_at = datetime.now()
|
|
||||||
db.session.commit()
|
|
||||||
except Exception:
|
|
||||||
db.session.rollback()
|
|
||||||
raise
|
|
||||||
|
|
||||||
user = mcp_provider.load_user()
|
|
||||||
if not mcp_provider.icon:
|
|
||||||
raise ValueError("MCP provider icon is required")
|
|
||||||
return ToolProviderApiEntity(
|
|
||||||
id=mcp_provider.id,
|
|
||||||
name=mcp_provider.name,
|
|
||||||
tools=ToolTransformService.mcp_tool_to_user_tool(mcp_provider, tools),
|
|
||||||
type=ToolProviderType.MCP,
|
|
||||||
icon=mcp_provider.icon,
|
|
||||||
author=user.name if user else "Anonymous",
|
|
||||||
server_url=mcp_provider.masked_server_url,
|
|
||||||
updated_at=int(mcp_provider.updated_at.timestamp()),
|
|
||||||
description=I18nObject(en_US="", zh_Hans=""),
|
|
||||||
label=I18nObject(en_US=mcp_provider.name, zh_Hans=mcp_provider.name),
|
|
||||||
plugin_unique_identifier=mcp_provider.server_identifier,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
self._session.add(mcp_tool)
|
||||||
def delete_mcp_tool(cls, tenant_id: str, provider_id: str):
|
self._session.flush()
|
||||||
mcp_tool = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
|
||||||
|
return mcp_providers
|
||||||
|
|
||||||
db.session.delete(mcp_tool)
|
def update_provider(
|
||||||
db.session.commit()
|
self,
|
||||||
|
*,
|
||||||
@classmethod
|
|
||||||
def update_mcp_provider(
|
|
||||||
cls,
|
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
provider_id: str,
|
provider_id: str,
|
||||||
name: str,
|
name: str,
|
||||||
@ -222,129 +177,546 @@ class MCPToolManageService:
|
|||||||
icon_type: str,
|
icon_type: str,
|
||||||
icon_background: str,
|
icon_background: str,
|
||||||
server_identifier: str,
|
server_identifier: str,
|
||||||
timeout: float | None = None,
|
|
||||||
sse_read_timeout: float | None = None,
|
|
||||||
headers: dict[str, str] | None = None,
|
headers: dict[str, str] | None = None,
|
||||||
):
|
configuration: MCPConfiguration,
|
||||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
authentication: MCPAuthentication | None = None,
|
||||||
|
validation_result: ServerUrlValidationResult | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Update an MCP provider.
|
||||||
|
|
||||||
reconnect_result = None
|
Args:
|
||||||
|
validation_result: Pre-validation result from validate_server_url_change.
|
||||||
|
If provided and contains reconnect_result, it will be used
|
||||||
|
instead of performing network operations.
|
||||||
|
"""
|
||||||
|
mcp_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
|
|
||||||
|
# Check for duplicate name (excluding current provider)
|
||||||
|
if name != mcp_provider.name:
|
||||||
|
stmt = select(MCPToolProvider).where(
|
||||||
|
MCPToolProvider.tenant_id == tenant_id,
|
||||||
|
MCPToolProvider.name == name,
|
||||||
|
MCPToolProvider.id != provider_id,
|
||||||
|
)
|
||||||
|
existing_provider = self._session.scalar(stmt)
|
||||||
|
if existing_provider:
|
||||||
|
raise ValueError(f"MCP tool {name} already exists")
|
||||||
|
|
||||||
|
# Get URL update data from validation result
|
||||||
encrypted_server_url = None
|
encrypted_server_url = None
|
||||||
server_url_hash = None
|
server_url_hash = None
|
||||||
|
reconnect_result = None
|
||||||
|
|
||||||
if UNCHANGED_SERVER_URL_PLACEHOLDER not in server_url:
|
if validation_result and validation_result.encrypted_server_url:
|
||||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
# Use all data from validation result
|
||||||
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
encrypted_server_url = validation_result.encrypted_server_url
|
||||||
|
server_url_hash = validation_result.server_url_hash
|
||||||
if server_url_hash != mcp_provider.server_url_hash:
|
reconnect_result = validation_result.reconnect_result
|
||||||
reconnect_result = cls._re_connect_mcp_provider(server_url, provider_id, tenant_id)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Update basic fields
|
||||||
mcp_provider.updated_at = datetime.now()
|
mcp_provider.updated_at = datetime.now()
|
||||||
mcp_provider.name = name
|
mcp_provider.name = name
|
||||||
mcp_provider.icon = (
|
mcp_provider.icon = self._prepare_icon(icon, icon_type, icon_background)
|
||||||
json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon
|
|
||||||
)
|
|
||||||
mcp_provider.server_identifier = server_identifier
|
mcp_provider.server_identifier = server_identifier
|
||||||
|
|
||||||
if encrypted_server_url is not None and server_url_hash is not None:
|
# Update server URL if changed
|
||||||
|
if encrypted_server_url and server_url_hash:
|
||||||
mcp_provider.server_url = encrypted_server_url
|
mcp_provider.server_url = encrypted_server_url
|
||||||
mcp_provider.server_url_hash = server_url_hash
|
mcp_provider.server_url_hash = server_url_hash
|
||||||
|
|
||||||
if reconnect_result:
|
if reconnect_result:
|
||||||
mcp_provider.authed = reconnect_result["authed"]
|
mcp_provider.authed = reconnect_result.authed
|
||||||
mcp_provider.tools = reconnect_result["tools"]
|
mcp_provider.tools = reconnect_result.tools
|
||||||
mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"]
|
mcp_provider.encrypted_credentials = reconnect_result.encrypted_credentials
|
||||||
|
|
||||||
if timeout is not None:
|
# Update optional configuration fields
|
||||||
mcp_provider.timeout = timeout
|
self._update_optional_fields(mcp_provider, configuration)
|
||||||
if sse_read_timeout is not None:
|
|
||||||
mcp_provider.sse_read_timeout = sse_read_timeout
|
# Update headers if provided
|
||||||
if headers is not None:
|
if headers is not None:
|
||||||
# Merge masked headers from frontend with existing real values
|
mcp_provider.encrypted_headers = self._process_headers(headers, mcp_provider, tenant_id)
|
||||||
if headers:
|
|
||||||
# existing decrypted and masked headers
|
|
||||||
existing_decrypted = mcp_provider.decrypted_headers
|
|
||||||
existing_masked = mcp_provider.masked_headers
|
|
||||||
|
|
||||||
# Build final headers: if value equals masked existing, keep original decrypted value
|
# Update credentials if provided
|
||||||
final_headers: dict[str, str] = {}
|
if authentication and authentication.client_id:
|
||||||
for key, incoming_value in headers.items():
|
mcp_provider.encrypted_credentials = self._process_credentials(authentication, mcp_provider, tenant_id)
|
||||||
if (
|
|
||||||
key in existing_masked
|
|
||||||
and key in existing_decrypted
|
|
||||||
and isinstance(incoming_value, str)
|
|
||||||
and incoming_value == existing_masked.get(key)
|
|
||||||
):
|
|
||||||
# unchanged, use original decrypted value
|
|
||||||
final_headers[key] = str(existing_decrypted[key])
|
|
||||||
else:
|
|
||||||
final_headers[key] = incoming_value
|
|
||||||
|
|
||||||
encrypted_headers_dict = MCPToolManageService._encrypt_headers(final_headers, tenant_id)
|
# Flush changes to database
|
||||||
mcp_provider.encrypted_headers = json.dumps(encrypted_headers_dict)
|
self._session.flush()
|
||||||
else:
|
|
||||||
# Explicitly clear headers if empty dict passed
|
|
||||||
mcp_provider.encrypted_headers = None
|
|
||||||
db.session.commit()
|
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
db.session.rollback()
|
self._handle_integrity_error(e, name, server_url, server_identifier)
|
||||||
error_msg = str(e.orig)
|
|
||||||
if "unique_mcp_provider_name" in error_msg:
|
|
||||||
raise ValueError(f"MCP tool {name} already exists")
|
|
||||||
if "unique_mcp_provider_server_url" in error_msg:
|
|
||||||
raise ValueError(f"MCP tool {server_url} already exists")
|
|
||||||
if "unique_mcp_provider_server_identifier" in error_msg:
|
|
||||||
raise ValueError(f"MCP tool {server_identifier} already exists")
|
|
||||||
raise
|
|
||||||
except Exception:
|
|
||||||
db.session.rollback()
|
|
||||||
raise
|
|
||||||
|
|
||||||
@classmethod
|
def delete_provider(self, *, tenant_id: str, provider_id: str) -> None:
|
||||||
def update_mcp_provider_credentials(
|
"""Delete an MCP provider."""
|
||||||
cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False
|
mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
):
|
self._session.delete(mcp_tool)
|
||||||
provider_controller = MCPToolProviderController.from_db(mcp_provider)
|
|
||||||
|
def list_providers(
|
||||||
|
self, *, tenant_id: str, for_list: bool = False, include_sensitive: bool = True
|
||||||
|
) -> list[ToolProviderApiEntity]:
|
||||||
|
"""List all MCP providers for a tenant.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: Tenant ID
|
||||||
|
for_list: If True, return provider ID; if False, return server identifier
|
||||||
|
include_sensitive: If False, skip expensive decryption operations (default: True for backward compatibility)
|
||||||
|
"""
|
||||||
|
from models.account import Account
|
||||||
|
|
||||||
|
stmt = select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id).order_by(MCPToolProvider.name)
|
||||||
|
mcp_providers = self._session.scalars(stmt).all()
|
||||||
|
|
||||||
|
if not mcp_providers:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Batch query all users to avoid N+1 problem
|
||||||
|
user_ids = {provider.user_id for provider in mcp_providers}
|
||||||
|
users = self._session.query(Account).where(Account.id.in_(user_ids)).all()
|
||||||
|
user_name_map = {user.id: user.name for user in users}
|
||||||
|
|
||||||
|
return [
|
||||||
|
ToolTransformService.mcp_provider_to_user_provider(
|
||||||
|
provider,
|
||||||
|
for_list=for_list,
|
||||||
|
user_name=user_name_map.get(provider.user_id),
|
||||||
|
include_sensitive=include_sensitive,
|
||||||
|
)
|
||||||
|
for provider in mcp_providers
|
||||||
|
]
|
||||||
|
|
||||||
|
# ========== Tool Operations ==========
|
||||||
|
|
||||||
|
def list_provider_tools(self, *, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
|
||||||
|
"""List tools from remote MCP server."""
|
||||||
|
# Load provider and convert to entity
|
||||||
|
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
|
provider_entity = db_provider.to_entity()
|
||||||
|
|
||||||
|
# Verify authentication
|
||||||
|
if not provider_entity.authed:
|
||||||
|
raise ValueError("Please auth the tool first")
|
||||||
|
|
||||||
|
# Prepare headers with auth token
|
||||||
|
headers = self._prepare_auth_headers(provider_entity)
|
||||||
|
|
||||||
|
# Retrieve tools from remote server
|
||||||
|
server_url = provider_entity.decrypt_server_url()
|
||||||
|
try:
|
||||||
|
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity)
|
||||||
|
except MCPError as e:
|
||||||
|
raise ValueError(f"Failed to connect to MCP server: {e}")
|
||||||
|
|
||||||
|
# Update database with retrieved tools
|
||||||
|
db_provider.tools = json.dumps([tool.model_dump() for tool in tools])
|
||||||
|
db_provider.authed = True
|
||||||
|
db_provider.updated_at = datetime.now()
|
||||||
|
self._session.flush()
|
||||||
|
|
||||||
|
# Build API response
|
||||||
|
return self._build_tool_provider_response(db_provider, provider_entity, tools)
|
||||||
|
|
||||||
|
# ========== OAuth and Credentials Operations ==========
|
||||||
|
|
||||||
|
def update_provider_credentials(
|
||||||
|
self, *, provider_id: str, tenant_id: str, credentials: dict[str, Any], authed: bool | None = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Update provider credentials with encryption.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_id: Provider ID
|
||||||
|
tenant_id: Tenant ID
|
||||||
|
credentials: Credentials to save
|
||||||
|
authed: Whether provider is authenticated (None means keep current state)
|
||||||
|
"""
|
||||||
|
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||||
|
|
||||||
|
# Get provider from current session
|
||||||
|
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
|
|
||||||
|
# Encrypt new credentials
|
||||||
|
provider_controller = MCPToolProviderController.from_db(provider)
|
||||||
tool_configuration = ProviderConfigEncrypter(
|
tool_configuration = ProviderConfigEncrypter(
|
||||||
tenant_id=mcp_provider.tenant_id,
|
tenant_id=provider.tenant_id,
|
||||||
config=list(provider_controller.get_credentials_schema()),
|
config=list(provider_controller.get_credentials_schema()),
|
||||||
provider_config_cache=NoOpProviderCredentialCache(),
|
provider_config_cache=NoOpProviderCredentialCache(),
|
||||||
)
|
)
|
||||||
credentials = tool_configuration.encrypt(credentials)
|
encrypted_credentials = tool_configuration.encrypt(credentials)
|
||||||
mcp_provider.updated_at = datetime.now()
|
|
||||||
mcp_provider.encrypted_credentials = json.dumps({**mcp_provider.credentials, **credentials})
|
|
||||||
mcp_provider.authed = authed
|
|
||||||
if not authed:
|
|
||||||
mcp_provider.tools = "[]"
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
@classmethod
|
# Update provider
|
||||||
def _re_connect_mcp_provider(cls, server_url: str, provider_id: str, tenant_id: str):
|
provider.updated_at = datetime.now()
|
||||||
# Get the existing provider to access headers and timeout settings
|
provider.encrypted_credentials = json.dumps({**provider.credentials, **encrypted_credentials})
|
||||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
|
||||||
headers = mcp_provider.decrypted_headers
|
if authed is not None:
|
||||||
timeout = mcp_provider.timeout
|
provider.authed = authed
|
||||||
sse_read_timeout = mcp_provider.sse_read_timeout
|
if not authed:
|
||||||
|
provider.tools = EMPTY_TOOLS_JSON
|
||||||
|
|
||||||
|
# Flush changes to database
|
||||||
|
self._session.flush()
|
||||||
|
|
||||||
|
def save_oauth_data(
|
||||||
|
self, provider_id: str, tenant_id: str, data: dict[str, Any], data_type: OAuthDataType = OAuthDataType.MIXED
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Save OAuth-related data (tokens, client info, code verifier).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_id: Provider ID
|
||||||
|
tenant_id: Tenant ID
|
||||||
|
data: Data to save (tokens, client info, or code verifier)
|
||||||
|
data_type: Type of OAuth data to save
|
||||||
|
"""
|
||||||
|
# Determine if this makes the provider authenticated
|
||||||
|
authed = (
|
||||||
|
data_type == OAuthDataType.TOKENS or (data_type == OAuthDataType.MIXED and "access_token" in data) or None
|
||||||
|
)
|
||||||
|
|
||||||
|
# update_provider_credentials will validate provider existence
|
||||||
|
self.update_provider_credentials(provider_id=provider_id, tenant_id=tenant_id, credentials=data, authed=authed)
|
||||||
|
|
||||||
|
def clear_provider_credentials(self, *, provider_id: str, tenant_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Clear all credentials for a provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_id: Provider ID
|
||||||
|
tenant_id: Tenant ID
|
||||||
|
"""
|
||||||
|
# Get provider from current session
|
||||||
|
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
|
|
||||||
|
provider.tools = EMPTY_TOOLS_JSON
|
||||||
|
provider.encrypted_credentials = EMPTY_CREDENTIALS_JSON
|
||||||
|
provider.updated_at = datetime.now()
|
||||||
|
provider.authed = False
|
||||||
|
|
||||||
|
# ========== Private Helper Methods ==========
|
||||||
|
|
||||||
|
def _check_provider_exists(self, tenant_id: str, name: str, server_url_hash: str, server_identifier: str) -> None:
|
||||||
|
"""Check if provider with same attributes already exists."""
|
||||||
|
stmt = select(MCPToolProvider).where(
|
||||||
|
MCPToolProvider.tenant_id == tenant_id,
|
||||||
|
or_(
|
||||||
|
MCPToolProvider.name == name,
|
||||||
|
MCPToolProvider.server_url_hash == server_url_hash,
|
||||||
|
MCPToolProvider.server_identifier == server_identifier,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
existing_provider = self._session.scalar(stmt)
|
||||||
|
|
||||||
|
if existing_provider:
|
||||||
|
if existing_provider.name == name:
|
||||||
|
raise ValueError(f"MCP tool {name} already exists")
|
||||||
|
if existing_provider.server_url_hash == server_url_hash:
|
||||||
|
raise ValueError("MCP tool with this server URL already exists")
|
||||||
|
if existing_provider.server_identifier == server_identifier:
|
||||||
|
raise ValueError(f"MCP tool {server_identifier} already exists")
|
||||||
|
|
||||||
|
def _prepare_icon(self, icon: str, icon_type: str, icon_background: str) -> str:
|
||||||
|
"""Prepare icon data for storage."""
|
||||||
|
if icon_type == "emoji":
|
||||||
|
return json.dumps({"content": icon, "background": icon_background})
|
||||||
|
return icon
|
||||||
|
|
||||||
|
def _encrypt_dict_fields(self, data: dict[str, Any], secret_fields: list[str], tenant_id: str) -> dict[str, str]:
|
||||||
|
"""Encrypt specified fields in a dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dictionary containing data to encrypt
|
||||||
|
secret_fields: List of field names to encrypt
|
||||||
|
tenant_id: Tenant ID for encryption
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON string of encrypted data
|
||||||
|
"""
|
||||||
|
from core.entities.provider_entities import BasicProviderConfig
|
||||||
|
from core.tools.utils.encryption import create_provider_encrypter
|
||||||
|
|
||||||
|
# Create config for secret fields
|
||||||
|
config = [
|
||||||
|
BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=field) for field in secret_fields
|
||||||
|
]
|
||||||
|
|
||||||
|
encrypter_instance, _ = create_provider_encrypter(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
config=config,
|
||||||
|
cache=NoOpProviderCredentialCache(),
|
||||||
|
)
|
||||||
|
|
||||||
|
encrypted_data = encrypter_instance.encrypt(data)
|
||||||
|
return encrypted_data
|
||||||
|
|
||||||
|
def _prepare_encrypted_dict(self, headers: dict[str, str], tenant_id: str) -> str:
|
||||||
|
"""Encrypt headers and prepare for storage."""
|
||||||
|
# All headers are treated as secret
|
||||||
|
return json.dumps(self._encrypt_dict_fields(headers, list(headers.keys()), tenant_id))
|
||||||
|
|
||||||
|
def _prepare_auth_headers(self, provider_entity: MCPProviderEntity) -> dict[str, str]:
|
||||||
|
"""Prepare headers with OAuth token if available."""
|
||||||
|
headers = provider_entity.decrypt_headers()
|
||||||
|
tokens = provider_entity.retrieve_tokens()
|
||||||
|
if tokens:
|
||||||
|
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def _retrieve_remote_mcp_tools(
|
||||||
|
self,
|
||||||
|
server_url: str,
|
||||||
|
headers: dict[str, str],
|
||||||
|
provider_entity: MCPProviderEntity,
|
||||||
|
):
|
||||||
|
"""Retrieve tools from remote MCP server."""
|
||||||
|
with MCPClientWithAuthRetry(
|
||||||
|
server_url=server_url,
|
||||||
|
headers=headers,
|
||||||
|
timeout=provider_entity.timeout,
|
||||||
|
sse_read_timeout=provider_entity.sse_read_timeout,
|
||||||
|
provider_entity=provider_entity,
|
||||||
|
) as mcp_client:
|
||||||
|
return mcp_client.list_tools()
|
||||||
|
|
||||||
|
def execute_auth_actions(self, auth_result: Any) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
Execute the actions returned by the auth function.
|
||||||
|
|
||||||
|
This method processes the AuthResult and performs the necessary database operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auth_result: The result from the auth function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The response from the auth result
|
||||||
|
"""
|
||||||
|
from core.mcp.entities import AuthAction, AuthActionType
|
||||||
|
|
||||||
|
action: AuthAction
|
||||||
|
for action in auth_result.actions:
|
||||||
|
if action.provider_id is None or action.tenant_id is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if action.action_type == AuthActionType.SAVE_CLIENT_INFO:
|
||||||
|
self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.CLIENT_INFO)
|
||||||
|
elif action.action_type == AuthActionType.SAVE_TOKENS:
|
||||||
|
self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.TOKENS)
|
||||||
|
elif action.action_type == AuthActionType.SAVE_CODE_VERIFIER:
|
||||||
|
self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.CODE_VERIFIER)
|
||||||
|
|
||||||
|
return auth_result.response
|
||||||
|
|
||||||
|
def auth_with_actions(
|
||||||
|
self, provider_entity: MCPProviderEntity, authorization_code: str | None = None
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
Perform authentication and execute all resulting actions.
|
||||||
|
|
||||||
|
This method is used by MCPClientWithAuthRetry for automatic re-authentication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_entity: The MCP provider entity
|
||||||
|
authorization_code: Optional authorization code
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response dictionary from auth result
|
||||||
|
"""
|
||||||
|
auth_result = auth(provider_entity, authorization_code)
|
||||||
|
return self.execute_auth_actions(auth_result)
|
||||||
|
|
||||||
|
def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> ReconnectResult:
|
||||||
|
"""Attempt to reconnect to MCP provider with new server URL."""
|
||||||
|
provider_entity = provider.to_entity()
|
||||||
|
headers = provider_entity.headers
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with MCPClient(
|
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity)
|
||||||
server_url,
|
return ReconnectResult(
|
||||||
provider_id,
|
authed=True,
|
||||||
tenant_id,
|
tools=json.dumps([tool.model_dump() for tool in tools]),
|
||||||
authed=False,
|
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
|
||||||
for_list=True,
|
)
|
||||||
headers=headers,
|
|
||||||
timeout=timeout,
|
|
||||||
sse_read_timeout=sse_read_timeout,
|
|
||||||
) as mcp_client:
|
|
||||||
tools = mcp_client.list_tools()
|
|
||||||
return {
|
|
||||||
"authed": True,
|
|
||||||
"tools": json.dumps([tool.model_dump() for tool in tools]),
|
|
||||||
"encrypted_credentials": "{}",
|
|
||||||
}
|
|
||||||
except MCPAuthError:
|
except MCPAuthError:
|
||||||
return {"authed": False, "tools": "[]", "encrypted_credentials": "{}"}
|
return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON)
|
||||||
except MCPError as e:
|
except MCPError as e:
|
||||||
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
|
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
|
||||||
|
|
||||||
|
def validate_server_url_change(
|
||||||
|
self, *, tenant_id: str, provider_id: str, new_server_url: str
|
||||||
|
) -> ServerUrlValidationResult:
|
||||||
|
"""
|
||||||
|
Validate server URL change by attempting to connect to the new server.
|
||||||
|
This method should be called BEFORE update_provider to perform network operations
|
||||||
|
outside of the database transaction.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ServerUrlValidationResult: Validation result with connection status and tools if successful
|
||||||
|
"""
|
||||||
|
# Handle hidden/unchanged URL
|
||||||
|
if UNCHANGED_SERVER_URL_PLACEHOLDER in new_server_url:
|
||||||
|
return ServerUrlValidationResult(needs_validation=False)
|
||||||
|
|
||||||
|
# Validate URL format
|
||||||
|
if not self._is_valid_url(new_server_url):
|
||||||
|
raise ValueError("Server URL is not valid.")
|
||||||
|
|
||||||
|
# Always encrypt and hash the URL
|
||||||
|
encrypted_server_url = encrypter.encrypt_token(tenant_id, new_server_url)
|
||||||
|
new_server_url_hash = hashlib.sha256(new_server_url.encode()).hexdigest()
|
||||||
|
|
||||||
|
# Get current provider
|
||||||
|
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
|
|
||||||
|
# Check if URL is actually different
|
||||||
|
if new_server_url_hash == provider.server_url_hash:
|
||||||
|
# URL hasn't changed, but still return the encrypted data
|
||||||
|
return ServerUrlValidationResult(
|
||||||
|
needs_validation=False, encrypted_server_url=encrypted_server_url, server_url_hash=new_server_url_hash
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform validation by attempting to connect
|
||||||
|
reconnect_result = self._reconnect_provider(server_url=new_server_url, provider=provider)
|
||||||
|
return ServerUrlValidationResult(
|
||||||
|
needs_validation=True,
|
||||||
|
validation_passed=True,
|
||||||
|
reconnect_result=reconnect_result,
|
||||||
|
encrypted_server_url=encrypted_server_url,
|
||||||
|
server_url_hash=new_server_url_hash,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_tool_provider_response(
|
||||||
|
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list
|
||||||
|
) -> ToolProviderApiEntity:
|
||||||
|
"""Build API response for tool provider."""
|
||||||
|
user = db_provider.load_user()
|
||||||
|
response = provider_entity.to_api_response(
|
||||||
|
user_name=user.name if user else None,
|
||||||
|
)
|
||||||
|
response["tools"] = ToolTransformService.mcp_tool_to_user_tool(db_provider, tools)
|
||||||
|
response["plugin_unique_identifier"] = provider_entity.provider_id
|
||||||
|
return ToolProviderApiEntity(**response)
|
||||||
|
|
||||||
|
def _handle_integrity_error(
|
||||||
|
self, error: IntegrityError, name: str, server_url: str, server_identifier: str
|
||||||
|
) -> None:
|
||||||
|
"""Handle database integrity errors with user-friendly messages."""
|
||||||
|
error_msg = str(error.orig)
|
||||||
|
if "unique_mcp_provider_name" in error_msg:
|
||||||
|
raise ValueError(f"MCP tool {name} already exists")
|
||||||
|
if "unique_mcp_provider_server_url" in error_msg:
|
||||||
|
raise ValueError(f"MCP tool {server_url} already exists")
|
||||||
|
if "unique_mcp_provider_server_identifier" in error_msg:
|
||||||
|
raise ValueError(f"MCP tool {server_identifier} already exists")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _is_valid_url(self, url: str) -> bool:
|
||||||
|
"""Validate URL format."""
|
||||||
|
if not url:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"]
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _update_optional_fields(self, mcp_provider: MCPToolProvider, configuration: MCPConfiguration) -> None:
|
||||||
|
"""Update optional configuration fields using setattr for cleaner code."""
|
||||||
|
field_mapping = {"timeout": configuration.timeout, "sse_read_timeout": configuration.sse_read_timeout}
|
||||||
|
|
||||||
|
for field, value in field_mapping.items():
|
||||||
|
if value is not None:
|
||||||
|
setattr(mcp_provider, field, value)
|
||||||
|
|
||||||
|
def _process_headers(self, headers: dict[str, str], mcp_provider: MCPToolProvider, tenant_id: str) -> str | None:
|
||||||
|
"""Process headers update, handling empty dict to clear headers."""
|
||||||
|
if not headers:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Merge with existing headers to preserve masked values
|
||||||
|
final_headers = self._merge_headers_with_masked(incoming_headers=headers, mcp_provider=mcp_provider)
|
||||||
|
return self._prepare_encrypted_dict(final_headers, tenant_id)
|
||||||
|
|
||||||
|
def _process_credentials(
|
||||||
|
self, authentication: MCPAuthentication, mcp_provider: MCPToolProvider, tenant_id: str
|
||||||
|
) -> str:
|
||||||
|
"""Process credentials update, handling masked values."""
|
||||||
|
# Merge with existing credentials
|
||||||
|
final_client_id, final_client_secret = self._merge_credentials_with_masked(
|
||||||
|
authentication.client_id, authentication.client_secret, mcp_provider
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build and encrypt
|
||||||
|
return self._build_and_encrypt_credentials(final_client_id, final_client_secret, tenant_id)
|
||||||
|
|
||||||
|
def _merge_headers_with_masked(
|
||||||
|
self, incoming_headers: dict[str, str], mcp_provider: MCPToolProvider
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""Merge incoming headers with existing ones, preserving unchanged masked values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
incoming_headers: Headers from frontend (may contain masked values)
|
||||||
|
mcp_provider: The MCP provider instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Final headers dict with proper values (original for unchanged masked, new for changed)
|
||||||
|
"""
|
||||||
|
mcp_provider_entity = mcp_provider.to_entity()
|
||||||
|
existing_decrypted = mcp_provider_entity.decrypt_headers()
|
||||||
|
existing_masked = mcp_provider_entity.masked_headers()
|
||||||
|
|
||||||
|
return {
|
||||||
|
key: (str(existing_decrypted[key]) if key in existing_masked and value == existing_masked[key] else value)
|
||||||
|
for key, value in incoming_headers.items()
|
||||||
|
if key in existing_decrypted or value != existing_masked.get(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
def _merge_credentials_with_masked(
|
||||||
|
self,
|
||||||
|
client_id: str,
|
||||||
|
client_secret: str | None,
|
||||||
|
mcp_provider: MCPToolProvider,
|
||||||
|
) -> tuple[
|
||||||
|
str,
|
||||||
|
str | None,
|
||||||
|
]:
|
||||||
|
"""Merge incoming credentials with existing ones, preserving unchanged masked values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client_id: Client ID from frontend (may be masked)
|
||||||
|
client_secret: Client secret from frontend (may be masked)
|
||||||
|
mcp_provider: The MCP provider instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (final_client_id, final_client_secret)
|
||||||
|
"""
|
||||||
|
mcp_provider_entity = mcp_provider.to_entity()
|
||||||
|
existing_decrypted = mcp_provider_entity.decrypt_credentials()
|
||||||
|
existing_masked = mcp_provider_entity.masked_credentials()
|
||||||
|
|
||||||
|
# Check if client_id is masked and unchanged
|
||||||
|
final_client_id = client_id
|
||||||
|
if existing_masked.get("client_id") and client_id == existing_masked["client_id"]:
|
||||||
|
# Use existing decrypted value
|
||||||
|
final_client_id = existing_decrypted.get("client_id", client_id)
|
||||||
|
|
||||||
|
# Check if client_secret is masked and unchanged
|
||||||
|
final_client_secret = client_secret
|
||||||
|
if existing_masked.get("client_secret") and client_secret == existing_masked["client_secret"]:
|
||||||
|
# Use existing decrypted value
|
||||||
|
final_client_secret = existing_decrypted.get("client_secret", client_secret)
|
||||||
|
|
||||||
|
return final_client_id, final_client_secret
|
||||||
|
|
||||||
|
def _build_and_encrypt_credentials(self, client_id: str, client_secret: str | None, tenant_id: str) -> str:
|
||||||
|
"""Build credentials and encrypt sensitive fields."""
|
||||||
|
# Create a flat structure with all credential data
|
||||||
|
credentials_data = {
|
||||||
|
"client_id": client_id,
|
||||||
|
"client_name": CLIENT_NAME,
|
||||||
|
"is_dynamic_registration": False,
|
||||||
|
}
|
||||||
|
secret_fields = []
|
||||||
|
if client_secret is not None:
|
||||||
|
credentials_data["encrypted_client_secret"] = client_secret
|
||||||
|
secret_fields = ["encrypted_client_secret"]
|
||||||
|
client_info = self._encrypt_dict_fields(credentials_data, secret_fields, tenant_id)
|
||||||
|
return json.dumps({"client_information": client_info})
|
||||||
|
|||||||
@ -3,9 +3,11 @@ import logging
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
|
from pydantic import ValidationError
|
||||||
from yarl import URL
|
from yarl import URL
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from core.entities.mcp_provider import MCPConfiguration
|
||||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||||
from core.mcp.types import Tool as MCPTool
|
from core.mcp.types import Tool as MCPTool
|
||||||
from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity
|
from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity
|
||||||
@ -232,40 +234,57 @@ class ToolTransformService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def mcp_provider_to_user_provider(db_provider: MCPToolProvider, for_list: bool = False) -> ToolProviderApiEntity:
|
def mcp_provider_to_user_provider(
|
||||||
user = db_provider.load_user()
|
db_provider: MCPToolProvider,
|
||||||
return ToolProviderApiEntity(
|
for_list: bool = False,
|
||||||
id=db_provider.server_identifier if not for_list else db_provider.id,
|
user_name: str | None = None,
|
||||||
author=user.name if user else "Anonymous",
|
include_sensitive: bool = True,
|
||||||
name=db_provider.name,
|
) -> ToolProviderApiEntity:
|
||||||
icon=db_provider.provider_icon,
|
# Use provided user_name to avoid N+1 query, fallback to load_user() if not provided
|
||||||
type=ToolProviderType.MCP,
|
if user_name is None:
|
||||||
is_team_authorization=db_provider.authed,
|
user = db_provider.load_user()
|
||||||
server_url=db_provider.masked_server_url,
|
user_name = user.name if user else None
|
||||||
tools=ToolTransformService.mcp_tool_to_user_tool(
|
|
||||||
db_provider, [MCPTool.model_validate(tool) for tool in json.loads(db_provider.tools)]
|
# Convert to entity and use its API response method
|
||||||
),
|
provider_entity = db_provider.to_entity()
|
||||||
updated_at=int(db_provider.updated_at.timestamp()),
|
|
||||||
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
|
response = provider_entity.to_api_response(user_name=user_name, include_sensitive=include_sensitive)
|
||||||
description=I18nObject(en_US="", zh_Hans=""),
|
try:
|
||||||
server_identifier=db_provider.server_identifier,
|
mcp_tools = [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
|
||||||
timeout=db_provider.timeout,
|
except (ValidationError, json.JSONDecodeError):
|
||||||
sse_read_timeout=db_provider.sse_read_timeout,
|
mcp_tools = []
|
||||||
masked_headers=db_provider.masked_headers,
|
# Add additional fields specific to the transform
|
||||||
original_headers=db_provider.decrypted_headers,
|
response["id"] = db_provider.server_identifier if not for_list else db_provider.id
|
||||||
)
|
response["tools"] = ToolTransformService.mcp_tool_to_user_tool(db_provider, mcp_tools, user_name=user_name)
|
||||||
|
response["server_identifier"] = db_provider.server_identifier
|
||||||
|
|
||||||
|
# Convert configuration dict to MCPConfiguration object
|
||||||
|
if "configuration" in response and isinstance(response["configuration"], dict):
|
||||||
|
response["configuration"] = MCPConfiguration(
|
||||||
|
timeout=float(response["configuration"]["timeout"]),
|
||||||
|
sse_read_timeout=float(response["configuration"]["sse_read_timeout"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
return ToolProviderApiEntity(**response)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def mcp_tool_to_user_tool(mcp_provider: MCPToolProvider, tools: list[MCPTool]) -> list[ToolApiEntity]:
|
def mcp_tool_to_user_tool(
|
||||||
user = mcp_provider.load_user()
|
mcp_provider: MCPToolProvider, tools: list[MCPTool], user_name: str | None = None
|
||||||
|
) -> list[ToolApiEntity]:
|
||||||
|
# Use provided user_name to avoid N+1 query, fallback to load_user() if not provided
|
||||||
|
if user_name is None:
|
||||||
|
user = mcp_provider.load_user()
|
||||||
|
user_name = user.name if user else "Anonymous"
|
||||||
|
|
||||||
return [
|
return [
|
||||||
ToolApiEntity(
|
ToolApiEntity(
|
||||||
author=user.name if user else "Anonymous",
|
author=user_name or "Anonymous",
|
||||||
name=tool.name,
|
name=tool.name,
|
||||||
label=I18nObject(en_US=tool.name, zh_Hans=tool.name),
|
label=I18nObject(en_US=tool.name, zh_Hans=tool.name),
|
||||||
description=I18nObject(en_US=tool.description or "", zh_Hans=tool.description or ""),
|
description=I18nObject(en_US=tool.description or "", zh_Hans=tool.description or ""),
|
||||||
parameters=ToolTransformService.convert_mcp_schema_to_parameter(tool.inputSchema),
|
parameters=ToolTransformService.convert_mcp_schema_to_parameter(tool.inputSchema),
|
||||||
labels=[],
|
labels=[],
|
||||||
|
output_schema=tool.outputSchema or {},
|
||||||
)
|
)
|
||||||
for tool in tools
|
for tool in tools
|
||||||
]
|
]
|
||||||
@ -412,7 +431,7 @@ class ToolTransformService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_mcp_schema_to_parameter(schema: dict) -> list["ToolParameter"]:
|
def convert_mcp_schema_to_parameter(schema: dict[str, Any]) -> list["ToolParameter"]:
|
||||||
"""
|
"""
|
||||||
Convert MCP JSON schema to tool parameters
|
Convert MCP JSON schema to tool parameters
|
||||||
|
|
||||||
@ -421,7 +440,7 @@ class ToolTransformService:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def create_parameter(
|
def create_parameter(
|
||||||
name: str, description: str, param_type: str, required: bool, input_schema: dict | None = None
|
name: str, description: str, param_type: str, required: bool, input_schema: dict[str, Any] | None = None
|
||||||
) -> ToolParameter:
|
) -> ToolParameter:
|
||||||
"""Create a ToolParameter instance with given attributes"""
|
"""Create a ToolParameter instance with given attributes"""
|
||||||
input_schema_dict: dict[str, Any] = {"input_schema": input_schema} if input_schema else {}
|
input_schema_dict: dict[str, Any] = {"input_schema": input_schema} if input_schema else {}
|
||||||
@ -436,7 +455,9 @@ class ToolTransformService:
|
|||||||
**input_schema_dict,
|
**input_schema_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_properties(props: dict, required: list, prefix: str = "") -> list[ToolParameter]:
|
def process_properties(
|
||||||
|
props: dict[str, dict[str, Any]], required: list[str], prefix: str = ""
|
||||||
|
) -> list[ToolParameter]:
|
||||||
"""Process properties recursively"""
|
"""Process properties recursively"""
|
||||||
TYPE_MAPPING = {"integer": "number", "float": "number"}
|
TYPE_MAPPING = {"integer": "number", "float": "number"}
|
||||||
COMPLEX_TYPES = ["array", "object"]
|
COMPLEX_TYPES = ["array", "object"]
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, Generic, TypeAlias, TypeVar, overload
|
from typing import Any, Generic, TypeAlias, TypeVar, overload
|
||||||
|
|
||||||
@ -66,7 +67,17 @@ class TruncationResult:
|
|||||||
truncated: bool
|
truncated: bool
|
||||||
|
|
||||||
|
|
||||||
class VariableTruncator:
|
class BaseTruncator(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def truncate(self, segment: Segment) -> TruncationResult:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def truncate_variable_mapping(self, v: Mapping[str, Any]) -> tuple[Mapping[str, Any], bool]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class VariableTruncator(BaseTruncator):
|
||||||
"""
|
"""
|
||||||
Handles variable truncation with structure-preserving strategies.
|
Handles variable truncation with structure-preserving strategies.
|
||||||
|
|
||||||
@ -418,3 +429,38 @@ class VariableTruncator:
|
|||||||
return _PartResult(val, self.calculate_json_size(val), False)
|
return _PartResult(val, self.calculate_json_size(val), False)
|
||||||
else:
|
else:
|
||||||
raise AssertionError("this statement should be unreachable.")
|
raise AssertionError("this statement should be unreachable.")
|
||||||
|
|
||||||
|
|
||||||
|
class DummyVariableTruncator(BaseTruncator):
|
||||||
|
"""
|
||||||
|
A no-op variable truncator that doesn't truncate any data.
|
||||||
|
|
||||||
|
This is used for Service API calls where truncation should be disabled
|
||||||
|
to maintain backward compatibility and provide complete data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def truncate_variable_mapping(self, v: Mapping[str, Any]) -> tuple[Mapping[str, Any], bool]:
|
||||||
|
"""
|
||||||
|
Return original mapping without truncation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
v: The variable mapping to process
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (original_mapping, False) where False indicates no truncation occurred
|
||||||
|
"""
|
||||||
|
return v, False
|
||||||
|
|
||||||
|
def truncate(self, segment: Segment) -> TruncationResult:
|
||||||
|
"""
|
||||||
|
Return original segment without truncation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
segment: The segment to process
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The original segment unchanged
|
||||||
|
"""
|
||||||
|
# For Service API, we want to preserve the original segment
|
||||||
|
# without any truncation, so just return it as-is
|
||||||
|
return TruncationResult(result=segment, truncated=False)
|
||||||
|
|||||||
258
api/tests/fixtures/workflow/iteration_flatten_output_disabled_workflow.yml
vendored
Normal file
258
api/tests/fixtures/workflow/iteration_flatten_output_disabled_workflow.yml
vendored
Normal file
@ -0,0 +1,258 @@
|
|||||||
|
app:
|
||||||
|
description: 'This workflow tests the iteration node with flatten_output=False.
|
||||||
|
|
||||||
|
|
||||||
|
It processes [1, 2, 3], outputs [item, item*2] for each iteration.
|
||||||
|
|
||||||
|
|
||||||
|
With flatten_output=False, it should output nested arrays:
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
{"output": [[1, 2], [2, 4], [3, 6]]}
|
||||||
|
|
||||||
|
```'
|
||||||
|
icon: 🤖
|
||||||
|
icon_background: '#FFEAD5'
|
||||||
|
mode: workflow
|
||||||
|
name: test_iteration_flatten_disabled
|
||||||
|
use_icon_as_answer_icon: false
|
||||||
|
dependencies: []
|
||||||
|
kind: app
|
||||||
|
version: 0.3.1
|
||||||
|
workflow:
|
||||||
|
conversation_variables: []
|
||||||
|
environment_variables: []
|
||||||
|
features:
|
||||||
|
file_upload:
|
||||||
|
enabled: false
|
||||||
|
opening_statement: ''
|
||||||
|
retriever_resource:
|
||||||
|
enabled: true
|
||||||
|
sensitive_word_avoidance:
|
||||||
|
enabled: false
|
||||||
|
speech_to_text:
|
||||||
|
enabled: false
|
||||||
|
suggested_questions: []
|
||||||
|
suggested_questions_after_answer:
|
||||||
|
enabled: false
|
||||||
|
text_to_speech:
|
||||||
|
enabled: false
|
||||||
|
graph:
|
||||||
|
edges:
|
||||||
|
- data:
|
||||||
|
isInIteration: false
|
||||||
|
isInLoop: false
|
||||||
|
sourceType: start
|
||||||
|
targetType: code
|
||||||
|
id: start-source-code-target
|
||||||
|
source: start_node
|
||||||
|
sourceHandle: source
|
||||||
|
target: code_node
|
||||||
|
targetHandle: target
|
||||||
|
type: custom
|
||||||
|
zIndex: 0
|
||||||
|
- data:
|
||||||
|
isInIteration: false
|
||||||
|
isInLoop: false
|
||||||
|
sourceType: code
|
||||||
|
targetType: iteration
|
||||||
|
id: code-source-iteration-target
|
||||||
|
source: code_node
|
||||||
|
sourceHandle: source
|
||||||
|
target: iteration_node
|
||||||
|
targetHandle: target
|
||||||
|
type: custom
|
||||||
|
zIndex: 0
|
||||||
|
- data:
|
||||||
|
isInIteration: true
|
||||||
|
isInLoop: false
|
||||||
|
iteration_id: iteration_node
|
||||||
|
sourceType: iteration-start
|
||||||
|
targetType: code
|
||||||
|
id: iteration-start-source-code-inner-target
|
||||||
|
source: iteration_nodestart
|
||||||
|
sourceHandle: source
|
||||||
|
target: code_inner_node
|
||||||
|
targetHandle: target
|
||||||
|
type: custom
|
||||||
|
zIndex: 1002
|
||||||
|
- data:
|
||||||
|
isInIteration: false
|
||||||
|
isInLoop: false
|
||||||
|
sourceType: iteration
|
||||||
|
targetType: end
|
||||||
|
id: iteration-source-end-target
|
||||||
|
source: iteration_node
|
||||||
|
sourceHandle: source
|
||||||
|
target: end_node
|
||||||
|
targetHandle: target
|
||||||
|
type: custom
|
||||||
|
zIndex: 0
|
||||||
|
nodes:
|
||||||
|
- data:
|
||||||
|
desc: ''
|
||||||
|
selected: false
|
||||||
|
title: Start
|
||||||
|
type: start
|
||||||
|
variables: []
|
||||||
|
height: 54
|
||||||
|
id: start_node
|
||||||
|
position:
|
||||||
|
x: 80
|
||||||
|
y: 282
|
||||||
|
positionAbsolute:
|
||||||
|
x: 80
|
||||||
|
y: 282
|
||||||
|
sourcePosition: right
|
||||||
|
targetPosition: left
|
||||||
|
type: custom
|
||||||
|
width: 244
|
||||||
|
- data:
|
||||||
|
code: "\ndef main() -> dict:\n return {\n \"result\": [1, 2, 3],\n\
|
||||||
|
\ }\n"
|
||||||
|
code_language: python3
|
||||||
|
desc: ''
|
||||||
|
outputs:
|
||||||
|
result:
|
||||||
|
children: null
|
||||||
|
type: array[number]
|
||||||
|
selected: false
|
||||||
|
title: Generate Array
|
||||||
|
type: code
|
||||||
|
variables: []
|
||||||
|
height: 54
|
||||||
|
id: code_node
|
||||||
|
position:
|
||||||
|
x: 384
|
||||||
|
y: 282
|
||||||
|
positionAbsolute:
|
||||||
|
x: 384
|
||||||
|
y: 282
|
||||||
|
selected: false
|
||||||
|
sourcePosition: right
|
||||||
|
targetPosition: left
|
||||||
|
type: custom
|
||||||
|
width: 244
|
||||||
|
- data:
|
||||||
|
desc: ''
|
||||||
|
error_handle_mode: terminated
|
||||||
|
flatten_output: false
|
||||||
|
height: 178
|
||||||
|
is_parallel: false
|
||||||
|
iterator_input_type: array[number]
|
||||||
|
iterator_selector:
|
||||||
|
- code_node
|
||||||
|
- result
|
||||||
|
output_selector:
|
||||||
|
- code_inner_node
|
||||||
|
- result
|
||||||
|
output_type: array[array[number]]
|
||||||
|
parallel_nums: 10
|
||||||
|
selected: false
|
||||||
|
start_node_id: iteration_nodestart
|
||||||
|
title: Iteration with Flatten Disabled
|
||||||
|
type: iteration
|
||||||
|
width: 388
|
||||||
|
height: 178
|
||||||
|
id: iteration_node
|
||||||
|
position:
|
||||||
|
x: 684
|
||||||
|
y: 282
|
||||||
|
positionAbsolute:
|
||||||
|
x: 684
|
||||||
|
y: 282
|
||||||
|
selected: false
|
||||||
|
sourcePosition: right
|
||||||
|
targetPosition: left
|
||||||
|
type: custom
|
||||||
|
width: 388
|
||||||
|
zIndex: 1
|
||||||
|
- data:
|
||||||
|
desc: ''
|
||||||
|
isInIteration: true
|
||||||
|
selected: false
|
||||||
|
title: ''
|
||||||
|
type: iteration-start
|
||||||
|
draggable: false
|
||||||
|
height: 48
|
||||||
|
id: iteration_nodestart
|
||||||
|
parentId: iteration_node
|
||||||
|
position:
|
||||||
|
x: 24
|
||||||
|
y: 68
|
||||||
|
positionAbsolute:
|
||||||
|
x: 708
|
||||||
|
y: 350
|
||||||
|
selectable: false
|
||||||
|
sourcePosition: right
|
||||||
|
targetPosition: left
|
||||||
|
type: custom-iteration-start
|
||||||
|
width: 44
|
||||||
|
zIndex: 1002
|
||||||
|
- data:
|
||||||
|
code: "\ndef main(arg1: int) -> dict:\n return {\n \"result\": [arg1,\
|
||||||
|
\ arg1 * 2],\n }\n"
|
||||||
|
code_language: python3
|
||||||
|
desc: ''
|
||||||
|
isInIteration: true
|
||||||
|
isInLoop: false
|
||||||
|
iteration_id: iteration_node
|
||||||
|
outputs:
|
||||||
|
result:
|
||||||
|
children: null
|
||||||
|
type: array[number]
|
||||||
|
selected: false
|
||||||
|
title: Generate Pair
|
||||||
|
type: code
|
||||||
|
variables:
|
||||||
|
- value_selector:
|
||||||
|
- iteration_node
|
||||||
|
- item
|
||||||
|
value_type: number
|
||||||
|
variable: arg1
|
||||||
|
height: 54
|
||||||
|
id: code_inner_node
|
||||||
|
parentId: iteration_node
|
||||||
|
position:
|
||||||
|
x: 128
|
||||||
|
y: 68
|
||||||
|
positionAbsolute:
|
||||||
|
x: 812
|
||||||
|
y: 350
|
||||||
|
selected: false
|
||||||
|
sourcePosition: right
|
||||||
|
targetPosition: left
|
||||||
|
type: custom
|
||||||
|
width: 244
|
||||||
|
zIndex: 1002
|
||||||
|
- data:
|
||||||
|
desc: ''
|
||||||
|
outputs:
|
||||||
|
- value_selector:
|
||||||
|
- iteration_node
|
||||||
|
- output
|
||||||
|
value_type: array[array[number]]
|
||||||
|
variable: output
|
||||||
|
selected: false
|
||||||
|
title: End
|
||||||
|
type: end
|
||||||
|
height: 90
|
||||||
|
id: end_node
|
||||||
|
position:
|
||||||
|
x: 1132
|
||||||
|
y: 282
|
||||||
|
positionAbsolute:
|
||||||
|
x: 1132
|
||||||
|
y: 282
|
||||||
|
selected: true
|
||||||
|
sourcePosition: right
|
||||||
|
targetPosition: left
|
||||||
|
type: custom
|
||||||
|
width: 244
|
||||||
|
viewport:
|
||||||
|
x: -476
|
||||||
|
y: 3
|
||||||
|
zoom: 1
|
||||||
|
|
||||||
258
api/tests/fixtures/workflow/iteration_flatten_output_enabled_workflow.yml
vendored
Normal file
258
api/tests/fixtures/workflow/iteration_flatten_output_enabled_workflow.yml
vendored
Normal file
@ -0,0 +1,258 @@
|
|||||||
|
app:
|
||||||
|
description: 'This workflow tests the iteration node with flatten_output=True.
|
||||||
|
|
||||||
|
|
||||||
|
It processes [1, 2, 3], outputs [item, item*2] for each iteration.
|
||||||
|
|
||||||
|
|
||||||
|
With flatten_output=True (default), it should output:
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
{"output": [1, 2, 2, 4, 3, 6]}
|
||||||
|
|
||||||
|
```'
|
||||||
|
icon: 🤖
|
||||||
|
icon_background: '#FFEAD5'
|
||||||
|
mode: workflow
|
||||||
|
name: test_iteration_flatten_enabled
|
||||||
|
use_icon_as_answer_icon: false
|
||||||
|
dependencies: []
|
||||||
|
kind: app
|
||||||
|
version: 0.3.1
|
||||||
|
workflow:
|
||||||
|
conversation_variables: []
|
||||||
|
environment_variables: []
|
||||||
|
features:
|
||||||
|
file_upload:
|
||||||
|
enabled: false
|
||||||
|
opening_statement: ''
|
||||||
|
retriever_resource:
|
||||||
|
enabled: true
|
||||||
|
sensitive_word_avoidance:
|
||||||
|
enabled: false
|
||||||
|
speech_to_text:
|
||||||
|
enabled: false
|
||||||
|
suggested_questions: []
|
||||||
|
suggested_questions_after_answer:
|
||||||
|
enabled: false
|
||||||
|
text_to_speech:
|
||||||
|
enabled: false
|
||||||
|
graph:
|
||||||
|
edges:
|
||||||
|
- data:
|
||||||
|
isInIteration: false
|
||||||
|
isInLoop: false
|
||||||
|
sourceType: start
|
||||||
|
targetType: code
|
||||||
|
id: start-source-code-target
|
||||||
|
source: start_node
|
||||||
|
sourceHandle: source
|
||||||
|
target: code_node
|
||||||
|
targetHandle: target
|
||||||
|
type: custom
|
||||||
|
zIndex: 0
|
||||||
|
- data:
|
||||||
|
isInIteration: false
|
||||||
|
isInLoop: false
|
||||||
|
sourceType: code
|
||||||
|
targetType: iteration
|
||||||
|
id: code-source-iteration-target
|
||||||
|
source: code_node
|
||||||
|
sourceHandle: source
|
||||||
|
target: iteration_node
|
||||||
|
targetHandle: target
|
||||||
|
type: custom
|
||||||
|
zIndex: 0
|
||||||
|
- data:
|
||||||
|
isInIteration: true
|
||||||
|
isInLoop: false
|
||||||
|
iteration_id: iteration_node
|
||||||
|
sourceType: iteration-start
|
||||||
|
targetType: code
|
||||||
|
id: iteration-start-source-code-inner-target
|
||||||
|
source: iteration_nodestart
|
||||||
|
sourceHandle: source
|
||||||
|
target: code_inner_node
|
||||||
|
targetHandle: target
|
||||||
|
type: custom
|
||||||
|
zIndex: 1002
|
||||||
|
- data:
|
||||||
|
isInIteration: false
|
||||||
|
isInLoop: false
|
||||||
|
sourceType: iteration
|
||||||
|
targetType: end
|
||||||
|
id: iteration-source-end-target
|
||||||
|
source: iteration_node
|
||||||
|
sourceHandle: source
|
||||||
|
target: end_node
|
||||||
|
targetHandle: target
|
||||||
|
type: custom
|
||||||
|
zIndex: 0
|
||||||
|
nodes:
|
||||||
|
- data:
|
||||||
|
desc: ''
|
||||||
|
selected: false
|
||||||
|
title: Start
|
||||||
|
type: start
|
||||||
|
variables: []
|
||||||
|
height: 54
|
||||||
|
id: start_node
|
||||||
|
position:
|
||||||
|
x: 80
|
||||||
|
y: 282
|
||||||
|
positionAbsolute:
|
||||||
|
x: 80
|
||||||
|
y: 282
|
||||||
|
sourcePosition: right
|
||||||
|
targetPosition: left
|
||||||
|
type: custom
|
||||||
|
width: 244
|
||||||
|
- data:
|
||||||
|
code: "\ndef main() -> dict:\n return {\n \"result\": [1, 2, 3],\n\
|
||||||
|
\ }\n"
|
||||||
|
code_language: python3
|
||||||
|
desc: ''
|
||||||
|
outputs:
|
||||||
|
result:
|
||||||
|
children: null
|
||||||
|
type: array[number]
|
||||||
|
selected: false
|
||||||
|
title: Generate Array
|
||||||
|
type: code
|
||||||
|
variables: []
|
||||||
|
height: 54
|
||||||
|
id: code_node
|
||||||
|
position:
|
||||||
|
x: 384
|
||||||
|
y: 282
|
||||||
|
positionAbsolute:
|
||||||
|
x: 384
|
||||||
|
y: 282
|
||||||
|
selected: false
|
||||||
|
sourcePosition: right
|
||||||
|
targetPosition: left
|
||||||
|
type: custom
|
||||||
|
width: 244
|
||||||
|
- data:
|
||||||
|
desc: ''
|
||||||
|
error_handle_mode: terminated
|
||||||
|
flatten_output: true
|
||||||
|
height: 178
|
||||||
|
is_parallel: false
|
||||||
|
iterator_input_type: array[number]
|
||||||
|
iterator_selector:
|
||||||
|
- code_node
|
||||||
|
- result
|
||||||
|
output_selector:
|
||||||
|
- code_inner_node
|
||||||
|
- result
|
||||||
|
output_type: array[array[number]]
|
||||||
|
parallel_nums: 10
|
||||||
|
selected: false
|
||||||
|
start_node_id: iteration_nodestart
|
||||||
|
title: Iteration with Flatten Enabled
|
||||||
|
type: iteration
|
||||||
|
width: 388
|
||||||
|
height: 178
|
||||||
|
id: iteration_node
|
||||||
|
position:
|
||||||
|
x: 684
|
||||||
|
y: 282
|
||||||
|
positionAbsolute:
|
||||||
|
x: 684
|
||||||
|
y: 282
|
||||||
|
selected: false
|
||||||
|
sourcePosition: right
|
||||||
|
targetPosition: left
|
||||||
|
type: custom
|
||||||
|
width: 388
|
||||||
|
zIndex: 1
|
||||||
|
- data:
|
||||||
|
desc: ''
|
||||||
|
isInIteration: true
|
||||||
|
selected: false
|
||||||
|
title: ''
|
||||||
|
type: iteration-start
|
||||||
|
draggable: false
|
||||||
|
height: 48
|
||||||
|
id: iteration_nodestart
|
||||||
|
parentId: iteration_node
|
||||||
|
position:
|
||||||
|
x: 24
|
||||||
|
y: 68
|
||||||
|
positionAbsolute:
|
||||||
|
x: 708
|
||||||
|
y: 350
|
||||||
|
selectable: false
|
||||||
|
sourcePosition: right
|
||||||
|
targetPosition: left
|
||||||
|
type: custom-iteration-start
|
||||||
|
width: 44
|
||||||
|
zIndex: 1002
|
||||||
|
- data:
|
||||||
|
code: "\ndef main(arg1: int) -> dict:\n return {\n \"result\": [arg1,\
|
||||||
|
\ arg1 * 2],\n }\n"
|
||||||
|
code_language: python3
|
||||||
|
desc: ''
|
||||||
|
isInIteration: true
|
||||||
|
isInLoop: false
|
||||||
|
iteration_id: iteration_node
|
||||||
|
outputs:
|
||||||
|
result:
|
||||||
|
children: null
|
||||||
|
type: array[number]
|
||||||
|
selected: false
|
||||||
|
title: Generate Pair
|
||||||
|
type: code
|
||||||
|
variables:
|
||||||
|
- value_selector:
|
||||||
|
- iteration_node
|
||||||
|
- item
|
||||||
|
value_type: number
|
||||||
|
variable: arg1
|
||||||
|
height: 54
|
||||||
|
id: code_inner_node
|
||||||
|
parentId: iteration_node
|
||||||
|
position:
|
||||||
|
x: 128
|
||||||
|
y: 68
|
||||||
|
positionAbsolute:
|
||||||
|
x: 812
|
||||||
|
y: 350
|
||||||
|
selected: false
|
||||||
|
sourcePosition: right
|
||||||
|
targetPosition: left
|
||||||
|
type: custom
|
||||||
|
width: 244
|
||||||
|
zIndex: 1002
|
||||||
|
- data:
|
||||||
|
desc: ''
|
||||||
|
outputs:
|
||||||
|
- value_selector:
|
||||||
|
- iteration_node
|
||||||
|
- output
|
||||||
|
value_type: array[number]
|
||||||
|
variable: output
|
||||||
|
selected: false
|
||||||
|
title: End
|
||||||
|
type: end
|
||||||
|
height: 90
|
||||||
|
id: end_node
|
||||||
|
position:
|
||||||
|
x: 1132
|
||||||
|
y: 282
|
||||||
|
positionAbsolute:
|
||||||
|
x: 1132
|
||||||
|
y: 282
|
||||||
|
selected: true
|
||||||
|
sourcePosition: right
|
||||||
|
targetPosition: left
|
||||||
|
type: custom
|
||||||
|
width: 244
|
||||||
|
viewport:
|
||||||
|
x: -476
|
||||||
|
y: 3
|
||||||
|
zoom: 1
|
||||||
|
|
||||||
@ -182,6 +182,28 @@ class TestOpenSearchVector:
|
|||||||
assert len(ids) == 1
|
assert len(ids) == 1
|
||||||
assert ids[0] == "mock_id"
|
assert ids[0] == "mock_id"
|
||||||
|
|
||||||
|
def test_delete_nonexistent_index(self):
|
||||||
|
"""Test deleting a non-existent index."""
|
||||||
|
# Create a vector instance with a non-existent collection name
|
||||||
|
self.vector._client.indices.exists.return_value = False
|
||||||
|
|
||||||
|
# Should not raise an exception
|
||||||
|
self.vector.delete()
|
||||||
|
|
||||||
|
# Verify that exists was called but delete was not
|
||||||
|
self.vector._client.indices.exists.assert_called_once_with(index=self.collection_name.lower())
|
||||||
|
self.vector._client.indices.delete.assert_not_called()
|
||||||
|
|
||||||
|
def test_delete_existing_index(self):
|
||||||
|
"""Test deleting an existing index."""
|
||||||
|
self.vector._client.indices.exists.return_value = True
|
||||||
|
|
||||||
|
self.vector.delete()
|
||||||
|
|
||||||
|
# Verify both exists and delete were called
|
||||||
|
self.vector._client.indices.exists.assert_called_once_with(index=self.collection_name.lower())
|
||||||
|
self.vector._client.indices.delete.assert_called_once_with(index=self.collection_name.lower())
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("setup_mock_redis")
|
@pytest.mark.usefixtures("setup_mock_redis")
|
||||||
class TestOpenSearchVectorWithRedis:
|
class TestOpenSearchVectorWithRedis:
|
||||||
|
|||||||
@ -20,12 +20,21 @@ class TestMCPToolManageService:
|
|||||||
patch("services.tools.mcp_tools_manage_service.ToolTransformService") as mock_tool_transform_service,
|
patch("services.tools.mcp_tools_manage_service.ToolTransformService") as mock_tool_transform_service,
|
||||||
):
|
):
|
||||||
# Setup default mock returns
|
# Setup default mock returns
|
||||||
|
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||||
|
from core.tools.entities.common_entities import I18nObject
|
||||||
|
|
||||||
mock_encrypter.encrypt_token.return_value = "encrypted_server_url"
|
mock_encrypter.encrypt_token.return_value = "encrypted_server_url"
|
||||||
mock_tool_transform_service.mcp_provider_to_user_provider.return_value = {
|
mock_tool_transform_service.mcp_provider_to_user_provider.return_value = ToolProviderApiEntity(
|
||||||
"id": "test_id",
|
id="test_id",
|
||||||
"name": "test_name",
|
author="test_author",
|
||||||
"type": ToolProviderType.MCP,
|
name="test_name",
|
||||||
}
|
type=ToolProviderType.MCP,
|
||||||
|
description=I18nObject(en_US="Test Description", zh_Hans="测试描述"),
|
||||||
|
icon={"type": "emoji", "content": "🤖"},
|
||||||
|
label=I18nObject(en_US="Test Label", zh_Hans="测试标签"),
|
||||||
|
labels=[],
|
||||||
|
tools=[],
|
||||||
|
)
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"encrypter": mock_encrypter,
|
"encrypter": mock_encrypter,
|
||||||
@ -104,9 +113,9 @@ class TestMCPToolManageService:
|
|||||||
mcp_provider = MCPToolProvider(
|
mcp_provider = MCPToolProvider(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
name=fake.company(),
|
name=fake.company(),
|
||||||
server_identifier=fake.uuid4(),
|
server_identifier=str(fake.uuid4()),
|
||||||
server_url="encrypted_server_url",
|
server_url="encrypted_server_url",
|
||||||
server_url_hash=fake.sha256(),
|
server_url_hash=str(fake.sha256()),
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
authed=False,
|
authed=False,
|
||||||
tools="[]",
|
tools="[]",
|
||||||
@ -144,7 +153,10 @@ class TestMCPToolManageService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Act: Execute the method under test
|
# Act: Execute the method under test
|
||||||
result = MCPToolManageService.get_mcp_provider_by_provider_id(mcp_provider.id, tenant.id)
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
service = MCPToolManageService(db.session())
|
||||||
|
result = service.get_provider(provider_id=mcp_provider.id, tenant_id=tenant.id)
|
||||||
|
|
||||||
# Assert: Verify the expected outcomes
|
# Assert: Verify the expected outcomes
|
||||||
assert result is not None
|
assert result is not None
|
||||||
@ -154,8 +166,6 @@ class TestMCPToolManageService:
|
|||||||
assert result.user_id == account.id
|
assert result.user_id == account.id
|
||||||
|
|
||||||
# Verify database state
|
# Verify database state
|
||||||
from extensions.ext_database import db
|
|
||||||
|
|
||||||
db.session.refresh(result)
|
db.session.refresh(result)
|
||||||
assert result.id is not None
|
assert result.id is not None
|
||||||
assert result.server_identifier == mcp_provider.server_identifier
|
assert result.server_identifier == mcp_provider.server_identifier
|
||||||
@ -177,11 +187,14 @@ class TestMCPToolManageService:
|
|||||||
db_session_with_containers, mock_external_service_dependencies
|
db_session_with_containers, mock_external_service_dependencies
|
||||||
)
|
)
|
||||||
|
|
||||||
non_existent_id = fake.uuid4()
|
non_existent_id = str(fake.uuid4())
|
||||||
|
|
||||||
# Act & Assert: Verify proper error handling
|
# Act & Assert: Verify proper error handling
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
service = MCPToolManageService(db.session())
|
||||||
with pytest.raises(ValueError, match="MCP tool not found"):
|
with pytest.raises(ValueError, match="MCP tool not found"):
|
||||||
MCPToolManageService.get_mcp_provider_by_provider_id(non_existent_id, tenant.id)
|
service.get_provider(provider_id=non_existent_id, tenant_id=tenant.id)
|
||||||
|
|
||||||
def test_get_mcp_provider_by_provider_id_tenant_isolation(
|
def test_get_mcp_provider_by_provider_id_tenant_isolation(
|
||||||
self, db_session_with_containers, mock_external_service_dependencies
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
@ -210,8 +223,11 @@ class TestMCPToolManageService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Act & Assert: Verify tenant isolation
|
# Act & Assert: Verify tenant isolation
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
service = MCPToolManageService(db.session())
|
||||||
with pytest.raises(ValueError, match="MCP tool not found"):
|
with pytest.raises(ValueError, match="MCP tool not found"):
|
||||||
MCPToolManageService.get_mcp_provider_by_provider_id(mcp_provider1.id, tenant2.id)
|
service.get_provider(provider_id=mcp_provider1.id, tenant_id=tenant2.id)
|
||||||
|
|
||||||
def test_get_mcp_provider_by_server_identifier_success(
|
def test_get_mcp_provider_by_server_identifier_success(
|
||||||
self, db_session_with_containers, mock_external_service_dependencies
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
@ -235,7 +251,10 @@ class TestMCPToolManageService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Act: Execute the method under test
|
# Act: Execute the method under test
|
||||||
result = MCPToolManageService.get_mcp_provider_by_server_identifier(mcp_provider.server_identifier, tenant.id)
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
service = MCPToolManageService(db.session())
|
||||||
|
result = service.get_provider(server_identifier=mcp_provider.server_identifier, tenant_id=tenant.id)
|
||||||
|
|
||||||
# Assert: Verify the expected outcomes
|
# Assert: Verify the expected outcomes
|
||||||
assert result is not None
|
assert result is not None
|
||||||
@ -245,8 +264,6 @@ class TestMCPToolManageService:
|
|||||||
assert result.user_id == account.id
|
assert result.user_id == account.id
|
||||||
|
|
||||||
# Verify database state
|
# Verify database state
|
||||||
from extensions.ext_database import db
|
|
||||||
|
|
||||||
db.session.refresh(result)
|
db.session.refresh(result)
|
||||||
assert result.id is not None
|
assert result.id is not None
|
||||||
assert result.name == mcp_provider.name
|
assert result.name == mcp_provider.name
|
||||||
@ -268,11 +285,14 @@ class TestMCPToolManageService:
|
|||||||
db_session_with_containers, mock_external_service_dependencies
|
db_session_with_containers, mock_external_service_dependencies
|
||||||
)
|
)
|
||||||
|
|
||||||
non_existent_identifier = fake.uuid4()
|
non_existent_identifier = str(fake.uuid4())
|
||||||
|
|
||||||
# Act & Assert: Verify proper error handling
|
# Act & Assert: Verify proper error handling
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
service = MCPToolManageService(db.session())
|
||||||
with pytest.raises(ValueError, match="MCP tool not found"):
|
with pytest.raises(ValueError, match="MCP tool not found"):
|
||||||
MCPToolManageService.get_mcp_provider_by_server_identifier(non_existent_identifier, tenant.id)
|
service.get_provider(server_identifier=non_existent_identifier, tenant_id=tenant.id)
|
||||||
|
|
||||||
def test_get_mcp_provider_by_server_identifier_tenant_isolation(
|
def test_get_mcp_provider_by_server_identifier_tenant_isolation(
|
||||||
self, db_session_with_containers, mock_external_service_dependencies
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
@ -301,8 +321,11 @@ class TestMCPToolManageService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Act & Assert: Verify tenant isolation
|
# Act & Assert: Verify tenant isolation
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
service = MCPToolManageService(db.session())
|
||||||
with pytest.raises(ValueError, match="MCP tool not found"):
|
with pytest.raises(ValueError, match="MCP tool not found"):
|
||||||
MCPToolManageService.get_mcp_provider_by_server_identifier(mcp_provider1.server_identifier, tenant2.id)
|
service.get_provider(server_identifier=mcp_provider1.server_identifier, tenant_id=tenant2.id)
|
||||||
|
|
||||||
def test_create_mcp_provider_success(self, db_session_with_containers, mock_external_service_dependencies):
|
def test_create_mcp_provider_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||||
"""
|
"""
|
||||||
@ -322,15 +345,30 @@ class TestMCPToolManageService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup mocks for provider creation
|
# Setup mocks for provider creation
|
||||||
|
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||||
|
from core.tools.entities.common_entities import I18nObject
|
||||||
|
|
||||||
mock_external_service_dependencies["encrypter"].encrypt_token.return_value = "encrypted_server_url"
|
mock_external_service_dependencies["encrypter"].encrypt_token.return_value = "encrypted_server_url"
|
||||||
mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.return_value = {
|
mock_external_service_dependencies[
|
||||||
"id": "new_provider_id",
|
"tool_transform_service"
|
||||||
"name": "Test MCP Provider",
|
].mcp_provider_to_user_provider.return_value = ToolProviderApiEntity(
|
||||||
"type": ToolProviderType.MCP,
|
id="new_provider_id",
|
||||||
}
|
author=account.name,
|
||||||
|
name="Test MCP Provider",
|
||||||
|
type=ToolProviderType.MCP,
|
||||||
|
description=I18nObject(en_US="Test MCP Provider Description", zh_Hans="测试MCP提供者描述"),
|
||||||
|
icon={"type": "emoji", "content": "🤖"},
|
||||||
|
label=I18nObject(en_US="Test MCP Provider", zh_Hans="测试MCP提供者"),
|
||||||
|
labels=[],
|
||||||
|
tools=[],
|
||||||
|
)
|
||||||
|
|
||||||
# Act: Execute the method under test
|
# Act: Execute the method under test
|
||||||
result = MCPToolManageService.create_mcp_provider(
|
from core.entities.mcp_provider import MCPConfiguration
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
service = MCPToolManageService(db.session())
|
||||||
|
result = service.create_provider(
|
||||||
tenant_id=tenant.id,
|
tenant_id=tenant.id,
|
||||||
name="Test MCP Provider",
|
name="Test MCP Provider",
|
||||||
server_url="https://example.com/mcp",
|
server_url="https://example.com/mcp",
|
||||||
@ -339,14 +377,16 @@ class TestMCPToolManageService:
|
|||||||
icon_type="emoji",
|
icon_type="emoji",
|
||||||
icon_background="#FF6B6B",
|
icon_background="#FF6B6B",
|
||||||
server_identifier="test_identifier_123",
|
server_identifier="test_identifier_123",
|
||||||
timeout=30.0,
|
configuration=MCPConfiguration(
|
||||||
sse_read_timeout=300.0,
|
timeout=30.0,
|
||||||
|
sse_read_timeout=300.0,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert: Verify the expected outcomes
|
# Assert: Verify the expected outcomes
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result["name"] == "Test MCP Provider"
|
assert result.name == "Test MCP Provider"
|
||||||
assert result["type"] == ToolProviderType.MCP
|
assert result.type == ToolProviderType.MCP
|
||||||
|
|
||||||
# Verify database state
|
# Verify database state
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -386,7 +426,11 @@ class TestMCPToolManageService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create first provider
|
# Create first provider
|
||||||
MCPToolManageService.create_mcp_provider(
|
from core.entities.mcp_provider import MCPConfiguration
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
service = MCPToolManageService(db.session())
|
||||||
|
service.create_provider(
|
||||||
tenant_id=tenant.id,
|
tenant_id=tenant.id,
|
||||||
name="Test MCP Provider",
|
name="Test MCP Provider",
|
||||||
server_url="https://example1.com/mcp",
|
server_url="https://example1.com/mcp",
|
||||||
@ -395,13 +439,15 @@ class TestMCPToolManageService:
|
|||||||
icon_type="emoji",
|
icon_type="emoji",
|
||||||
icon_background="#FF6B6B",
|
icon_background="#FF6B6B",
|
||||||
server_identifier="test_identifier_1",
|
server_identifier="test_identifier_1",
|
||||||
timeout=30.0,
|
configuration=MCPConfiguration(
|
||||||
sse_read_timeout=300.0,
|
timeout=30.0,
|
||||||
|
sse_read_timeout=300.0,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Act & Assert: Verify proper error handling for duplicate name
|
# Act & Assert: Verify proper error handling for duplicate name
|
||||||
with pytest.raises(ValueError, match="MCP tool Test MCP Provider already exists"):
|
with pytest.raises(ValueError, match="MCP tool Test MCP Provider already exists"):
|
||||||
MCPToolManageService.create_mcp_provider(
|
service.create_provider(
|
||||||
tenant_id=tenant.id,
|
tenant_id=tenant.id,
|
||||||
name="Test MCP Provider", # Duplicate name
|
name="Test MCP Provider", # Duplicate name
|
||||||
server_url="https://example2.com/mcp",
|
server_url="https://example2.com/mcp",
|
||||||
@ -410,8 +456,10 @@ class TestMCPToolManageService:
|
|||||||
icon_type="emoji",
|
icon_type="emoji",
|
||||||
icon_background="#4ECDC4",
|
icon_background="#4ECDC4",
|
||||||
server_identifier="test_identifier_2",
|
server_identifier="test_identifier_2",
|
||||||
timeout=45.0,
|
configuration=MCPConfiguration(
|
||||||
sse_read_timeout=400.0,
|
timeout=45.0,
|
||||||
|
sse_read_timeout=400.0,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_create_mcp_provider_duplicate_server_url(
|
def test_create_mcp_provider_duplicate_server_url(
|
||||||
@ -432,7 +480,11 @@ class TestMCPToolManageService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create first provider
|
# Create first provider
|
||||||
MCPToolManageService.create_mcp_provider(
|
from core.entities.mcp_provider import MCPConfiguration
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
service = MCPToolManageService(db.session())
|
||||||
|
service.create_provider(
|
||||||
tenant_id=tenant.id,
|
tenant_id=tenant.id,
|
||||||
name="Test MCP Provider 1",
|
name="Test MCP Provider 1",
|
||||||
server_url="https://example.com/mcp",
|
server_url="https://example.com/mcp",
|
||||||
@ -441,13 +493,15 @@ class TestMCPToolManageService:
|
|||||||
icon_type="emoji",
|
icon_type="emoji",
|
||||||
icon_background="#FF6B6B",
|
icon_background="#FF6B6B",
|
||||||
server_identifier="test_identifier_1",
|
server_identifier="test_identifier_1",
|
||||||
timeout=30.0,
|
configuration=MCPConfiguration(
|
||||||
sse_read_timeout=300.0,
|
timeout=30.0,
|
||||||
|
sse_read_timeout=300.0,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Act & Assert: Verify proper error handling for duplicate server URL
|
# Act & Assert: Verify proper error handling for duplicate server URL
|
||||||
with pytest.raises(ValueError, match="MCP tool https://example.com/mcp already exists"):
|
with pytest.raises(ValueError, match="MCP tool with this server URL already exists"):
|
||||||
MCPToolManageService.create_mcp_provider(
|
service.create_provider(
|
||||||
tenant_id=tenant.id,
|
tenant_id=tenant.id,
|
||||||
name="Test MCP Provider 2",
|
name="Test MCP Provider 2",
|
||||||
server_url="https://example.com/mcp", # Duplicate URL
|
server_url="https://example.com/mcp", # Duplicate URL
|
||||||
@ -456,8 +510,10 @@ class TestMCPToolManageService:
|
|||||||
icon_type="emoji",
|
icon_type="emoji",
|
||||||
icon_background="#4ECDC4",
|
icon_background="#4ECDC4",
|
||||||
server_identifier="test_identifier_2",
|
server_identifier="test_identifier_2",
|
||||||
timeout=45.0,
|
configuration=MCPConfiguration(
|
||||||
sse_read_timeout=400.0,
|
timeout=45.0,
|
||||||
|
sse_read_timeout=400.0,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_create_mcp_provider_duplicate_server_identifier(
|
def test_create_mcp_provider_duplicate_server_identifier(
|
||||||
@ -478,7 +534,11 @@ class TestMCPToolManageService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create first provider
|
# Create first provider
|
||||||
MCPToolManageService.create_mcp_provider(
|
from core.entities.mcp_provider import MCPConfiguration
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
service = MCPToolManageService(db.session())
|
||||||
|
service.create_provider(
|
||||||
tenant_id=tenant.id,
|
tenant_id=tenant.id,
|
||||||
name="Test MCP Provider 1",
|
name="Test MCP Provider 1",
|
||||||
server_url="https://example1.com/mcp",
|
server_url="https://example1.com/mcp",
|
||||||
@ -487,13 +547,15 @@ class TestMCPToolManageService:
|
|||||||
icon_type="emoji",
|
icon_type="emoji",
|
||||||
icon_background="#FF6B6B",
|
icon_background="#FF6B6B",
|
||||||
server_identifier="test_identifier_123",
|
server_identifier="test_identifier_123",
|
||||||
timeout=30.0,
|
configuration=MCPConfiguration(
|
||||||
sse_read_timeout=300.0,
|
timeout=30.0,
|
||||||
|
sse_read_timeout=300.0,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Act & Assert: Verify proper error handling for duplicate server identifier
|
# Act & Assert: Verify proper error handling for duplicate server identifier
|
||||||
with pytest.raises(ValueError, match="MCP tool test_identifier_123 already exists"):
|
with pytest.raises(ValueError, match="MCP tool test_identifier_123 already exists"):
|
||||||
MCPToolManageService.create_mcp_provider(
|
service.create_provider(
|
||||||
tenant_id=tenant.id,
|
tenant_id=tenant.id,
|
||||||
name="Test MCP Provider 2",
|
name="Test MCP Provider 2",
|
||||||
server_url="https://example2.com/mcp",
|
server_url="https://example2.com/mcp",
|
||||||
@ -502,8 +564,10 @@ class TestMCPToolManageService:
|
|||||||
icon_type="emoji",
|
icon_type="emoji",
|
||||||
icon_background="#4ECDC4",
|
icon_background="#4ECDC4",
|
||||||
server_identifier="test_identifier_123", # Duplicate identifier
|
server_identifier="test_identifier_123", # Duplicate identifier
|
||||||
timeout=45.0,
|
configuration=MCPConfiguration(
|
||||||
sse_read_timeout=400.0,
|
timeout=45.0,
|
||||||
|
sse_read_timeout=400.0,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_retrieve_mcp_tools_success(self, db_session_with_containers, mock_external_service_dependencies):
|
def test_retrieve_mcp_tools_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||||
@ -543,23 +607,59 @@ class TestMCPToolManageService:
|
|||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
# Setup mock for transformation service
|
# Setup mock for transformation service
|
||||||
|
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||||
|
from core.tools.entities.common_entities import I18nObject
|
||||||
|
|
||||||
mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.side_effect = [
|
mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.side_effect = [
|
||||||
{"id": provider1.id, "name": provider1.name, "type": ToolProviderType.MCP},
|
ToolProviderApiEntity(
|
||||||
{"id": provider2.id, "name": provider2.name, "type": ToolProviderType.MCP},
|
id=provider1.id,
|
||||||
{"id": provider3.id, "name": provider3.name, "type": ToolProviderType.MCP},
|
author=account.name,
|
||||||
|
name=provider1.name,
|
||||||
|
type=ToolProviderType.MCP,
|
||||||
|
description=I18nObject(en_US="Alpha Provider Description", zh_Hans="Alpha提供者描述"),
|
||||||
|
icon={"type": "emoji", "content": "🅰️"},
|
||||||
|
label=I18nObject(en_US=provider1.name, zh_Hans=provider1.name),
|
||||||
|
labels=[],
|
||||||
|
tools=[],
|
||||||
|
),
|
||||||
|
ToolProviderApiEntity(
|
||||||
|
id=provider2.id,
|
||||||
|
author=account.name,
|
||||||
|
name=provider2.name,
|
||||||
|
type=ToolProviderType.MCP,
|
||||||
|
description=I18nObject(en_US="Beta Provider Description", zh_Hans="Beta提供者描述"),
|
||||||
|
icon={"type": "emoji", "content": "🅱️"},
|
||||||
|
label=I18nObject(en_US=provider2.name, zh_Hans=provider2.name),
|
||||||
|
labels=[],
|
||||||
|
tools=[],
|
||||||
|
),
|
||||||
|
ToolProviderApiEntity(
|
||||||
|
id=provider3.id,
|
||||||
|
author=account.name,
|
||||||
|
name=provider3.name,
|
||||||
|
type=ToolProviderType.MCP,
|
||||||
|
description=I18nObject(en_US="Gamma Provider Description", zh_Hans="Gamma提供者描述"),
|
||||||
|
icon={"type": "emoji", "content": "Γ"},
|
||||||
|
label=I18nObject(en_US=provider3.name, zh_Hans=provider3.name),
|
||||||
|
labels=[],
|
||||||
|
tools=[],
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Act: Execute the method under test
|
# Act: Execute the method under test
|
||||||
result = MCPToolManageService.retrieve_mcp_tools(tenant.id, for_list=True)
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
service = MCPToolManageService(db.session())
|
||||||
|
result = service.list_providers(tenant_id=tenant.id, for_list=True)
|
||||||
|
|
||||||
# Assert: Verify the expected outcomes
|
# Assert: Verify the expected outcomes
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert len(result) == 3
|
assert len(result) == 3
|
||||||
|
|
||||||
# Verify correct ordering by name
|
# Verify correct ordering by name
|
||||||
assert result[0]["name"] == "Alpha Provider"
|
assert result[0].name == "Alpha Provider"
|
||||||
assert result[1]["name"] == "Beta Provider"
|
assert result[1].name == "Beta Provider"
|
||||||
assert result[2]["name"] == "Gamma Provider"
|
assert result[2].name == "Gamma Provider"
|
||||||
|
|
||||||
# Verify mock interactions
|
# Verify mock interactions
|
||||||
assert (
|
assert (
|
||||||
@ -584,7 +684,10 @@ class TestMCPToolManageService:
|
|||||||
# No MCP providers created for this tenant
|
# No MCP providers created for this tenant
|
||||||
|
|
||||||
# Act: Execute the method under test
|
# Act: Execute the method under test
|
||||||
result = MCPToolManageService.retrieve_mcp_tools(tenant.id, for_list=False)
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
service = MCPToolManageService(db.session())
|
||||||
|
result = service.list_providers(tenant_id=tenant.id, for_list=False)
|
||||||
|
|
||||||
# Assert: Verify the expected outcomes
|
# Assert: Verify the expected outcomes
|
||||||
assert result is not None
|
assert result is not None
|
||||||
@ -624,20 +727,46 @@ class TestMCPToolManageService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup mock for transformation service
|
# Setup mock for transformation service
|
||||||
|
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||||
|
from core.tools.entities.common_entities import I18nObject
|
||||||
|
|
||||||
mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.side_effect = [
|
mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.side_effect = [
|
||||||
{"id": provider1.id, "name": provider1.name, "type": ToolProviderType.MCP},
|
ToolProviderApiEntity(
|
||||||
{"id": provider2.id, "name": provider2.name, "type": ToolProviderType.MCP},
|
id=provider1.id,
|
||||||
|
author=account1.name,
|
||||||
|
name=provider1.name,
|
||||||
|
type=ToolProviderType.MCP,
|
||||||
|
description=I18nObject(en_US="Provider 1 Description", zh_Hans="提供者1描述"),
|
||||||
|
icon={"type": "emoji", "content": "1️⃣"},
|
||||||
|
label=I18nObject(en_US=provider1.name, zh_Hans=provider1.name),
|
||||||
|
labels=[],
|
||||||
|
tools=[],
|
||||||
|
),
|
||||||
|
ToolProviderApiEntity(
|
||||||
|
id=provider2.id,
|
||||||
|
author=account2.name,
|
||||||
|
name=provider2.name,
|
||||||
|
type=ToolProviderType.MCP,
|
||||||
|
description=I18nObject(en_US="Provider 2 Description", zh_Hans="提供者2描述"),
|
||||||
|
icon={"type": "emoji", "content": "2️⃣"},
|
||||||
|
label=I18nObject(en_US=provider2.name, zh_Hans=provider2.name),
|
||||||
|
labels=[],
|
||||||
|
tools=[],
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Act: Execute the method under test for both tenants
|
# Act: Execute the method under test for both tenants
|
||||||
result1 = MCPToolManageService.retrieve_mcp_tools(tenant1.id, for_list=True)
|
from extensions.ext_database import db
|
||||||
result2 = MCPToolManageService.retrieve_mcp_tools(tenant2.id, for_list=True)
|
|
||||||
|
service = MCPToolManageService(db.session())
|
||||||
|
result1 = service.list_providers(tenant_id=tenant1.id, for_list=True)
|
||||||
|
result2 = service.list_providers(tenant_id=tenant2.id, for_list=True)
|
||||||
|
|
||||||
# Assert: Verify tenant isolation
|
# Assert: Verify tenant isolation
|
||||||
assert len(result1) == 1
|
assert len(result1) == 1
|
||||||
assert len(result2) == 1
|
assert len(result2) == 1
|
||||||
assert result1[0]["id"] == provider1.id
|
assert result1[0].id == provider1.id
|
||||||
assert result2[0]["id"] == provider2.id
|
assert result2[0].id == provider2.id
|
||||||
|
|
||||||
def test_list_mcp_tool_from_remote_server_success(
|
def test_list_mcp_tool_from_remote_server_success(
|
||||||
self, db_session_with_containers, mock_external_service_dependencies
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
@ -661,17 +790,20 @@ class TestMCPToolManageService:
|
|||||||
mcp_provider = self._create_test_mcp_provider(
|
mcp_provider = self._create_test_mcp_provider(
|
||||||
db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id
|
db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id
|
||||||
)
|
)
|
||||||
mcp_provider.server_url = "encrypted_server_url"
|
# Use a valid base64 encoded string to avoid decryption errors
|
||||||
mcp_provider.authed = False
|
import base64
|
||||||
|
|
||||||
|
mcp_provider.server_url = base64.b64encode(b"encrypted_server_url").decode()
|
||||||
|
mcp_provider.authed = True # Provider must be authenticated to list tools
|
||||||
mcp_provider.tools = "[]"
|
mcp_provider.tools = "[]"
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
# Mock the decrypted_server_url property to avoid encryption issues
|
# Mock the decryption process at the rsa level to avoid key file issues
|
||||||
with patch("models.tools.encrypter") as mock_encrypter:
|
with patch("libs.rsa.decrypt") as mock_decrypt:
|
||||||
mock_encrypter.decrypt_token.return_value = "https://example.com/mcp"
|
mock_decrypt.return_value = "https://example.com/mcp"
|
||||||
|
|
||||||
# Mock MCPClient and its context manager
|
# Mock MCPClient and its context manager
|
||||||
mock_tools = [
|
mock_tools = [
|
||||||
@ -683,13 +815,16 @@ class TestMCPToolManageService:
|
|||||||
)(),
|
)(),
|
||||||
]
|
]
|
||||||
|
|
||||||
with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client:
|
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
|
||||||
# Setup mock client
|
# Setup mock client
|
||||||
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
||||||
mock_client_instance.list_tools.return_value = mock_tools
|
mock_client_instance.list_tools.return_value = mock_tools
|
||||||
|
|
||||||
# Act: Execute the method under test
|
# Act: Execute the method under test
|
||||||
result = MCPToolManageService.list_mcp_tool_from_remote_server(tenant.id, mcp_provider.id)
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
service = MCPToolManageService(db.session())
|
||||||
|
result = service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id)
|
||||||
|
|
||||||
# Assert: Verify the expected outcomes
|
# Assert: Verify the expected outcomes
|
||||||
assert result is not None
|
assert result is not None
|
||||||
@ -705,16 +840,8 @@ class TestMCPToolManageService:
|
|||||||
assert mcp_provider.updated_at is not None
|
assert mcp_provider.updated_at is not None
|
||||||
|
|
||||||
# Verify mock interactions
|
# Verify mock interactions
|
||||||
mock_mcp_client.assert_called_once_with(
|
# MCPClientWithAuthRetry is called with different parameters
|
||||||
"https://example.com/mcp",
|
mock_mcp_client.assert_called_once()
|
||||||
mcp_provider.id,
|
|
||||||
tenant.id,
|
|
||||||
authed=False,
|
|
||||||
for_list=True,
|
|
||||||
headers={},
|
|
||||||
timeout=30.0,
|
|
||||||
sse_read_timeout=300.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_list_mcp_tool_from_remote_server_auth_error(
|
def test_list_mcp_tool_from_remote_server_auth_error(
|
||||||
self, db_session_with_containers, mock_external_service_dependencies
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
@ -737,7 +864,10 @@ class TestMCPToolManageService:
|
|||||||
mcp_provider = self._create_test_mcp_provider(
|
mcp_provider = self._create_test_mcp_provider(
|
||||||
db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id
|
db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id
|
||||||
)
|
)
|
||||||
mcp_provider.server_url = "encrypted_server_url"
|
# Use a valid base64 encoded string to avoid decryption errors
|
||||||
|
import base64
|
||||||
|
|
||||||
|
mcp_provider.server_url = base64.b64encode(b"encrypted_server_url").decode()
|
||||||
mcp_provider.authed = False
|
mcp_provider.authed = False
|
||||||
mcp_provider.tools = "[]"
|
mcp_provider.tools = "[]"
|
||||||
|
|
||||||
@ -745,20 +875,23 @@ class TestMCPToolManageService:
|
|||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
# Mock the decrypted_server_url property to avoid encryption issues
|
# Mock the decryption process at the rsa level to avoid key file issues
|
||||||
with patch("models.tools.encrypter") as mock_encrypter:
|
with patch("libs.rsa.decrypt") as mock_decrypt:
|
||||||
mock_encrypter.decrypt_token.return_value = "https://example.com/mcp"
|
mock_decrypt.return_value = "https://example.com/mcp"
|
||||||
|
|
||||||
# Mock MCPClient to raise authentication error
|
# Mock MCPClient to raise authentication error
|
||||||
with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client:
|
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
|
||||||
from core.mcp.error import MCPAuthError
|
from core.mcp.error import MCPAuthError
|
||||||
|
|
||||||
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
||||||
mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required")
|
mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required")
|
||||||
|
|
||||||
# Act & Assert: Verify proper error handling
|
# Act & Assert: Verify proper error handling
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
service = MCPToolManageService(db.session())
|
||||||
with pytest.raises(ValueError, match="Please auth the tool first"):
|
with pytest.raises(ValueError, match="Please auth the tool first"):
|
||||||
MCPToolManageService.list_mcp_tool_from_remote_server(tenant.id, mcp_provider.id)
|
service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id)
|
||||||
|
|
||||||
# Verify database state was not changed
|
# Verify database state was not changed
|
||||||
db.session.refresh(mcp_provider)
|
db.session.refresh(mcp_provider)
|
||||||
@ -786,32 +919,38 @@ class TestMCPToolManageService:
|
|||||||
mcp_provider = self._create_test_mcp_provider(
|
mcp_provider = self._create_test_mcp_provider(
|
||||||
db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id
|
db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id
|
||||||
)
|
)
|
||||||
mcp_provider.server_url = "encrypted_server_url"
|
# Use a valid base64 encoded string to avoid decryption errors
|
||||||
mcp_provider.authed = False
|
import base64
|
||||||
|
|
||||||
|
mcp_provider.server_url = base64.b64encode(b"encrypted_server_url").decode()
|
||||||
|
mcp_provider.authed = True # Provider must be authenticated to test connection errors
|
||||||
mcp_provider.tools = "[]"
|
mcp_provider.tools = "[]"
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
# Mock the decrypted_server_url property to avoid encryption issues
|
# Mock the decryption process at the rsa level to avoid key file issues
|
||||||
with patch("models.tools.encrypter") as mock_encrypter:
|
with patch("libs.rsa.decrypt") as mock_decrypt:
|
||||||
mock_encrypter.decrypt_token.return_value = "https://example.com/mcp"
|
mock_decrypt.return_value = "https://example.com/mcp"
|
||||||
|
|
||||||
# Mock MCPClient to raise connection error
|
# Mock MCPClient to raise connection error
|
||||||
with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client:
|
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
|
||||||
from core.mcp.error import MCPError
|
from core.mcp.error import MCPError
|
||||||
|
|
||||||
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
||||||
mock_client_instance.list_tools.side_effect = MCPError("Connection failed")
|
mock_client_instance.list_tools.side_effect = MCPError("Connection failed")
|
||||||
|
|
||||||
# Act & Assert: Verify proper error handling
|
# Act & Assert: Verify proper error handling
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
service = MCPToolManageService(db.session())
|
||||||
with pytest.raises(ValueError, match="Failed to connect to MCP server: Connection failed"):
|
with pytest.raises(ValueError, match="Failed to connect to MCP server: Connection failed"):
|
||||||
MCPToolManageService.list_mcp_tool_from_remote_server(tenant.id, mcp_provider.id)
|
service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id)
|
||||||
|
|
||||||
# Verify database state was not changed
|
# Verify database state was not changed
|
||||||
db.session.refresh(mcp_provider)
|
db.session.refresh(mcp_provider)
|
||||||
assert mcp_provider.authed is False
|
assert mcp_provider.authed is True # Provider remains authenticated
|
||||||
assert mcp_provider.tools == "[]"
|
assert mcp_provider.tools == "[]"
|
||||||
|
|
||||||
def test_delete_mcp_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
|
def test_delete_mcp_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||||
@ -840,7 +979,8 @@ class TestMCPToolManageService:
|
|||||||
assert db.session.query(MCPToolProvider).filter_by(id=mcp_provider.id).first() is not None
|
assert db.session.query(MCPToolProvider).filter_by(id=mcp_provider.id).first() is not None
|
||||||
|
|
||||||
# Act: Execute the method under test
|
# Act: Execute the method under test
|
||||||
MCPToolManageService.delete_mcp_tool(tenant.id, mcp_provider.id)
|
service = MCPToolManageService(db.session())
|
||||||
|
service.delete_provider(tenant_id=tenant.id, provider_id=mcp_provider.id)
|
||||||
|
|
||||||
# Assert: Verify the expected outcomes
|
# Assert: Verify the expected outcomes
|
||||||
# Provider should be deleted from database
|
# Provider should be deleted from database
|
||||||
@ -862,11 +1002,14 @@ class TestMCPToolManageService:
|
|||||||
db_session_with_containers, mock_external_service_dependencies
|
db_session_with_containers, mock_external_service_dependencies
|
||||||
)
|
)
|
||||||
|
|
||||||
non_existent_id = fake.uuid4()
|
non_existent_id = str(fake.uuid4())
|
||||||
|
|
||||||
# Act & Assert: Verify proper error handling
|
# Act & Assert: Verify proper error handling
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
service = MCPToolManageService(db.session())
|
||||||
with pytest.raises(ValueError, match="MCP tool not found"):
|
with pytest.raises(ValueError, match="MCP tool not found"):
|
||||||
MCPToolManageService.delete_mcp_tool(tenant.id, non_existent_id)
|
service.delete_provider(tenant_id=tenant.id, provider_id=non_existent_id)
|
||||||
|
|
||||||
def test_delete_mcp_tool_tenant_isolation(self, db_session_with_containers, mock_external_service_dependencies):
|
def test_delete_mcp_tool_tenant_isolation(self, db_session_with_containers, mock_external_service_dependencies):
|
||||||
"""
|
"""
|
||||||
@ -893,8 +1036,11 @@ class TestMCPToolManageService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Act & Assert: Verify tenant isolation
|
# Act & Assert: Verify tenant isolation
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
service = MCPToolManageService(db.session())
|
||||||
with pytest.raises(ValueError, match="MCP tool not found"):
|
with pytest.raises(ValueError, match="MCP tool not found"):
|
||||||
MCPToolManageService.delete_mcp_tool(tenant2.id, mcp_provider1.id)
|
service.delete_provider(tenant_id=tenant2.id, provider_id=mcp_provider1.id)
|
||||||
|
|
||||||
# Verify provider still exists in tenant1
|
# Verify provider still exists in tenant1
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -929,7 +1075,10 @@ class TestMCPToolManageService:
|
|||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
# Act: Execute the method under test
|
# Act: Execute the method under test
|
||||||
MCPToolManageService.update_mcp_provider(
|
from core.entities.mcp_provider import MCPConfiguration
|
||||||
|
|
||||||
|
service = MCPToolManageService(db.session())
|
||||||
|
service.update_provider(
|
||||||
tenant_id=tenant.id,
|
tenant_id=tenant.id,
|
||||||
provider_id=mcp_provider.id,
|
provider_id=mcp_provider.id,
|
||||||
name="Updated MCP Provider",
|
name="Updated MCP Provider",
|
||||||
@ -938,8 +1087,10 @@ class TestMCPToolManageService:
|
|||||||
icon_type="emoji",
|
icon_type="emoji",
|
||||||
icon_background="#4ECDC4",
|
icon_background="#4ECDC4",
|
||||||
server_identifier="updated_identifier_123",
|
server_identifier="updated_identifier_123",
|
||||||
timeout=45.0,
|
configuration=MCPConfiguration(
|
||||||
sse_read_timeout=400.0,
|
timeout=45.0,
|
||||||
|
sse_read_timeout=400.0,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert: Verify the expected outcomes
|
# Assert: Verify the expected outcomes
|
||||||
@ -953,70 +1104,10 @@ class TestMCPToolManageService:
|
|||||||
# Verify icon was updated
|
# Verify icon was updated
|
||||||
import json
|
import json
|
||||||
|
|
||||||
icon_data = json.loads(mcp_provider.icon)
|
icon_data = json.loads(mcp_provider.icon or "{}")
|
||||||
assert icon_data["content"] == "🚀"
|
assert icon_data["content"] == "🚀"
|
||||||
assert icon_data["background"] == "#4ECDC4"
|
assert icon_data["background"] == "#4ECDC4"
|
||||||
|
|
||||||
def test_update_mcp_provider_with_server_url_change(
|
|
||||||
self, db_session_with_containers, mock_external_service_dependencies
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Test successful update of MCP provider with server URL change.
|
|
||||||
|
|
||||||
This test verifies:
|
|
||||||
- Proper handling of server URL changes
|
|
||||||
- Correct reconnection logic
|
|
||||||
- Database state updates
|
|
||||||
- External service integration
|
|
||||||
"""
|
|
||||||
# Arrange: Create test data
|
|
||||||
fake = Faker()
|
|
||||||
account, tenant = self._create_test_account_and_tenant(
|
|
||||||
db_session_with_containers, mock_external_service_dependencies
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create MCP provider
|
|
||||||
mcp_provider = self._create_test_mcp_provider(
|
|
||||||
db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id
|
|
||||||
)
|
|
||||||
|
|
||||||
from extensions.ext_database import db
|
|
||||||
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Mock the reconnection method
|
|
||||||
with patch.object(MCPToolManageService, "_re_connect_mcp_provider") as mock_reconnect:
|
|
||||||
mock_reconnect.return_value = {
|
|
||||||
"authed": True,
|
|
||||||
"tools": '[{"name": "test_tool"}]',
|
|
||||||
"encrypted_credentials": "{}",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Act: Execute the method under test
|
|
||||||
MCPToolManageService.update_mcp_provider(
|
|
||||||
tenant_id=tenant.id,
|
|
||||||
provider_id=mcp_provider.id,
|
|
||||||
name="Updated MCP Provider",
|
|
||||||
server_url="https://new-example.com/mcp",
|
|
||||||
icon="🚀",
|
|
||||||
icon_type="emoji",
|
|
||||||
icon_background="#4ECDC4",
|
|
||||||
server_identifier="updated_identifier_123",
|
|
||||||
timeout=45.0,
|
|
||||||
sse_read_timeout=400.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assert: Verify the expected outcomes
|
|
||||||
db.session.refresh(mcp_provider)
|
|
||||||
assert mcp_provider.name == "Updated MCP Provider"
|
|
||||||
assert mcp_provider.server_identifier == "updated_identifier_123"
|
|
||||||
assert mcp_provider.timeout == 45.0
|
|
||||||
assert mcp_provider.sse_read_timeout == 400.0
|
|
||||||
assert mcp_provider.updated_at is not None
|
|
||||||
|
|
||||||
# Verify reconnection was called
|
|
||||||
mock_reconnect.assert_called_once_with("https://new-example.com/mcp", mcp_provider.id, tenant.id)
|
|
||||||
|
|
||||||
def test_update_mcp_provider_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies):
|
def test_update_mcp_provider_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies):
|
||||||
"""
|
"""
|
||||||
Test error handling when updating MCP provider with duplicate name.
|
Test error handling when updating MCP provider with duplicate name.
|
||||||
@ -1048,8 +1139,12 @@ class TestMCPToolManageService:
|
|||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
# Act & Assert: Verify proper error handling for duplicate name
|
# Act & Assert: Verify proper error handling for duplicate name
|
||||||
|
from core.entities.mcp_provider import MCPConfiguration
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
service = MCPToolManageService(db.session())
|
||||||
with pytest.raises(ValueError, match="MCP tool First Provider already exists"):
|
with pytest.raises(ValueError, match="MCP tool First Provider already exists"):
|
||||||
MCPToolManageService.update_mcp_provider(
|
service.update_provider(
|
||||||
tenant_id=tenant.id,
|
tenant_id=tenant.id,
|
||||||
provider_id=provider2.id,
|
provider_id=provider2.id,
|
||||||
name="First Provider", # Duplicate name
|
name="First Provider", # Duplicate name
|
||||||
@ -1058,8 +1153,10 @@ class TestMCPToolManageService:
|
|||||||
icon_type="emoji",
|
icon_type="emoji",
|
||||||
icon_background="#4ECDC4",
|
icon_background="#4ECDC4",
|
||||||
server_identifier="unique_identifier",
|
server_identifier="unique_identifier",
|
||||||
timeout=45.0,
|
configuration=MCPConfiguration(
|
||||||
sse_read_timeout=400.0,
|
timeout=45.0,
|
||||||
|
sse_read_timeout=400.0,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_update_mcp_provider_credentials_success(
|
def test_update_mcp_provider_credentials_success(
|
||||||
@ -1094,19 +1191,25 @@ class TestMCPToolManageService:
|
|||||||
|
|
||||||
# Mock the provider controller and encryption
|
# Mock the provider controller and encryption
|
||||||
with (
|
with (
|
||||||
patch("services.tools.mcp_tools_manage_service.MCPToolProviderController") as mock_controller,
|
patch("core.tools.mcp_tool.provider.MCPToolProviderController") as mock_controller,
|
||||||
patch("services.tools.mcp_tools_manage_service.ProviderConfigEncrypter") as mock_encrypter,
|
patch("core.tools.utils.encryption.ProviderConfigEncrypter") as mock_encrypter,
|
||||||
):
|
):
|
||||||
# Setup mocks
|
# Setup mocks
|
||||||
mock_controller_instance = mock_controller._from_db.return_value
|
mock_controller_instance = mock_controller.from_db.return_value
|
||||||
mock_controller_instance.get_credentials_schema.return_value = []
|
mock_controller_instance.get_credentials_schema.return_value = []
|
||||||
|
|
||||||
mock_encrypter_instance = mock_encrypter.return_value
|
mock_encrypter_instance = mock_encrypter.return_value
|
||||||
mock_encrypter_instance.encrypt.return_value = {"new_key": "encrypted_value"}
|
mock_encrypter_instance.encrypt.return_value = {"new_key": "encrypted_value"}
|
||||||
|
|
||||||
# Act: Execute the method under test
|
# Act: Execute the method under test
|
||||||
MCPToolManageService.update_mcp_provider_credentials(
|
from extensions.ext_database import db
|
||||||
mcp_provider=mcp_provider, credentials={"new_key": "new_value"}, authed=True
|
|
||||||
|
service = MCPToolManageService(db.session())
|
||||||
|
service.update_provider_credentials(
|
||||||
|
provider_id=mcp_provider.id,
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
credentials={"new_key": "new_value"},
|
||||||
|
authed=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert: Verify the expected outcomes
|
# Assert: Verify the expected outcomes
|
||||||
@ -1117,7 +1220,7 @@ class TestMCPToolManageService:
|
|||||||
# Verify credentials were encrypted and merged
|
# Verify credentials were encrypted and merged
|
||||||
import json
|
import json
|
||||||
|
|
||||||
credentials = json.loads(mcp_provider.encrypted_credentials)
|
credentials = json.loads(mcp_provider.encrypted_credentials or "{}")
|
||||||
assert "existing_key" in credentials
|
assert "existing_key" in credentials
|
||||||
assert "new_key" in credentials
|
assert "new_key" in credentials
|
||||||
|
|
||||||
@ -1152,19 +1255,25 @@ class TestMCPToolManageService:
|
|||||||
|
|
||||||
# Mock the provider controller and encryption
|
# Mock the provider controller and encryption
|
||||||
with (
|
with (
|
||||||
patch("services.tools.mcp_tools_manage_service.MCPToolProviderController") as mock_controller,
|
patch("core.tools.mcp_tool.provider.MCPToolProviderController") as mock_controller,
|
||||||
patch("services.tools.mcp_tools_manage_service.ProviderConfigEncrypter") as mock_encrypter,
|
patch("core.tools.utils.encryption.ProviderConfigEncrypter") as mock_encrypter,
|
||||||
):
|
):
|
||||||
# Setup mocks
|
# Setup mocks
|
||||||
mock_controller_instance = mock_controller._from_db.return_value
|
mock_controller_instance = mock_controller.from_db.return_value
|
||||||
mock_controller_instance.get_credentials_schema.return_value = []
|
mock_controller_instance.get_credentials_schema.return_value = []
|
||||||
|
|
||||||
mock_encrypter_instance = mock_encrypter.return_value
|
mock_encrypter_instance = mock_encrypter.return_value
|
||||||
mock_encrypter_instance.encrypt.return_value = {"new_key": "encrypted_value"}
|
mock_encrypter_instance.encrypt.return_value = {"new_key": "encrypted_value"}
|
||||||
|
|
||||||
# Act: Execute the method under test
|
# Act: Execute the method under test
|
||||||
MCPToolManageService.update_mcp_provider_credentials(
|
from extensions.ext_database import db
|
||||||
mcp_provider=mcp_provider, credentials={"new_key": "new_value"}, authed=False
|
|
||||||
|
service = MCPToolManageService(db.session())
|
||||||
|
service.update_provider_credentials(
|
||||||
|
provider_id=mcp_provider.id,
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
credentials={"new_key": "new_value"},
|
||||||
|
authed=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert: Verify the expected outcomes
|
# Assert: Verify the expected outcomes
|
||||||
@ -1199,41 +1308,37 @@ class TestMCPToolManageService:
|
|||||||
type("MockTool", (), {"model_dump": lambda self: {"name": "test_tool_2", "description": "Test tool 2"}})(),
|
type("MockTool", (), {"model_dump": lambda self: {"name": "test_tool_2", "description": "Test tool 2"}})(),
|
||||||
]
|
]
|
||||||
|
|
||||||
with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client:
|
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
|
||||||
# Setup mock client
|
# Setup mock client
|
||||||
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
||||||
mock_client_instance.list_tools.return_value = mock_tools
|
mock_client_instance.list_tools.return_value = mock_tools
|
||||||
|
|
||||||
# Act: Execute the method under test
|
# Act: Execute the method under test
|
||||||
result = MCPToolManageService._re_connect_mcp_provider(
|
from extensions.ext_database import db
|
||||||
"https://example.com/mcp", mcp_provider.id, tenant.id
|
|
||||||
|
service = MCPToolManageService(db.session())
|
||||||
|
result = service._reconnect_provider(
|
||||||
|
server_url="https://example.com/mcp",
|
||||||
|
provider=mcp_provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert: Verify the expected outcomes
|
# Assert: Verify the expected outcomes
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result["authed"] is True
|
assert result.authed is True
|
||||||
assert result["tools"] is not None
|
assert result.tools is not None
|
||||||
assert result["encrypted_credentials"] == "{}"
|
assert result.encrypted_credentials == "{}"
|
||||||
|
|
||||||
# Verify tools were properly serialized
|
# Verify tools were properly serialized
|
||||||
import json
|
import json
|
||||||
|
|
||||||
tools_data = json.loads(result["tools"])
|
tools_data = json.loads(result.tools)
|
||||||
assert len(tools_data) == 2
|
assert len(tools_data) == 2
|
||||||
assert tools_data[0]["name"] == "test_tool_1"
|
assert tools_data[0]["name"] == "test_tool_1"
|
||||||
assert tools_data[1]["name"] == "test_tool_2"
|
assert tools_data[1]["name"] == "test_tool_2"
|
||||||
|
|
||||||
# Verify mock interactions
|
# Verify mock interactions
|
||||||
mock_mcp_client.assert_called_once_with(
|
provider_entity = mcp_provider.to_entity()
|
||||||
"https://example.com/mcp",
|
mock_mcp_client.assert_called_once()
|
||||||
mcp_provider.id,
|
|
||||||
tenant.id,
|
|
||||||
authed=False,
|
|
||||||
for_list=True,
|
|
||||||
headers={},
|
|
||||||
timeout=30.0,
|
|
||||||
sse_read_timeout=300.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies):
|
def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies):
|
||||||
"""
|
"""
|
||||||
@ -1256,22 +1361,26 @@ class TestMCPToolManageService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Mock MCPClient to raise authentication error
|
# Mock MCPClient to raise authentication error
|
||||||
with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client:
|
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
|
||||||
from core.mcp.error import MCPAuthError
|
from core.mcp.error import MCPAuthError
|
||||||
|
|
||||||
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
||||||
mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required")
|
mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required")
|
||||||
|
|
||||||
# Act: Execute the method under test
|
# Act: Execute the method under test
|
||||||
result = MCPToolManageService._re_connect_mcp_provider(
|
from extensions.ext_database import db
|
||||||
"https://example.com/mcp", mcp_provider.id, tenant.id
|
|
||||||
|
service = MCPToolManageService(db.session())
|
||||||
|
result = service._reconnect_provider(
|
||||||
|
server_url="https://example.com/mcp",
|
||||||
|
provider=mcp_provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert: Verify the expected outcomes
|
# Assert: Verify the expected outcomes
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result["authed"] is False
|
assert result.authed is False
|
||||||
assert result["tools"] == "[]"
|
assert result.tools == "[]"
|
||||||
assert result["encrypted_credentials"] == "{}"
|
assert result.encrypted_credentials == "{}"
|
||||||
|
|
||||||
def test_re_connect_mcp_provider_connection_error(
|
def test_re_connect_mcp_provider_connection_error(
|
||||||
self, db_session_with_containers, mock_external_service_dependencies
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
@ -1295,12 +1404,18 @@ class TestMCPToolManageService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Mock MCPClient to raise connection error
|
# Mock MCPClient to raise connection error
|
||||||
with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client:
|
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
|
||||||
from core.mcp.error import MCPError
|
from core.mcp.error import MCPError
|
||||||
|
|
||||||
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
||||||
mock_client_instance.list_tools.side_effect = MCPError("Connection failed")
|
mock_client_instance.list_tools.side_effect = MCPError("Connection failed")
|
||||||
|
|
||||||
# Act & Assert: Verify proper error handling
|
# Act & Assert: Verify proper error handling
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
service = MCPToolManageService(db.session())
|
||||||
with pytest.raises(ValueError, match="Failed to re-connect MCP server: Connection failed"):
|
with pytest.raises(ValueError, match="Failed to re-connect MCP server: Connection failed"):
|
||||||
MCPToolManageService._re_connect_mcp_provider("https://example.com/mcp", mcp_provider.id, tenant.id)
|
service._reconnect_provider(
|
||||||
|
server_url="https://example.com/mcp",
|
||||||
|
provider=mcp_provider,
|
||||||
|
)
|
||||||
|
|||||||
@ -1,324 +0,0 @@
|
|||||||
"""
|
|
||||||
Unit tests for WorkflowResponseConverter focusing on process_data truncation functionality.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
from collections.abc import Mapping
|
|
||||||
from typing import Any
|
|
||||||
from unittest.mock import Mock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
|
||||||
from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity
|
|
||||||
from core.app.entities.queue_entities import (
|
|
||||||
QueueNodeRetryEvent,
|
|
||||||
QueueNodeStartedEvent,
|
|
||||||
QueueNodeSucceededEvent,
|
|
||||||
)
|
|
||||||
from core.workflow.enums import NodeType
|
|
||||||
from core.workflow.system_variable import SystemVariable
|
|
||||||
from libs.datetime_utils import naive_utc_now
|
|
||||||
from models import Account
|
|
||||||
|
|
||||||
|
|
||||||
class TestWorkflowResponseConverterCenarios:
|
|
||||||
"""Test process_data truncation in WorkflowResponseConverter."""
|
|
||||||
|
|
||||||
def create_mock_generate_entity(self) -> WorkflowAppGenerateEntity:
|
|
||||||
"""Create a mock WorkflowAppGenerateEntity."""
|
|
||||||
mock_entity = Mock(spec=WorkflowAppGenerateEntity)
|
|
||||||
mock_app_config = Mock()
|
|
||||||
mock_app_config.tenant_id = "test-tenant-id"
|
|
||||||
mock_entity.app_config = mock_app_config
|
|
||||||
mock_entity.inputs = {}
|
|
||||||
return mock_entity
|
|
||||||
|
|
||||||
def create_workflow_response_converter(self) -> WorkflowResponseConverter:
|
|
||||||
"""Create a WorkflowResponseConverter for testing."""
|
|
||||||
|
|
||||||
mock_entity = self.create_mock_generate_entity()
|
|
||||||
mock_user = Mock(spec=Account)
|
|
||||||
mock_user.id = "test-user-id"
|
|
||||||
mock_user.name = "Test User"
|
|
||||||
mock_user.email = "test@example.com"
|
|
||||||
|
|
||||||
system_variables = SystemVariable(workflow_id="wf-id", workflow_execution_id="initial-run-id")
|
|
||||||
return WorkflowResponseConverter(
|
|
||||||
application_generate_entity=mock_entity,
|
|
||||||
user=mock_user,
|
|
||||||
system_variables=system_variables,
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_node_started_event(self, *, node_execution_id: str | None = None) -> QueueNodeStartedEvent:
|
|
||||||
"""Create a QueueNodeStartedEvent for testing."""
|
|
||||||
return QueueNodeStartedEvent(
|
|
||||||
node_execution_id=node_execution_id or str(uuid.uuid4()),
|
|
||||||
node_id="test-node-id",
|
|
||||||
node_title="Test Node",
|
|
||||||
node_type=NodeType.CODE,
|
|
||||||
start_at=naive_utc_now(),
|
|
||||||
predecessor_node_id=None,
|
|
||||||
in_iteration_id=None,
|
|
||||||
in_loop_id=None,
|
|
||||||
provider_type="built-in",
|
|
||||||
provider_id="code",
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_node_succeeded_event(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
node_execution_id: str,
|
|
||||||
process_data: Mapping[str, Any] | None = None,
|
|
||||||
) -> QueueNodeSucceededEvent:
|
|
||||||
"""Create a QueueNodeSucceededEvent for testing."""
|
|
||||||
return QueueNodeSucceededEvent(
|
|
||||||
node_id="test-node-id",
|
|
||||||
node_type=NodeType.CODE,
|
|
||||||
node_execution_id=node_execution_id,
|
|
||||||
start_at=naive_utc_now(),
|
|
||||||
in_iteration_id=None,
|
|
||||||
in_loop_id=None,
|
|
||||||
inputs={},
|
|
||||||
process_data=process_data or {},
|
|
||||||
outputs={},
|
|
||||||
execution_metadata={},
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_node_retry_event(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
node_execution_id: str,
|
|
||||||
process_data: Mapping[str, Any] | None = None,
|
|
||||||
) -> QueueNodeRetryEvent:
|
|
||||||
"""Create a QueueNodeRetryEvent for testing."""
|
|
||||||
return QueueNodeRetryEvent(
|
|
||||||
inputs={"data": "inputs"},
|
|
||||||
outputs={"data": "outputs"},
|
|
||||||
process_data=process_data or {},
|
|
||||||
error="oops",
|
|
||||||
retry_index=1,
|
|
||||||
node_id="test-node-id",
|
|
||||||
node_type=NodeType.CODE,
|
|
||||||
node_title="test code",
|
|
||||||
provider_type="built-in",
|
|
||||||
provider_id="code",
|
|
||||||
node_execution_id=node_execution_id,
|
|
||||||
start_at=naive_utc_now(),
|
|
||||||
in_iteration_id=None,
|
|
||||||
in_loop_id=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_workflow_node_finish_response_uses_truncated_process_data(self):
|
|
||||||
"""Test that node finish response uses get_response_process_data()."""
|
|
||||||
converter = self.create_workflow_response_converter()
|
|
||||||
|
|
||||||
original_data = {"large_field": "x" * 10000, "metadata": "info"}
|
|
||||||
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
|
|
||||||
|
|
||||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
|
||||||
start_event = self.create_node_started_event()
|
|
||||||
converter.workflow_node_start_to_stream_response(
|
|
||||||
event=start_event,
|
|
||||||
task_id="test-task-id",
|
|
||||||
)
|
|
||||||
|
|
||||||
event = self.create_node_succeeded_event(
|
|
||||||
node_execution_id=start_event.node_execution_id,
|
|
||||||
process_data=original_data,
|
|
||||||
)
|
|
||||||
|
|
||||||
def fake_truncate(mapping):
|
|
||||||
if mapping == dict(original_data):
|
|
||||||
return truncated_data, True
|
|
||||||
return mapping, False
|
|
||||||
|
|
||||||
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
|
||||||
|
|
||||||
response = converter.workflow_node_finish_to_stream_response(
|
|
||||||
event=event,
|
|
||||||
task_id="test-task-id",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Response should use truncated data, not original
|
|
||||||
assert response is not None
|
|
||||||
assert response.data.process_data == truncated_data
|
|
||||||
assert response.data.process_data != original_data
|
|
||||||
assert response.data.process_data_truncated is True
|
|
||||||
|
|
||||||
def test_workflow_node_finish_response_without_truncation(self):
|
|
||||||
"""Test node finish response when no truncation is applied."""
|
|
||||||
converter = self.create_workflow_response_converter()
|
|
||||||
|
|
||||||
original_data = {"small": "data"}
|
|
||||||
|
|
||||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
|
||||||
start_event = self.create_node_started_event()
|
|
||||||
converter.workflow_node_start_to_stream_response(
|
|
||||||
event=start_event,
|
|
||||||
task_id="test-task-id",
|
|
||||||
)
|
|
||||||
|
|
||||||
event = self.create_node_succeeded_event(
|
|
||||||
node_execution_id=start_event.node_execution_id,
|
|
||||||
process_data=original_data,
|
|
||||||
)
|
|
||||||
|
|
||||||
def fake_truncate(mapping):
|
|
||||||
return mapping, False
|
|
||||||
|
|
||||||
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
|
||||||
|
|
||||||
response = converter.workflow_node_finish_to_stream_response(
|
|
||||||
event=event,
|
|
||||||
task_id="test-task-id",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Response should use original data
|
|
||||||
assert response is not None
|
|
||||||
assert response.data.process_data == original_data
|
|
||||||
assert response.data.process_data_truncated is False
|
|
||||||
|
|
||||||
def test_workflow_node_finish_response_with_none_process_data(self):
|
|
||||||
"""Test node finish response when process_data is None."""
|
|
||||||
converter = self.create_workflow_response_converter()
|
|
||||||
|
|
||||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
|
||||||
start_event = self.create_node_started_event()
|
|
||||||
converter.workflow_node_start_to_stream_response(
|
|
||||||
event=start_event,
|
|
||||||
task_id="test-task-id",
|
|
||||||
)
|
|
||||||
|
|
||||||
event = self.create_node_succeeded_event(
|
|
||||||
node_execution_id=start_event.node_execution_id,
|
|
||||||
process_data=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
def fake_truncate(mapping):
|
|
||||||
return mapping, False
|
|
||||||
|
|
||||||
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
|
||||||
|
|
||||||
response = converter.workflow_node_finish_to_stream_response(
|
|
||||||
event=event,
|
|
||||||
task_id="test-task-id",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Response should normalize missing process_data to an empty mapping
|
|
||||||
assert response is not None
|
|
||||||
assert response.data.process_data == {}
|
|
||||||
assert response.data.process_data_truncated is False
|
|
||||||
|
|
||||||
def test_workflow_node_retry_response_uses_truncated_process_data(self):
|
|
||||||
"""Test that node retry response uses get_response_process_data()."""
|
|
||||||
converter = self.create_workflow_response_converter()
|
|
||||||
|
|
||||||
original_data = {"large_field": "x" * 10000, "metadata": "info"}
|
|
||||||
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
|
|
||||||
|
|
||||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
|
||||||
start_event = self.create_node_started_event()
|
|
||||||
converter.workflow_node_start_to_stream_response(
|
|
||||||
event=start_event,
|
|
||||||
task_id="test-task-id",
|
|
||||||
)
|
|
||||||
|
|
||||||
event = self.create_node_retry_event(
|
|
||||||
node_execution_id=start_event.node_execution_id,
|
|
||||||
process_data=original_data,
|
|
||||||
)
|
|
||||||
|
|
||||||
def fake_truncate(mapping):
|
|
||||||
if mapping == dict(original_data):
|
|
||||||
return truncated_data, True
|
|
||||||
return mapping, False
|
|
||||||
|
|
||||||
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
|
||||||
|
|
||||||
response = converter.workflow_node_retry_to_stream_response(
|
|
||||||
event=event,
|
|
||||||
task_id="test-task-id",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Response should use truncated data, not original
|
|
||||||
assert response is not None
|
|
||||||
assert response.data.process_data == truncated_data
|
|
||||||
assert response.data.process_data != original_data
|
|
||||||
assert response.data.process_data_truncated is True
|
|
||||||
|
|
||||||
def test_workflow_node_retry_response_without_truncation(self):
|
|
||||||
"""Test node retry response when no truncation is applied."""
|
|
||||||
converter = self.create_workflow_response_converter()
|
|
||||||
|
|
||||||
original_data = {"small": "data"}
|
|
||||||
|
|
||||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
|
||||||
start_event = self.create_node_started_event()
|
|
||||||
converter.workflow_node_start_to_stream_response(
|
|
||||||
event=start_event,
|
|
||||||
task_id="test-task-id",
|
|
||||||
)
|
|
||||||
|
|
||||||
event = self.create_node_retry_event(
|
|
||||||
node_execution_id=start_event.node_execution_id,
|
|
||||||
process_data=original_data,
|
|
||||||
)
|
|
||||||
|
|
||||||
def fake_truncate(mapping):
|
|
||||||
return mapping, False
|
|
||||||
|
|
||||||
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
|
||||||
|
|
||||||
response = converter.workflow_node_retry_to_stream_response(
|
|
||||||
event=event,
|
|
||||||
task_id="test-task-id",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response is not None
|
|
||||||
assert response.data.process_data == original_data
|
|
||||||
assert response.data.process_data_truncated is False
|
|
||||||
|
|
||||||
def test_iteration_and_loop_nodes_return_none(self):
|
|
||||||
"""Test that iteration and loop nodes return None (no streaming events)."""
|
|
||||||
converter = self.create_workflow_response_converter()
|
|
||||||
|
|
||||||
iteration_event = QueueNodeSucceededEvent(
|
|
||||||
node_id="iteration-node",
|
|
||||||
node_type=NodeType.ITERATION,
|
|
||||||
node_execution_id=str(uuid.uuid4()),
|
|
||||||
start_at=naive_utc_now(),
|
|
||||||
in_iteration_id=None,
|
|
||||||
in_loop_id=None,
|
|
||||||
inputs={},
|
|
||||||
process_data={},
|
|
||||||
outputs={},
|
|
||||||
execution_metadata={},
|
|
||||||
)
|
|
||||||
|
|
||||||
response = converter.workflow_node_finish_to_stream_response(
|
|
||||||
event=iteration_event,
|
|
||||||
task_id="test-task-id",
|
|
||||||
)
|
|
||||||
assert response is None
|
|
||||||
|
|
||||||
loop_event = iteration_event.model_copy(update={"node_type": NodeType.LOOP})
|
|
||||||
response = converter.workflow_node_finish_to_stream_response(
|
|
||||||
event=loop_event,
|
|
||||||
task_id="test-task-id",
|
|
||||||
)
|
|
||||||
assert response is None
|
|
||||||
|
|
||||||
def test_finish_without_start_raises(self):
|
|
||||||
"""Ensure finish responses require a prior workflow start."""
|
|
||||||
converter = self.create_workflow_response_converter()
|
|
||||||
event = self.create_node_succeeded_event(
|
|
||||||
node_execution_id=str(uuid.uuid4()),
|
|
||||||
process_data={},
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
converter.workflow_node_finish_to_stream_response(
|
|
||||||
event=event,
|
|
||||||
task_id="test-task-id",
|
|
||||||
)
|
|
||||||
@ -0,0 +1,810 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for WorkflowResponseConverter focusing on process_data truncation functionality.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||||
|
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||||
|
from core.app.entities.queue_entities import (
|
||||||
|
QueueEvent,
|
||||||
|
QueueIterationStartEvent,
|
||||||
|
QueueLoopStartEvent,
|
||||||
|
QueueNodeExceptionEvent,
|
||||||
|
QueueNodeFailedEvent,
|
||||||
|
QueueNodeRetryEvent,
|
||||||
|
QueueNodeStartedEvent,
|
||||||
|
QueueNodeSucceededEvent,
|
||||||
|
)
|
||||||
|
from core.workflow.enums import NodeType
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
|
from libs.datetime_utils import naive_utc_now
|
||||||
|
from models import Account
|
||||||
|
from models.model import AppMode
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkflowResponseConverter:
|
||||||
|
"""Test truncation in WorkflowResponseConverter."""
|
||||||
|
|
||||||
|
def create_mock_generate_entity(self) -> WorkflowAppGenerateEntity:
|
||||||
|
"""Create a mock WorkflowAppGenerateEntity."""
|
||||||
|
mock_entity = Mock(spec=WorkflowAppGenerateEntity)
|
||||||
|
mock_app_config = Mock()
|
||||||
|
mock_app_config.tenant_id = "test-tenant-id"
|
||||||
|
mock_entity.invoke_from = InvokeFrom.WEB_APP
|
||||||
|
mock_entity.app_config = mock_app_config
|
||||||
|
mock_entity.inputs = {}
|
||||||
|
return mock_entity
|
||||||
|
|
||||||
|
def create_workflow_response_converter(self) -> WorkflowResponseConverter:
|
||||||
|
"""Create a WorkflowResponseConverter for testing."""
|
||||||
|
|
||||||
|
mock_entity = self.create_mock_generate_entity()
|
||||||
|
mock_user = Mock(spec=Account)
|
||||||
|
mock_user.id = "test-user-id"
|
||||||
|
mock_user.name = "Test User"
|
||||||
|
mock_user.email = "test@example.com"
|
||||||
|
|
||||||
|
system_variables = SystemVariable(workflow_id="wf-id", workflow_execution_id="initial-run-id")
|
||||||
|
return WorkflowResponseConverter(
|
||||||
|
application_generate_entity=mock_entity,
|
||||||
|
user=mock_user,
|
||||||
|
system_variables=system_variables,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_node_started_event(self, *, node_execution_id: str | None = None) -> QueueNodeStartedEvent:
|
||||||
|
"""Create a QueueNodeStartedEvent for testing."""
|
||||||
|
return QueueNodeStartedEvent(
|
||||||
|
node_execution_id=node_execution_id or str(uuid.uuid4()),
|
||||||
|
node_id="test-node-id",
|
||||||
|
node_title="Test Node",
|
||||||
|
node_type=NodeType.CODE,
|
||||||
|
start_at=naive_utc_now(),
|
||||||
|
in_iteration_id=None,
|
||||||
|
in_loop_id=None,
|
||||||
|
provider_type="built-in",
|
||||||
|
provider_id="code",
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_node_succeeded_event(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
node_execution_id: str,
|
||||||
|
process_data: Mapping[str, Any] | None = None,
|
||||||
|
) -> QueueNodeSucceededEvent:
|
||||||
|
"""Create a QueueNodeSucceededEvent for testing."""
|
||||||
|
return QueueNodeSucceededEvent(
|
||||||
|
node_id="test-node-id",
|
||||||
|
node_type=NodeType.CODE,
|
||||||
|
node_execution_id=node_execution_id,
|
||||||
|
start_at=naive_utc_now(),
|
||||||
|
in_iteration_id=None,
|
||||||
|
in_loop_id=None,
|
||||||
|
inputs={},
|
||||||
|
process_data=process_data or {},
|
||||||
|
outputs={},
|
||||||
|
execution_metadata={},
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_node_retry_event(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
node_execution_id: str,
|
||||||
|
process_data: Mapping[str, Any] | None = None,
|
||||||
|
) -> QueueNodeRetryEvent:
|
||||||
|
"""Create a QueueNodeRetryEvent for testing."""
|
||||||
|
return QueueNodeRetryEvent(
|
||||||
|
inputs={"data": "inputs"},
|
||||||
|
outputs={"data": "outputs"},
|
||||||
|
process_data=process_data or {},
|
||||||
|
error="oops",
|
||||||
|
retry_index=1,
|
||||||
|
node_id="test-node-id",
|
||||||
|
node_type=NodeType.CODE,
|
||||||
|
node_title="test code",
|
||||||
|
provider_type="built-in",
|
||||||
|
provider_id="code",
|
||||||
|
node_execution_id=node_execution_id,
|
||||||
|
start_at=naive_utc_now(),
|
||||||
|
in_iteration_id=None,
|
||||||
|
in_loop_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_workflow_node_finish_response_uses_truncated_process_data(self):
|
||||||
|
"""Test that node finish response uses get_response_process_data()."""
|
||||||
|
converter = self.create_workflow_response_converter()
|
||||||
|
|
||||||
|
original_data = {"large_field": "x" * 10000, "metadata": "info"}
|
||||||
|
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
|
||||||
|
|
||||||
|
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||||
|
start_event = self.create_node_started_event()
|
||||||
|
converter.workflow_node_start_to_stream_response(
|
||||||
|
event=start_event,
|
||||||
|
task_id="test-task-id",
|
||||||
|
)
|
||||||
|
|
||||||
|
event = self.create_node_succeeded_event(
|
||||||
|
node_execution_id=start_event.node_execution_id,
|
||||||
|
process_data=original_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
def fake_truncate(mapping):
|
||||||
|
if mapping == dict(original_data):
|
||||||
|
return truncated_data, True
|
||||||
|
return mapping, False
|
||||||
|
|
||||||
|
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
||||||
|
|
||||||
|
response = converter.workflow_node_finish_to_stream_response(
|
||||||
|
event=event,
|
||||||
|
task_id="test-task-id",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Response should use truncated data, not original
|
||||||
|
assert response is not None
|
||||||
|
assert response.data.process_data == truncated_data
|
||||||
|
assert response.data.process_data != original_data
|
||||||
|
assert response.data.process_data_truncated is True
|
||||||
|
|
||||||
|
def test_workflow_node_finish_response_without_truncation(self):
|
||||||
|
"""Test node finish response when no truncation is applied."""
|
||||||
|
converter = self.create_workflow_response_converter()
|
||||||
|
|
||||||
|
original_data = {"small": "data"}
|
||||||
|
|
||||||
|
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||||
|
start_event = self.create_node_started_event()
|
||||||
|
converter.workflow_node_start_to_stream_response(
|
||||||
|
event=start_event,
|
||||||
|
task_id="test-task-id",
|
||||||
|
)
|
||||||
|
|
||||||
|
event = self.create_node_succeeded_event(
|
||||||
|
node_execution_id=start_event.node_execution_id,
|
||||||
|
process_data=original_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
def fake_truncate(mapping):
|
||||||
|
return mapping, False
|
||||||
|
|
||||||
|
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
||||||
|
|
||||||
|
response = converter.workflow_node_finish_to_stream_response(
|
||||||
|
event=event,
|
||||||
|
task_id="test-task-id",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Response should use original data
|
||||||
|
assert response is not None
|
||||||
|
assert response.data.process_data == original_data
|
||||||
|
assert response.data.process_data_truncated is False
|
||||||
|
|
||||||
|
def test_workflow_node_finish_response_with_none_process_data(self):
|
||||||
|
"""Test node finish response when process_data is None."""
|
||||||
|
converter = self.create_workflow_response_converter()
|
||||||
|
|
||||||
|
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||||
|
start_event = self.create_node_started_event()
|
||||||
|
converter.workflow_node_start_to_stream_response(
|
||||||
|
event=start_event,
|
||||||
|
task_id="test-task-id",
|
||||||
|
)
|
||||||
|
|
||||||
|
event = self.create_node_succeeded_event(
|
||||||
|
node_execution_id=start_event.node_execution_id,
|
||||||
|
process_data=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def fake_truncate(mapping):
|
||||||
|
return mapping, False
|
||||||
|
|
||||||
|
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
||||||
|
|
||||||
|
response = converter.workflow_node_finish_to_stream_response(
|
||||||
|
event=event,
|
||||||
|
task_id="test-task-id",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Response should normalize missing process_data to an empty mapping
|
||||||
|
assert response is not None
|
||||||
|
assert response.data.process_data == {}
|
||||||
|
assert response.data.process_data_truncated is False
|
||||||
|
|
||||||
|
def test_workflow_node_retry_response_uses_truncated_process_data(self):
|
||||||
|
"""Test that node retry response uses get_response_process_data()."""
|
||||||
|
converter = self.create_workflow_response_converter()
|
||||||
|
|
||||||
|
original_data = {"large_field": "x" * 10000, "metadata": "info"}
|
||||||
|
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
|
||||||
|
|
||||||
|
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||||
|
start_event = self.create_node_started_event()
|
||||||
|
converter.workflow_node_start_to_stream_response(
|
||||||
|
event=start_event,
|
||||||
|
task_id="test-task-id",
|
||||||
|
)
|
||||||
|
|
||||||
|
event = self.create_node_retry_event(
|
||||||
|
node_execution_id=start_event.node_execution_id,
|
||||||
|
process_data=original_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
def fake_truncate(mapping):
|
||||||
|
if mapping == dict(original_data):
|
||||||
|
return truncated_data, True
|
||||||
|
return mapping, False
|
||||||
|
|
||||||
|
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
||||||
|
|
||||||
|
response = converter.workflow_node_retry_to_stream_response(
|
||||||
|
event=event,
|
||||||
|
task_id="test-task-id",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Response should use truncated data, not original
|
||||||
|
assert response is not None
|
||||||
|
assert response.data.process_data == truncated_data
|
||||||
|
assert response.data.process_data != original_data
|
||||||
|
assert response.data.process_data_truncated is True
|
||||||
|
|
||||||
|
def test_workflow_node_retry_response_without_truncation(self):
|
||||||
|
"""Test node retry response when no truncation is applied."""
|
||||||
|
converter = self.create_workflow_response_converter()
|
||||||
|
|
||||||
|
original_data = {"small": "data"}
|
||||||
|
|
||||||
|
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||||
|
start_event = self.create_node_started_event()
|
||||||
|
converter.workflow_node_start_to_stream_response(
|
||||||
|
event=start_event,
|
||||||
|
task_id="test-task-id",
|
||||||
|
)
|
||||||
|
|
||||||
|
event = self.create_node_retry_event(
|
||||||
|
node_execution_id=start_event.node_execution_id,
|
||||||
|
process_data=original_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
def fake_truncate(mapping):
|
||||||
|
return mapping, False
|
||||||
|
|
||||||
|
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
||||||
|
|
||||||
|
response = converter.workflow_node_retry_to_stream_response(
|
||||||
|
event=event,
|
||||||
|
task_id="test-task-id",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response is not None
|
||||||
|
assert response.data.process_data == original_data
|
||||||
|
assert response.data.process_data_truncated is False
|
||||||
|
|
||||||
|
def test_iteration_and_loop_nodes_return_none(self):
|
||||||
|
"""Test that iteration and loop nodes return None (no streaming events)."""
|
||||||
|
converter = self.create_workflow_response_converter()
|
||||||
|
|
||||||
|
iteration_event = QueueNodeSucceededEvent(
|
||||||
|
node_id="iteration-node",
|
||||||
|
node_type=NodeType.ITERATION,
|
||||||
|
node_execution_id=str(uuid.uuid4()),
|
||||||
|
start_at=naive_utc_now(),
|
||||||
|
in_iteration_id=None,
|
||||||
|
in_loop_id=None,
|
||||||
|
inputs={},
|
||||||
|
process_data={},
|
||||||
|
outputs={},
|
||||||
|
execution_metadata={},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = converter.workflow_node_finish_to_stream_response(
|
||||||
|
event=iteration_event,
|
||||||
|
task_id="test-task-id",
|
||||||
|
)
|
||||||
|
assert response is None
|
||||||
|
|
||||||
|
loop_event = iteration_event.model_copy(update={"node_type": NodeType.LOOP})
|
||||||
|
response = converter.workflow_node_finish_to_stream_response(
|
||||||
|
event=loop_event,
|
||||||
|
task_id="test-task-id",
|
||||||
|
)
|
||||||
|
assert response is None
|
||||||
|
|
||||||
|
def test_finish_without_start_raises(self):
|
||||||
|
"""Ensure finish responses require a prior workflow start."""
|
||||||
|
converter = self.create_workflow_response_converter()
|
||||||
|
event = self.create_node_succeeded_event(
|
||||||
|
node_execution_id=str(uuid.uuid4()),
|
||||||
|
process_data={},
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
converter.workflow_node_finish_to_stream_response(
|
||||||
|
event=event,
|
||||||
|
task_id="test-task-id",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TestCase:
|
||||||
|
"""Test case data for table-driven tests."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
invoke_from: InvokeFrom
|
||||||
|
expected_truncation_enabled: bool
|
||||||
|
description: str
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkflowResponseConverterServiceApiTruncation:
|
||||||
|
"""Test class for Service API truncation functionality in WorkflowResponseConverter."""
|
||||||
|
|
||||||
|
def create_test_app_generate_entity(self, invoke_from: InvokeFrom) -> WorkflowAppGenerateEntity:
|
||||||
|
"""Create a test WorkflowAppGenerateEntity with specified invoke_from."""
|
||||||
|
# Create a minimal WorkflowUIBasedAppConfig for testing
|
||||||
|
app_config = WorkflowUIBasedAppConfig(
|
||||||
|
tenant_id="test_tenant",
|
||||||
|
app_id="test_app",
|
||||||
|
app_mode=AppMode.WORKFLOW,
|
||||||
|
workflow_id="test_workflow_id",
|
||||||
|
)
|
||||||
|
|
||||||
|
entity = WorkflowAppGenerateEntity(
|
||||||
|
task_id="test_task_id",
|
||||||
|
app_id="test_app_id",
|
||||||
|
app_config=app_config,
|
||||||
|
tenant_id="test_tenant",
|
||||||
|
app_mode="workflow",
|
||||||
|
invoke_from=invoke_from,
|
||||||
|
inputs={"test_input": "test_value"},
|
||||||
|
user_id="test_user_id",
|
||||||
|
stream=True,
|
||||||
|
files=[],
|
||||||
|
workflow_execution_id="test_workflow_exec_id",
|
||||||
|
)
|
||||||
|
return entity
|
||||||
|
|
||||||
|
def create_test_user(self) -> Account:
|
||||||
|
"""Create a test user account."""
|
||||||
|
account = Account(
|
||||||
|
name="Test User",
|
||||||
|
email="test@example.com",
|
||||||
|
)
|
||||||
|
# Manually set the ID for testing purposes
|
||||||
|
account.id = "test_user_id"
|
||||||
|
return account
|
||||||
|
|
||||||
|
def create_test_system_variables(self) -> SystemVariable:
|
||||||
|
"""Create test system variables."""
|
||||||
|
return SystemVariable()
|
||||||
|
|
||||||
|
def create_test_converter(self, invoke_from: InvokeFrom) -> WorkflowResponseConverter:
|
||||||
|
"""Create WorkflowResponseConverter with specified invoke_from."""
|
||||||
|
entity = self.create_test_app_generate_entity(invoke_from)
|
||||||
|
user = self.create_test_user()
|
||||||
|
system_variables = self.create_test_system_variables()
|
||||||
|
|
||||||
|
converter = WorkflowResponseConverter(
|
||||||
|
application_generate_entity=entity,
|
||||||
|
user=user,
|
||||||
|
system_variables=system_variables,
|
||||||
|
)
|
||||||
|
# ensure `workflow_run_id` is set.
|
||||||
|
converter.workflow_start_to_stream_response(
|
||||||
|
task_id="test-task-id",
|
||||||
|
workflow_run_id="test-workflow-run-id",
|
||||||
|
workflow_id="test-workflow-id",
|
||||||
|
)
|
||||||
|
return converter
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_case",
|
||||||
|
[
|
||||||
|
TestCase(
|
||||||
|
name="service_api_truncation_disabled",
|
||||||
|
invoke_from=InvokeFrom.SERVICE_API,
|
||||||
|
expected_truncation_enabled=False,
|
||||||
|
description="Service API calls should have truncation disabled",
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
name="web_app_truncation_enabled",
|
||||||
|
invoke_from=InvokeFrom.WEB_APP,
|
||||||
|
expected_truncation_enabled=True,
|
||||||
|
description="Web app calls should have truncation enabled",
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
name="debugger_truncation_enabled",
|
||||||
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
|
expected_truncation_enabled=True,
|
||||||
|
description="Debugger calls should have truncation enabled",
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
name="explore_truncation_enabled",
|
||||||
|
invoke_from=InvokeFrom.EXPLORE,
|
||||||
|
expected_truncation_enabled=True,
|
||||||
|
description="Explore calls should have truncation enabled",
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
name="published_truncation_enabled",
|
||||||
|
invoke_from=InvokeFrom.PUBLISHED,
|
||||||
|
expected_truncation_enabled=True,
|
||||||
|
description="Published app calls should have truncation enabled",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
ids=lambda x: x.name,
|
||||||
|
)
|
||||||
|
def test_truncator_selection_based_on_invoke_from(self, test_case: TestCase):
|
||||||
|
"""Test that the correct truncator is selected based on invoke_from."""
|
||||||
|
converter = self.create_test_converter(test_case.invoke_from)
|
||||||
|
|
||||||
|
# Test truncation behavior instead of checking private attribute
|
||||||
|
|
||||||
|
# Create a test event with large data
|
||||||
|
large_value = {"key": ["x"] * 2000} # Large data that would be truncated
|
||||||
|
|
||||||
|
event = QueueNodeSucceededEvent(
|
||||||
|
node_execution_id="test_node_exec_id",
|
||||||
|
node_id="test_node",
|
||||||
|
node_type=NodeType.LLM,
|
||||||
|
start_at=naive_utc_now(),
|
||||||
|
inputs=large_value,
|
||||||
|
process_data=large_value,
|
||||||
|
outputs=large_value,
|
||||||
|
error=None,
|
||||||
|
execution_metadata=None,
|
||||||
|
in_iteration_id=None,
|
||||||
|
in_loop_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = converter.workflow_node_finish_to_stream_response(
|
||||||
|
event=event,
|
||||||
|
task_id="test_task",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify response is not None
|
||||||
|
assert response is not None
|
||||||
|
|
||||||
|
# Verify truncation behavior matches expectations
|
||||||
|
if test_case.expected_truncation_enabled:
|
||||||
|
# Truncation should be enabled for non-service-api calls
|
||||||
|
assert response.data.inputs_truncated
|
||||||
|
assert response.data.process_data_truncated
|
||||||
|
assert response.data.outputs_truncated
|
||||||
|
else:
|
||||||
|
# SERVICE_API should not truncate
|
||||||
|
assert not response.data.inputs_truncated
|
||||||
|
assert not response.data.process_data_truncated
|
||||||
|
assert not response.data.outputs_truncated
|
||||||
|
|
||||||
|
def test_service_api_truncator_no_op_mapping(self):
|
||||||
|
"""Test that Service API truncator doesn't truncate variable mappings."""
|
||||||
|
converter = self.create_test_converter(InvokeFrom.SERVICE_API)
|
||||||
|
|
||||||
|
# Create a test event with large data
|
||||||
|
large_value: dict[str, Any] = {
|
||||||
|
"large_string": "x" * 10000, # Large string
|
||||||
|
"large_list": list(range(2000)), # Large array
|
||||||
|
"nested_data": {"deep_nested": {"very_deep": {"value": "x" * 5000}}},
|
||||||
|
}
|
||||||
|
|
||||||
|
event = QueueNodeSucceededEvent(
|
||||||
|
node_execution_id="test_node_exec_id",
|
||||||
|
node_id="test_node",
|
||||||
|
node_type=NodeType.LLM,
|
||||||
|
start_at=naive_utc_now(),
|
||||||
|
inputs=large_value,
|
||||||
|
process_data=large_value,
|
||||||
|
outputs=large_value,
|
||||||
|
error=None,
|
||||||
|
execution_metadata=None,
|
||||||
|
in_iteration_id=None,
|
||||||
|
in_loop_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = converter.workflow_node_finish_to_stream_response(
|
||||||
|
event=event,
|
||||||
|
task_id="test_task",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify response is not None
|
||||||
|
data = response.data
|
||||||
|
assert data.inputs == large_value
|
||||||
|
assert data.process_data == large_value
|
||||||
|
assert data.outputs == large_value
|
||||||
|
# Service API should not truncate
|
||||||
|
assert data.inputs_truncated is False
|
||||||
|
assert data.process_data_truncated is False
|
||||||
|
assert data.outputs_truncated is False
|
||||||
|
|
||||||
|
def test_web_app_truncator_works_normally(self):
|
||||||
|
"""Test that web app truncator still works normally."""
|
||||||
|
converter = self.create_test_converter(InvokeFrom.WEB_APP)
|
||||||
|
|
||||||
|
# Create a test event with large data
|
||||||
|
large_value = {
|
||||||
|
"large_string": "x" * 10000, # Large string
|
||||||
|
"large_list": list(range(2000)), # Large array
|
||||||
|
}
|
||||||
|
|
||||||
|
event = QueueNodeSucceededEvent(
|
||||||
|
node_execution_id="test_node_exec_id",
|
||||||
|
node_id="test_node",
|
||||||
|
node_type=NodeType.LLM,
|
||||||
|
start_at=naive_utc_now(),
|
||||||
|
inputs=large_value,
|
||||||
|
process_data=large_value,
|
||||||
|
outputs=large_value,
|
||||||
|
error=None,
|
||||||
|
execution_metadata=None,
|
||||||
|
in_iteration_id=None,
|
||||||
|
in_loop_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = converter.workflow_node_finish_to_stream_response(
|
||||||
|
event=event,
|
||||||
|
task_id="test_task",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify response is not None
|
||||||
|
assert response is not None
|
||||||
|
|
||||||
|
# Web app should truncate
|
||||||
|
data = response.data
|
||||||
|
assert data.inputs != large_value
|
||||||
|
assert data.process_data != large_value
|
||||||
|
assert data.outputs != large_value
|
||||||
|
# The exact behavior depends on VariableTruncator implementation
|
||||||
|
# Just verify that truncation flags are present
|
||||||
|
assert data.inputs_truncated is True
|
||||||
|
assert data.process_data_truncated is True
|
||||||
|
assert data.outputs_truncated is True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_event_by_type(
|
||||||
|
type_: QueueEvent, inputs: Mapping[str, Any], process_data: Mapping[str, Any], outputs: Mapping[str, Any]
|
||||||
|
) -> QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent:
|
||||||
|
if type_ == QueueEvent.NODE_SUCCEEDED:
|
||||||
|
return QueueNodeSucceededEvent(
|
||||||
|
node_execution_id="test_node_exec_id",
|
||||||
|
node_id="test_node",
|
||||||
|
node_type=NodeType.LLM,
|
||||||
|
start_at=naive_utc_now(),
|
||||||
|
inputs=inputs,
|
||||||
|
process_data=process_data,
|
||||||
|
outputs=outputs,
|
||||||
|
error=None,
|
||||||
|
execution_metadata=None,
|
||||||
|
in_iteration_id=None,
|
||||||
|
in_loop_id=None,
|
||||||
|
)
|
||||||
|
elif type_ == QueueEvent.NODE_FAILED:
|
||||||
|
return QueueNodeFailedEvent(
|
||||||
|
node_execution_id="test_node_exec_id",
|
||||||
|
node_id="test_node",
|
||||||
|
node_type=NodeType.LLM,
|
||||||
|
start_at=naive_utc_now(),
|
||||||
|
inputs=inputs,
|
||||||
|
process_data=process_data,
|
||||||
|
outputs=outputs,
|
||||||
|
error="oops",
|
||||||
|
execution_metadata=None,
|
||||||
|
in_iteration_id=None,
|
||||||
|
in_loop_id=None,
|
||||||
|
)
|
||||||
|
elif type_ == QueueEvent.NODE_EXCEPTION:
|
||||||
|
return QueueNodeExceptionEvent(
|
||||||
|
node_execution_id="test_node_exec_id",
|
||||||
|
node_id="test_node",
|
||||||
|
node_type=NodeType.LLM,
|
||||||
|
start_at=naive_utc_now(),
|
||||||
|
inputs=inputs,
|
||||||
|
process_data=process_data,
|
||||||
|
outputs=outputs,
|
||||||
|
error="oops",
|
||||||
|
execution_metadata=None,
|
||||||
|
in_iteration_id=None,
|
||||||
|
in_loop_id=None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception("unknown type.")
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"event_type",
|
||||||
|
[
|
||||||
|
QueueEvent.NODE_SUCCEEDED,
|
||||||
|
QueueEvent.NODE_FAILED,
|
||||||
|
QueueEvent.NODE_EXCEPTION,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_service_api_node_finish_event_no_truncation(self, event_type: QueueEvent):
|
||||||
|
"""Test that Service API doesn't truncate node finish events."""
|
||||||
|
converter = self.create_test_converter(InvokeFrom.SERVICE_API)
|
||||||
|
# Create test event with large data
|
||||||
|
large_inputs = {"input1": "x" * 5000, "input2": list(range(2000))}
|
||||||
|
large_process_data = {"process1": "y" * 5000, "process2": {"nested": ["z"] * 2000}}
|
||||||
|
large_outputs = {"output1": "result" * 1000, "output2": list(range(2000))}
|
||||||
|
|
||||||
|
event = TestWorkflowResponseConverterServiceApiTruncation._create_event_by_type(
|
||||||
|
event_type, large_inputs, large_process_data, large_outputs
|
||||||
|
)
|
||||||
|
|
||||||
|
response = converter.workflow_node_finish_to_stream_response(
|
||||||
|
event=event,
|
||||||
|
task_id="test_task",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify response is not None
|
||||||
|
assert response is not None
|
||||||
|
|
||||||
|
# Verify response contains full data (not truncated)
|
||||||
|
assert response.data.inputs == large_inputs
|
||||||
|
assert response.data.process_data == large_process_data
|
||||||
|
assert response.data.outputs == large_outputs
|
||||||
|
assert not response.data.inputs_truncated
|
||||||
|
assert not response.data.process_data_truncated
|
||||||
|
assert not response.data.outputs_truncated
|
||||||
|
|
||||||
|
def test_service_api_node_retry_event_no_truncation(self):
|
||||||
|
"""Test that Service API doesn't truncate node retry events."""
|
||||||
|
converter = self.create_test_converter(InvokeFrom.SERVICE_API)
|
||||||
|
|
||||||
|
# Create test event with large data
|
||||||
|
large_inputs = {"retry_input": "x" * 5000}
|
||||||
|
large_process_data = {"retry_process": "y" * 5000}
|
||||||
|
large_outputs = {"retry_output": "z" * 5000}
|
||||||
|
|
||||||
|
# First, we need to store a snapshot by simulating a start event
|
||||||
|
start_event = QueueNodeStartedEvent(
|
||||||
|
node_execution_id="test_node_exec_id",
|
||||||
|
node_id="test_node",
|
||||||
|
node_type=NodeType.LLM,
|
||||||
|
node_title="Test Node",
|
||||||
|
node_run_index=1,
|
||||||
|
start_at=naive_utc_now(),
|
||||||
|
in_iteration_id=None,
|
||||||
|
in_loop_id=None,
|
||||||
|
agent_strategy=None,
|
||||||
|
provider_type="plugin",
|
||||||
|
provider_id="test/test_plugin",
|
||||||
|
)
|
||||||
|
converter.workflow_node_start_to_stream_response(event=start_event, task_id="test_task")
|
||||||
|
|
||||||
|
# Now create retry event
|
||||||
|
event = QueueNodeRetryEvent(
|
||||||
|
node_execution_id="test_node_exec_id",
|
||||||
|
node_id="test_node",
|
||||||
|
node_type=NodeType.LLM,
|
||||||
|
node_title="Test Node",
|
||||||
|
node_run_index=1,
|
||||||
|
start_at=naive_utc_now(),
|
||||||
|
inputs=large_inputs,
|
||||||
|
process_data=large_process_data,
|
||||||
|
outputs=large_outputs,
|
||||||
|
error="Retry error",
|
||||||
|
execution_metadata=None,
|
||||||
|
in_iteration_id=None,
|
||||||
|
in_loop_id=None,
|
||||||
|
retry_index=1,
|
||||||
|
provider_type="plugin",
|
||||||
|
provider_id="test/test_plugin",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = converter.workflow_node_retry_to_stream_response(
|
||||||
|
event=event,
|
||||||
|
task_id="test_task",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify response is not None
|
||||||
|
assert response is not None
|
||||||
|
|
||||||
|
# Verify response contains full data (not truncated)
|
||||||
|
assert response.data.inputs == large_inputs
|
||||||
|
assert response.data.process_data == large_process_data
|
||||||
|
assert response.data.outputs == large_outputs
|
||||||
|
assert not response.data.inputs_truncated
|
||||||
|
assert not response.data.process_data_truncated
|
||||||
|
assert not response.data.outputs_truncated
|
||||||
|
|
||||||
|
def test_service_api_iteration_events_no_truncation(self):
|
||||||
|
"""Test that Service API doesn't truncate iteration events."""
|
||||||
|
converter = self.create_test_converter(InvokeFrom.SERVICE_API)
|
||||||
|
|
||||||
|
# Test iteration start event
|
||||||
|
large_value = {"iteration_input": ["x"] * 2000}
|
||||||
|
|
||||||
|
start_event = QueueIterationStartEvent(
|
||||||
|
node_execution_id="test_iter_exec_id",
|
||||||
|
node_id="test_iteration",
|
||||||
|
node_type=NodeType.ITERATION,
|
||||||
|
node_title="Test Iteration",
|
||||||
|
node_run_index=0,
|
||||||
|
start_at=naive_utc_now(),
|
||||||
|
inputs=large_value,
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = converter.workflow_iteration_start_to_stream_response(
|
||||||
|
task_id="test_task",
|
||||||
|
workflow_execution_id="test_workflow_exec_id",
|
||||||
|
event=start_event,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response is not None
|
||||||
|
assert response.data.inputs == large_value
|
||||||
|
assert not response.data.inputs_truncated
|
||||||
|
|
||||||
|
def test_service_api_loop_events_no_truncation(self):
|
||||||
|
"""Test that Service API doesn't truncate loop events."""
|
||||||
|
converter = self.create_test_converter(InvokeFrom.SERVICE_API)
|
||||||
|
|
||||||
|
# Test loop start event
|
||||||
|
large_inputs = {"loop_input": ["x"] * 2000}
|
||||||
|
|
||||||
|
start_event = QueueLoopStartEvent(
|
||||||
|
node_execution_id="test_loop_exec_id",
|
||||||
|
node_id="test_loop",
|
||||||
|
node_type=NodeType.LOOP,
|
||||||
|
node_title="Test Loop",
|
||||||
|
start_at=naive_utc_now(),
|
||||||
|
inputs=large_inputs,
|
||||||
|
metadata={},
|
||||||
|
node_run_index=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = converter.workflow_loop_start_to_stream_response(
|
||||||
|
task_id="test_task",
|
||||||
|
workflow_execution_id="test_workflow_exec_id",
|
||||||
|
event=start_event,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response is not None
|
||||||
|
assert response.data.inputs == large_inputs
|
||||||
|
assert not response.data.inputs_truncated
|
||||||
|
|
||||||
|
def test_web_app_node_finish_event_truncation_works(self):
|
||||||
|
"""Test that web app still truncates node finish events."""
|
||||||
|
converter = self.create_test_converter(InvokeFrom.WEB_APP)
|
||||||
|
|
||||||
|
# Create test event with large data that should be truncated
|
||||||
|
large_inputs = {"input1": ["x"] * 2000}
|
||||||
|
large_process_data = {"process1": ["y"] * 2000}
|
||||||
|
large_outputs = {"output1": ["z"] * 2000}
|
||||||
|
|
||||||
|
event = QueueNodeSucceededEvent(
|
||||||
|
node_execution_id="test_node_exec_id",
|
||||||
|
node_id="test_node",
|
||||||
|
node_type=NodeType.LLM,
|
||||||
|
start_at=naive_utc_now(),
|
||||||
|
inputs=large_inputs,
|
||||||
|
process_data=large_process_data,
|
||||||
|
outputs=large_outputs,
|
||||||
|
error=None,
|
||||||
|
execution_metadata=None,
|
||||||
|
in_iteration_id=None,
|
||||||
|
in_loop_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = converter.workflow_node_finish_to_stream_response(
|
||||||
|
event=event,
|
||||||
|
task_id="test_task",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify response is not None
|
||||||
|
assert response is not None
|
||||||
|
|
||||||
|
# Verify response contains truncated data
|
||||||
|
# The exact behavior depends on VariableTruncator implementation
|
||||||
|
# Just verify truncation flags are set correctly (may or may not be truncated depending on size)
|
||||||
|
# At minimum, the truncation mechanism should work
|
||||||
|
assert isinstance(response.data.inputs, dict)
|
||||||
|
assert response.data.inputs_truncated
|
||||||
|
assert isinstance(response.data.process_data, dict)
|
||||||
|
assert response.data.process_data_truncated
|
||||||
|
assert isinstance(response.data.outputs, dict)
|
||||||
|
assert response.data.outputs_truncated
|
||||||
0
api/tests/unit_tests/core/mcp/__init__.py
Normal file
0
api/tests/unit_tests/core/mcp/__init__.py
Normal file
0
api/tests/unit_tests/core/mcp/auth/__init__.py
Normal file
0
api/tests/unit_tests/core/mcp/auth/__init__.py
Normal file
740
api/tests/unit_tests/core/mcp/auth/test_auth_flow.py
Normal file
740
api/tests/unit_tests/core/mcp/auth/test_auth_flow.py
Normal file
@ -0,0 +1,740 @@
|
|||||||
|
"""Unit tests for MCP OAuth authentication flow."""
|
||||||
|
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.entities.mcp_provider import MCPProviderEntity
|
||||||
|
from core.mcp.auth.auth_flow import (
|
||||||
|
OAUTH_STATE_EXPIRY_SECONDS,
|
||||||
|
OAUTH_STATE_REDIS_KEY_PREFIX,
|
||||||
|
OAuthCallbackState,
|
||||||
|
_create_secure_redis_state,
|
||||||
|
_retrieve_redis_state,
|
||||||
|
auth,
|
||||||
|
check_support_resource_discovery,
|
||||||
|
discover_oauth_metadata,
|
||||||
|
exchange_authorization,
|
||||||
|
generate_pkce_challenge,
|
||||||
|
handle_callback,
|
||||||
|
refresh_authorization,
|
||||||
|
register_client,
|
||||||
|
start_authorization,
|
||||||
|
)
|
||||||
|
from core.mcp.entities import AuthActionType, AuthResult
|
||||||
|
from core.mcp.types import (
|
||||||
|
OAuthClientInformation,
|
||||||
|
OAuthClientInformationFull,
|
||||||
|
OAuthClientMetadata,
|
||||||
|
OAuthMetadata,
|
||||||
|
OAuthTokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPKCEGeneration:
|
||||||
|
"""Test PKCE challenge generation."""
|
||||||
|
|
||||||
|
def test_generate_pkce_challenge(self):
|
||||||
|
"""Test PKCE challenge and verifier generation."""
|
||||||
|
code_verifier, code_challenge = generate_pkce_challenge()
|
||||||
|
|
||||||
|
# Verify format - should be URL-safe base64 without padding
|
||||||
|
assert "=" not in code_verifier
|
||||||
|
assert "+" not in code_verifier
|
||||||
|
assert "/" not in code_verifier
|
||||||
|
assert "=" not in code_challenge
|
||||||
|
assert "+" not in code_challenge
|
||||||
|
assert "/" not in code_challenge
|
||||||
|
|
||||||
|
# Verify length
|
||||||
|
assert len(code_verifier) > 40 # Should be around 54 characters
|
||||||
|
assert len(code_challenge) > 40 # Should be around 43 characters
|
||||||
|
|
||||||
|
def test_generate_pkce_challenge_uniqueness(self):
|
||||||
|
"""Test that PKCE generation produces unique values."""
|
||||||
|
results = set()
|
||||||
|
for _ in range(10):
|
||||||
|
code_verifier, code_challenge = generate_pkce_challenge()
|
||||||
|
results.add((code_verifier, code_challenge))
|
||||||
|
|
||||||
|
# All should be unique
|
||||||
|
assert len(results) == 10
|
||||||
|
|
||||||
|
|
||||||
|
class TestRedisStateManagement:
|
||||||
|
"""Test Redis state management functions."""
|
||||||
|
|
||||||
|
@patch("core.mcp.auth.auth_flow.redis_client")
|
||||||
|
def test_create_secure_redis_state(self, mock_redis):
|
||||||
|
"""Test creating secure Redis state."""
|
||||||
|
state_data = OAuthCallbackState(
|
||||||
|
provider_id="test-provider",
|
||||||
|
tenant_id="test-tenant",
|
||||||
|
server_url="https://example.com",
|
||||||
|
metadata=None,
|
||||||
|
client_information=OAuthClientInformation(client_id="test-client"),
|
||||||
|
code_verifier="test-verifier",
|
||||||
|
redirect_uri="https://redirect.example.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
state_key = _create_secure_redis_state(state_data)
|
||||||
|
|
||||||
|
# Verify state key format
|
||||||
|
assert len(state_key) > 20 # Should be a secure random token
|
||||||
|
|
||||||
|
# Verify Redis call
|
||||||
|
mock_redis.setex.assert_called_once()
|
||||||
|
call_args = mock_redis.setex.call_args
|
||||||
|
assert call_args[0][0].startswith(OAUTH_STATE_REDIS_KEY_PREFIX)
|
||||||
|
assert call_args[0][1] == OAUTH_STATE_EXPIRY_SECONDS
|
||||||
|
assert state_data.model_dump_json() in call_args[0][2]
|
||||||
|
|
||||||
|
@patch("core.mcp.auth.auth_flow.redis_client")
|
||||||
|
def test_retrieve_redis_state_success(self, mock_redis):
|
||||||
|
"""Test retrieving state from Redis."""
|
||||||
|
state_data = OAuthCallbackState(
|
||||||
|
provider_id="test-provider",
|
||||||
|
tenant_id="test-tenant",
|
||||||
|
server_url="https://example.com",
|
||||||
|
metadata=None,
|
||||||
|
client_information=OAuthClientInformation(client_id="test-client"),
|
||||||
|
code_verifier="test-verifier",
|
||||||
|
redirect_uri="https://redirect.example.com",
|
||||||
|
)
|
||||||
|
mock_redis.get.return_value = state_data.model_dump_json()
|
||||||
|
|
||||||
|
result = _retrieve_redis_state("test-state-key")
|
||||||
|
|
||||||
|
# Verify result
|
||||||
|
assert result.provider_id == "test-provider"
|
||||||
|
assert result.tenant_id == "test-tenant"
|
||||||
|
assert result.server_url == "https://example.com"
|
||||||
|
|
||||||
|
# Verify Redis calls
|
||||||
|
mock_redis.get.assert_called_once_with(f"{OAUTH_STATE_REDIS_KEY_PREFIX}test-state-key")
|
||||||
|
mock_redis.delete.assert_called_once_with(f"{OAUTH_STATE_REDIS_KEY_PREFIX}test-state-key")
|
||||||
|
|
||||||
|
@patch("core.mcp.auth.auth_flow.redis_client")
|
||||||
|
def test_retrieve_redis_state_not_found(self, mock_redis):
|
||||||
|
"""Test retrieving non-existent state from Redis."""
|
||||||
|
mock_redis.get.return_value = None
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
_retrieve_redis_state("nonexistent-key")
|
||||||
|
|
||||||
|
assert "State parameter has expired or does not exist" in str(exc_info.value)
|
||||||
|
|
||||||
|
@patch("core.mcp.auth.auth_flow.redis_client")
|
||||||
|
def test_retrieve_redis_state_invalid_json(self, mock_redis):
|
||||||
|
"""Test retrieving invalid JSON state from Redis."""
|
||||||
|
mock_redis.get.return_value = '{"invalid": json}'
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
_retrieve_redis_state("test-key")
|
||||||
|
|
||||||
|
assert "Invalid state parameter" in str(exc_info.value)
|
||||||
|
# State should still be deleted
|
||||||
|
mock_redis.delete.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
class TestOAuthDiscovery:
|
||||||
|
"""Test OAuth discovery functions."""
|
||||||
|
|
||||||
|
@patch("core.helper.ssrf_proxy.get")
|
||||||
|
def test_check_support_resource_discovery_success(self, mock_get):
|
||||||
|
"""Test successful resource discovery check."""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {"authorization_server_url": ["https://auth.example.com"]}
|
||||||
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
|
supported, auth_url = check_support_resource_discovery("https://api.example.com/endpoint")
|
||||||
|
|
||||||
|
assert supported is True
|
||||||
|
assert auth_url == "https://auth.example.com"
|
||||||
|
mock_get.assert_called_once_with(
|
||||||
|
"https://api.example.com/.well-known/oauth-protected-resource",
|
||||||
|
headers={"MCP-Protocol-Version": "2025-03-26", "User-Agent": "Dify"},
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("core.helper.ssrf_proxy.get")
|
||||||
|
def test_check_support_resource_discovery_not_supported(self, mock_get):
|
||||||
|
"""Test resource discovery not supported."""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status_code = 404
|
||||||
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
|
supported, auth_url = check_support_resource_discovery("https://api.example.com")
|
||||||
|
|
||||||
|
assert supported is False
|
||||||
|
assert auth_url == ""
|
||||||
|
|
||||||
|
@patch("core.helper.ssrf_proxy.get")
|
||||||
|
def test_check_support_resource_discovery_with_query_fragment(self, mock_get):
|
||||||
|
"""Test resource discovery with query and fragment."""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {"authorization_server_url": ["https://auth.example.com"]}
|
||||||
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
|
supported, auth_url = check_support_resource_discovery("https://api.example.com/path?query=1#fragment")
|
||||||
|
|
||||||
|
assert supported is True
|
||||||
|
assert auth_url == "https://auth.example.com"
|
||||||
|
mock_get.assert_called_once_with(
|
||||||
|
"https://api.example.com/.well-known/oauth-protected-resource?query=1#fragment",
|
||||||
|
headers={"MCP-Protocol-Version": "2025-03-26", "User-Agent": "Dify"},
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("core.helper.ssrf_proxy.get")
|
||||||
|
def test_discover_oauth_metadata_with_resource_discovery(self, mock_get):
|
||||||
|
"""Test OAuth metadata discovery with resource discovery support."""
|
||||||
|
with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check:
|
||||||
|
mock_check.return_value = (True, "https://auth.example.com")
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.is_success = True
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"authorization_endpoint": "https://auth.example.com/authorize",
|
||||||
|
"token_endpoint": "https://auth.example.com/token",
|
||||||
|
"response_types_supported": ["code"],
|
||||||
|
}
|
||||||
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
|
metadata = discover_oauth_metadata("https://api.example.com")
|
||||||
|
|
||||||
|
assert metadata is not None
|
||||||
|
assert metadata.authorization_endpoint == "https://auth.example.com/authorize"
|
||||||
|
assert metadata.token_endpoint == "https://auth.example.com/token"
|
||||||
|
mock_get.assert_called_once_with(
|
||||||
|
"https://auth.example.com/.well-known/oauth-authorization-server",
|
||||||
|
headers={"MCP-Protocol-Version": "2025-03-26"},
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("core.helper.ssrf_proxy.get")
|
||||||
|
def test_discover_oauth_metadata_without_resource_discovery(self, mock_get):
|
||||||
|
"""Test OAuth metadata discovery without resource discovery."""
|
||||||
|
with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check:
|
||||||
|
mock_check.return_value = (False, "")
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.is_success = True
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"authorization_endpoint": "https://api.example.com/oauth/authorize",
|
||||||
|
"token_endpoint": "https://api.example.com/oauth/token",
|
||||||
|
"response_types_supported": ["code"],
|
||||||
|
}
|
||||||
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
|
metadata = discover_oauth_metadata("https://api.example.com")
|
||||||
|
|
||||||
|
assert metadata is not None
|
||||||
|
assert metadata.authorization_endpoint == "https://api.example.com/oauth/authorize"
|
||||||
|
mock_get.assert_called_once_with(
|
||||||
|
"https://api.example.com/.well-known/oauth-authorization-server",
|
||||||
|
headers={"MCP-Protocol-Version": "2025-03-26"},
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("core.helper.ssrf_proxy.get")
|
||||||
|
def test_discover_oauth_metadata_not_found(self, mock_get):
|
||||||
|
"""Test OAuth metadata discovery when not found."""
|
||||||
|
with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check:
|
||||||
|
mock_check.return_value = (False, "")
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status_code = 404
|
||||||
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
|
metadata = discover_oauth_metadata("https://api.example.com")
|
||||||
|
|
||||||
|
assert metadata is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuthorizationFlow:
|
||||||
|
"""Test authorization flow functions."""
|
||||||
|
|
||||||
|
@patch("core.mcp.auth.auth_flow._create_secure_redis_state")
|
||||||
|
def test_start_authorization_with_metadata(self, mock_create_state):
|
||||||
|
"""Test starting authorization with metadata."""
|
||||||
|
mock_create_state.return_value = "secure-state-key"
|
||||||
|
|
||||||
|
metadata = OAuthMetadata(
|
||||||
|
authorization_endpoint="https://auth.example.com/authorize",
|
||||||
|
token_endpoint="https://auth.example.com/token",
|
||||||
|
response_types_supported=["code"],
|
||||||
|
code_challenge_methods_supported=["S256"],
|
||||||
|
)
|
||||||
|
client_info = OAuthClientInformation(client_id="test-client-id")
|
||||||
|
|
||||||
|
auth_url, code_verifier = start_authorization(
|
||||||
|
"https://api.example.com",
|
||||||
|
metadata,
|
||||||
|
client_info,
|
||||||
|
"https://redirect.example.com",
|
||||||
|
"provider-id",
|
||||||
|
"tenant-id",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify URL format
|
||||||
|
assert auth_url.startswith("https://auth.example.com/authorize?")
|
||||||
|
assert "response_type=code" in auth_url
|
||||||
|
assert "client_id=test-client-id" in auth_url
|
||||||
|
assert "code_challenge=" in auth_url
|
||||||
|
assert "code_challenge_method=S256" in auth_url
|
||||||
|
assert "redirect_uri=https%3A%2F%2Fredirect.example.com" in auth_url
|
||||||
|
assert "state=secure-state-key" in auth_url
|
||||||
|
|
||||||
|
# Verify code verifier
|
||||||
|
assert len(code_verifier) > 40
|
||||||
|
|
||||||
|
# Verify state was stored
|
||||||
|
mock_create_state.assert_called_once()
|
||||||
|
state_data = mock_create_state.call_args[0][0]
|
||||||
|
assert state_data.provider_id == "provider-id"
|
||||||
|
assert state_data.tenant_id == "tenant-id"
|
||||||
|
assert state_data.code_verifier == code_verifier
|
||||||
|
|
||||||
|
def test_start_authorization_without_metadata(self):
|
||||||
|
"""Test starting authorization without metadata."""
|
||||||
|
with patch("core.mcp.auth.auth_flow._create_secure_redis_state") as mock_create_state:
|
||||||
|
mock_create_state.return_value = "secure-state-key"
|
||||||
|
|
||||||
|
client_info = OAuthClientInformation(client_id="test-client-id")
|
||||||
|
|
||||||
|
auth_url, code_verifier = start_authorization(
|
||||||
|
"https://api.example.com",
|
||||||
|
None,
|
||||||
|
client_info,
|
||||||
|
"https://redirect.example.com",
|
||||||
|
"provider-id",
|
||||||
|
"tenant-id",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should use default authorization endpoint
|
||||||
|
assert auth_url.startswith("https://api.example.com/authorize?")
|
||||||
|
|
||||||
|
def test_start_authorization_invalid_metadata(self):
|
||||||
|
"""Test starting authorization with invalid metadata."""
|
||||||
|
metadata = OAuthMetadata(
|
||||||
|
authorization_endpoint="https://auth.example.com/authorize",
|
||||||
|
token_endpoint="https://auth.example.com/token",
|
||||||
|
response_types_supported=["token"], # No "code" support
|
||||||
|
code_challenge_methods_supported=["plain"], # No "S256" support
|
||||||
|
)
|
||||||
|
client_info = OAuthClientInformation(client_id="test-client-id")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
start_authorization(
|
||||||
|
"https://api.example.com",
|
||||||
|
metadata,
|
||||||
|
client_info,
|
||||||
|
"https://redirect.example.com",
|
||||||
|
"provider-id",
|
||||||
|
"tenant-id",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "does not support response type code" in str(exc_info.value)
|
||||||
|
|
||||||
|
@patch("core.helper.ssrf_proxy.post")
|
||||||
|
def test_exchange_authorization_success(self, mock_post):
|
||||||
|
"""Test successful authorization code exchange."""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.is_success = True
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"access_token": "new-access-token",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"expires_in": 3600,
|
||||||
|
"refresh_token": "new-refresh-token",
|
||||||
|
}
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
metadata = OAuthMetadata(
|
||||||
|
authorization_endpoint="https://auth.example.com/authorize",
|
||||||
|
token_endpoint="https://auth.example.com/token",
|
||||||
|
response_types_supported=["code"],
|
||||||
|
grant_types_supported=["authorization_code"],
|
||||||
|
)
|
||||||
|
client_info = OAuthClientInformation(client_id="test-client-id", client_secret="test-secret")
|
||||||
|
|
||||||
|
tokens = exchange_authorization(
|
||||||
|
"https://api.example.com",
|
||||||
|
metadata,
|
||||||
|
client_info,
|
||||||
|
"auth-code-123",
|
||||||
|
"code-verifier-xyz",
|
||||||
|
"https://redirect.example.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tokens.access_token == "new-access-token"
|
||||||
|
assert tokens.token_type == "Bearer"
|
||||||
|
assert tokens.expires_in == 3600
|
||||||
|
assert tokens.refresh_token == "new-refresh-token"
|
||||||
|
|
||||||
|
# Verify request
|
||||||
|
mock_post.assert_called_once_with(
|
||||||
|
"https://auth.example.com/token",
|
||||||
|
data={
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"client_id": "test-client-id",
|
||||||
|
"client_secret": "test-secret",
|
||||||
|
"code": "auth-code-123",
|
||||||
|
"code_verifier": "code-verifier-xyz",
|
||||||
|
"redirect_uri": "https://redirect.example.com",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("core.helper.ssrf_proxy.post")
|
||||||
|
def test_exchange_authorization_failure(self, mock_post):
|
||||||
|
"""Test failed authorization code exchange."""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.is_success = False
|
||||||
|
mock_response.status_code = 400
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
client_info = OAuthClientInformation(client_id="test-client-id")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
exchange_authorization(
|
||||||
|
"https://api.example.com",
|
||||||
|
None,
|
||||||
|
client_info,
|
||||||
|
"invalid-code",
|
||||||
|
"code-verifier",
|
||||||
|
"https://redirect.example.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Token exchange failed: HTTP 400" in str(exc_info.value)
|
||||||
|
|
||||||
|
@patch("core.helper.ssrf_proxy.post")
|
||||||
|
def test_refresh_authorization_success(self, mock_post):
|
||||||
|
"""Test successful token refresh."""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.is_success = True
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"access_token": "refreshed-access-token",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"expires_in": 3600,
|
||||||
|
"refresh_token": "new-refresh-token",
|
||||||
|
}
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
metadata = OAuthMetadata(
|
||||||
|
authorization_endpoint="https://auth.example.com/authorize",
|
||||||
|
token_endpoint="https://auth.example.com/token",
|
||||||
|
response_types_supported=["code"],
|
||||||
|
grant_types_supported=["refresh_token"],
|
||||||
|
)
|
||||||
|
client_info = OAuthClientInformation(client_id="test-client-id")
|
||||||
|
|
||||||
|
tokens = refresh_authorization("https://api.example.com", metadata, client_info, "old-refresh-token")
|
||||||
|
|
||||||
|
assert tokens.access_token == "refreshed-access-token"
|
||||||
|
assert tokens.refresh_token == "new-refresh-token"
|
||||||
|
|
||||||
|
# Verify request
|
||||||
|
mock_post.assert_called_once_with(
|
||||||
|
"https://auth.example.com/token",
|
||||||
|
data={
|
||||||
|
"grant_type": "refresh_token",
|
||||||
|
"client_id": "test-client-id",
|
||||||
|
"refresh_token": "old-refresh-token",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("core.helper.ssrf_proxy.post")
|
||||||
|
def test_register_client_success(self, mock_post):
|
||||||
|
"""Test successful client registration."""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.is_success = True
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"client_id": "new-client-id",
|
||||||
|
"client_secret": "new-client-secret",
|
||||||
|
"client_name": "Dify",
|
||||||
|
"redirect_uris": ["https://redirect.example.com"],
|
||||||
|
}
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
metadata = OAuthMetadata(
|
||||||
|
authorization_endpoint="https://auth.example.com/authorize",
|
||||||
|
token_endpoint="https://auth.example.com/token",
|
||||||
|
registration_endpoint="https://auth.example.com/register",
|
||||||
|
response_types_supported=["code"],
|
||||||
|
)
|
||||||
|
client_metadata = OAuthClientMetadata(
|
||||||
|
client_name="Dify",
|
||||||
|
redirect_uris=["https://redirect.example.com"],
|
||||||
|
grant_types=["authorization_code"],
|
||||||
|
response_types=["code"],
|
||||||
|
)
|
||||||
|
|
||||||
|
client_info = register_client("https://api.example.com", metadata, client_metadata)
|
||||||
|
|
||||||
|
assert isinstance(client_info, OAuthClientInformationFull)
|
||||||
|
assert client_info.client_id == "new-client-id"
|
||||||
|
assert client_info.client_secret == "new-client-secret"
|
||||||
|
|
||||||
|
# Verify request
|
||||||
|
mock_post.assert_called_once_with(
|
||||||
|
"https://auth.example.com/register",
|
||||||
|
json=client_metadata.model_dump(),
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_register_client_no_endpoint(self):
|
||||||
|
"""Test client registration when no endpoint available."""
|
||||||
|
metadata = OAuthMetadata(
|
||||||
|
authorization_endpoint="https://auth.example.com/authorize",
|
||||||
|
token_endpoint="https://auth.example.com/token",
|
||||||
|
registration_endpoint=None,
|
||||||
|
response_types_supported=["code"],
|
||||||
|
)
|
||||||
|
client_metadata = OAuthClientMetadata(client_name="Dify", redirect_uris=["https://redirect.example.com"])
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
register_client("https://api.example.com", metadata, client_metadata)
|
||||||
|
|
||||||
|
assert "does not support dynamic client registration" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCallbackHandling:
|
||||||
|
"""Test OAuth callback handling."""
|
||||||
|
|
||||||
|
@patch("core.mcp.auth.auth_flow._retrieve_redis_state")
|
||||||
|
@patch("core.mcp.auth.auth_flow.exchange_authorization")
|
||||||
|
def test_handle_callback_success(self, mock_exchange, mock_retrieve_state):
|
||||||
|
"""Test successful callback handling."""
|
||||||
|
# Setup state
|
||||||
|
state_data = OAuthCallbackState(
|
||||||
|
provider_id="test-provider",
|
||||||
|
tenant_id="test-tenant",
|
||||||
|
server_url="https://api.example.com",
|
||||||
|
metadata=None,
|
||||||
|
client_information=OAuthClientInformation(client_id="test-client"),
|
||||||
|
code_verifier="test-verifier",
|
||||||
|
redirect_uri="https://redirect.example.com",
|
||||||
|
)
|
||||||
|
mock_retrieve_state.return_value = state_data
|
||||||
|
|
||||||
|
# Setup token exchange
|
||||||
|
tokens = OAuthTokens(
|
||||||
|
access_token="new-token",
|
||||||
|
token_type="Bearer",
|
||||||
|
expires_in=3600,
|
||||||
|
)
|
||||||
|
mock_exchange.return_value = tokens
|
||||||
|
|
||||||
|
# Setup service
|
||||||
|
mock_service = Mock()
|
||||||
|
|
||||||
|
state_result, tokens_result = handle_callback("state-key", "auth-code")
|
||||||
|
|
||||||
|
assert state_result == state_data
|
||||||
|
assert tokens_result == tokens
|
||||||
|
|
||||||
|
# Verify calls
|
||||||
|
mock_retrieve_state.assert_called_once_with("state-key")
|
||||||
|
mock_exchange.assert_called_once_with(
|
||||||
|
"https://api.example.com",
|
||||||
|
None,
|
||||||
|
state_data.client_information,
|
||||||
|
"auth-code",
|
||||||
|
"test-verifier",
|
||||||
|
"https://redirect.example.com",
|
||||||
|
)
|
||||||
|
# Note: handle_callback no longer saves tokens directly, it just returns them
|
||||||
|
# The caller (e.g., controller) is responsible for saving via execute_auth_actions
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuthOrchestration:
|
||||||
|
"""Test the main auth orchestration function."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_provider(self):
|
||||||
|
"""Create a mock provider entity."""
|
||||||
|
provider = Mock(spec=MCPProviderEntity)
|
||||||
|
provider.id = "provider-id"
|
||||||
|
provider.tenant_id = "tenant-id"
|
||||||
|
provider.decrypt_server_url.return_value = "https://api.example.com"
|
||||||
|
provider.client_metadata = OAuthClientMetadata(
|
||||||
|
client_name="Dify",
|
||||||
|
redirect_uris=["https://redirect.example.com"],
|
||||||
|
)
|
||||||
|
provider.redirect_url = "https://redirect.example.com"
|
||||||
|
provider.retrieve_client_information.return_value = None
|
||||||
|
provider.retrieve_tokens.return_value = None
|
||||||
|
return provider
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_service(self):
|
||||||
|
"""Create a mock MCP service."""
|
||||||
|
return Mock()
|
||||||
|
|
||||||
|
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
|
||||||
|
@patch("core.mcp.auth.auth_flow.register_client")
|
||||||
|
@patch("core.mcp.auth.auth_flow.start_authorization")
|
||||||
|
def test_auth_new_registration(self, mock_start_auth, mock_register, mock_discover, mock_provider, mock_service):
|
||||||
|
"""Test auth flow for new client registration."""
|
||||||
|
# Setup
|
||||||
|
mock_discover.return_value = OAuthMetadata(
|
||||||
|
authorization_endpoint="https://auth.example.com/authorize",
|
||||||
|
token_endpoint="https://auth.example.com/token",
|
||||||
|
response_types_supported=["code"],
|
||||||
|
grant_types_supported=["authorization_code"],
|
||||||
|
)
|
||||||
|
mock_register.return_value = OAuthClientInformationFull(
|
||||||
|
client_id="new-client-id",
|
||||||
|
client_name="Dify",
|
||||||
|
redirect_uris=["https://redirect.example.com"],
|
||||||
|
)
|
||||||
|
mock_start_auth.return_value = ("https://auth.example.com/authorize?...", "code-verifier")
|
||||||
|
|
||||||
|
result = auth(mock_provider)
|
||||||
|
|
||||||
|
# auth() now returns AuthResult
|
||||||
|
assert isinstance(result, AuthResult)
|
||||||
|
assert result.response == {"authorization_url": "https://auth.example.com/authorize?..."}
|
||||||
|
|
||||||
|
# Verify that the result contains the correct actions
|
||||||
|
assert len(result.actions) == 2
|
||||||
|
# Check for SAVE_CLIENT_INFO action
|
||||||
|
client_info_action = next(a for a in result.actions if a.action_type == AuthActionType.SAVE_CLIENT_INFO)
|
||||||
|
assert client_info_action.data == {"client_information": mock_register.return_value.model_dump()}
|
||||||
|
assert client_info_action.provider_id == "provider-id"
|
||||||
|
assert client_info_action.tenant_id == "tenant-id"
|
||||||
|
|
||||||
|
# Check for SAVE_CODE_VERIFIER action
|
||||||
|
verifier_action = next(a for a in result.actions if a.action_type == AuthActionType.SAVE_CODE_VERIFIER)
|
||||||
|
assert verifier_action.data == {"code_verifier": "code-verifier"}
|
||||||
|
assert verifier_action.provider_id == "provider-id"
|
||||||
|
assert verifier_action.tenant_id == "tenant-id"
|
||||||
|
|
||||||
|
# Verify calls
|
||||||
|
mock_register.assert_called_once()
|
||||||
|
|
||||||
|
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
|
||||||
|
@patch("core.mcp.auth.auth_flow._retrieve_redis_state")
|
||||||
|
@patch("core.mcp.auth.auth_flow.exchange_authorization")
|
||||||
|
def test_auth_exchange_code(self, mock_exchange, mock_retrieve_state, mock_discover, mock_provider, mock_service):
|
||||||
|
"""Test auth flow for exchanging authorization code."""
|
||||||
|
# Setup metadata discovery
|
||||||
|
mock_discover.return_value = OAuthMetadata(
|
||||||
|
authorization_endpoint="https://auth.example.com/authorize",
|
||||||
|
token_endpoint="https://auth.example.com/token",
|
||||||
|
response_types_supported=["code"],
|
||||||
|
grant_types_supported=["authorization_code"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup existing client
|
||||||
|
mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client")
|
||||||
|
|
||||||
|
# Setup state retrieval
|
||||||
|
state_data = OAuthCallbackState(
|
||||||
|
provider_id="provider-id",
|
||||||
|
tenant_id="tenant-id",
|
||||||
|
server_url="https://api.example.com",
|
||||||
|
metadata=None,
|
||||||
|
client_information=OAuthClientInformation(client_id="existing-client"),
|
||||||
|
code_verifier="test-verifier",
|
||||||
|
redirect_uri="https://redirect.example.com",
|
||||||
|
)
|
||||||
|
mock_retrieve_state.return_value = state_data
|
||||||
|
|
||||||
|
# Setup token exchange
|
||||||
|
tokens = OAuthTokens(access_token="new-token", token_type="Bearer", expires_in=3600)
|
||||||
|
mock_exchange.return_value = tokens
|
||||||
|
|
||||||
|
result = auth(mock_provider, authorization_code="auth-code", state_param="state-key")
|
||||||
|
|
||||||
|
# auth() now returns AuthResult, not a dict
|
||||||
|
assert isinstance(result, AuthResult)
|
||||||
|
assert result.response == {"result": "success"}
|
||||||
|
|
||||||
|
# Verify that the result contains the correct action
|
||||||
|
assert len(result.actions) == 1
|
||||||
|
assert result.actions[0].action_type == AuthActionType.SAVE_TOKENS
|
||||||
|
assert result.actions[0].data == tokens.model_dump()
|
||||||
|
assert result.actions[0].provider_id == "provider-id"
|
||||||
|
assert result.actions[0].tenant_id == "tenant-id"
|
||||||
|
|
||||||
|
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
|
||||||
|
def test_auth_exchange_code_without_state(self, mock_discover, mock_provider, mock_service):
|
||||||
|
"""Test auth flow fails when exchanging code without state."""
|
||||||
|
# Setup metadata discovery
|
||||||
|
mock_discover.return_value = OAuthMetadata(
|
||||||
|
authorization_endpoint="https://auth.example.com/authorize",
|
||||||
|
token_endpoint="https://auth.example.com/token",
|
||||||
|
response_types_supported=["code"],
|
||||||
|
grant_types_supported=["authorization_code"],
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
auth(mock_provider, authorization_code="auth-code")
|
||||||
|
|
||||||
|
assert "State parameter is required" in str(exc_info.value)
|
||||||
|
|
||||||
|
@patch("core.mcp.auth.auth_flow.refresh_authorization")
|
||||||
|
def test_auth_refresh_token(self, mock_refresh, mock_provider, mock_service):
|
||||||
|
"""Test auth flow for refreshing tokens."""
|
||||||
|
# Setup existing client and tokens
|
||||||
|
mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client")
|
||||||
|
mock_provider.retrieve_tokens.return_value = OAuthTokens(
|
||||||
|
access_token="old-token",
|
||||||
|
token_type="Bearer",
|
||||||
|
expires_in=0,
|
||||||
|
refresh_token="refresh-token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup refresh
|
||||||
|
new_tokens = OAuthTokens(
|
||||||
|
access_token="refreshed-token",
|
||||||
|
token_type="Bearer",
|
||||||
|
expires_in=3600,
|
||||||
|
refresh_token="new-refresh-token",
|
||||||
|
)
|
||||||
|
mock_refresh.return_value = new_tokens
|
||||||
|
|
||||||
|
with patch("core.mcp.auth.auth_flow.discover_oauth_metadata") as mock_discover:
|
||||||
|
mock_discover.return_value = OAuthMetadata(
|
||||||
|
authorization_endpoint="https://auth.example.com/authorize",
|
||||||
|
token_endpoint="https://auth.example.com/token",
|
||||||
|
response_types_supported=["code"],
|
||||||
|
grant_types_supported=["authorization_code"],
|
||||||
|
)
|
||||||
|
|
||||||
|
result = auth(mock_provider)
|
||||||
|
|
||||||
|
# auth() now returns AuthResult
|
||||||
|
assert isinstance(result, AuthResult)
|
||||||
|
assert result.response == {"result": "success"}
|
||||||
|
|
||||||
|
# Verify that the result contains the correct action
|
||||||
|
assert len(result.actions) == 1
|
||||||
|
assert result.actions[0].action_type == AuthActionType.SAVE_TOKENS
|
||||||
|
assert result.actions[0].data == new_tokens.model_dump()
|
||||||
|
assert result.actions[0].provider_id == "provider-id"
|
||||||
|
assert result.actions[0].tenant_id == "tenant-id"
|
||||||
|
|
||||||
|
# Verify refresh was called
|
||||||
|
mock_refresh.assert_called_once()
|
||||||
|
|
||||||
|
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
|
||||||
|
def test_auth_registration_fails_with_code(self, mock_discover, mock_provider, mock_service):
|
||||||
|
"""Test auth fails when no client info exists but code is provided."""
|
||||||
|
# Setup metadata discovery
|
||||||
|
mock_discover.return_value = OAuthMetadata(
|
||||||
|
authorization_endpoint="https://auth.example.com/authorize",
|
||||||
|
token_endpoint="https://auth.example.com/token",
|
||||||
|
response_types_supported=["code"],
|
||||||
|
grant_types_supported=["authorization_code"],
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_provider.retrieve_client_information.return_value = None
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
auth(mock_provider, authorization_code="auth-code")
|
||||||
|
|
||||||
|
assert "Existing OAuth client information is required" in str(exc_info.value)
|
||||||
239
api/tests/unit_tests/core/mcp/test_entities.py
Normal file
239
api/tests/unit_tests/core/mcp/test_entities.py
Normal file
@ -0,0 +1,239 @@
|
|||||||
|
"""Unit tests for MCP entities module."""
|
||||||
|
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from core.mcp.entities import (
|
||||||
|
SUPPORTED_PROTOCOL_VERSIONS,
|
||||||
|
LifespanContextT,
|
||||||
|
RequestContext,
|
||||||
|
SessionT,
|
||||||
|
)
|
||||||
|
from core.mcp.session.base_session import BaseSession
|
||||||
|
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestParams
|
||||||
|
|
||||||
|
|
||||||
|
class TestProtocolVersions:
|
||||||
|
"""Test protocol version constants."""
|
||||||
|
|
||||||
|
def test_supported_protocol_versions(self):
|
||||||
|
"""Test supported protocol versions list."""
|
||||||
|
assert isinstance(SUPPORTED_PROTOCOL_VERSIONS, list)
|
||||||
|
assert len(SUPPORTED_PROTOCOL_VERSIONS) >= 3
|
||||||
|
assert "2024-11-05" in SUPPORTED_PROTOCOL_VERSIONS
|
||||||
|
assert "2025-03-26" in SUPPORTED_PROTOCOL_VERSIONS
|
||||||
|
assert LATEST_PROTOCOL_VERSION in SUPPORTED_PROTOCOL_VERSIONS
|
||||||
|
|
||||||
|
def test_latest_protocol_version_is_supported(self):
|
||||||
|
"""Test that latest protocol version is in supported versions."""
|
||||||
|
assert LATEST_PROTOCOL_VERSION in SUPPORTED_PROTOCOL_VERSIONS
|
||||||
|
|
||||||
|
|
||||||
|
class TestRequestContext:
|
||||||
|
"""Test RequestContext dataclass."""
|
||||||
|
|
||||||
|
def test_request_context_creation(self):
|
||||||
|
"""Test creating a RequestContext instance."""
|
||||||
|
mock_session = Mock(spec=BaseSession)
|
||||||
|
mock_lifespan = {"key": "value"}
|
||||||
|
mock_meta = RequestParams.Meta(progressToken="test-token")
|
||||||
|
|
||||||
|
context = RequestContext(
|
||||||
|
request_id="test-request-123",
|
||||||
|
meta=mock_meta,
|
||||||
|
session=mock_session,
|
||||||
|
lifespan_context=mock_lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert context.request_id == "test-request-123"
|
||||||
|
assert context.meta == mock_meta
|
||||||
|
assert context.session == mock_session
|
||||||
|
assert context.lifespan_context == mock_lifespan
|
||||||
|
|
||||||
|
def test_request_context_with_none_meta(self):
|
||||||
|
"""Test creating RequestContext with None meta."""
|
||||||
|
mock_session = Mock(spec=BaseSession)
|
||||||
|
|
||||||
|
context = RequestContext(
|
||||||
|
request_id=42, # Can be int or string
|
||||||
|
meta=None,
|
||||||
|
session=mock_session,
|
||||||
|
lifespan_context=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert context.request_id == 42
|
||||||
|
assert context.meta is None
|
||||||
|
assert context.session == mock_session
|
||||||
|
assert context.lifespan_context is None
|
||||||
|
|
||||||
|
def test_request_context_attributes(self):
|
||||||
|
"""Test RequestContext attributes are accessible."""
|
||||||
|
mock_session = Mock(spec=BaseSession)
|
||||||
|
|
||||||
|
context = RequestContext(
|
||||||
|
request_id="test-123",
|
||||||
|
meta=None,
|
||||||
|
session=mock_session,
|
||||||
|
lifespan_context=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify attributes are accessible
|
||||||
|
assert hasattr(context, "request_id")
|
||||||
|
assert hasattr(context, "meta")
|
||||||
|
assert hasattr(context, "session")
|
||||||
|
assert hasattr(context, "lifespan_context")
|
||||||
|
|
||||||
|
# Verify values
|
||||||
|
assert context.request_id == "test-123"
|
||||||
|
assert context.meta is None
|
||||||
|
assert context.session == mock_session
|
||||||
|
assert context.lifespan_context is None
|
||||||
|
|
||||||
|
def test_request_context_generic_typing(self):
|
||||||
|
"""Test RequestContext with different generic types."""
|
||||||
|
# Create a mock session with specific type
|
||||||
|
mock_session = Mock(spec=BaseSession)
|
||||||
|
|
||||||
|
# Create context with string lifespan context
|
||||||
|
context_str = RequestContext[BaseSession, str](
|
||||||
|
request_id="test-1",
|
||||||
|
meta=None,
|
||||||
|
session=mock_session,
|
||||||
|
lifespan_context="string-context",
|
||||||
|
)
|
||||||
|
assert isinstance(context_str.lifespan_context, str)
|
||||||
|
|
||||||
|
# Create context with dict lifespan context
|
||||||
|
context_dict = RequestContext[BaseSession, dict](
|
||||||
|
request_id="test-2",
|
||||||
|
meta=None,
|
||||||
|
session=mock_session,
|
||||||
|
lifespan_context={"key": "value"},
|
||||||
|
)
|
||||||
|
assert isinstance(context_dict.lifespan_context, dict)
|
||||||
|
|
||||||
|
# Create context with custom object lifespan context
|
||||||
|
class CustomLifespan:
|
||||||
|
def __init__(self, data):
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
custom_lifespan = CustomLifespan("test-data")
|
||||||
|
context_custom = RequestContext[BaseSession, CustomLifespan](
|
||||||
|
request_id="test-3",
|
||||||
|
meta=None,
|
||||||
|
session=mock_session,
|
||||||
|
lifespan_context=custom_lifespan,
|
||||||
|
)
|
||||||
|
assert isinstance(context_custom.lifespan_context, CustomLifespan)
|
||||||
|
assert context_custom.lifespan_context.data == "test-data"
|
||||||
|
|
||||||
|
def test_request_context_with_progress_meta(self):
|
||||||
|
"""Test RequestContext with progress metadata."""
|
||||||
|
mock_session = Mock(spec=BaseSession)
|
||||||
|
progress_meta = RequestParams.Meta(progressToken="progress-123")
|
||||||
|
|
||||||
|
context = RequestContext(
|
||||||
|
request_id="req-456",
|
||||||
|
meta=progress_meta,
|
||||||
|
session=mock_session,
|
||||||
|
lifespan_context=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert context.meta is not None
|
||||||
|
assert context.meta.progressToken == "progress-123"
|
||||||
|
|
||||||
|
def test_request_context_equality(self):
|
||||||
|
"""Test RequestContext equality comparison."""
|
||||||
|
mock_session1 = Mock(spec=BaseSession)
|
||||||
|
mock_session2 = Mock(spec=BaseSession)
|
||||||
|
|
||||||
|
context1 = RequestContext(
|
||||||
|
request_id="test-123",
|
||||||
|
meta=None,
|
||||||
|
session=mock_session1,
|
||||||
|
lifespan_context="context",
|
||||||
|
)
|
||||||
|
|
||||||
|
context2 = RequestContext(
|
||||||
|
request_id="test-123",
|
||||||
|
meta=None,
|
||||||
|
session=mock_session1,
|
||||||
|
lifespan_context="context",
|
||||||
|
)
|
||||||
|
|
||||||
|
context3 = RequestContext(
|
||||||
|
request_id="test-456",
|
||||||
|
meta=None,
|
||||||
|
session=mock_session1,
|
||||||
|
lifespan_context="context",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Same values should be equal
|
||||||
|
assert context1 == context2
|
||||||
|
|
||||||
|
# Different request_id should not be equal
|
||||||
|
assert context1 != context3
|
||||||
|
|
||||||
|
# Different session should not be equal
|
||||||
|
context4 = RequestContext(
|
||||||
|
request_id="test-123",
|
||||||
|
meta=None,
|
||||||
|
session=mock_session2,
|
||||||
|
lifespan_context="context",
|
||||||
|
)
|
||||||
|
assert context1 != context4
|
||||||
|
|
||||||
|
def test_request_context_repr(self):
|
||||||
|
"""Test RequestContext string representation."""
|
||||||
|
mock_session = Mock(spec=BaseSession)
|
||||||
|
mock_session.__repr__ = Mock(return_value="<MockSession>")
|
||||||
|
|
||||||
|
context = RequestContext(
|
||||||
|
request_id="test-123",
|
||||||
|
meta=None,
|
||||||
|
session=mock_session,
|
||||||
|
lifespan_context={"data": "test"},
|
||||||
|
)
|
||||||
|
|
||||||
|
repr_str = repr(context)
|
||||||
|
assert "RequestContext" in repr_str
|
||||||
|
assert "test-123" in repr_str
|
||||||
|
assert "MockSession" in repr_str
|
||||||
|
|
||||||
|
|
||||||
|
class TestTypeVariables:
|
||||||
|
"""Test type variables defined in the module."""
|
||||||
|
|
||||||
|
def test_session_type_var(self):
|
||||||
|
"""Test SessionT type variable."""
|
||||||
|
|
||||||
|
# Create a custom session class
|
||||||
|
class CustomSession(BaseSession):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Use in generic context
|
||||||
|
def process_session(session: SessionT) -> SessionT:
|
||||||
|
return session
|
||||||
|
|
||||||
|
mock_session = Mock(spec=CustomSession)
|
||||||
|
result = process_session(mock_session)
|
||||||
|
assert result == mock_session
|
||||||
|
|
||||||
|
def test_lifespan_context_type_var(self):
|
||||||
|
"""Test LifespanContextT type variable."""
|
||||||
|
|
||||||
|
# Use in generic context
|
||||||
|
def process_lifespan(context: LifespanContextT) -> LifespanContextT:
|
||||||
|
return context
|
||||||
|
|
||||||
|
# Test with different types
|
||||||
|
str_context = "string-context"
|
||||||
|
assert process_lifespan(str_context) == str_context
|
||||||
|
|
||||||
|
dict_context = {"key": "value"}
|
||||||
|
assert process_lifespan(dict_context) == dict_context
|
||||||
|
|
||||||
|
class CustomContext:
|
||||||
|
pass
|
||||||
|
|
||||||
|
custom_context = CustomContext()
|
||||||
|
assert process_lifespan(custom_context) == custom_context
|
||||||
205
api/tests/unit_tests/core/mcp/test_error.py
Normal file
205
api/tests/unit_tests/core/mcp/test_error.py
Normal file
@ -0,0 +1,205 @@
|
|||||||
|
"""Unit tests for MCP error classes."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.mcp.error import MCPAuthError, MCPConnectionError, MCPError
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPError:
|
||||||
|
"""Test MCPError base exception class."""
|
||||||
|
|
||||||
|
def test_mcp_error_creation(self):
|
||||||
|
"""Test creating MCPError instance."""
|
||||||
|
error = MCPError("Test error message")
|
||||||
|
assert str(error) == "Test error message"
|
||||||
|
assert isinstance(error, Exception)
|
||||||
|
|
||||||
|
def test_mcp_error_inheritance(self):
|
||||||
|
"""Test MCPError inherits from Exception."""
|
||||||
|
error = MCPError()
|
||||||
|
assert isinstance(error, Exception)
|
||||||
|
assert type(error).__name__ == "MCPError"
|
||||||
|
|
||||||
|
def test_mcp_error_with_empty_message(self):
|
||||||
|
"""Test MCPError with empty message."""
|
||||||
|
error = MCPError()
|
||||||
|
assert str(error) == ""
|
||||||
|
|
||||||
|
def test_mcp_error_raise(self):
|
||||||
|
"""Test raising MCPError."""
|
||||||
|
with pytest.raises(MCPError) as exc_info:
|
||||||
|
raise MCPError("Something went wrong")
|
||||||
|
|
||||||
|
assert str(exc_info.value) == "Something went wrong"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPConnectionError:
|
||||||
|
"""Test MCPConnectionError exception class."""
|
||||||
|
|
||||||
|
def test_mcp_connection_error_creation(self):
|
||||||
|
"""Test creating MCPConnectionError instance."""
|
||||||
|
error = MCPConnectionError("Connection failed")
|
||||||
|
assert str(error) == "Connection failed"
|
||||||
|
assert isinstance(error, MCPError)
|
||||||
|
assert isinstance(error, Exception)
|
||||||
|
|
||||||
|
def test_mcp_connection_error_inheritance(self):
|
||||||
|
"""Test MCPConnectionError inheritance chain."""
|
||||||
|
error = MCPConnectionError()
|
||||||
|
assert isinstance(error, MCPConnectionError)
|
||||||
|
assert isinstance(error, MCPError)
|
||||||
|
assert isinstance(error, Exception)
|
||||||
|
|
||||||
|
def test_mcp_connection_error_raise(self):
|
||||||
|
"""Test raising MCPConnectionError."""
|
||||||
|
with pytest.raises(MCPConnectionError) as exc_info:
|
||||||
|
raise MCPConnectionError("Unable to connect to server")
|
||||||
|
|
||||||
|
assert str(exc_info.value) == "Unable to connect to server"
|
||||||
|
|
||||||
|
def test_mcp_connection_error_catch_as_mcp_error(self):
|
||||||
|
"""Test catching MCPConnectionError as MCPError."""
|
||||||
|
with pytest.raises(MCPError) as exc_info:
|
||||||
|
raise MCPConnectionError("Connection issue")
|
||||||
|
|
||||||
|
assert isinstance(exc_info.value, MCPConnectionError)
|
||||||
|
assert str(exc_info.value) == "Connection issue"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPAuthError:
|
||||||
|
"""Test MCPAuthError exception class."""
|
||||||
|
|
||||||
|
def test_mcp_auth_error_creation(self):
|
||||||
|
"""Test creating MCPAuthError instance."""
|
||||||
|
error = MCPAuthError("Authentication failed")
|
||||||
|
assert str(error) == "Authentication failed"
|
||||||
|
assert isinstance(error, MCPConnectionError)
|
||||||
|
assert isinstance(error, MCPError)
|
||||||
|
assert isinstance(error, Exception)
|
||||||
|
|
||||||
|
def test_mcp_auth_error_inheritance(self):
|
||||||
|
"""Test MCPAuthError inheritance chain."""
|
||||||
|
error = MCPAuthError()
|
||||||
|
assert isinstance(error, MCPAuthError)
|
||||||
|
assert isinstance(error, MCPConnectionError)
|
||||||
|
assert isinstance(error, MCPError)
|
||||||
|
assert isinstance(error, Exception)
|
||||||
|
|
||||||
|
def test_mcp_auth_error_raise(self):
|
||||||
|
"""Test raising MCPAuthError."""
|
||||||
|
with pytest.raises(MCPAuthError) as exc_info:
|
||||||
|
raise MCPAuthError("Invalid credentials")
|
||||||
|
|
||||||
|
assert str(exc_info.value) == "Invalid credentials"
|
||||||
|
|
||||||
|
def test_mcp_auth_error_catch_hierarchy(self):
|
||||||
|
"""Test catching MCPAuthError at different levels."""
|
||||||
|
# Catch as MCPAuthError
|
||||||
|
with pytest.raises(MCPAuthError) as exc_info:
|
||||||
|
raise MCPAuthError("Auth specific error")
|
||||||
|
assert str(exc_info.value) == "Auth specific error"
|
||||||
|
|
||||||
|
# Catch as MCPConnectionError
|
||||||
|
with pytest.raises(MCPConnectionError) as exc_info:
|
||||||
|
raise MCPAuthError("Auth connection error")
|
||||||
|
assert isinstance(exc_info.value, MCPAuthError)
|
||||||
|
assert str(exc_info.value) == "Auth connection error"
|
||||||
|
|
||||||
|
# Catch as MCPError
|
||||||
|
with pytest.raises(MCPError) as exc_info:
|
||||||
|
raise MCPAuthError("Auth base error")
|
||||||
|
assert isinstance(exc_info.value, MCPAuthError)
|
||||||
|
assert str(exc_info.value) == "Auth base error"
|
||||||
|
|
||||||
|
|
||||||
|
class TestErrorHierarchy:
|
||||||
|
"""Test the complete error hierarchy."""
|
||||||
|
|
||||||
|
def test_exception_hierarchy(self):
|
||||||
|
"""Test the complete exception hierarchy."""
|
||||||
|
# Create instances
|
||||||
|
base_error = MCPError("base")
|
||||||
|
connection_error = MCPConnectionError("connection")
|
||||||
|
auth_error = MCPAuthError("auth")
|
||||||
|
|
||||||
|
# Test type relationships
|
||||||
|
assert not isinstance(base_error, MCPConnectionError)
|
||||||
|
assert not isinstance(base_error, MCPAuthError)
|
||||||
|
|
||||||
|
assert isinstance(connection_error, MCPError)
|
||||||
|
assert not isinstance(connection_error, MCPAuthError)
|
||||||
|
|
||||||
|
assert isinstance(auth_error, MCPError)
|
||||||
|
assert isinstance(auth_error, MCPConnectionError)
|
||||||
|
|
||||||
|
def test_error_handling_patterns(self):
|
||||||
|
"""Test common error handling patterns."""
|
||||||
|
|
||||||
|
def raise_auth_error():
|
||||||
|
raise MCPAuthError("401 Unauthorized")
|
||||||
|
|
||||||
|
def raise_connection_error():
|
||||||
|
raise MCPConnectionError("Connection timeout")
|
||||||
|
|
||||||
|
def raise_base_error():
|
||||||
|
raise MCPError("Generic error")
|
||||||
|
|
||||||
|
# Pattern 1: Catch specific errors first
|
||||||
|
errors_caught = []
|
||||||
|
|
||||||
|
for error_func in [raise_auth_error, raise_connection_error, raise_base_error]:
|
||||||
|
try:
|
||||||
|
error_func()
|
||||||
|
except MCPAuthError:
|
||||||
|
errors_caught.append("auth")
|
||||||
|
except MCPConnectionError:
|
||||||
|
errors_caught.append("connection")
|
||||||
|
except MCPError:
|
||||||
|
errors_caught.append("base")
|
||||||
|
|
||||||
|
assert errors_caught == ["auth", "connection", "base"]
|
||||||
|
|
||||||
|
# Pattern 2: Catch all as base error
|
||||||
|
for error_func in [raise_auth_error, raise_connection_error, raise_base_error]:
|
||||||
|
with pytest.raises(MCPError) as exc_info:
|
||||||
|
error_func()
|
||||||
|
assert isinstance(exc_info.value, MCPError)
|
||||||
|
|
||||||
|
def test_error_with_cause(self):
|
||||||
|
"""Test errors with cause (chained exceptions)."""
|
||||||
|
original_error = ValueError("Original error")
|
||||||
|
|
||||||
|
def raise_chained_error():
|
||||||
|
try:
|
||||||
|
raise original_error
|
||||||
|
except ValueError as e:
|
||||||
|
raise MCPConnectionError("Connection failed") from e
|
||||||
|
|
||||||
|
with pytest.raises(MCPConnectionError) as exc_info:
|
||||||
|
raise_chained_error()
|
||||||
|
|
||||||
|
assert str(exc_info.value) == "Connection failed"
|
||||||
|
assert exc_info.value.__cause__ == original_error
|
||||||
|
|
||||||
|
def test_error_comparison(self):
|
||||||
|
"""Test error instance comparison."""
|
||||||
|
error1 = MCPError("Test message")
|
||||||
|
error2 = MCPError("Test message")
|
||||||
|
error3 = MCPError("Different message")
|
||||||
|
|
||||||
|
# Errors are not equal even with same message (different instances)
|
||||||
|
assert error1 != error2
|
||||||
|
assert error1 != error3
|
||||||
|
|
||||||
|
# But they have the same type
|
||||||
|
assert type(error1) == type(error2) == type(error3)
|
||||||
|
|
||||||
|
def test_error_representation(self):
|
||||||
|
"""Test error string representation."""
|
||||||
|
base_error = MCPError("Base error message")
|
||||||
|
connection_error = MCPConnectionError("Connection error message")
|
||||||
|
auth_error = MCPAuthError("Auth error message")
|
||||||
|
|
||||||
|
assert repr(base_error) == "MCPError('Base error message')"
|
||||||
|
assert repr(connection_error) == "MCPConnectionError('Connection error message')"
|
||||||
|
assert repr(auth_error) == "MCPAuthError('Auth error message')"
|
||||||
382
api/tests/unit_tests/core/mcp/test_mcp_client.py
Normal file
382
api/tests/unit_tests/core/mcp/test_mcp_client.py
Normal file
@ -0,0 +1,382 @@
|
|||||||
|
"""Unit tests for MCP client."""
|
||||||
|
|
||||||
|
from contextlib import ExitStack
|
||||||
|
from types import TracebackType
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.mcp.error import MCPConnectionError
|
||||||
|
from core.mcp.mcp_client import MCPClient
|
||||||
|
from core.mcp.types import CallToolResult, ListToolsResult, TextContent, Tool, ToolAnnotations
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPClient:
|
||||||
|
"""Test suite for MCPClient."""
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
"""Test client initialization."""
|
||||||
|
client = MCPClient(
|
||||||
|
server_url="http://test.example.com/mcp",
|
||||||
|
headers={"Authorization": "Bearer test"},
|
||||||
|
timeout=30.0,
|
||||||
|
sse_read_timeout=60.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert client.server_url == "http://test.example.com/mcp"
|
||||||
|
assert client.headers == {"Authorization": "Bearer test"}
|
||||||
|
assert client.timeout == 30.0
|
||||||
|
assert client.sse_read_timeout == 60.0
|
||||||
|
assert client._session is None
|
||||||
|
assert isinstance(client._exit_stack, ExitStack)
|
||||||
|
assert client._initialized is False
|
||||||
|
|
||||||
|
def test_init_defaults(self):
|
||||||
|
"""Test client initialization with defaults."""
|
||||||
|
client = MCPClient(server_url="http://test.example.com")
|
||||||
|
|
||||||
|
assert client.server_url == "http://test.example.com"
|
||||||
|
assert client.headers == {}
|
||||||
|
assert client.timeout is None
|
||||||
|
assert client.sse_read_timeout is None
|
||||||
|
|
||||||
|
@patch("core.mcp.mcp_client.streamablehttp_client")
|
||||||
|
@patch("core.mcp.mcp_client.ClientSession")
|
||||||
|
def test_initialize_with_mcp_url(self, mock_client_session, mock_streamable_client):
|
||||||
|
"""Test initialization with MCP URL."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_read_stream = Mock()
|
||||||
|
mock_write_stream = Mock()
|
||||||
|
mock_client_context = Mock()
|
||||||
|
mock_streamable_client.return_value.__enter__.return_value = (
|
||||||
|
mock_read_stream,
|
||||||
|
mock_write_stream,
|
||||||
|
mock_client_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_client_session.return_value.__enter__.return_value = mock_session
|
||||||
|
|
||||||
|
client = MCPClient(server_url="http://test.example.com/mcp")
|
||||||
|
client._initialize()
|
||||||
|
|
||||||
|
# Verify streamable client was called
|
||||||
|
mock_streamable_client.assert_called_once_with(
|
||||||
|
url="http://test.example.com/mcp",
|
||||||
|
headers={},
|
||||||
|
timeout=None,
|
||||||
|
sse_read_timeout=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify session was created
|
||||||
|
mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream)
|
||||||
|
mock_session.initialize.assert_called_once()
|
||||||
|
assert client._session == mock_session
|
||||||
|
|
||||||
|
@patch("core.mcp.mcp_client.sse_client")
|
||||||
|
@patch("core.mcp.mcp_client.ClientSession")
|
||||||
|
def test_initialize_with_sse_url(self, mock_client_session, mock_sse_client):
|
||||||
|
"""Test initialization with SSE URL."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_read_stream = Mock()
|
||||||
|
mock_write_stream = Mock()
|
||||||
|
mock_sse_client.return_value.__enter__.return_value = (mock_read_stream, mock_write_stream)
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_client_session.return_value.__enter__.return_value = mock_session
|
||||||
|
|
||||||
|
client = MCPClient(server_url="http://test.example.com/sse")
|
||||||
|
client._initialize()
|
||||||
|
|
||||||
|
# Verify SSE client was called
|
||||||
|
mock_sse_client.assert_called_once_with(
|
||||||
|
url="http://test.example.com/sse",
|
||||||
|
headers={},
|
||||||
|
timeout=None,
|
||||||
|
sse_read_timeout=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify session was created
|
||||||
|
mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream)
|
||||||
|
mock_session.initialize.assert_called_once()
|
||||||
|
assert client._session == mock_session
|
||||||
|
|
||||||
|
@patch("core.mcp.mcp_client.sse_client")
|
||||||
|
@patch("core.mcp.mcp_client.streamablehttp_client")
|
||||||
|
@patch("core.mcp.mcp_client.ClientSession")
|
||||||
|
def test_initialize_with_unknown_method_fallback_to_sse(
|
||||||
|
self, mock_client_session, mock_streamable_client, mock_sse_client
|
||||||
|
):
|
||||||
|
"""Test initialization with unknown method falls back to SSE."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_read_stream = Mock()
|
||||||
|
mock_write_stream = Mock()
|
||||||
|
mock_sse_client.return_value.__enter__.return_value = (mock_read_stream, mock_write_stream)
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_client_session.return_value.__enter__.return_value = mock_session
|
||||||
|
|
||||||
|
client = MCPClient(server_url="http://test.example.com/unknown")
|
||||||
|
client._initialize()
|
||||||
|
|
||||||
|
# Verify SSE client was tried
|
||||||
|
mock_sse_client.assert_called_once()
|
||||||
|
mock_streamable_client.assert_not_called()
|
||||||
|
|
||||||
|
# Verify session was created
|
||||||
|
assert client._session == mock_session
|
||||||
|
|
||||||
|
@patch("core.mcp.mcp_client.sse_client")
|
||||||
|
@patch("core.mcp.mcp_client.streamablehttp_client")
|
||||||
|
@patch("core.mcp.mcp_client.ClientSession")
|
||||||
|
def test_initialize_fallback_from_sse_to_mcp(self, mock_client_session, mock_streamable_client, mock_sse_client):
|
||||||
|
"""Test initialization falls back from SSE to MCP on connection error."""
|
||||||
|
# Setup SSE to fail
|
||||||
|
mock_sse_client.side_effect = MCPConnectionError("SSE connection failed")
|
||||||
|
|
||||||
|
# Setup MCP to succeed
|
||||||
|
mock_read_stream = Mock()
|
||||||
|
mock_write_stream = Mock()
|
||||||
|
mock_client_context = Mock()
|
||||||
|
mock_streamable_client.return_value.__enter__.return_value = (
|
||||||
|
mock_read_stream,
|
||||||
|
mock_write_stream,
|
||||||
|
mock_client_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_client_session.return_value.__enter__.return_value = mock_session
|
||||||
|
|
||||||
|
client = MCPClient(server_url="http://test.example.com/unknown")
|
||||||
|
client._initialize()
|
||||||
|
|
||||||
|
# Verify both were tried
|
||||||
|
mock_sse_client.assert_called_once()
|
||||||
|
mock_streamable_client.assert_called_once()
|
||||||
|
|
||||||
|
# Verify session was created with MCP
|
||||||
|
assert client._session == mock_session
|
||||||
|
|
||||||
|
@patch("core.mcp.mcp_client.streamablehttp_client")
|
||||||
|
@patch("core.mcp.mcp_client.ClientSession")
|
||||||
|
def test_connect_server_mcp(self, mock_client_session, mock_streamable_client):
|
||||||
|
"""Test connect_server with MCP method."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_read_stream = Mock()
|
||||||
|
mock_write_stream = Mock()
|
||||||
|
mock_client_context = Mock()
|
||||||
|
mock_streamable_client.return_value.__enter__.return_value = (
|
||||||
|
mock_read_stream,
|
||||||
|
mock_write_stream,
|
||||||
|
mock_client_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_client_session.return_value.__enter__.return_value = mock_session
|
||||||
|
|
||||||
|
client = MCPClient(server_url="http://test.example.com")
|
||||||
|
client.connect_server(mock_streamable_client, "mcp")
|
||||||
|
|
||||||
|
# Verify correct streams were passed
|
||||||
|
mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream)
|
||||||
|
mock_session.initialize.assert_called_once()
|
||||||
|
|
||||||
|
@patch("core.mcp.mcp_client.sse_client")
|
||||||
|
@patch("core.mcp.mcp_client.ClientSession")
|
||||||
|
def test_connect_server_sse(self, mock_client_session, mock_sse_client):
|
||||||
|
"""Test connect_server with SSE method."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_read_stream = Mock()
|
||||||
|
mock_write_stream = Mock()
|
||||||
|
mock_sse_client.return_value.__enter__.return_value = (mock_read_stream, mock_write_stream)
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_client_session.return_value.__enter__.return_value = mock_session
|
||||||
|
|
||||||
|
client = MCPClient(server_url="http://test.example.com")
|
||||||
|
client.connect_server(mock_sse_client, "sse")
|
||||||
|
|
||||||
|
# Verify correct streams were passed
|
||||||
|
mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream)
|
||||||
|
mock_session.initialize.assert_called_once()
|
||||||
|
|
||||||
|
def test_context_manager_enter(self):
|
||||||
|
"""Test context manager enter."""
|
||||||
|
client = MCPClient(server_url="http://test.example.com")
|
||||||
|
|
||||||
|
with patch.object(client, "_initialize") as mock_initialize:
|
||||||
|
result = client.__enter__()
|
||||||
|
|
||||||
|
assert result == client
|
||||||
|
assert client._initialized is True
|
||||||
|
mock_initialize.assert_called_once()
|
||||||
|
|
||||||
|
def test_context_manager_exit(self):
|
||||||
|
"""Test context manager exit."""
|
||||||
|
client = MCPClient(server_url="http://test.example.com")
|
||||||
|
|
||||||
|
with patch.object(client, "cleanup") as mock_cleanup:
|
||||||
|
exc_type: type[BaseException] | None = None
|
||||||
|
exc_val: BaseException | None = None
|
||||||
|
exc_tb: TracebackType | None = None
|
||||||
|
client.__exit__(exc_type, exc_val, exc_tb)
|
||||||
|
|
||||||
|
mock_cleanup.assert_called_once()
|
||||||
|
|
||||||
|
def test_list_tools_not_initialized(self):
|
||||||
|
"""Test list_tools when session not initialized."""
|
||||||
|
client = MCPClient(server_url="http://test.example.com")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
client.list_tools()
|
||||||
|
|
||||||
|
assert "Session not initialized" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_list_tools_success(self):
|
||||||
|
"""Test successful list_tools call."""
|
||||||
|
client = MCPClient(server_url="http://test.example.com")
|
||||||
|
|
||||||
|
# Setup mock session
|
||||||
|
mock_session = Mock()
|
||||||
|
expected_tools = [
|
||||||
|
Tool(
|
||||||
|
name="test-tool",
|
||||||
|
description="A test tool",
|
||||||
|
inputSchema={"type": "object", "properties": {}},
|
||||||
|
annotations=ToolAnnotations(title="Test Tool"),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
mock_session.list_tools.return_value = ListToolsResult(tools=expected_tools)
|
||||||
|
client._session = mock_session
|
||||||
|
|
||||||
|
result = client.list_tools()
|
||||||
|
|
||||||
|
assert result == expected_tools
|
||||||
|
mock_session.list_tools.assert_called_once()
|
||||||
|
|
||||||
|
def test_invoke_tool_not_initialized(self):
|
||||||
|
"""Test invoke_tool when session not initialized."""
|
||||||
|
client = MCPClient(server_url="http://test.example.com")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
client.invoke_tool("test-tool", {"arg": "value"})
|
||||||
|
|
||||||
|
assert "Session not initialized" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_invoke_tool_success(self):
|
||||||
|
"""Test successful invoke_tool call."""
|
||||||
|
client = MCPClient(server_url="http://test.example.com")
|
||||||
|
|
||||||
|
# Setup mock session
|
||||||
|
mock_session = Mock()
|
||||||
|
expected_result = CallToolResult(
|
||||||
|
content=[TextContent(type="text", text="Tool executed successfully")],
|
||||||
|
isError=False,
|
||||||
|
)
|
||||||
|
mock_session.call_tool.return_value = expected_result
|
||||||
|
client._session = mock_session
|
||||||
|
|
||||||
|
result = client.invoke_tool("test-tool", {"arg": "value"})
|
||||||
|
|
||||||
|
assert result == expected_result
|
||||||
|
mock_session.call_tool.assert_called_once_with("test-tool", {"arg": "value"})
|
||||||
|
|
||||||
|
def test_cleanup(self):
|
||||||
|
"""Test cleanup method."""
|
||||||
|
client = MCPClient(server_url="http://test.example.com")
|
||||||
|
mock_exit_stack = Mock(spec=ExitStack)
|
||||||
|
client._exit_stack = mock_exit_stack
|
||||||
|
client._session = Mock()
|
||||||
|
client._initialized = True
|
||||||
|
|
||||||
|
client.cleanup()
|
||||||
|
|
||||||
|
mock_exit_stack.close.assert_called_once()
|
||||||
|
assert client._session is None
|
||||||
|
assert client._initialized is False
|
||||||
|
|
||||||
|
def test_cleanup_with_error(self):
|
||||||
|
"""Test cleanup method with error."""
|
||||||
|
client = MCPClient(server_url="http://test.example.com")
|
||||||
|
mock_exit_stack = Mock(spec=ExitStack)
|
||||||
|
mock_exit_stack.close.side_effect = Exception("Cleanup error")
|
||||||
|
client._exit_stack = mock_exit_stack
|
||||||
|
client._session = Mock()
|
||||||
|
client._initialized = True
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
client.cleanup()
|
||||||
|
|
||||||
|
assert "Error during cleanup: Cleanup error" in str(exc_info.value)
|
||||||
|
assert client._session is None
|
||||||
|
assert client._initialized is False
|
||||||
|
|
||||||
|
@patch("core.mcp.mcp_client.streamablehttp_client")
|
||||||
|
@patch("core.mcp.mcp_client.ClientSession")
|
||||||
|
def test_full_context_manager_flow(self, mock_client_session, mock_streamable_client):
|
||||||
|
"""Test full context manager flow."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_read_stream = Mock()
|
||||||
|
mock_write_stream = Mock()
|
||||||
|
mock_client_context = Mock()
|
||||||
|
mock_streamable_client.return_value.__enter__.return_value = (
|
||||||
|
mock_read_stream,
|
||||||
|
mock_write_stream,
|
||||||
|
mock_client_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_client_session.return_value.__enter__.return_value = mock_session
|
||||||
|
|
||||||
|
expected_tools = [Tool(name="test-tool", description="Test", inputSchema={})]
|
||||||
|
mock_session.list_tools.return_value = ListToolsResult(tools=expected_tools)
|
||||||
|
|
||||||
|
with MCPClient(server_url="http://test.example.com/mcp") as client:
|
||||||
|
assert client._initialized is True
|
||||||
|
assert client._session == mock_session
|
||||||
|
|
||||||
|
# Test tool operations
|
||||||
|
tools = client.list_tools()
|
||||||
|
assert tools == expected_tools
|
||||||
|
|
||||||
|
# After exit, should be cleaned up
|
||||||
|
assert client._initialized is False
|
||||||
|
assert client._session is None
|
||||||
|
|
||||||
|
def test_headers_passed_to_clients(self):
|
||||||
|
"""Test that headers are properly passed to underlying clients."""
|
||||||
|
custom_headers = {
|
||||||
|
"Authorization": "Bearer test-token",
|
||||||
|
"X-Custom-Header": "test-value",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("core.mcp.mcp_client.streamablehttp_client") as mock_streamable_client:
|
||||||
|
with patch("core.mcp.mcp_client.ClientSession") as mock_client_session:
|
||||||
|
# Setup mocks
|
||||||
|
mock_read_stream = Mock()
|
||||||
|
mock_write_stream = Mock()
|
||||||
|
mock_client_context = Mock()
|
||||||
|
mock_streamable_client.return_value.__enter__.return_value = (
|
||||||
|
mock_read_stream,
|
||||||
|
mock_write_stream,
|
||||||
|
mock_client_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_client_session.return_value.__enter__.return_value = mock_session
|
||||||
|
|
||||||
|
client = MCPClient(
|
||||||
|
server_url="http://test.example.com/mcp",
|
||||||
|
headers=custom_headers,
|
||||||
|
timeout=30.0,
|
||||||
|
sse_read_timeout=60.0,
|
||||||
|
)
|
||||||
|
client._initialize()
|
||||||
|
|
||||||
|
# Verify headers were passed
|
||||||
|
mock_streamable_client.assert_called_once_with(
|
||||||
|
url="http://test.example.com/mcp",
|
||||||
|
headers=custom_headers,
|
||||||
|
timeout=30.0,
|
||||||
|
sse_read_timeout=60.0,
|
||||||
|
)
|
||||||
492
api/tests/unit_tests/core/mcp/test_types.py
Normal file
492
api/tests/unit_tests/core/mcp/test_types.py
Normal file
@ -0,0 +1,492 @@
|
|||||||
|
"""Unit tests for MCP types module."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from core.mcp.types import (
|
||||||
|
INTERNAL_ERROR,
|
||||||
|
INVALID_PARAMS,
|
||||||
|
INVALID_REQUEST,
|
||||||
|
LATEST_PROTOCOL_VERSION,
|
||||||
|
METHOD_NOT_FOUND,
|
||||||
|
PARSE_ERROR,
|
||||||
|
SERVER_LATEST_PROTOCOL_VERSION,
|
||||||
|
Annotations,
|
||||||
|
CallToolRequest,
|
||||||
|
CallToolRequestParams,
|
||||||
|
CallToolResult,
|
||||||
|
ClientCapabilities,
|
||||||
|
CompleteRequest,
|
||||||
|
CompleteRequestParams,
|
||||||
|
CompleteResult,
|
||||||
|
Completion,
|
||||||
|
CompletionArgument,
|
||||||
|
CompletionContext,
|
||||||
|
ErrorData,
|
||||||
|
ImageContent,
|
||||||
|
Implementation,
|
||||||
|
InitializeRequest,
|
||||||
|
InitializeRequestParams,
|
||||||
|
InitializeResult,
|
||||||
|
JSONRPCError,
|
||||||
|
JSONRPCMessage,
|
||||||
|
JSONRPCNotification,
|
||||||
|
JSONRPCRequest,
|
||||||
|
JSONRPCResponse,
|
||||||
|
ListToolsRequest,
|
||||||
|
ListToolsResult,
|
||||||
|
OAuthClientInformation,
|
||||||
|
OAuthClientMetadata,
|
||||||
|
OAuthMetadata,
|
||||||
|
OAuthTokens,
|
||||||
|
PingRequest,
|
||||||
|
ProgressNotification,
|
||||||
|
ProgressNotificationParams,
|
||||||
|
PromptReference,
|
||||||
|
RequestParams,
|
||||||
|
ResourceTemplateReference,
|
||||||
|
Result,
|
||||||
|
ServerCapabilities,
|
||||||
|
TextContent,
|
||||||
|
Tool,
|
||||||
|
ToolAnnotations,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestConstants:
|
||||||
|
"""Test module constants."""
|
||||||
|
|
||||||
|
def test_protocol_versions(self):
|
||||||
|
"""Test protocol version constants."""
|
||||||
|
assert LATEST_PROTOCOL_VERSION == "2025-03-26"
|
||||||
|
assert SERVER_LATEST_PROTOCOL_VERSION == "2024-11-05"
|
||||||
|
|
||||||
|
def test_error_codes(self):
|
||||||
|
"""Test JSON-RPC error code constants."""
|
||||||
|
assert PARSE_ERROR == -32700
|
||||||
|
assert INVALID_REQUEST == -32600
|
||||||
|
assert METHOD_NOT_FOUND == -32601
|
||||||
|
assert INVALID_PARAMS == -32602
|
||||||
|
assert INTERNAL_ERROR == -32603
|
||||||
|
|
||||||
|
|
||||||
|
class TestRequestParams:
|
||||||
|
"""Test RequestParams and related classes."""
|
||||||
|
|
||||||
|
def test_request_params_basic(self):
|
||||||
|
"""Test basic RequestParams creation."""
|
||||||
|
params = RequestParams()
|
||||||
|
assert params.meta is None
|
||||||
|
|
||||||
|
def test_request_params_with_meta(self):
|
||||||
|
"""Test RequestParams with meta."""
|
||||||
|
meta = RequestParams.Meta(progressToken="test-token")
|
||||||
|
params = RequestParams(_meta=meta)
|
||||||
|
assert params.meta is not None
|
||||||
|
assert params.meta.progressToken == "test-token"
|
||||||
|
|
||||||
|
def test_request_params_meta_extra_fields(self):
|
||||||
|
"""Test RequestParams.Meta allows extra fields."""
|
||||||
|
meta = RequestParams.Meta(progressToken="token", customField="value")
|
||||||
|
assert meta.progressToken == "token"
|
||||||
|
assert meta.customField == "value" # type: ignore
|
||||||
|
|
||||||
|
def test_request_params_serialization(self):
|
||||||
|
"""Test RequestParams serialization with _meta alias."""
|
||||||
|
meta = RequestParams.Meta(progressToken="test")
|
||||||
|
params = RequestParams(_meta=meta)
|
||||||
|
|
||||||
|
# Model dump should use the alias
|
||||||
|
dumped = params.model_dump(by_alias=True)
|
||||||
|
assert "_meta" in dumped
|
||||||
|
assert dumped["_meta"] is not None
|
||||||
|
assert dumped["_meta"]["progressToken"] == "test"
|
||||||
|
|
||||||
|
|
||||||
|
class TestJSONRPCMessages:
|
||||||
|
"""Test JSON-RPC message types."""
|
||||||
|
|
||||||
|
def test_jsonrpc_request(self):
|
||||||
|
"""Test JSONRPCRequest creation and validation."""
|
||||||
|
request = JSONRPCRequest(jsonrpc="2.0", id="test-123", method="test_method", params={"key": "value"})
|
||||||
|
|
||||||
|
assert request.jsonrpc == "2.0"
|
||||||
|
assert request.id == "test-123"
|
||||||
|
assert request.method == "test_method"
|
||||||
|
assert request.params == {"key": "value"}
|
||||||
|
|
||||||
|
def test_jsonrpc_request_numeric_id(self):
|
||||||
|
"""Test JSONRPCRequest with numeric ID."""
|
||||||
|
request = JSONRPCRequest(jsonrpc="2.0", id=123, method="test", params=None)
|
||||||
|
assert request.id == 123
|
||||||
|
|
||||||
|
def test_jsonrpc_notification(self):
|
||||||
|
"""Test JSONRPCNotification creation."""
|
||||||
|
notification = JSONRPCNotification(jsonrpc="2.0", method="notification_method", params={"data": "test"})
|
||||||
|
|
||||||
|
assert notification.jsonrpc == "2.0"
|
||||||
|
assert notification.method == "notification_method"
|
||||||
|
assert not hasattr(notification, "id") # Notifications don't have ID
|
||||||
|
|
||||||
|
def test_jsonrpc_response(self):
|
||||||
|
"""Test JSONRPCResponse creation."""
|
||||||
|
response = JSONRPCResponse(jsonrpc="2.0", id="req-123", result={"success": True})
|
||||||
|
|
||||||
|
assert response.jsonrpc == "2.0"
|
||||||
|
assert response.id == "req-123"
|
||||||
|
assert response.result == {"success": True}
|
||||||
|
|
||||||
|
def test_jsonrpc_error(self):
|
||||||
|
"""Test JSONRPCError creation."""
|
||||||
|
error_data = ErrorData(code=INVALID_PARAMS, message="Invalid parameters", data={"field": "missing"})
|
||||||
|
|
||||||
|
error = JSONRPCError(jsonrpc="2.0", id="req-123", error=error_data)
|
||||||
|
|
||||||
|
assert error.jsonrpc == "2.0"
|
||||||
|
assert error.id == "req-123"
|
||||||
|
assert error.error.code == INVALID_PARAMS
|
||||||
|
assert error.error.message == "Invalid parameters"
|
||||||
|
assert error.error.data == {"field": "missing"}
|
||||||
|
|
||||||
|
def test_jsonrpc_message_parsing(self):
|
||||||
|
"""Test JSONRPCMessage parsing different message types."""
|
||||||
|
# Parse request
|
||||||
|
request_json = '{"jsonrpc": "2.0", "id": 1, "method": "test", "params": null}'
|
||||||
|
msg = JSONRPCMessage.model_validate_json(request_json)
|
||||||
|
assert isinstance(msg.root, JSONRPCRequest)
|
||||||
|
|
||||||
|
# Parse response
|
||||||
|
response_json = '{"jsonrpc": "2.0", "id": 1, "result": {"data": "test"}}'
|
||||||
|
msg = JSONRPCMessage.model_validate_json(response_json)
|
||||||
|
assert isinstance(msg.root, JSONRPCResponse)
|
||||||
|
|
||||||
|
# Parse error
|
||||||
|
error_json = '{"jsonrpc": "2.0", "id": 1, "error": {"code": -32600, "message": "Invalid Request"}}'
|
||||||
|
msg = JSONRPCMessage.model_validate_json(error_json)
|
||||||
|
assert isinstance(msg.root, JSONRPCError)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCapabilities:
|
||||||
|
"""Test capability classes."""
|
||||||
|
|
||||||
|
def test_client_capabilities(self):
|
||||||
|
"""Test ClientCapabilities creation."""
|
||||||
|
caps = ClientCapabilities(
|
||||||
|
experimental={"feature": {"enabled": True}},
|
||||||
|
sampling={"model_config": {"extra": "allow"}},
|
||||||
|
roots={"listChanged": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert caps.experimental == {"feature": {"enabled": True}}
|
||||||
|
assert caps.sampling is not None
|
||||||
|
assert caps.roots.listChanged is True # type: ignore
|
||||||
|
|
||||||
|
def test_server_capabilities(self):
|
||||||
|
"""Test ServerCapabilities creation."""
|
||||||
|
caps = ServerCapabilities(
|
||||||
|
tools={"listChanged": True},
|
||||||
|
resources={"subscribe": True, "listChanged": False},
|
||||||
|
prompts={"listChanged": True},
|
||||||
|
logging={},
|
||||||
|
completions={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert caps.tools.listChanged is True # type: ignore
|
||||||
|
assert caps.resources.subscribe is True # type: ignore
|
||||||
|
assert caps.resources.listChanged is False # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class TestInitialization:
|
||||||
|
"""Test initialization request/response types."""
|
||||||
|
|
||||||
|
def test_initialize_request(self):
|
||||||
|
"""Test InitializeRequest creation."""
|
||||||
|
client_info = Implementation(name="test-client", version="1.0.0")
|
||||||
|
capabilities = ClientCapabilities()
|
||||||
|
|
||||||
|
params = InitializeRequestParams(
|
||||||
|
protocolVersion=LATEST_PROTOCOL_VERSION, capabilities=capabilities, clientInfo=client_info
|
||||||
|
)
|
||||||
|
|
||||||
|
request = InitializeRequest(params=params)
|
||||||
|
|
||||||
|
assert request.method == "initialize"
|
||||||
|
assert request.params.protocolVersion == LATEST_PROTOCOL_VERSION
|
||||||
|
assert request.params.clientInfo.name == "test-client"
|
||||||
|
|
||||||
|
def test_initialize_result(self):
|
||||||
|
"""Test InitializeResult creation."""
|
||||||
|
server_info = Implementation(name="test-server", version="1.0.0")
|
||||||
|
capabilities = ServerCapabilities()
|
||||||
|
|
||||||
|
result = InitializeResult(
|
||||||
|
protocolVersion=LATEST_PROTOCOL_VERSION,
|
||||||
|
capabilities=capabilities,
|
||||||
|
serverInfo=server_info,
|
||||||
|
instructions="Welcome to test server",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.protocolVersion == LATEST_PROTOCOL_VERSION
|
||||||
|
assert result.serverInfo.name == "test-server"
|
||||||
|
assert result.instructions == "Welcome to test server"
|
||||||
|
|
||||||
|
|
||||||
|
class TestTools:
|
||||||
|
"""Test tool-related types."""
|
||||||
|
|
||||||
|
def test_tool_creation(self):
|
||||||
|
"""Test Tool creation with all fields."""
|
||||||
|
tool = Tool(
|
||||||
|
name="test_tool",
|
||||||
|
title="Test Tool",
|
||||||
|
description="A tool for testing",
|
||||||
|
inputSchema={"type": "object", "properties": {"input": {"type": "string"}}, "required": ["input"]},
|
||||||
|
outputSchema={"type": "object", "properties": {"result": {"type": "string"}}},
|
||||||
|
annotations=ToolAnnotations(
|
||||||
|
title="Test Tool", readOnlyHint=False, destructiveHint=False, idempotentHint=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tool.name == "test_tool"
|
||||||
|
assert tool.title == "Test Tool"
|
||||||
|
assert tool.description == "A tool for testing"
|
||||||
|
assert tool.inputSchema["properties"]["input"]["type"] == "string"
|
||||||
|
assert tool.annotations.idempotentHint is True
|
||||||
|
|
||||||
|
def test_call_tool_request(self):
|
||||||
|
"""Test CallToolRequest creation."""
|
||||||
|
params = CallToolRequestParams(name="test_tool", arguments={"input": "test value"})
|
||||||
|
|
||||||
|
request = CallToolRequest(params=params)
|
||||||
|
|
||||||
|
assert request.method == "tools/call"
|
||||||
|
assert request.params.name == "test_tool"
|
||||||
|
assert request.params.arguments == {"input": "test value"}
|
||||||
|
|
||||||
|
def test_call_tool_result(self):
|
||||||
|
"""Test CallToolResult creation."""
|
||||||
|
result = CallToolResult(
|
||||||
|
content=[TextContent(type="text", text="Tool executed successfully")],
|
||||||
|
structuredContent={"status": "success", "data": "test"},
|
||||||
|
isError=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result.content) == 1
|
||||||
|
assert result.content[0].text == "Tool executed successfully" # type: ignore
|
||||||
|
assert result.structuredContent == {"status": "success", "data": "test"}
|
||||||
|
assert result.isError is False
|
||||||
|
|
||||||
|
def test_list_tools_request(self):
|
||||||
|
"""Test ListToolsRequest creation."""
|
||||||
|
request = ListToolsRequest()
|
||||||
|
assert request.method == "tools/list"
|
||||||
|
|
||||||
|
def test_list_tools_result(self):
|
||||||
|
"""Test ListToolsResult creation."""
|
||||||
|
tool1 = Tool(name="tool1", inputSchema={})
|
||||||
|
tool2 = Tool(name="tool2", inputSchema={})
|
||||||
|
|
||||||
|
result = ListToolsResult(tools=[tool1, tool2])
|
||||||
|
|
||||||
|
assert len(result.tools) == 2
|
||||||
|
assert result.tools[0].name == "tool1"
|
||||||
|
assert result.tools[1].name == "tool2"
|
||||||
|
|
||||||
|
|
||||||
|
class TestContent:
|
||||||
|
"""Test content types."""
|
||||||
|
|
||||||
|
def test_text_content(self):
|
||||||
|
"""Test TextContent creation."""
|
||||||
|
annotations = Annotations(audience=["user"], priority=0.8)
|
||||||
|
content = TextContent(type="text", text="Hello, world!", annotations=annotations)
|
||||||
|
|
||||||
|
assert content.type == "text"
|
||||||
|
assert content.text == "Hello, world!"
|
||||||
|
assert content.annotations is not None
|
||||||
|
assert content.annotations.priority == 0.8
|
||||||
|
|
||||||
|
def test_image_content(self):
|
||||||
|
"""Test ImageContent creation."""
|
||||||
|
content = ImageContent(type="image", data="base64encodeddata", mimeType="image/png")
|
||||||
|
|
||||||
|
assert content.type == "image"
|
||||||
|
assert content.data == "base64encodeddata"
|
||||||
|
assert content.mimeType == "image/png"
|
||||||
|
|
||||||
|
|
||||||
|
class TestOAuth:
|
||||||
|
"""Test OAuth-related types."""
|
||||||
|
|
||||||
|
def test_oauth_client_metadata(self):
|
||||||
|
"""Test OAuthClientMetadata creation."""
|
||||||
|
metadata = OAuthClientMetadata(
|
||||||
|
client_name="Test Client",
|
||||||
|
redirect_uris=["https://example.com/callback"],
|
||||||
|
grant_types=["authorization_code", "refresh_token"],
|
||||||
|
response_types=["code"],
|
||||||
|
token_endpoint_auth_method="none",
|
||||||
|
client_uri="https://example.com",
|
||||||
|
scope="read write",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert metadata.client_name == "Test Client"
|
||||||
|
assert len(metadata.redirect_uris) == 1
|
||||||
|
assert "authorization_code" in metadata.grant_types
|
||||||
|
|
||||||
|
def test_oauth_client_information(self):
|
||||||
|
"""Test OAuthClientInformation creation."""
|
||||||
|
info = OAuthClientInformation(client_id="test-client-id", client_secret="test-secret")
|
||||||
|
|
||||||
|
assert info.client_id == "test-client-id"
|
||||||
|
assert info.client_secret == "test-secret"
|
||||||
|
|
||||||
|
def test_oauth_client_information_without_secret(self):
|
||||||
|
"""Test OAuthClientInformation without secret."""
|
||||||
|
info = OAuthClientInformation(client_id="public-client")
|
||||||
|
|
||||||
|
assert info.client_id == "public-client"
|
||||||
|
assert info.client_secret is None
|
||||||
|
|
||||||
|
def test_oauth_tokens(self):
|
||||||
|
"""Test OAuthTokens creation."""
|
||||||
|
tokens = OAuthTokens(
|
||||||
|
access_token="access-token-123",
|
||||||
|
token_type="Bearer",
|
||||||
|
expires_in=3600,
|
||||||
|
refresh_token="refresh-token-456",
|
||||||
|
scope="read write",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tokens.access_token == "access-token-123"
|
||||||
|
assert tokens.token_type == "Bearer"
|
||||||
|
assert tokens.expires_in == 3600
|
||||||
|
assert tokens.refresh_token == "refresh-token-456"
|
||||||
|
assert tokens.scope == "read write"
|
||||||
|
|
||||||
|
def test_oauth_metadata(self):
|
||||||
|
"""Test OAuthMetadata creation."""
|
||||||
|
metadata = OAuthMetadata(
|
||||||
|
authorization_endpoint="https://auth.example.com/authorize",
|
||||||
|
token_endpoint="https://auth.example.com/token",
|
||||||
|
registration_endpoint="https://auth.example.com/register",
|
||||||
|
response_types_supported=["code", "token"],
|
||||||
|
grant_types_supported=["authorization_code", "refresh_token"],
|
||||||
|
code_challenge_methods_supported=["plain", "S256"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert metadata.authorization_endpoint == "https://auth.example.com/authorize"
|
||||||
|
assert "code" in metadata.response_types_supported
|
||||||
|
assert "S256" in metadata.code_challenge_methods_supported
|
||||||
|
|
||||||
|
|
||||||
|
class TestNotifications:
|
||||||
|
"""Test notification types."""
|
||||||
|
|
||||||
|
def test_progress_notification(self):
|
||||||
|
"""Test ProgressNotification creation."""
|
||||||
|
params = ProgressNotificationParams(
|
||||||
|
progressToken="progress-123", progress=50.0, total=100.0, message="Processing... 50%"
|
||||||
|
)
|
||||||
|
|
||||||
|
notification = ProgressNotification(params=params)
|
||||||
|
|
||||||
|
assert notification.method == "notifications/progress"
|
||||||
|
assert notification.params.progressToken == "progress-123"
|
||||||
|
assert notification.params.progress == 50.0
|
||||||
|
assert notification.params.total == 100.0
|
||||||
|
assert notification.params.message == "Processing... 50%"
|
||||||
|
|
||||||
|
def test_ping_request(self):
|
||||||
|
"""Test PingRequest creation."""
|
||||||
|
request = PingRequest()
|
||||||
|
assert request.method == "ping"
|
||||||
|
assert request.params is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompletion:
|
||||||
|
"""Test completion-related types."""
|
||||||
|
|
||||||
|
def test_completion_context(self):
|
||||||
|
"""Test CompletionContext creation."""
|
||||||
|
context = CompletionContext(arguments={"template_var": "value"})
|
||||||
|
assert context.arguments == {"template_var": "value"}
|
||||||
|
|
||||||
|
def test_resource_template_reference(self):
|
||||||
|
"""Test ResourceTemplateReference creation."""
|
||||||
|
ref = ResourceTemplateReference(type="ref/resource", uri="file:///path/to/{filename}")
|
||||||
|
assert ref.type == "ref/resource"
|
||||||
|
assert ref.uri == "file:///path/to/{filename}"
|
||||||
|
|
||||||
|
def test_prompt_reference(self):
|
||||||
|
"""Test PromptReference creation."""
|
||||||
|
ref = PromptReference(type="ref/prompt", name="test_prompt")
|
||||||
|
assert ref.type == "ref/prompt"
|
||||||
|
assert ref.name == "test_prompt"
|
||||||
|
|
||||||
|
def test_complete_request(self):
|
||||||
|
"""Test CompleteRequest creation."""
|
||||||
|
ref = PromptReference(type="ref/prompt", name="test_prompt")
|
||||||
|
arg = CompletionArgument(name="arg1", value="val")
|
||||||
|
|
||||||
|
params = CompleteRequestParams(ref=ref, argument=arg, context=CompletionContext(arguments={"key": "value"}))
|
||||||
|
|
||||||
|
request = CompleteRequest(params=params)
|
||||||
|
|
||||||
|
assert request.method == "completion/complete"
|
||||||
|
assert request.params.ref.name == "test_prompt" # type: ignore
|
||||||
|
assert request.params.argument.name == "arg1"
|
||||||
|
|
||||||
|
def test_complete_result(self):
|
||||||
|
"""Test CompleteResult creation."""
|
||||||
|
completion = Completion(values=["option1", "option2", "option3"], total=10, hasMore=True)
|
||||||
|
|
||||||
|
result = CompleteResult(completion=completion)
|
||||||
|
|
||||||
|
assert len(result.completion.values) == 3
|
||||||
|
assert result.completion.total == 10
|
||||||
|
assert result.completion.hasMore is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidation:
|
||||||
|
"""Test validation of various types."""
|
||||||
|
|
||||||
|
def test_invalid_jsonrpc_version(self):
|
||||||
|
"""Test invalid JSON-RPC version validation."""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
JSONRPCRequest(
|
||||||
|
jsonrpc="1.0", # Invalid version
|
||||||
|
id=1,
|
||||||
|
method="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_tool_annotations_validation(self):
|
||||||
|
"""Test ToolAnnotations with invalid values."""
|
||||||
|
# Valid annotations
|
||||||
|
annotations = ToolAnnotations(
|
||||||
|
title="Test", readOnlyHint=True, destructiveHint=False, idempotentHint=True, openWorldHint=False
|
||||||
|
)
|
||||||
|
assert annotations.title == "Test"
|
||||||
|
|
||||||
|
def test_extra_fields_allowed(self):
|
||||||
|
"""Test that extra fields are allowed in models."""
|
||||||
|
# Most models should allow extra fields
|
||||||
|
tool = Tool(
|
||||||
|
name="test",
|
||||||
|
inputSchema={},
|
||||||
|
customField="allowed", # type: ignore
|
||||||
|
)
|
||||||
|
assert tool.customField == "allowed" # type: ignore
|
||||||
|
|
||||||
|
def test_result_meta_alias(self):
|
||||||
|
"""Test Result model with _meta alias."""
|
||||||
|
# Create with the field name (not alias)
|
||||||
|
result = Result(_meta={"key": "value"})
|
||||||
|
|
||||||
|
# Verify the field is set correctly
|
||||||
|
assert result.meta == {"key": "value"}
|
||||||
|
|
||||||
|
# Dump with alias
|
||||||
|
dumped = result.model_dump(by_alias=True)
|
||||||
|
assert "_meta" in dumped
|
||||||
|
assert dumped["_meta"] == {"key": "value"}
|
||||||
355
api/tests/unit_tests/core/mcp/test_utils.py
Normal file
355
api/tests/unit_tests/core/mcp/test_utils.py
Normal file
@ -0,0 +1,355 @@
|
|||||||
|
"""Unit tests for MCP utils module."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from collections.abc import Generator
|
||||||
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import httpx_sse
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.mcp.utils import (
|
||||||
|
STATUS_FORCELIST,
|
||||||
|
create_mcp_error_response,
|
||||||
|
create_ssrf_proxy_mcp_http_client,
|
||||||
|
ssrf_proxy_sse_connect,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestConstants:
|
||||||
|
"""Test module constants."""
|
||||||
|
|
||||||
|
def test_status_forcelist(self):
|
||||||
|
"""Test STATUS_FORCELIST contains expected HTTP status codes."""
|
||||||
|
assert STATUS_FORCELIST == [429, 500, 502, 503, 504]
|
||||||
|
assert 429 in STATUS_FORCELIST # Too Many Requests
|
||||||
|
assert 500 in STATUS_FORCELIST # Internal Server Error
|
||||||
|
assert 502 in STATUS_FORCELIST # Bad Gateway
|
||||||
|
assert 503 in STATUS_FORCELIST # Service Unavailable
|
||||||
|
assert 504 in STATUS_FORCELIST # Gateway Timeout
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateSSRFProxyMCPHTTPClient:
|
||||||
|
"""Test create_ssrf_proxy_mcp_http_client function."""
|
||||||
|
|
||||||
|
@patch("core.mcp.utils.dify_config")
|
||||||
|
def test_create_client_with_all_url_proxy(self, mock_config):
|
||||||
|
"""Test client creation with SSRF_PROXY_ALL_URL configured."""
|
||||||
|
mock_config.SSRF_PROXY_ALL_URL = "http://proxy.example.com:8080"
|
||||||
|
mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = True
|
||||||
|
|
||||||
|
client = create_ssrf_proxy_mcp_http_client(
|
||||||
|
headers={"Authorization": "Bearer token"}, timeout=httpx.Timeout(30.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(client, httpx.Client)
|
||||||
|
assert client.headers["Authorization"] == "Bearer token"
|
||||||
|
assert client.timeout.connect == 30.0
|
||||||
|
assert client.follow_redirects is True
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
client.close()
|
||||||
|
|
||||||
|
@patch("core.mcp.utils.dify_config")
|
||||||
|
def test_create_client_with_http_https_proxies(self, mock_config):
|
||||||
|
"""Test client creation with separate HTTP/HTTPS proxies."""
|
||||||
|
mock_config.SSRF_PROXY_ALL_URL = None
|
||||||
|
mock_config.SSRF_PROXY_HTTP_URL = "http://http-proxy.example.com:8080"
|
||||||
|
mock_config.SSRF_PROXY_HTTPS_URL = "http://https-proxy.example.com:8443"
|
||||||
|
mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = False
|
||||||
|
|
||||||
|
client = create_ssrf_proxy_mcp_http_client()
|
||||||
|
|
||||||
|
assert isinstance(client, httpx.Client)
|
||||||
|
assert client.follow_redirects is True
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
client.close()
|
||||||
|
|
||||||
|
@patch("core.mcp.utils.dify_config")
|
||||||
|
def test_create_client_without_proxy(self, mock_config):
|
||||||
|
"""Test client creation without proxy configuration."""
|
||||||
|
mock_config.SSRF_PROXY_ALL_URL = None
|
||||||
|
mock_config.SSRF_PROXY_HTTP_URL = None
|
||||||
|
mock_config.SSRF_PROXY_HTTPS_URL = None
|
||||||
|
mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = True
|
||||||
|
|
||||||
|
headers = {"X-Custom-Header": "value"}
|
||||||
|
timeout = httpx.Timeout(timeout=30.0, connect=5.0, read=10.0, write=30.0)
|
||||||
|
|
||||||
|
client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout)
|
||||||
|
|
||||||
|
assert isinstance(client, httpx.Client)
|
||||||
|
assert client.headers["X-Custom-Header"] == "value"
|
||||||
|
assert client.timeout.connect == 5.0
|
||||||
|
assert client.timeout.read == 10.0
|
||||||
|
assert client.follow_redirects is True
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
client.close()
|
||||||
|
|
||||||
|
@patch("core.mcp.utils.dify_config")
|
||||||
|
def test_create_client_default_params(self, mock_config):
|
||||||
|
"""Test client creation with default parameters."""
|
||||||
|
mock_config.SSRF_PROXY_ALL_URL = None
|
||||||
|
mock_config.SSRF_PROXY_HTTP_URL = None
|
||||||
|
mock_config.SSRF_PROXY_HTTPS_URL = None
|
||||||
|
mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = True
|
||||||
|
|
||||||
|
client = create_ssrf_proxy_mcp_http_client()
|
||||||
|
|
||||||
|
assert isinstance(client, httpx.Client)
|
||||||
|
# httpx.Client adds default headers, so we just check it's a Headers object
|
||||||
|
assert isinstance(client.headers, httpx.Headers)
|
||||||
|
# When no timeout is provided, httpx uses its default timeout
|
||||||
|
assert client.timeout is not None
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
client.close()
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSRFProxySSEConnect:
|
||||||
|
"""Test ssrf_proxy_sse_connect function."""
|
||||||
|
|
||||||
|
@patch("core.mcp.utils.connect_sse")
|
||||||
|
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
|
||||||
|
def test_sse_connect_with_provided_client(self, mock_create_client, mock_connect_sse):
|
||||||
|
"""Test SSE connection with pre-configured client."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_client = Mock(spec=httpx.Client)
|
||||||
|
mock_event_source = Mock(spec=httpx_sse.EventSource)
|
||||||
|
mock_context = MagicMock()
|
||||||
|
mock_context.__enter__.return_value = mock_event_source
|
||||||
|
mock_connect_sse.return_value = mock_context
|
||||||
|
|
||||||
|
# Call with provided client
|
||||||
|
result = ssrf_proxy_sse_connect(
|
||||||
|
"http://example.com/sse", client=mock_client, method="POST", headers={"Authorization": "Bearer token"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify client creation was not called
|
||||||
|
mock_create_client.assert_not_called()
|
||||||
|
|
||||||
|
# Verify connect_sse was called correctly
|
||||||
|
mock_connect_sse.assert_called_once_with(
|
||||||
|
mock_client, "POST", "http://example.com/sse", headers={"Authorization": "Bearer token"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify result
|
||||||
|
assert result == mock_context
|
||||||
|
|
||||||
|
@patch("core.mcp.utils.connect_sse")
|
||||||
|
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
|
||||||
|
@patch("core.mcp.utils.dify_config")
|
||||||
|
def test_sse_connect_without_client(self, mock_config, mock_create_client, mock_connect_sse):
|
||||||
|
"""Test SSE connection without pre-configured client."""
|
||||||
|
# Setup config
|
||||||
|
mock_config.SSRF_DEFAULT_TIME_OUT = 30.0
|
||||||
|
mock_config.SSRF_DEFAULT_CONNECT_TIME_OUT = 10.0
|
||||||
|
mock_config.SSRF_DEFAULT_READ_TIME_OUT = 60.0
|
||||||
|
mock_config.SSRF_DEFAULT_WRITE_TIME_OUT = 30.0
|
||||||
|
|
||||||
|
# Setup mocks
|
||||||
|
mock_client = Mock(spec=httpx.Client)
|
||||||
|
mock_create_client.return_value = mock_client
|
||||||
|
|
||||||
|
mock_event_source = Mock(spec=httpx_sse.EventSource)
|
||||||
|
mock_context = MagicMock()
|
||||||
|
mock_context.__enter__.return_value = mock_event_source
|
||||||
|
mock_connect_sse.return_value = mock_context
|
||||||
|
|
||||||
|
# Call without client
|
||||||
|
result = ssrf_proxy_sse_connect("http://example.com/sse", headers={"X-Custom": "value"})
|
||||||
|
|
||||||
|
# Verify client was created
|
||||||
|
mock_create_client.assert_called_once()
|
||||||
|
call_args = mock_create_client.call_args
|
||||||
|
assert call_args[1]["headers"] == {"X-Custom": "value"}
|
||||||
|
|
||||||
|
timeout = call_args[1]["timeout"]
|
||||||
|
# httpx.Timeout object has these attributes
|
||||||
|
assert isinstance(timeout, httpx.Timeout)
|
||||||
|
assert timeout.connect == 10.0
|
||||||
|
assert timeout.read == 60.0
|
||||||
|
assert timeout.write == 30.0
|
||||||
|
|
||||||
|
# Verify connect_sse was called
|
||||||
|
mock_connect_sse.assert_called_once_with(
|
||||||
|
mock_client,
|
||||||
|
"GET", # Default method
|
||||||
|
"http://example.com/sse",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify result
|
||||||
|
assert result == mock_context
|
||||||
|
|
||||||
|
@patch("core.mcp.utils.connect_sse")
|
||||||
|
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
|
||||||
|
def test_sse_connect_with_custom_timeout(self, mock_create_client, mock_connect_sse):
|
||||||
|
"""Test SSE connection with custom timeout."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_client = Mock(spec=httpx.Client)
|
||||||
|
mock_create_client.return_value = mock_client
|
||||||
|
|
||||||
|
mock_event_source = Mock(spec=httpx_sse.EventSource)
|
||||||
|
mock_context = MagicMock()
|
||||||
|
mock_context.__enter__.return_value = mock_event_source
|
||||||
|
mock_connect_sse.return_value = mock_context
|
||||||
|
|
||||||
|
custom_timeout = httpx.Timeout(timeout=60.0, read=120.0)
|
||||||
|
|
||||||
|
# Call with custom timeout
|
||||||
|
result = ssrf_proxy_sse_connect("http://example.com/sse", timeout=custom_timeout)
|
||||||
|
|
||||||
|
# Verify client was created with custom timeout
|
||||||
|
mock_create_client.assert_called_once()
|
||||||
|
call_args = mock_create_client.call_args
|
||||||
|
assert call_args[1]["timeout"] == custom_timeout
|
||||||
|
|
||||||
|
# Verify result
|
||||||
|
assert result == mock_context
|
||||||
|
|
||||||
|
@patch("core.mcp.utils.connect_sse")
|
||||||
|
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
|
||||||
|
def test_sse_connect_error_cleanup(self, mock_create_client, mock_connect_sse):
|
||||||
|
"""Test SSE connection cleans up client on error."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_client = Mock(spec=httpx.Client)
|
||||||
|
mock_create_client.return_value = mock_client
|
||||||
|
|
||||||
|
# Make connect_sse raise an exception
|
||||||
|
mock_connect_sse.side_effect = httpx.ConnectError("Connection failed")
|
||||||
|
|
||||||
|
# Call should raise the exception
|
||||||
|
with pytest.raises(httpx.ConnectError):
|
||||||
|
ssrf_proxy_sse_connect("http://example.com/sse")
|
||||||
|
|
||||||
|
# Verify client was cleaned up
|
||||||
|
mock_client.close.assert_called_once()
|
||||||
|
|
||||||
|
@patch("core.mcp.utils.connect_sse")
|
||||||
|
def test_sse_connect_error_no_cleanup_with_provided_client(self, mock_connect_sse):
|
||||||
|
"""Test SSE connection doesn't clean up provided client on error."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_client = Mock(spec=httpx.Client)
|
||||||
|
|
||||||
|
# Make connect_sse raise an exception
|
||||||
|
mock_connect_sse.side_effect = httpx.ConnectError("Connection failed")
|
||||||
|
|
||||||
|
# Call should raise the exception
|
||||||
|
with pytest.raises(httpx.ConnectError):
|
||||||
|
ssrf_proxy_sse_connect("http://example.com/sse", client=mock_client)
|
||||||
|
|
||||||
|
# Verify client was NOT cleaned up (because it was provided)
|
||||||
|
mock_client.close.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateMCPErrorResponse:
|
||||||
|
"""Test create_mcp_error_response function."""
|
||||||
|
|
||||||
|
def test_create_error_response_basic(self):
|
||||||
|
"""Test creating basic error response."""
|
||||||
|
generator = create_mcp_error_response(request_id="req-123", code=-32600, message="Invalid Request")
|
||||||
|
|
||||||
|
# Generator should yield bytes
|
||||||
|
assert isinstance(generator, Generator)
|
||||||
|
|
||||||
|
# Get the response
|
||||||
|
response_bytes = next(generator)
|
||||||
|
assert isinstance(response_bytes, bytes)
|
||||||
|
|
||||||
|
# Parse the response
|
||||||
|
response_str = response_bytes.decode("utf-8")
|
||||||
|
response_json = json.loads(response_str)
|
||||||
|
|
||||||
|
assert response_json["jsonrpc"] == "2.0"
|
||||||
|
assert response_json["id"] == "req-123"
|
||||||
|
assert response_json["error"]["code"] == -32600
|
||||||
|
assert response_json["error"]["message"] == "Invalid Request"
|
||||||
|
assert response_json["error"]["data"] is None
|
||||||
|
|
||||||
|
# Generator should be exhausted
|
||||||
|
with pytest.raises(StopIteration):
|
||||||
|
next(generator)
|
||||||
|
|
||||||
|
def test_create_error_response_with_data(self):
|
||||||
|
"""Test creating error response with additional data."""
|
||||||
|
error_data = {"field": "username", "reason": "required"}
|
||||||
|
|
||||||
|
generator = create_mcp_error_response(
|
||||||
|
request_id=456, # Numeric ID
|
||||||
|
code=-32602,
|
||||||
|
message="Invalid params",
|
||||||
|
data=error_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
response_bytes = next(generator)
|
||||||
|
response_json = json.loads(response_bytes.decode("utf-8"))
|
||||||
|
|
||||||
|
assert response_json["id"] == 456
|
||||||
|
assert response_json["error"]["code"] == -32602
|
||||||
|
assert response_json["error"]["message"] == "Invalid params"
|
||||||
|
assert response_json["error"]["data"] == error_data
|
||||||
|
|
||||||
|
def test_create_error_response_without_request_id(self):
|
||||||
|
"""Test creating error response without request ID."""
|
||||||
|
generator = create_mcp_error_response(request_id=None, code=-32700, message="Parse error")
|
||||||
|
|
||||||
|
response_bytes = next(generator)
|
||||||
|
response_json = json.loads(response_bytes.decode("utf-8"))
|
||||||
|
|
||||||
|
# Should default to ID 1
|
||||||
|
assert response_json["id"] == 1
|
||||||
|
assert response_json["error"]["code"] == -32700
|
||||||
|
assert response_json["error"]["message"] == "Parse error"
|
||||||
|
|
||||||
|
def test_create_error_response_with_complex_data(self):
|
||||||
|
"""Test creating error response with complex error data."""
|
||||||
|
complex_data = {
|
||||||
|
"errors": [{"field": "name", "message": "Too short"}, {"field": "email", "message": "Invalid format"}],
|
||||||
|
"timestamp": "2024-01-01T00:00:00Z",
|
||||||
|
}
|
||||||
|
|
||||||
|
generator = create_mcp_error_response(
|
||||||
|
request_id="complex-req", code=-32602, message="Validation failed", data=complex_data
|
||||||
|
)
|
||||||
|
|
||||||
|
response_bytes = next(generator)
|
||||||
|
response_json = json.loads(response_bytes.decode("utf-8"))
|
||||||
|
|
||||||
|
assert response_json["error"]["data"] == complex_data
|
||||||
|
assert len(response_json["error"]["data"]["errors"]) == 2
|
||||||
|
|
||||||
|
def test_create_error_response_encoding(self):
|
||||||
|
"""Test error response with non-ASCII characters."""
|
||||||
|
generator = create_mcp_error_response(
|
||||||
|
request_id="unicode-req",
|
||||||
|
code=-32603,
|
||||||
|
message="内部错误", # Chinese characters
|
||||||
|
data={"details": "エラー詳細"}, # Japanese characters
|
||||||
|
)
|
||||||
|
|
||||||
|
response_bytes = next(generator)
|
||||||
|
|
||||||
|
# Should be valid UTF-8
|
||||||
|
response_str = response_bytes.decode("utf-8")
|
||||||
|
response_json = json.loads(response_str)
|
||||||
|
|
||||||
|
assert response_json["error"]["message"] == "内部错误"
|
||||||
|
assert response_json["error"]["data"]["details"] == "エラー詳細"
|
||||||
|
|
||||||
|
def test_create_error_response_yields_once(self):
|
||||||
|
"""Test that error response generator yields exactly once."""
|
||||||
|
generator = create_mcp_error_response(request_id="test", code=-32600, message="Test")
|
||||||
|
|
||||||
|
# First yield should work
|
||||||
|
first_yield = next(generator)
|
||||||
|
assert isinstance(first_yield, bytes)
|
||||||
|
|
||||||
|
# Second yield should raise StopIteration
|
||||||
|
with pytest.raises(StopIteration):
|
||||||
|
next(generator)
|
||||||
|
|
||||||
|
# Subsequent calls should also raise
|
||||||
|
with pytest.raises(StopIteration):
|
||||||
|
next(generator)
|
||||||
@ -0,0 +1,96 @@
|
|||||||
|
"""
|
||||||
|
Test cases for the Iteration node's flatten_output functionality.
|
||||||
|
|
||||||
|
This module tests the iteration node's ability to:
|
||||||
|
1. Flatten array outputs when flatten_output=True (default)
|
||||||
|
2. Preserve nested array structure when flatten_output=False
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .test_table_runner import TableTestRunner, WorkflowTestCase
|
||||||
|
|
||||||
|
|
||||||
|
def test_iteration_with_flatten_output_enabled():
|
||||||
|
"""
|
||||||
|
Test iteration node with flatten_output=True (default behavior).
|
||||||
|
|
||||||
|
The fixture implements an iteration that:
|
||||||
|
1. Iterates over [1, 2, 3]
|
||||||
|
2. For each item, outputs [item, item*2]
|
||||||
|
3. With flatten_output=True, should output [1, 2, 2, 4, 3, 6]
|
||||||
|
"""
|
||||||
|
runner = TableTestRunner()
|
||||||
|
|
||||||
|
test_case = WorkflowTestCase(
|
||||||
|
fixture_path="iteration_flatten_output_enabled_workflow",
|
||||||
|
inputs={},
|
||||||
|
expected_outputs={"output": [1, 2, 2, 4, 3, 6]},
|
||||||
|
description="Iteration with flatten_output=True flattens nested arrays",
|
||||||
|
use_auto_mock=False, # Run code nodes directly
|
||||||
|
)
|
||||||
|
|
||||||
|
result = runner.run_test_case(test_case)
|
||||||
|
|
||||||
|
assert result.success, f"Test failed: {result.error}"
|
||||||
|
assert result.actual_outputs is not None, "Should have outputs"
|
||||||
|
assert result.actual_outputs == {"output": [1, 2, 2, 4, 3, 6]}, (
|
||||||
|
f"Expected flattened output [1, 2, 2, 4, 3, 6], got {result.actual_outputs}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_iteration_with_flatten_output_disabled():
|
||||||
|
"""
|
||||||
|
Test iteration node with flatten_output=False.
|
||||||
|
|
||||||
|
The fixture implements an iteration that:
|
||||||
|
1. Iterates over [1, 2, 3]
|
||||||
|
2. For each item, outputs [item, item*2]
|
||||||
|
3. With flatten_output=False, should output [[1, 2], [2, 4], [3, 6]]
|
||||||
|
"""
|
||||||
|
runner = TableTestRunner()
|
||||||
|
|
||||||
|
test_case = WorkflowTestCase(
|
||||||
|
fixture_path="iteration_flatten_output_disabled_workflow",
|
||||||
|
inputs={},
|
||||||
|
expected_outputs={"output": [[1, 2], [2, 4], [3, 6]]},
|
||||||
|
description="Iteration with flatten_output=False preserves nested structure",
|
||||||
|
use_auto_mock=False, # Run code nodes directly
|
||||||
|
)
|
||||||
|
|
||||||
|
result = runner.run_test_case(test_case)
|
||||||
|
|
||||||
|
assert result.success, f"Test failed: {result.error}"
|
||||||
|
assert result.actual_outputs is not None, "Should have outputs"
|
||||||
|
assert result.actual_outputs == {"output": [[1, 2], [2, 4], [3, 6]]}, (
|
||||||
|
f"Expected nested output [[1, 2], [2, 4], [3, 6]], got {result.actual_outputs}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_iteration_flatten_output_comparison():
|
||||||
|
"""
|
||||||
|
Run both flatten_output configurations in parallel to verify the difference.
|
||||||
|
"""
|
||||||
|
runner = TableTestRunner()
|
||||||
|
|
||||||
|
test_cases = [
|
||||||
|
WorkflowTestCase(
|
||||||
|
fixture_path="iteration_flatten_output_enabled_workflow",
|
||||||
|
inputs={},
|
||||||
|
expected_outputs={"output": [1, 2, 2, 4, 3, 6]},
|
||||||
|
description="flatten_output=True: Flattened output",
|
||||||
|
use_auto_mock=False, # Run code nodes directly
|
||||||
|
),
|
||||||
|
WorkflowTestCase(
|
||||||
|
fixture_path="iteration_flatten_output_disabled_workflow",
|
||||||
|
inputs={},
|
||||||
|
expected_outputs={"output": [[1, 2], [2, 4], [3, 6]]},
|
||||||
|
description="flatten_output=False: Nested output",
|
||||||
|
use_auto_mock=False, # Run code nodes directly
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
suite_result = runner.run_table_tests(test_cases, parallel=True)
|
||||||
|
|
||||||
|
# Assert all tests passed
|
||||||
|
assert suite_result.passed_tests == 2, f"Expected 2 passed tests, got {suite_result.passed_tests}"
|
||||||
|
assert suite_result.failed_tests == 0, f"Expected 0 failed tests, got {suite_result.failed_tests}"
|
||||||
|
assert suite_result.success_rate == 100.0, f"Expected 100% success rate, got {suite_result.success_rate}"
|
||||||
@ -1,5 +1,10 @@
|
|||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from werkzeug.wrappers import Response
|
||||||
|
|
||||||
from constants import COOKIE_NAME_ACCESS_TOKEN, COOKIE_NAME_WEBAPP_ACCESS_TOKEN
|
from constants import COOKIE_NAME_ACCESS_TOKEN, COOKIE_NAME_WEBAPP_ACCESS_TOKEN
|
||||||
from libs.token import extract_access_token, extract_webapp_access_token
|
from libs import token
|
||||||
|
from libs.token import extract_access_token, extract_webapp_access_token, set_csrf_token_to_cookie
|
||||||
|
|
||||||
|
|
||||||
class MockRequest:
|
class MockRequest:
|
||||||
@ -23,3 +28,35 @@ def test_extract_access_token():
|
|||||||
for request, expected_console, expected_webapp in test_cases:
|
for request, expected_console, expected_webapp in test_cases:
|
||||||
assert extract_access_token(request) == expected_console # pyright: ignore[reportArgumentType]
|
assert extract_access_token(request) == expected_console # pyright: ignore[reportArgumentType]
|
||||||
assert extract_webapp_access_token(request) == expected_webapp # pyright: ignore[reportArgumentType]
|
assert extract_webapp_access_token(request) == expected_webapp # pyright: ignore[reportArgumentType]
|
||||||
|
|
||||||
|
|
||||||
|
def test_real_cookie_name_uses_host_prefix_without_domain(monkeypatch):
|
||||||
|
monkeypatch.setattr(token.dify_config, "CONSOLE_WEB_URL", "https://console.example.com", raising=False)
|
||||||
|
monkeypatch.setattr(token.dify_config, "CONSOLE_API_URL", "https://api.example.com", raising=False)
|
||||||
|
monkeypatch.setattr(token.dify_config, "COOKIE_DOMAIN", "", raising=False)
|
||||||
|
|
||||||
|
assert token._real_cookie_name("csrf_token") == "__Host-csrf_token"
|
||||||
|
|
||||||
|
|
||||||
|
def test_real_cookie_name_without_host_prefix_when_domain_present(monkeypatch):
|
||||||
|
monkeypatch.setattr(token.dify_config, "CONSOLE_WEB_URL", "https://console.example.com", raising=False)
|
||||||
|
monkeypatch.setattr(token.dify_config, "CONSOLE_API_URL", "https://api.example.com", raising=False)
|
||||||
|
monkeypatch.setattr(token.dify_config, "COOKIE_DOMAIN", ".example.com", raising=False)
|
||||||
|
|
||||||
|
assert token._real_cookie_name("csrf_token") == "csrf_token"
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_csrf_cookie_includes_domain_when_configured(monkeypatch):
|
||||||
|
monkeypatch.setattr(token.dify_config, "CONSOLE_WEB_URL", "https://console.example.com", raising=False)
|
||||||
|
monkeypatch.setattr(token.dify_config, "CONSOLE_API_URL", "https://api.example.com", raising=False)
|
||||||
|
monkeypatch.setattr(token.dify_config, "COOKIE_DOMAIN", ".example.com", raising=False)
|
||||||
|
|
||||||
|
response = Response()
|
||||||
|
request = MagicMock()
|
||||||
|
|
||||||
|
set_csrf_token_to_cookie(request, response, "abc123")
|
||||||
|
|
||||||
|
cookies = response.headers.getlist("Set-Cookie")
|
||||||
|
assert any("csrf_token=abc123" in c for c in cookies)
|
||||||
|
assert any("Domain=example.com" in c for c in cookies)
|
||||||
|
assert all("__Host-" not in c for c in cookies)
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from core.file.enums import FileTransferMethod, FileType
|
|||||||
from core.file.models import File
|
from core.file.models import File
|
||||||
from core.variables.segments import (
|
from core.variables.segments import (
|
||||||
ArrayFileSegment,
|
ArrayFileSegment,
|
||||||
|
ArrayNumberSegment,
|
||||||
ArraySegment,
|
ArraySegment,
|
||||||
FileSegment,
|
FileSegment,
|
||||||
FloatSegment,
|
FloatSegment,
|
||||||
@ -30,6 +31,7 @@ from core.variables.segments import (
|
|||||||
StringSegment,
|
StringSegment,
|
||||||
)
|
)
|
||||||
from services.variable_truncator import (
|
from services.variable_truncator import (
|
||||||
|
DummyVariableTruncator,
|
||||||
MaxDepthExceededError,
|
MaxDepthExceededError,
|
||||||
TruncationResult,
|
TruncationResult,
|
||||||
UnknownTypeError,
|
UnknownTypeError,
|
||||||
@ -596,3 +598,32 @@ class TestIntegrationScenarios:
|
|||||||
truncated_mapping, truncated = truncator.truncate_variable_mapping(mapping)
|
truncated_mapping, truncated = truncator.truncate_variable_mapping(mapping)
|
||||||
assert truncated is False
|
assert truncated is False
|
||||||
assert truncated_mapping == mapping
|
assert truncated_mapping == mapping
|
||||||
|
|
||||||
|
|
||||||
|
def test_dummy_variable_truncator_methods():
|
||||||
|
"""Test DummyVariableTruncator methods work correctly."""
|
||||||
|
truncator = DummyVariableTruncator()
|
||||||
|
|
||||||
|
# Test truncate_variable_mapping
|
||||||
|
test_data: dict[str, Any] = {
|
||||||
|
"key1": "value1",
|
||||||
|
"key2": ["item1", "item2"],
|
||||||
|
"large_array": list(range(2000)),
|
||||||
|
}
|
||||||
|
result, is_truncated = truncator.truncate_variable_mapping(test_data)
|
||||||
|
|
||||||
|
assert result == test_data
|
||||||
|
assert not is_truncated
|
||||||
|
|
||||||
|
# Test truncate method
|
||||||
|
segment = StringSegment(value="test string")
|
||||||
|
result = truncator.truncate(segment)
|
||||||
|
assert isinstance(result, TruncationResult)
|
||||||
|
assert result.result == segment
|
||||||
|
assert result.truncated is False
|
||||||
|
|
||||||
|
segment = ArrayNumberSegment(value=list(range(2000)))
|
||||||
|
result = truncator.truncate(segment)
|
||||||
|
assert isinstance(result, TruncationResult)
|
||||||
|
assert result.result == segment
|
||||||
|
assert result.truncated is False
|
||||||
|
|||||||
@ -180,6 +180,25 @@ class TestMCPToolTransform:
|
|||||||
# Set tools data with null description
|
# Set tools data with null description
|
||||||
mock_provider_full.tools = '[{"name": "tool1", "description": null, "inputSchema": {}}]'
|
mock_provider_full.tools = '[{"name": "tool1", "description": null, "inputSchema": {}}]'
|
||||||
|
|
||||||
|
# Mock the to_entity and to_api_response methods
|
||||||
|
mock_entity = Mock()
|
||||||
|
mock_entity.to_api_response.return_value = {
|
||||||
|
"name": "Test MCP Provider",
|
||||||
|
"type": ToolProviderType.MCP,
|
||||||
|
"is_team_authorization": True,
|
||||||
|
"server_url": "https://*****.com/mcp",
|
||||||
|
"provider_icon": "icon.png",
|
||||||
|
"masked_headers": {"Authorization": "Bearer *****"},
|
||||||
|
"updated_at": 1234567890,
|
||||||
|
"labels": [],
|
||||||
|
"author": "Test User",
|
||||||
|
"description": I18nObject(en_US="Test MCP Provider Description", zh_Hans="Test MCP Provider Description"),
|
||||||
|
"icon": "icon.png",
|
||||||
|
"label": I18nObject(en_US="Test MCP Provider", zh_Hans="Test MCP Provider"),
|
||||||
|
"masked_credentials": {},
|
||||||
|
}
|
||||||
|
mock_provider_full.to_entity.return_value = mock_entity
|
||||||
|
|
||||||
# Call the method with for_list=True
|
# Call the method with for_list=True
|
||||||
result = ToolTransformService.mcp_provider_to_user_provider(mock_provider_full, for_list=True)
|
result = ToolTransformService.mcp_provider_to_user_provider(mock_provider_full, for_list=True)
|
||||||
|
|
||||||
@ -198,6 +217,27 @@ class TestMCPToolTransform:
|
|||||||
# Set tools data with description
|
# Set tools data with description
|
||||||
mock_provider_full.tools = '[{"name": "tool1", "description": "Tool description", "inputSchema": {}}]'
|
mock_provider_full.tools = '[{"name": "tool1", "description": "Tool description", "inputSchema": {}}]'
|
||||||
|
|
||||||
|
# Mock the to_entity and to_api_response methods
|
||||||
|
mock_entity = Mock()
|
||||||
|
mock_entity.to_api_response.return_value = {
|
||||||
|
"name": "Test MCP Provider",
|
||||||
|
"type": ToolProviderType.MCP,
|
||||||
|
"is_team_authorization": True,
|
||||||
|
"server_url": "https://*****.com/mcp",
|
||||||
|
"provider_icon": "icon.png",
|
||||||
|
"masked_headers": {"Authorization": "Bearer *****"},
|
||||||
|
"updated_at": 1234567890,
|
||||||
|
"labels": [],
|
||||||
|
"configuration": {"timeout": "30", "sse_read_timeout": "300"},
|
||||||
|
"original_headers": {"Authorization": "Bearer secret-token"},
|
||||||
|
"author": "Test User",
|
||||||
|
"description": I18nObject(en_US="Test MCP Provider Description", zh_Hans="Test MCP Provider Description"),
|
||||||
|
"icon": "icon.png",
|
||||||
|
"label": I18nObject(en_US="Test MCP Provider", zh_Hans="Test MCP Provider"),
|
||||||
|
"masked_credentials": {},
|
||||||
|
}
|
||||||
|
mock_provider_full.to_entity.return_value = mock_entity
|
||||||
|
|
||||||
# Call the method with for_list=False
|
# Call the method with for_list=False
|
||||||
result = ToolTransformService.mcp_provider_to_user_provider(mock_provider_full, for_list=False)
|
result = ToolTransformService.mcp_provider_to_user_provider(mock_provider_full, for_list=False)
|
||||||
|
|
||||||
@ -205,8 +245,9 @@ class TestMCPToolTransform:
|
|||||||
assert isinstance(result, ToolProviderApiEntity)
|
assert isinstance(result, ToolProviderApiEntity)
|
||||||
assert result.id == "server-identifier-456" # Should use server_identifier when for_list=False
|
assert result.id == "server-identifier-456" # Should use server_identifier when for_list=False
|
||||||
assert result.server_identifier == "server-identifier-456"
|
assert result.server_identifier == "server-identifier-456"
|
||||||
assert result.timeout == 30
|
assert result.configuration is not None
|
||||||
assert result.sse_read_timeout == 300
|
assert result.configuration.timeout == 30
|
||||||
|
assert result.configuration.sse_read_timeout == 300
|
||||||
assert result.original_headers == {"Authorization": "Bearer secret-token"}
|
assert result.original_headers == {"Authorization": "Bearer secret-token"}
|
||||||
assert len(result.tools) == 1
|
assert len(result.tools) == 1
|
||||||
assert result.tools[0].description.en_US == "Tool description"
|
assert result.tools[0].description.en_US == "Tool description"
|
||||||
|
|||||||
8
api/uv.lock
generated
8
api/uv.lock
generated
@ -1590,7 +1590,7 @@ dev = [
|
|||||||
{ name = "ruff", specifier = "~=0.14.0" },
|
{ name = "ruff", specifier = "~=0.14.0" },
|
||||||
{ name = "scipy-stubs", specifier = ">=1.15.3.0" },
|
{ name = "scipy-stubs", specifier = ">=1.15.3.0" },
|
||||||
{ name = "sseclient-py", specifier = ">=1.8.0" },
|
{ name = "sseclient-py", specifier = ">=1.8.0" },
|
||||||
{ name = "testcontainers", specifier = "~=4.10.0" },
|
{ name = "testcontainers", specifier = "~=4.13.2" },
|
||||||
{ name = "ty", specifier = "~=0.0.1a19" },
|
{ name = "ty", specifier = "~=0.0.1a19" },
|
||||||
{ name = "types-aiofiles", specifier = "~=24.1.0" },
|
{ name = "types-aiofiles", specifier = "~=24.1.0" },
|
||||||
{ name = "types-beautifulsoup4", specifier = "~=4.12.0" },
|
{ name = "types-beautifulsoup4", specifier = "~=4.12.0" },
|
||||||
@ -5907,7 +5907,7 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "testcontainers"
|
name = "testcontainers"
|
||||||
version = "4.10.0"
|
version = "4.13.2"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "docker" },
|
{ name = "docker" },
|
||||||
@ -5916,9 +5916,9 @@ dependencies = [
|
|||||||
{ name = "urllib3" },
|
{ name = "urllib3" },
|
||||||
{ name = "wrapt" },
|
{ name = "wrapt" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/a1/49/9c618aff1c50121d183cdfbc3a4a5cf2727a2cde1893efe6ca55c7009196/testcontainers-4.10.0.tar.gz", hash = "sha256:03f85c3e505d8b4edeb192c72a961cebbcba0dd94344ae778b4a159cb6dcf8d3", size = 63327, upload-time = "2025-04-02T16:13:27.582Z" }
|
sdist = { url = "https://files.pythonhosted.org/packages/18/51/edac83edab339d8b4dce9a7b659163afb1ea7e011bfed1d5573d495a4485/testcontainers-4.13.2.tar.gz", hash = "sha256:2315f1e21b059427a9d11e8921f85fef322fbe0d50749bcca4eaa11271708ba4", size = 78692, upload-time = "2025-10-07T21:53:07.531Z" }
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/1c/0a/824b0c1ecf224802125279c3effff2e25ed785ed046e67da6e53d928de4c/testcontainers-4.10.0-py3-none-any.whl", hash = "sha256:31ed1a81238c7e131a2a29df6db8f23717d892b592fa5a1977fd0dcd0c23fc23", size = 107414, upload-time = "2025-04-02T16:13:25.785Z" },
|
{ url = "https://files.pythonhosted.org/packages/2a/5e/73aa94770f1df0595364aed526f31d54440db5492911e2857318ed326e51/testcontainers-4.13.2-py3-none-any.whl", hash = "sha256:0209baf8f4274b568cde95bef2cadf7b1d33b375321f793790462e235cd684ee", size = 124771, upload-time = "2025-10-07T21:53:05.937Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|||||||
@ -201,6 +201,10 @@ ENABLE_WEBSITE_JINAREADER=true
|
|||||||
ENABLE_WEBSITE_FIRECRAWL=true
|
ENABLE_WEBSITE_FIRECRAWL=true
|
||||||
ENABLE_WEBSITE_WATERCRAWL=true
|
ENABLE_WEBSITE_WATERCRAWL=true
|
||||||
|
|
||||||
|
# Enable inline LaTeX rendering with single dollar signs ($...$) in the web frontend
|
||||||
|
# Default is false for security reasons to prevent conflicts with regular text
|
||||||
|
NEXT_PUBLIC_ENABLE_SINGLE_DOLLAR_LATEX=false
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# Database Configuration
|
# Database Configuration
|
||||||
# The database uses PostgreSQL. Please use the public schema.
|
# The database uses PostgreSQL. Please use the public schema.
|
||||||
@ -260,16 +264,18 @@ POSTGRES_MAINTENANCE_WORK_MEM=64MB
|
|||||||
POSTGRES_EFFECTIVE_CACHE_SIZE=4096MB
|
POSTGRES_EFFECTIVE_CACHE_SIZE=4096MB
|
||||||
|
|
||||||
# Sets the maximum allowed duration of any statement before termination.
|
# Sets the maximum allowed duration of any statement before termination.
|
||||||
# Default is 60000 milliseconds.
|
# Default is 0 (no timeout).
|
||||||
#
|
#
|
||||||
# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-STATEMENT-TIMEOUT
|
# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-STATEMENT-TIMEOUT
|
||||||
POSTGRES_STATEMENT_TIMEOUT=60000
|
# A value of 0 prevents the server from timing out statements.
|
||||||
|
POSTGRES_STATEMENT_TIMEOUT=0
|
||||||
|
|
||||||
# Sets the maximum allowed duration of any idle in-transaction session before termination.
|
# Sets the maximum allowed duration of any idle in-transaction session before termination.
|
||||||
# Default is 60000 milliseconds.
|
# Default is 0 (no timeout).
|
||||||
#
|
#
|
||||||
# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-IDLE-IN-TRANSACTION-SESSION-TIMEOUT
|
# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-IDLE-IN-TRANSACTION-SESSION-TIMEOUT
|
||||||
POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT=60000
|
# A value of 0 prevents the server from terminating idle sessions.
|
||||||
|
POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT=0
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# Redis Configuration
|
# Redis Configuration
|
||||||
@ -314,7 +320,7 @@ REDIS_CLUSTERS_PASSWORD=
|
|||||||
# Celery Configuration
|
# Celery Configuration
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
|
|
||||||
# Use standalone redis as the broker, and redis db 1 for celery broker. (redis_username is usually set by defualt as empty)
|
# Use standalone redis as the broker, and redis db 1 for celery broker. (redis_username is usually set by default as empty)
|
||||||
# Format as follows: `redis://<redis_username>:<redis_password>@<redis_host>:<redis_port>/<redis_database>`.
|
# Format as follows: `redis://<redis_username>:<redis_password>@<redis_host>:<redis_port>/<redis_database>`.
|
||||||
# Example: redis://:difyai123456@redis:6379/1
|
# Example: redis://:difyai123456@redis:6379/1
|
||||||
# If use Redis Sentinel, format as follows: `sentinel://<redis_username>:<redis_password>@<sentinel_host1>:<sentinel_port>/<redis_database>`
|
# If use Redis Sentinel, format as follows: `sentinel://<redis_username>:<redis_password>@<sentinel_host1>:<sentinel_port>/<redis_database>`
|
||||||
@ -342,6 +348,11 @@ WEB_API_CORS_ALLOW_ORIGINS=*
|
|||||||
# Specifies the allowed origins for cross-origin requests to the console API,
|
# Specifies the allowed origins for cross-origin requests to the console API,
|
||||||
# e.g. https://cloud.dify.ai or * for all origins.
|
# e.g. https://cloud.dify.ai or * for all origins.
|
||||||
CONSOLE_CORS_ALLOW_ORIGINS=*
|
CONSOLE_CORS_ALLOW_ORIGINS=*
|
||||||
|
# Set COOKIE_DOMAIN when the console frontend and API are on different subdomains.
|
||||||
|
# Provide the registrable domain (e.g. example.com); leading dots are optional.
|
||||||
|
COOKIE_DOMAIN=
|
||||||
|
# The frontend reads NEXT_PUBLIC_COOKIE_DOMAIN to align cookie handling with the API.
|
||||||
|
NEXT_PUBLIC_COOKIE_DOMAIN=
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# File Storage Configuration
|
# File Storage Configuration
|
||||||
|
|||||||
@ -81,6 +81,7 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||||
APP_API_URL: ${APP_API_URL:-}
|
APP_API_URL: ${APP_API_URL:-}
|
||||||
|
NEXT_PUBLIC_COOKIE_DOMAIN: ${NEXT_PUBLIC_COOKIE_DOMAIN:-}
|
||||||
SENTRY_DSN: ${WEB_SENTRY_DSN:-}
|
SENTRY_DSN: ${WEB_SENTRY_DSN:-}
|
||||||
NEXT_TELEMETRY_DISABLED: ${NEXT_TELEMETRY_DISABLED:-0}
|
NEXT_TELEMETRY_DISABLED: ${NEXT_TELEMETRY_DISABLED:-0}
|
||||||
TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000}
|
TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000}
|
||||||
@ -115,8 +116,8 @@ services:
|
|||||||
-c 'work_mem=${POSTGRES_WORK_MEM:-4MB}'
|
-c 'work_mem=${POSTGRES_WORK_MEM:-4MB}'
|
||||||
-c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}'
|
-c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}'
|
||||||
-c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}'
|
-c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}'
|
||||||
-c 'statement_timeout=${POSTGRES_STATEMENT_TIMEOUT:-60000}'
|
-c 'statement_timeout=${POSTGRES_STATEMENT_TIMEOUT:-0}'
|
||||||
-c 'idle_in_transaction_session_timeout=${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-60000}'
|
-c 'idle_in_transaction_session_timeout=${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-0}'
|
||||||
volumes:
|
volumes:
|
||||||
- ./volumes/db/data:/var/lib/postgresql/data
|
- ./volumes/db/data:/var/lib/postgresql/data
|
||||||
healthcheck:
|
healthcheck:
|
||||||
|
|||||||
@ -15,8 +15,8 @@ services:
|
|||||||
-c 'work_mem=${POSTGRES_WORK_MEM:-4MB}'
|
-c 'work_mem=${POSTGRES_WORK_MEM:-4MB}'
|
||||||
-c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}'
|
-c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}'
|
||||||
-c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}'
|
-c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}'
|
||||||
-c 'statement_timeout=${POSTGRES_STATEMENT_TIMEOUT:-60000}'
|
-c 'statement_timeout=${POSTGRES_STATEMENT_TIMEOUT:-0}'
|
||||||
-c 'idle_in_transaction_session_timeout=${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-60000}'
|
-c 'idle_in_transaction_session_timeout=${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-0}'
|
||||||
volumes:
|
volumes:
|
||||||
- ${PGDATA_HOST_VOLUME:-./volumes/db/data}:/var/lib/postgresql/data
|
- ${PGDATA_HOST_VOLUME:-./volumes/db/data}:/var/lib/postgresql/data
|
||||||
ports:
|
ports:
|
||||||
|
|||||||
@ -51,6 +51,7 @@ x-shared-env: &shared-api-worker-env
|
|||||||
ENABLE_WEBSITE_JINAREADER: ${ENABLE_WEBSITE_JINAREADER:-true}
|
ENABLE_WEBSITE_JINAREADER: ${ENABLE_WEBSITE_JINAREADER:-true}
|
||||||
ENABLE_WEBSITE_FIRECRAWL: ${ENABLE_WEBSITE_FIRECRAWL:-true}
|
ENABLE_WEBSITE_FIRECRAWL: ${ENABLE_WEBSITE_FIRECRAWL:-true}
|
||||||
ENABLE_WEBSITE_WATERCRAWL: ${ENABLE_WEBSITE_WATERCRAWL:-true}
|
ENABLE_WEBSITE_WATERCRAWL: ${ENABLE_WEBSITE_WATERCRAWL:-true}
|
||||||
|
NEXT_PUBLIC_ENABLE_SINGLE_DOLLAR_LATEX: ${NEXT_PUBLIC_ENABLE_SINGLE_DOLLAR_LATEX:-false}
|
||||||
DB_USERNAME: ${DB_USERNAME:-postgres}
|
DB_USERNAME: ${DB_USERNAME:-postgres}
|
||||||
DB_PASSWORD: ${DB_PASSWORD:-difyai123456}
|
DB_PASSWORD: ${DB_PASSWORD:-difyai123456}
|
||||||
DB_HOST: ${DB_HOST:-db}
|
DB_HOST: ${DB_HOST:-db}
|
||||||
@ -68,8 +69,8 @@ x-shared-env: &shared-api-worker-env
|
|||||||
POSTGRES_WORK_MEM: ${POSTGRES_WORK_MEM:-4MB}
|
POSTGRES_WORK_MEM: ${POSTGRES_WORK_MEM:-4MB}
|
||||||
POSTGRES_MAINTENANCE_WORK_MEM: ${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}
|
POSTGRES_MAINTENANCE_WORK_MEM: ${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}
|
||||||
POSTGRES_EFFECTIVE_CACHE_SIZE: ${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}
|
POSTGRES_EFFECTIVE_CACHE_SIZE: ${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}
|
||||||
POSTGRES_STATEMENT_TIMEOUT: ${POSTGRES_STATEMENT_TIMEOUT:-60000}
|
POSTGRES_STATEMENT_TIMEOUT: ${POSTGRES_STATEMENT_TIMEOUT:-0}
|
||||||
POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT: ${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-60000}
|
POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT: ${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-0}
|
||||||
REDIS_HOST: ${REDIS_HOST:-redis}
|
REDIS_HOST: ${REDIS_HOST:-redis}
|
||||||
REDIS_PORT: ${REDIS_PORT:-6379}
|
REDIS_PORT: ${REDIS_PORT:-6379}
|
||||||
REDIS_USERNAME: ${REDIS_USERNAME:-}
|
REDIS_USERNAME: ${REDIS_USERNAME:-}
|
||||||
@ -98,6 +99,8 @@ x-shared-env: &shared-api-worker-env
|
|||||||
CELERY_SENTINEL_SOCKET_TIMEOUT: ${CELERY_SENTINEL_SOCKET_TIMEOUT:-0.1}
|
CELERY_SENTINEL_SOCKET_TIMEOUT: ${CELERY_SENTINEL_SOCKET_TIMEOUT:-0.1}
|
||||||
WEB_API_CORS_ALLOW_ORIGINS: ${WEB_API_CORS_ALLOW_ORIGINS:-*}
|
WEB_API_CORS_ALLOW_ORIGINS: ${WEB_API_CORS_ALLOW_ORIGINS:-*}
|
||||||
CONSOLE_CORS_ALLOW_ORIGINS: ${CONSOLE_CORS_ALLOW_ORIGINS:-*}
|
CONSOLE_CORS_ALLOW_ORIGINS: ${CONSOLE_CORS_ALLOW_ORIGINS:-*}
|
||||||
|
COOKIE_DOMAIN: ${COOKIE_DOMAIN:-}
|
||||||
|
NEXT_PUBLIC_COOKIE_DOMAIN: ${NEXT_PUBLIC_COOKIE_DOMAIN:-}
|
||||||
STORAGE_TYPE: ${STORAGE_TYPE:-opendal}
|
STORAGE_TYPE: ${STORAGE_TYPE:-opendal}
|
||||||
OPENDAL_SCHEME: ${OPENDAL_SCHEME:-fs}
|
OPENDAL_SCHEME: ${OPENDAL_SCHEME:-fs}
|
||||||
OPENDAL_FS_ROOT: ${OPENDAL_FS_ROOT:-storage}
|
OPENDAL_FS_ROOT: ${OPENDAL_FS_ROOT:-storage}
|
||||||
@ -690,6 +693,7 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||||
APP_API_URL: ${APP_API_URL:-}
|
APP_API_URL: ${APP_API_URL:-}
|
||||||
|
NEXT_PUBLIC_COOKIE_DOMAIN: ${NEXT_PUBLIC_COOKIE_DOMAIN:-}
|
||||||
SENTRY_DSN: ${WEB_SENTRY_DSN:-}
|
SENTRY_DSN: ${WEB_SENTRY_DSN:-}
|
||||||
NEXT_TELEMETRY_DISABLED: ${NEXT_TELEMETRY_DISABLED:-0}
|
NEXT_TELEMETRY_DISABLED: ${NEXT_TELEMETRY_DISABLED:-0}
|
||||||
TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000}
|
TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000}
|
||||||
@ -724,8 +728,8 @@ services:
|
|||||||
-c 'work_mem=${POSTGRES_WORK_MEM:-4MB}'
|
-c 'work_mem=${POSTGRES_WORK_MEM:-4MB}'
|
||||||
-c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}'
|
-c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}'
|
||||||
-c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}'
|
-c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}'
|
||||||
-c 'statement_timeout=${POSTGRES_STATEMENT_TIMEOUT:-60000}'
|
-c 'statement_timeout=${POSTGRES_STATEMENT_TIMEOUT:-0}'
|
||||||
-c 'idle_in_transaction_session_timeout=${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-60000}'
|
-c 'idle_in_transaction_session_timeout=${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-0}'
|
||||||
volumes:
|
volumes:
|
||||||
- ./volumes/db/data:/var/lib/postgresql/data
|
- ./volumes/db/data:/var/lib/postgresql/data
|
||||||
healthcheck:
|
healthcheck:
|
||||||
|
|||||||
@ -41,16 +41,18 @@ POSTGRES_MAINTENANCE_WORK_MEM=64MB
|
|||||||
POSTGRES_EFFECTIVE_CACHE_SIZE=4096MB
|
POSTGRES_EFFECTIVE_CACHE_SIZE=4096MB
|
||||||
|
|
||||||
# Sets the maximum allowed duration of any statement before termination.
|
# Sets the maximum allowed duration of any statement before termination.
|
||||||
# Default is 60000 milliseconds.
|
# Default is 0 (no timeout).
|
||||||
#
|
#
|
||||||
# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-STATEMENT-TIMEOUT
|
# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-STATEMENT-TIMEOUT
|
||||||
POSTGRES_STATEMENT_TIMEOUT=60000
|
# A value of 0 prevents the server from timing out statements.
|
||||||
|
POSTGRES_STATEMENT_TIMEOUT=0
|
||||||
|
|
||||||
# Sets the maximum allowed duration of any idle in-transaction session before termination.
|
# Sets the maximum allowed duration of any idle in-transaction session before termination.
|
||||||
# Default is 60000 milliseconds.
|
# Default is 0 (no timeout).
|
||||||
#
|
#
|
||||||
# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-IDLE-IN-TRANSACTION-SESSION-TIMEOUT
|
# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-IDLE-IN-TRANSACTION-SESSION-TIMEOUT
|
||||||
POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT=60000
|
# A value of 0 prevents the server from terminating idle sessions.
|
||||||
|
POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT=0
|
||||||
|
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
# Environment Variables for redis Service
|
# Environment Variables for redis Service
|
||||||
|
|||||||
@ -34,6 +34,9 @@ NEXT_PUBLIC_CSP_WHITELIST=
|
|||||||
# Default is not allow to embed into iframe to prevent Clickjacking: https://owasp.org/www-community/attacks/Clickjacking
|
# Default is not allow to embed into iframe to prevent Clickjacking: https://owasp.org/www-community/attacks/Clickjacking
|
||||||
NEXT_PUBLIC_ALLOW_EMBED=
|
NEXT_PUBLIC_ALLOW_EMBED=
|
||||||
|
|
||||||
|
# Shared cookie domain when console UI and API use different subdomains (e.g. example.com)
|
||||||
|
NEXT_PUBLIC_COOKIE_DOMAIN=
|
||||||
|
|
||||||
# Allow rendering unsafe URLs which have "data:" scheme.
|
# Allow rendering unsafe URLs which have "data:" scheme.
|
||||||
NEXT_PUBLIC_ALLOW_UNSAFE_DATA_SCHEME=false
|
NEXT_PUBLIC_ALLOW_UNSAFE_DATA_SCHEME=false
|
||||||
|
|
||||||
@ -61,5 +64,9 @@ NEXT_PUBLIC_ENABLE_WEBSITE_JINAREADER=true
|
|||||||
NEXT_PUBLIC_ENABLE_WEBSITE_FIRECRAWL=true
|
NEXT_PUBLIC_ENABLE_WEBSITE_FIRECRAWL=true
|
||||||
NEXT_PUBLIC_ENABLE_WEBSITE_WATERCRAWL=true
|
NEXT_PUBLIC_ENABLE_WEBSITE_WATERCRAWL=true
|
||||||
|
|
||||||
|
# Enable inline LaTeX rendering with single dollar signs ($...$)
|
||||||
|
# Default is false for security reasons to prevent conflicts with regular text
|
||||||
|
NEXT_PUBLIC_ENABLE_SINGLE_DOLLAR_LATEX=false
|
||||||
|
|
||||||
# The maximum number of tree node depth for workflow
|
# The maximum number of tree node depth for workflow
|
||||||
NEXT_PUBLIC_MAX_TREE_DEPTH=50
|
NEXT_PUBLIC_MAX_TREE_DEPTH=50
|
||||||
|
|||||||
132
web/__tests__/embedded-user-id-auth.test.tsx
Normal file
132
web/__tests__/embedded-user-id-auth.test.tsx
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
import React from 'react'
|
||||||
|
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||||
|
|
||||||
|
import MailAndPasswordAuth from '@/app/(shareLayout)/webapp-signin/components/mail-and-password-auth'
|
||||||
|
import CheckCode from '@/app/(shareLayout)/webapp-signin/check-code/page'
|
||||||
|
|
||||||
|
jest.mock('react-i18next', () => ({
|
||||||
|
useTranslation: () => ({
|
||||||
|
t: (key: string) => key,
|
||||||
|
}),
|
||||||
|
}))
|
||||||
|
|
||||||
|
const replaceMock = jest.fn()
|
||||||
|
const backMock = jest.fn()
|
||||||
|
|
||||||
|
jest.mock('next/navigation', () => ({
|
||||||
|
usePathname: jest.fn(() => '/chatbot/test-app'),
|
||||||
|
useRouter: jest.fn(() => ({
|
||||||
|
replace: replaceMock,
|
||||||
|
back: backMock,
|
||||||
|
})),
|
||||||
|
useSearchParams: jest.fn(),
|
||||||
|
}))
|
||||||
|
|
||||||
|
const mockStoreState = {
|
||||||
|
embeddedUserId: 'embedded-user-99',
|
||||||
|
shareCode: 'test-app',
|
||||||
|
}
|
||||||
|
|
||||||
|
const useWebAppStoreMock = jest.fn((selector?: (state: typeof mockStoreState) => any) => {
|
||||||
|
return selector ? selector(mockStoreState) : mockStoreState
|
||||||
|
})
|
||||||
|
|
||||||
|
jest.mock('@/context/web-app-context', () => ({
|
||||||
|
useWebAppStore: (selector?: (state: typeof mockStoreState) => any) => useWebAppStoreMock(selector),
|
||||||
|
}))
|
||||||
|
|
||||||
|
const webAppLoginMock = jest.fn()
|
||||||
|
const webAppEmailLoginWithCodeMock = jest.fn()
|
||||||
|
const sendWebAppEMailLoginCodeMock = jest.fn()
|
||||||
|
|
||||||
|
jest.mock('@/service/common', () => ({
|
||||||
|
webAppLogin: (...args: any[]) => webAppLoginMock(...args),
|
||||||
|
webAppEmailLoginWithCode: (...args: any[]) => webAppEmailLoginWithCodeMock(...args),
|
||||||
|
sendWebAppEMailLoginCode: (...args: any[]) => sendWebAppEMailLoginCodeMock(...args),
|
||||||
|
}))
|
||||||
|
|
||||||
|
const fetchAccessTokenMock = jest.fn()
|
||||||
|
|
||||||
|
jest.mock('@/service/share', () => ({
|
||||||
|
fetchAccessToken: (...args: any[]) => fetchAccessTokenMock(...args),
|
||||||
|
}))
|
||||||
|
|
||||||
|
const setWebAppAccessTokenMock = jest.fn()
|
||||||
|
const setWebAppPassportMock = jest.fn()
|
||||||
|
|
||||||
|
jest.mock('@/service/webapp-auth', () => ({
|
||||||
|
setWebAppAccessToken: (...args: any[]) => setWebAppAccessTokenMock(...args),
|
||||||
|
setWebAppPassport: (...args: any[]) => setWebAppPassportMock(...args),
|
||||||
|
webAppLogout: jest.fn(),
|
||||||
|
}))
|
||||||
|
|
||||||
|
jest.mock('@/app/components/signin/countdown', () => () => <div data-testid="countdown" />)
|
||||||
|
|
||||||
|
jest.mock('@remixicon/react', () => ({
|
||||||
|
RiMailSendFill: () => <div data-testid="mail-icon" />,
|
||||||
|
RiArrowLeftLine: () => <div data-testid="arrow-icon" />,
|
||||||
|
}))
|
||||||
|
|
||||||
|
const { useSearchParams } = jest.requireMock('next/navigation') as {
|
||||||
|
useSearchParams: jest.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks()
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('embedded user id propagation in authentication flows', () => {
|
||||||
|
it('passes embedded user id when logging in with email and password', async () => {
|
||||||
|
const params = new URLSearchParams()
|
||||||
|
params.set('redirect_url', encodeURIComponent('/chatbot/test-app'))
|
||||||
|
useSearchParams.mockReturnValue(params)
|
||||||
|
|
||||||
|
webAppLoginMock.mockResolvedValue({ result: 'success', data: { access_token: 'login-token' } })
|
||||||
|
fetchAccessTokenMock.mockResolvedValue({ access_token: 'passport-token' })
|
||||||
|
|
||||||
|
render(<MailAndPasswordAuth isEmailSetup />)
|
||||||
|
|
||||||
|
fireEvent.change(screen.getByLabelText('login.email'), { target: { value: 'user@example.com' } })
|
||||||
|
fireEvent.change(screen.getByLabelText(/login\.password/), { target: { value: 'strong-password' } })
|
||||||
|
fireEvent.click(screen.getByRole('button', { name: 'login.signBtn' }))
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(fetchAccessTokenMock).toHaveBeenCalledWith({
|
||||||
|
appCode: 'test-app',
|
||||||
|
userId: 'embedded-user-99',
|
||||||
|
})
|
||||||
|
})
|
||||||
|
expect(setWebAppAccessTokenMock).toHaveBeenCalledWith('login-token')
|
||||||
|
expect(setWebAppPassportMock).toHaveBeenCalledWith('test-app', 'passport-token')
|
||||||
|
expect(replaceMock).toHaveBeenCalledWith('/chatbot/test-app')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('passes embedded user id when verifying email code', async () => {
|
||||||
|
const params = new URLSearchParams()
|
||||||
|
params.set('redirect_url', encodeURIComponent('/chatbot/test-app'))
|
||||||
|
params.set('email', encodeURIComponent('user@example.com'))
|
||||||
|
params.set('token', encodeURIComponent('token-abc'))
|
||||||
|
useSearchParams.mockReturnValue(params)
|
||||||
|
|
||||||
|
webAppEmailLoginWithCodeMock.mockResolvedValue({ result: 'success', data: { access_token: 'code-token' } })
|
||||||
|
fetchAccessTokenMock.mockResolvedValue({ access_token: 'passport-token' })
|
||||||
|
|
||||||
|
render(<CheckCode />)
|
||||||
|
|
||||||
|
fireEvent.change(
|
||||||
|
screen.getByPlaceholderText('login.checkCode.verificationCodePlaceholder'),
|
||||||
|
{ target: { value: '123456' } },
|
||||||
|
)
|
||||||
|
fireEvent.click(screen.getByRole('button', { name: 'login.checkCode.verify' }))
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(fetchAccessTokenMock).toHaveBeenCalledWith({
|
||||||
|
appCode: 'test-app',
|
||||||
|
userId: 'embedded-user-99',
|
||||||
|
})
|
||||||
|
})
|
||||||
|
expect(setWebAppAccessTokenMock).toHaveBeenCalledWith('code-token')
|
||||||
|
expect(setWebAppPassportMock).toHaveBeenCalledWith('test-app', 'passport-token')
|
||||||
|
expect(replaceMock).toHaveBeenCalledWith('/chatbot/test-app')
|
||||||
|
})
|
||||||
|
})
|
||||||
155
web/__tests__/embedded-user-id-store.test.tsx
Normal file
155
web/__tests__/embedded-user-id-store.test.tsx
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
import React from 'react'
|
||||||
|
import { render, screen, waitFor } from '@testing-library/react'
|
||||||
|
|
||||||
|
import WebAppStoreProvider, { useWebAppStore } from '@/context/web-app-context'
|
||||||
|
|
||||||
|
jest.mock('next/navigation', () => ({
|
||||||
|
usePathname: jest.fn(() => '/chatbot/sample-app'),
|
||||||
|
useSearchParams: jest.fn(() => {
|
||||||
|
const params = new URLSearchParams()
|
||||||
|
return params
|
||||||
|
}),
|
||||||
|
}))
|
||||||
|
|
||||||
|
jest.mock('@/service/use-share', () => {
|
||||||
|
const { AccessMode } = jest.requireActual('@/models/access-control')
|
||||||
|
return {
|
||||||
|
useGetWebAppAccessModeByCode: jest.fn(() => ({
|
||||||
|
isLoading: false,
|
||||||
|
data: { accessMode: AccessMode.PUBLIC },
|
||||||
|
})),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
jest.mock('@/app/components/base/chat/utils', () => ({
|
||||||
|
getProcessedSystemVariablesFromUrlParams: jest.fn(),
|
||||||
|
}))
|
||||||
|
|
||||||
|
const { getProcessedSystemVariablesFromUrlParams: mockGetProcessedSystemVariablesFromUrlParams }
|
||||||
|
= jest.requireMock('@/app/components/base/chat/utils') as {
|
||||||
|
getProcessedSystemVariablesFromUrlParams: jest.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
jest.mock('@/context/global-public-context', () => {
|
||||||
|
const mockGlobalStoreState = {
|
||||||
|
isGlobalPending: false,
|
||||||
|
setIsGlobalPending: jest.fn(),
|
||||||
|
systemFeatures: {},
|
||||||
|
setSystemFeatures: jest.fn(),
|
||||||
|
}
|
||||||
|
const useGlobalPublicStore = Object.assign(
|
||||||
|
(selector?: (state: typeof mockGlobalStoreState) => any) =>
|
||||||
|
selector ? selector(mockGlobalStoreState) : mockGlobalStoreState,
|
||||||
|
{
|
||||||
|
setState: (updater: any) => {
|
||||||
|
if (typeof updater === 'function')
|
||||||
|
Object.assign(mockGlobalStoreState, updater(mockGlobalStoreState) ?? {})
|
||||||
|
|
||||||
|
else
|
||||||
|
Object.assign(mockGlobalStoreState, updater)
|
||||||
|
},
|
||||||
|
__mockState: mockGlobalStoreState,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
useGlobalPublicStore,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
const {
|
||||||
|
useGlobalPublicStore: useGlobalPublicStoreMock,
|
||||||
|
} = jest.requireMock('@/context/global-public-context') as {
|
||||||
|
useGlobalPublicStore: ((selector?: (state: any) => any) => any) & {
|
||||||
|
setState: (updater: any) => void
|
||||||
|
__mockState: {
|
||||||
|
isGlobalPending: boolean
|
||||||
|
setIsGlobalPending: jest.Mock
|
||||||
|
systemFeatures: Record<string, unknown>
|
||||||
|
setSystemFeatures: jest.Mock
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const mockGlobalStoreState = useGlobalPublicStoreMock.__mockState
|
||||||
|
|
||||||
|
const TestConsumer = () => {
|
||||||
|
const embeddedUserId = useWebAppStore(state => state.embeddedUserId)
|
||||||
|
const embeddedConversationId = useWebAppStore(state => state.embeddedConversationId)
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<div data-testid="embedded-user-id">{embeddedUserId ?? 'null'}</div>
|
||||||
|
<div data-testid="embedded-conversation-id">{embeddedConversationId ?? 'null'}</div>
|
||||||
|
</>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const initialWebAppStore = (() => {
|
||||||
|
const snapshot = useWebAppStore.getState()
|
||||||
|
return {
|
||||||
|
shareCode: null as string | null,
|
||||||
|
appInfo: null,
|
||||||
|
appParams: null,
|
||||||
|
webAppAccessMode: snapshot.webAppAccessMode,
|
||||||
|
appMeta: null,
|
||||||
|
userCanAccessApp: false,
|
||||||
|
embeddedUserId: null,
|
||||||
|
embeddedConversationId: null,
|
||||||
|
updateShareCode: snapshot.updateShareCode,
|
||||||
|
updateAppInfo: snapshot.updateAppInfo,
|
||||||
|
updateAppParams: snapshot.updateAppParams,
|
||||||
|
updateWebAppAccessMode: snapshot.updateWebAppAccessMode,
|
||||||
|
updateWebAppMeta: snapshot.updateWebAppMeta,
|
||||||
|
updateUserCanAccessApp: snapshot.updateUserCanAccessApp,
|
||||||
|
updateEmbeddedUserId: snapshot.updateEmbeddedUserId,
|
||||||
|
updateEmbeddedConversationId: snapshot.updateEmbeddedConversationId,
|
||||||
|
}
|
||||||
|
})()
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
mockGlobalStoreState.isGlobalPending = false
|
||||||
|
mockGetProcessedSystemVariablesFromUrlParams.mockReset()
|
||||||
|
useWebAppStore.setState(initialWebAppStore, true)
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('WebAppStoreProvider embedded user id handling', () => {
|
||||||
|
it('hydrates embedded user and conversation ids from system variables', async () => {
|
||||||
|
mockGetProcessedSystemVariablesFromUrlParams.mockResolvedValue({
|
||||||
|
user_id: 'iframe-user-123',
|
||||||
|
conversation_id: 'conversation-456',
|
||||||
|
})
|
||||||
|
|
||||||
|
render(
|
||||||
|
<WebAppStoreProvider>
|
||||||
|
<TestConsumer />
|
||||||
|
</WebAppStoreProvider>,
|
||||||
|
)
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByTestId('embedded-user-id')).toHaveTextContent('iframe-user-123')
|
||||||
|
expect(screen.getByTestId('embedded-conversation-id')).toHaveTextContent('conversation-456')
|
||||||
|
})
|
||||||
|
expect(useWebAppStore.getState().embeddedUserId).toBe('iframe-user-123')
|
||||||
|
expect(useWebAppStore.getState().embeddedConversationId).toBe('conversation-456')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('clears embedded user id when system variable is absent', async () => {
|
||||||
|
useWebAppStore.setState(state => ({
|
||||||
|
...state,
|
||||||
|
embeddedUserId: 'previous-user',
|
||||||
|
embeddedConversationId: 'existing-conversation',
|
||||||
|
}))
|
||||||
|
mockGetProcessedSystemVariablesFromUrlParams.mockResolvedValue({})
|
||||||
|
|
||||||
|
render(
|
||||||
|
<WebAppStoreProvider>
|
||||||
|
<TestConsumer />
|
||||||
|
</WebAppStoreProvider>,
|
||||||
|
)
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByTestId('embedded-user-id')).toHaveTextContent('null')
|
||||||
|
expect(screen.getByTestId('embedded-conversation-id')).toHaveTextContent('null')
|
||||||
|
})
|
||||||
|
expect(useWebAppStore.getState().embeddedUserId).toBeNull()
|
||||||
|
expect(useWebAppStore.getState().embeddedConversationId).toBeNull()
|
||||||
|
})
|
||||||
|
})
|
||||||
@ -2,14 +2,14 @@ import React from 'react'
|
|||||||
import Main from '@/app/components/explore/installed-app'
|
import Main from '@/app/components/explore/installed-app'
|
||||||
|
|
||||||
export type IInstalledAppProps = {
|
export type IInstalledAppProps = {
|
||||||
params: {
|
params?: Promise<{
|
||||||
appId: string
|
appId: string
|
||||||
}
|
}>
|
||||||
}
|
}
|
||||||
|
|
||||||
// Using Next.js page convention for async server components
|
// Using Next.js page convention for async server components
|
||||||
async function InstalledApp({ params }: IInstalledAppProps) {
|
async function InstalledApp({ params }: IInstalledAppProps) {
|
||||||
const appId = (await params).appId
|
const { appId } = await (params ?? Promise.reject(new Error('Missing params')))
|
||||||
return (
|
return (
|
||||||
<Main id={appId} />
|
<Main id={appId} />
|
||||||
)
|
)
|
||||||
|
|||||||
@ -15,6 +15,7 @@ const Splash: FC<PropsWithChildren> = ({ children }) => {
|
|||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
const shareCode = useWebAppStore(s => s.shareCode)
|
const shareCode = useWebAppStore(s => s.shareCode)
|
||||||
const webAppAccessMode = useWebAppStore(s => s.webAppAccessMode)
|
const webAppAccessMode = useWebAppStore(s => s.webAppAccessMode)
|
||||||
|
const embeddedUserId = useWebAppStore(s => s.embeddedUserId)
|
||||||
const searchParams = useSearchParams()
|
const searchParams = useSearchParams()
|
||||||
const router = useRouter()
|
const router = useRouter()
|
||||||
const redirectUrl = searchParams.get('redirect_url')
|
const redirectUrl = searchParams.get('redirect_url')
|
||||||
@ -69,7 +70,10 @@ const Splash: FC<PropsWithChildren> = ({ children }) => {
|
|||||||
}
|
}
|
||||||
else if (userLoggedIn && !appLoggedIn) {
|
else if (userLoggedIn && !appLoggedIn) {
|
||||||
try {
|
try {
|
||||||
const { access_token } = await fetchAccessToken({ appCode: shareCode! })
|
const { access_token } = await fetchAccessToken({
|
||||||
|
appCode: shareCode!,
|
||||||
|
userId: embeddedUserId || undefined,
|
||||||
|
})
|
||||||
setWebAppPassport(shareCode!, access_token)
|
setWebAppPassport(shareCode!, access_token)
|
||||||
redirectOrFinish()
|
redirectOrFinish()
|
||||||
}
|
}
|
||||||
@ -85,7 +89,8 @@ const Splash: FC<PropsWithChildren> = ({ children }) => {
|
|||||||
router,
|
router,
|
||||||
message,
|
message,
|
||||||
webAppAccessMode,
|
webAppAccessMode,
|
||||||
tokenFromUrl])
|
tokenFromUrl,
|
||||||
|
embeddedUserId])
|
||||||
|
|
||||||
if (message) {
|
if (message) {
|
||||||
return <div className='flex h-full flex-col items-center justify-center gap-y-4'>
|
return <div className='flex h-full flex-col items-center justify-center gap-y-4'>
|
||||||
|
|||||||
@ -12,6 +12,7 @@ import { sendWebAppEMailLoginCode, webAppEmailLoginWithCode } from '@/service/co
|
|||||||
import I18NContext from '@/context/i18n'
|
import I18NContext from '@/context/i18n'
|
||||||
import { setWebAppAccessToken, setWebAppPassport } from '@/service/webapp-auth'
|
import { setWebAppAccessToken, setWebAppPassport } from '@/service/webapp-auth'
|
||||||
import { fetchAccessToken } from '@/service/share'
|
import { fetchAccessToken } from '@/service/share'
|
||||||
|
import { useWebAppStore } from '@/context/web-app-context'
|
||||||
|
|
||||||
export default function CheckCode() {
|
export default function CheckCode() {
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
@ -23,6 +24,7 @@ export default function CheckCode() {
|
|||||||
const [loading, setIsLoading] = useState(false)
|
const [loading, setIsLoading] = useState(false)
|
||||||
const { locale } = useContext(I18NContext)
|
const { locale } = useContext(I18NContext)
|
||||||
const redirectUrl = searchParams.get('redirect_url')
|
const redirectUrl = searchParams.get('redirect_url')
|
||||||
|
const embeddedUserId = useWebAppStore(s => s.embeddedUserId)
|
||||||
|
|
||||||
const getAppCodeFromRedirectUrl = useCallback(() => {
|
const getAppCodeFromRedirectUrl = useCallback(() => {
|
||||||
if (!redirectUrl)
|
if (!redirectUrl)
|
||||||
@ -63,7 +65,10 @@ export default function CheckCode() {
|
|||||||
const ret = await webAppEmailLoginWithCode({ email, code, token })
|
const ret = await webAppEmailLoginWithCode({ email, code, token })
|
||||||
if (ret.result === 'success') {
|
if (ret.result === 'success') {
|
||||||
setWebAppAccessToken(ret.data.access_token)
|
setWebAppAccessToken(ret.data.access_token)
|
||||||
const { access_token } = await fetchAccessToken({ appCode: appCode! })
|
const { access_token } = await fetchAccessToken({
|
||||||
|
appCode: appCode!,
|
||||||
|
userId: embeddedUserId || undefined,
|
||||||
|
})
|
||||||
setWebAppPassport(appCode!, access_token)
|
setWebAppPassport(appCode!, access_token)
|
||||||
router.replace(decodeURIComponent(redirectUrl))
|
router.replace(decodeURIComponent(redirectUrl))
|
||||||
}
|
}
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import { emailRegex } from '@/config'
|
|||||||
import { webAppLogin } from '@/service/common'
|
import { webAppLogin } from '@/service/common'
|
||||||
import Input from '@/app/components/base/input'
|
import Input from '@/app/components/base/input'
|
||||||
import I18NContext from '@/context/i18n'
|
import I18NContext from '@/context/i18n'
|
||||||
|
import { useWebAppStore } from '@/context/web-app-context'
|
||||||
import { noop } from 'lodash-es'
|
import { noop } from 'lodash-es'
|
||||||
import { fetchAccessToken } from '@/service/share'
|
import { fetchAccessToken } from '@/service/share'
|
||||||
import { setWebAppAccessToken, setWebAppPassport } from '@/service/webapp-auth'
|
import { setWebAppAccessToken, setWebAppPassport } from '@/service/webapp-auth'
|
||||||
@ -30,6 +31,7 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut
|
|||||||
|
|
||||||
const [isLoading, setIsLoading] = useState(false)
|
const [isLoading, setIsLoading] = useState(false)
|
||||||
const redirectUrl = searchParams.get('redirect_url')
|
const redirectUrl = searchParams.get('redirect_url')
|
||||||
|
const embeddedUserId = useWebAppStore(s => s.embeddedUserId)
|
||||||
|
|
||||||
const getAppCodeFromRedirectUrl = useCallback(() => {
|
const getAppCodeFromRedirectUrl = useCallback(() => {
|
||||||
if (!redirectUrl)
|
if (!redirectUrl)
|
||||||
@ -82,7 +84,10 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut
|
|||||||
if (res.result === 'success') {
|
if (res.result === 'success') {
|
||||||
setWebAppAccessToken(res.data.access_token)
|
setWebAppAccessToken(res.data.access_token)
|
||||||
|
|
||||||
const { access_token } = await fetchAccessToken({ appCode: appCode! })
|
const { access_token } = await fetchAccessToken({
|
||||||
|
appCode: appCode!,
|
||||||
|
userId: embeddedUserId || undefined,
|
||||||
|
})
|
||||||
setWebAppPassport(appCode!, access_token)
|
setWebAppPassport(appCode!, access_token)
|
||||||
router.replace(decodeURIComponent(redirectUrl))
|
router.replace(decodeURIComponent(redirectUrl))
|
||||||
}
|
}
|
||||||
|
|||||||
3
web/app/account/oauth/authorize/constants.ts
Normal file
3
web/app/account/oauth/authorize/constants.ts
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
export const OAUTH_AUTHORIZE_PENDING_KEY = 'oauth_authorize_pending'
|
||||||
|
export const REDIRECT_URL_KEY = 'oauth_redirect_url'
|
||||||
|
export const OAUTH_AUTHORIZE_PENDING_TTL = 60 * 3
|
||||||
@ -19,11 +19,11 @@ import {
|
|||||||
} from '@remixicon/react'
|
} from '@remixicon/react'
|
||||||
import dayjs from 'dayjs'
|
import dayjs from 'dayjs'
|
||||||
import { useIsLogin } from '@/service/use-common'
|
import { useIsLogin } from '@/service/use-common'
|
||||||
|
import {
|
||||||
export const OAUTH_AUTHORIZE_PENDING_KEY = 'oauth_authorize_pending'
|
OAUTH_AUTHORIZE_PENDING_KEY,
|
||||||
export const REDIRECT_URL_KEY = 'oauth_redirect_url'
|
OAUTH_AUTHORIZE_PENDING_TTL,
|
||||||
|
REDIRECT_URL_KEY,
|
||||||
const OAUTH_AUTHORIZE_PENDING_TTL = 60 * 3
|
} from './constants'
|
||||||
|
|
||||||
function setItemWithExpiry(key: string, value: string, ttl: number) {
|
function setItemWithExpiry(key: string, value: string, ttl: number) {
|
||||||
const item = {
|
const item = {
|
||||||
|
|||||||
@ -44,7 +44,7 @@ import { appDefaultIconBackground } from '@/config'
|
|||||||
import type { PublishWorkflowParams } from '@/types/workflow'
|
import type { PublishWorkflowParams } from '@/types/workflow'
|
||||||
import { useAppWhiteListSubjects, useGetUserCanAccessApp } from '@/service/access-control'
|
import { useAppWhiteListSubjects, useGetUserCanAccessApp } from '@/service/access-control'
|
||||||
import { AccessMode } from '@/models/access-control'
|
import { AccessMode } from '@/models/access-control'
|
||||||
import { fetchAppDetail } from '@/service/apps'
|
import { fetchAppDetailDirect } from '@/service/apps'
|
||||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||||
import { useFormatTimeFromNow } from '@/hooks/use-format-time-from-now'
|
import { useFormatTimeFromNow } from '@/hooks/use-format-time-from-now'
|
||||||
|
|
||||||
@ -162,11 +162,16 @@ const AppPublisher = ({
|
|||||||
}
|
}
|
||||||
}, [appDetail?.id])
|
}, [appDetail?.id])
|
||||||
|
|
||||||
const handleAccessControlUpdate = useCallback(() => {
|
const handleAccessControlUpdate = useCallback(async () => {
|
||||||
fetchAppDetail({ url: '/apps', id: appDetail!.id }).then((res) => {
|
if (!appDetail)
|
||||||
|
return
|
||||||
|
try {
|
||||||
|
const res = await fetchAppDetailDirect({ url: '/apps', id: appDetail.id })
|
||||||
setAppDetail(res)
|
setAppDetail(res)
|
||||||
|
}
|
||||||
|
finally {
|
||||||
setShowAppAccessControl(false)
|
setShowAppAccessControl(false)
|
||||||
})
|
}
|
||||||
}, [appDetail, setAppDetail])
|
}, [appDetail, setAppDetail])
|
||||||
|
|
||||||
const [embeddingModalOpen, setEmbeddingModalOpen] = useState(false)
|
const [embeddingModalOpen, setEmbeddingModalOpen] = useState(false)
|
||||||
|
|||||||
@ -14,7 +14,8 @@ import { TransferMethod } from '@/app/components/base/chat/types'
|
|||||||
import { useEventEmitterContextContext } from '@/context/event-emitter'
|
import { useEventEmitterContextContext } from '@/context/event-emitter'
|
||||||
import { useProviderContext } from '@/context/provider-context'
|
import { useProviderContext } from '@/context/provider-context'
|
||||||
import { useFeatures } from '@/app/components/base/features/hooks'
|
import { useFeatures } from '@/app/components/base/features/hooks'
|
||||||
import { noop } from 'lodash-es'
|
import { cloneDeep, noop } from 'lodash-es'
|
||||||
|
import { DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG } from '@/config'
|
||||||
|
|
||||||
type TextGenerationItemProps = {
|
type TextGenerationItemProps = {
|
||||||
modelAndParameter: ModelAndParameter
|
modelAndParameter: ModelAndParameter
|
||||||
@ -50,8 +51,8 @@ const TextGenerationItem: FC<TextGenerationItemProps> = ({
|
|||||||
const config: TextGenerationConfig = {
|
const config: TextGenerationConfig = {
|
||||||
pre_prompt: !isAdvancedMode ? modelConfig.configs.prompt_template : '',
|
pre_prompt: !isAdvancedMode ? modelConfig.configs.prompt_template : '',
|
||||||
prompt_type: promptMode,
|
prompt_type: promptMode,
|
||||||
chat_prompt_config: isAdvancedMode ? chatPromptConfig : {},
|
chat_prompt_config: isAdvancedMode ? chatPromptConfig : cloneDeep(DEFAULT_CHAT_PROMPT_CONFIG),
|
||||||
completion_prompt_config: isAdvancedMode ? completionPromptConfig : {},
|
completion_prompt_config: isAdvancedMode ? completionPromptConfig : cloneDeep(DEFAULT_COMPLETION_PROMPT_CONFIG),
|
||||||
user_input_form: promptVariablesToUserInputsForm(modelConfig.configs.prompt_variables),
|
user_input_form: promptVariablesToUserInputsForm(modelConfig.configs.prompt_variables),
|
||||||
dataset_query_variable: contextVar || '',
|
dataset_query_variable: contextVar || '',
|
||||||
// features
|
// features
|
||||||
@ -74,6 +75,7 @@ const TextGenerationItem: FC<TextGenerationItemProps> = ({
|
|||||||
datasets: [...postDatasets],
|
datasets: [...postDatasets],
|
||||||
} as any,
|
} as any,
|
||||||
},
|
},
|
||||||
|
system_parameters: modelConfig.system_parameters,
|
||||||
}
|
}
|
||||||
const {
|
const {
|
||||||
completion,
|
completion,
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import {
|
|||||||
import Chat from '@/app/components/base/chat/chat'
|
import Chat from '@/app/components/base/chat/chat'
|
||||||
import { useChat } from '@/app/components/base/chat/chat/hooks'
|
import { useChat } from '@/app/components/base/chat/chat/hooks'
|
||||||
import { useDebugConfigurationContext } from '@/context/debug-configuration'
|
import { useDebugConfigurationContext } from '@/context/debug-configuration'
|
||||||
import type { ChatConfig, ChatItem, ChatItemInTree, OnSend } from '@/app/components/base/chat/types'
|
import type { ChatConfig, ChatItem, OnSend } from '@/app/components/base/chat/types'
|
||||||
import { useProviderContext } from '@/context/provider-context'
|
import { useProviderContext } from '@/context/provider-context'
|
||||||
import {
|
import {
|
||||||
fetchConversationMessages,
|
fetchConversationMessages,
|
||||||
@ -126,7 +126,7 @@ const DebugWithSingleModel = (
|
|||||||
)
|
)
|
||||||
}, [appId, chatList, checkCanSend, completionParams, config, handleSend, inputs, modelConfig.mode, modelConfig.model_id, modelConfig.provider, textGenerationModelList])
|
}, [appId, chatList, checkCanSend, completionParams, config, handleSend, inputs, modelConfig.mode, modelConfig.model_id, modelConfig.provider, textGenerationModelList])
|
||||||
|
|
||||||
const doRegenerate = useCallback((chatItem: ChatItemInTree, editedQuestion?: { message: string, files?: FileEntity[] }) => {
|
const doRegenerate = useCallback((chatItem: ChatItem, editedQuestion?: { message: string, files?: FileEntity[] }) => {
|
||||||
const question = editedQuestion ? chatItem : chatList.find(item => item.id === chatItem.parentMessageId)!
|
const question = editedQuestion ? chatItem : chatList.find(item => item.id === chatItem.parentMessageId)!
|
||||||
const parentAnswer = chatList.find(item => item.id === question.parentMessageId)
|
const parentAnswer = chatList.find(item => item.id === question.parentMessageId)
|
||||||
doSend(editedQuestion ? editedQuestion.message : question.content,
|
doSend(editedQuestion ? editedQuestion.message : question.content,
|
||||||
|
|||||||
@ -12,12 +12,15 @@ import type {
|
|||||||
ChatConfig,
|
ChatConfig,
|
||||||
ChatItem,
|
ChatItem,
|
||||||
} from '@/app/components/base/chat/types'
|
} from '@/app/components/base/chat/types'
|
||||||
|
import cloneDeep from 'lodash-es/cloneDeep'
|
||||||
import {
|
import {
|
||||||
AgentStrategy,
|
AgentStrategy,
|
||||||
} from '@/types/app'
|
} from '@/types/app'
|
||||||
|
import { SupportUploadFileTypes } from '@/app/components/workflow/types'
|
||||||
import { promptVariablesToUserInputsForm } from '@/utils/model-config'
|
import { promptVariablesToUserInputsForm } from '@/utils/model-config'
|
||||||
import { useDebugConfigurationContext } from '@/context/debug-configuration'
|
import { useDebugConfigurationContext } from '@/context/debug-configuration'
|
||||||
import { useEventEmitterContextContext } from '@/context/event-emitter'
|
import { useEventEmitterContextContext } from '@/context/event-emitter'
|
||||||
|
import { DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG } from '@/config'
|
||||||
|
|
||||||
export const useDebugWithSingleOrMultipleModel = (appId: string) => {
|
export const useDebugWithSingleOrMultipleModel = (appId: string) => {
|
||||||
const localeDebugWithSingleOrMultipleModelConfigs = localStorage.getItem('app-debug-with-single-or-multiple-models')
|
const localeDebugWithSingleOrMultipleModelConfigs = localStorage.getItem('app-debug-with-single-or-multiple-models')
|
||||||
@ -95,16 +98,14 @@ export const useConfigFromDebugContext = () => {
|
|||||||
const config: ChatConfig = {
|
const config: ChatConfig = {
|
||||||
pre_prompt: !isAdvancedMode ? modelConfig.configs.prompt_template : '',
|
pre_prompt: !isAdvancedMode ? modelConfig.configs.prompt_template : '',
|
||||||
prompt_type: promptMode,
|
prompt_type: promptMode,
|
||||||
chat_prompt_config: isAdvancedMode ? chatPromptConfig : {},
|
chat_prompt_config: isAdvancedMode ? chatPromptConfig : cloneDeep(DEFAULT_CHAT_PROMPT_CONFIG),
|
||||||
completion_prompt_config: isAdvancedMode ? completionPromptConfig : {},
|
completion_prompt_config: isAdvancedMode ? completionPromptConfig : cloneDeep(DEFAULT_COMPLETION_PROMPT_CONFIG),
|
||||||
user_input_form: promptVariablesToUserInputsForm(modelConfig.configs.prompt_variables),
|
user_input_form: promptVariablesToUserInputsForm(modelConfig.configs.prompt_variables),
|
||||||
dataset_query_variable: contextVar || '',
|
dataset_query_variable: contextVar || '',
|
||||||
opening_statement: introduction,
|
opening_statement: introduction,
|
||||||
more_like_this: {
|
more_like_this: modelConfig.more_like_this ?? { enabled: false },
|
||||||
enabled: false,
|
|
||||||
},
|
|
||||||
suggested_questions: openingSuggestedQuestions,
|
suggested_questions: openingSuggestedQuestions,
|
||||||
suggested_questions_after_answer: suggestedQuestionsAfterAnswerConfig,
|
suggested_questions_after_answer: suggestedQuestionsAfterAnswerConfig ?? { enabled: false },
|
||||||
text_to_speech: textToSpeechConfig,
|
text_to_speech: textToSpeechConfig,
|
||||||
speech_to_text: speechToTextConfig,
|
speech_to_text: speechToTextConfig,
|
||||||
retriever_resource: citationConfig,
|
retriever_resource: citationConfig,
|
||||||
@ -121,8 +122,13 @@ export const useConfigFromDebugContext = () => {
|
|||||||
},
|
},
|
||||||
file_upload: {
|
file_upload: {
|
||||||
image: visionConfig,
|
image: visionConfig,
|
||||||
|
allowed_file_upload_methods: visionConfig.transfer_methods ?? [],
|
||||||
|
allowed_file_types: [SupportUploadFileTypes.image],
|
||||||
|
max_length: visionConfig.number_limits ?? 0,
|
||||||
|
number_limits: visionConfig.number_limits,
|
||||||
},
|
},
|
||||||
annotation_reply: annotationConfig,
|
annotation_reply: annotationConfig,
|
||||||
|
system_parameters: modelConfig.system_parameters,
|
||||||
|
|
||||||
supportAnnotation: true,
|
supportAnnotation: true,
|
||||||
appId,
|
appId,
|
||||||
|
|||||||
@ -3,6 +3,7 @@ import type { FC } from 'react'
|
|||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
import React, { useCallback, useEffect, useRef, useState } from 'react'
|
import React, { useCallback, useEffect, useRef, useState } from 'react'
|
||||||
import { produce, setAutoFreeze } from 'immer'
|
import { produce, setAutoFreeze } from 'immer'
|
||||||
|
import cloneDeep from 'lodash-es/cloneDeep'
|
||||||
import { useBoolean } from 'ahooks'
|
import { useBoolean } from 'ahooks'
|
||||||
import {
|
import {
|
||||||
RiAddLine,
|
RiAddLine,
|
||||||
@ -36,7 +37,7 @@ import ActionButton, { ActionButtonState } from '@/app/components/base/action-bu
|
|||||||
import type { ModelConfig as BackendModelConfig, VisionFile, VisionSettings } from '@/types/app'
|
import type { ModelConfig as BackendModelConfig, VisionFile, VisionSettings } from '@/types/app'
|
||||||
import { formatBooleanInputs, promptVariablesToUserInputsForm } from '@/utils/model-config'
|
import { formatBooleanInputs, promptVariablesToUserInputsForm } from '@/utils/model-config'
|
||||||
import TextGeneration from '@/app/components/app/text-generate/item'
|
import TextGeneration from '@/app/components/app/text-generate/item'
|
||||||
import { IS_CE_EDITION } from '@/config'
|
import { DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG, IS_CE_EDITION } from '@/config'
|
||||||
import type { Inputs } from '@/models/debug'
|
import type { Inputs } from '@/models/debug'
|
||||||
import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||||
import { ModelFeatureEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
import { ModelFeatureEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||||
@ -90,6 +91,7 @@ const Debug: FC<IDebug> = ({
|
|||||||
completionParams,
|
completionParams,
|
||||||
hasSetContextVar,
|
hasSetContextVar,
|
||||||
datasetConfigs,
|
datasetConfigs,
|
||||||
|
externalDataToolsConfig,
|
||||||
} = useContext(ConfigContext)
|
} = useContext(ConfigContext)
|
||||||
const { eventEmitter } = useEventEmitterContextContext()
|
const { eventEmitter } = useEventEmitterContextContext()
|
||||||
const { data: text2speechDefaultModel } = useDefaultModel(ModelTypeEnum.textEmbedding)
|
const { data: text2speechDefaultModel } = useDefaultModel(ModelTypeEnum.textEmbedding)
|
||||||
@ -223,8 +225,8 @@ const Debug: FC<IDebug> = ({
|
|||||||
const postModelConfig: BackendModelConfig = {
|
const postModelConfig: BackendModelConfig = {
|
||||||
pre_prompt: !isAdvancedMode ? modelConfig.configs.prompt_template : '',
|
pre_prompt: !isAdvancedMode ? modelConfig.configs.prompt_template : '',
|
||||||
prompt_type: promptMode,
|
prompt_type: promptMode,
|
||||||
chat_prompt_config: {},
|
chat_prompt_config: isAdvancedMode ? chatPromptConfig : cloneDeep(DEFAULT_CHAT_PROMPT_CONFIG),
|
||||||
completion_prompt_config: {},
|
completion_prompt_config: isAdvancedMode ? completionPromptConfig : cloneDeep(DEFAULT_COMPLETION_PROMPT_CONFIG),
|
||||||
user_input_form: promptVariablesToUserInputsForm(modelConfig.configs.prompt_variables),
|
user_input_form: promptVariablesToUserInputsForm(modelConfig.configs.prompt_variables),
|
||||||
dataset_query_variable: contextVar || '',
|
dataset_query_variable: contextVar || '',
|
||||||
dataset_configs: {
|
dataset_configs: {
|
||||||
@ -251,11 +253,8 @@ const Debug: FC<IDebug> = ({
|
|||||||
suggested_questions_after_answer: suggestedQuestionsAfterAnswerConfig,
|
suggested_questions_after_answer: suggestedQuestionsAfterAnswerConfig,
|
||||||
speech_to_text: speechToTextConfig,
|
speech_to_text: speechToTextConfig,
|
||||||
retriever_resource: citationConfig,
|
retriever_resource: citationConfig,
|
||||||
}
|
system_parameters: modelConfig.system_parameters,
|
||||||
|
external_data_tools: externalDataToolsConfig,
|
||||||
if (isAdvancedMode) {
|
|
||||||
postModelConfig.chat_prompt_config = chatPromptConfig
|
|
||||||
postModelConfig.completion_prompt_config = completionPromptConfig
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const data: Record<string, any> = {
|
const data: Record<string, any> = {
|
||||||
|
|||||||
@ -36,14 +36,14 @@ import type {
|
|||||||
} from '@/models/debug'
|
} from '@/models/debug'
|
||||||
import type { ExternalDataTool } from '@/models/common'
|
import type { ExternalDataTool } from '@/models/common'
|
||||||
import type { DataSet } from '@/models/datasets'
|
import type { DataSet } from '@/models/datasets'
|
||||||
import type { ModelConfig as BackendModelConfig, VisionSettings } from '@/types/app'
|
import type { ModelConfig as BackendModelConfig, UserInputFormItem, VisionSettings } from '@/types/app'
|
||||||
import ConfigContext from '@/context/debug-configuration'
|
import ConfigContext from '@/context/debug-configuration'
|
||||||
import Config from '@/app/components/app/configuration/config'
|
import Config from '@/app/components/app/configuration/config'
|
||||||
import Debug from '@/app/components/app/configuration/debug'
|
import Debug from '@/app/components/app/configuration/debug'
|
||||||
import Confirm from '@/app/components/base/confirm'
|
import Confirm from '@/app/components/base/confirm'
|
||||||
import { ModelFeatureEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
import { ModelFeatureEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||||
import { ToastContext } from '@/app/components/base/toast'
|
import { ToastContext } from '@/app/components/base/toast'
|
||||||
import { fetchAppDetail, updateAppModelConfig } from '@/service/apps'
|
import { fetchAppDetailDirect, updateAppModelConfig } from '@/service/apps'
|
||||||
import { promptVariablesToUserInputsForm, userInputsFormToPromptVariables } from '@/utils/model-config'
|
import { promptVariablesToUserInputsForm, userInputsFormToPromptVariables } from '@/utils/model-config'
|
||||||
import { fetchDatasets } from '@/service/datasets'
|
import { fetchDatasets } from '@/service/datasets'
|
||||||
import { useProviderContext } from '@/context/provider-context'
|
import { useProviderContext } from '@/context/provider-context'
|
||||||
@ -186,6 +186,8 @@ const Configuration: FC = () => {
|
|||||||
prompt_template: '',
|
prompt_template: '',
|
||||||
prompt_variables: [] as PromptVariable[],
|
prompt_variables: [] as PromptVariable[],
|
||||||
},
|
},
|
||||||
|
chat_prompt_config: clone(DEFAULT_CHAT_PROMPT_CONFIG),
|
||||||
|
completion_prompt_config: clone(DEFAULT_COMPLETION_PROMPT_CONFIG),
|
||||||
more_like_this: null,
|
more_like_this: null,
|
||||||
opening_statement: '',
|
opening_statement: '',
|
||||||
suggested_questions: [],
|
suggested_questions: [],
|
||||||
@ -196,6 +198,14 @@ const Configuration: FC = () => {
|
|||||||
suggested_questions_after_answer: null,
|
suggested_questions_after_answer: null,
|
||||||
retriever_resource: null,
|
retriever_resource: null,
|
||||||
annotation_reply: null,
|
annotation_reply: null,
|
||||||
|
external_data_tools: [],
|
||||||
|
system_parameters: {
|
||||||
|
audio_file_size_limit: 0,
|
||||||
|
file_size_limit: 0,
|
||||||
|
image_file_size_limit: 0,
|
||||||
|
video_file_size_limit: 0,
|
||||||
|
workflow_file_upload_limit: 0,
|
||||||
|
},
|
||||||
dataSets: [],
|
dataSets: [],
|
||||||
agentConfig: DEFAULT_AGENT_SETTING,
|
agentConfig: DEFAULT_AGENT_SETTING,
|
||||||
})
|
})
|
||||||
@ -543,169 +553,169 @@ const Configuration: FC = () => {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
setCollectionList(collectionList)
|
setCollectionList(collectionList)
|
||||||
fetchAppDetail({ url: '/apps', id: appId }).then(async (res: any) => {
|
const res = await fetchAppDetailDirect({ url: '/apps', id: appId })
|
||||||
setMode(res.mode)
|
setMode(res.mode)
|
||||||
const modelConfig = res.model_config
|
const modelConfig = res.model_config as BackendModelConfig
|
||||||
const promptMode = modelConfig.prompt_type === PromptMode.advanced ? PromptMode.advanced : PromptMode.simple
|
const promptMode = modelConfig.prompt_type === PromptMode.advanced ? PromptMode.advanced : PromptMode.simple
|
||||||
doSetPromptMode(promptMode)
|
doSetPromptMode(promptMode)
|
||||||
if (promptMode === PromptMode.advanced) {
|
if (promptMode === PromptMode.advanced) {
|
||||||
if (modelConfig.chat_prompt_config && modelConfig.chat_prompt_config.prompt.length > 0)
|
if (modelConfig.chat_prompt_config && modelConfig.chat_prompt_config.prompt.length > 0)
|
||||||
setChatPromptConfig(modelConfig.chat_prompt_config)
|
setChatPromptConfig(modelConfig.chat_prompt_config)
|
||||||
else
|
else
|
||||||
setChatPromptConfig(clone(DEFAULT_CHAT_PROMPT_CONFIG))
|
setChatPromptConfig(clone(DEFAULT_CHAT_PROMPT_CONFIG))
|
||||||
setCompletionPromptConfig(modelConfig.completion_prompt_config || clone(DEFAULT_COMPLETION_PROMPT_CONFIG) as any)
|
setCompletionPromptConfig(modelConfig.completion_prompt_config || clone(DEFAULT_COMPLETION_PROMPT_CONFIG) as any)
|
||||||
setCanReturnToSimpleMode(false)
|
setCanReturnToSimpleMode(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
const model = res.model_config.model
|
const model = modelConfig.model
|
||||||
|
|
||||||
let datasets: any = null
|
let datasets: any = null
|
||||||
// old dataset struct
|
// old dataset struct
|
||||||
if (modelConfig.agent_mode?.tools?.find(({ dataset }: any) => dataset?.enabled))
|
if (modelConfig.agent_mode?.tools?.find(({ dataset }: any) => dataset?.enabled))
|
||||||
datasets = modelConfig.agent_mode?.tools.filter(({ dataset }: any) => dataset?.enabled)
|
datasets = modelConfig.agent_mode?.tools.filter(({ dataset }: any) => dataset?.enabled)
|
||||||
// new dataset struct
|
// new dataset struct
|
||||||
else if (modelConfig.dataset_configs.datasets?.datasets?.length > 0)
|
else if (modelConfig.dataset_configs.datasets?.datasets?.length > 0)
|
||||||
datasets = modelConfig.dataset_configs?.datasets?.datasets
|
datasets = modelConfig.dataset_configs?.datasets?.datasets
|
||||||
|
|
||||||
if (dataSets && datasets?.length && datasets?.length > 0) {
|
if (dataSets && datasets?.length && datasets?.length > 0) {
|
||||||
const { data: dataSetsWithDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: datasets.map(({ dataset }: any) => dataset.id) } })
|
const { data: dataSetsWithDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: datasets.map(({ dataset }: any) => dataset.id) } })
|
||||||
datasets = dataSetsWithDetail
|
datasets = dataSetsWithDetail
|
||||||
setDataSets(datasets)
|
setDataSets(datasets)
|
||||||
}
|
}
|
||||||
|
|
||||||
setIntroduction(modelConfig.opening_statement)
|
setIntroduction(modelConfig.opening_statement)
|
||||||
setSuggestedQuestions(modelConfig.suggested_questions || [])
|
setSuggestedQuestions(modelConfig.suggested_questions || [])
|
||||||
if (modelConfig.more_like_this)
|
if (modelConfig.more_like_this)
|
||||||
setMoreLikeThisConfig(modelConfig.more_like_this)
|
setMoreLikeThisConfig(modelConfig.more_like_this)
|
||||||
|
|
||||||
if (modelConfig.suggested_questions_after_answer)
|
if (modelConfig.suggested_questions_after_answer)
|
||||||
setSuggestedQuestionsAfterAnswerConfig(modelConfig.suggested_questions_after_answer)
|
setSuggestedQuestionsAfterAnswerConfig(modelConfig.suggested_questions_after_answer)
|
||||||
|
|
||||||
if (modelConfig.speech_to_text)
|
if (modelConfig.speech_to_text)
|
||||||
setSpeechToTextConfig(modelConfig.speech_to_text)
|
setSpeechToTextConfig(modelConfig.speech_to_text)
|
||||||
|
|
||||||
if (modelConfig.text_to_speech)
|
if (modelConfig.text_to_speech)
|
||||||
setTextToSpeechConfig(modelConfig.text_to_speech)
|
setTextToSpeechConfig(modelConfig.text_to_speech)
|
||||||
|
|
||||||
if (modelConfig.retriever_resource)
|
if (modelConfig.retriever_resource)
|
||||||
setCitationConfig(modelConfig.retriever_resource)
|
setCitationConfig(modelConfig.retriever_resource)
|
||||||
|
|
||||||
if (modelConfig.annotation_reply) {
|
if (modelConfig.annotation_reply) {
|
||||||
let annotationConfig = modelConfig.annotation_reply
|
let annotationConfig = modelConfig.annotation_reply
|
||||||
if (modelConfig.annotation_reply.enabled) {
|
if (modelConfig.annotation_reply.enabled) {
|
||||||
annotationConfig = {
|
annotationConfig = {
|
||||||
...modelConfig.annotation_reply,
|
...modelConfig.annotation_reply,
|
||||||
embedding_model: {
|
embedding_model: {
|
||||||
...modelConfig.annotation_reply.embedding_model,
|
...modelConfig.annotation_reply.embedding_model,
|
||||||
embedding_provider_name: correctModelProvider(modelConfig.annotation_reply.embedding_model.embedding_provider_name),
|
embedding_provider_name: correctModelProvider(modelConfig.annotation_reply.embedding_model.embedding_provider_name),
|
||||||
},
|
},
|
||||||
}
|
|
||||||
}
|
}
|
||||||
setAnnotationConfig(annotationConfig, true)
|
|
||||||
}
|
}
|
||||||
|
setAnnotationConfig(annotationConfig, true)
|
||||||
|
}
|
||||||
|
|
||||||
if (modelConfig.sensitive_word_avoidance)
|
if (modelConfig.sensitive_word_avoidance)
|
||||||
setModerationConfig(modelConfig.sensitive_word_avoidance)
|
setModerationConfig(modelConfig.sensitive_word_avoidance)
|
||||||
|
|
||||||
if (modelConfig.external_data_tools)
|
if (modelConfig.external_data_tools)
|
||||||
setExternalDataToolsConfig(modelConfig.external_data_tools)
|
setExternalDataToolsConfig(modelConfig.external_data_tools)
|
||||||
|
|
||||||
const config = {
|
const config: PublishConfig = {
|
||||||
modelConfig: {
|
modelConfig: {
|
||||||
provider: correctModelProvider(model.provider),
|
provider: correctModelProvider(model.provider),
|
||||||
model_id: model.name,
|
model_id: model.name,
|
||||||
mode: model.mode,
|
mode: model.mode,
|
||||||
configs: {
|
configs: {
|
||||||
prompt_template: modelConfig.pre_prompt || '',
|
prompt_template: modelConfig.pre_prompt || '',
|
||||||
prompt_variables: userInputsFormToPromptVariables(
|
prompt_variables: userInputsFormToPromptVariables(
|
||||||
[
|
([
|
||||||
...modelConfig.user_input_form,
|
...modelConfig.user_input_form,
|
||||||
...(
|
...(
|
||||||
modelConfig.external_data_tools?.length
|
modelConfig.external_data_tools?.length
|
||||||
? modelConfig.external_data_tools.map((item: any) => {
|
? modelConfig.external_data_tools.map((item: any) => {
|
||||||
return {
|
return {
|
||||||
external_data_tool: {
|
external_data_tool: {
|
||||||
variable: item.variable as string,
|
variable: item.variable as string,
|
||||||
label: item.label as string,
|
label: item.label as string,
|
||||||
enabled: item.enabled,
|
enabled: item.enabled,
|
||||||
type: item.type as string,
|
type: item.type as string,
|
||||||
config: item.config,
|
config: item.config,
|
||||||
required: true,
|
required: true,
|
||||||
icon: item.icon,
|
icon: item.icon,
|
||||||
icon_background: item.icon_background,
|
icon_background: item.icon_background,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
: []
|
: []
|
||||||
),
|
),
|
||||||
],
|
]) as unknown as UserInputFormItem[],
|
||||||
modelConfig.dataset_query_variable,
|
modelConfig.dataset_query_variable,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
more_like_this: modelConfig.more_like_this,
|
more_like_this: modelConfig.more_like_this ?? { enabled: false },
|
||||||
opening_statement: modelConfig.opening_statement,
|
opening_statement: modelConfig.opening_statement,
|
||||||
suggested_questions: modelConfig.suggested_questions,
|
suggested_questions: modelConfig.suggested_questions ?? [],
|
||||||
sensitive_word_avoidance: modelConfig.sensitive_word_avoidance,
|
sensitive_word_avoidance: modelConfig.sensitive_word_avoidance,
|
||||||
speech_to_text: modelConfig.speech_to_text,
|
speech_to_text: modelConfig.speech_to_text,
|
||||||
text_to_speech: modelConfig.text_to_speech,
|
text_to_speech: modelConfig.text_to_speech,
|
||||||
file_upload: modelConfig.file_upload,
|
file_upload: modelConfig.file_upload ?? null,
|
||||||
suggested_questions_after_answer: modelConfig.suggested_questions_after_answer,
|
suggested_questions_after_answer: modelConfig.suggested_questions_after_answer ?? { enabled: false },
|
||||||
retriever_resource: modelConfig.retriever_resource,
|
retriever_resource: modelConfig.retriever_resource,
|
||||||
annotation_reply: modelConfig.annotation_reply,
|
annotation_reply: modelConfig.annotation_reply ?? null,
|
||||||
external_data_tools: modelConfig.external_data_tools,
|
external_data_tools: modelConfig.external_data_tools ?? [],
|
||||||
dataSets: datasets || [],
|
system_parameters: modelConfig.system_parameters,
|
||||||
agentConfig: res.mode === 'agent-chat' ? {
|
dataSets: datasets || [],
|
||||||
max_iteration: DEFAULT_AGENT_SETTING.max_iteration,
|
agentConfig: res.mode === 'agent-chat' ? {
|
||||||
...modelConfig.agent_mode,
|
max_iteration: DEFAULT_AGENT_SETTING.max_iteration,
|
||||||
|
...modelConfig.agent_mode,
|
||||||
// remove dataset
|
// remove dataset
|
||||||
enabled: true, // modelConfig.agent_mode?.enabled is not correct. old app: the value of app with dataset's is always true
|
enabled: true, // modelConfig.agent_mode?.enabled is not correct. old app: the value of app with dataset's is always true
|
||||||
tools: modelConfig.agent_mode?.tools.filter((tool: any) => {
|
tools: (modelConfig.agent_mode?.tools ?? []).filter((tool: any) => {
|
||||||
return !tool.dataset
|
return !tool.dataset
|
||||||
}).map((tool: any) => {
|
}).map((tool: any) => {
|
||||||
const toolInCollectionList = collectionList.find(c => tool.provider_id === c.id)
|
const toolInCollectionList = collectionList.find(c => tool.provider_id === c.id)
|
||||||
return {
|
return {
|
||||||
...tool,
|
...tool,
|
||||||
isDeleted: res.deleted_tools?.some((deletedTool: any) => deletedTool.id === tool.id && deletedTool.tool_name === tool.tool_name),
|
isDeleted: res.deleted_tools?.some((deletedTool: any) => deletedTool.id === tool.id && deletedTool.tool_name === tool.tool_name) ?? false,
|
||||||
notAuthor: toolInCollectionList?.is_team_authorization === false,
|
notAuthor: toolInCollectionList?.is_team_authorization === false,
|
||||||
...(tool.provider_type === 'builtin' ? {
|
...(tool.provider_type === 'builtin' ? {
|
||||||
provider_id: correctToolProvider(tool.provider_name, !!toolInCollectionList),
|
provider_id: correctToolProvider(tool.provider_name, !!toolInCollectionList),
|
||||||
provider_name: correctToolProvider(tool.provider_name, !!toolInCollectionList),
|
provider_name: correctToolProvider(tool.provider_name, !!toolInCollectionList),
|
||||||
} : {}),
|
} : {}),
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
} : DEFAULT_AGENT_SETTING,
|
strategy: modelConfig.agent_mode?.strategy ?? AgentStrategy.react,
|
||||||
},
|
} : DEFAULT_AGENT_SETTING,
|
||||||
completionParams: model.completion_params,
|
},
|
||||||
}
|
completionParams: model.completion_params,
|
||||||
|
}
|
||||||
|
|
||||||
if (modelConfig.file_upload)
|
if (modelConfig.file_upload)
|
||||||
handleSetVisionConfig(modelConfig.file_upload.image, true)
|
handleSetVisionConfig(modelConfig.file_upload.image, true)
|
||||||
|
|
||||||
syncToPublishedConfig(config)
|
syncToPublishedConfig(config)
|
||||||
setPublishedConfig(config)
|
setPublishedConfig(config)
|
||||||
const retrievalConfig = getMultipleRetrievalConfig({
|
const retrievalConfig = getMultipleRetrievalConfig({
|
||||||
...modelConfig.dataset_configs,
|
...modelConfig.dataset_configs,
|
||||||
reranking_model: modelConfig.dataset_configs.reranking_model && {
|
reranking_model: modelConfig.dataset_configs.reranking_model && {
|
||||||
provider: modelConfig.dataset_configs.reranking_model.reranking_provider_name,
|
provider: modelConfig.dataset_configs.reranking_model.reranking_provider_name,
|
||||||
model: modelConfig.dataset_configs.reranking_model.reranking_model_name,
|
model: modelConfig.dataset_configs.reranking_model.reranking_model_name,
|
||||||
},
|
},
|
||||||
}, datasets, datasets, {
|
}, datasets, datasets, {
|
||||||
provider: currentRerankProvider?.provider,
|
provider: currentRerankProvider?.provider,
|
||||||
model: currentRerankModel?.model,
|
model: currentRerankModel?.model,
|
||||||
})
|
|
||||||
setDatasetConfigs({
|
|
||||||
retrieval_model: RETRIEVE_TYPE.multiWay,
|
|
||||||
...modelConfig.dataset_configs,
|
|
||||||
...retrievalConfig,
|
|
||||||
...(retrievalConfig.reranking_model ? {
|
|
||||||
reranking_model: {
|
|
||||||
reranking_model_name: retrievalConfig.reranking_model.model,
|
|
||||||
reranking_provider_name: correctModelProvider(retrievalConfig.reranking_model.provider),
|
|
||||||
},
|
|
||||||
} : {}),
|
|
||||||
})
|
|
||||||
setHasFetchedDetail(true)
|
|
||||||
})
|
})
|
||||||
|
setDatasetConfigs({
|
||||||
|
...modelConfig.dataset_configs,
|
||||||
|
...retrievalConfig,
|
||||||
|
...(retrievalConfig.reranking_model ? {
|
||||||
|
reranking_model: {
|
||||||
|
reranking_model_name: retrievalConfig.reranking_model.model,
|
||||||
|
reranking_provider_name: correctModelProvider(retrievalConfig.reranking_model.provider),
|
||||||
|
},
|
||||||
|
} : {}),
|
||||||
|
} as DatasetConfigs)
|
||||||
|
setHasFetchedDetail(true)
|
||||||
})()
|
})()
|
||||||
}, [appId])
|
}, [appId])
|
||||||
|
|
||||||
@ -780,8 +790,8 @@ const Configuration: FC = () => {
|
|||||||
// Simple Mode prompt
|
// Simple Mode prompt
|
||||||
pre_prompt: !isAdvancedMode ? promptTemplate : '',
|
pre_prompt: !isAdvancedMode ? promptTemplate : '',
|
||||||
prompt_type: promptMode,
|
prompt_type: promptMode,
|
||||||
chat_prompt_config: {},
|
chat_prompt_config: isAdvancedMode ? chatPromptConfig : clone(DEFAULT_CHAT_PROMPT_CONFIG),
|
||||||
completion_prompt_config: {},
|
completion_prompt_config: isAdvancedMode ? completionPromptConfig : clone(DEFAULT_COMPLETION_PROMPT_CONFIG),
|
||||||
user_input_form: promptVariablesToUserInputsForm(promptVariables),
|
user_input_form: promptVariablesToUserInputsForm(promptVariables),
|
||||||
dataset_query_variable: contextVar || '',
|
dataset_query_variable: contextVar || '',
|
||||||
// features
|
// features
|
||||||
@ -798,6 +808,7 @@ const Configuration: FC = () => {
|
|||||||
...modelConfig.agentConfig,
|
...modelConfig.agentConfig,
|
||||||
strategy: isFunctionCall ? AgentStrategy.functionCall : AgentStrategy.react,
|
strategy: isFunctionCall ? AgentStrategy.functionCall : AgentStrategy.react,
|
||||||
},
|
},
|
||||||
|
external_data_tools: externalDataToolsConfig,
|
||||||
model: {
|
model: {
|
||||||
provider: modelAndParameter?.provider || modelConfig.provider,
|
provider: modelAndParameter?.provider || modelConfig.provider,
|
||||||
name: modelId,
|
name: modelId,
|
||||||
@ -810,11 +821,7 @@ const Configuration: FC = () => {
|
|||||||
datasets: [...postDatasets],
|
datasets: [...postDatasets],
|
||||||
} as any,
|
} as any,
|
||||||
},
|
},
|
||||||
}
|
system_parameters: modelConfig.system_parameters,
|
||||||
|
|
||||||
if (isAdvancedMode) {
|
|
||||||
data.chat_prompt_config = chatPromptConfig
|
|
||||||
data.completion_prompt_config = completionPromptConfig
|
|
||||||
}
|
}
|
||||||
|
|
||||||
await updateAppModelConfig({ url: `/apps/${appId}/model-config`, body: data })
|
await updateAppModelConfig({ url: `/apps/${appId}/model-config`, body: data })
|
||||||
|
|||||||
@ -983,7 +983,7 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
|
|||||||
if (!showDrawer)
|
if (!showDrawer)
|
||||||
setShowDrawer(true)
|
setShowDrawer(true)
|
||||||
|
|
||||||
if (!currentConversation || currentConversation.id !== conversationIdInUrl || (matchedConversation && currentConversation !== matchedConversation))
|
if (!currentConversation || currentConversation.id !== conversationIdInUrl || (!('created_at' in currentConversation) && matchedConversation))
|
||||||
setCurrentConversation(nextConversation)
|
setCurrentConversation(nextConversation)
|
||||||
|
|
||||||
if (pendingConversationCacheRef.current?.id === conversationIdInUrl || matchedConversation)
|
if (pendingConversationCacheRef.current?.id === conversationIdInUrl || matchedConversation)
|
||||||
|
|||||||
@ -3,7 +3,6 @@ import Chat from '../chat'
|
|||||||
import type {
|
import type {
|
||||||
ChatConfig,
|
ChatConfig,
|
||||||
ChatItem,
|
ChatItem,
|
||||||
ChatItemInTree,
|
|
||||||
OnSend,
|
OnSend,
|
||||||
} from '../types'
|
} from '../types'
|
||||||
import { useChat } from '../chat/hooks'
|
import { useChat } from '../chat/hooks'
|
||||||
@ -149,7 +148,7 @@ const ChatWrapper = () => {
|
|||||||
)
|
)
|
||||||
}, [chatList, handleNewConversationCompleted, handleSend, currentConversationId, currentConversationInputs, newConversationInputs, isInstalledApp, appId])
|
}, [chatList, handleNewConversationCompleted, handleSend, currentConversationId, currentConversationInputs, newConversationInputs, isInstalledApp, appId])
|
||||||
|
|
||||||
const doRegenerate = useCallback((chatItem: ChatItemInTree, editedQuestion?: { message: string, files?: FileEntity[] }) => {
|
const doRegenerate = useCallback((chatItem: ChatItem, editedQuestion?: { message: string, files?: FileEntity[] }) => {
|
||||||
const question = editedQuestion ? chatItem : chatList.find(item => item.id === chatItem.parentMessageId)!
|
const question = editedQuestion ? chatItem : chatList.find(item => item.id === chatItem.parentMessageId)!
|
||||||
const parentAnswer = chatList.find(item => item.id === question.parentMessageId)
|
const parentAnswer = chatList.find(item => item.id === question.parentMessageId)
|
||||||
doSend(editedQuestion ? editedQuestion.message : question.content,
|
doSend(editedQuestion ? editedQuestion.message : question.content,
|
||||||
|
|||||||
@ -3,7 +3,6 @@ import Chat from '../chat'
|
|||||||
import type {
|
import type {
|
||||||
ChatConfig,
|
ChatConfig,
|
||||||
ChatItem,
|
ChatItem,
|
||||||
ChatItemInTree,
|
|
||||||
OnSend,
|
OnSend,
|
||||||
} from '../types'
|
} from '../types'
|
||||||
import { useChat } from '../chat/hooks'
|
import { useChat } from '../chat/hooks'
|
||||||
@ -147,7 +146,7 @@ const ChatWrapper = () => {
|
|||||||
)
|
)
|
||||||
}, [currentConversationId, currentConversationInputs, newConversationInputs, chatList, handleSend, isInstalledApp, appId, handleNewConversationCompleted])
|
}, [currentConversationId, currentConversationInputs, newConversationInputs, chatList, handleSend, isInstalledApp, appId, handleNewConversationCompleted])
|
||||||
|
|
||||||
const doRegenerate = useCallback((chatItem: ChatItemInTree, editedQuestion?: { message: string, files?: FileEntity[] }) => {
|
const doRegenerate = useCallback((chatItem: ChatItem, editedQuestion?: { message: string, files?: FileEntity[] }) => {
|
||||||
const question = editedQuestion ? chatItem : chatList.find(item => item.id === chatItem.parentMessageId)!
|
const question = editedQuestion ? chatItem : chatList.find(item => item.id === chatItem.parentMessageId)!
|
||||||
const parentAnswer = chatList.find(item => item.id === question.parentMessageId)
|
const parentAnswer = chatList.find(item => item.id === question.parentMessageId)
|
||||||
doSend(editedQuestion ? editedQuestion.message : question.content,
|
doSend(editedQuestion ? editedQuestion.message : question.content,
|
||||||
|
|||||||
@ -70,16 +70,20 @@ export const useEmbeddedChatbot = () => {
|
|||||||
const appInfo = useWebAppStore(s => s.appInfo)
|
const appInfo = useWebAppStore(s => s.appInfo)
|
||||||
const appMeta = useWebAppStore(s => s.appMeta)
|
const appMeta = useWebAppStore(s => s.appMeta)
|
||||||
const appParams = useWebAppStore(s => s.appParams)
|
const appParams = useWebAppStore(s => s.appParams)
|
||||||
|
const embeddedConversationId = useWebAppStore(s => s.embeddedConversationId)
|
||||||
|
const embeddedUserId = useWebAppStore(s => s.embeddedUserId)
|
||||||
const appId = useMemo(() => appInfo?.app_id, [appInfo])
|
const appId = useMemo(() => appInfo?.app_id, [appInfo])
|
||||||
|
|
||||||
const [userId, setUserId] = useState<string>()
|
const [userId, setUserId] = useState<string>()
|
||||||
const [conversationId, setConversationId] = useState<string>()
|
const [conversationId, setConversationId] = useState<string>()
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
getProcessedSystemVariablesFromUrlParams().then(({ user_id, conversation_id }) => {
|
setUserId(embeddedUserId || undefined)
|
||||||
setUserId(user_id)
|
}, [embeddedUserId])
|
||||||
setConversationId(conversation_id)
|
|
||||||
})
|
useEffect(() => {
|
||||||
}, [])
|
setConversationId(embeddedConversationId || undefined)
|
||||||
|
}, [embeddedConversationId])
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const setLanguageFromParams = async () => {
|
const setLanguageFromParams = async () => {
|
||||||
|
|||||||
@ -85,7 +85,7 @@ export type OnSend = {
|
|||||||
(message: string, files: FileEntity[] | undefined, isRegenerate: boolean, lastAnswer?: ChatItem | null): void
|
(message: string, files: FileEntity[] | undefined, isRegenerate: boolean, lastAnswer?: ChatItem | null): void
|
||||||
}
|
}
|
||||||
|
|
||||||
export type OnRegenerate = (chatItem: ChatItem) => void
|
export type OnRegenerate = (chatItem: ChatItem, editedQuestion?: { message: string; files?: FileEntity[] }) => void
|
||||||
|
|
||||||
export type Callback = {
|
export type Callback = {
|
||||||
onSuccess: () => void
|
onSuccess: () => void
|
||||||
|
|||||||
@ -32,6 +32,7 @@ const meta = {
|
|||||||
},
|
},
|
||||||
args: {
|
args: {
|
||||||
show: false,
|
show: false,
|
||||||
|
children: null,
|
||||||
},
|
},
|
||||||
} satisfies Meta<typeof ContentDialog>
|
} satisfies Meta<typeof ContentDialog>
|
||||||
|
|
||||||
@ -92,6 +93,9 @@ const DemoWrapper = (props: Props) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export const Default: Story = {
|
export const Default: Story = {
|
||||||
|
args: {
|
||||||
|
children: null,
|
||||||
|
},
|
||||||
render: args => <DemoWrapper {...args} />,
|
render: args => <DemoWrapper {...args} />,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -99,6 +103,7 @@ export const NarrowPanel: Story = {
|
|||||||
render: args => <DemoWrapper {...args} />,
|
render: args => <DemoWrapper {...args} />,
|
||||||
args: {
|
args: {
|
||||||
className: 'max-w-[420px]',
|
className: 'max-w-[420px]',
|
||||||
|
children: null,
|
||||||
},
|
},
|
||||||
parameters: {
|
parameters: {
|
||||||
docs: {
|
docs: {
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user