Merge branch 'feat/mcp-06-18' into deploy/dev

This commit is contained in:
Novice 2025-10-14 21:40:29 +08:00
commit 4cb42499ae
No known key found for this signature in database
GPG Key ID: EE3F68E3105DAAAB
187 changed files with 9526 additions and 3521 deletions

View File

@ -1,6 +1,7 @@
#!/bin/bash
yq eval '.services.weaviate.ports += ["8080:8080"]' -i docker/docker-compose.yaml
yq eval '.services.weaviate.ports += ["50051:50051"]' -i docker/docker-compose.yaml
yq eval '.services.qdrant.ports += ["6333:6333"]' -i docker/docker-compose.yaml
yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml
yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml

View File

@ -189,6 +189,11 @@ class PluginConfig(BaseSettings):
default="plugin-api-key",
)
PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field(
description="Timeout in seconds for requests to the plugin daemon (set to None to disable)",
default=300.0,
)
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
PLUGIN_REMOTE_INSTALL_HOST: str = Field(

View File

@ -1,4 +1,5 @@
import flask_restx
from flask import Response
from flask_restx import Resource, fields, marshal_with
from flask_restx._http import HTTPStatus
from sqlalchemy import select
@ -7,8 +8,7 @@ from werkzeug.exceptions import Forbidden
from extensions.ext_database import db
from libs.helper import TimestampField
from libs.login import current_user, login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from models.dataset import Dataset
from models.model import ApiToken, App
@ -57,9 +57,9 @@ class BaseApiKeyListResource(Resource):
def get(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
_, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model)
keys = db.session.scalars(
select(ApiToken).where(
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
@ -71,9 +71,8 @@ class BaseApiKeyListResource(Resource):
def post(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
current_user, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model)
if not current_user.has_edit_permission:
raise Forbidden()
@ -93,7 +92,7 @@ class BaseApiKeyListResource(Resource):
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
api_token = ApiToken()
setattr(api_token, self.resource_id_field, resource_id)
api_token.tenant_id = current_user.current_tenant_id
api_token.tenant_id = current_tenant_id
api_token.token = key
api_token.type = self.resource_type
db.session.add(api_token)
@ -112,9 +111,8 @@ class BaseApiKeyResource(Resource):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
api_key_id = str(api_key_id)
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
current_user, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model)
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
@ -158,7 +156,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
"""Create a new API key for an app"""
return super().post(resource_id)
def after_request(self, resp):
def after_request(self, resp: Response):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
@ -208,7 +206,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
"""Create a new API key for a dataset"""
return super().post(resource_id)
def after_request(self, resp):
def after_request(self, resp: Response):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
@ -229,7 +227,7 @@ class DatasetApiKeyResource(BaseApiKeyResource):
"""Delete an API key for a dataset"""
return super().delete(resource_id, api_key_id)
def after_request(self, resp):
def after_request(self, resp: Response):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp

View File

@ -1,7 +1,6 @@
from typing import Literal
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
@ -17,7 +16,7 @@ from fields.annotation_fields import (
annotation_fields,
annotation_hit_history_fields,
)
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from services.annotation_service import AppAnnotationService
@ -43,7 +42,9 @@ class AnnotationReplyActionApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
def post(self, app_id, action: Literal["enable", "disable"]):
if not current_user.is_editor:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
app_id = str(app_id)
@ -70,7 +71,9 @@ class AppAnnotationSettingDetailApi(Resource):
@login_required
@account_initialization_required
def get(self, app_id):
if not current_user.is_editor:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
app_id = str(app_id)
@ -99,7 +102,9 @@ class AppAnnotationSettingUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self, app_id, annotation_setting_id):
if not current_user.is_editor:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
app_id = str(app_id)
@ -125,7 +130,9 @@ class AnnotationReplyActionStatusApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
def get(self, app_id, job_id, action):
if not current_user.is_editor:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
job_id = str(job_id)
@ -160,7 +167,9 @@ class AnnotationApi(Resource):
@login_required
@account_initialization_required
def get(self, app_id):
if not current_user.is_editor:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
page = request.args.get("page", default=1, type=int)
@ -199,7 +208,9 @@ class AnnotationApi(Resource):
@cloud_edition_billing_resource_check("annotation")
@marshal_with(annotation_fields)
def post(self, app_id):
if not current_user.is_editor:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
app_id = str(app_id)
@ -214,7 +225,9 @@ class AnnotationApi(Resource):
@login_required
@account_initialization_required
def delete(self, app_id):
if not current_user.is_editor:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
app_id = str(app_id)
@ -250,7 +263,9 @@ class AnnotationExportApi(Resource):
@login_required
@account_initialization_required
def get(self, app_id):
if not current_user.is_editor:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
app_id = str(app_id)
@ -273,7 +288,9 @@ class AnnotationUpdateDeleteApi(Resource):
@cloud_edition_billing_resource_check("annotation")
@marshal_with(annotation_fields)
def post(self, app_id, annotation_id):
if not current_user.is_editor:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
app_id = str(app_id)
@ -289,7 +306,9 @@ class AnnotationUpdateDeleteApi(Resource):
@login_required
@account_initialization_required
def delete(self, app_id, annotation_id):
if not current_user.is_editor:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
app_id = str(app_id)
@ -311,7 +330,9 @@ class AnnotationBatchImportApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
def post(self, app_id):
if not current_user.is_editor:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
app_id = str(app_id)
@ -342,7 +363,9 @@ class AnnotationBatchImportStatusApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
def get(self, app_id, job_id):
if not current_user.is_editor:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
job_id = str(job_id)
@ -377,7 +400,9 @@ class AnnotationHitHistoryListApi(Resource):
@login_required
@account_initialization_required
def get(self, app_id, annotation_id):
if not current_user.is_editor:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
page = request.args.get("page", default=1, type=int)

View File

@ -1,6 +1,3 @@
from typing import cast
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
@ -13,8 +10,7 @@ from controllers.console.wraps import (
)
from extensions.ext_database import db
from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
from libs.login import login_required
from models import Account
from libs.login import current_account_with_tenant, login_required
from models.model import App
from services.app_dsl_service import AppDslService, ImportStatus
from services.enterprise.enterprise_service import EnterpriseService
@ -32,7 +28,8 @@ class AppImportApi(Resource):
@cloud_edition_billing_resource_check("apps")
def post(self):
# Check user role first
if not current_user.is_editor:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -51,7 +48,7 @@ class AppImportApi(Resource):
with Session(db.engine) as session:
import_service = AppDslService(session)
# Import app
account = cast(Account, current_user)
account = current_user
result = import_service.import_app(
account=account,
import_mode=args["mode"],
@ -85,14 +82,15 @@ class AppImportConfirmApi(Resource):
@marshal_with(app_import_fields)
def post(self, import_id):
# Check user role first
if not current_user.is_editor:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
# Create service with session
with Session(db.engine) as session:
import_service = AppDslService(session)
# Confirm import
account = cast(Account, current_user)
account = current_user
result = import_service.confirm_import(import_id=import_id, account=account)
session.commit()
@ -110,7 +108,8 @@ class AppImportCheckDependenciesApi(Resource):
@account_initialization_required
@marshal_with(app_import_check_dependencies_fields)
def get(self, app_model: App):
if not current_user.is_editor:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
with Session(db.engine) as session:

View File

@ -1,6 +1,5 @@
from collections.abc import Sequence
from flask_login import current_user
from flask_restx import Resource, fields, reqparse
from controllers.console import api, console_ns
@ -17,7 +16,7 @@ from core.helper.code_executor.python3.python3_code_provider import Python3CodeP
from core.llm_generator.llm_generator import LLMGenerator
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models import App
from services.workflow_service import WorkflowService
@ -48,11 +47,11 @@ class RuleGenerateApi(Resource):
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
account = current_user
try:
rules = LLMGenerator.generate_rule_config(
tenant_id=account.current_tenant_id,
tenant_id=current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
no_variable=args["no_variable"],
@ -99,11 +98,11 @@ class RuleCodeGenerateApi(Resource):
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
parser.add_argument("code_language", type=str, required=False, default="javascript", location="json")
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
account = current_user
try:
code_result = LLMGenerator.generate_code(
tenant_id=account.current_tenant_id,
tenant_id=current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
code_language=args["code_language"],
@ -144,11 +143,11 @@ class RuleStructuredOutputGenerateApi(Resource):
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
account = current_user
try:
structured_output = LLMGenerator.generate_structured_output(
tenant_id=account.current_tenant_id,
tenant_id=current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
)
@ -198,6 +197,7 @@ class InstructionGenerateApi(Resource):
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
parser.add_argument("ideal_output", type=str, required=False, default="", location="json")
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
code_template = (
Python3CodeProvider.get_default_code()
if args["language"] == "python"
@ -222,21 +222,21 @@ class InstructionGenerateApi(Resource):
match node_type:
case "llm":
return LLMGenerator.generate_rule_config(
current_user.current_tenant_id,
current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
no_variable=True,
)
case "agent":
return LLMGenerator.generate_rule_config(
current_user.current_tenant_id,
current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
no_variable=True,
)
case "code":
return LLMGenerator.generate_code(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
code_language=args["language"],
@ -245,7 +245,7 @@ class InstructionGenerateApi(Resource):
return {"error": f"invalid node type: {node_type}"}
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
return LLMGenerator.instruction_modify_legacy(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
flow_id=args["flow_id"],
current=args["current"],
instruction=args["instruction"],
@ -254,7 +254,7 @@ class InstructionGenerateApi(Resource):
)
if args["node_id"] != "" and args["current"] != "": # For workflow node
return LLMGenerator.instruction_modify_workflow(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
flow_id=args["flow_id"],
node_id=args["node_id"],
current=args["current"],

View File

@ -2,7 +2,6 @@ import json
from typing import cast
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields
from werkzeug.exceptions import Forbidden
@ -15,8 +14,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from models.model import AppMode, AppModelConfig
from services.app_model_config_service import AppModelConfigService
@ -54,16 +52,14 @@ class ModelConfigResource(Resource):
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
def post(self, app_model):
"""Modify app model config"""
if not isinstance(current_user, Account):
raise Forbidden()
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
assert current_user.current_tenant_id is not None, "The tenant information should be loaded."
# validate config
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
config=cast(dict, request.json),
app_mode=AppMode.value_of(app_model.mode),
)
@ -95,12 +91,12 @@ class ModelConfigResource(Resource):
# get tool
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
@ -134,7 +130,7 @@ class ModelConfigResource(Resource):
else:
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
)
@ -142,7 +138,7 @@ class ModelConfigResource(Resource):
continue
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,

View File

@ -1,4 +1,3 @@
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound
@ -9,7 +8,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
from extensions.ext_database import db
from fields.app_fields import app_site_fields
from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models import Account, Site
@ -76,9 +75,10 @@ class AppSite(Resource):
@marshal_with(app_site_fields)
def post(self, app_model):
args = parse_app_site_args()
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be editor, admin, or owner
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
@ -131,6 +131,8 @@ class AppSiteAccessTokenReset(Resource):
@marshal_with(app_site_fields)
def post(self, app_model):
# The role of the current user in the ta table must be admin or owner
current_user, _ = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()

View File

@ -1,10 +1,9 @@
from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
from controllers.console.auth.error import ApiKeyAuthFailedError
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from services.auth.api_key_auth_service import ApiKeyAuthService
from ..wraps import account_initialization_required, setup_required
@ -16,7 +15,8 @@ class ApiKeyAuthDataSource(Resource):
@login_required
@account_initialization_required
def get(self):
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id)
_, current_tenant_id = current_account_with_tenant()
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id)
if data_source_api_key_bindings:
return {
"sources": [
@ -41,6 +41,8 @@ class ApiKeyAuthDataSourceBinding(Resource):
@account_initialization_required
def post(self):
# The role of the current user in the table must be admin or owner
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
@ -50,7 +52,7 @@ class ApiKeyAuthDataSourceBinding(Resource):
args = parser.parse_args()
ApiKeyAuthService.validate_api_key_auth_args(args)
try:
ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args)
ApiKeyAuthService.create_provider_auth(current_tenant_id, args)
except Exception as e:
raise ApiKeyAuthFailedError(str(e))
return {"result": "success"}, 200
@ -63,9 +65,11 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
@account_initialization_required
def delete(self, binding_id):
# The role of the current user in the table must be admin or owner
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id)
return {"result": "success"}, 204

View File

@ -3,7 +3,6 @@ from collections.abc import Generator
from typing import cast
from flask import request
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -20,7 +19,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models import DataSourceOauthBinding, Document
from services.dataset_service import DatasetService, DocumentService
from services.datasource_provider_service import DatasourceProviderService
@ -37,10 +36,12 @@ class DataSourceApi(Resource):
@account_initialization_required
@marshal_with(integrate_list_fields)
def get(self):
_, current_tenant_id = current_account_with_tenant()
# get workspace data source integrates
data_source_integrates = db.session.scalars(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.tenant_id == current_tenant_id,
DataSourceOauthBinding.disabled == False,
)
).all()
@ -120,13 +121,15 @@ class DataSourceNotionListApi(Resource):
@account_initialization_required
@marshal_with(integrate_notion_info_list_fields)
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = request.args.get("dataset_id", default=None, type=str)
credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id:
raise ValueError("Credential id is required.")
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
@ -146,7 +149,7 @@ class DataSourceNotionListApi(Resource):
documents = session.scalars(
select(Document).filter_by(
dataset_id=dataset_id,
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
data_source_type="notion_import",
enabled=True,
)
@ -161,7 +164,7 @@ class DataSourceNotionListApi(Resource):
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id="langgenius/notion_datasource/notion_datasource",
datasource_name="notion_datasource",
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
)
datasource_provider_service = DatasourceProviderService()
@ -210,12 +213,14 @@ class DataSourceNotionApi(Resource):
@login_required
@account_initialization_required
def get(self, workspace_id, page_id, page_type):
_, current_tenant_id = current_account_with_tenant()
credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id:
raise ValueError("Credential id is required.")
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
@ -229,7 +234,7 @@ class DataSourceNotionApi(Resource):
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"),
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
)
text_docs = extractor.extract()
@ -239,6 +244,8 @@ class DataSourceNotionApi(Resource):
@login_required
@account_initialization_required
def post(self):
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
@ -263,7 +270,7 @@ class DataSourceNotionApi(Resource):
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id,
"tenant_id": current_tenant_id,
}
),
document_model=args["doc_form"],
@ -271,7 +278,7 @@ class DataSourceNotionApi(Resource):
extract_settings.append(extract_setting)
indexing_runner = IndexingRunner()
response = indexing_runner.indexing_estimate(
current_user.current_tenant_id,
current_tenant_id,
extract_settings,
args["process_rule"],
args["doc_form"],

View File

@ -45,6 +45,79 @@ def _validate_name(name: str) -> str:
return name
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
"""
Get supported retrieval methods based on vector database type.
Args:
vector_type: Vector database type, can be None
is_mock: Whether this is a Mock API, affects MILVUS handling
Returns:
Dictionary containing supported retrieval methods
Raises:
ValueError: If vector_type is None or unsupported
"""
if vector_type is None:
raise ValueError("Vector store type is not configured.")
# Define vector database types that only support semantic search
semantic_only_types = {
VectorType.RELYT,
VectorType.TIDB_VECTOR,
VectorType.CHROMA,
VectorType.PGVECTO_RS,
VectorType.VIKINGDB,
VectorType.UPSTASH,
}
# Define vector database types that support all retrieval methods
full_search_types = {
VectorType.QDRANT,
VectorType.WEAVIATE,
VectorType.OPENSEARCH,
VectorType.ANALYTICDB,
VectorType.MYSCALE,
VectorType.ORACLE,
VectorType.ELASTICSEARCH,
VectorType.ELASTICSEARCH_JA,
VectorType.PGVECTOR,
VectorType.VASTBASE,
VectorType.TIDB_ON_QDRANT,
VectorType.LINDORM,
VectorType.COUCHBASE,
VectorType.OPENGAUSS,
VectorType.OCEANBASE,
VectorType.TABLESTORE,
VectorType.HUAWEI_CLOUD,
VectorType.TENCENT,
VectorType.MATRIXONE,
VectorType.CLICKZETTA,
VectorType.BAIDU,
VectorType.ALIBABACLOUD_MYSQL,
}
semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
full_methods = {
"retrieval_method": [
RetrievalMethod.SEMANTIC_SEARCH.value,
RetrievalMethod.FULL_TEXT_SEARCH.value,
RetrievalMethod.HYBRID_SEARCH.value,
]
}
if vector_type == VectorType.MILVUS:
return semantic_methods if is_mock else full_methods
if vector_type in semantic_only_types:
return semantic_methods
elif vector_type in full_search_types:
return full_methods
else:
raise ValueError(f"Unsupported vector db type {vector_type}.")
@console_ns.route("/datasets")
class DatasetListApi(Resource):
@api.doc("get_datasets")
@ -777,50 +850,7 @@ class DatasetRetrievalSettingApi(Resource):
@account_initialization_required
def get(self):
vector_type = dify_config.VECTOR_STORE
match vector_type:
case (
VectorType.RELYT
| VectorType.TIDB_VECTOR
| VectorType.CHROMA
| VectorType.PGVECTO_RS
| VectorType.VIKINGDB
| VectorType.UPSTASH
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH]}
case (
VectorType.QDRANT
| VectorType.WEAVIATE
| VectorType.OPENSEARCH
| VectorType.ANALYTICDB
| VectorType.MYSCALE
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
| VectorType.ELASTICSEARCH_JA
| VectorType.PGVECTOR
| VectorType.VASTBASE
| VectorType.TIDB_ON_QDRANT
| VectorType.LINDORM
| VectorType.COUCHBASE
| VectorType.MILVUS
| VectorType.OPENGAUSS
| VectorType.OCEANBASE
| VectorType.TABLESTORE
| VectorType.HUAWEI_CLOUD
| VectorType.TENCENT
| VectorType.MATRIXONE
| VectorType.CLICKZETTA
| VectorType.BAIDU
| VectorType.ALIBABACLOUD_MYSQL
):
return {
"retrieval_method": [
RetrievalMethod.SEMANTIC_SEARCH,
RetrievalMethod.FULL_TEXT_SEARCH,
RetrievalMethod.HYBRID_SEARCH,
]
}
case _:
raise ValueError(f"Unsupported vector db type {vector_type}.")
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=False)
@console_ns.route("/datasets/retrieval-setting/<string:vector_type>")
@ -833,49 +863,7 @@ class DatasetRetrievalSettingMockApi(Resource):
@login_required
@account_initialization_required
def get(self, vector_type):
match vector_type:
case (
VectorType.MILVUS
| VectorType.RELYT
| VectorType.TIDB_VECTOR
| VectorType.CHROMA
| VectorType.PGVECTO_RS
| VectorType.VIKINGDB
| VectorType.UPSTASH
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH]}
case (
VectorType.QDRANT
| VectorType.WEAVIATE
| VectorType.OPENSEARCH
| VectorType.ANALYTICDB
| VectorType.MYSCALE
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
| VectorType.ELASTICSEARCH_JA
| VectorType.COUCHBASE
| VectorType.PGVECTOR
| VectorType.VASTBASE
| VectorType.LINDORM
| VectorType.OPENGAUSS
| VectorType.OCEANBASE
| VectorType.TABLESTORE
| VectorType.TENCENT
| VectorType.HUAWEI_CLOUD
| VectorType.MATRIXONE
| VectorType.CLICKZETTA
| VectorType.BAIDU
| VectorType.ALIBABACLOUD_MYSQL
):
return {
"retrieval_method": [
RetrievalMethod.SEMANTIC_SEARCH,
RetrievalMethod.FULL_TEXT_SEARCH,
RetrievalMethod.HYBRID_SEARCH,
]
}
case _:
raise ValueError(f"Unsupported vector db type {vector_type}.")
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=True)
@console_ns.route("/datasets/<uuid:dataset_id>/error-docs")

View File

@ -1,7 +1,6 @@
import uuid
from flask import request
from flask_login import current_user
from flask_restx import Resource, marshal, reqparse
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound
@ -27,7 +26,7 @@ from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.segment_fields import child_chunk_fields, segment_fields
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.dataset import ChildChunk, DocumentSegment
from models.model import UploadFile
from services.dataset_service import DatasetService, DocumentService, SegmentService
@ -43,6 +42,8 @@ class DatasetDocumentSegmentListApi(Resource):
@login_required
@account_initialization_required
def get(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
@ -79,7 +80,7 @@ class DatasetDocumentSegmentListApi(Resource):
select(DocumentSegment)
.where(
DocumentSegment.document_id == str(document_id),
DocumentSegment.tenant_id == current_user.current_tenant_id,
DocumentSegment.tenant_id == current_tenant_id,
)
.order_by(DocumentSegment.position.asc())
)
@ -115,6 +116,8 @@ class DatasetDocumentSegmentListApi(Resource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id):
current_user, _ = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@ -148,6 +151,8 @@ class DatasetDocumentSegmentApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, action):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
@ -171,7 +176,7 @@ class DatasetDocumentSegmentApi(Resource):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
@ -204,6 +209,8 @@ class DatasetDocumentSegmentAddApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@ -221,7 +228,7 @@ class DatasetDocumentSegmentAddApi(Resource):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
@ -255,6 +262,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@ -272,7 +281,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
@ -287,7 +296,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
)
if not segment:
@ -317,6 +326,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@ -333,7 +344,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
)
if not segment:
@ -361,6 +372,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@ -396,7 +409,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
upload_file_id,
dataset_id,
document_id,
current_user.current_tenant_id,
current_tenant_id,
current_user.id,
)
except Exception as e:
@ -427,6 +440,8 @@ class ChildChunkAddApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@ -441,7 +456,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
)
if not segment:
@ -453,7 +468,7 @@ class ChildChunkAddApi(Resource):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
@ -483,6 +498,8 @@ class ChildChunkAddApi(Resource):
@login_required
@account_initialization_required
def get(self, dataset_id, document_id, segment_id):
_, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@ -499,7 +516,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
)
if not segment:
@ -530,6 +547,8 @@ class ChildChunkAddApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@ -546,7 +565,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
)
if not segment:
@ -580,6 +599,8 @@ class ChildChunkUpdateApi(Resource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@ -596,7 +617,7 @@ class ChildChunkUpdateApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
)
if not segment:
@ -607,7 +628,7 @@ class ChildChunkUpdateApi(Resource):
db.session.query(ChildChunk)
.where(
ChildChunk.id == str(child_chunk_id),
ChildChunk.tenant_id == current_user.current_tenant_id,
ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id,
)
@ -634,6 +655,8 @@ class ChildChunkUpdateApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@ -650,7 +673,7 @@ class ChildChunkUpdateApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
)
if not segment:
@ -661,7 +684,7 @@ class ChildChunkUpdateApi(Resource):
db.session.query(ChildChunk)
.where(
ChildChunk.id == str(child_chunk_id),
ChildChunk.tenant_id == current_user.current_tenant_id,
ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id,
)

View File

@ -1,5 +1,4 @@
from flask import make_response, redirect, request
from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden, NotFound
@ -13,7 +12,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.oauth import OAuthHandler
from libs.helper import StrLen
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.provider_ids import DatasourceProviderID
from services.datasource_provider_service import DatasourceProviderService
from services.plugin.oauth_service import OAuthProxyService
@ -25,9 +24,10 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
@login_required
@account_initialization_required
def get(self, provider_id: str):
user = current_user
tenant_id = user.current_tenant_id
if not current_user.is_editor:
current_user, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
if not current_user.has_edit_permission:
raise Forbidden()
credential_id = request.args.get("credential_id")
@ -52,7 +52,7 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
authorization_url_response = oauth_handler.get_authorization_url(
tenant_id=tenant_id,
user_id=user.id,
user_id=current_user.id,
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
@ -131,7 +131,9 @@ class DatasourceAuth(Resource):
@login_required
@account_initialization_required
def post(self, provider_id: str):
if not current_user.is_editor:
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -145,7 +147,7 @@ class DatasourceAuth(Resource):
try:
datasource_provider_service.add_datasource_api_key_provider(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider_id=datasource_provider_id,
credentials=args["credentials"],
name=args["name"],
@ -160,8 +162,10 @@ class DatasourceAuth(Resource):
def get(self, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
_, current_tenant_id = current_account_with_tenant()
datasources = datasource_provider_service.list_datasource_credentials(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
)
@ -174,17 +178,19 @@ class DatasourceAuthDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self, provider_id: str):
current_user, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id)
plugin_id = datasource_provider_id.plugin_id
provider_name = datasource_provider_id.provider_name
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
auth_id=args["credential_id"],
provider=provider_name,
plugin_id=plugin_id,
@ -198,17 +204,19 @@ class DatasourceAuthUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider_id: str):
current_user, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id)
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
parser.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_credentials(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
auth_id=args["credential_id"],
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
@ -224,10 +232,10 @@ class DatasourceAuthListApi(Resource):
@login_required
@account_initialization_required
def get(self):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_all_datasource_credentials(
tenant_id=current_user.current_tenant_id
)
datasources = datasource_provider_service.get_all_datasource_credentials(tenant_id=current_tenant_id)
return {"result": jsonable_encoder(datasources)}, 200
@ -237,10 +245,10 @@ class DatasourceHardCodeAuthListApi(Resource):
@login_required
@account_initialization_required
def get(self):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_hard_code_datasource_credentials(
tenant_id=current_user.current_tenant_id
)
datasources = datasource_provider_service.get_hard_code_datasource_credentials(tenant_id=current_tenant_id)
return {"result": jsonable_encoder(datasources)}, 200
@ -250,7 +258,9 @@ class DatasourceAuthOauthCustomClient(Resource):
@login_required
@account_initialization_required
def post(self, provider_id: str):
if not current_user.is_editor:
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
@ -259,7 +269,7 @@ class DatasourceAuthOauthCustomClient(Resource):
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.setup_oauth_custom_client_params(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id,
client_params=args.get("client_params", {}),
enabled=args.get("enable_oauth_custom_client", False),
@ -270,10 +280,12 @@ class DatasourceAuthOauthCustomClient(Resource):
@login_required
@account_initialization_required
def delete(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_oauth_custom_client_params(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id,
)
return {"result": "success"}, 200
@ -285,7 +297,9 @@ class DatasourceAuthDefaultApi(Resource):
@login_required
@account_initialization_required
def post(self, provider_id: str):
if not current_user.is_editor:
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
@ -293,7 +307,7 @@ class DatasourceAuthDefaultApi(Resource):
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.set_default_datasource_provider(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id,
credential_id=args["id"],
)
@ -306,7 +320,9 @@ class DatasourceUpdateProviderNameApi(Resource):
@login_required
@account_initialization_required
def post(self, provider_id: str):
if not current_user.is_editor:
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
@ -315,7 +331,7 @@ class DatasourceUpdateProviderNameApi(Resource):
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_provider_name(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id,
name=args["name"],
credential_id=args["credential_id"],

View File

@ -1,4 +1,3 @@
from flask_login import current_user
from flask_restx import Resource, marshal, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
@ -13,7 +12,7 @@ from controllers.console.wraps import (
)
from extensions.ext_database import db
from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.dataset import DatasetPermissionEnum
from services.dataset_service import DatasetPermissionService, DatasetService
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
@ -38,7 +37,7 @@ class CreateRagPipelineDatasetApi(Resource):
)
args = parser.parse_args()
current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
@ -58,12 +57,12 @@ class CreateRagPipelineDatasetApi(Resource):
with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
)
if rag_pipeline_dataset_create_entity.permission == "partial_members":
DatasetPermissionService.update_partial_member_list(
current_user.current_tenant_id,
current_tenant_id,
import_info["dataset_id"],
rag_pipeline_dataset_create_entity.partial_member_list,
)
@ -81,10 +80,12 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_dataset_editor:
raise Forbidden()
dataset = DatasetService.create_empty_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(
name="",
description="",

View File

@ -12,8 +12,8 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
from extensions.ext_database import db
from fields.installed_app_fields import installed_app_list_fields
from libs.datetime_utils import naive_utc_now
from libs.login import current_user, login_required
from models import Account, App, InstalledApp, RecommendedApp
from libs.login import current_account_with_tenant, login_required
from models import App, InstalledApp, RecommendedApp
from services.account_service import TenantService
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
@ -29,9 +29,7 @@ class InstalledAppsListApi(Resource):
@marshal_with(installed_app_list_fields)
def get(self):
app_id = request.args.get("app_id", default=None, type=str)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
current_tenant_id = current_user.current_tenant_id
current_user, current_tenant_id = current_account_with_tenant()
if app_id:
installed_apps = db.session.scalars(
@ -121,9 +119,8 @@ class InstalledAppsListApi(Resource):
if recommended_app is None:
raise NotFound("App not found")
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
current_tenant_id = current_user.current_tenant_id
_, current_tenant_id = current_account_with_tenant()
app = db.session.query(App).where(App.id == args["app_id"]).first()
if app is None:
@ -163,9 +160,8 @@ class InstalledAppApi(InstalledAppResource):
"""
def delete(self, installed_app):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
if installed_app.app_owner_tenant_id == current_user.current_tenant_id:
_, current_tenant_id = current_account_with_tenant()
if installed_app.app_owner_tenant_id == current_tenant_id:
raise BadRequest("You can't uninstall an app owned by the current tenant")
db.session.delete(installed_app)

View File

@ -4,7 +4,7 @@ from constants import HIDDEN_VALUE
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from fields.api_based_extension_fields import api_based_extension_fields
from libs.login import current_user, login_required
from libs.login import current_account_with_tenant, current_user, login_required
from models.account import Account
from models.api_based_extension import APIBasedExtension
from services.api_based_extension_service import APIBasedExtensionService
@ -47,9 +47,7 @@ class APIBasedExtensionAPI(Resource):
@account_initialization_required
@marshal_with(api_based_extension_fields)
def get(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
@api.doc("create_api_based_extension")
@ -77,9 +75,10 @@ class APIBasedExtensionAPI(Resource):
parser.add_argument("api_endpoint", type=str, required=True, location="json")
parser.add_argument("api_key", type=str, required=True, location="json")
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
extension_data = APIBasedExtension(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
name=args["name"],
api_endpoint=args["api_endpoint"],
api_key=args["api_key"],
@ -102,7 +101,7 @@ class APIBasedExtensionDetailAPI(Resource):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
@ -128,9 +127,9 @@ class APIBasedExtensionDetailAPI(Resource):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id
_, current_tenant_id = current_account_with_tenant()
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
@ -157,9 +156,9 @@ class APIBasedExtensionDetailAPI(Resource):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id
_, current_tenant_id = current_account_with_tenant()
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
APIBasedExtensionService.delete(extension_data_from_db)

View File

@ -1,7 +1,6 @@
from flask_restx import Resource, fields
from libs.login import current_user, login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from services.feature_service import FeatureService
from . import api, console_ns
@ -23,9 +22,9 @@ class FeatureApi(Resource):
@cloud_utm_record
def get(self):
"""Get feature configuration for current tenant"""
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
return FeatureService.get_features(current_user.current_tenant_id).model_dump()
_, current_tenant_id = current_account_with_tenant()
return FeatureService.get_features(current_tenant_id).model_dump()
@console_ns.route("/system-features")

View File

@ -14,8 +14,7 @@ from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from extensions.ext_database import db
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
from libs.login import current_user
from models.account import Account
from libs.login import current_account_with_tenant
from services.file_service import FileService
from . import console_ns
@ -64,8 +63,7 @@ class RemoteFileUploadApi(Resource):
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
try:
assert isinstance(current_user, Account)
user = current_user
user, _ = current_account_with_tenant()
upload_file = FileService(db.engine).upload_file(
filename=file_info.filename,
content=content,

View File

@ -5,18 +5,10 @@ from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginPermissionDeniedError
from libs.login import current_user, login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from services.plugin.endpoint_service import EndpointService
def _current_account_with_tenant() -> tuple[Account, str]:
assert isinstance(current_user, Account)
tenant_id = current_user.current_tenant_id
assert tenant_id is not None
return current_user, tenant_id
@console_ns.route("/workspaces/current/endpoints/create")
class EndpointCreateApi(Resource):
@api.doc("create_endpoint")
@ -41,7 +33,7 @@ class EndpointCreateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user, tenant_id = _current_account_with_tenant()
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
@ -87,7 +79,7 @@ class EndpointListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user, tenant_id = _current_account_with_tenant()
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
@ -130,7 +122,7 @@ class EndpointListForSinglePluginApi(Resource):
@login_required
@account_initialization_required
def get(self):
user, tenant_id = _current_account_with_tenant()
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
@ -172,7 +164,7 @@ class EndpointDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
user, tenant_id = _current_account_with_tenant()
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
@ -212,7 +204,7 @@ class EndpointUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user, tenant_id = _current_account_with_tenant()
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
@ -255,7 +247,7 @@ class EndpointEnableApi(Resource):
@login_required
@account_initialization_required
def post(self):
user, tenant_id = _current_account_with_tenant()
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
@ -288,7 +280,7 @@ class EndpointDisableApi(Resource):
@login_required
@account_initialization_required
def post(self):
user, tenant_id = _current_account_with_tenant()
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)

View File

@ -25,7 +25,7 @@ from controllers.console.wraps import (
from extensions.ext_database import db
from fields.member_fields import account_with_role_list_fields
from libs.helper import extract_remote_ip
from libs.login import current_user, login_required
from libs.login import current_account_with_tenant, login_required
from models.account import Account, TenantAccountRole
from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountAlreadyInTenantError
@ -41,8 +41,7 @@ class MemberListApi(Resource):
@account_initialization_required
@marshal_with(account_with_role_list_fields)
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_tenant_members(current_user.current_tenant)
@ -69,9 +68,7 @@ class MemberInviteEmailApi(Resource):
interface_language = args["language"]
if not TenantAccountRole.is_non_owner_role(invitee_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
inviter = current_user
if not inviter.current_tenant:
raise ValueError("No current tenant")
@ -120,8 +117,7 @@ class MemberCancelInviteApi(Resource):
@login_required
@account_initialization_required
def delete(self, member_id):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
member = db.session.query(Account).where(Account.id == str(member_id)).first()
@ -160,9 +156,7 @@ class MemberUpdateRoleApi(Resource):
if not TenantAccountRole.is_valid_role(new_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
member = db.session.get(Account, str(member_id))
@ -189,8 +183,7 @@ class DatasetOperatorMemberListApi(Resource):
@account_initialization_required
@marshal_with(account_with_role_list_fields)
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
@ -212,10 +205,8 @@ class SendOwnerTransferEmailApi(Resource):
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
current_user, _ = current_account_with_tenant()
# check if the current user is the owner of the workspace
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):
@ -250,8 +241,7 @@ class OwnerTransferCheckApi(Resource):
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
# check if the current user is the owner of the workspace
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):
@ -296,8 +286,7 @@ class OwnerTransfer(Resource):
args = parser.parse_args()
# check if the current user is the owner of the workspace
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):

View File

@ -1,7 +1,6 @@
import io
from flask import send_file
from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
@ -11,8 +10,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import StrLen, uuid_value
from libs.login import login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService
@ -23,11 +21,8 @@ class ModelProviderListApi(Resource):
@login_required
@account_initialization_required
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument(
@ -52,11 +47,8 @@ class ModelProviderCredentialApi(Resource):
@login_required
@account_initialization_required
def get(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
# if credential_id is not provided, return current used credential
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
@ -73,8 +65,7 @@ class ModelProviderCredentialApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
@ -85,11 +76,9 @@ class ModelProviderCredentialApi(Resource):
model_provider_service = ModelProviderService()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
try:
model_provider_service.create_provider_credential(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=provider,
credentials=args["credentials"],
credential_name=args["name"],
@ -103,8 +92,7 @@ class ModelProviderCredentialApi(Resource):
@login_required
@account_initialization_required
def put(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
@ -116,11 +104,9 @@ class ModelProviderCredentialApi(Resource):
model_provider_service = ModelProviderService()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
try:
model_provider_service.update_provider_credential(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=provider,
credentials=args["credentials"],
credential_id=args["credential_id"],
@ -135,19 +121,16 @@ class ModelProviderCredentialApi(Resource):
@login_required
@account_initialization_required
def delete(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
model_provider_service = ModelProviderService()
model_provider_service.remove_provider_credential(
tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"]
tenant_id=current_tenant_id, provider=provider, credential_id=args["credential_id"]
)
return {"result": "success"}, 204
@ -159,19 +142,16 @@ class ModelProviderCredentialSwitchApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
service = ModelProviderService()
service.switch_active_provider_credential(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=provider,
credential_id=args["credential_id"],
)
@ -184,15 +164,12 @@ class ModelProviderValidateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id
tenant_id = current_tenant_id
model_provider_service = ModelProviderService()
@ -240,14 +217,11 @@ class PreferredProviderTypeUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id
tenant_id = current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument(
@ -276,14 +250,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
def get(self, provider: str):
if provider != "anthropic":
raise ValueError(f"provider name {provider} is invalid")
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
BillingService.is_tenant_owner_or_admin(current_user)
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
data = BillingService.get_model_provider_payment_link(
provider_name=provider,
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
account_id=current_user.id,
prefilled_email=current_user.email,
)

View File

@ -1,6 +1,5 @@
import logging
from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
@ -10,7 +9,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import StrLen, uuid_value
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from services.model_load_balancing_service import ModelLoadBalancingService
from services.model_provider_service import ModelProviderService
@ -23,6 +22,8 @@ class DefaultModelApi(Resource):
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument(
"model_type",
@ -34,8 +35,6 @@ class DefaultModelApi(Resource):
)
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
default_model_entity = model_provider_service.get_default_model_of_model_type(
tenant_id=tenant_id, model_type=args["model_type"]
@ -47,15 +46,14 @@ class DefaultModelApi(Resource):
@login_required
@account_initialization_required
def post(self):
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json")
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
model_settings = args["model_settings"]
for model_setting in model_settings:
@ -92,7 +90,7 @@ class ModelProviderModelApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
@ -104,11 +102,11 @@ class ModelProviderModelApi(Resource):
@account_initialization_required
def post(self, provider: str):
# To save the model's load balance configs
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
parser.add_argument(
@ -129,7 +127,7 @@ class ModelProviderModelApi(Resource):
raise ValueError("credential_id is required when configuring a custom-model")
service = ModelProviderService()
service.switch_active_custom_model_credential(
tenant_id=current_user.current_tenant_id,
tenant_id=tenant_id,
provider=provider,
model_type=args["model_type"],
model=args["model"],
@ -164,11 +162,11 @@ class ModelProviderModelApi(Resource):
@login_required
@account_initialization_required
def delete(self, provider: str):
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
parser.add_argument(
@ -195,7 +193,7 @@ class ModelProviderModelCredentialApi(Resource):
@login_required
@account_initialization_required
def get(self, provider: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="args")
@ -257,6 +255,8 @@ class ModelProviderModelCredentialApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
@ -274,7 +274,6 @@ class ModelProviderModelCredentialApi(Resource):
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
try:
@ -301,6 +300,8 @@ class ModelProviderModelCredentialApi(Resource):
@login_required
@account_initialization_required
def put(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
@ -323,7 +324,7 @@ class ModelProviderModelCredentialApi(Resource):
try:
model_provider_service.update_model_credential(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=provider,
model_type=args["model_type"],
model=args["model"],
@ -340,6 +341,8 @@ class ModelProviderModelCredentialApi(Resource):
@login_required
@account_initialization_required
def delete(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
@ -357,7 +360,7 @@ class ModelProviderModelCredentialApi(Resource):
model_provider_service = ModelProviderService()
model_provider_service.remove_model_credential(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=provider,
model_type=args["model_type"],
model=args["model"],
@ -373,6 +376,8 @@ class ModelProviderModelCredentialSwitchApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
@ -390,7 +395,7 @@ class ModelProviderModelCredentialSwitchApi(Resource):
service = ModelProviderService()
service.add_model_credential_to_model_list(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=provider,
model_type=args["model_type"],
model=args["model"],
@ -407,7 +412,7 @@ class ModelProviderModelEnableApi(Resource):
@login_required
@account_initialization_required
def patch(self, provider: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
@ -437,7 +442,7 @@ class ModelProviderModelDisableApi(Resource):
@login_required
@account_initialization_required
def patch(self, provider: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
@ -465,7 +470,7 @@ class ModelProviderModelValidateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
@ -514,8 +519,7 @@ class ModelProviderModelParameterRuleApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService()
parameter_rules = model_provider_service.get_model_parameter_rules(
@ -531,8 +535,7 @@ class ModelProviderAvailableModelApi(Resource):
@login_required
@account_initialization_required
def get(self, model_type):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)

View File

@ -1,7 +1,6 @@
import io
from flask import request, send_file
from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
@ -11,7 +10,7 @@ from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginDaemonClientSideError
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
from services.plugin.plugin_parameter_service import PluginParameterService
@ -26,7 +25,7 @@ class PluginDebuggingKeyApi(Resource):
@account_initialization_required
@plugin_permission_required(debug_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
try:
return {
@ -44,7 +43,7 @@ class PluginListApi(Resource):
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=False, location="args", default=1)
parser.add_argument("page_size", type=int, required=False, location="args", default=256)
@ -81,7 +80,7 @@ class PluginListInstallationsFromIdsApi(Resource):
@login_required
@account_initialization_required
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("plugin_ids", type=list, required=True, location="json")
@ -120,7 +119,7 @@ class PluginUploadFromPkgApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
file = request.files["pkg"]
@ -144,7 +143,7 @@ class PluginUploadFromGithubApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("repo", type=str, required=True, location="json")
@ -167,7 +166,7 @@ class PluginUploadFromBundleApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
file = request.files["bundle"]
@ -191,7 +190,7 @@ class PluginInstallFromPkgApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
@ -217,7 +216,7 @@ class PluginInstallFromGithubApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("repo", type=str, required=True, location="json")
@ -247,7 +246,7 @@ class PluginInstallFromMarketplaceApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
@ -273,7 +272,7 @@ class PluginFetchMarketplacePkgApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
@ -299,7 +298,7 @@ class PluginFetchManifestApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
@ -324,7 +323,7 @@ class PluginFetchInstallTasksApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
@ -346,7 +345,7 @@ class PluginFetchInstallTaskApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self, task_id: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
try:
return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)})
@ -361,7 +360,7 @@ class PluginDeleteInstallTaskApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self, task_id: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
try:
return {"success": PluginService.delete_install_task(tenant_id, task_id)}
@ -376,7 +375,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
try:
return {"success": PluginService.delete_all_install_task_items(tenant_id)}
@ -391,7 +390,7 @@ class PluginDeleteInstallTaskItemApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self, task_id: str, identifier: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
try:
return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)}
@ -406,7 +405,7 @@ class PluginUpgradeFromMarketplaceApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
@ -430,7 +429,7 @@ class PluginUpgradeFromGithubApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
@ -466,7 +465,7 @@ class PluginUninstallApi(Resource):
req.add_argument("plugin_installation_id", type=str, required=True, location="json")
args = req.parse_args()
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
try:
return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
@ -480,6 +479,7 @@ class PluginChangePermissionApi(Resource):
@login_required
@account_initialization_required
def post(self):
current_user, current_tenant_id = current_account_with_tenant()
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
@ -492,7 +492,7 @@ class PluginChangePermissionApi(Resource):
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
tenant_id = user.current_tenant_id
tenant_id = current_tenant_id
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
@ -503,7 +503,7 @@ class PluginFetchPermissionApi(Resource):
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
permission = PluginPermissionService.get_permission(tenant_id)
if not permission:
@ -529,10 +529,10 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
@account_initialization_required
def get(self):
# check if the user is admin or owner
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
tenant_id = current_user.current_tenant_id
user_id = current_user.id
parser = reqparse.RequestParser()
@ -565,7 +565,7 @@ class PluginChangePreferencesApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
@ -574,8 +574,6 @@ class PluginChangePreferencesApi(Resource):
req.add_argument("auto_upgrade", type=dict, required=True, location="json")
args = req.parse_args()
tenant_id = user.current_tenant_id
permission = args["permission"]
install_permission = TenantPluginPermission.InstallPermission(permission.get("install_permission", "everyone"))
@ -621,7 +619,7 @@ class PluginFetchPreferencesApi(Resource):
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
permission = PluginPermissionService.get_permission(tenant_id)
permission_dict = {
@ -661,7 +659,7 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
@account_initialization_required
def post(self):
# exclude one single plugin
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
req = reqparse.RequestParser()
req.add_argument("plugin_id", type=str, required=True, location="json")

View File

@ -2,11 +2,11 @@ import io
from urllib.parse import urlparse
from flask import make_response, redirect, request, send_file
from flask_login import current_user
from flask_restx import (
Resource,
reqparse,
)
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from configs import dify_config
@ -16,15 +16,16 @@ from controllers.console.wraps import (
enterprise_license_required,
setup_required,
)
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.auth.auth_provider import OAuthClientProvider
from core.mcp.error import MCPAuthError, MCPError
from core.mcp.mcp_client import MCPClient
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.oauth import OAuthHandler
from core.tools.entities.tool_entities import CredentialType
from extensions.ext_database import db
from libs.helper import StrLen, alphanumeric, uuid_value
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.provider_ids import ToolProviderID
from services.plugin.oauth_service import OAuthProxyService
from services.tools.api_tools_manage_service import ApiToolManageService
@ -43,7 +44,9 @@ def is_valid_url(url: str) -> bool:
try:
parsed = urlparse(url)
return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"]
except Exception:
except (ValueError, TypeError):
# ValueError: Invalid URL format
# TypeError: url is not a string
return False
@ -53,10 +56,9 @@ class ToolProviderListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
req = reqparse.RequestParser()
req.add_argument(
@ -78,9 +80,7 @@ class ToolBuiltinProviderListToolsApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
user = current_user
tenant_id = user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.list_builtin_tool_provider_tools(
@ -96,9 +96,7 @@ class ToolBuiltinProviderInfoApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
user = current_user
tenant_id = user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
@ -109,11 +107,10 @@ class ToolBuiltinProviderDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self, provider):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
tenant_id = user.current_tenant_id
req = reqparse.RequestParser()
req.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = req.parse_args()
@ -131,10 +128,9 @@ class ToolBuiltinProviderAddApi(Resource):
@login_required
@account_initialization_required
def post(self, provider):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
@ -161,13 +157,12 @@ class ToolBuiltinProviderUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
@ -193,7 +188,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_credentials(
@ -218,13 +213,12 @@ class ToolApiProviderAddApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
@ -258,10 +252,9 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
@ -282,10 +275,9 @@ class ToolApiProviderListToolsApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
@ -308,13 +300,12 @@ class ToolApiProviderUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
@ -350,13 +341,12 @@ class ToolApiProviderDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
@ -377,10 +367,9 @@ class ToolApiProviderGetApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
@ -401,8 +390,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
@login_required
@account_initialization_required
def get(self, provider, credential_type):
user = current_user
tenant_id = user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.list_builtin_provider_credentials_schema(
@ -444,9 +432,9 @@ class ToolApiProviderPreviousTestApi(Resource):
parser.add_argument("schema", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
return ApiToolManageService.test_api_tool_preview(
current_user.current_tenant_id,
current_tenant_id,
args["provider_name"] or "",
args["tool_name"],
args["credentials"],
@ -462,13 +450,12 @@ class ToolWorkflowProviderCreateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
reqparser = reqparse.RequestParser()
reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
@ -502,13 +489,12 @@ class ToolWorkflowProviderUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
reqparser = reqparse.RequestParser()
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
@ -545,13 +531,12 @@ class ToolWorkflowProviderDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
reqparser = reqparse.RequestParser()
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
@ -571,10 +556,9 @@ class ToolWorkflowProviderGetApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
@ -606,10 +590,9 @@ class ToolWorkflowProviderListToolApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args")
@ -631,10 +614,9 @@ class ToolBuiltinListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(
[
@ -653,8 +635,7 @@ class ToolApiListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
tenant_id = user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
[
@ -672,10 +653,9 @@ class ToolWorkflowListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(
[
@ -709,19 +689,18 @@ class ToolPluginOAuthApi(Resource):
provider_name = tool_provider.provider_name
# todo check permission
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
tenant_id = user.current_tenant_id
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider)
if oauth_client_params is None:
raise Forbidden("no oauth available client config found for this tool provider")
oauth_handler = OAuthHandler()
context_id = OAuthProxyService.create_proxy_context(
user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
user_id=user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
)
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
authorization_url_response = oauth_handler.get_authorization_url(
@ -800,11 +779,12 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
@login_required
@account_initialization_required
def post(self, provider):
current_user, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
return BuiltinToolManageService.set_default_provider(
tenant_id=current_user.current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
)
@ -819,13 +799,13 @@ class ToolOAuthCustomClient(Resource):
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
args = parser.parse_args()
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
return BuiltinToolManageService.save_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
provider=provider,
client_params=args.get("client_params", {}),
enable_oauth_custom_client=args.get("enable_oauth_custom_client", True),
@ -835,20 +815,18 @@ class ToolOAuthCustomClient(Resource):
@login_required
@account_initialization_required
def get(self, provider):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_custom_oauth_client_params(
tenant_id=current_user.current_tenant_id, provider=provider
)
BuiltinToolManageService.get_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
)
@setup_required
@login_required
@account_initialization_required
def delete(self, provider):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.delete_custom_oauth_client_params(
tenant_id=current_user.current_tenant_id, provider=provider
)
BuiltinToolManageService.delete_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
)
@ -858,9 +836,10 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema(
tenant_id=current_user.current_tenant_id, provider_name=provider
tenant_id=current_tenant_id, provider_name=provider
)
)
@ -871,7 +850,7 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_credential_info(
@ -894,30 +873,37 @@ class ToolProviderMCPApi(Resource):
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
parser.add_argument("timeout", type=float, required=False, nullable=False, location="json", default=30)
parser.add_argument(
"sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300
)
parser.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
parser.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
args = parser.parse_args()
user = current_user
user, tenant_id = current_account_with_tenant()
# Validate server URL
if not is_valid_url(args["server_url"]):
raise ValueError("Server URL is not valid.")
return jsonable_encoder(
MCPToolManageService.create_mcp_provider(
tenant_id=user.current_tenant_id,
# Parse and validate models
configuration = MCPConfiguration.model_validate(args["configuration"])
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
# Create provider
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
result = service.create_provider(
tenant_id=tenant_id,
user_id=user.id,
server_url=args["server_url"],
name=args["name"],
icon=args["icon"],
icon_type=args["icon_type"],
icon_background=args["icon_background"],
user_id=user.id,
server_identifier=args["server_identifier"],
timeout=args["timeout"],
sse_read_timeout=args["sse_read_timeout"],
headers=args["headers"],
configuration=configuration,
authentication=authentication,
)
)
return jsonable_encoder(result)
@setup_required
@login_required
@ -931,29 +917,35 @@ class ToolProviderMCPApi(Resource):
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
parser.add_argument("timeout", type=float, required=False, nullable=True, location="json")
parser.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json")
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json")
parser.add_argument("configuration", type=dict, required=False, nullable=True, location="json")
parser.add_argument("authentication", type=dict, required=False, nullable=True, location="json")
args = parser.parse_args()
if not is_valid_url(args["server_url"]):
if "[__HIDDEN__]" in args["server_url"]:
pass
else:
raise ValueError("Server URL is not valid.")
MCPToolManageService.update_mcp_provider(
tenant_id=current_user.current_tenant_id,
provider_id=args["provider_id"],
server_url=args["server_url"],
name=args["name"],
icon=args["icon"],
icon_type=args["icon_type"],
icon_background=args["icon_background"],
server_identifier=args["server_identifier"],
timeout=args.get("timeout"),
sse_read_timeout=args.get("sse_read_timeout"),
headers=args.get("headers"),
)
return {"result": "success"}
configuration = MCPConfiguration.model_validate(args["configuration"])
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
_, current_tenant_id = current_account_with_tenant()
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.update_provider(
tenant_id=current_tenant_id,
provider_id=args["provider_id"],
server_url=args["server_url"],
name=args["name"],
icon=args["icon"],
icon_type=args["icon_type"],
icon_background=args["icon_background"],
server_identifier=args["server_identifier"],
headers=args["headers"],
configuration=configuration,
authentication=authentication,
)
return {"result": "success"}
@setup_required
@login_required
@ -962,8 +954,12 @@ class ToolProviderMCPApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
MCPToolManageService.delete_mcp_tool(tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"])
return {"result": "success"}
_, current_tenant_id = current_account_with_tenant()
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"])
return {"result": "success"}
@console_ns.route("/workspaces/current/tool-provider/mcp/auth")
@ -977,39 +973,44 @@ class ToolMCPAuthApi(Resource):
parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
provider_id = args["provider_id"]
tenant_id = current_user.current_tenant_id
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if not provider:
raise ValueError("provider not found")
try:
with MCPClient(
provider.decrypted_server_url,
provider_id,
tenant_id,
authed=False,
authorization_code=args["authorization_code"],
for_list=True,
headers=provider.decrypted_headers,
timeout=provider.timeout,
sse_read_timeout=provider.sse_read_timeout,
):
MCPToolManageService.update_mcp_provider_credentials(
mcp_provider=provider,
credentials=provider.decrypted_credentials,
authed=True,
)
return {"result": "success"}
_, tenant_id = current_account_with_tenant()
except MCPAuthError:
auth_provider = OAuthClientProvider(provider_id, tenant_id, for_list=True)
return auth(auth_provider, provider.decrypted_server_url, args["authorization_code"])
except MCPError as e:
MCPToolManageService.update_mcp_provider_credentials(
mcp_provider=provider,
credentials={},
authed=False,
)
raise ValueError(f"Failed to connect to MCP server: {e}") from e
with Session(db.engine) as session:
with session.begin():
service = MCPToolManageService(session=session)
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
if not db_provider:
raise ValueError("provider not found")
# Convert to entity
provider_entity = db_provider.to_entity()
server_url = provider_entity.decrypt_server_url()
headers = provider_entity.decrypt_authentication()
# Try to connect without active transaction
try:
# Use MCPClientWithAuthRetry to handle authentication automatically
with MCPClient(
server_url=server_url,
headers=headers,
timeout=provider_entity.timeout,
sse_read_timeout=provider_entity.sse_read_timeout,
):
# Create new transaction for update
with session.begin():
service.update_provider_credentials(
provider=db_provider,
credentials=provider_entity.credentials,
authed=True,
)
return {"result": "success"}
except MCPAuthError as e:
service = MCPToolManageService(session=session)
return auth(provider_entity, service, args.get("authorization_code"))
except MCPError as e:
with session.begin():
service.clear_provider_credentials(provider=db_provider)
raise ValueError(f"Failed to connect to MCP server: {e}") from e
@console_ns.route("/workspaces/current/tool-provider/mcp/tools/<path:provider_id>")
@ -1018,9 +1019,11 @@ class ToolMCPDetailApi(Resource):
@login_required
@account_initialization_required
def get(self, provider_id):
user = current_user
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, user.current_tenant_id)
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
_, tenant_id = current_account_with_tenant()
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
@console_ns.route("/workspaces/current/tools/mcp")
@ -1029,12 +1032,13 @@ class ToolMCPListAllApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
tenant_id = user.current_tenant_id
_, tenant_id = current_account_with_tenant()
tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id)
with Session(db.engine, expire_on_commit=False) as session:
service = MCPToolManageService(session=session)
tools = service.list_providers(tenant_id=tenant_id)
return [tool.to_dict() for tool in tools]
return [tool.to_dict() for tool in tools]
@console_ns.route("/workspaces/current/tool-provider/mcp/update/<path:provider_id>")
@ -1043,12 +1047,14 @@ class ToolMCPUpdateApi(Resource):
@login_required
@account_initialization_required
def get(self, provider_id):
tenant_id = current_user.current_tenant_id
tools = MCPToolManageService.list_mcp_tool_from_remote_server(
tenant_id=tenant_id,
provider_id=provider_id,
)
return jsonable_encoder(tools)
_, tenant_id = current_account_with_tenant()
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
tools = service.list_provider_tools(
tenant_id=tenant_id,
provider_id=provider_id,
)
return jsonable_encoder(tools)
@console_ns.route("/mcp/oauth/callback")
@ -1060,5 +1066,10 @@ class ToolMCPCallbackApi(Resource):
args = parser.parse_args()
state_key = args["state"]
authorization_code = args["code"]
handle_callback(state_key, authorization_code)
# Create service instance for handle_callback
with Session(db.engine) as session, session.begin():
mcp_service = MCPToolManageService(session=session)
handle_callback(state_key, authorization_code, mcp_service)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")

View File

@ -12,8 +12,8 @@ from configs import dify_config
from controllers.console.workspace.error import AccountNotInitializedError
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.login import current_user
from models.account import Account, AccountStatus
from libs.login import current_account_with_tenant
from models.account import AccountStatus
from models.dataset import RateLimitLog
from models.model import DifySetup
from services.feature_service import FeatureService, LicenseStatus
@ -25,16 +25,13 @@ P = ParamSpec("P")
R = TypeVar("R")
def _current_account() -> Account:
assert isinstance(current_user, Account)
return current_user
def account_initialization_required(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
# check account initialization
account = _current_account()
current_user, _ = current_account_with_tenant()
account = current_user
if account.status == AccountStatus.UNINITIALIZED:
raise AccountNotInitializedError()
@ -80,9 +77,8 @@ def only_edition_self_hosted(view: Callable[P, R]):
def cloud_edition_billing_enabled(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
account = _current_account()
assert account.current_tenant_id is not None
features = FeatureService.get_features(account.current_tenant_id)
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if not features.billing.enabled:
abort(403, "Billing feature is not enabled.")
return view(*args, **kwargs)
@ -94,10 +90,8 @@ def cloud_edition_billing_resource_check(resource: str):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
account = _current_account()
assert account.current_tenant_id is not None
tenant_id = account.current_tenant_id
features = FeatureService.get_features(tenant_id)
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if features.billing.enabled:
members = features.members
apps = features.apps
@ -138,9 +132,8 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
account = _current_account()
assert account.current_tenant_id is not None
features = FeatureService.get_features(account.current_tenant_id)
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if features.billing.enabled:
if resource == "add_segment":
if features.billing.subscription.plan == "sandbox":
@ -163,13 +156,11 @@ def cloud_edition_billing_rate_limit_check(resource: str):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
if resource == "knowledge":
account = _current_account()
assert account.current_tenant_id is not None
tenant_id = account.current_tenant_id
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(tenant_id)
_, current_tenant_id = current_account_with_tenant()
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{tenant_id}"
key = f"rate_limit_{current_tenant_id}"
redis_client.zadd(key, {current_time: current_time})
@ -180,7 +171,7 @@ def cloud_edition_billing_rate_limit_check(resource: str):
if request_count > knowledge_rate_limit.limit:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=tenant_id,
tenant_id=current_tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
@ -200,17 +191,15 @@ def cloud_utm_record(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
with contextlib.suppress(Exception):
account = _current_account()
assert account.current_tenant_id is not None
tenant_id = account.current_tenant_id
features = FeatureService.get_features(tenant_id)
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if features.billing.enabled:
utm_info = request.cookies.get("utm_info")
if utm_info:
utm_info_dict: dict = json.loads(utm_info)
OperationService.record_utm(tenant_id, utm_info_dict)
OperationService.record_utm(current_tenant_id, utm_info_dict)
return view(*args, **kwargs)
@ -289,9 +278,8 @@ def enable_change_email(view: Callable[P, R]):
def is_allow_transfer_owner(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
account = _current_account()
assert account.current_tenant_id is not None
features = FeatureService.get_features(account.current_tenant_id)
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if features.is_allow_transfer_workspace:
return view(*args, **kwargs)
@ -301,12 +289,11 @@ def is_allow_transfer_owner(view: Callable[P, R]):
return decorated
def knowledge_pipeline_publish_enabled(view):
def knowledge_pipeline_publish_enabled(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
account = _current_account()
assert account.current_tenant_id is not None
features = FeatureService.get_features(account.current_tenant_id)
def decorated(*args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if features.knowledge_pipeline.publish_enabled:
return view(*args, **kwargs)
abort(403)

View File

@ -1,5 +1,4 @@
from flask import request
from flask_login import current_user
from flask_restx import marshal, reqparse
from werkzeug.exceptions import NotFound
@ -16,6 +15,7 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from fields.segment_fields import child_chunk_fields, segment_fields
from libs.login import current_account_with_tenant
from models.dataset import Dataset
from services.dataset_service import DatasetService, DocumentService, SegmentService
from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs
@ -66,6 +66,7 @@ class SegmentApi(DatasetApiResource):
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: str, document_id: str):
_, current_tenant_id = current_account_with_tenant()
"""Create single segment."""
# check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
@ -84,7 +85,7 @@ class SegmentApi(DatasetApiResource):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
@ -117,6 +118,7 @@ class SegmentApi(DatasetApiResource):
}
)
def get(self, tenant_id: str, dataset_id: str, document_id: str):
_, current_tenant_id = current_account_with_tenant()
"""Get segments."""
# check dataset
page = request.args.get("page", default=1, type=int)
@ -133,7 +135,7 @@ class SegmentApi(DatasetApiResource):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
@ -149,7 +151,7 @@ class SegmentApi(DatasetApiResource):
segments, total = SegmentService.get_segments(
document_id=document_id,
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
status_list=args["status"],
keyword=args["keyword"],
page=page,
@ -184,6 +186,7 @@ class DatasetSegmentApi(DatasetApiResource):
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
_, current_tenant_id = current_account_with_tenant()
# check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
@ -195,7 +198,7 @@ class DatasetSegmentApi(DatasetApiResource):
if not document:
raise NotFound("Document not found.")
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
SegmentService.delete_segment(segment, document, dataset)
@ -217,6 +220,7 @@ class DatasetSegmentApi(DatasetApiResource):
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
_, current_tenant_id = current_account_with_tenant()
# check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
@ -232,7 +236,7 @@ class DatasetSegmentApi(DatasetApiResource):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
@ -244,7 +248,7 @@ class DatasetSegmentApi(DatasetApiResource):
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
@ -266,6 +270,7 @@ class DatasetSegmentApi(DatasetApiResource):
}
)
def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
_, current_tenant_id = current_account_with_tenant()
# check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
@ -277,7 +282,7 @@ class DatasetSegmentApi(DatasetApiResource):
if not document:
raise NotFound("Document not found.")
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
@ -307,6 +312,7 @@ class ChildChunkApi(DatasetApiResource):
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
_, current_tenant_id = current_account_with_tenant()
"""Create child chunk."""
# check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
@ -319,7 +325,7 @@ class ChildChunkApi(DatasetApiResource):
raise NotFound("Document not found.")
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
@ -328,7 +334,7 @@ class ChildChunkApi(DatasetApiResource):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
@ -364,6 +370,7 @@ class ChildChunkApi(DatasetApiResource):
}
)
def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
_, current_tenant_id = current_account_with_tenant()
"""Get child chunks."""
# check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
@ -376,7 +383,7 @@ class ChildChunkApi(DatasetApiResource):
raise NotFound("Document not found.")
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
@ -423,6 +430,7 @@ class DatasetChildChunkApi(DatasetApiResource):
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str):
_, current_tenant_id = current_account_with_tenant()
"""Delete child chunk."""
# check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
@ -435,7 +443,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Document not found.")
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
@ -444,9 +452,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Document not found.")
# check child chunk
child_chunk = SegmentService.get_child_chunk_by_id(
child_chunk_id=child_chunk_id, tenant_id=current_user.current_tenant_id
)
child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id)
if not child_chunk:
raise NotFound("Child chunk not found.")
@ -483,6 +489,7 @@ class DatasetChildChunkApi(DatasetApiResource):
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str):
_, current_tenant_id = current_account_with_tenant()
"""Update child chunk."""
# check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
@ -495,7 +502,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Document not found.")
# get segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
@ -504,9 +511,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Segment not found.")
# get child chunk
child_chunk = SegmentService.get_child_chunk_by_id(
child_chunk_id=child_chunk_id, tenant_id=current_user.current_tenant_id
)
child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id)
if not child_chunk:
raise NotFound("Child chunk not found.")

View File

@ -1,10 +1,12 @@
import logging
import queue
import threading
import time
from abc import abstractmethod
from enum import IntEnum, auto
from typing import Any
from cachetools import TTLCache, cachedmethod
from redis.exceptions import RedisError
from sqlalchemy.orm import DeclarativeMeta
@ -45,6 +47,8 @@ class AppQueueManager:
q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
self._q = q
self._stopped_cache: TTLCache[tuple, bool] = TTLCache(maxsize=1, ttl=1)
self._cache_lock = threading.Lock()
def listen(self):
"""
@ -157,6 +161,7 @@ class AppQueueManager:
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
redis_client.setex(stopped_cache_key, 600, 1)
@cachedmethod(lambda self: self._stopped_cache, lock=lambda self: self._cache_lock)
def _is_stopped(self) -> bool:
"""
Check if task is stopped

View File

@ -0,0 +1,341 @@
import json
from datetime import datetime
from enum import StrEnum
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse
from pydantic import BaseModel
from configs import dify_config
from core.entities.provider_entities import BasicProviderConfig
from core.file import helpers as file_helpers
from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.utils.encryption import create_provider_encrypter
if TYPE_CHECKING:
from models.tools import MCPToolProvider
# Constants
CLIENT_NAME = "Dify"
CLIENT_URI = "https://github.com/langgenius/dify"
DEFAULT_TOKEN_TYPE = "Bearer"
DEFAULT_EXPIRES_IN = 3600
MASK_CHAR = "*"
MIN_UNMASK_LENGTH = 6
class MCPSupportGrantType(StrEnum):
"""The supported grant types for MCP"""
AUTHORIZATION_CODE = "authorization_code"
CLIENT_CREDENTIALS = "client_credentials"
REFRESH_TOKEN = "refresh_token"
class MCPAuthentication(BaseModel):
client_id: str
client_secret: str | None = None
class MCPConfiguration(BaseModel):
timeout: float = 30
sse_read_timeout: float = 300
class MCPProviderEntity(BaseModel):
"""MCP Provider domain entity for business logic operations"""
# Basic identification
id: str
provider_id: str # server_identifier
name: str
tenant_id: str
user_id: str
# Server connection info
server_url: str # encrypted URL
headers: dict[str, str] # encrypted headers
timeout: float
sse_read_timeout: float
# Authentication related
authed: bool
credentials: dict[str, Any] # encrypted credentials
code_verifier: str | None = None # for OAuth
# Tools and display info
tools: list[dict[str, Any]] # parsed tools list
icon: str | dict[str, str] # parsed icon
# Timestamps
created_at: datetime
updated_at: datetime
@classmethod
def from_db_model(cls, db_provider: "MCPToolProvider") -> "MCPProviderEntity":
"""Create entity from database model with decryption"""
return cls(
id=db_provider.id,
provider_id=db_provider.server_identifier,
name=db_provider.name,
tenant_id=db_provider.tenant_id,
user_id=db_provider.user_id,
server_url=db_provider.server_url,
headers=db_provider.headers,
timeout=db_provider.timeout,
sse_read_timeout=db_provider.sse_read_timeout,
authed=db_provider.authed,
credentials=db_provider.credentials,
tools=db_provider.tool_dict,
icon=db_provider.icon or "",
created_at=db_provider.created_at,
updated_at=db_provider.updated_at,
)
@property
def redirect_url(self) -> str:
"""OAuth redirect URL"""
return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
@property
def client_metadata(self) -> OAuthClientMetadata:
"""Metadata about this OAuth client."""
# Get grant type from credentials
credentials = self.decrypt_credentials()
# Try to get grant_type from different locations
grant_type = credentials.get("grant_type", MCPSupportGrantType.AUTHORIZATION_CODE)
# For nested structure, check if client_information has grant_types
if "client_information" in credentials and isinstance(credentials["client_information"], dict):
client_info = credentials["client_information"]
# If grant_types is specified in client_information, use it to determine grant_type
if "grant_types" in client_info and isinstance(client_info["grant_types"], list):
if "client_credentials" in client_info["grant_types"]:
grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS
elif "authorization_code" in client_info["grant_types"]:
grant_type = MCPSupportGrantType.AUTHORIZATION_CODE
# Configure based on grant type
is_client_credentials = grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS
grant_types = ["refresh_token"]
grant_types.append("client_credentials" if is_client_credentials else "authorization_code")
response_types = [] if is_client_credentials else ["code"]
redirect_uris = [] if is_client_credentials else [self.redirect_url]
return OAuthClientMetadata(
redirect_uris=redirect_uris,
token_endpoint_auth_method="none",
grant_types=grant_types,
response_types=response_types,
client_name=CLIENT_NAME,
client_uri=CLIENT_URI,
)
@property
def provider_icon(self) -> dict[str, str] | str:
"""Get provider icon, handling both dict and string formats"""
if isinstance(self.icon, dict):
return self.icon
try:
return json.loads(self.icon)
except (json.JSONDecodeError, TypeError):
# If not JSON, assume it's a file path
return file_helpers.get_signed_file_url(self.icon)
def to_api_response(self, user_name: str | None = None) -> dict[str, Any]:
"""Convert to API response format"""
response = {
"id": self.id,
"author": user_name or "Anonymous",
"name": self.name,
"icon": self.provider_icon,
"type": ToolProviderType.MCP.value,
"is_team_authorization": self.authed,
"server_url": self.masked_server_url(),
"server_identifier": self.provider_id,
"updated_at": int(self.updated_at.timestamp()),
"label": I18nObject(en_US=self.name, zh_Hans=self.name).to_dict(),
"description": I18nObject(en_US="", zh_Hans="").to_dict(),
}
# Add configuration
response["configuration"] = {
"timeout": str(self.timeout),
"sse_read_timeout": str(self.sse_read_timeout),
}
# Add masked headers
response["masked_headers"] = self.masked_headers()
# Add authentication info if available
masked_creds = self.masked_credentials()
if masked_creds:
response["authentication"] = masked_creds
response["is_dynamic_registration"] = self.credentials.get("is_dynamic_registration", True)
return response
def retrieve_client_information(self) -> OAuthClientInformation | None:
"""OAuth client information if available"""
credentials = self.decrypt_credentials()
if not credentials:
return None
# Check if we have nested client_information structure
if "client_information" in credentials:
# Handle nested structure (Authorization Code flow)
client_info_data = credentials["client_information"]
if isinstance(client_info_data, dict):
return OAuthClientInformation.model_validate(client_info_data)
return None
# Handle flat structure (Client Credentials flow)
if "client_id" not in credentials:
return None
# Build client information from flat structure
client_info = {
"client_id": credentials.get("client_id", ""),
"client_secret": credentials.get("client_secret", ""),
"client_name": credentials.get("client_name", CLIENT_NAME),
}
# Parse JSON fields if they exist
json_fields = ["redirect_uris", "grant_types", "response_types"]
for field in json_fields:
if field in credentials:
client_info[field] = json.loads(credentials[field])
if "scope" in credentials:
client_info["scope"] = credentials["scope"]
return OAuthClientInformation.model_validate(client_info)
def retrieve_tokens(self) -> OAuthTokens | None:
"""OAuth tokens if available"""
if not self.credentials:
return None
credentials = self.decrypt_credentials()
return OAuthTokens(
access_token=credentials.get("access_token", ""),
token_type=credentials.get("token_type", DEFAULT_TOKEN_TYPE),
expires_in=int(credentials.get("expires_in", str(DEFAULT_EXPIRES_IN)) or DEFAULT_EXPIRES_IN),
refresh_token=credentials.get("refresh_token", ""),
)
def masked_server_url(self) -> str:
"""Masked server URL for display"""
parsed = urlparse(self.decrypt_server_url())
if parsed.path and parsed.path != "/":
masked = parsed._replace(path="/******")
return masked.geturl()
return parsed.geturl()
def _mask_value(self, value: str) -> str:
"""Mask a sensitive value for display"""
if len(value) > MIN_UNMASK_LENGTH:
return value[:2] + MASK_CHAR * (len(value) - 4) + value[-2:]
else:
return MASK_CHAR * len(value)
def masked_headers(self) -> dict[str, str]:
"""Masked headers for display"""
return {key: self._mask_value(value) for key, value in self.decrypt_headers().items()}
def masked_credentials(self) -> dict[str, str]:
"""Masked credentials for display"""
credentials = self.decrypt_credentials()
if not credentials:
return {}
masked = {}
# Check if we have nested client_information structure
if "client_information" in credentials and isinstance(credentials["client_information"], dict):
client_info = credentials["client_information"]
# Mask sensitive fields from nested structure
if client_info.get("client_id"):
masked["client_id"] = self._mask_value(client_info["client_id"])
if client_info.get("client_secret"):
masked["client_secret"] = self._mask_value(client_info["client_secret"])
else:
# Handle flat structure
# Mask sensitive fields
sensitive_fields = ["client_id", "client_secret"]
for field in sensitive_fields:
if credentials.get(field):
masked[field] = self._mask_value(credentials[field])
# Include non-sensitive fields (check both flat and nested structures)
if "grant_type" in credentials:
masked["grant_type"] = credentials["grant_type"]
return masked
def decrypt_server_url(self) -> str:
"""Decrypt server URL"""
return encrypter.decrypt_token(self.tenant_id, self.server_url)
def _decrypt_dict(self, data: dict[str, Any]) -> dict[str, Any]:
"""Generic method to decrypt dictionary fields"""
if not data:
return {}
# Only decrypt fields that are actually encrypted
# For nested structures, client_information is not encrypted as a whole
encrypted_fields = []
for key, value in data.items():
# Skip nested objects - they are not encrypted
if isinstance(value, dict):
continue
# Only process string values that might be encrypted
if isinstance(value, str) and value:
encrypted_fields.append(key)
if not encrypted_fields:
return data
# Create dynamic config only for encrypted fields
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in encrypted_fields]
encrypter_instance, _ = create_provider_encrypter(
tenant_id=self.tenant_id,
config=config,
cache=NoOpProviderCredentialCache(),
)
# Decrypt only the encrypted fields
decrypted_data = encrypter_instance.decrypt({k: data[k] for k in encrypted_fields})
# Merge decrypted data with original data (preserving non-encrypted fields)
result = data.copy()
result.update(decrypted_data)
return result
def decrypt_headers(self) -> dict[str, Any]:
"""Decrypt headers"""
return self._decrypt_dict(self.headers)
def decrypt_credentials(self) -> dict[str, Any]:
"""Decrypt credentials"""
return self._decrypt_dict(self.credentials)
def decrypt_authentication(self) -> dict[str, Any]:
"""Decrypt authentication"""
# Option 1: if headers is provided, use it and don't need to get token
headers = self.decrypt_headers()
# Option 2: Add OAuth token if authed and no headers provided
if not self.headers and self.authed:
token = self.retrieve_tokens()
if token:
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
return headers

View File

@ -472,6 +472,9 @@ class ProviderConfiguration(BaseModel):
provider_model_credentials_cache.delete()
self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session)
else:
# some historical data may have a provider record but not be set as valid
provider_record.is_valid = True
session.commit()
except Exception:

View File

@ -4,13 +4,16 @@ import json
import os
import secrets
import urllib.parse
from typing import TYPE_CHECKING
from urllib.parse import urljoin, urlparse
import httpx
from httpx import ConnectError, HTTPStatusError, RequestError
from pydantic import BaseModel, ValidationError
from core.mcp.auth.auth_provider import OAuthClientProvider
from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType
from core.helper import ssrf_proxy
from core.mcp.types import (
LATEST_PROTOCOL_VERSION,
OAuthClientInformation,
OAuthClientInformationFull,
OAuthClientMetadata,
@ -19,7 +22,9 @@ from core.mcp.types import (
)
from extensions.ext_redis import redis_client
LATEST_PROTOCOL_VERSION = "1.0"
if TYPE_CHECKING:
from services.tools.mcp_tools_manage_service import MCPToolManageService
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
@ -80,7 +85,7 @@ def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
raise ValueError(f"Invalid state parameter: {str(e)}")
def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackState:
def handle_callback(state_key: str, authorization_code: str, mcp_service: "MCPToolManageService") -> OAuthCallbackState:
"""Handle the callback from the OAuth provider."""
# Retrieve state data from Redis (state is automatically deleted after retrieval)
full_state_data = _retrieve_redis_state(state_key)
@ -93,30 +98,35 @@ def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackSta
full_state_data.code_verifier,
full_state_data.redirect_uri,
)
provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id, for_list=True)
provider.save_tokens(tokens)
# Save tokens using the service layer
mcp_service.save_oauth_data(full_state_data.provider_id, full_state_data.tenant_id, tokens.model_dump(), "tokens")
return full_state_data
def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
"""Check if the server supports OAuth 2.0 Resource Discovery."""
b_scheme, b_netloc, b_path, _, b_query, b_fragment = urlparse(server_url, "", True)
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource{b_path}"
b_scheme, b_netloc, _, _, b_query, b_fragment = urlparse(server_url, "", True)
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource"
if b_query:
url_for_resource_discovery += f"?{b_query}"
if b_fragment:
url_for_resource_discovery += f"#{b_fragment}"
try:
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
response = httpx.get(url_for_resource_discovery, headers=headers)
response = ssrf_proxy.get(url_for_resource_discovery, headers=headers)
if 200 <= response.status_code < 300:
body = response.json()
if "authorization_server_url" in body:
# Support both singular and plural forms
if body.get("authorization_servers"):
return True, body["authorization_servers"][0]
elif body.get("authorization_server_url"):
return True, body["authorization_server_url"][0]
else:
return False, ""
return False, ""
except httpx.RequestError:
except RequestError:
# Not support resource discovery, fall back to well-known OAuth metadata
return False, ""
@ -126,27 +136,37 @@ def discover_oauth_metadata(server_url: str, protocol_version: str | None = None
# First check if the server supports OAuth 2.0 Resource Discovery
support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url)
if support_resource_discovery:
url = oauth_discovery_url
# The oauth_discovery_url is the authorization server base URL
# Try OpenID Connect discovery first (more common), then OAuth 2.0
urls_to_try = [
urljoin(oauth_discovery_url + "/", ".well-known/oauth-authorization-server"),
urljoin(oauth_discovery_url + "/", ".well-known/openid-configuration"),
]
else:
url = urljoin(server_url, "/.well-known/oauth-authorization-server")
urls_to_try = [urljoin(server_url, "/.well-known/oauth-authorization-server")]
try:
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
response = httpx.get(url, headers=headers)
if response.status_code == 404:
return None
if not response.is_success:
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
return OAuthMetadata.model_validate(response.json())
except httpx.RequestError as e:
if isinstance(e, httpx.ConnectError):
response = httpx.get(url)
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
for url in urls_to_try:
try:
response = ssrf_proxy.get(url, headers=headers)
if response.status_code == 404:
return None
continue
if not response.is_success:
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
response.raise_for_status()
return OAuthMetadata.model_validate(response.json())
raise
except (RequestError, HTTPStatusError) as e:
if isinstance(e, ConnectError):
response = ssrf_proxy.get(url)
if response.status_code == 404:
continue # Try next URL
if not response.is_success:
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
return OAuthMetadata.model_validate(response.json())
# For other errors, try next URL
continue
return None # No metadata found
def start_authorization(
@ -213,7 +233,7 @@ def exchange_authorization(
redirect_uri: str,
) -> OAuthTokens:
"""Exchanges an authorization code for an access token."""
grant_type = "authorization_code"
grant_type = MCPSupportGrantType.AUTHORIZATION_CODE.value
if metadata:
token_url = metadata.token_endpoint
@ -233,7 +253,7 @@ def exchange_authorization(
if client_information.client_secret:
params["client_secret"] = client_information.client_secret
response = httpx.post(token_url, data=params)
response = ssrf_proxy.post(token_url, data=params)
if not response.is_success:
raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
return OAuthTokens.model_validate(response.json())
@ -246,7 +266,7 @@ def refresh_authorization(
refresh_token: str,
) -> OAuthTokens:
"""Exchange a refresh token for an updated access token."""
grant_type = "refresh_token"
grant_type = MCPSupportGrantType.REFRESH_TOKEN.value
if metadata:
token_url = metadata.token_endpoint
@ -264,12 +284,55 @@ def refresh_authorization(
if client_information.client_secret:
params["client_secret"] = client_information.client_secret
response = httpx.post(token_url, data=params)
response = ssrf_proxy.post(token_url, data=params)
if not response.is_success:
raise ValueError(f"Token refresh failed: HTTP {response.status_code}")
return OAuthTokens.model_validate(response.json())
def client_credentials_flow(
server_url: str,
metadata: OAuthMetadata | None,
client_information: OAuthClientInformation,
scope: str | None = None,
) -> OAuthTokens:
"""Execute Client Credentials Flow to get access token."""
grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
if metadata:
token_url = metadata.token_endpoint
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
else:
token_url = urljoin(server_url, "/token")
# Support both Basic Auth and body parameters for client authentication
headers = {"Content-Type": "application/x-www-form-urlencoded"}
data = {"grant_type": grant_type}
if scope:
data["scope"] = scope
# If client_secret is provided, use Basic Auth (preferred method)
if client_information.client_secret:
credentials = f"{client_information.client_id}:{client_information.client_secret}"
encoded_credentials = base64.b64encode(credentials.encode()).decode()
headers["Authorization"] = f"Basic {encoded_credentials}"
else:
# Fall back to including credentials in the body
data["client_id"] = client_information.client_id
if client_information.client_secret:
data["client_secret"] = client_information.client_secret
response = ssrf_proxy.post(token_url, headers=headers, data=data)
if not response.is_success:
raise ValueError(
f"Client credentials token request failed: HTTP {response.status_code}, Response: {response.text}"
)
return OAuthTokens.model_validate(response.json())
def register_client(
server_url: str,
metadata: OAuthMetadata | None,
@ -283,7 +346,7 @@ def register_client(
else:
registration_url = urljoin(server_url, "/register")
response = httpx.post(
response = ssrf_proxy.post(
registration_url,
json=client_metadata.model_dump(),
headers={"Content-Type": "application/json"},
@ -294,28 +357,85 @@ def register_client(
def auth(
provider: OAuthClientProvider,
server_url: str,
provider: MCPProviderEntity,
mcp_service: "MCPToolManageService",
authorization_code: str | None = None,
state_param: str | None = None,
for_list: bool = False,
) -> dict[str, str]:
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
metadata = discover_oauth_metadata(server_url)
server_url = provider.decrypt_server_url()
server_metadata = discover_oauth_metadata(server_url)
client_metadata = provider.client_metadata
provider_id = provider.id
tenant_id = provider.tenant_id
client_information = provider.retrieve_client_information()
redirect_url = provider.redirect_url
# Determine grant type based on server metadata
if not server_metadata:
raise ValueError("Failed to discover OAuth metadata from server")
supported_grant_types = server_metadata.grant_types_supported or []
# Convert to lowercase for comparison
supported_grant_types_lower = [gt.lower() for gt in supported_grant_types]
# Determine which grant type to use
effective_grant_type = None
if MCPSupportGrantType.AUTHORIZATION_CODE.value in supported_grant_types_lower:
effective_grant_type = MCPSupportGrantType.AUTHORIZATION_CODE.value
else:
effective_grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
# Get stored credentials
credentials = provider.decrypt_credentials()
# Handle client registration if needed
client_information = provider.client_information()
if not client_information:
if authorization_code is not None:
raise ValueError("Existing OAuth client information is required when exchanging an authorization code")
# For client credentials flow, we don't need to register client dynamically
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
# Client should provide client_id and client_secret directly
raise ValueError("Client credentials flow requires client_id and client_secret to be provided")
try:
full_information = register_client(server_url, metadata, provider.client_metadata)
except httpx.RequestError as e:
full_information = register_client(server_url, server_metadata, client_metadata)
except RequestError as e:
raise ValueError(f"Could not register OAuth client: {e}")
provider.save_client_information(full_information)
# Save client information using service layer
mcp_service.save_oauth_data(
provider_id, tenant_id, {"client_information": full_information.model_dump()}, "client_info"
)
client_information = full_information
# Exchange authorization code for tokens
# Handle client credentials flow
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
# Direct token request without user interaction
try:
scope = credentials.get("scope")
tokens = client_credentials_flow(
server_url,
server_metadata,
client_information,
scope,
)
# Save tokens and grant type
token_data = tokens.model_dump()
token_data["grant_type"] = MCPSupportGrantType.CLIENT_CREDENTIALS.value
mcp_service.save_oauth_data(provider_id, tenant_id, token_data, "tokens")
return {"result": "success"}
except (RequestError, ValueError, KeyError) as e:
# RequestError: HTTP request failed
# ValueError: Invalid response data
# KeyError: Missing required fields in response
raise ValueError(f"Client credentials flow failed: {e}")
# Exchange authorization code for tokens (Authorization Code flow)
if authorization_code is not None:
if not state_param:
raise ValueError("State parameter is required when exchanging authorization code")
@ -335,35 +455,48 @@ def auth(
tokens = exchange_authorization(
server_url,
metadata,
server_metadata,
client_information,
authorization_code,
code_verifier,
redirect_uri,
)
provider.save_tokens(tokens)
# Save tokens using service layer
mcp_service.save_oauth_data(provider_id, tenant_id, tokens.model_dump(), "tokens")
return {"result": "success"}
provider_tokens = provider.tokens()
provider_tokens = provider.retrieve_tokens()
# Handle token refresh or new authorization
if provider_tokens and provider_tokens.refresh_token:
try:
new_tokens = refresh_authorization(server_url, metadata, client_information, provider_tokens.refresh_token)
provider.save_tokens(new_tokens)
new_tokens = refresh_authorization(
server_url, server_metadata, client_information, provider_tokens.refresh_token
)
# Save new tokens using service layer
mcp_service.save_oauth_data(provider_id, tenant_id, new_tokens.model_dump(), "tokens")
return {"result": "success"}
except Exception as e:
except (RequestError, ValueError, KeyError) as e:
# RequestError: HTTP request failed
# ValueError: Invalid response data
# KeyError: Missing required fields in response
raise ValueError(f"Could not refresh OAuth tokens: {e}")
# Start new authorization flow
# Start new authorization flow (only for authorization code flow)
authorization_url, code_verifier = start_authorization(
server_url,
metadata,
server_metadata,
client_information,
provider.redirect_url,
provider.mcp_provider.id,
provider.mcp_provider.tenant_id,
redirect_url,
provider_id,
tenant_id,
)
provider.save_code_verifier(code_verifier)
# Save code verifier using service layer
mcp_service.save_oauth_data(provider_id, tenant_id, {"code_verifier": code_verifier}, "code_verifier")
return {"authorization_url": authorization_url}

View File

@ -1,77 +0,0 @@
from configs import dify_config
from core.mcp.types import (
OAuthClientInformation,
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthTokens,
)
from models.tools import MCPToolProvider
from services.tools.mcp_tools_manage_service import MCPToolManageService
class OAuthClientProvider:
mcp_provider: MCPToolProvider
def __init__(self, provider_id: str, tenant_id: str, for_list: bool = False):
if for_list:
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
else:
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_server_identifier(provider_id, tenant_id)
@property
def redirect_url(self) -> str:
"""The URL to redirect the user agent to after authorization."""
return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
@property
def client_metadata(self) -> OAuthClientMetadata:
"""Metadata about this OAuth client."""
return OAuthClientMetadata(
redirect_uris=[self.redirect_url],
token_endpoint_auth_method="none",
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
client_name="Dify",
client_uri="https://github.com/langgenius/dify",
)
def client_information(self) -> OAuthClientInformation | None:
"""Loads information about this OAuth client."""
client_information = self.mcp_provider.decrypted_credentials.get("client_information", {})
if not client_information:
return None
return OAuthClientInformation.model_validate(client_information)
def save_client_information(self, client_information: OAuthClientInformationFull):
"""Saves client information after dynamic registration."""
MCPToolManageService.update_mcp_provider_credentials(
self.mcp_provider,
{"client_information": client_information.model_dump()},
)
def tokens(self) -> OAuthTokens | None:
"""Loads any existing OAuth tokens for the current session."""
credentials = self.mcp_provider.decrypted_credentials
if not credentials:
return None
return OAuthTokens(
access_token=credentials.get("access_token", ""),
token_type=credentials.get("token_type", "Bearer"),
expires_in=int(credentials.get("expires_in", "3600") or 3600),
refresh_token=credentials.get("refresh_token", ""),
)
def save_tokens(self, tokens: OAuthTokens):
"""Stores new OAuth tokens for the current session."""
# update mcp provider credentials
token_dict = tokens.model_dump()
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True)
def save_code_verifier(self, code_verifier: str):
"""Saves a PKCE code verifier for the current session."""
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier})
def code_verifier(self) -> str:
"""Loads the PKCE code verifier for the current session."""
# get code verifier from mcp provider credentials
return str(self.mcp_provider.decrypted_credentials.get("code_verifier", ""))

182
api/core/mcp/auth_client.py Normal file
View File

@ -0,0 +1,182 @@
"""
MCP Client with Authentication Retry Support
This module provides an enhanced MCPClient that automatically handles
authentication failures and retries operations after refreshing tokens.
"""
import logging
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Optional
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.error import MCPAuthError
from core.mcp.mcp_client import MCPClient
from core.mcp.types import CallToolResult, Tool
if TYPE_CHECKING:
from services.tools.mcp_tools_manage_service import MCPToolManageService
logger = logging.getLogger(__name__)
class MCPClientWithAuthRetry(MCPClient):
"""
An enhanced MCPClient that provides automatic authentication retry.
This class extends MCPClient and intercepts MCPAuthError exceptions
to refresh authentication before retrying failed operations.
"""
def __init__(
self,
server_url: str,
headers: dict[str, str] | None = None,
timeout: float | None = None,
sse_read_timeout: float | None = None,
provider_entity: MCPProviderEntity | None = None,
auth_callback: Callable[[MCPProviderEntity, "MCPToolManageService", Optional[str]], dict[str, str]]
| None = None,
authorization_code: str | None = None,
by_server_id: bool = False,
mcp_service: Optional["MCPToolManageService"] = None,
):
"""
Initialize the MCP client with auth retry capability.
Args:
server_url: The MCP server URL
headers: Optional headers for requests
timeout: Request timeout
sse_read_timeout: SSE read timeout
provider_entity: Provider entity for authentication
auth_callback: Authentication callback function
authorization_code: Optional authorization code for initial auth
by_server_id: Whether to look up provider by server ID
mcp_service: MCP service instance
"""
super().__init__(server_url, headers, timeout, sse_read_timeout)
self.provider_entity = provider_entity
self.auth_callback = auth_callback
self.authorization_code = authorization_code
self.by_server_id = by_server_id
self.mcp_service = mcp_service
self._has_retried = False
def _handle_auth_error(self, error: MCPAuthError) -> None:
"""
Handle authentication error by refreshing tokens.
Args:
error: The authentication error
Raises:
MCPAuthError: If authentication fails or max retries reached
"""
if not self.provider_entity or not self.auth_callback or not self.mcp_service:
raise error
if self._has_retried:
raise error
self._has_retried = True
try:
# Perform authentication
self.auth_callback(self.provider_entity, self.mcp_service, self.authorization_code)
# Retrieve new tokens
self.provider_entity = self.mcp_service.get_provider_entity(
self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id
)
token = self.provider_entity.retrieve_tokens()
if not token:
raise MCPAuthError("Authentication failed - no token received")
# Update headers with new token
self.headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
# Clear authorization code after first use
self.authorization_code = None
except MCPAuthError:
# Re-raise MCPAuthError as is
raise
except Exception as e:
# Catch all exceptions during auth retry
logger.exception("Authentication retry failed")
raise MCPAuthError(f"Authentication retry failed: {e}") from e
def _execute_with_retry(self, func: Callable[..., Any], *args, **kwargs) -> Any:
"""
Execute a function with authentication retry logic.
Args:
func: The function to execute
*args: Positional arguments for the function
**kwargs: Keyword arguments for the function
Returns:
The result of the function call
Raises:
MCPAuthError: If authentication fails after retries
Any other exceptions from the function
"""
try:
return func(*args, **kwargs)
except MCPAuthError as e:
self._handle_auth_error(e)
# Re-initialize the connection with new headers
if self._initialized:
# Clean up existing connection
self._exit_stack.close()
self._session = None
self._initialized = False
# Re-initialize with new headers
self._initialize()
self._initialized = True
return func(*args, **kwargs)
finally:
# Reset retry flag after operation completes
self._has_retried = False
def __enter__(self):
"""Enter the context manager with retry support."""
def initialize_with_retry():
super(MCPClientWithAuthRetry, self).__enter__()
return self
return self._execute_with_retry(initialize_with_retry)
def list_tools(self) -> list[Tool]:
"""
List available tools from the MCP server with auth retry.
Returns:
List of available tools
Raises:
MCPAuthError: If authentication fails after retries
"""
return self._execute_with_retry(super().list_tools)
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
"""
Invoke a tool on the MCP server with auth retry.
Args:
tool_name: Name of the tool to invoke
tool_args: Arguments for the tool
Returns:
Result of the tool invocation
Raises:
MCPAuthError: If authentication fails after retries
"""
return self._execute_with_retry(super().invoke_tool, tool_name, tool_args)

View File

View File

@ -46,7 +46,7 @@ class SSETransport:
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5.0,
sse_read_timeout: float = 5 * 60,
sse_read_timeout: float = 1 * 60,
):
"""Initialize the SSE transport.
@ -255,7 +255,7 @@ def sse_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5.0,
sse_read_timeout: float = 5 * 60,
sse_read_timeout: float = 1 * 60,
) -> Generator[tuple[ReadQueue, WriteQueue], None, None]:
"""
Client transport for SSE.
@ -276,31 +276,34 @@ def sse_client(
read_queue: ReadQueue | None = None
write_queue: WriteQueue | None = None
with ThreadPoolExecutor() as executor:
try:
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
with ssrf_proxy_sse_connect(
url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
) as event_source:
event_source.response.raise_for_status()
executor = ThreadPoolExecutor()
try:
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
with ssrf_proxy_sse_connect(
url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
) as event_source:
event_source.response.raise_for_status()
read_queue, write_queue = transport.connect(executor, client, event_source)
read_queue, write_queue = transport.connect(executor, client, event_source)
yield read_queue, write_queue
yield read_queue, write_queue
except httpx.HTTPStatusError as exc:
if exc.response.status_code == 401:
raise MCPAuthError()
raise MCPConnectionError()
except Exception:
logger.exception("Error connecting to SSE endpoint")
raise
finally:
# Clean up queues
if read_queue:
read_queue.put(None)
if write_queue:
write_queue.put(None)
except httpx.HTTPStatusError as exc:
if exc.response.status_code == 401:
raise MCPAuthError()
raise MCPConnectionError()
except Exception:
logger.exception("Error connecting to SSE endpoint")
raise
finally:
# Clean up queues
if read_queue:
read_queue.put(None)
if write_queue:
write_queue.put(None)
# Shutdown executor without waiting to prevent hanging
executor.shutdown(wait=False)
def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage):

View File

@ -434,45 +434,48 @@ def streamablehttp_client(
server_to_client_queue: ServerToClientQueue = queue.Queue() # For messages FROM server TO client
client_to_server_queue: ClientToServerQueue = queue.Queue() # For messages FROM client TO server
with ThreadPoolExecutor(max_workers=2) as executor:
try:
with create_ssrf_proxy_mcp_http_client(
headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
) as client:
# Define callbacks that need access to thread pool
def start_get_stream():
"""Start a worker thread to handle server-initiated messages."""
executor.submit(transport.handle_get_stream, client, server_to_client_queue)
executor = ThreadPoolExecutor(max_workers=2)
try:
with create_ssrf_proxy_mcp_http_client(
headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
) as client:
# Define callbacks that need access to thread pool
def start_get_stream():
"""Start a worker thread to handle server-initiated messages."""
executor.submit(transport.handle_get_stream, client, server_to_client_queue)
# Start the post_writer worker thread
executor.submit(
transport.post_writer,
client,
client_to_server_queue, # Queue for messages FROM client TO server
server_to_client_queue, # Queue for messages FROM server TO client
start_get_stream,
)
# Start the post_writer worker thread
executor.submit(
transport.post_writer,
client,
client_to_server_queue, # Queue for messages FROM client TO server
server_to_client_queue, # Queue for messages FROM server TO client
start_get_stream,
)
try:
yield (
server_to_client_queue, # Queue for receiving messages FROM server
client_to_server_queue, # Queue for sending messages TO server
transport.get_session_id,
)
finally:
if transport.session_id and terminate_on_close:
transport.terminate_session(client)
# Signal threads to stop
client_to_server_queue.put(None)
finally:
# Clear any remaining items and add None sentinel to unblock any waiting threads
try:
while not client_to_server_queue.empty():
client_to_server_queue.get_nowait()
except queue.Empty:
pass
yield (
server_to_client_queue, # Queue for receiving messages FROM server
client_to_server_queue, # Queue for sending messages TO server
transport.get_session_id,
)
finally:
if transport.session_id and terminate_on_close:
transport.terminate_session(client)
client_to_server_queue.put(None)
server_to_client_queue.put(None)
# Signal threads to stop
client_to_server_queue.put(None)
finally:
# Clear any remaining items and add None sentinel to unblock any waiting threads
try:
while not client_to_server_queue.empty():
client_to_server_queue.get_nowait()
except queue.Empty:
pass
client_to_server_queue.put(None)
server_to_client_queue.put(None)
# Shutdown executor without waiting to prevent hanging
executor.shutdown(wait=False)

View File

@ -4,7 +4,7 @@ from typing import Any, Generic, TypeVar
from core.mcp.session.base_session import BaseSession
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestId, RequestParams
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", LATEST_PROTOCOL_VERSION]
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])

View File

@ -7,9 +7,9 @@ from urllib.parse import urlparse
from core.mcp.client.sse_client import sse_client
from core.mcp.client.streamable_client import streamablehttp_client
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.error import MCPConnectionError
from core.mcp.session.client_session import ClientSession
from core.mcp.types import Tool
from core.mcp.types import CallToolResult, Tool
logger = logging.getLogger(__name__)
@ -18,40 +18,18 @@ class MCPClient:
def __init__(
self,
server_url: str,
provider_id: str,
tenant_id: str,
authed: bool = True,
authorization_code: str | None = None,
for_list: bool = False,
headers: dict[str, str] | None = None,
timeout: float | None = None,
sse_read_timeout: float | None = None,
):
# Initialize info
self.provider_id = provider_id
self.tenant_id = tenant_id
self.client_type = "streamable"
self.server_url = server_url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
# Authentication info
self.authed = authed
self.authorization_code = authorization_code
if authed:
from core.mcp.auth.auth_provider import OAuthClientProvider
self.provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=for_list)
self.token = self.provider.tokens()
# Initialize session and client objects
self._session: ClientSession | None = None
self._streams_context: AbstractContextManager[Any] | None = None
self._session_context: ClientSession | None = None
self._exit_stack = ExitStack()
# Whether the client has been initialized
self._initialized = False
def __enter__(self):
@ -85,61 +63,42 @@ class MCPClient:
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
self.connect_server(streamablehttp_client, "mcp")
def connect_server(
self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str, first_try: bool = True
):
from core.mcp.auth.auth_flow import auth
def connect_server(self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str) -> None:
"""
Connect to the MCP server using streamable http or sse.
Default to streamable http.
Args:
client_factory: The client factory to use(streamablehttp_client or sse_client).
method_name: The method name to use(mcp or sse).
"""
streams_context = client_factory(
url=self.server_url,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)
try:
headers = (
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
if self.authed and self.token
else self.headers
)
self._streams_context = client_factory(
url=self.server_url,
headers=headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)
if not self._streams_context:
raise MCPConnectionError("Failed to create connection context")
# Use exit_stack to manage context managers properly
if method_name == "mcp":
read_stream, write_stream, _ = self._exit_stack.enter_context(streams_context)
streams = (read_stream, write_stream)
else: # sse_client
streams = self._exit_stack.enter_context(streams_context)
# Use exit_stack to manage context managers properly
if method_name == "mcp":
read_stream, write_stream, _ = self._exit_stack.enter_context(self._streams_context)
streams = (read_stream, write_stream)
else: # sse_client
streams = self._exit_stack.enter_context(self._streams_context)
self._session_context = ClientSession(*streams)
self._session = self._exit_stack.enter_context(self._session_context)
self._session.initialize()
return
except MCPAuthError:
if not self.authed:
raise
try:
auth(self.provider, self.server_url, self.authorization_code)
except Exception as e:
raise ValueError(f"Failed to authenticate: {e}")
self.token = self.provider.tokens()
if first_try:
return self.connect_server(client_factory, method_name, first_try=False)
session_context = ClientSession(*streams)
self._session = self._exit_stack.enter_context(session_context)
self._session.initialize()
def list_tools(self) -> list[Tool]:
"""Connect to an MCP server running with SSE transport"""
# List available tools to verify connection
if not self._initialized or not self._session:
"""List available tools from the MCP server"""
if not self._session:
raise ValueError("Session not initialized.")
response = self._session.list_tools()
tools = response.tools
return tools
return response.tools
def invoke_tool(self, tool_name: str, tool_args: dict):
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
"""Call a tool"""
if not self._initialized or not self._session:
if not self._session:
raise ValueError("Session not initialized.")
return self._session.call_tool(tool_name, tool_args)
@ -153,6 +112,4 @@ class MCPClient:
raise ValueError(f"Error during cleanup: {e}")
finally:
self._session = None
self._session_context = None
self._streams_context = None
self._initialized = False

View File

@ -201,11 +201,14 @@ class BaseSession(
self._receiver_future.result(timeout=5.0) # Wait up to 5 seconds
except TimeoutError:
# If the receiver loop is still running after timeout, we'll force shutdown
pass
# Cancel the future to interrupt the receiver loop
self._receiver_future.cancel()
# Shutdown the executor
if self._executor:
self._executor.shutdown(wait=True)
# Use non-blocking shutdown to prevent hanging
# The receiver thread should have already exited due to the None message in the queue
self._executor.shutdown(wait=False)
def send_request(
self,

View File

@ -284,7 +284,7 @@ class ClientSession(
def complete(
self,
ref: types.ResourceReference | types.PromptReference,
ref: types.ResourceTemplateReference | types.PromptReference,
argument: dict[str, str],
) -> types.CompleteResult:
"""Send a completion/complete request."""

View File

@ -1,13 +1,6 @@
from collections.abc import Callable
from dataclasses import dataclass
from typing import (
Annotated,
Any,
Generic,
Literal,
TypeAlias,
TypeVar,
)
from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar
from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel
from pydantic.networks import AnyUrl, UrlConstraints
@ -33,6 +26,7 @@ for reference.
LATEST_PROTOCOL_VERSION = "2025-03-26"
# Server support 2024-11-05 to allow claude to use.
SERVER_LATEST_PROTOCOL_VERSION = "2024-11-05"
DEFAULT_NEGOTIATED_VERSION = "2025-03-26"
ProgressToken = str | int
Cursor = str
Role = Literal["user", "assistant"]
@ -55,14 +49,22 @@ class RequestParams(BaseModel):
meta: Meta | None = Field(alias="_meta", default=None)
class PaginatedRequestParams(RequestParams):
cursor: Cursor | None = None
"""
An opaque token representing the current pagination position.
If provided, the server should return results starting after this cursor.
"""
class NotificationParams(BaseModel):
class Meta(BaseModel):
model_config = ConfigDict(extra="allow")
meta: Meta | None = Field(alias="_meta", default=None)
"""
This parameter name is reserved by MCP to allow clients and servers to attach
additional metadata to their notifications.
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
@ -79,12 +81,11 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]):
model_config = ConfigDict(extra="allow")
class PaginatedRequest(Request[RequestParamsT, MethodT]):
cursor: Cursor | None = None
"""
An opaque token representing the current pagination position.
If provided, the server should return results starting after this cursor.
"""
class PaginatedRequest(Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]):
"""Base class for paginated requests,
matching the schema's PaginatedRequest interface."""
params: PaginatedRequestParams | None = None
class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
@ -98,13 +99,12 @@ class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
class Result(BaseModel):
"""Base class for JSON-RPC results."""
model_config = ConfigDict(extra="allow")
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
This result property is reserved by the protocol to allow clients and servers to
attach additional metadata to their responses.
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
class PaginatedResult(Result):
@ -186,10 +186,26 @@ class EmptyResult(Result):
"""A response that indicates success but carries no data."""
class Implementation(BaseModel):
"""Describes the name and version of an MCP implementation."""
class BaseMetadata(BaseModel):
"""Base class for entities with name and optional title fields."""
name: str
"""The programmatic name of the entity."""
title: str | None = None
"""
Intended for UI and end-user contexts optimized to be human-readable and easily understood,
even by those unfamiliar with domain-specific terminology.
If not provided, the name should be used for display (except for Tool,
where `annotations.title` should be given precedence over using `name`,
if present).
"""
class Implementation(BaseMetadata):
"""Describes the name and version of an MCP implementation."""
version: str
model_config = ConfigDict(extra="allow")
@ -203,7 +219,7 @@ class RootsCapability(BaseModel):
class SamplingCapability(BaseModel):
"""Capability for logging operations."""
"""Capability for sampling operations."""
model_config = ConfigDict(extra="allow")
@ -252,6 +268,12 @@ class LoggingCapability(BaseModel):
model_config = ConfigDict(extra="allow")
class CompletionsCapability(BaseModel):
"""Capability for completions operations."""
model_config = ConfigDict(extra="allow")
class ServerCapabilities(BaseModel):
"""Capabilities that a server may support."""
@ -265,6 +287,8 @@ class ServerCapabilities(BaseModel):
"""Present if the server offers any resources to read."""
tools: ToolsCapability | None = None
"""Present if the server offers any tools to call."""
completions: CompletionsCapability | None = None
"""Present if the server offers autocompletion suggestions for prompts and resources."""
model_config = ConfigDict(extra="allow")
@ -284,7 +308,7 @@ class InitializeRequest(Request[InitializeRequestParams, Literal["initialize"]])
to begin initialization.
"""
method: Literal["initialize"]
method: Literal["initialize"] = "initialize"
params: InitializeRequestParams
@ -305,7 +329,7 @@ class InitializedNotification(Notification[NotificationParams | None, Literal["n
finished.
"""
method: Literal["notifications/initialized"]
method: Literal["notifications/initialized"] = "notifications/initialized"
params: NotificationParams | None = None
@ -315,7 +339,7 @@ class PingRequest(Request[RequestParams | None, Literal["ping"]]):
still alive.
"""
method: Literal["ping"]
method: Literal["ping"] = "ping"
params: RequestParams | None = None
@ -334,6 +358,11 @@ class ProgressNotificationParams(NotificationParams):
"""
total: float | None = None
"""Total number of items to process (or total progress required), if known."""
message: str | None = None
"""
Message related to progress. This should provide relevant human readable
progress information.
"""
model_config = ConfigDict(extra="allow")
@ -343,15 +372,14 @@ class ProgressNotification(Notification[ProgressNotificationParams, Literal["not
long-running request.
"""
method: Literal["notifications/progress"]
method: Literal["notifications/progress"] = "notifications/progress"
params: ProgressNotificationParams
class ListResourcesRequest(PaginatedRequest[RequestParams | None, Literal["resources/list"]]):
class ListResourcesRequest(PaginatedRequest[Literal["resources/list"]]):
"""Sent from the client to request a list of resources the server has."""
method: Literal["resources/list"]
params: RequestParams | None = None
method: Literal["resources/list"] = "resources/list"
class Annotations(BaseModel):
@ -360,13 +388,11 @@ class Annotations(BaseModel):
model_config = ConfigDict(extra="allow")
class Resource(BaseModel):
class Resource(BaseMetadata):
"""A known resource that the server is capable of reading."""
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
"""The URI of this resource."""
name: str
"""A human-readable name for this resource."""
description: str | None = None
"""A description of what this resource represents."""
mimeType: str | None = None
@ -379,10 +405,15 @@ class Resource(BaseModel):
This can be used by Hosts to display file sizes and estimate context window usage.
"""
annotations: Annotations | None = None
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
class ResourceTemplate(BaseModel):
class ResourceTemplate(BaseMetadata):
"""A template description for resources available on the server."""
uriTemplate: str
@ -390,8 +421,6 @@ class ResourceTemplate(BaseModel):
A URI template (according to RFC 6570) that can be used to construct resource
URIs.
"""
name: str
"""A human-readable name for the type of resource this template refers to."""
description: str | None = None
"""A human-readable description of what this template is for."""
mimeType: str | None = None
@ -400,6 +429,11 @@ class ResourceTemplate(BaseModel):
included if all resources matching this template have the same type.
"""
annotations: Annotations | None = None
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
@ -409,11 +443,10 @@ class ListResourcesResult(PaginatedResult):
resources: list[Resource]
class ListResourceTemplatesRequest(PaginatedRequest[RequestParams | None, Literal["resources/templates/list"]]):
class ListResourceTemplatesRequest(PaginatedRequest[Literal["resources/templates/list"]]):
"""Sent from the client to request a list of resource templates the server has."""
method: Literal["resources/templates/list"]
params: RequestParams | None = None
method: Literal["resources/templates/list"] = "resources/templates/list"
class ListResourceTemplatesResult(PaginatedResult):
@ -436,7 +469,7 @@ class ReadResourceRequestParams(RequestParams):
class ReadResourceRequest(Request[ReadResourceRequestParams, Literal["resources/read"]]):
"""Sent from the client to the server, to read a specific resource URI."""
method: Literal["resources/read"]
method: Literal["resources/read"] = "resources/read"
params: ReadResourceRequestParams
@ -447,6 +480,11 @@ class ResourceContents(BaseModel):
"""The URI of this resource."""
mimeType: str | None = None
"""The MIME type of this resource, if known."""
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
@ -481,7 +519,7 @@ class ResourceListChangedNotification(
of resources it can read from has changed.
"""
method: Literal["notifications/resources/list_changed"]
method: Literal["notifications/resources/list_changed"] = "notifications/resources/list_changed"
params: NotificationParams | None = None
@ -502,7 +540,7 @@ class SubscribeRequest(Request[SubscribeRequestParams, Literal["resources/subscr
whenever a particular resource changes.
"""
method: Literal["resources/subscribe"]
method: Literal["resources/subscribe"] = "resources/subscribe"
params: SubscribeRequestParams
@ -520,7 +558,7 @@ class UnsubscribeRequest(Request[UnsubscribeRequestParams, Literal["resources/un
the server.
"""
method: Literal["resources/unsubscribe"]
method: Literal["resources/unsubscribe"] = "resources/unsubscribe"
params: UnsubscribeRequestParams
@ -543,15 +581,14 @@ class ResourceUpdatedNotification(
changed and may need to be read again.
"""
method: Literal["notifications/resources/updated"]
method: Literal["notifications/resources/updated"] = "notifications/resources/updated"
params: ResourceUpdatedNotificationParams
class ListPromptsRequest(PaginatedRequest[RequestParams | None, Literal["prompts/list"]]):
class ListPromptsRequest(PaginatedRequest[Literal["prompts/list"]]):
"""Sent from the client to request a list of prompts and prompt templates."""
method: Literal["prompts/list"]
params: RequestParams | None = None
method: Literal["prompts/list"] = "prompts/list"
class PromptArgument(BaseModel):
@ -566,15 +603,18 @@ class PromptArgument(BaseModel):
model_config = ConfigDict(extra="allow")
class Prompt(BaseModel):
class Prompt(BaseMetadata):
"""A prompt or prompt template that the server offers."""
name: str
"""The name of the prompt or prompt template."""
description: str | None = None
"""An optional description of what this prompt provides."""
arguments: list[PromptArgument] | None = None
"""A list of arguments to use for templating the prompt."""
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
@ -597,7 +637,7 @@ class GetPromptRequestParams(RequestParams):
class GetPromptRequest(Request[GetPromptRequestParams, Literal["prompts/get"]]):
"""Used by the client to get a prompt provided by the server."""
method: Literal["prompts/get"]
method: Literal["prompts/get"] = "prompts/get"
params: GetPromptRequestParams
@ -608,6 +648,11 @@ class TextContent(BaseModel):
text: str
"""The text content of the message."""
annotations: Annotations | None = None
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
@ -623,6 +668,31 @@ class ImageContent(BaseModel):
image types.
"""
annotations: Annotations | None = None
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
class AudioContent(BaseModel):
"""Audio content for a message."""
type: Literal["audio"]
data: str
"""The base64-encoded audio data."""
mimeType: str
"""
The MIME type of the audio. Different providers may support different
audio types.
"""
annotations: Annotations | None = None
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
@ -630,7 +700,7 @@ class SamplingMessage(BaseModel):
"""Describes a message issued to or received from an LLM API."""
role: Role
content: TextContent | ImageContent
content: TextContent | ImageContent | AudioContent
model_config = ConfigDict(extra="allow")
@ -645,14 +715,36 @@ class EmbeddedResource(BaseModel):
type: Literal["resource"]
resource: TextResourceContents | BlobResourceContents
annotations: Annotations | None = None
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
class ResourceLink(Resource):
"""
A resource that the server is capable of reading, included in a prompt or tool call result.
Note: resource links returned by tools are not guaranteed to appear in the results of `resources/list` requests.
"""
type: Literal["resource_link"]
ContentBlock = TextContent | ImageContent | AudioContent | ResourceLink | EmbeddedResource
"""A content block that can be used in prompts and tool results."""
Content: TypeAlias = ContentBlock
# """DEPRECATED: Content is deprecated, you should use ContentBlock directly."""
class PromptMessage(BaseModel):
"""Describes a message returned as part of a prompt."""
role: Role
content: TextContent | ImageContent | EmbeddedResource
content: ContentBlock
model_config = ConfigDict(extra="allow")
@ -672,15 +764,14 @@ class PromptListChangedNotification(
of prompts it offers has changed.
"""
method: Literal["notifications/prompts/list_changed"]
method: Literal["notifications/prompts/list_changed"] = "notifications/prompts/list_changed"
params: NotificationParams | None = None
class ListToolsRequest(PaginatedRequest[RequestParams | None, Literal["tools/list"]]):
class ListToolsRequest(PaginatedRequest[Literal["tools/list"]]):
"""Sent from the client to request a list of tools the server has."""
method: Literal["tools/list"]
params: RequestParams | None = None
method: Literal["tools/list"] = "tools/list"
class ToolAnnotations(BaseModel):
@ -731,17 +822,25 @@ class ToolAnnotations(BaseModel):
model_config = ConfigDict(extra="allow")
class Tool(BaseModel):
class Tool(BaseMetadata):
"""Definition for a tool the client can call."""
name: str
"""The name of the tool."""
description: str | None = None
"""A human-readable description of the tool."""
inputSchema: dict[str, Any]
"""A JSON Schema object defining the expected parameters for the tool."""
outputSchema: dict[str, Any] | None = None
"""
An optional JSON Schema object defining the structure of the tool's output
returned in the structuredContent field of a CallToolResult.
"""
annotations: ToolAnnotations | None = None
"""Optional additional tool information."""
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
@ -762,14 +861,16 @@ class CallToolRequestParams(RequestParams):
class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]):
"""Used by the client to invoke a tool provided by the server."""
method: Literal["tools/call"]
method: Literal["tools/call"] = "tools/call"
params: CallToolRequestParams
class CallToolResult(Result):
"""The server's response to a tool call."""
content: list[TextContent | ImageContent | EmbeddedResource]
content: list[ContentBlock]
structuredContent: dict[str, Any] | None = None
"""An optional JSON object that represents the structured result of the tool call."""
isError: bool = False
@ -779,7 +880,7 @@ class ToolListChangedNotification(Notification[NotificationParams | None, Litera
of tools it offers has changed.
"""
method: Literal["notifications/tools/list_changed"]
method: Literal["notifications/tools/list_changed"] = "notifications/tools/list_changed"
params: NotificationParams | None = None
@ -797,7 +898,7 @@ class SetLevelRequestParams(RequestParams):
class SetLevelRequest(Request[SetLevelRequestParams, Literal["logging/setLevel"]]):
"""A request from the client to the server, to enable or adjust logging."""
method: Literal["logging/setLevel"]
method: Literal["logging/setLevel"] = "logging/setLevel"
params: SetLevelRequestParams
@ -808,7 +909,7 @@ class LoggingMessageNotificationParams(NotificationParams):
"""The severity of this log message."""
logger: str | None = None
"""An optional name of the logger issuing this message."""
data: Any = None
data: Any
"""
The data to be logged, such as a string message or an object. Any JSON serializable
type is allowed here.
@ -819,7 +920,7 @@ class LoggingMessageNotificationParams(NotificationParams):
class LoggingMessageNotification(Notification[LoggingMessageNotificationParams, Literal["notifications/message"]]):
"""Notification of a log message passed from server to client."""
method: Literal["notifications/message"]
method: Literal["notifications/message"] = "notifications/message"
params: LoggingMessageNotificationParams
@ -914,7 +1015,7 @@ class CreateMessageRequestParams(RequestParams):
class CreateMessageRequest(Request[CreateMessageRequestParams, Literal["sampling/createMessage"]]):
"""A request from the server to sample an LLM via the client."""
method: Literal["sampling/createMessage"]
method: Literal["sampling/createMessage"] = "sampling/createMessage"
params: CreateMessageRequestParams
@ -925,14 +1026,14 @@ class CreateMessageResult(Result):
"""The client's response to a sampling/create_message request from the server."""
role: Role
content: TextContent | ImageContent
content: TextContent | ImageContent | AudioContent
model: str
"""The name of the model that generated the message."""
stopReason: StopReason | None = None
"""The reason why sampling stopped, if known."""
class ResourceReference(BaseModel):
class ResourceTemplateReference(BaseModel):
"""A reference to a resource or resource template definition."""
type: Literal["ref/resource"]
@ -960,18 +1061,28 @@ class CompletionArgument(BaseModel):
model_config = ConfigDict(extra="allow")
class CompletionContext(BaseModel):
"""Additional, optional context for completions."""
arguments: dict[str, str] | None = None
"""Previously-resolved variables in a URI template or prompt."""
model_config = ConfigDict(extra="allow")
class CompleteRequestParams(RequestParams):
"""Parameters for completion requests."""
ref: ResourceReference | PromptReference
ref: ResourceTemplateReference | PromptReference
argument: CompletionArgument
context: CompletionContext | None = None
"""Additional, optional context for completions"""
model_config = ConfigDict(extra="allow")
class CompleteRequest(Request[CompleteRequestParams, Literal["completion/complete"]]):
"""A request from the client to the server, to ask for completion options."""
method: Literal["completion/complete"]
method: Literal["completion/complete"] = "completion/complete"
params: CompleteRequestParams
@ -1010,7 +1121,7 @@ class ListRootsRequest(Request[RequestParams | None, Literal["roots/list"]]):
structure or access specific locations that the client has permission to read from.
"""
method: Literal["roots/list"]
method: Literal["roots/list"] = "roots/list"
params: RequestParams | None = None
@ -1029,6 +1140,11 @@ class Root(BaseModel):
identifier for the root, which may be useful for display purposes or for
referencing the root in other parts of the application.
"""
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
@ -1054,7 +1170,7 @@ class RootsListChangedNotification(
using the ListRootsRequest.
"""
method: Literal["notifications/roots/list_changed"]
method: Literal["notifications/roots/list_changed"] = "notifications/roots/list_changed"
params: NotificationParams | None = None
@ -1074,7 +1190,7 @@ class CancelledNotification(Notification[CancelledNotificationParams, Literal["n
previously-issued request.
"""
method: Literal["notifications/cancelled"]
method: Literal["notifications/cancelled"] = "notifications/cancelled"
params: CancelledNotificationParams

View File

@ -7,7 +7,7 @@ import uuid
from collections import deque
from collections.abc import Sequence
from datetime import datetime
from typing import Final
from typing import Final, cast
from urllib.parse import urljoin
import httpx
@ -199,7 +199,7 @@ def convert_to_trace_id(uuid_v4: str | None) -> int:
raise ValueError("UUID cannot be None")
try:
uuid_obj = uuid.UUID(uuid_v4)
return uuid_obj.int
return cast(int, uuid_obj.int)
except ValueError as e:
raise ValueError(f"Invalid UUID input: {uuid_v4}") from e

View File

@ -13,6 +13,7 @@ class TracingProviderEnum(StrEnum):
OPIK = "opik"
WEAVE = "weave"
ALIYUN = "aliyun"
TENCENT = "tencent"
class BaseTracingConfig(BaseModel):
@ -195,5 +196,32 @@ class AliyunConfig(BaseTracingConfig):
return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com")
class TencentConfig(BaseTracingConfig):
"""
Tencent APM tracing config
"""
token: str
endpoint: str
service_name: str
@field_validator("token")
@classmethod
def token_validator(cls, v, info: ValidationInfo):
if not v or v.strip() == "":
raise ValueError("Token cannot be empty")
return v
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
return cls.validate_endpoint_url(v, "https://apm.tencentcloudapi.com")
@field_validator("service_name")
@classmethod
def service_name_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "dify_app")
OPS_FILE_PATH = "ops_trace/"
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"

View File

@ -90,6 +90,7 @@ class SuggestedQuestionTraceInfo(BaseTraceInfo):
class DatasetRetrievalTraceInfo(BaseTraceInfo):
documents: Any = None
error: str | None = None
class ToolTraceInfo(BaseTraceInfo):

View File

@ -120,6 +120,17 @@ class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
"trace_instance": AliyunDataTrace,
}
case TracingProviderEnum.TENCENT:
from core.ops.entities.config_entity import TencentConfig
from core.ops.tencent_trace.tencent_trace import TencentDataTrace
return {
"config_class": TencentConfig,
"secret_keys": ["token"],
"other_keys": ["endpoint", "service_name"],
"trace_instance": TencentDataTrace,
}
case _:
raise KeyError(f"Unsupported tracing provider: {provider}")
@ -723,6 +734,7 @@ class TraceTask:
end_time=timer.get("end"),
metadata=metadata,
message_data=message_data.to_dict(),
error=kwargs.get("error"),
)
return dataset_retrieval_trace_info
@ -889,6 +901,7 @@ class TraceQueueManager:
continue
file_id = uuid4().hex
trace_info = task.execute()
task_data = TaskData(
app_id=task.app_id,
trace_info_type=type(trace_info).__name__,

View File

View File

@ -0,0 +1,337 @@
"""
Tencent APM Trace Client - handles network operations, metrics, and API communication
"""
from __future__ import annotations
import importlib
import logging
import os
import socket
from typing import TYPE_CHECKING
from urllib.parse import urlparse
if TYPE_CHECKING:
from opentelemetry.metrics import Meter
from opentelemetry.metrics._internal.instrument import Histogram
from opentelemetry.sdk.metrics.export import MetricReader
from opentelemetry import trace as trace_api
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.resource import ResourceAttributes
from opentelemetry.trace import SpanKind
from opentelemetry.util.types import AttributeValue
from configs import dify_config
from .entities.tencent_semconv import LLM_OPERATION_DURATION
from .entities.tencent_trace_entity import SpanData
logger = logging.getLogger(__name__)
class TencentTraceClient:
"""Tencent APM trace client using OpenTelemetry OTLP exporter"""
def __init__(
self,
service_name: str,
endpoint: str,
token: str,
max_queue_size: int = 1000,
schedule_delay_sec: int = 5,
max_export_batch_size: int = 50,
metrics_export_interval_sec: int = 10,
):
self.endpoint = endpoint
self.token = token
self.service_name = service_name
self.metrics_export_interval_sec = metrics_export_interval_sec
self.resource = Resource(
attributes={
ResourceAttributes.SERVICE_NAME: service_name,
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
ResourceAttributes.HOST_NAME: socket.gethostname(),
}
)
# Prepare gRPC endpoint/metadata
grpc_endpoint, insecure, _, _ = self._resolve_grpc_target(endpoint)
headers = (("authorization", f"Bearer {token}"),)
self.exporter = OTLPSpanExporter(
endpoint=grpc_endpoint,
headers=headers,
insecure=insecure,
timeout=30,
)
self.tracer_provider = TracerProvider(resource=self.resource)
self.span_processor = BatchSpanProcessor(
span_exporter=self.exporter,
max_queue_size=max_queue_size,
schedule_delay_millis=schedule_delay_sec * 1000,
max_export_batch_size=max_export_batch_size,
)
self.tracer_provider.add_span_processor(self.span_processor)
self.tracer = self.tracer_provider.get_tracer("dify.tencent_apm")
# Store span contexts for parent-child relationships
self.span_contexts: dict[int, trace_api.SpanContext] = {}
self.meter: Meter | None = None
self.hist_llm_duration: Histogram | None = None
self.metric_reader: MetricReader | None = None
# Metrics exporter and instruments
try:
from opentelemetry import metrics
from opentelemetry.sdk.metrics import Histogram, MeterProvider
from opentelemetry.sdk.metrics.export import AggregationTemporality, PeriodicExportingMetricReader
protocol = os.getenv("OTEL_EXPORTER_OTLP_PROTOCOL", "").strip().lower()
use_http_protobuf = protocol in {"http/protobuf", "http-protobuf"}
use_http_json = protocol in {"http/json", "http-json"}
# Set preferred temporality for histograms to DELTA
preferred_temporality: dict[type, AggregationTemporality] = {Histogram: AggregationTemporality.DELTA}
def _create_metric_exporter(exporter_cls, **kwargs):
"""Create metric exporter with preferred_temporality support"""
try:
return exporter_cls(**kwargs, preferred_temporality=preferred_temporality)
except Exception:
return exporter_cls(**kwargs)
metric_reader = None
if use_http_json:
exporter_cls = None
for mod_path in (
"opentelemetry.exporter.otlp.http.json.metric_exporter",
"opentelemetry.exporter.otlp.json.metric_exporter",
):
try:
mod = importlib.import_module(mod_path)
exporter_cls = getattr(mod, "OTLPMetricExporter", None)
if exporter_cls:
break
except Exception:
continue
if exporter_cls is not None:
metric_exporter = _create_metric_exporter(
exporter_cls,
endpoint=endpoint,
headers={"authorization": f"Bearer {token}"},
)
else:
from opentelemetry.exporter.otlp.proto.http.metric_exporter import (
OTLPMetricExporter as HttpMetricExporter,
)
metric_exporter = _create_metric_exporter(
HttpMetricExporter,
endpoint=endpoint,
headers={"authorization": f"Bearer {token}"},
)
metric_reader = PeriodicExportingMetricReader(
metric_exporter, export_interval_millis=self.metrics_export_interval_sec * 1000
)
elif use_http_protobuf:
from opentelemetry.exporter.otlp.proto.http.metric_exporter import (
OTLPMetricExporter as HttpMetricExporter,
)
metric_exporter = _create_metric_exporter(
HttpMetricExporter,
endpoint=endpoint,
headers={"authorization": f"Bearer {token}"},
)
metric_reader = PeriodicExportingMetricReader(
metric_exporter, export_interval_millis=self.metrics_export_interval_sec * 1000
)
else:
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import (
OTLPMetricExporter as GrpcMetricExporter,
)
m_grpc_endpoint, m_insecure, _, _ = self._resolve_grpc_target(endpoint)
metric_exporter = _create_metric_exporter(
GrpcMetricExporter,
endpoint=m_grpc_endpoint,
headers={"authorization": f"Bearer {token}"},
insecure=m_insecure,
)
metric_reader = PeriodicExportingMetricReader(
metric_exporter, export_interval_millis=self.metrics_export_interval_sec * 1000
)
if metric_reader is not None:
provider = MeterProvider(resource=self.resource, metric_readers=[metric_reader])
metrics.set_meter_provider(provider)
self.meter = metrics.get_meter("dify-sdk", dify_config.project.version)
self.hist_llm_duration = self.meter.create_histogram(
name=LLM_OPERATION_DURATION,
unit="s",
description="LLM operation duration (seconds)",
)
self.metric_reader = metric_reader
else:
self.meter = None
self.hist_llm_duration = None
self.metric_reader = None
except Exception:
logger.exception("[Tencent APM] Metrics initialization failed; metrics disabled")
self.meter = None
self.hist_llm_duration = None
self.metric_reader = None
def add_span(self, span_data: SpanData) -> None:
"""Create and export span using OpenTelemetry Tracer API"""
try:
self._create_and_export_span(span_data)
logger.debug("[Tencent APM] Created span: %s", span_data.name)
except Exception:
logger.exception("[Tencent APM] Failed to create span: %s", span_data.name)
# Metrics recording API
def record_llm_duration(self, latency_seconds: float, attributes: dict[str, str] | None = None) -> None:
"""Record LLM operation duration histogram in seconds."""
try:
if not hasattr(self, "hist_llm_duration") or self.hist_llm_duration is None:
return
attrs: dict[str, str] = {}
if attributes:
for k, v in attributes.items():
attrs[k] = str(v) if not isinstance(v, (str, int, float, bool)) else v # type: ignore[assignment]
self.hist_llm_duration.record(latency_seconds, attrs) # type: ignore[attr-defined]
except Exception:
logger.debug("[Tencent APM] Failed to record LLM duration", exc_info=True)
def _create_and_export_span(self, span_data: SpanData) -> None:
"""Create span using OpenTelemetry Tracer API"""
try:
parent_context = None
if span_data.parent_span_id and span_data.parent_span_id in self.span_contexts:
parent_context = trace_api.set_span_in_context(
trace_api.NonRecordingSpan(self.span_contexts[span_data.parent_span_id])
)
span = self.tracer.start_span(
name=span_data.name,
context=parent_context,
kind=SpanKind.INTERNAL,
start_time=span_data.start_time,
)
self.span_contexts[span_data.span_id] = span.get_span_context()
if span_data.attributes:
attributes: dict[str, AttributeValue] = {}
for key, value in span_data.attributes.items():
if isinstance(value, (int, float, bool)):
attributes[key] = value
else:
attributes[key] = str(value)
span.set_attributes(attributes)
if span_data.events:
for event in span_data.events:
span.add_event(event.name, event.attributes, event.timestamp)
if span_data.status:
span.set_status(span_data.status)
# Manually end span; do not use context manager to avoid double-end warnings
span.end(end_time=span_data.end_time)
except Exception:
logger.exception("[Tencent APM] Error creating span: %s", span_data.name)
def api_check(self) -> bool:
"""Check API connectivity using socket connection test for gRPC endpoints"""
try:
# Resolve gRPC target consistently with exporters
_, _, host, port = self._resolve_grpc_target(self.endpoint)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(5)
result = sock.connect_ex((host, port))
sock.close()
if result == 0:
logger.info("[Tencent APM] Endpoint %s:%s is accessible", host, port)
return True
else:
logger.warning("[Tencent APM] Endpoint %s:%s is not accessible", host, port)
if host in ["127.0.0.1", "localhost"]:
logger.info("[Tencent APM] Development environment detected, allowing config save")
return True
return False
except Exception:
logger.exception("[Tencent APM] API check failed")
if "127.0.0.1" in self.endpoint or "localhost" in self.endpoint:
return True
return False
def get_project_url(self) -> str:
"""Get project console URL"""
return "https://console.cloud.tencent.com/apm"
def shutdown(self) -> None:
"""Shutdown the client and export remaining spans"""
try:
if self.span_processor:
logger.info("[Tencent APM] Flushing remaining spans before shutdown")
_ = self.span_processor.force_flush()
self.span_processor.shutdown()
if self.tracer_provider:
self.tracer_provider.shutdown()
if self.metric_reader is not None:
try:
self.metric_reader.shutdown() # type: ignore[attr-defined]
except Exception:
pass
except Exception:
logger.exception("[Tencent APM] Error during client shutdown")
@staticmethod
def _resolve_grpc_target(endpoint: str, default_port: int = 4317) -> tuple[str, bool, str, int]:
"""Normalize endpoint to gRPC target and security flag.
Returns:
(grpc_endpoint, insecure, host, port)
"""
try:
if endpoint.startswith(("http://", "https://")):
parsed = urlparse(endpoint)
host = parsed.hostname or "localhost"
port = parsed.port or default_port
insecure = parsed.scheme == "http"
return f"{host}:{port}", insecure, host, port
host = endpoint
port = default_port
if ":" in endpoint:
parts = endpoint.rsplit(":", 1)
host = parts[0] or "localhost"
try:
port = int(parts[1])
except Exception:
port = default_port
insecure = ("localhost" in host) or ("127.0.0.1" in host)
return f"{host}:{port}", insecure, host, port
except Exception:
host, port = "localhost", default_port
return f"{host}:{port}", True, host, port

View File

@ -0,0 +1 @@
# Tencent trace entities module

View File

@ -0,0 +1,73 @@
from enum import Enum
# public
GEN_AI_SESSION_ID = "gen_ai.session.id"
GEN_AI_USER_ID = "gen_ai.user.id"
GEN_AI_USER_NAME = "gen_ai.user.name"
GEN_AI_SPAN_KIND = "gen_ai.span.kind"
GEN_AI_FRAMEWORK = "gen_ai.framework"
GEN_AI_IS_ENTRY = "gen_ai.is_entry" # mark to count the LLM-related traces
# Chain
INPUT_VALUE = "gen_ai.entity.input"
OUTPUT_VALUE = "gen_ai.entity.output"
# Retriever
RETRIEVAL_QUERY = "retrieval.query"
RETRIEVAL_DOCUMENT = "retrieval.document"
# GENERATION
GEN_AI_MODEL_NAME = "gen_ai.response.model"
GEN_AI_PROVIDER = "gen_ai.provider.name"
GEN_AI_USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
GEN_AI_USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
GEN_AI_PROMPT_TEMPLATE_TEMPLATE = "gen_ai.prompt_template.template"
GEN_AI_PROMPT_TEMPLATE_VARIABLE = "gen_ai.prompt_template.variable"
GEN_AI_PROMPT = "gen_ai.prompt"
GEN_AI_COMPLETION = "gen_ai.completion"
GEN_AI_RESPONSE_FINISH_REASON = "gen_ai.response.finish_reason"
# Tool
TOOL_NAME = "tool.name"
TOOL_DESCRIPTION = "tool.description"
TOOL_PARAMETERS = "tool.parameters"
# Instrumentation Library
INSTRUMENTATION_NAME = "dify-sdk"
INSTRUMENTATION_VERSION = "0.1.0"
INSTRUMENTATION_LANGUAGE = "python"
# Metrics
LLM_OPERATION_DURATION = "gen_ai.client.operation.duration"
class GenAISpanKind(Enum):
WORKFLOW = "WORKFLOW" # OpenLLMetry
RETRIEVER = "RETRIEVER" # RAG
GENERATION = "GENERATION" # Langfuse
TOOL = "TOOL" # OpenLLMetry
AGENT = "AGENT" # OpenLLMetry
TASK = "TASK" # OpenLLMetry

View File

@ -0,0 +1,21 @@
from collections.abc import Sequence
from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import Event
from opentelemetry.trace import Status, StatusCode
from pydantic import BaseModel, Field
class SpanData(BaseModel):
model_config = {"arbitrary_types_allowed": True}
trace_id: int = Field(..., description="The unique identifier for the trace.")
parent_span_id: int | None = Field(None, description="The ID of the parent span, if any.")
span_id: int = Field(..., description="The unique identifier for this span.")
name: str = Field(..., description="The name of the span.")
attributes: dict[str, str] = Field(default_factory=dict, description="Attributes associated with the span.")
events: Sequence[Event] = Field(default_factory=list, description="Events recorded in the span.")
links: Sequence[trace_api.Link] = Field(default_factory=list, description="Links to other spans.")
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
start_time: int = Field(..., description="The start time of the span in nanoseconds.")
end_time: int = Field(..., description="The end time of the span in nanoseconds.")

View File

@ -0,0 +1,372 @@
"""
Tencent APM Span Builder - handles all span construction logic
"""
import json
import logging
from datetime import datetime
from opentelemetry.trace import Status, StatusCode
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
MessageTraceInfo,
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.ops.tencent_trace.entities.tencent_semconv import (
GEN_AI_COMPLETION,
GEN_AI_FRAMEWORK,
GEN_AI_IS_ENTRY,
GEN_AI_MODEL_NAME,
GEN_AI_PROMPT,
GEN_AI_PROVIDER,
GEN_AI_RESPONSE_FINISH_REASON,
GEN_AI_SESSION_ID,
GEN_AI_SPAN_KIND,
GEN_AI_USAGE_INPUT_TOKENS,
GEN_AI_USAGE_OUTPUT_TOKENS,
GEN_AI_USAGE_TOTAL_TOKENS,
GEN_AI_USER_ID,
INPUT_VALUE,
OUTPUT_VALUE,
RETRIEVAL_DOCUMENT,
RETRIEVAL_QUERY,
TOOL_DESCRIPTION,
TOOL_NAME,
TOOL_PARAMETERS,
GenAISpanKind,
)
from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData
from core.ops.tencent_trace.utils import TencentTraceUtils
from core.rag.models.document import Document
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
logger = logging.getLogger(__name__)
class TencentSpanBuilder:
"""Builder class for constructing different types of spans"""
@staticmethod
def _get_time_nanoseconds(time_value: datetime | None) -> int:
"""Convert datetime to nanoseconds for span creation."""
return TencentTraceUtils.convert_datetime_to_nanoseconds(time_value)
@staticmethod
def build_workflow_spans(
trace_info: WorkflowTraceInfo, trace_id: int, user_id: str, links: list | None = None
) -> list[SpanData]:
"""Build workflow-related spans"""
spans = []
links = links or []
message_span_id = None
workflow_span_id = TencentTraceUtils.convert_to_span_id(trace_info.workflow_run_id, "workflow")
if hasattr(trace_info, "metadata") and trace_info.metadata.get("conversation_id"):
message_span_id = TencentTraceUtils.convert_to_span_id(trace_info.workflow_run_id, "message")
status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
if message_span_id:
message_span = TencentSpanBuilder._build_message_span(
trace_info, trace_id, message_span_id, user_id, status, links
)
spans.append(message_span)
workflow_span = TencentSpanBuilder._build_workflow_span(
trace_info, trace_id, workflow_span_id, message_span_id, user_id, status, links
)
spans.append(workflow_span)
return spans
@staticmethod
def _build_message_span(
trace_info: WorkflowTraceInfo, trace_id: int, message_span_id: int, user_id: str, status: Status, links: list
) -> SpanData:
"""Build message span for chatflow"""
return SpanData(
trace_id=trace_id,
parent_span_id=None,
span_id=message_span_id,
name="message",
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.WORKFLOW.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_IS_ENTRY: "true",
INPUT_VALUE: trace_info.workflow_run_inputs.get("sys.query", ""),
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
},
status=status,
links=links,
)
@staticmethod
def _build_workflow_span(
trace_info: WorkflowTraceInfo,
trace_id: int,
workflow_span_id: int,
message_span_id: int | None,
user_id: str,
status: Status,
links: list,
) -> SpanData:
"""Build workflow span"""
attributes = {
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.WORKFLOW.value,
GEN_AI_FRAMEWORK: "dify",
INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False),
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
}
if message_span_id is None:
attributes[GEN_AI_IS_ENTRY] = "true"
return SpanData(
trace_id=trace_id,
parent_span_id=message_span_id,
span_id=workflow_span_id,
name="workflow",
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
attributes=attributes,
status=status,
links=links,
)
@staticmethod
def build_workflow_llm_span(
trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
) -> SpanData:
"""Build LLM span for workflow nodes."""
process_data = node_execution.process_data or {}
outputs = node_execution.outputs or {}
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
span_id=TencentTraceUtils.convert_to_span_id(node_execution.id, "node"),
name="GENERATION",
start_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.created_at),
end_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_SPAN_KIND: GenAISpanKind.GENERATION.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_MODEL_NAME: process_data.get("model_name", ""),
GEN_AI_PROVIDER: process_data.get("model_provider", ""),
GEN_AI_USAGE_INPUT_TOKENS: str(usage_data.get("prompt_tokens", 0)),
GEN_AI_USAGE_OUTPUT_TOKENS: str(usage_data.get("completion_tokens", 0)),
GEN_AI_USAGE_TOTAL_TOKENS: str(usage_data.get("total_tokens", 0)),
GEN_AI_PROMPT: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
GEN_AI_COMPLETION: str(outputs.get("text", "")),
GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason", ""),
INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
OUTPUT_VALUE: str(outputs.get("text", "")),
},
status=TencentSpanBuilder._get_workflow_node_status(node_execution),
)
@staticmethod
def build_message_span(
trace_info: MessageTraceInfo, trace_id: int, user_id: str, links: list | None = None
) -> SpanData:
"""Build message span."""
links = links or []
status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
return SpanData(
trace_id=trace_id,
parent_span_id=None,
span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message"),
name="message",
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.WORKFLOW.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_IS_ENTRY: "true",
INPUT_VALUE: str(trace_info.inputs or ""),
OUTPUT_VALUE: str(trace_info.outputs or ""),
},
status=status,
links=links,
)
@staticmethod
def build_tool_span(trace_info: ToolTraceInfo, trace_id: int, parent_span_id: int) -> SpanData:
"""Build tool span."""
status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
return SpanData(
trace_id=trace_id,
parent_span_id=parent_span_id,
span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "tool"),
name=trace_info.tool_name,
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
GEN_AI_FRAMEWORK: "dify",
TOOL_NAME: trace_info.tool_name,
TOOL_DESCRIPTION: "",
TOOL_PARAMETERS: json.dumps(trace_info.tool_parameters, ensure_ascii=False),
INPUT_VALUE: json.dumps(trace_info.tool_inputs, ensure_ascii=False),
OUTPUT_VALUE: str(trace_info.tool_outputs),
},
status=status,
)
@staticmethod
def build_retrieval_span(trace_info: DatasetRetrievalTraceInfo, trace_id: int, parent_span_id: int) -> SpanData:
"""Build dataset retrieval span."""
status = Status(StatusCode.OK)
if getattr(trace_info, "error", None):
status = Status(StatusCode.ERROR, trace_info.error) # type: ignore[arg-type]
documents_data = TencentSpanBuilder._extract_retrieval_documents(trace_info.documents)
return SpanData(
trace_id=trace_id,
parent_span_id=parent_span_id,
span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "retrieval"),
name="retrieval",
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
GEN_AI_FRAMEWORK: "dify",
RETRIEVAL_QUERY: str(trace_info.inputs or ""),
RETRIEVAL_DOCUMENT: json.dumps(documents_data, ensure_ascii=False),
INPUT_VALUE: str(trace_info.inputs or ""),
OUTPUT_VALUE: json.dumps(documents_data, ensure_ascii=False),
},
status=status,
)
@staticmethod
def _get_workflow_node_status(node_execution: WorkflowNodeExecution) -> Status:
"""Get workflow node execution status."""
if node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED:
return Status(StatusCode.OK)
elif node_execution.status in [WorkflowNodeExecutionStatus.FAILED, WorkflowNodeExecutionStatus.EXCEPTION]:
return Status(StatusCode.ERROR, str(node_execution.error))
return Status(StatusCode.UNSET)
@staticmethod
def build_workflow_retrieval_span(
trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
) -> SpanData:
"""Build knowledge retrieval span for workflow nodes."""
input_value = ""
if node_execution.inputs:
input_value = str(node_execution.inputs.get("query", ""))
output_value = ""
if node_execution.outputs:
output_value = json.dumps(node_execution.outputs.get("result", []), ensure_ascii=False)
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
span_id=TencentTraceUtils.convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.created_at),
end_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
GEN_AI_FRAMEWORK: "dify",
RETRIEVAL_QUERY: input_value,
RETRIEVAL_DOCUMENT: output_value,
INPUT_VALUE: input_value,
OUTPUT_VALUE: output_value,
},
status=TencentSpanBuilder._get_workflow_node_status(node_execution),
)
@staticmethod
def build_workflow_tool_span(
trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
) -> SpanData:
"""Build tool span for workflow nodes."""
tool_des = {}
if node_execution.metadata:
tool_des = node_execution.metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO, {})
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
span_id=TencentTraceUtils.convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.created_at),
end_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
GEN_AI_FRAMEWORK: "dify",
TOOL_NAME: node_execution.title,
TOOL_DESCRIPTION: json.dumps(tool_des, ensure_ascii=False),
TOOL_PARAMETERS: json.dumps(node_execution.inputs or {}, ensure_ascii=False),
INPUT_VALUE: json.dumps(node_execution.inputs or {}, ensure_ascii=False),
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
},
status=TencentSpanBuilder._get_workflow_node_status(node_execution),
)
@staticmethod
def build_workflow_task_span(
trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
) -> SpanData:
"""Build generic task span for workflow nodes."""
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
span_id=TencentTraceUtils.convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.created_at),
end_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_SPAN_KIND: GenAISpanKind.TASK.value,
GEN_AI_FRAMEWORK: "dify",
INPUT_VALUE: json.dumps(node_execution.inputs, ensure_ascii=False),
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
},
status=TencentSpanBuilder._get_workflow_node_status(node_execution),
)
@staticmethod
def _extract_retrieval_documents(documents: list[Document]):
"""Extract documents data for retrieval tracing."""
documents_data = []
for document in documents:
document_data = {
"content": document.page_content,
"metadata": {
"dataset_id": document.metadata.get("dataset_id"),
"doc_id": document.metadata.get("doc_id"),
"document_id": document.metadata.get("document_id"),
},
"score": document.metadata.get("score"),
}
documents_data.append(document_data)
return documents_data

View File

@ -0,0 +1,317 @@
"""
Tencent APM tracing implementation with separated concerns
"""
import logging
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import TencentConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.ops.tencent_trace.client import TencentTraceClient
from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData
from core.ops.tencent_trace.span_builder import TencentSpanBuilder
from core.ops.tencent_trace.utils import TencentTraceUtils
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
)
from core.workflow.nodes import NodeType
from extensions.ext_database import db
from models import Account, App, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)
class TencentDataTrace(BaseTraceInstance):
"""
Tencent APM trace implementation with single responsibility principle.
Acts as a coordinator that delegates specific tasks to specialized classes.
"""
def __init__(self, tencent_config: TencentConfig):
super().__init__(tencent_config)
self.trace_client = TencentTraceClient(
service_name=tencent_config.service_name,
endpoint=tencent_config.endpoint,
token=tencent_config.token,
metrics_export_interval_sec=5,
)
def trace(self, trace_info: BaseTraceInfo) -> None:
"""Main tracing entry point - coordinates different trace types."""
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
elif isinstance(trace_info, MessageTraceInfo):
self.message_trace(trace_info)
elif isinstance(trace_info, ModerationTraceInfo):
pass
elif isinstance(trace_info, SuggestedQuestionTraceInfo):
self.suggested_question_trace(trace_info)
elif isinstance(trace_info, DatasetRetrievalTraceInfo):
self.dataset_retrieval_trace(trace_info)
elif isinstance(trace_info, ToolTraceInfo):
self.tool_trace(trace_info)
elif isinstance(trace_info, GenerateNameTraceInfo):
pass
def api_check(self) -> bool:
return self.trace_client.api_check()
def get_project_url(self) -> str:
return self.trace_client.get_project_url()
def workflow_trace(self, trace_info: WorkflowTraceInfo) -> None:
"""Handle workflow tracing by coordinating data retrieval and span construction."""
try:
trace_id = TencentTraceUtils.convert_to_trace_id(trace_info.workflow_run_id)
links = []
if trace_info.trace_id:
links.append(TencentTraceUtils.create_link(trace_info.trace_id))
user_id = self._get_user_id(trace_info)
workflow_spans = TencentSpanBuilder.build_workflow_spans(trace_info, trace_id, str(user_id), links)
for span in workflow_spans:
self.trace_client.add_span(span)
self._process_workflow_nodes(trace_info, trace_id)
except Exception:
logger.exception("[Tencent APM] Failed to process workflow trace")
def message_trace(self, trace_info: MessageTraceInfo) -> None:
"""Handle message tracing."""
try:
trace_id = TencentTraceUtils.convert_to_trace_id(trace_info.message_id)
user_id = self._get_user_id(trace_info)
links = []
if trace_info.trace_id:
links.append(TencentTraceUtils.create_link(trace_info.trace_id))
message_span = TencentSpanBuilder.build_message_span(trace_info, trace_id, str(user_id), links)
self.trace_client.add_span(message_span)
except Exception:
logger.exception("[Tencent APM] Failed to process message trace")
def tool_trace(self, trace_info: ToolTraceInfo) -> None:
"""Handle tool tracing."""
try:
parent_span_id = None
trace_root_id = None
if trace_info.message_id:
parent_span_id = TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message")
trace_root_id = trace_info.message_id
if parent_span_id and trace_root_id:
trace_id = TencentTraceUtils.convert_to_trace_id(trace_root_id)
tool_span = TencentSpanBuilder.build_tool_span(trace_info, trace_id, parent_span_id)
self.trace_client.add_span(tool_span)
except Exception:
logger.exception("[Tencent APM] Failed to process tool trace")
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo) -> None:
"""Handle dataset retrieval tracing."""
try:
parent_span_id = None
trace_root_id = None
if trace_info.message_id:
parent_span_id = TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message")
trace_root_id = trace_info.message_id
if parent_span_id and trace_root_id:
trace_id = TencentTraceUtils.convert_to_trace_id(trace_root_id)
retrieval_span = TencentSpanBuilder.build_retrieval_span(trace_info, trace_id, parent_span_id)
self.trace_client.add_span(retrieval_span)
except Exception:
logger.exception("[Tencent APM] Failed to process dataset retrieval trace")
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo) -> None:
"""Handle suggested question tracing"""
try:
logger.info("[Tencent APM] Processing suggested question trace")
except Exception:
logger.exception("[Tencent APM] Failed to process suggested question trace")
def _process_workflow_nodes(self, trace_info: WorkflowTraceInfo, trace_id: int) -> None:
"""Process workflow node executions."""
try:
workflow_span_id = TencentTraceUtils.convert_to_span_id(trace_info.workflow_run_id, "workflow")
node_executions = self._get_workflow_node_executions(trace_info)
for node_execution in node_executions:
try:
node_span = self._build_workflow_node_span(node_execution, trace_id, trace_info, workflow_span_id)
if node_span:
self.trace_client.add_span(node_span)
if node_execution.node_type == NodeType.LLM:
self._record_llm_metrics(node_execution)
except Exception:
logger.exception("[Tencent APM] Failed to process node execution: %s", node_execution.id)
except Exception:
logger.exception("[Tencent APM] Failed to process workflow nodes")
def _build_workflow_node_span(
self, node_execution: WorkflowNodeExecution, trace_id: int, trace_info: WorkflowTraceInfo, workflow_span_id: int
) -> SpanData | None:
"""Build span for different node types"""
try:
if node_execution.node_type == NodeType.LLM:
return TencentSpanBuilder.build_workflow_llm_span(
trace_id, workflow_span_id, trace_info, node_execution
)
elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
return TencentSpanBuilder.build_workflow_retrieval_span(
trace_id, workflow_span_id, trace_info, node_execution
)
elif node_execution.node_type == NodeType.TOOL:
return TencentSpanBuilder.build_workflow_tool_span(
trace_id, workflow_span_id, trace_info, node_execution
)
else:
# Handle all other node types as generic tasks
return TencentSpanBuilder.build_workflow_task_span(
trace_id, workflow_span_id, trace_info, node_execution
)
except Exception:
logger.debug(
"[Tencent APM] Error building span for node %s: %s",
node_execution.id,
node_execution.node_type,
exc_info=True,
)
return None
def _get_workflow_node_executions(self, trace_info: WorkflowTraceInfo) -> list[WorkflowNodeExecution]:
"""Retrieve workflow node executions from database."""
try:
session_maker = sessionmaker(bind=db.engine)
with Session(db.engine, expire_on_commit=False) as session:
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
app_stmt = select(App).where(App.id == app_id)
app = session.scalar(app_stmt)
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator")
account_stmt = select(Account).where(Account.id == app.created_by)
service_account = session.scalar(account_stmt)
if not service_account:
raise ValueError(f"Creator account not found for app {app_id}")
current_tenant = (
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
)
if not current_tenant:
raise ValueError(f"Current tenant not found for account {service_account.id}")
service_account.set_tenant_id(current_tenant.tenant_id)
repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_maker,
user=service_account,
app_id=trace_info.metadata.get("app_id"),
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
executions = repository.get_by_workflow_run(workflow_run_id=trace_info.workflow_run_id)
return list(executions)
except Exception:
logger.exception("[Tencent APM] Failed to get workflow node executions")
return []
def _get_user_id(self, trace_info: BaseTraceInfo) -> str:
"""Get user ID from trace info."""
try:
tenant_id = None
user_id = None
if isinstance(trace_info, (WorkflowTraceInfo, GenerateNameTraceInfo)):
tenant_id = trace_info.tenant_id
if hasattr(trace_info, "metadata") and trace_info.metadata:
user_id = trace_info.metadata.get("user_id")
if user_id and tenant_id:
stmt = (
select(Account.name)
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
.where(Account.id == user_id, TenantAccountJoin.tenant_id == tenant_id)
)
session_maker = sessionmaker(bind=db.engine)
with session_maker() as session:
account_name = session.scalar(stmt)
return account_name or str(user_id)
elif user_id:
return str(user_id)
return "anonymous"
except Exception:
logger.exception("[Tencent APM] Failed to get user ID")
return "unknown"
def _record_llm_metrics(self, node_execution: WorkflowNodeExecution) -> None:
"""Record LLM performance metrics"""
try:
if not hasattr(self.trace_client, "record_llm_duration"):
return
process_data = node_execution.process_data or {}
usage = process_data.get("usage", {})
latency_s = float(usage.get("latency", 0.0))
if latency_s > 0:
attributes = {
"provider": process_data.get("model_provider", ""),
"model": process_data.get("model_name", ""),
"span_kind": "GENERATION",
}
self.trace_client.record_llm_duration(latency_s, attributes)
except Exception:
logger.debug("[Tencent APM] Failed to record LLM metrics")
def __del__(self):
"""Ensure proper cleanup on garbage collection."""
try:
if hasattr(self, "trace_client"):
self.trace_client.shutdown()
except Exception:
pass

View File

@ -0,0 +1,65 @@
"""
Utility functions for Tencent APM tracing
"""
import hashlib
import random
import uuid
from datetime import datetime
from typing import cast
from opentelemetry.trace import Link, SpanContext, TraceFlags
class TencentTraceUtils:
"""Utility class for common tracing operations."""
INVALID_SPAN_ID = 0x0000000000000000
INVALID_TRACE_ID = 0x00000000000000000000000000000000
@staticmethod
def convert_to_trace_id(uuid_v4: str | None) -> int:
try:
uuid_obj = uuid.UUID(uuid_v4) if uuid_v4 else uuid.uuid4()
except Exception as e:
raise ValueError(f"Invalid UUID input: {e}")
return cast(int, uuid_obj.int)
@staticmethod
def convert_to_span_id(uuid_v4: str | None, span_type: str) -> int:
try:
uuid_obj = uuid.UUID(uuid_v4) if uuid_v4 else uuid.uuid4()
except Exception as e:
raise ValueError(f"Invalid UUID input: {e}")
combined_key = f"{uuid_obj.hex}-{span_type}"
hash_bytes = hashlib.sha256(combined_key.encode("utf-8")).digest()
return int.from_bytes(hash_bytes[:8], byteorder="big", signed=False)
@staticmethod
def generate_span_id() -> int:
span_id = random.getrandbits(64)
while span_id == TencentTraceUtils.INVALID_SPAN_ID:
span_id = random.getrandbits(64)
return span_id
@staticmethod
def convert_datetime_to_nanoseconds(start_time: datetime | None) -> int:
if start_time is None:
start_time = datetime.now()
timestamp_in_seconds = start_time.timestamp()
return int(timestamp_in_seconds * 1e9)
@staticmethod
def create_link(trace_id_str: str) -> Link:
try:
trace_id = int(trace_id_str, 16) if len(trace_id_str) == 32 else cast(int, uuid.UUID(trace_id_str).int)
except (ValueError, TypeError):
trace_id = cast(int, uuid.uuid4().int)
span_context = SpanContext(
trace_id=trace_id,
span_id=TencentTraceUtils.INVALID_SPAN_ID,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
)
return Link(span_context)

View File

@ -2,7 +2,7 @@ import inspect
import json
import logging
from collections.abc import Callable, Generator
from typing import Any, TypeVar
from typing import Any, TypeVar, cast
import httpx
from pydantic import BaseModel
@ -31,6 +31,17 @@ from core.plugin.impl.exc import (
)
plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL))
_plugin_daemon_timeout_config = cast(
float | httpx.Timeout | None,
getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 300.0),
)
plugin_daemon_request_timeout: httpx.Timeout | None
if _plugin_daemon_timeout_config is None:
plugin_daemon_request_timeout = None
elif isinstance(_plugin_daemon_timeout_config, httpx.Timeout):
plugin_daemon_request_timeout = _plugin_daemon_timeout_config
else:
plugin_daemon_request_timeout = httpx.Timeout(_plugin_daemon_timeout_config)
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
@ -58,6 +69,7 @@ class BasePluginClient:
"headers": headers,
"params": params,
"files": files,
"timeout": plugin_daemon_request_timeout,
}
if isinstance(prepared_data, dict):
request_kwargs["data"] = prepared_data
@ -116,6 +128,7 @@ class BasePluginClient:
"headers": headers,
"params": params,
"files": files,
"timeout": plugin_daemon_request_timeout,
}
if isinstance(prepared_data, dict):
stream_kwargs["data"] = prepared_data

View File

@ -1,9 +1,24 @@
"""
Weaviate vector database implementation for Dify's RAG system.
This module provides integration with Weaviate vector database for storing and retrieving
document embeddings used in retrieval-augmented generation workflows.
"""
import datetime
import json
import logging
import uuid as _uuid
from typing import Any
from urllib.parse import urlparse
import weaviate # type: ignore
import weaviate
import weaviate.classes.config as wc
from pydantic import BaseModel, model_validator
from weaviate.classes.data import DataObject
from weaviate.classes.init import Auth
from weaviate.classes.query import Filter, MetadataQuery
from weaviate.exceptions import UnexpectedStatusCodeError
from configs import dify_config
from core.rag.datasource.vdb.field import Field
@ -15,265 +30,394 @@ from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class WeaviateConfig(BaseModel):
"""
Configuration model for Weaviate connection settings.
Attributes:
endpoint: Weaviate server endpoint URL
api_key: Optional API key for authentication
batch_size: Number of objects to batch per insert operation
"""
endpoint: str
api_key: str | None = None
batch_size: int = 100
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
def validate_config(cls, values: dict) -> dict:
"""Validates that required configuration values are present."""
if not values["endpoint"]:
raise ValueError("config WEAVIATE_ENDPOINT is required")
return values
class WeaviateVector(BaseVector):
"""
Weaviate vector database implementation for document storage and retrieval.
Handles creation, insertion, deletion, and querying of document embeddings
in a Weaviate collection.
"""
def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
"""
Initializes the Weaviate vector store.
Args:
collection_name: Name of the Weaviate collection
config: Weaviate configuration settings
attributes: List of metadata attributes to store
"""
super().__init__(collection_name)
self._client = self._init_client(config)
self._attributes = attributes
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
auth_config = weaviate.AuthApiKey(api_key=config.api_key or "")
def _init_client(self, config: WeaviateConfig) -> weaviate.WeaviateClient:
"""
Initializes and returns a connected Weaviate client.
weaviate.connect.connection.has_grpc = False # ty: ignore [unresolved-attribute]
Configures both HTTP and gRPC connections with proper authentication.
"""
p = urlparse(config.endpoint)
host = p.hostname or config.endpoint.replace("https://", "").replace("http://", "")
http_secure = p.scheme == "https"
http_port = p.port or (443 if http_secure else 80)
try:
client = weaviate.Client(
url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None
)
except Exception as exc:
raise ConnectionError("Vector database connection error") from exc
grpc_host = host
grpc_secure = http_secure
grpc_port = 443 if grpc_secure else 50051
client.batch.configure(
# `batch_size` takes an `int` value to enable auto-batching
# (`None` is used for manual batching)
batch_size=config.batch_size,
# dynamically update the `batch_size` based on import speed
dynamic=True,
# `timeout_retries` takes an `int` value to retry on time outs
timeout_retries=3,
client = weaviate.connect_to_custom(
http_host=host,
http_port=http_port,
http_secure=http_secure,
grpc_host=grpc_host,
grpc_port=grpc_port,
grpc_secure=grpc_secure,
auth_credentials=Auth.api_key(config.api_key) if config.api_key else None,
)
if not client.is_ready():
raise ConnectionError("Vector database is not ready")
return client
def get_type(self) -> str:
"""Returns the vector database type identifier."""
return VectorType.WEAVIATE
def get_collection_name(self, dataset: Dataset) -> str:
"""
Retrieves or generates the collection name for a dataset.
Uses existing index structure if available, otherwise generates from dataset ID.
"""
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
if not class_prefix.endswith("_Node"):
# original class_prefix
class_prefix += "_Node"
return class_prefix
dataset_id = dataset.id
return Dataset.gen_collection_name_by_id(dataset_id)
def to_index_struct(self):
def to_index_struct(self) -> dict:
"""Returns the index structure dictionary for persistence."""
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
# create collection
"""
Creates a new collection and adds initial documents with embeddings.
"""
self._create_collection()
# create vector
self.add_texts(texts, embeddings)
def _create_collection(self):
"""
Creates the Weaviate collection with required schema if it doesn't exist.
Uses Redis locking to prevent concurrent creation attempts.
"""
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(cache_key):
return
schema = self._default_schema(self._collection_name)
if not self._client.schema.contains(schema):
# create collection
self._client.schema.create_class(schema)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
try:
if not self._client.collections.exists(self._collection_name):
self._client.collections.create(
name=self._collection_name,
properties=[
wc.Property(
name=Field.TEXT_KEY.value,
data_type=wc.DataType.TEXT,
tokenization=wc.Tokenization.WORD,
),
wc.Property(name="document_id", data_type=wc.DataType.TEXT),
wc.Property(name="doc_id", data_type=wc.DataType.TEXT),
wc.Property(name="chunk_index", data_type=wc.DataType.INT),
],
vector_config=wc.Configure.Vectors.self_provided(),
)
self._ensure_properties()
redis_client.set(cache_key, 1, ex=3600)
except Exception as e:
logger.exception("Error creating collection %s", self._collection_name)
raise
def _ensure_properties(self) -> None:
"""
Ensures all required properties exist in the collection schema.
Adds missing properties if the collection exists but lacks them.
"""
if not self._client.collections.exists(self._collection_name):
return
col = self._client.collections.use(self._collection_name)
cfg = col.config.get()
existing = {p.name for p in (cfg.properties or [])}
to_add = []
if "document_id" not in existing:
to_add.append(wc.Property(name="document_id", data_type=wc.DataType.TEXT))
if "doc_id" not in existing:
to_add.append(wc.Property(name="doc_id", data_type=wc.DataType.TEXT))
if "chunk_index" not in existing:
to_add.append(wc.Property(name="chunk_index", data_type=wc.DataType.INT))
for prop in to_add:
try:
col.config.add_property(prop)
except Exception as e:
logger.warning("Could not add property %s: %s", prop.name, e)
def _get_uuids(self, documents: list[Document]) -> list[str]:
"""
Generates deterministic UUIDs for documents based on their content.
Uses UUID5 with URL namespace to ensure consistent IDs for identical content.
"""
URL_NAMESPACE = _uuid.UUID("6ba7b811-9dad-11d1-80b4-00c04fd430c8")
uuids = []
for doc in documents:
uuid_val = _uuid.uuid5(URL_NAMESPACE, doc.page_content)
uuids.append(str(uuid_val))
return uuids
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
"""
Adds documents with their embeddings to the collection.
Batches insertions for efficiency and returns the list of inserted object IDs.
"""
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
ids = []
col = self._client.collections.use(self._collection_name)
objs: list[DataObject] = []
ids_out: list[str] = []
with self._client.batch as batch:
for i, text in enumerate(texts):
data_properties = {Field.TEXT_KEY: text}
if metadatas is not None:
# metadata maybe None
for key, val in (metadatas[i] or {}).items():
data_properties[key] = self._json_serializable(val)
for i, text in enumerate(texts):
props: dict[str, Any] = {Field.TEXT_KEY.value: text}
meta = metadatas[i] or {}
for k, v in meta.items():
props[k] = self._json_serializable(v)
batch.add_data_object(
data_object=data_properties,
class_name=self._collection_name,
uuid=uuids[i],
vector=embeddings[i] if embeddings else None,
candidate = uuids[i] if uuids else None
uid = candidate if (candidate and self._is_uuid(candidate)) else str(_uuid.uuid4())
ids_out.append(uid)
vec_payload = None
if embeddings and i < len(embeddings) and embeddings[i]:
vec_payload = {"default": embeddings[i]}
objs.append(
DataObject(
uuid=uid,
properties=props, # type: ignore[arg-type] # mypy incorrectly infers DataObject signature
vector=vec_payload,
)
ids.append(uuids[i])
return ids
)
def delete_by_metadata_field(self, key: str, value: str):
# check whether the index already exists
schema = self._default_schema(self._collection_name)
if self._client.schema.contains(schema):
where_filter = {"operator": "Equal", "path": [key], "valueText": value}
batch_size = max(1, int(dify_config.WEAVIATE_BATCH_SIZE or 100))
with col.batch.dynamic() as batch:
for obj in objs:
batch.add_object(properties=obj.properties, uuid=obj.uuid, vector=obj.vector)
self._client.batch.delete_objects(class_name=self._collection_name, where=where_filter, output="minimal")
return ids_out
def _is_uuid(self, val: str) -> bool:
"""Validates whether a string is a valid UUID format."""
try:
_uuid.UUID(str(val))
return True
except Exception:
return False
def delete_by_metadata_field(self, key: str, value: str) -> None:
"""Deletes all objects matching a specific metadata field value."""
if not self._client.collections.exists(self._collection_name):
return
col = self._client.collections.use(self._collection_name)
col.data.delete_many(where=Filter.by_property(key).equal(value))
def delete(self):
# check whether the index already exists
schema = self._default_schema(self._collection_name)
if self._client.schema.contains(schema):
self._client.schema.delete_class(self._collection_name)
"""Deletes the entire collection from Weaviate."""
if self._client.collections.exists(self._collection_name):
self._client.collections.delete(self._collection_name)
def text_exists(self, id: str) -> bool:
collection_name = self._collection_name
schema = self._default_schema(self._collection_name)
# check whether the index already exists
if not self._client.schema.contains(schema):
"""Checks if a document with the given doc_id exists in the collection."""
if not self._client.collections.exists(self._collection_name):
return False
result = (
self._client.query.get(collection_name)
.with_additional(["id"])
.with_where(
{
"path": ["doc_id"],
"operator": "Equal",
"valueText": id,
}
)
.with_limit(1)
.do()
col = self._client.collections.use(self._collection_name)
res = col.query.fetch_objects(
filters=Filter.by_property("doc_id").equal(id),
limit=1,
return_properties=["doc_id"],
)
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
return len(res.objects) > 0
entries = result["data"]["Get"][collection_name]
if len(entries) == 0:
return False
def delete_by_ids(self, ids: list[str]) -> None:
"""
Deletes objects by their UUID identifiers.
return True
Silently ignores 404 errors for non-existent IDs.
"""
if not self._client.collections.exists(self._collection_name):
return
def delete_by_ids(self, ids: list[str]):
# check whether the index already exists
schema = self._default_schema(self._collection_name)
if self._client.schema.contains(schema):
for uuid in ids:
try:
self._client.data_object.delete(
class_name=self._collection_name,
uuid=uuid,
)
except weaviate.UnexpectedStatusCodeException as e:
# tolerate not found error
if e.status_code != 404:
raise e
col = self._client.collections.use(self._collection_name)
for uid in ids:
try:
col.data.delete_by_id(uid)
except UnexpectedStatusCodeError as e:
if getattr(e, "status_code", None) != 404:
raise
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""Look up similar documents by embedding vector in Weaviate."""
collection_name = self._collection_name
properties = self._attributes
properties.append(Field.TEXT_KEY)
query_obj = self._client.query.get(collection_name, properties)
"""
Performs vector similarity search using the provided query vector.
vector = {"vector": query_vector}
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
operands = []
for document_id_filter in document_ids_filter:
operands.append({"path": ["document_id"], "operator": "Equal", "valueText": document_id_filter})
where_filter = {"operator": "Or", "operands": operands}
query_obj = query_obj.with_where(where_filter)
result = (
query_obj.with_near_vector(vector)
.with_limit(kwargs.get("top_k", 4))
.with_additional(["vector", "distance"])
.do()
Filters by document IDs if provided and applies score threshold.
Returns documents sorted by relevance score.
"""
if not self._client.collections.exists(self._collection_name):
return []
col = self._client.collections.use(self._collection_name)
props = list({*self._attributes, "document_id", Field.TEXT_KEY.value})
where = None
doc_ids = kwargs.get("document_ids_filter") or []
if doc_ids:
ors = [Filter.by_property("document_id").equal(x) for x in doc_ids]
where = ors[0]
for f in ors[1:]:
where = where | f
top_k = int(kwargs.get("top_k", 4))
score_threshold = float(kwargs.get("score_threshold") or 0.0)
res = col.query.near_vector(
near_vector=query_vector,
limit=top_k,
return_properties=props,
return_metadata=MetadataQuery(distance=True),
include_vector=False,
filters=where,
target_vector="default",
)
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs_and_scores = []
for res in result["data"]["Get"][collection_name]:
text = res.pop(Field.TEXT_KEY)
score = 1 - res["_additional"]["distance"]
docs_and_scores.append((Document(page_content=text, metadata=res), score))
docs: list[Document] = []
for obj in res.objects:
properties = dict(obj.properties or {})
text = properties.pop(Field.TEXT_KEY.value, "")
distance = (obj.metadata.distance if obj.metadata else None) or 1.0
score = 1.0 - distance
docs = []
for doc, score in docs_and_scores:
score_threshold = float(kwargs.get("score_threshold") or 0.0)
# check score threshold
if score >= score_threshold:
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True)
if score > score_threshold:
properties["score"] = score
docs.append(Document(page_content=text, metadata=properties))
docs.sort(key=lambda d: d.metadata.get("score", 0.0), reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
"""Return docs using BM25F.
Args:
query: Text to look up documents similar to.
Returns:
List of Documents most similar to the query.
"""
collection_name = self._collection_name
content: dict[str, Any] = {"concepts": [query]}
properties = self._attributes
properties.append(Field.TEXT_KEY)
if kwargs.get("search_distance"):
content["certainty"] = kwargs.get("search_distance")
query_obj = self._client.query.get(collection_name, properties)
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
operands = []
for document_id_filter in document_ids_filter:
operands.append({"path": ["document_id"], "operator": "Equal", "valueText": document_id_filter})
where_filter = {"operator": "Or", "operands": operands}
query_obj = query_obj.with_where(where_filter)
query_obj = query_obj.with_additional(["vector"])
properties = ["text"]
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 4)).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs = []
for res in result["data"]["Get"][collection_name]:
text = res.pop(Field.TEXT_KEY)
additional = res.pop("_additional")
docs.append(Document(page_content=text, vector=additional["vector"], metadata=res))
Performs BM25 full-text search on document content.
Filters by document IDs if provided and returns matching documents with vectors.
"""
if not self._client.collections.exists(self._collection_name):
return []
col = self._client.collections.use(self._collection_name)
props = list({*self._attributes, Field.TEXT_KEY.value})
where = None
doc_ids = kwargs.get("document_ids_filter") or []
if doc_ids:
ors = [Filter.by_property("document_id").equal(x) for x in doc_ids]
where = ors[0]
for f in ors[1:]:
where = where | f
top_k = int(kwargs.get("top_k", 4))
res = col.query.bm25(
query=query,
query_properties=[Field.TEXT_KEY.value],
limit=top_k,
return_properties=props,
include_vector=True,
filters=where,
)
docs: list[Document] = []
for obj in res.objects:
properties = dict(obj.properties or {})
text = properties.pop(Field.TEXT_KEY.value, "")
vec = obj.vector
if isinstance(vec, dict):
vec = vec.get("default") or next(iter(vec.values()), None)
docs.append(Document(page_content=text, vector=vec, metadata=properties))
return docs
def _default_schema(self, index_name: str):
return {
"class": index_name,
"properties": [
{
"name": "text",
"dataType": ["text"],
}
],
}
def _json_serializable(self, value: Any):
def _json_serializable(self, value: Any) -> Any:
"""Converts values to JSON-serializable format, handling datetime objects."""
if isinstance(value, datetime.datetime):
return value.isoformat()
return value
class WeaviateVectorFactory(AbstractVectorFactory):
"""Factory class for creating WeaviateVector instances."""
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> WeaviateVector:
"""
Initializes a WeaviateVector instance for the given dataset.
Uses existing collection name from dataset index structure or generates a new one.
Updates dataset index structure if not already set.
"""
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
@ -281,7 +425,6 @@ class WeaviateVectorFactory(AbstractVectorFactory):
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
return WeaviateVector(
collection_name=collection_name,
config=WeaviateConfig(

View File

@ -43,8 +43,7 @@ class CacheEmbedding(Embeddings):
else:
embedding_queue_indices.append(i)
# release database connection, because embedding may take a long time
db.session.close()
# NOTE: avoid closing the shared scoped session here; downstream code may still have pending work
if embedding_queue_indices:
embedding_queue_texts = [texts[i] for i in embedding_queue_indices]

View File

@ -217,3 +217,16 @@ class Tool(ABC):
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.JSON, message=ToolInvokeMessage.JsonMessage(json_object=object)
)
def create_variable_message(
self, variable_name: str, variable_value: Any, stream: bool = False
) -> ToolInvokeMessage:
"""
create a variable message
"""
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.VARIABLE,
message=ToolInvokeMessage.VariableMessage(
variable_name=variable_name, variable_value=variable_value, stream=stream
),
)

View File

@ -4,6 +4,7 @@ from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool import ToolParameter
from core.tools.entities.common_entities import I18nObject
@ -44,10 +45,14 @@ class ToolProviderApiEntity(BaseModel):
server_url: str | None = Field(default="", description="The server url of the tool")
updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
server_identifier: str | None = Field(default="", description="The server identifier of the MCP tool")
timeout: float | None = Field(default=30.0, description="The timeout of the MCP tool")
sse_read_timeout: float | None = Field(default=300.0, description="The SSE read timeout of the MCP tool")
masked_headers: dict[str, str] | None = Field(default=None, description="The masked headers of the MCP tool")
original_headers: dict[str, str] | None = Field(default=None, description="The original headers of the MCP tool")
authentication: MCPAuthentication | None = Field(default=None, description="The OAuth config of the MCP tool")
is_dynamic_registration: bool = Field(default=True, description="Whether the MCP tool is dynamically registered")
configuration: MCPConfiguration | None = Field(
default=None, description="The timeout and sse_read_timeout of the MCP tool"
)
@field_validator("tools", mode="before")
@classmethod
@ -70,8 +75,15 @@ class ToolProviderApiEntity(BaseModel):
if self.type == ToolProviderType.MCP:
optional_fields.update(self.optional_field("updated_at", self.updated_at))
optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
optional_fields.update(self.optional_field("timeout", self.timeout))
optional_fields.update(self.optional_field("sse_read_timeout", self.sse_read_timeout))
optional_fields.update(
self.optional_field(
"configuration", self.configuration.model_dump() if self.configuration else MCPConfiguration()
)
)
optional_fields.update(
self.optional_field("authentication", self.authentication.model_dump() if self.authentication else None)
)
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
optional_fields.update(self.optional_field("original_headers", self.original_headers))
return {

View File

@ -1,6 +1,6 @@
import json
from typing import Any, Self
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.types import Tool as RemoteMCPTool
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
@ -52,18 +52,25 @@ class MCPToolProviderController(ToolProviderController):
"""
from db provider
"""
tools = []
tools_data = json.loads(db_provider.tools)
remote_mcp_tools = [RemoteMCPTool.model_validate(tool) for tool in tools_data]
user = db_provider.load_user()
# Convert to entity first
provider_entity = db_provider.to_entity()
return cls.from_entity(provider_entity)
@classmethod
def from_entity(cls, entity: MCPProviderEntity) -> Self:
"""
create a MCPToolProviderController from a MCPProviderEntity
"""
remote_mcp_tools = [RemoteMCPTool(**tool) for tool in entity.tools]
tools = [
ToolEntity(
identity=ToolIdentity(
author=user.name if user else "Anonymous",
author="Anonymous", # Tool level author is not stored
name=remote_mcp_tool.name,
label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name),
provider=db_provider.server_identifier,
icon=db_provider.icon,
provider=entity.provider_id,
icon=entity.icon if isinstance(entity.icon, str) else "",
),
parameters=ToolTransformService.convert_mcp_schema_to_parameter(remote_mcp_tool.inputSchema),
description=ToolDescription(
@ -72,31 +79,32 @@ class MCPToolProviderController(ToolProviderController):
),
llm=remote_mcp_tool.description or "",
),
output_schema=remote_mcp_tool.outputSchema or {},
has_runtime_parameters=len(remote_mcp_tool.inputSchema) > 0,
)
for remote_mcp_tool in remote_mcp_tools
]
if not db_provider.icon:
if not entity.icon:
raise ValueError("Database provider icon is required")
return cls(
entity=ToolProviderEntityWithPlugin(
identity=ToolProviderIdentity(
author=user.name if user else "Anonymous",
name=db_provider.name,
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
author="Anonymous", # Provider level author is not stored in entity
name=entity.name,
label=I18nObject(en_US=entity.name, zh_Hans=entity.name),
description=I18nObject(en_US="", zh_Hans=""),
icon=db_provider.icon,
icon=entity.icon if isinstance(entity.icon, str) else "",
),
plugin_id=None,
credentials_schema=[],
tools=tools,
),
provider_id=db_provider.server_identifier or "",
tenant_id=db_provider.tenant_id or "",
server_url=db_provider.decrypted_server_url,
headers=db_provider.decrypted_headers or {},
timeout=db_provider.timeout,
sse_read_timeout=db_provider.sse_read_timeout,
provider_id=entity.provider_id,
tenant_id=entity.tenant_id,
server_url=entity.server_url,
headers=entity.headers,
timeout=entity.timeout,
sse_read_timeout=entity.sse_read_timeout,
)
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):

View File

@ -3,12 +3,14 @@ import json
from collections.abc import Generator
from typing import Any
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.mcp_client import MCPClient
from core.mcp.types import ImageContent, TextContent
from core.mcp.auth.auth_flow import auth
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPConnectionError
from core.mcp.types import CallToolResult, ImageContent, TextContent
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
from core.tools.errors import ToolInvokeError
class MCPTool(Tool):
@ -44,40 +46,32 @@ class MCPTool(Tool):
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
from core.tools.errors import ToolInvokeError
try:
with MCPClient(
self.server_url,
self.provider_id,
self.tenant_id,
authed=True,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
) as mcp_client:
tool_parameters = self._handle_none_parameter(tool_parameters)
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPAuthError as e:
raise ToolInvokeError("Please auth the tool first") from e
except MCPConnectionError as e:
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
except Exception as e:
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
result = self.invoke_remote_mcp_tool(tool_parameters)
# handle dify tool output
for content in result.content:
if isinstance(content, TextContent):
yield from self._process_text_content(content)
elif isinstance(content, ImageContent):
yield self._process_image_content(content)
# handle MCP structured output
if self.entity.output_schema and result.structuredContent:
for k, v in result.structuredContent.items():
yield self.create_variable_message(k, v)
def _process_text_content(self, content: TextContent) -> Generator[ToolInvokeMessage, None, None]:
"""Process text content and yield appropriate messages."""
try:
content_json = json.loads(content.text)
yield from self._process_json_content(content_json)
except json.JSONDecodeError:
yield self.create_text_message(content.text)
# Check if content looks like JSON before attempting to parse
text = content.text.strip()
if text and text[0] in ("{", "[") and text[-1] in ("}", "]"):
try:
content_json = json.loads(text)
yield from self._process_json_content(content_json)
return
except json.JSONDecodeError:
pass
# If not JSON or parsing failed, treat as plain text
yield self.create_text_message(content.text)
def _process_json_content(self, content_json: Any) -> Generator[ToolInvokeMessage, None, None]:
"""Process JSON content based on its type."""
@ -126,3 +120,76 @@ class MCPTool(Tool):
for key, value in parameter.items()
if value is not None and not (isinstance(value, str) and value.strip() == "")
}
def invoke_remote_mcp_tool(self, tool_parameters: dict[str, Any]) -> CallToolResult:
headers = self.headers.copy() if self.headers else {}
tool_parameters = self._handle_none_parameter(tool_parameters)
# Get provider entity to access tokens
# Get MCP service from invoke parameters or create new one
provider_entity = None
mcp_service = None
# Check if mcp_service is passed in tool_parameters
if "_mcp_service" in tool_parameters:
mcp_service = tool_parameters.pop("_mcp_service")
if mcp_service:
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
headers = provider_entity.decrypt_headers()
# Try to get existing token and add to headers
if not headers:
tokens = provider_entity.retrieve_tokens()
if tokens and tokens.access_token:
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
# Use MCPClientWithAuthRetry to handle authentication automatically
try:
with MCPClientWithAuthRetry(
server_url=provider_entity.decrypt_server_url() if provider_entity else self.server_url,
headers=headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
provider_entity=provider_entity,
auth_callback=auth if mcp_service else None,
mcp_service=mcp_service,
) as mcp_client:
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPConnectionError as e:
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
except (ValueError, TypeError, KeyError) as e:
# Catch specific exceptions that might occur during tool invocation
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
else:
# Fallback to creating service with database session
from sqlalchemy.orm import Session
from extensions.ext_database import db
from services.tools.mcp_tools_manage_service import MCPToolManageService
with Session(db.engine, expire_on_commit=False) as session:
mcp_service = MCPToolManageService(session=session)
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
headers = provider_entity.decrypt_headers()
# Try to get existing token and add to headers
if not headers:
tokens = provider_entity.retrieve_tokens()
if tokens and tokens.access_token:
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
# Use MCPClientWithAuthRetry to handle authentication automatically
try:
with MCPClientWithAuthRetry(
server_url=provider_entity.decrypt_server_url() if provider_entity else self.server_url,
headers=headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
provider_entity=provider_entity,
auth_callback=auth if mcp_service else None,
mcp_service=mcp_service,
) as mcp_client:
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPConnectionError as e:
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
except Exception as e:
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e

View File

@ -14,17 +14,32 @@ from sqlalchemy.orm import Session
from yarl import URL
import contexts
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.mcp_tool.tool import MCPTool
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.plugin_tool.tool import PluginTool
from core.tools.utils.uuid_utils import is_valid_uuid
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.workflow.entities.variable_pool import VariablePool
from extensions.ext_database import db
from models.provider_ids import ToolProviderID
from services.enterprise.plugin_manager_service import PluginCredentialType
from services.tools.mcp_tools_manage_service import MCPToolManageService
if TYPE_CHECKING:
from core.workflow.nodes.tool.entities import ToolEntity
from configs import dify_config
from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.helper.module_import_helper import load_single_subclass_from_source
from core.helper.position_helper import is_filtered
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool import Tool
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
from core.tools.builtin_tool.tool import BuiltinTool
@ -40,21 +55,11 @@ from core.tools.entities.tool_entities import (
ToolProviderType,
)
from core.tools.errors import ToolProviderNotFoundError
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.mcp_tool.tool import MCPTool
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.plugin_tool.tool import PluginTool
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
from core.tools.utils.uuid_utils import is_valid_uuid
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
from models.provider_ids import ToolProviderID
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
from services.enterprise.plugin_manager_service import PluginCredentialType
from services.tools.mcp_tools_manage_service import MCPToolManageService
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
from services.tools.tools_transform_service import ToolTransformService
if TYPE_CHECKING:
@ -718,7 +723,9 @@ class ToolManager:
)
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
if "mcp" in filters:
mcp_providers = MCPToolManageService.retrieve_mcp_tools(tenant_id, for_list=True)
with Session(db.engine) as session:
mcp_service = MCPToolManageService(session=session)
mcp_providers = mcp_service.list_providers(tenant_id=tenant_id, for_list=True)
for mcp_provider in mcp_providers:
result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider
@ -773,17 +780,12 @@ class ToolManager:
:return: the provider controller, the credentials
"""
provider: MCPToolProvider | None = (
db.session.query(MCPToolProvider)
.where(
MCPToolProvider.server_identifier == provider_id,
MCPToolProvider.tenant_id == tenant_id,
)
.first()
)
if provider is None:
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
with Session(db.engine) as session:
mcp_service = MCPToolManageService(session=session)
try:
provider = mcp_service.get_provider(server_identifier=provider_id, tenant_id=tenant_id)
except ValueError:
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
controller = MCPToolProviderController.from_db(provider)
@ -921,16 +923,15 @@ class ToolManager:
@classmethod
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str] | str:
try:
mcp_provider: MCPToolProvider | None = (
db.session.query(MCPToolProvider)
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id)
.first()
)
if mcp_provider is None:
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
return mcp_provider.provider_icon
with Session(db.engine) as session:
mcp_service = MCPToolManageService(session=session)
try:
mcp_provider = mcp_service.get_provider_entity(
provider_id=provider_id, tenant_id=tenant_id, by_server_id=True
)
return mcp_provider.provider_icon
except ValueError:
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}

View File

@ -1,10 +1,11 @@
import logging
import time as time_module
from datetime import datetime
from typing import Any
from typing import Any, cast
from pydantic import BaseModel
from sqlalchemy import update
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session
from configs import dify_config
@ -283,7 +284,7 @@ def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]
# Build and execute the update statement
stmt = update(Provider).where(*where_conditions).values(**update_values)
result = session.execute(stmt)
result = cast(CursorResult, session.execute(stmt))
rows_affected = result.rowcount
logger.debug(

View File

@ -64,7 +64,10 @@ def build_from_mapping(
config: FileUploadConfig | None = None,
strict_type_validation: bool = False,
) -> File:
transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method"))
transfer_method_value = mapping.get("transfer_method")
if not transfer_method_value:
raise ValueError("transfer_method is required in file mapping")
transfer_method = FileTransferMethod.value_of(transfer_method_value)
build_functions: dict[FileTransferMethod, Callable] = {
FileTransferMethod.LOCAL_FILE: _build_from_local_file,
@ -104,6 +107,8 @@ def build_from_mappings(
) -> Sequence[File]:
# TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query.
# Implement batch processing to reduce database load when handling multiple files.
# Filter out None/empty mappings to avoid errors
valid_mappings = [m for m in mappings if m and m.get("transfer_method")]
files = [
build_from_mapping(
mapping=mapping,
@ -111,7 +116,7 @@ def build_from_mappings(
config=config,
strict_type_validation=strict_type_validation,
)
for mapping in mappings
for mapping in valid_mappings
]
if (

View File

@ -13,6 +13,15 @@ from models.model import EndUser
#: A proxy for the current user. If no user is logged in, this will be an
#: anonymous user
current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user()))
def current_account_with_tenant():
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
assert current_user.current_tenant_id is not None, "The tenant information should be loaded."
return current_user, current_user.current_tenant_id
from typing import ParamSpec, TypeVar
P = ParamSpec("P")

View File

@ -1,16 +1,13 @@
import json
from collections.abc import Mapping
from datetime import datetime
from decimal import Decimal
from typing import TYPE_CHECKING, Any, cast
from urllib.parse import urlparse
import sqlalchemy as sa
from deprecated import deprecated
from sqlalchemy import ForeignKey, String, func
from sqlalchemy.orm import Mapped, mapped_column
from core.helper import encrypter
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
@ -21,7 +18,7 @@ from .model import Account, App, Tenant
from .types import StringUUID
if TYPE_CHECKING:
from core.mcp.types import Tool as MCPTool
from core.entities.mcp_provider import MCPProviderEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
@ -331,126 +328,36 @@ class MCPToolProvider(TypeBase):
def load_user(self) -> Account | None:
return db.session.query(Account).where(Account.id == self.user_id).first()
@property
def tenant(self) -> Tenant | None:
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
@property
def credentials(self) -> dict[str, Any]:
if not self.encrypted_credentials:
return {}
try:
return cast(dict[str, Any], json.loads(self.encrypted_credentials)) or {}
except json.JSONDecodeError:
return {}
@property
def mcp_tools(self) -> list["MCPTool"]:
from core.mcp.types import Tool as MCPTool
return [MCPTool.model_validate(tool) for tool in json.loads(self.tools)]
@property
def provider_icon(self) -> Mapping[str, str] | str:
from core.file import helpers as file_helpers
assert self.icon
try:
return json.loads(self.icon)
except json.JSONDecodeError:
return file_helpers.get_signed_file_url(self.icon)
@property
def decrypted_server_url(self) -> str:
return encrypter.decrypt_token(self.tenant_id, self.server_url)
@property
def decrypted_headers(self) -> dict[str, Any]:
"""Get decrypted headers for MCP server requests."""
from core.entities.provider_entities import BasicProviderConfig
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.tools.utils.encryption import create_provider_encrypter
try:
if not self.encrypted_headers:
return {}
headers_data = json.loads(self.encrypted_headers)
# Create dynamic config for all headers as SECRET_INPUT
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers_data]
encrypter_instance, _ = create_provider_encrypter(
tenant_id=self.tenant_id,
config=config,
cache=NoOpProviderCredentialCache(),
)
result = encrypter_instance.decrypt(headers_data)
return result
return json.loads(self.encrypted_credentials)
except Exception:
return {}
@property
def masked_headers(self) -> dict[str, Any]:
"""Get masked headers for frontend display."""
from core.entities.provider_entities import BasicProviderConfig
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.tools.utils.encryption import create_provider_encrypter
def headers(self) -> dict[str, Any]:
if self.encrypted_headers is None:
return {}
try:
if not self.encrypted_headers:
return {}
headers_data = json.loads(self.encrypted_headers)
# Create dynamic config for all headers as SECRET_INPUT
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers_data]
encrypter_instance, _ = create_provider_encrypter(
tenant_id=self.tenant_id,
config=config,
cache=NoOpProviderCredentialCache(),
)
# First decrypt, then mask
decrypted_headers = encrypter_instance.decrypt(headers_data)
result = encrypter_instance.mask_tool_credentials(decrypted_headers)
return result
return json.loads(self.encrypted_headers)
except Exception:
return {}
@property
def masked_server_url(self) -> str:
def mask_url(url: str, mask_char: str = "*") -> str:
"""
mask the url to a simple string
"""
parsed = urlparse(url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
def tool_dict(self) -> list[dict[str, Any]]:
try:
return json.loads(self.tools) if self.tools else []
except (json.JSONDecodeError, TypeError):
return []
if parsed.path and parsed.path != "/":
return f"{base_url}/{mask_char * 6}"
else:
return base_url
def to_entity(self) -> "MCPProviderEntity":
"""Convert to domain entity"""
from core.entities.mcp_provider import MCPProviderEntity
return mask_url(self.decrypted_server_url)
@property
def decrypted_credentials(self) -> dict[str, Any]:
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.utils.encryption import create_provider_encrypter
provider_controller = MCPToolProviderController.from_db(self)
encrypter, _ = create_provider_encrypter(
tenant_id=self.tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
cache=NoOpProviderCredentialCache(),
)
return encrypter.decrypt(self.credentials)
return MCPProviderEntity.from_db_model(self)
class ToolModelInvoke(TypeBase):

View File

@ -13,7 +13,7 @@ dependencies = [
"celery~=5.5.2",
"chardet~=5.1.0",
"flask~=3.1.2",
"flask-compress~=1.17",
"flask-compress>=1.17,<1.18",
"flask-cors~=6.0.0",
"flask-login~=0.6.3",
"flask-migrate~=4.0.7",
@ -88,7 +88,6 @@ dependencies = [
"sendgrid~=6.12.3",
"flask-restx~=1.3.0",
"packaging~=23.2",
"gevent-websocket>=0.10.1",
]
# Before adding new dependency, consider place it in
# alphabet order (a-z) and suitable group.
@ -217,7 +216,7 @@ vdb = [
"tidb-vector==0.0.9",
"upstash-vector==0.6.0",
"volcengine-compat~=1.0.0",
"weaviate-client~=3.24.0",
"weaviate-client>=4.0.0,<5.0.0",
"xinference-client~=1.2.2",
"mo-vector~=0.1.13",
"mysql-connector-python>=9.3.0",

View File

@ -7,8 +7,10 @@ using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
from collections.abc import Sequence
from datetime import datetime
from typing import cast
from sqlalchemy import asc, delete, desc, select
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session, sessionmaker
from models.workflow import WorkflowNodeExecutionModel
@ -181,7 +183,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
# Delete the batch
delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
result = session.execute(delete_stmt)
result = cast(CursorResult, session.execute(delete_stmt))
session.commit()
total_deleted += result.rowcount
@ -228,7 +230,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
# Delete the batch
delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
result = session.execute(delete_stmt)
result = cast(CursorResult, session.execute(delete_stmt))
session.commit()
total_deleted += result.rowcount
@ -285,6 +287,6 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
with self._session_maker() as session:
stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
result = session.execute(stmt)
result = cast(CursorResult, session.execute(stmt))
session.commit()
return result.rowcount

View File

@ -22,8 +22,10 @@ Implementation Notes:
import logging
from collections.abc import Sequence
from datetime import datetime
from typing import cast
from sqlalchemy import delete, select
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session, sessionmaker
from libs.infinite_scroll_pagination import InfiniteScrollPagination
@ -150,7 +152,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
with self._session_maker() as session:
stmt = delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids))
result = session.execute(stmt)
result = cast(CursorResult, session.execute(stmt))
session.commit()
deleted_count = result.rowcount
@ -186,7 +188,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
# Delete the batch
delete_stmt = delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids))
result = session.execute(delete_stmt)
result = cast(CursorResult, session.execute(delete_stmt))
session.commit()
batch_deleted = result.rowcount

View File

@ -8,8 +8,7 @@ from werkzeug.exceptions import NotFound
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models.account import Account
from libs.login import current_account_with_tenant
from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation
from services.feature_service import FeatureService
from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task
@ -24,10 +23,10 @@ class AppAnnotationService:
@classmethod
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
# get app info
assert isinstance(current_user, Account)
current_user, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
)
@ -63,12 +62,12 @@ class AppAnnotationService:
db.session.commit()
# if annotation reply is enabled , add annotation to index
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
assert current_user.current_tenant_id is not None
assert current_tenant_id is not None
if annotation_setting:
add_annotation_to_index_task.delay(
annotation.id,
args["question"],
current_user.current_tenant_id,
current_tenant_id,
app_id,
annotation_setting.collection_binding_id,
)
@ -86,13 +85,12 @@ class AppAnnotationService:
enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"
# send batch add segments task
redis_client.setnx(enable_app_annotation_job_key, "waiting")
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
current_user, current_tenant_id = current_account_with_tenant()
enable_annotation_reply_task.delay(
str(job_id),
app_id,
current_user.id,
current_user.current_tenant_id,
current_tenant_id,
args["score_threshold"],
args["embedding_provider_name"],
args["embedding_model_name"],
@ -101,8 +99,7 @@ class AppAnnotationService:
@classmethod
def disable_app_annotation(cls, app_id: str):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_, current_tenant_id = current_account_with_tenant()
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
cache_result = redis_client.get(disable_app_annotation_key)
if cache_result is not None:
@ -113,17 +110,16 @@ class AppAnnotationService:
disable_app_annotation_job_key = f"disable_app_annotation_job_{str(job_id)}"
# send batch add segments task
redis_client.setnx(disable_app_annotation_job_key, "waiting")
disable_annotation_reply_task.delay(str(job_id), app_id, current_user.current_tenant_id)
disable_annotation_reply_task.delay(str(job_id), app_id, current_tenant_id)
return {"job_id": job_id, "job_status": "waiting"}
@classmethod
def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str):
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
)
@ -153,11 +149,10 @@ class AppAnnotationService:
@classmethod
def export_annotation_list_by_app_id(cls, app_id: str):
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
)
@ -174,11 +169,10 @@ class AppAnnotationService:
@classmethod
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
current_user, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
)
@ -196,7 +190,7 @@ class AppAnnotationService:
add_annotation_to_index_task.delay(
annotation.id,
args["question"],
current_user.current_tenant_id,
current_tenant_id,
app_id,
annotation_setting.collection_binding_id,
)
@ -205,11 +199,10 @@ class AppAnnotationService:
@classmethod
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
)
@ -234,7 +227,7 @@ class AppAnnotationService:
update_annotation_to_index_task.delay(
annotation.id,
annotation.question,
current_user.current_tenant_id,
current_tenant_id,
app_id,
app_annotation_setting.collection_binding_id,
)
@ -244,11 +237,10 @@ class AppAnnotationService:
@classmethod
def delete_app_annotation(cls, app_id: str, annotation_id: str):
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
)
@ -277,17 +269,16 @@ class AppAnnotationService:
if app_annotation_setting:
delete_annotation_index_task.delay(
annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id
annotation.id, app_id, current_tenant_id, app_annotation_setting.collection_binding_id
)
@classmethod
def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]):
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
)
@ -317,7 +308,7 @@ class AppAnnotationService:
for annotation, annotation_setting in annotations_to_delete:
if annotation_setting:
delete_annotation_index_task.delay(
annotation.id, app_id, current_user.current_tenant_id, annotation_setting.collection_binding_id
annotation.id, app_id, current_tenant_id, annotation_setting.collection_binding_id
)
# Step 4: Bulk delete annotations in a single query
@ -333,11 +324,10 @@ class AppAnnotationService:
@classmethod
def batch_import_app_annotations(cls, app_id, file: FileStorage):
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
current_user, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
)
@ -354,7 +344,7 @@ class AppAnnotationService:
if len(result) == 0:
raise ValueError("The CSV file is empty.")
# check annotation limit
features = FeatureService.get_features(current_user.current_tenant_id)
features = FeatureService.get_features(current_tenant_id)
if features.billing.enabled:
annotation_quota_limit = features.annotation_quota_limit
if annotation_quota_limit.limit < len(result) + annotation_quota_limit.size:
@ -364,21 +354,18 @@ class AppAnnotationService:
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
# send batch add segments task
redis_client.setnx(indexing_cache_key, "waiting")
batch_import_annotations_task.delay(
str(job_id), result, app_id, current_user.current_tenant_id, current_user.id
)
batch_import_annotations_task.delay(str(job_id), result, app_id, current_tenant_id, current_user.id)
except Exception as e:
return {"error_msg": str(e)}
return {"job_id": job_id, "job_status": "waiting"}
@classmethod
def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_, current_tenant_id = current_account_with_tenant()
# get app info
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
)
@ -445,12 +432,11 @@ class AppAnnotationService:
@classmethod
def get_app_annotation_setting_by_app_id(cls, app_id: str):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_, current_tenant_id = current_account_with_tenant()
# get app info
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
)
@ -481,12 +467,11 @@ class AppAnnotationService:
@classmethod
def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
current_user, current_tenant_id = current_account_with_tenant()
# get app info
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
)
@ -531,11 +516,10 @@ class AppAnnotationService:
@classmethod
def clear_all_annotations(cls, app_id: str):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
)
@ -558,7 +542,7 @@ class AppAnnotationService:
# if annotation reply is enabled, delete annotation index
if app_annotation_setting:
delete_annotation_index_task.delay(
annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id
annotation.id, app_id, current_tenant_id, app_annotation_setting.collection_binding_id
)
db.session.delete(annotation)

View File

@ -3,7 +3,6 @@ import time
from collections.abc import Mapping
from typing import Any
from flask_login import current_user
from sqlalchemy.orm import Session
from configs import dify_config
@ -18,6 +17,7 @@ from core.tools.entities.tool_entities import CredentialType
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.login import current_account_with_tenant
from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider
from models.provider_ids import DatasourceProviderID
from services.plugin.plugin_service import PluginService
@ -93,6 +93,8 @@ class DatasourceProviderService:
"""
get credential by id
"""
current_user, _ = current_account_with_tenant()
with Session(db.engine) as session:
if credential_id:
datasource_provider = (
@ -157,6 +159,8 @@ class DatasourceProviderService:
"""
get all datasource credentials by provider
"""
current_user, _ = current_account_with_tenant()
with Session(db.engine) as session:
datasource_providers = (
session.query(DatasourceProvider)
@ -604,6 +608,8 @@ class DatasourceProviderService:
"""
provider_name = provider_id.provider_name
plugin_id = provider_id.plugin_id
current_user, _ = current_account_with_tenant()
with Session(db.engine) as session:
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}"
with redis_client.lock(lock, timeout=20):
@ -901,6 +907,8 @@ class DatasourceProviderService:
"""
update datasource credentials.
"""
current_user, _ = current_account_with_tenant()
with Session(db.engine) as session:
datasource_provider = (
session.query(DatasourceProvider)

View File

@ -102,6 +102,15 @@ class OpsService:
except Exception:
new_decrypt_tracing_config.update({"project_url": "https://arms.console.aliyun.com/"})
if tracing_provider == "tencent" and (
"project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
):
try:
project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider)
new_decrypt_tracing_config.update({"project_url": project_url})
except Exception:
new_decrypt_tracing_config.update({"project_url": "https://console.cloud.tencent.com/apm"})
trace_config_data.tracing_config = new_decrypt_tracing_config
return trace_config_data.to_dict()
@ -144,7 +153,7 @@ class OpsService:
project_url = f"{tracing_config.get('host')}/project/{project_key}"
except Exception:
project_url = None
elif tracing_provider in ("langsmith", "opik"):
elif tracing_provider in ("langsmith", "opik", "tencent"):
try:
project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider)
except Exception:

View File

@ -1,86 +1,83 @@
import hashlib
import json
import logging
from collections.abc import Callable
from datetime import datetime
from typing import Any
from sqlalchemy import or_
from sqlalchemy import or_, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity
from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPAuthError, MCPError
from core.mcp.mcp_client import MCPClient
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.utils.encryption import ProviderConfigEncrypter
from extensions.ext_database import db
from models.tools import MCPToolProvider
from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
# Constants
UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]"
CLIENT_NAME = "Dify"
EMPTY_TOOLS_JSON = "[]"
EMPTY_CREDENTIALS_JSON = "{}"
class MCPToolManageService:
"""
Service class for managing mcp tools.
"""
"""Service class for managing MCP tools and providers."""
@staticmethod
def _encrypt_headers(headers: dict[str, str], tenant_id: str) -> dict[str, str]:
def __init__(self, session: Session):
self._session = session
# ========== Provider CRUD Operations ==========
def get_provider(
self, *, provider_id: str | None = None, server_identifier: str | None = None, tenant_id: str
) -> MCPToolProvider:
"""
Encrypt headers using ProviderConfigEncrypter with all headers as SECRET_INPUT.
Get MCP provider by ID or server identifier.
Args:
headers: Dictionary of headers to encrypt
tenant_id: Tenant ID for encryption
provider_id: Provider ID (UUID)
server_identifier: Server identifier
tenant_id: Tenant ID
Returns:
Dictionary with all headers encrypted
MCPToolProvider instance
Raises:
ValueError: If provider not found
"""
if not headers:
return {}
if server_identifier:
stmt = select(MCPToolProvider).where(
MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier
)
else:
stmt = select(MCPToolProvider).where(
MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id
)
from core.entities.provider_entities import BasicProviderConfig
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.tools.utils.encryption import create_provider_encrypter
# Create dynamic config for all headers as SECRET_INPUT
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers]
encrypter_instance, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=config,
cache=NoOpProviderCredentialCache(),
)
return encrypter_instance.encrypt(headers)
@staticmethod
def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider:
res = (
db.session.query(MCPToolProvider)
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id)
.first()
)
if not res:
provider = self._session.scalar(stmt)
if not provider:
raise ValueError("MCP tool not found")
return res
return provider
@staticmethod
def get_mcp_provider_by_server_identifier(server_identifier: str, tenant_id: str) -> MCPToolProvider:
res = (
db.session.query(MCPToolProvider)
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier)
.first()
)
if not res:
raise ValueError("MCP tool not found")
return res
def get_provider_entity(self, provider_id: str, tenant_id: str, by_server_id: bool = False) -> MCPProviderEntity:
"""Get provider entity by ID or server identifier."""
if by_server_id:
db_provider = self.get_provider(server_identifier=provider_id, tenant_id=tenant_id)
else:
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
return db_provider.to_entity()
@staticmethod
def create_mcp_provider(
def create_provider(
self,
*,
tenant_id: str,
name: str,
server_url: str,
@ -89,37 +86,27 @@ class MCPToolManageService:
icon_type: str,
icon_background: str,
server_identifier: str,
timeout: float,
sse_read_timeout: float,
configuration: MCPConfiguration,
authentication: MCPAuthentication | None = None,
headers: dict[str, str] | None = None,
) -> ToolProviderApiEntity:
"""Create a new MCP provider."""
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
existing_provider = (
db.session.query(MCPToolProvider)
.where(
MCPToolProvider.tenant_id == tenant_id,
or_(
MCPToolProvider.name == name,
MCPToolProvider.server_url_hash == server_url_hash,
MCPToolProvider.server_identifier == server_identifier,
),
)
.first()
)
if existing_provider:
if existing_provider.name == name:
raise ValueError(f"MCP tool {name} already exists")
if existing_provider.server_url_hash == server_url_hash:
raise ValueError(f"MCP tool {server_url} already exists")
if existing_provider.server_identifier == server_identifier:
raise ValueError(f"MCP tool {server_identifier} already exists")
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
# Encrypt headers
encrypted_headers = None
if headers:
encrypted_headers_dict = MCPToolManageService._encrypt_headers(headers, tenant_id)
encrypted_headers = json.dumps(encrypted_headers_dict)
# Check for existing provider
self._check_provider_exists(tenant_id, name, server_url_hash, server_identifier)
# Encrypt sensitive data
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
encrypted_headers = self._prepare_encrypted_dict(headers, tenant_id) if headers else None
if authentication is not None:
# Build the full credentials structure with encrypted client_id and client_secret
encrypted_credentials = self._build_and_encrypt_credentials(
authentication.client_id, authentication.client_secret, tenant_id
)
else:
encrypted_credentials = None
# Create provider
mcp_tool = MCPToolProvider(
tenant_id=tenant_id,
name=name,
@ -127,93 +114,23 @@ class MCPToolManageService:
server_url_hash=server_url_hash,
user_id=user_id,
authed=False,
tools="[]",
icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon,
tools=EMPTY_TOOLS_JSON,
icon=self._prepare_icon(icon, icon_type, icon_background),
server_identifier=server_identifier,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
timeout=configuration.timeout,
sse_read_timeout=configuration.sse_read_timeout,
encrypted_headers=encrypted_headers,
)
db.session.add(mcp_tool)
db.session.commit()
return ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
@staticmethod
def retrieve_mcp_tools(tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]:
mcp_providers = (
db.session.query(MCPToolProvider)
.where(MCPToolProvider.tenant_id == tenant_id)
.order_by(MCPToolProvider.name)
.all()
)
return [
ToolTransformService.mcp_provider_to_user_provider(mcp_provider, for_list=for_list)
for mcp_provider in mcp_providers
]
@classmethod
def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
server_url = mcp_provider.decrypted_server_url
authed = mcp_provider.authed
headers = mcp_provider.decrypted_headers
timeout = mcp_provider.timeout
sse_read_timeout = mcp_provider.sse_read_timeout
try:
with MCPClient(
server_url,
provider_id,
tenant_id,
authed=authed,
for_list=True,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
) as mcp_client:
tools = mcp_client.list_tools()
except MCPAuthError:
raise ValueError("Please auth the tool first")
except MCPError as e:
raise ValueError(f"Failed to connect to MCP server: {e}")
try:
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
mcp_provider.authed = True
mcp_provider.updated_at = datetime.now()
db.session.commit()
except Exception:
db.session.rollback()
raise
user = mcp_provider.load_user()
if not mcp_provider.icon:
raise ValueError("MCP provider icon is required")
return ToolProviderApiEntity(
id=mcp_provider.id,
name=mcp_provider.name,
tools=ToolTransformService.mcp_tool_to_user_tool(mcp_provider, tools),
type=ToolProviderType.MCP,
icon=mcp_provider.icon,
author=user.name if user else "Anonymous",
server_url=mcp_provider.masked_server_url,
updated_at=int(mcp_provider.updated_at.timestamp()),
description=I18nObject(en_US="", zh_Hans=""),
label=I18nObject(en_US=mcp_provider.name, zh_Hans=mcp_provider.name),
plugin_unique_identifier=mcp_provider.server_identifier,
encrypted_credentials=encrypted_credentials,
)
@classmethod
def delete_mcp_tool(cls, tenant_id: str, provider_id: str):
mcp_tool = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
self._session.add(mcp_tool)
self._session.flush()
mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
return mcp_providers
db.session.delete(mcp_tool)
db.session.commit()
@classmethod
def update_mcp_provider(
cls,
def update_provider(
self,
*,
tenant_id: str,
provider_id: str,
name: str,
@ -222,31 +139,36 @@ class MCPToolManageService:
icon_type: str,
icon_background: str,
server_identifier: str,
timeout: float | None = None,
sse_read_timeout: float | None = None,
headers: dict[str, str] | None = None,
):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
configuration: MCPConfiguration,
authentication: MCPAuthentication | None = None,
) -> None:
"""Update an MCP provider."""
mcp_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
reconnect_result = None
encrypted_server_url = None
server_url_hash = None
# Handle server URL update
if UNCHANGED_SERVER_URL_PLACEHOLDER not in server_url:
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
if server_url_hash != mcp_provider.server_url_hash:
reconnect_result = cls._re_connect_mcp_provider(server_url, provider_id, tenant_id)
reconnect_result = self._reconnect_provider(
server_url=server_url,
provider=mcp_provider,
)
try:
# Update basic fields
mcp_provider.updated_at = datetime.now()
mcp_provider.name = name
mcp_provider.icon = (
json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon
)
mcp_provider.icon = self._prepare_icon(icon, icon_type, icon_background)
mcp_provider.server_identifier = server_identifier
# Update server URL if changed
if encrypted_server_url is not None and server_url_hash is not None:
mcp_provider.server_url = encrypted_server_url
mcp_provider.server_url_hash = server_url_hash
@ -256,95 +178,363 @@ class MCPToolManageService:
mcp_provider.tools = reconnect_result["tools"]
mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"]
if timeout is not None:
mcp_provider.timeout = timeout
if sse_read_timeout is not None:
mcp_provider.sse_read_timeout = sse_read_timeout
# Update optional fields
if configuration.timeout is not None:
mcp_provider.timeout = configuration.timeout
if configuration.sse_read_timeout is not None:
mcp_provider.sse_read_timeout = configuration.sse_read_timeout
if headers is not None:
# Merge masked headers from frontend with existing real values
if headers:
# existing decrypted and masked headers
existing_decrypted = mcp_provider.decrypted_headers
existing_masked = mcp_provider.masked_headers
# Build final headers: if value equals masked existing, keep original decrypted value
final_headers: dict[str, str] = {}
for key, incoming_value in headers.items():
if (
key in existing_masked
and key in existing_decrypted
and isinstance(incoming_value, str)
and incoming_value == existing_masked.get(key)
):
# unchanged, use original decrypted value
final_headers[key] = str(existing_decrypted[key])
else:
final_headers[key] = incoming_value
encrypted_headers_dict = MCPToolManageService._encrypt_headers(final_headers, tenant_id)
mcp_provider.encrypted_headers = json.dumps(encrypted_headers_dict)
# Build headers preserving unchanged masked values
final_headers = self._merge_headers_with_masked(incoming_headers=headers, mcp_provider=mcp_provider)
encrypted_headers_dict = self._prepare_encrypted_dict(final_headers, tenant_id)
mcp_provider.encrypted_headers = encrypted_headers_dict
else:
# Explicitly clear headers if empty dict passed
# Clear headers if empty dict passed
mcp_provider.encrypted_headers = None
db.session.commit()
# Update credentials if provided
if authentication is not None:
# Merge with existing credentials to handle masked values
(
final_client_id,
final_client_secret,
) = self._merge_credentials_with_masked(
authentication.client_id, authentication.client_secret, mcp_provider
)
# Build and encrypt new credentials
encrypted_credentials = self._build_and_encrypt_credentials(
final_client_id, final_client_secret, tenant_id
)
mcp_provider.encrypted_credentials = encrypted_credentials
self._session.commit()
except IntegrityError as e:
db.session.rollback()
error_msg = str(e.orig)
if "unique_mcp_provider_name" in error_msg:
raise ValueError(f"MCP tool {name} already exists")
if "unique_mcp_provider_server_url" in error_msg:
raise ValueError(f"MCP tool {server_url} already exists")
if "unique_mcp_provider_server_identifier" in error_msg:
raise ValueError(f"MCP tool {server_identifier} already exists")
raise
except Exception:
db.session.rollback()
self._session.rollback()
self._handle_integrity_error(e, name, server_url, server_identifier)
except (ValueError, AttributeError, TypeError) as e:
# Catch specific exceptions that might occur during update
# ValueError: invalid data provided
# AttributeError: missing required attributes
# TypeError: type conversion errors
self._session.rollback()
raise
@classmethod
def update_mcp_provider_credentials(
cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False
):
provider_controller = MCPToolProviderController.from_db(mcp_provider)
def delete_provider(self, *, tenant_id: str, provider_id: str) -> None:
"""Delete an MCP provider."""
mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
self._session.delete(mcp_tool)
self._session.commit()
def list_providers(self, *, tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]:
"""List all MCP providers for a tenant."""
stmt = select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id).order_by(MCPToolProvider.name)
mcp_providers = self._session.scalars(stmt).all()
return [
ToolTransformService.mcp_provider_to_user_provider(provider, for_list=for_list)
for provider in mcp_providers
]
# ========== Tool Operations ==========
def list_provider_tools(self, *, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
"""List tools from remote MCP server."""
from core.mcp.auth.auth_flow import auth
# Load provider and convert to entity
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
provider_entity = db_provider.to_entity()
# Verify authentication
if not provider_entity.authed:
raise ValueError("Please auth the tool first")
# Prepare headers with auth token
headers = self._prepare_auth_headers(provider_entity)
# Retrieve tools from remote server
server_url = provider_entity.decrypt_server_url()
try:
tools = self._retrieve_remote_mcp_tools(
server_url, headers, provider_entity, lambda p, s, c: auth(p, self, c)
)
except MCPError as e:
raise ValueError(f"Failed to connect to MCP server: {e}")
# Update database with retrieved tools
db_provider.tools = json.dumps([tool.model_dump() for tool in tools])
db_provider.authed = True
db_provider.updated_at = datetime.now()
self._session.flush()
# Build API response
return self._build_tool_provider_response(db_provider, provider_entity, tools)
# ========== OAuth and Credentials Operations ==========
def update_provider_credentials(
self, *, provider: MCPToolProvider, credentials: dict[str, Any], authed: bool | None = None
) -> None:
"""
Update provider credentials with encryption.
Args:
provider: Provider instance
credentials: Credentials to save
authed: Whether provider is authenticated (None means keep current state)
"""
from core.tools.mcp_tool.provider import MCPToolProviderController
# Encrypt new credentials
provider_controller = MCPToolProviderController.from_db(provider)
tool_configuration = ProviderConfigEncrypter(
tenant_id=mcp_provider.tenant_id,
config=list(provider_controller.get_credentials_schema()), # ty: ignore [invalid-argument-type]
tenant_id=provider.tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_config_cache=NoOpProviderCredentialCache(),
)
credentials = tool_configuration.encrypt(credentials)
mcp_provider.updated_at = datetime.now()
mcp_provider.encrypted_credentials = json.dumps({**mcp_provider.credentials, **credentials})
mcp_provider.authed = authed
if not authed:
mcp_provider.tools = "[]"
db.session.commit()
encrypted_credentials = tool_configuration.encrypt(credentials)
@classmethod
def _re_connect_mcp_provider(cls, server_url: str, provider_id: str, tenant_id: str):
# Get the existing provider to access headers and timeout settings
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
headers = mcp_provider.decrypted_headers
timeout = mcp_provider.timeout
sse_read_timeout = mcp_provider.sse_read_timeout
# Update provider
provider.updated_at = datetime.now()
provider.encrypted_credentials = json.dumps({**provider.credentials, **encrypted_credentials})
if authed is not None:
provider.authed = authed
if not authed:
provider.tools = EMPTY_TOOLS_JSON
self._session.flush()
def save_oauth_data(self, provider_id: str, tenant_id: str, data: dict[str, Any], data_type: str = "mixed") -> None:
"""
Save OAuth-related data (tokens, client info, code verifier).
Args:
provider_id: Provider ID
tenant_id: Tenant ID
data: Data to save (tokens, client info, or code verifier)
data_type: Type of data ('tokens', 'client_info', 'code_verifier', 'mixed')
"""
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
# Determine if this makes the provider authenticated
authed = data_type == "tokens" or (data_type == "mixed" and "access_token" in data) or None
self.update_provider_credentials(provider=db_provider, credentials=data, authed=authed)
def clear_provider_credentials(self, *, provider: MCPToolProvider) -> None:
"""Clear all credentials for a provider."""
provider.tools = EMPTY_TOOLS_JSON
provider.encrypted_credentials = EMPTY_CREDENTIALS_JSON
provider.updated_at = datetime.now()
provider.authed = False
self._session.commit()
# ========== Private Helper Methods ==========
def _check_provider_exists(self, tenant_id: str, name: str, server_url_hash: str, server_identifier: str) -> None:
"""Check if provider with same attributes already exists."""
stmt = select(MCPToolProvider).where(
MCPToolProvider.tenant_id == tenant_id,
or_(
MCPToolProvider.name == name,
MCPToolProvider.server_url_hash == server_url_hash,
MCPToolProvider.server_identifier == server_identifier,
),
)
existing_provider = self._session.scalar(stmt)
if existing_provider:
if existing_provider.name == name:
raise ValueError(f"MCP tool {name} already exists")
if existing_provider.server_url_hash == server_url_hash:
raise ValueError("MCP tool with this server URL already exists")
if existing_provider.server_identifier == server_identifier:
raise ValueError(f"MCP tool {server_identifier} already exists")
def _prepare_icon(self, icon: str, icon_type: str, icon_background: str) -> str:
"""Prepare icon data for storage."""
if icon_type == "emoji":
return json.dumps({"content": icon, "background": icon_background})
return icon
def _encrypt_dict_fields(self, data: dict[str, Any], secret_fields: list[str], tenant_id: str) -> str:
"""Encrypt specified fields in a dictionary.
Args:
data: Dictionary containing data to encrypt
secret_fields: List of field names to encrypt
tenant_id: Tenant ID for encryption
Returns:
JSON string of encrypted data
"""
from core.entities.provider_entities import BasicProviderConfig
from core.tools.utils.encryption import create_provider_encrypter
# Create config for secret fields
config = [
BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=field) for field in secret_fields
]
encrypter_instance, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=config,
cache=NoOpProviderCredentialCache(),
)
encrypted_data = encrypter_instance.encrypt(data)
return json.dumps(encrypted_data)
def _prepare_encrypted_dict(self, headers: dict[str, str], tenant_id: str) -> str:
"""Encrypt headers and prepare for storage."""
# All headers are treated as secret
return self._encrypt_dict_fields(headers, list(headers.keys()), tenant_id)
def _prepare_auth_headers(self, provider_entity: MCPProviderEntity) -> dict[str, str]:
"""Prepare headers with OAuth token if available."""
headers = provider_entity.decrypt_headers()
tokens = provider_entity.retrieve_tokens()
if tokens:
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
return headers
def _retrieve_remote_mcp_tools(
self,
server_url: str,
headers: dict[str, str],
provider_entity: MCPProviderEntity,
auth_callback: Callable[[MCPProviderEntity, "MCPToolManageService", str | None], dict[str, str]],
):
"""Retrieve tools from remote MCP server."""
with MCPClientWithAuthRetry(
server_url,
headers=headers,
timeout=provider_entity.timeout,
sse_read_timeout=provider_entity.sse_read_timeout,
provider_entity=provider_entity,
auth_callback=auth_callback,
mcp_service=self,
) as mcp_client:
return mcp_client.list_tools()
def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> dict[str, Any]:
"""Attempt to reconnect to MCP provider with new server URL."""
from core.mcp.auth.auth_flow import auth
provider_entity = provider.to_entity()
headers = provider_entity.headers
try:
with MCPClient(
server_url,
provider_id,
tenant_id,
authed=False,
for_list=True,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
) as mcp_client:
tools = mcp_client.list_tools()
return {
"authed": True,
"tools": json.dumps([tool.model_dump() for tool in tools]),
"encrypted_credentials": "{}",
}
tools = self._retrieve_remote_mcp_tools(
server_url, headers, provider_entity, lambda p, s, c: auth(p, self, c)
)
return {
"authed": True,
"tools": json.dumps([tool.model_dump() for tool in tools]),
"encrypted_credentials": EMPTY_CREDENTIALS_JSON,
}
except MCPAuthError:
return {"authed": False, "tools": "[]", "encrypted_credentials": "{}"}
return {"authed": False, "tools": EMPTY_TOOLS_JSON, "encrypted_credentials": EMPTY_CREDENTIALS_JSON}
except MCPError as e:
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
def _build_tool_provider_response(
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list
) -> ToolProviderApiEntity:
"""Build API response for tool provider."""
user = db_provider.load_user()
response = provider_entity.to_api_response(
user_name=user.name if user else None,
)
response["tools"] = ToolTransformService.mcp_tool_to_user_tool(db_provider, tools)
response["plugin_unique_identifier"] = provider_entity.provider_id
return ToolProviderApiEntity(**response)
def _handle_integrity_error(
self, error: IntegrityError, name: str, server_url: str, server_identifier: str
) -> None:
"""Handle database integrity errors with user-friendly messages."""
error_msg = str(error.orig)
if "unique_mcp_provider_name" in error_msg:
raise ValueError(f"MCP tool {name} already exists")
if "unique_mcp_provider_server_url" in error_msg:
raise ValueError(f"MCP tool {server_url} already exists")
if "unique_mcp_provider_server_identifier" in error_msg:
raise ValueError(f"MCP tool {server_identifier} already exists")
raise
def _merge_headers_with_masked(
self, incoming_headers: dict[str, str], mcp_provider: MCPToolProvider
) -> dict[str, str]:
"""Merge incoming headers with existing ones, preserving unchanged masked values.
Args:
incoming_headers: Headers from frontend (may contain masked values)
mcp_provider: The MCP provider instance
Returns:
Final headers dict with proper values (original for unchanged masked, new for changed)
"""
mcp_provider_entity = mcp_provider.to_entity()
existing_decrypted = mcp_provider_entity.decrypt_headers()
existing_masked = mcp_provider_entity.masked_headers()
return {
key: (str(existing_decrypted[key]) if key in existing_masked and value == existing_masked[key] else value)
for key, value in incoming_headers.items()
if key in existing_decrypted or value != existing_masked.get(key)
}
def _merge_credentials_with_masked(
self,
client_id: str,
client_secret: str | None,
mcp_provider: MCPToolProvider,
) -> tuple[
str,
str | None,
]:
"""Merge incoming credentials with existing ones, preserving unchanged masked values.
Args:
client_id: Client ID from frontend (may be masked)
client_secret: Client secret from frontend (may be masked)
mcp_provider: The MCP provider instance
Returns:
Tuple of (final_client_id, final_client_secret)
"""
mcp_provider_entity = mcp_provider.to_entity()
existing_decrypted = mcp_provider_entity.decrypt_credentials()
existing_masked = mcp_provider_entity.masked_credentials()
# Check if client_id is masked and unchanged
final_client_id = client_id
if existing_masked.get("client_id") and client_id == existing_masked["client_id"]:
# Use existing decrypted value
final_client_id = existing_decrypted.get("client_id", client_id)
# Check if client_secret is masked and unchanged
final_client_secret = client_secret
if existing_masked.get("client_secret") and client_secret == existing_masked["client_secret"]:
# Use existing decrypted value
final_client_secret = existing_decrypted.get("client_secret", client_secret)
return final_client_id, final_client_secret
def _build_and_encrypt_credentials(self, client_id: str, client_secret: str | None, tenant_id: str) -> str:
"""Build credentials and encrypt sensitive fields."""
# Create a flat structure with all credential data
credentials_data = {
"client_id": client_id,
"client_secret": client_secret,
"client_name": CLIENT_NAME,
"is_dynamic_registration": False,
}
# Only client_id and client_secret need encryption
secret_fields = ["client_id", "client_secret"] if client_secret else ["client_id"]
return self._encrypt_dict_fields(credentials_data, secret_fields, tenant_id)

View File

@ -6,6 +6,7 @@ from typing import Any, Union
from yarl import URL
from configs import dify_config
from core.entities.mcp_provider import MCPConfiguration
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.mcp.types import Tool as MCPTool
from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity
@ -233,27 +234,27 @@ class ToolTransformService:
@staticmethod
def mcp_provider_to_user_provider(db_provider: MCPToolProvider, for_list: bool = False) -> ToolProviderApiEntity:
# Convert to entity and use its API response method
provider_entity = db_provider.to_entity()
user = db_provider.load_user()
return ToolProviderApiEntity(
id=db_provider.server_identifier if not for_list else db_provider.id,
author=user.name if user else "Anonymous",
name=db_provider.name,
icon=db_provider.provider_icon,
type=ToolProviderType.MCP,
is_team_authorization=db_provider.authed,
server_url=db_provider.masked_server_url,
tools=ToolTransformService.mcp_tool_to_user_tool(
db_provider, [MCPTool.model_validate(tool) for tool in json.loads(db_provider.tools)]
),
updated_at=int(db_provider.updated_at.timestamp()),
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
description=I18nObject(en_US="", zh_Hans=""),
server_identifier=db_provider.server_identifier,
timeout=db_provider.timeout,
sse_read_timeout=db_provider.sse_read_timeout,
masked_headers=db_provider.masked_headers,
original_headers=db_provider.decrypted_headers,
response = provider_entity.to_api_response(user_name=user.name if user else None)
# Add additional fields specific to the transform
response["id"] = db_provider.server_identifier if not for_list else db_provider.id
response["tools"] = ToolTransformService.mcp_tool_to_user_tool(
db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
)
response["server_identifier"] = db_provider.server_identifier
# Convert configuration dict to MCPConfiguration object
if "configuration" in response and isinstance(response["configuration"], dict):
response["configuration"] = MCPConfiguration(
timeout=float(response["configuration"]["timeout"]),
sse_read_timeout=float(response["configuration"]["sse_read_timeout"]),
)
return ToolProviderApiEntity(**response)
@staticmethod
def mcp_tool_to_user_tool(mcp_provider: MCPToolProvider, tools: list[MCPTool]) -> list[ToolApiEntity]:
@ -266,6 +267,7 @@ class ToolTransformService:
description=I18nObject(en_US=tool.description or "", zh_Hans=tool.description or ""),
parameters=ToolTransformService.convert_mcp_schema_to_parameter(tool.inputSchema),
labels=[],
output_schema=tool.outputSchema or {},
)
for tool in tools
]
@ -412,7 +414,7 @@ class ToolTransformService:
)
@staticmethod
def convert_mcp_schema_to_parameter(schema: dict) -> list["ToolParameter"]:
def convert_mcp_schema_to_parameter(schema: dict[str, Any]) -> list["ToolParameter"]:
"""
Convert MCP JSON schema to tool parameters
@ -421,7 +423,7 @@ class ToolTransformService:
"""
def create_parameter(
name: str, description: str, param_type: str, required: bool, input_schema: dict | None = None
name: str, description: str, param_type: str, required: bool, input_schema: dict[str, Any] | None = None
) -> ToolParameter:
"""Create a ToolParameter instance with given attributes"""
input_schema_dict: dict[str, Any] = {"input_schema": input_schema} if input_schema else {}
@ -436,7 +438,9 @@ class ToolTransformService:
**input_schema_dict,
)
def process_properties(props: dict, required: list, prefix: str = "") -> list[ToolParameter]:
def process_properties(
props: dict[str, dict[str, Any]], required: list[str], prefix: str = ""
) -> list[ToolParameter]:
"""Process properties recursively"""
TYPE_MAPPING = {"integer": "number", "float": "number"}
COMPLEX_TYPES = ["array", "object"]

View File

@ -8,7 +8,6 @@ import click
import pandas as pd
from celery import shared_task
from sqlalchemy import func
from sqlalchemy.orm import Session
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
@ -50,54 +49,48 @@ def batch_create_segment_to_index_task(
indexing_cache_key = f"segment_batch_import_{job_id}"
try:
with Session(db.engine) as session:
dataset = session.get(Dataset, dataset_id)
if not dataset:
raise ValueError("Dataset not exist.")
dataset = db.session.get(Dataset, dataset_id)
if not dataset:
raise ValueError("Dataset not exist.")
dataset_document = session.get(Document, document_id)
if not dataset_document:
raise ValueError("Document not exist.")
dataset_document = db.session.get(Document, document_id)
if not dataset_document:
raise ValueError("Document not exist.")
if (
not dataset_document.enabled
or dataset_document.archived
or dataset_document.indexing_status != "completed"
):
raise ValueError("Document is not available.")
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
raise ValueError("Document is not available.")
upload_file = session.get(UploadFile, upload_file_id)
if not upload_file:
raise ValueError("UploadFile not found.")
upload_file = db.session.get(UploadFile, upload_file_id)
if not upload_file:
raise ValueError("UploadFile not found.")
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
# FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
storage.download(upload_file.key, file_path)
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
storage.download(upload_file.key, file_path)
# Skip the first row
df = pd.read_csv(file_path)
content = []
for _, row in df.iterrows():
if dataset_document.doc_form == "qa_model":
data = {"content": row.iloc[0], "answer": row.iloc[1]}
else:
data = {"content": row.iloc[0]}
content.append(data)
if len(content) == 0:
raise ValueError("The CSV file is empty.")
df = pd.read_csv(file_path)
content = []
for _, row in df.iterrows():
if dataset_document.doc_form == "qa_model":
data = {"content": row.iloc[0], "answer": row.iloc[1]}
else:
data = {"content": row.iloc[0]}
content.append(data)
if len(content) == 0:
raise ValueError("The CSV file is empty.")
document_segments = []
embedding_model = None
if dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
document_segments = []
embedding_model = None
if dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
word_count_change = 0
if embedding_model:
tokens_list = embedding_model.get_text_embedding_num_tokens(
@ -105,6 +98,7 @@ def batch_create_segment_to_index_task(
)
else:
tokens_list = [0] * len(content)
for segment, tokens in zip(content, tokens_list):
content = segment["content"]
doc_id = str(uuid.uuid4())
@ -135,11 +129,11 @@ def batch_create_segment_to_index_task(
word_count_change += segment_document.word_count
db.session.add(segment_document)
document_segments.append(segment_document)
# update document word count
assert dataset_document.word_count is not None
dataset_document.word_count += word_count_change
db.session.add(dataset_document)
# add index to db
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
db.session.commit()
redis_client.setex(indexing_cache_key, 600, "completed")

View File

@ -25,9 +25,7 @@ class TestAnnotationService:
patch("services.annotation_service.enable_annotation_reply_task") as mock_enable_task,
patch("services.annotation_service.disable_annotation_reply_task") as mock_disable_task,
patch("services.annotation_service.batch_import_annotations_task") as mock_batch_import_task,
patch(
"services.annotation_service.current_user", create_autospec(Account, instance=True)
) as mock_current_user,
patch("services.annotation_service.current_account_with_tenant") as mock_current_account_with_tenant,
):
# Setup default mock returns
mock_account_feature_service.get_features.return_value.billing.enabled = False
@ -38,6 +36,9 @@ class TestAnnotationService:
mock_disable_task.delay.return_value = None
mock_batch_import_task.delay.return_value = None
# Create mock user that will be returned by current_account_with_tenant
mock_user = create_autospec(Account, instance=True)
yield {
"account_feature_service": mock_account_feature_service,
"feature_service": mock_feature_service,
@ -47,7 +48,8 @@ class TestAnnotationService:
"enable_task": mock_enable_task,
"disable_task": mock_disable_task,
"batch_import_task": mock_batch_import_task,
"current_user": mock_current_user,
"current_account_with_tenant": mock_current_account_with_tenant,
"current_user": mock_user,
}
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
@ -107,6 +109,11 @@ class TestAnnotationService:
"""
mock_external_service_dependencies["current_user"].id = account_id
mock_external_service_dependencies["current_user"].current_tenant_id = tenant_id
# Configure current_account_with_tenant to return (user, tenant_id)
mock_external_service_dependencies["current_account_with_tenant"].return_value = (
mock_external_service_dependencies["current_user"],
tenant_id,
)
def _create_test_conversation(self, app, account, fake):
"""

View File

@ -794,16 +794,12 @@ class TestWorkflowAppService:
new_email = "changed@example.com"
account.email = new_email
db_session_with_containers.commit()
assert account.email == new_email
# Results for new email, is expected to be the same as the original email
result_with_new_email = service.get_paginate_workflow_app_logs(
session=db_session_with_containers,
app_model=app,
created_by_account=new_email,
page=1,
limit=20
session=db_session_with_containers, app_model=app, created_by_account=new_email, page=1, limit=20
)
assert result_with_new_email["total"] == 3
assert all(log.created_by_role == CreatorUserRole.ACCOUNT for log in result_with_new_email["data"])
@ -1087,15 +1083,15 @@ class TestWorkflowAppService:
assert len(result_no_session["data"]) == 0
# Test with account email that doesn't exist
result_no_account = service.get_paginate_workflow_app_logs(
session=db_session_with_containers,
app_model=app,
created_by_account="nonexistent@example.com",
page=1,
limit=20,
)
assert result_no_account["total"] == 0
assert len(result_no_account["data"]) == 0
with pytest.raises(ValueError) as exc_info:
service.get_paginate_workflow_app_logs(
session=db_session_with_containers,
app_model=app,
created_by_account="nonexistent@example.com",
page=1,
limit=20,
)
assert "Account not found" in str(exc_info.value)
def test_get_paginate_workflow_app_logs_with_complex_query_combinations(
self, db_session_with_containers, mock_external_service_dependencies

View File

@ -20,12 +20,21 @@ class TestMCPToolManageService:
patch("services.tools.mcp_tools_manage_service.ToolTransformService") as mock_tool_transform_service,
):
# Setup default mock returns
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
mock_encrypter.encrypt_token.return_value = "encrypted_server_url"
mock_tool_transform_service.mcp_provider_to_user_provider.return_value = {
"id": "test_id",
"name": "test_name",
"type": ToolProviderType.MCP,
}
mock_tool_transform_service.mcp_provider_to_user_provider.return_value = ToolProviderApiEntity(
id="test_id",
author="test_author",
name="test_name",
type=ToolProviderType.MCP,
description=I18nObject(en_US="Test Description", zh_Hans="测试描述"),
icon={"type": "emoji", "content": "🤖"},
label=I18nObject(en_US="Test Label", zh_Hans="测试标签"),
labels=[],
tools=[],
)
yield {
"encrypter": mock_encrypter,
@ -104,9 +113,9 @@ class TestMCPToolManageService:
mcp_provider = MCPToolProvider(
tenant_id=tenant_id,
name=fake.company(),
server_identifier=fake.uuid4(),
server_identifier=str(fake.uuid4()),
server_url="encrypted_server_url",
server_url_hash=fake.sha256(),
server_url_hash=str(fake.sha256()),
user_id=user_id,
authed=False,
tools="[]",
@ -144,7 +153,10 @@ class TestMCPToolManageService:
)
# Act: Execute the method under test
result = MCPToolManageService.get_mcp_provider_by_provider_id(mcp_provider.id, tenant.id)
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result = service.get_provider(provider_id=mcp_provider.id, tenant_id=tenant.id)
# Assert: Verify the expected outcomes
assert result is not None
@ -154,8 +166,6 @@ class TestMCPToolManageService:
assert result.user_id == account.id
# Verify database state
from extensions.ext_database import db
db.session.refresh(result)
assert result.id is not None
assert result.server_identifier == mcp_provider.server_identifier
@ -177,11 +187,14 @@ class TestMCPToolManageService:
db_session_with_containers, mock_external_service_dependencies
)
non_existent_id = fake.uuid4()
non_existent_id = str(fake.uuid4())
# Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="MCP tool not found"):
MCPToolManageService.get_mcp_provider_by_provider_id(non_existent_id, tenant.id)
service.get_provider(provider_id=non_existent_id, tenant_id=tenant.id)
def test_get_mcp_provider_by_provider_id_tenant_isolation(
self, db_session_with_containers, mock_external_service_dependencies
@ -210,8 +223,11 @@ class TestMCPToolManageService:
)
# Act & Assert: Verify tenant isolation
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="MCP tool not found"):
MCPToolManageService.get_mcp_provider_by_provider_id(mcp_provider1.id, tenant2.id)
service.get_provider(provider_id=mcp_provider1.id, tenant_id=tenant2.id)
def test_get_mcp_provider_by_server_identifier_success(
self, db_session_with_containers, mock_external_service_dependencies
@ -235,7 +251,10 @@ class TestMCPToolManageService:
)
# Act: Execute the method under test
result = MCPToolManageService.get_mcp_provider_by_server_identifier(mcp_provider.server_identifier, tenant.id)
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result = service.get_provider(server_identifier=mcp_provider.server_identifier, tenant_id=tenant.id)
# Assert: Verify the expected outcomes
assert result is not None
@ -245,8 +264,6 @@ class TestMCPToolManageService:
assert result.user_id == account.id
# Verify database state
from extensions.ext_database import db
db.session.refresh(result)
assert result.id is not None
assert result.name == mcp_provider.name
@ -268,11 +285,14 @@ class TestMCPToolManageService:
db_session_with_containers, mock_external_service_dependencies
)
non_existent_identifier = fake.uuid4()
non_existent_identifier = str(fake.uuid4())
# Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="MCP tool not found"):
MCPToolManageService.get_mcp_provider_by_server_identifier(non_existent_identifier, tenant.id)
service.get_provider(server_identifier=non_existent_identifier, tenant_id=tenant.id)
def test_get_mcp_provider_by_server_identifier_tenant_isolation(
self, db_session_with_containers, mock_external_service_dependencies
@ -301,8 +321,11 @@ class TestMCPToolManageService:
)
# Act & Assert: Verify tenant isolation
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="MCP tool not found"):
MCPToolManageService.get_mcp_provider_by_server_identifier(mcp_provider1.server_identifier, tenant2.id)
service.get_provider(server_identifier=mcp_provider1.server_identifier, tenant_id=tenant2.id)
def test_create_mcp_provider_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
@ -322,15 +345,30 @@ class TestMCPToolManageService:
)
# Setup mocks for provider creation
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
mock_external_service_dependencies["encrypter"].encrypt_token.return_value = "encrypted_server_url"
mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.return_value = {
"id": "new_provider_id",
"name": "Test MCP Provider",
"type": ToolProviderType.MCP,
}
mock_external_service_dependencies[
"tool_transform_service"
].mcp_provider_to_user_provider.return_value = ToolProviderApiEntity(
id="new_provider_id",
author=account.name,
name="Test MCP Provider",
type=ToolProviderType.MCP,
description=I18nObject(en_US="Test MCP Provider Description", zh_Hans="测试MCP提供者描述"),
icon={"type": "emoji", "content": "🤖"},
label=I18nObject(en_US="Test MCP Provider", zh_Hans="测试MCP提供者"),
labels=[],
tools=[],
)
# Act: Execute the method under test
result = MCPToolManageService.create_mcp_provider(
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result = service.create_provider(
tenant_id=tenant.id,
name="Test MCP Provider",
server_url="https://example.com/mcp",
@ -339,14 +377,16 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#FF6B6B",
server_identifier="test_identifier_123",
timeout=30.0,
sse_read_timeout=300.0,
configuration=MCPConfiguration(
timeout=30.0,
sse_read_timeout=300.0,
),
)
# Assert: Verify the expected outcomes
assert result is not None
assert result["name"] == "Test MCP Provider"
assert result["type"] == ToolProviderType.MCP
assert result.name == "Test MCP Provider"
assert result.type == ToolProviderType.MCP
# Verify database state
from extensions.ext_database import db
@ -386,7 +426,11 @@ class TestMCPToolManageService:
)
# Create first provider
MCPToolManageService.create_mcp_provider(
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service.create_provider(
tenant_id=tenant.id,
name="Test MCP Provider",
server_url="https://example1.com/mcp",
@ -395,13 +439,15 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#FF6B6B",
server_identifier="test_identifier_1",
timeout=30.0,
sse_read_timeout=300.0,
configuration=MCPConfiguration(
timeout=30.0,
sse_read_timeout=300.0,
),
)
# Act & Assert: Verify proper error handling for duplicate name
with pytest.raises(ValueError, match="MCP tool Test MCP Provider already exists"):
MCPToolManageService.create_mcp_provider(
service.create_provider(
tenant_id=tenant.id,
name="Test MCP Provider", # Duplicate name
server_url="https://example2.com/mcp",
@ -410,8 +456,10 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#4ECDC4",
server_identifier="test_identifier_2",
timeout=45.0,
sse_read_timeout=400.0,
configuration=MCPConfiguration(
timeout=45.0,
sse_read_timeout=400.0,
),
)
def test_create_mcp_provider_duplicate_server_url(
@ -432,7 +480,11 @@ class TestMCPToolManageService:
)
# Create first provider
MCPToolManageService.create_mcp_provider(
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service.create_provider(
tenant_id=tenant.id,
name="Test MCP Provider 1",
server_url="https://example.com/mcp",
@ -441,13 +493,15 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#FF6B6B",
server_identifier="test_identifier_1",
timeout=30.0,
sse_read_timeout=300.0,
configuration=MCPConfiguration(
timeout=30.0,
sse_read_timeout=300.0,
),
)
# Act & Assert: Verify proper error handling for duplicate server URL
with pytest.raises(ValueError, match="MCP tool https://example.com/mcp already exists"):
MCPToolManageService.create_mcp_provider(
with pytest.raises(ValueError, match="MCP tool with this server URL already exists"):
service.create_provider(
tenant_id=tenant.id,
name="Test MCP Provider 2",
server_url="https://example.com/mcp", # Duplicate URL
@ -456,8 +510,10 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#4ECDC4",
server_identifier="test_identifier_2",
timeout=45.0,
sse_read_timeout=400.0,
configuration=MCPConfiguration(
timeout=45.0,
sse_read_timeout=400.0,
),
)
def test_create_mcp_provider_duplicate_server_identifier(
@ -478,7 +534,11 @@ class TestMCPToolManageService:
)
# Create first provider
MCPToolManageService.create_mcp_provider(
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service.create_provider(
tenant_id=tenant.id,
name="Test MCP Provider 1",
server_url="https://example1.com/mcp",
@ -487,13 +547,15 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#FF6B6B",
server_identifier="test_identifier_123",
timeout=30.0,
sse_read_timeout=300.0,
configuration=MCPConfiguration(
timeout=30.0,
sse_read_timeout=300.0,
),
)
# Act & Assert: Verify proper error handling for duplicate server identifier
with pytest.raises(ValueError, match="MCP tool test_identifier_123 already exists"):
MCPToolManageService.create_mcp_provider(
service.create_provider(
tenant_id=tenant.id,
name="Test MCP Provider 2",
server_url="https://example2.com/mcp",
@ -502,8 +564,10 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#4ECDC4",
server_identifier="test_identifier_123", # Duplicate identifier
timeout=45.0,
sse_read_timeout=400.0,
configuration=MCPConfiguration(
timeout=45.0,
sse_read_timeout=400.0,
),
)
def test_retrieve_mcp_tools_success(self, db_session_with_containers, mock_external_service_dependencies):
@ -543,23 +607,59 @@ class TestMCPToolManageService:
db.session.commit()
# Setup mock for transformation service
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.side_effect = [
{"id": provider1.id, "name": provider1.name, "type": ToolProviderType.MCP},
{"id": provider2.id, "name": provider2.name, "type": ToolProviderType.MCP},
{"id": provider3.id, "name": provider3.name, "type": ToolProviderType.MCP},
ToolProviderApiEntity(
id=provider1.id,
author=account.name,
name=provider1.name,
type=ToolProviderType.MCP,
description=I18nObject(en_US="Alpha Provider Description", zh_Hans="Alpha提供者描述"),
icon={"type": "emoji", "content": "🅰️"},
label=I18nObject(en_US=provider1.name, zh_Hans=provider1.name),
labels=[],
tools=[],
),
ToolProviderApiEntity(
id=provider2.id,
author=account.name,
name=provider2.name,
type=ToolProviderType.MCP,
description=I18nObject(en_US="Beta Provider Description", zh_Hans="Beta提供者描述"),
icon={"type": "emoji", "content": "🅱️"},
label=I18nObject(en_US=provider2.name, zh_Hans=provider2.name),
labels=[],
tools=[],
),
ToolProviderApiEntity(
id=provider3.id,
author=account.name,
name=provider3.name,
type=ToolProviderType.MCP,
description=I18nObject(en_US="Gamma Provider Description", zh_Hans="Gamma提供者描述"),
icon={"type": "emoji", "content": "Γ"},
label=I18nObject(en_US=provider3.name, zh_Hans=provider3.name),
labels=[],
tools=[],
),
]
# Act: Execute the method under test
result = MCPToolManageService.retrieve_mcp_tools(tenant.id, for_list=True)
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result = service.list_providers(tenant_id=tenant.id, for_list=True)
# Assert: Verify the expected outcomes
assert result is not None
assert len(result) == 3
# Verify correct ordering by name
assert result[0]["name"] == "Alpha Provider"
assert result[1]["name"] == "Beta Provider"
assert result[2]["name"] == "Gamma Provider"
assert result[0].name == "Alpha Provider"
assert result[1].name == "Beta Provider"
assert result[2].name == "Gamma Provider"
# Verify mock interactions
assert (
@ -584,7 +684,10 @@ class TestMCPToolManageService:
# No MCP providers created for this tenant
# Act: Execute the method under test
result = MCPToolManageService.retrieve_mcp_tools(tenant.id, for_list=False)
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result = service.list_providers(tenant_id=tenant.id, for_list=False)
# Assert: Verify the expected outcomes
assert result is not None
@ -624,20 +727,46 @@ class TestMCPToolManageService:
)
# Setup mock for transformation service
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.side_effect = [
{"id": provider1.id, "name": provider1.name, "type": ToolProviderType.MCP},
{"id": provider2.id, "name": provider2.name, "type": ToolProviderType.MCP},
ToolProviderApiEntity(
id=provider1.id,
author=account1.name,
name=provider1.name,
type=ToolProviderType.MCP,
description=I18nObject(en_US="Provider 1 Description", zh_Hans="提供者1描述"),
icon={"type": "emoji", "content": "1"},
label=I18nObject(en_US=provider1.name, zh_Hans=provider1.name),
labels=[],
tools=[],
),
ToolProviderApiEntity(
id=provider2.id,
author=account2.name,
name=provider2.name,
type=ToolProviderType.MCP,
description=I18nObject(en_US="Provider 2 Description", zh_Hans="提供者2描述"),
icon={"type": "emoji", "content": "2"},
label=I18nObject(en_US=provider2.name, zh_Hans=provider2.name),
labels=[],
tools=[],
),
]
# Act: Execute the method under test for both tenants
result1 = MCPToolManageService.retrieve_mcp_tools(tenant1.id, for_list=True)
result2 = MCPToolManageService.retrieve_mcp_tools(tenant2.id, for_list=True)
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result1 = service.list_providers(tenant_id=tenant1.id, for_list=True)
result2 = service.list_providers(tenant_id=tenant2.id, for_list=True)
# Assert: Verify tenant isolation
assert len(result1) == 1
assert len(result2) == 1
assert result1[0]["id"] == provider1.id
assert result2[0]["id"] == provider2.id
assert result1[0].id == provider1.id
assert result2[0].id == provider2.id
def test_list_mcp_tool_from_remote_server_success(
self, db_session_with_containers, mock_external_service_dependencies
@ -661,17 +790,20 @@ class TestMCPToolManageService:
mcp_provider = self._create_test_mcp_provider(
db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id
)
mcp_provider.server_url = "encrypted_server_url"
mcp_provider.authed = False
# Use a valid base64 encoded string to avoid decryption errors
import base64
mcp_provider.server_url = base64.b64encode(b"encrypted_server_url").decode()
mcp_provider.authed = True # Provider must be authenticated to list tools
mcp_provider.tools = "[]"
from extensions.ext_database import db
db.session.commit()
# Mock the decrypted_server_url property to avoid encryption issues
with patch("models.tools.encrypter") as mock_encrypter:
mock_encrypter.decrypt_token.return_value = "https://example.com/mcp"
# Mock the decryption process at the rsa level to avoid key file issues
with patch("libs.rsa.decrypt") as mock_decrypt:
mock_decrypt.return_value = "https://example.com/mcp"
# Mock MCPClient and its context manager
mock_tools = [
@ -683,13 +815,16 @@ class TestMCPToolManageService:
)(),
]
with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client:
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
# Setup mock client
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
mock_client_instance.list_tools.return_value = mock_tools
# Act: Execute the method under test
result = MCPToolManageService.list_mcp_tool_from_remote_server(tenant.id, mcp_provider.id)
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result = service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id)
# Assert: Verify the expected outcomes
assert result is not None
@ -705,16 +840,8 @@ class TestMCPToolManageService:
assert mcp_provider.updated_at is not None
# Verify mock interactions
mock_mcp_client.assert_called_once_with(
"https://example.com/mcp",
mcp_provider.id,
tenant.id,
authed=False,
for_list=True,
headers={},
timeout=30.0,
sse_read_timeout=300.0,
)
# MCPClientWithAuthRetry is called with different parameters
mock_mcp_client.assert_called_once()
def test_list_mcp_tool_from_remote_server_auth_error(
self, db_session_with_containers, mock_external_service_dependencies
@ -737,7 +864,10 @@ class TestMCPToolManageService:
mcp_provider = self._create_test_mcp_provider(
db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id
)
mcp_provider.server_url = "encrypted_server_url"
# Use a valid base64 encoded string to avoid decryption errors
import base64
mcp_provider.server_url = base64.b64encode(b"encrypted_server_url").decode()
mcp_provider.authed = False
mcp_provider.tools = "[]"
@ -745,20 +875,23 @@ class TestMCPToolManageService:
db.session.commit()
# Mock the decrypted_server_url property to avoid encryption issues
with patch("models.tools.encrypter") as mock_encrypter:
mock_encrypter.decrypt_token.return_value = "https://example.com/mcp"
# Mock the decryption process at the rsa level to avoid key file issues
with patch("libs.rsa.decrypt") as mock_decrypt:
mock_decrypt.return_value = "https://example.com/mcp"
# Mock MCPClient to raise authentication error
with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client:
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
from core.mcp.error import MCPAuthError
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required")
# Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="Please auth the tool first"):
MCPToolManageService.list_mcp_tool_from_remote_server(tenant.id, mcp_provider.id)
service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id)
# Verify database state was not changed
db.session.refresh(mcp_provider)
@ -786,32 +919,38 @@ class TestMCPToolManageService:
mcp_provider = self._create_test_mcp_provider(
db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id
)
mcp_provider.server_url = "encrypted_server_url"
mcp_provider.authed = False
# Use a valid base64 encoded string to avoid decryption errors
import base64
mcp_provider.server_url = base64.b64encode(b"encrypted_server_url").decode()
mcp_provider.authed = True # Provider must be authenticated to test connection errors
mcp_provider.tools = "[]"
from extensions.ext_database import db
db.session.commit()
# Mock the decrypted_server_url property to avoid encryption issues
with patch("models.tools.encrypter") as mock_encrypter:
mock_encrypter.decrypt_token.return_value = "https://example.com/mcp"
# Mock the decryption process at the rsa level to avoid key file issues
with patch("libs.rsa.decrypt") as mock_decrypt:
mock_decrypt.return_value = "https://example.com/mcp"
# Mock MCPClient to raise connection error
with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client:
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
from core.mcp.error import MCPError
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
mock_client_instance.list_tools.side_effect = MCPError("Connection failed")
# Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="Failed to connect to MCP server: Connection failed"):
MCPToolManageService.list_mcp_tool_from_remote_server(tenant.id, mcp_provider.id)
service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id)
# Verify database state was not changed
db.session.refresh(mcp_provider)
assert mcp_provider.authed is False
assert mcp_provider.authed is True # Provider remains authenticated
assert mcp_provider.tools == "[]"
def test_delete_mcp_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
@ -840,7 +979,8 @@ class TestMCPToolManageService:
assert db.session.query(MCPToolProvider).filter_by(id=mcp_provider.id).first() is not None
# Act: Execute the method under test
MCPToolManageService.delete_mcp_tool(tenant.id, mcp_provider.id)
service = MCPToolManageService(db.session())
service.delete_provider(tenant_id=tenant.id, provider_id=mcp_provider.id)
# Assert: Verify the expected outcomes
# Provider should be deleted from database
@ -862,11 +1002,14 @@ class TestMCPToolManageService:
db_session_with_containers, mock_external_service_dependencies
)
non_existent_id = fake.uuid4()
non_existent_id = str(fake.uuid4())
# Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="MCP tool not found"):
MCPToolManageService.delete_mcp_tool(tenant.id, non_existent_id)
service.delete_provider(tenant_id=tenant.id, provider_id=non_existent_id)
def test_delete_mcp_tool_tenant_isolation(self, db_session_with_containers, mock_external_service_dependencies):
"""
@ -893,8 +1036,11 @@ class TestMCPToolManageService:
)
# Act & Assert: Verify tenant isolation
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="MCP tool not found"):
MCPToolManageService.delete_mcp_tool(tenant2.id, mcp_provider1.id)
service.delete_provider(tenant_id=tenant2.id, provider_id=mcp_provider1.id)
# Verify provider still exists in tenant1
from extensions.ext_database import db
@ -929,7 +1075,10 @@ class TestMCPToolManageService:
db.session.commit()
# Act: Execute the method under test
MCPToolManageService.update_mcp_provider(
from core.entities.mcp_provider import MCPConfiguration
service = MCPToolManageService(db.session())
service.update_provider(
tenant_id=tenant.id,
provider_id=mcp_provider.id,
name="Updated MCP Provider",
@ -938,8 +1087,10 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#4ECDC4",
server_identifier="updated_identifier_123",
timeout=45.0,
sse_read_timeout=400.0,
configuration=MCPConfiguration(
timeout=45.0,
sse_read_timeout=400.0,
),
)
# Assert: Verify the expected outcomes
@ -953,7 +1104,7 @@ class TestMCPToolManageService:
# Verify icon was updated
import json
icon_data = json.loads(mcp_provider.icon)
icon_data = json.loads(mcp_provider.icon or "{}")
assert icon_data["content"] == "🚀"
assert icon_data["background"] == "#4ECDC4"
@ -985,7 +1136,7 @@ class TestMCPToolManageService:
db.session.commit()
# Mock the reconnection method
with patch.object(MCPToolManageService, "_re_connect_mcp_provider") as mock_reconnect:
with patch.object(MCPToolManageService, "_reconnect_provider") as mock_reconnect:
mock_reconnect.return_value = {
"authed": True,
"tools": '[{"name": "test_tool"}]',
@ -993,7 +1144,11 @@ class TestMCPToolManageService:
}
# Act: Execute the method under test
MCPToolManageService.update_mcp_provider(
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service.update_provider(
tenant_id=tenant.id,
provider_id=mcp_provider.id,
name="Updated MCP Provider",
@ -1002,8 +1157,10 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#4ECDC4",
server_identifier="updated_identifier_123",
timeout=45.0,
sse_read_timeout=400.0,
configuration=MCPConfiguration(
timeout=45.0,
sse_read_timeout=400.0,
),
)
# Assert: Verify the expected outcomes
@ -1015,7 +1172,10 @@ class TestMCPToolManageService:
assert mcp_provider.updated_at is not None
# Verify reconnection was called
mock_reconnect.assert_called_once_with("https://new-example.com/mcp", mcp_provider.id, tenant.id)
mock_reconnect.assert_called_once_with(
server_url="https://new-example.com/mcp",
provider=mcp_provider,
)
def test_update_mcp_provider_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies):
"""
@ -1048,8 +1208,12 @@ class TestMCPToolManageService:
db.session.commit()
# Act & Assert: Verify proper error handling for duplicate name
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="MCP tool First Provider already exists"):
MCPToolManageService.update_mcp_provider(
service.update_provider(
tenant_id=tenant.id,
provider_id=provider2.id,
name="First Provider", # Duplicate name
@ -1058,8 +1222,10 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#4ECDC4",
server_identifier="unique_identifier",
timeout=45.0,
sse_read_timeout=400.0,
configuration=MCPConfiguration(
timeout=45.0,
sse_read_timeout=400.0,
),
)
def test_update_mcp_provider_credentials_success(
@ -1094,19 +1260,22 @@ class TestMCPToolManageService:
# Mock the provider controller and encryption
with (
patch("services.tools.mcp_tools_manage_service.MCPToolProviderController") as mock_controller,
patch("services.tools.mcp_tools_manage_service.ProviderConfigEncrypter") as mock_encrypter,
patch("core.tools.mcp_tool.provider.MCPToolProviderController") as mock_controller,
patch("core.tools.utils.encryption.ProviderConfigEncrypter") as mock_encrypter,
):
# Setup mocks
mock_controller_instance = mock_controller._from_db.return_value
mock_controller_instance = mock_controller.from_db.return_value
mock_controller_instance.get_credentials_schema.return_value = []
mock_encrypter_instance = mock_encrypter.return_value
mock_encrypter_instance.encrypt.return_value = {"new_key": "encrypted_value"}
# Act: Execute the method under test
MCPToolManageService.update_mcp_provider_credentials(
mcp_provider=mcp_provider, credentials={"new_key": "new_value"}, authed=True
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service.update_provider_credentials(
provider=mcp_provider, credentials={"new_key": "new_value"}, authed=True
)
# Assert: Verify the expected outcomes
@ -1117,7 +1286,7 @@ class TestMCPToolManageService:
# Verify credentials were encrypted and merged
import json
credentials = json.loads(mcp_provider.encrypted_credentials)
credentials = json.loads(mcp_provider.encrypted_credentials or "{}")
assert "existing_key" in credentials
assert "new_key" in credentials
@ -1152,19 +1321,22 @@ class TestMCPToolManageService:
# Mock the provider controller and encryption
with (
patch("services.tools.mcp_tools_manage_service.MCPToolProviderController") as mock_controller,
patch("services.tools.mcp_tools_manage_service.ProviderConfigEncrypter") as mock_encrypter,
patch("core.tools.mcp_tool.provider.MCPToolProviderController") as mock_controller,
patch("core.tools.utils.encryption.ProviderConfigEncrypter") as mock_encrypter,
):
# Setup mocks
mock_controller_instance = mock_controller._from_db.return_value
mock_controller_instance = mock_controller.from_db.return_value
mock_controller_instance.get_credentials_schema.return_value = []
mock_encrypter_instance = mock_encrypter.return_value
mock_encrypter_instance.encrypt.return_value = {"new_key": "encrypted_value"}
# Act: Execute the method under test
MCPToolManageService.update_mcp_provider_credentials(
mcp_provider=mcp_provider, credentials={"new_key": "new_value"}, authed=False
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service.update_provider_credentials(
provider=mcp_provider, credentials={"new_key": "new_value"}, authed=False
)
# Assert: Verify the expected outcomes
@ -1199,14 +1371,18 @@ class TestMCPToolManageService:
type("MockTool", (), {"model_dump": lambda self: {"name": "test_tool_2", "description": "Test tool 2"}})(),
]
with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client:
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
# Setup mock client
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
mock_client_instance.list_tools.return_value = mock_tools
# Act: Execute the method under test
result = MCPToolManageService._re_connect_mcp_provider(
"https://example.com/mcp", mcp_provider.id, tenant.id
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result = service._reconnect_provider(
server_url="https://example.com/mcp",
provider=mcp_provider,
)
# Assert: Verify the expected outcomes
@ -1224,16 +1400,8 @@ class TestMCPToolManageService:
assert tools_data[1]["name"] == "test_tool_2"
# Verify mock interactions
mock_mcp_client.assert_called_once_with(
"https://example.com/mcp",
mcp_provider.id,
tenant.id,
authed=False,
for_list=True,
headers={},
timeout=30.0,
sse_read_timeout=300.0,
)
provider_entity = mcp_provider.to_entity()
mock_mcp_client.assert_called_once()
def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies):
"""
@ -1256,15 +1424,19 @@ class TestMCPToolManageService:
)
# Mock MCPClient to raise authentication error
with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client:
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
from core.mcp.error import MCPAuthError
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required")
# Act: Execute the method under test
result = MCPToolManageService._re_connect_mcp_provider(
"https://example.com/mcp", mcp_provider.id, tenant.id
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result = service._reconnect_provider(
server_url="https://example.com/mcp",
provider=mcp_provider,
)
# Assert: Verify the expected outcomes
@ -1295,12 +1467,18 @@ class TestMCPToolManageService:
)
# Mock MCPClient to raise connection error
with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client:
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
from core.mcp.error import MCPError
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
mock_client_instance.list_tools.side_effect = MCPError("Connection failed")
# Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="Failed to re-connect MCP server: Connection failed"):
MCPToolManageService._re_connect_mcp_provider("https://example.com/mcp", mcp_provider.id, tenant.id)
service._reconnect_provider(
server_url="https://example.com/mcp",
provider=mcp_provider,
)

View File

@ -0,0 +1,401 @@
"""
TestContainers-based integration tests for mail_owner_transfer_task.
This module provides comprehensive integration tests for the mail owner transfer tasks
using TestContainers to ensure real email service integration and proper functionality
testing with actual database and service dependencies.
"""
import logging
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from libs.email_i18n import EmailType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from tasks.mail_owner_transfer_task import (
send_new_owner_transfer_notify_email_task,
send_old_owner_transfer_notify_email_task,
send_owner_transfer_confirm_task,
)
logger = logging.getLogger(__name__)
class TestMailOwnerTransferTask:
"""Integration tests for mail owner transfer tasks using testcontainers."""
@pytest.fixture
def mock_mail_dependencies(self):
"""Mock setup for mail service dependencies."""
with (
patch("tasks.mail_owner_transfer_task.mail") as mock_mail,
patch("tasks.mail_owner_transfer_task.get_email_i18n_service") as mock_get_email_service,
):
# Setup mock mail service
mock_mail.is_inited.return_value = True
# Setup mock email service
mock_email_service = MagicMock()
mock_get_email_service.return_value = mock_email_service
yield {
"mail": mock_mail,
"email_service": mock_email_service,
"get_email_service": mock_get_email_service,
}
def _create_test_account_and_tenant(self, db_session_with_containers):
"""
Helper method to create test account and tenant for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
Returns:
tuple: (account, tenant) - Created account and tenant instances
"""
fake = Faker()
# Create account
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant
tenant = Tenant(
name=fake.company(),
status="normal",
)
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
current=True,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
return account, tenant
def test_send_owner_transfer_confirm_task_success(self, db_session_with_containers, mock_mail_dependencies):
"""
Test successful owner transfer confirmation email sending.
This test verifies:
- Proper email service initialization check
- Correct email service method calls with right parameters
- Email template context is properly constructed
"""
# Arrange: Create test data
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
test_language = "en-US"
test_email = account.email
test_code = "123456"
test_workspace = tenant.name
# Act: Execute the task
send_owner_transfer_confirm_task(
language=test_language,
to=test_email,
code=test_code,
workspace=test_workspace,
)
# Assert: Verify the expected outcomes
mock_mail_dependencies["mail"].is_inited.assert_called_once()
mock_mail_dependencies["get_email_service"].assert_called_once()
# Verify email service was called with correct parameters
mock_mail_dependencies["email_service"].send_email.assert_called_once()
call_args = mock_mail_dependencies["email_service"].send_email.call_args
assert call_args[1]["email_type"] == EmailType.OWNER_TRANSFER_CONFIRM
assert call_args[1]["language_code"] == test_language
assert call_args[1]["to"] == test_email
assert call_args[1]["template_context"]["to"] == test_email
assert call_args[1]["template_context"]["code"] == test_code
assert call_args[1]["template_context"]["WorkspaceName"] == test_workspace
def test_send_owner_transfer_confirm_task_mail_not_initialized(
self, db_session_with_containers, mock_mail_dependencies
):
"""
Test owner transfer confirmation email when mail service is not initialized.
This test verifies:
- Early return when mail service is not initialized
- No email service calls are made
- No exceptions are raised
"""
# Arrange: Set mail service as not initialized
mock_mail_dependencies["mail"].is_inited.return_value = False
test_language = "en-US"
test_email = "test@example.com"
test_code = "123456"
test_workspace = "Test Workspace"
# Act: Execute the task
send_owner_transfer_confirm_task(
language=test_language,
to=test_email,
code=test_code,
workspace=test_workspace,
)
# Assert: Verify no email service calls were made
mock_mail_dependencies["get_email_service"].assert_not_called()
mock_mail_dependencies["email_service"].send_email.assert_not_called()
def test_send_owner_transfer_confirm_task_exception_handling(
self, db_session_with_containers, mock_mail_dependencies
):
"""
Test exception handling in owner transfer confirmation email.
This test verifies:
- Exceptions are properly caught and logged
- No exceptions are propagated to caller
- Email service calls are attempted
- Error logging works correctly
"""
# Arrange: Setup email service to raise exception
mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error")
test_language = "en-US"
test_email = "test@example.com"
test_code = "123456"
test_workspace = "Test Workspace"
# Act & Assert: Verify no exception is raised
try:
send_owner_transfer_confirm_task(
language=test_language,
to=test_email,
code=test_code,
workspace=test_workspace,
)
except Exception as e:
pytest.fail(f"Task should not raise exceptions, but raised: {e}")
# Verify email service was called despite the exception
mock_mail_dependencies["email_service"].send_email.assert_called_once()
def test_send_old_owner_transfer_notify_email_task_success(
self, db_session_with_containers, mock_mail_dependencies
):
"""
Test successful old owner transfer notification email sending.
This test verifies:
- Proper email service initialization check
- Correct email service method calls with right parameters
- Email template context includes new owner email
"""
# Arrange: Create test data
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
test_language = "en-US"
test_email = account.email
test_workspace = tenant.name
test_new_owner_email = "newowner@example.com"
# Act: Execute the task
send_old_owner_transfer_notify_email_task(
language=test_language,
to=test_email,
workspace=test_workspace,
new_owner_email=test_new_owner_email,
)
# Assert: Verify the expected outcomes
mock_mail_dependencies["mail"].is_inited.assert_called_once()
mock_mail_dependencies["get_email_service"].assert_called_once()
# Verify email service was called with correct parameters
mock_mail_dependencies["email_service"].send_email.assert_called_once()
call_args = mock_mail_dependencies["email_service"].send_email.call_args
assert call_args[1]["email_type"] == EmailType.OWNER_TRANSFER_OLD_NOTIFY
assert call_args[1]["language_code"] == test_language
assert call_args[1]["to"] == test_email
assert call_args[1]["template_context"]["to"] == test_email
assert call_args[1]["template_context"]["WorkspaceName"] == test_workspace
assert call_args[1]["template_context"]["NewOwnerEmail"] == test_new_owner_email
def test_send_old_owner_transfer_notify_email_task_mail_not_initialized(
self, db_session_with_containers, mock_mail_dependencies
):
"""
Test old owner transfer notification email when mail service is not initialized.
This test verifies:
- Early return when mail service is not initialized
- No email service calls are made
- No exceptions are raised
"""
# Arrange: Set mail service as not initialized
mock_mail_dependencies["mail"].is_inited.return_value = False
test_language = "en-US"
test_email = "test@example.com"
test_workspace = "Test Workspace"
test_new_owner_email = "newowner@example.com"
# Act: Execute the task
send_old_owner_transfer_notify_email_task(
language=test_language,
to=test_email,
workspace=test_workspace,
new_owner_email=test_new_owner_email,
)
# Assert: Verify no email service calls were made
mock_mail_dependencies["get_email_service"].assert_not_called()
mock_mail_dependencies["email_service"].send_email.assert_not_called()
def test_send_old_owner_transfer_notify_email_task_exception_handling(
self, db_session_with_containers, mock_mail_dependencies
):
"""
Test exception handling in old owner transfer notification email.
This test verifies:
- Exceptions are properly caught and logged
- No exceptions are propagated to caller
- Email service calls are attempted
- Error logging works correctly
"""
# Arrange: Setup email service to raise exception
mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error")
test_language = "en-US"
test_email = "test@example.com"
test_workspace = "Test Workspace"
test_new_owner_email = "newowner@example.com"
# Act & Assert: Verify no exception is raised
try:
send_old_owner_transfer_notify_email_task(
language=test_language,
to=test_email,
workspace=test_workspace,
new_owner_email=test_new_owner_email,
)
except Exception as e:
pytest.fail(f"Task should not raise exceptions, but raised: {e}")
# Verify email service was called despite the exception
mock_mail_dependencies["email_service"].send_email.assert_called_once()
def test_send_new_owner_transfer_notify_email_task_success(
self, db_session_with_containers, mock_mail_dependencies
):
"""
Test successful new owner transfer notification email sending.
This test verifies:
- Proper email service initialization check
- Correct email service method calls with right parameters
- Email template context is properly constructed
"""
# Arrange: Create test data
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
test_language = "en-US"
test_email = account.email
test_workspace = tenant.name
# Act: Execute the task
send_new_owner_transfer_notify_email_task(
language=test_language,
to=test_email,
workspace=test_workspace,
)
# Assert: Verify the expected outcomes
mock_mail_dependencies["mail"].is_inited.assert_called_once()
mock_mail_dependencies["get_email_service"].assert_called_once()
# Verify email service was called with correct parameters
mock_mail_dependencies["email_service"].send_email.assert_called_once()
call_args = mock_mail_dependencies["email_service"].send_email.call_args
assert call_args[1]["email_type"] == EmailType.OWNER_TRANSFER_NEW_NOTIFY
assert call_args[1]["language_code"] == test_language
assert call_args[1]["to"] == test_email
assert call_args[1]["template_context"]["to"] == test_email
assert call_args[1]["template_context"]["WorkspaceName"] == test_workspace
def test_send_new_owner_transfer_notify_email_task_mail_not_initialized(
self, db_session_with_containers, mock_mail_dependencies
):
"""
Test new owner transfer notification email when mail service is not initialized.
This test verifies:
- Early return when mail service is not initialized
- No email service calls are made
- No exceptions are raised
"""
# Arrange: Set mail service as not initialized
mock_mail_dependencies["mail"].is_inited.return_value = False
test_language = "en-US"
test_email = "test@example.com"
test_workspace = "Test Workspace"
# Act: Execute the task
send_new_owner_transfer_notify_email_task(
language=test_language,
to=test_email,
workspace=test_workspace,
)
# Assert: Verify no email service calls were made
mock_mail_dependencies["get_email_service"].assert_not_called()
mock_mail_dependencies["email_service"].send_email.assert_not_called()
def test_send_new_owner_transfer_notify_email_task_exception_handling(
self, db_session_with_containers, mock_mail_dependencies
):
"""
Test exception handling in new owner transfer notification email.
This test verifies:
- Exceptions are properly caught and logged
- No exceptions are propagated to caller
- Email service calls are attempted
- Error logging works correctly
"""
# Arrange: Setup email service to raise exception
mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error")
test_language = "en-US"
test_email = "test@example.com"
test_workspace = "Test Workspace"
# Act & Assert: Verify no exception is raised
try:
send_new_owner_transfer_notify_email_task(
language=test_language,
to=test_email,
workspace=test_workspace,
)
except Exception as e:
pytest.fail(f"Task should not raise exceptions, but raised: {e}")
# Verify email service was called despite the exception
mock_mail_dependencies["email_service"].send_email.assert_called_once()

View File

@ -60,7 +60,7 @@ class TestAccountInitialization:
return "success"
# Act
with patch("controllers.console.wraps._current_account", return_value=mock_user):
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_user, "tenant123")):
result = protected_view()
# Assert
@ -77,7 +77,7 @@ class TestAccountInitialization:
return "success"
# Act & Assert
with patch("controllers.console.wraps._current_account", return_value=mock_user):
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_user, "tenant123")):
with pytest.raises(AccountNotInitializedError):
protected_view()
@ -163,7 +163,9 @@ class TestBillingResourceLimits:
return "member_added"
# Act
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch(
"controllers.console.wraps.current_account_with_tenant", return_value=(MockUser("test_user"), "tenant123")
):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
result = add_member()
@ -185,7 +187,10 @@ class TestBillingResourceLimits:
# Act & Assert
with app.test_request_context():
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch(
"controllers.console.wraps.current_account_with_tenant",
return_value=(MockUser("test_user"), "tenant123"),
):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
with pytest.raises(Exception) as exc_info:
add_member()
@ -207,7 +212,10 @@ class TestBillingResourceLimits:
# Test 1: Should reject when source is datasets
with app.test_request_context("/?source=datasets"):
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch(
"controllers.console.wraps.current_account_with_tenant",
return_value=(MockUser("test_user"), "tenant123"),
):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
with pytest.raises(Exception) as exc_info:
upload_document()
@ -215,7 +223,10 @@ class TestBillingResourceLimits:
# Test 2: Should allow when source is not datasets
with app.test_request_context("/?source=other"):
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch(
"controllers.console.wraps.current_account_with_tenant",
return_value=(MockUser("test_user"), "tenant123"),
):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
result = upload_document()
assert result == "document_uploaded"
@ -239,7 +250,9 @@ class TestRateLimiting:
return "knowledge_success"
# Act
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch(
"controllers.console.wraps.current_account_with_tenant", return_value=(MockUser("test_user"), "tenant123")
):
with patch(
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
):
@ -271,7 +284,10 @@ class TestRateLimiting:
# Act & Assert
with app.test_request_context():
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch(
"controllers.console.wraps.current_account_with_tenant",
return_value=(MockUser("test_user"), "tenant123"),
):
with patch(
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
):

View File

@ -0,0 +1,720 @@
"""Unit tests for MCP OAuth authentication flow."""
from unittest.mock import Mock, patch
import pytest
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.auth.auth_flow import (
OAUTH_STATE_EXPIRY_SECONDS,
OAUTH_STATE_REDIS_KEY_PREFIX,
OAuthCallbackState,
_create_secure_redis_state,
_retrieve_redis_state,
auth,
check_support_resource_discovery,
discover_oauth_metadata,
exchange_authorization,
generate_pkce_challenge,
handle_callback,
refresh_authorization,
register_client,
start_authorization,
)
from core.mcp.types import (
OAuthClientInformation,
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthMetadata,
OAuthTokens,
)
class TestPKCEGeneration:
"""Test PKCE challenge generation."""
def test_generate_pkce_challenge(self):
"""Test PKCE challenge and verifier generation."""
code_verifier, code_challenge = generate_pkce_challenge()
# Verify format - should be URL-safe base64 without padding
assert "=" not in code_verifier
assert "+" not in code_verifier
assert "/" not in code_verifier
assert "=" not in code_challenge
assert "+" not in code_challenge
assert "/" not in code_challenge
# Verify length
assert len(code_verifier) > 40 # Should be around 54 characters
assert len(code_challenge) > 40 # Should be around 43 characters
def test_generate_pkce_challenge_uniqueness(self):
"""Test that PKCE generation produces unique values."""
results = set()
for _ in range(10):
code_verifier, code_challenge = generate_pkce_challenge()
results.add((code_verifier, code_challenge))
# All should be unique
assert len(results) == 10
class TestRedisStateManagement:
"""Test Redis state management functions."""
@patch("core.mcp.auth.auth_flow.redis_client")
def test_create_secure_redis_state(self, mock_redis):
"""Test creating secure Redis state."""
state_data = OAuthCallbackState(
provider_id="test-provider",
tenant_id="test-tenant",
server_url="https://example.com",
metadata=None,
client_information=OAuthClientInformation(client_id="test-client"),
code_verifier="test-verifier",
redirect_uri="https://redirect.example.com",
)
state_key = _create_secure_redis_state(state_data)
# Verify state key format
assert len(state_key) > 20 # Should be a secure random token
# Verify Redis call
mock_redis.setex.assert_called_once()
call_args = mock_redis.setex.call_args
assert call_args[0][0].startswith(OAUTH_STATE_REDIS_KEY_PREFIX)
assert call_args[0][1] == OAUTH_STATE_EXPIRY_SECONDS
assert state_data.model_dump_json() in call_args[0][2]
@patch("core.mcp.auth.auth_flow.redis_client")
def test_retrieve_redis_state_success(self, mock_redis):
"""Test retrieving state from Redis."""
state_data = OAuthCallbackState(
provider_id="test-provider",
tenant_id="test-tenant",
server_url="https://example.com",
metadata=None,
client_information=OAuthClientInformation(client_id="test-client"),
code_verifier="test-verifier",
redirect_uri="https://redirect.example.com",
)
mock_redis.get.return_value = state_data.model_dump_json()
result = _retrieve_redis_state("test-state-key")
# Verify result
assert result.provider_id == "test-provider"
assert result.tenant_id == "test-tenant"
assert result.server_url == "https://example.com"
# Verify Redis calls
mock_redis.get.assert_called_once_with(f"{OAUTH_STATE_REDIS_KEY_PREFIX}test-state-key")
mock_redis.delete.assert_called_once_with(f"{OAUTH_STATE_REDIS_KEY_PREFIX}test-state-key")
@patch("core.mcp.auth.auth_flow.redis_client")
def test_retrieve_redis_state_not_found(self, mock_redis):
"""Test retrieving non-existent state from Redis."""
mock_redis.get.return_value = None
with pytest.raises(ValueError) as exc_info:
_retrieve_redis_state("nonexistent-key")
assert "State parameter has expired or does not exist" in str(exc_info.value)
@patch("core.mcp.auth.auth_flow.redis_client")
def test_retrieve_redis_state_invalid_json(self, mock_redis):
"""Test retrieving invalid JSON state from Redis."""
mock_redis.get.return_value = '{"invalid": json}'
with pytest.raises(ValueError) as exc_info:
_retrieve_redis_state("test-key")
assert "Invalid state parameter" in str(exc_info.value)
# State should still be deleted
mock_redis.delete.assert_called_once()
class TestOAuthDiscovery:
"""Test OAuth discovery functions."""
@patch("core.helper.ssrf_proxy.get")
def test_check_support_resource_discovery_success(self, mock_get):
"""Test successful resource discovery check."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"authorization_server_url": ["https://auth.example.com"]}
mock_get.return_value = mock_response
supported, auth_url = check_support_resource_discovery("https://api.example.com/endpoint")
assert supported is True
assert auth_url == "https://auth.example.com"
mock_get.assert_called_once_with(
"https://api.example.com/.well-known/oauth-protected-resource",
headers={"MCP-Protocol-Version": "2025-03-26", "User-Agent": "Dify"},
)
@patch("core.helper.ssrf_proxy.get")
def test_check_support_resource_discovery_not_supported(self, mock_get):
"""Test resource discovery not supported."""
mock_response = Mock()
mock_response.status_code = 404
mock_get.return_value = mock_response
supported, auth_url = check_support_resource_discovery("https://api.example.com")
assert supported is False
assert auth_url == ""
@patch("core.helper.ssrf_proxy.get")
def test_check_support_resource_discovery_with_query_fragment(self, mock_get):
"""Test resource discovery with query and fragment."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"authorization_server_url": ["https://auth.example.com"]}
mock_get.return_value = mock_response
supported, auth_url = check_support_resource_discovery("https://api.example.com/path?query=1#fragment")
assert supported is True
assert auth_url == "https://auth.example.com"
mock_get.assert_called_once_with(
"https://api.example.com/.well-known/oauth-protected-resource?query=1#fragment",
headers={"MCP-Protocol-Version": "2025-03-26", "User-Agent": "Dify"},
)
@patch("core.helper.ssrf_proxy.get")
def test_discover_oauth_metadata_with_resource_discovery(self, mock_get):
"""Test OAuth metadata discovery with resource discovery support."""
with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check:
mock_check.return_value = (True, "https://auth.example.com")
mock_response = Mock()
mock_response.status_code = 200
mock_response.is_success = True
mock_response.json.return_value = {
"authorization_endpoint": "https://auth.example.com/authorize",
"token_endpoint": "https://auth.example.com/token",
"response_types_supported": ["code"],
}
mock_get.return_value = mock_response
metadata = discover_oauth_metadata("https://api.example.com")
assert metadata is not None
assert metadata.authorization_endpoint == "https://auth.example.com/authorize"
assert metadata.token_endpoint == "https://auth.example.com/token"
mock_get.assert_called_once_with(
"https://auth.example.com/.well-known/oauth-authorization-server",
headers={"MCP-Protocol-Version": "2025-03-26"},
)
@patch("core.helper.ssrf_proxy.get")
def test_discover_oauth_metadata_without_resource_discovery(self, mock_get):
"""Test OAuth metadata discovery without resource discovery."""
with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check:
mock_check.return_value = (False, "")
mock_response = Mock()
mock_response.status_code = 200
mock_response.is_success = True
mock_response.json.return_value = {
"authorization_endpoint": "https://api.example.com/oauth/authorize",
"token_endpoint": "https://api.example.com/oauth/token",
"response_types_supported": ["code"],
}
mock_get.return_value = mock_response
metadata = discover_oauth_metadata("https://api.example.com")
assert metadata is not None
assert metadata.authorization_endpoint == "https://api.example.com/oauth/authorize"
mock_get.assert_called_once_with(
"https://api.example.com/.well-known/oauth-authorization-server",
headers={"MCP-Protocol-Version": "2025-03-26"},
)
@patch("core.helper.ssrf_proxy.get")
def test_discover_oauth_metadata_not_found(self, mock_get):
"""Test OAuth metadata discovery when not found."""
with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check:
mock_check.return_value = (False, "")
mock_response = Mock()
mock_response.status_code = 404
mock_get.return_value = mock_response
metadata = discover_oauth_metadata("https://api.example.com")
assert metadata is None
class TestAuthorizationFlow:
"""Test authorization flow functions."""
@patch("core.mcp.auth.auth_flow._create_secure_redis_state")
def test_start_authorization_with_metadata(self, mock_create_state):
"""Test starting authorization with metadata."""
mock_create_state.return_value = "secure-state-key"
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
code_challenge_methods_supported=["S256"],
)
client_info = OAuthClientInformation(client_id="test-client-id")
auth_url, code_verifier = start_authorization(
"https://api.example.com",
metadata,
client_info,
"https://redirect.example.com",
"provider-id",
"tenant-id",
)
# Verify URL format
assert auth_url.startswith("https://auth.example.com/authorize?")
assert "response_type=code" in auth_url
assert "client_id=test-client-id" in auth_url
assert "code_challenge=" in auth_url
assert "code_challenge_method=S256" in auth_url
assert "redirect_uri=https%3A%2F%2Fredirect.example.com" in auth_url
assert "state=secure-state-key" in auth_url
# Verify code verifier
assert len(code_verifier) > 40
# Verify state was stored
mock_create_state.assert_called_once()
state_data = mock_create_state.call_args[0][0]
assert state_data.provider_id == "provider-id"
assert state_data.tenant_id == "tenant-id"
assert state_data.code_verifier == code_verifier
def test_start_authorization_without_metadata(self):
"""Test starting authorization without metadata."""
with patch("core.mcp.auth.auth_flow._create_secure_redis_state") as mock_create_state:
mock_create_state.return_value = "secure-state-key"
client_info = OAuthClientInformation(client_id="test-client-id")
auth_url, code_verifier = start_authorization(
"https://api.example.com",
None,
client_info,
"https://redirect.example.com",
"provider-id",
"tenant-id",
)
# Should use default authorization endpoint
assert auth_url.startswith("https://api.example.com/authorize?")
def test_start_authorization_invalid_metadata(self):
"""Test starting authorization with invalid metadata."""
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["token"], # No "code" support
code_challenge_methods_supported=["plain"], # No "S256" support
)
client_info = OAuthClientInformation(client_id="test-client-id")
with pytest.raises(ValueError) as exc_info:
start_authorization(
"https://api.example.com",
metadata,
client_info,
"https://redirect.example.com",
"provider-id",
"tenant-id",
)
assert "does not support response type code" in str(exc_info.value)
@patch("core.helper.ssrf_proxy.post")
def test_exchange_authorization_success(self, mock_post):
"""Test successful authorization code exchange."""
mock_response = Mock()
mock_response.is_success = True
mock_response.json.return_value = {
"access_token": "new-access-token",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "new-refresh-token",
}
mock_post.return_value = mock_response
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
client_info = OAuthClientInformation(client_id="test-client-id", client_secret="test-secret")
tokens = exchange_authorization(
"https://api.example.com",
metadata,
client_info,
"auth-code-123",
"code-verifier-xyz",
"https://redirect.example.com",
)
assert tokens.access_token == "new-access-token"
assert tokens.token_type == "Bearer"
assert tokens.expires_in == 3600
assert tokens.refresh_token == "new-refresh-token"
# Verify request
mock_post.assert_called_once_with(
"https://auth.example.com/token",
data={
"grant_type": "authorization_code",
"client_id": "test-client-id",
"client_secret": "test-secret",
"code": "auth-code-123",
"code_verifier": "code-verifier-xyz",
"redirect_uri": "https://redirect.example.com",
},
)
@patch("core.helper.ssrf_proxy.post")
def test_exchange_authorization_failure(self, mock_post):
"""Test failed authorization code exchange."""
mock_response = Mock()
mock_response.is_success = False
mock_response.status_code = 400
mock_post.return_value = mock_response
client_info = OAuthClientInformation(client_id="test-client-id")
with pytest.raises(ValueError) as exc_info:
exchange_authorization(
"https://api.example.com",
None,
client_info,
"invalid-code",
"code-verifier",
"https://redirect.example.com",
)
assert "Token exchange failed: HTTP 400" in str(exc_info.value)
@patch("core.helper.ssrf_proxy.post")
def test_refresh_authorization_success(self, mock_post):
"""Test successful token refresh."""
mock_response = Mock()
mock_response.is_success = True
mock_response.json.return_value = {
"access_token": "refreshed-access-token",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "new-refresh-token",
}
mock_post.return_value = mock_response
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["refresh_token"],
)
client_info = OAuthClientInformation(client_id="test-client-id")
tokens = refresh_authorization("https://api.example.com", metadata, client_info, "old-refresh-token")
assert tokens.access_token == "refreshed-access-token"
assert tokens.refresh_token == "new-refresh-token"
# Verify request
mock_post.assert_called_once_with(
"https://auth.example.com/token",
data={
"grant_type": "refresh_token",
"client_id": "test-client-id",
"refresh_token": "old-refresh-token",
},
)
@patch("core.helper.ssrf_proxy.post")
def test_register_client_success(self, mock_post):
"""Test successful client registration."""
mock_response = Mock()
mock_response.is_success = True
mock_response.json.return_value = {
"client_id": "new-client-id",
"client_secret": "new-client-secret",
"client_name": "Dify",
"redirect_uris": ["https://redirect.example.com"],
}
mock_post.return_value = mock_response
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
registration_endpoint="https://auth.example.com/register",
response_types_supported=["code"],
)
client_metadata = OAuthClientMetadata(
client_name="Dify",
redirect_uris=["https://redirect.example.com"],
grant_types=["authorization_code"],
response_types=["code"],
)
client_info = register_client("https://api.example.com", metadata, client_metadata)
assert isinstance(client_info, OAuthClientInformationFull)
assert client_info.client_id == "new-client-id"
assert client_info.client_secret == "new-client-secret"
# Verify request
mock_post.assert_called_once_with(
"https://auth.example.com/register",
json=client_metadata.model_dump(),
headers={"Content-Type": "application/json"},
)
def test_register_client_no_endpoint(self):
"""Test client registration when no endpoint available."""
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
registration_endpoint=None,
response_types_supported=["code"],
)
client_metadata = OAuthClientMetadata(client_name="Dify", redirect_uris=["https://redirect.example.com"])
with pytest.raises(ValueError) as exc_info:
register_client("https://api.example.com", metadata, client_metadata)
assert "does not support dynamic client registration" in str(exc_info.value)
class TestCallbackHandling:
"""Test OAuth callback handling."""
@patch("core.mcp.auth.auth_flow._retrieve_redis_state")
@patch("core.mcp.auth.auth_flow.exchange_authorization")
def test_handle_callback_success(self, mock_exchange, mock_retrieve_state):
"""Test successful callback handling."""
# Setup state
state_data = OAuthCallbackState(
provider_id="test-provider",
tenant_id="test-tenant",
server_url="https://api.example.com",
metadata=None,
client_information=OAuthClientInformation(client_id="test-client"),
code_verifier="test-verifier",
redirect_uri="https://redirect.example.com",
)
mock_retrieve_state.return_value = state_data
# Setup token exchange
tokens = OAuthTokens(
access_token="new-token",
token_type="Bearer",
expires_in=3600,
)
mock_exchange.return_value = tokens
# Setup service
mock_service = Mock()
result = handle_callback("state-key", "auth-code", mock_service)
assert result == state_data
# Verify calls
mock_retrieve_state.assert_called_once_with("state-key")
mock_exchange.assert_called_once_with(
"https://api.example.com",
None,
state_data.client_information,
"auth-code",
"test-verifier",
"https://redirect.example.com",
)
mock_service.save_oauth_data.assert_called_once_with(
"test-provider", "test-tenant", tokens.model_dump(), "tokens"
)
class TestAuthOrchestration:
"""Test the main auth orchestration function."""
@pytest.fixture
def mock_provider(self):
"""Create a mock provider entity."""
provider = Mock(spec=MCPProviderEntity)
provider.id = "provider-id"
provider.tenant_id = "tenant-id"
provider.decrypt_server_url.return_value = "https://api.example.com"
provider.client_metadata = OAuthClientMetadata(
client_name="Dify",
redirect_uris=["https://redirect.example.com"],
)
provider.redirect_url = "https://redirect.example.com"
provider.retrieve_client_information.return_value = None
provider.retrieve_tokens.return_value = None
return provider
@pytest.fixture
def mock_service(self):
"""Create a mock MCP service."""
return Mock()
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
@patch("core.mcp.auth.auth_flow.register_client")
@patch("core.mcp.auth.auth_flow.start_authorization")
def test_auth_new_registration(self, mock_start_auth, mock_register, mock_discover, mock_provider, mock_service):
"""Test auth flow for new client registration."""
# Setup
mock_discover.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
mock_register.return_value = OAuthClientInformationFull(
client_id="new-client-id",
client_name="Dify",
redirect_uris=["https://redirect.example.com"],
)
mock_start_auth.return_value = ("https://auth.example.com/authorize?...", "code-verifier")
result = auth(mock_provider, mock_service)
assert result == {"authorization_url": "https://auth.example.com/authorize?..."}
# Verify calls
mock_register.assert_called_once()
mock_service.save_oauth_data.assert_any_call(
"provider-id",
"tenant-id",
{"client_information": mock_register.return_value.model_dump()},
"client_info",
)
mock_service.save_oauth_data.assert_any_call(
"provider-id", "tenant-id", {"code_verifier": "code-verifier"}, "code_verifier"
)
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
@patch("core.mcp.auth.auth_flow._retrieve_redis_state")
@patch("core.mcp.auth.auth_flow.exchange_authorization")
def test_auth_exchange_code(self, mock_exchange, mock_retrieve_state, mock_discover, mock_provider, mock_service):
"""Test auth flow for exchanging authorization code."""
# Setup metadata discovery
mock_discover.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
# Setup existing client
mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client")
# Setup state retrieval
state_data = OAuthCallbackState(
provider_id="provider-id",
tenant_id="tenant-id",
server_url="https://api.example.com",
metadata=None,
client_information=OAuthClientInformation(client_id="existing-client"),
code_verifier="test-verifier",
redirect_uri="https://redirect.example.com",
)
mock_retrieve_state.return_value = state_data
# Setup token exchange
tokens = OAuthTokens(access_token="new-token", token_type="Bearer", expires_in=3600)
mock_exchange.return_value = tokens
result = auth(mock_provider, mock_service, authorization_code="auth-code", state_param="state-key")
assert result == {"result": "success"}
# Verify token save
mock_service.save_oauth_data.assert_called_with("provider-id", "tenant-id", tokens.model_dump(), "tokens")
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
def test_auth_exchange_code_without_state(self, mock_discover, mock_provider, mock_service):
"""Test auth flow fails when exchanging code without state."""
# Setup metadata discovery
mock_discover.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client")
with pytest.raises(ValueError) as exc_info:
auth(mock_provider, mock_service, authorization_code="auth-code")
assert "State parameter is required" in str(exc_info.value)
@patch("core.mcp.auth.auth_flow.refresh_authorization")
def test_auth_refresh_token(self, mock_refresh, mock_provider, mock_service):
"""Test auth flow for refreshing tokens."""
# Setup existing client and tokens
mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client")
mock_provider.retrieve_tokens.return_value = OAuthTokens(
access_token="old-token",
token_type="Bearer",
expires_in=0,
refresh_token="refresh-token",
)
# Setup refresh
new_tokens = OAuthTokens(
access_token="refreshed-token",
token_type="Bearer",
expires_in=3600,
refresh_token="new-refresh-token",
)
mock_refresh.return_value = new_tokens
with patch("core.mcp.auth.auth_flow.discover_oauth_metadata") as mock_discover:
mock_discover.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
result = auth(mock_provider, mock_service)
assert result == {"result": "success"}
# Verify refresh was called
mock_refresh.assert_called_once()
mock_service.save_oauth_data.assert_called_with(
"provider-id", "tenant-id", new_tokens.model_dump(), "tokens"
)
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
def test_auth_registration_fails_with_code(self, mock_discover, mock_provider, mock_service):
"""Test auth fails when no client info exists but code is provided."""
# Setup metadata discovery
mock_discover.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
mock_provider.retrieve_client_information.return_value = None
with pytest.raises(ValueError) as exc_info:
auth(mock_provider, mock_service, authorization_code="auth-code")
assert "Existing OAuth client information is required" in str(exc_info.value)

View File

@ -0,0 +1,420 @@
"""Unit tests for MCP auth client with retry logic."""
from types import TracebackType
from unittest.mock import Mock, patch
import pytest
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPAuthError
from core.mcp.mcp_client import MCPClient
from core.mcp.types import CallToolResult, TextContent, Tool, ToolAnnotations
class TestMCPClientWithAuthRetry:
"""Test suite for MCPClientWithAuthRetry."""
@pytest.fixture
def mock_provider_entity(self):
"""Create a mock provider entity."""
provider = Mock(spec=MCPProviderEntity)
provider.id = "test-provider-id"
provider.tenant_id = "test-tenant-id"
provider.retrieve_tokens.return_value = Mock(
access_token="test-token", token_type="Bearer", expires_in=3600, refresh_token=None
)
return provider
@pytest.fixture
def mock_mcp_service(self):
"""Create a mock MCP service."""
service = Mock()
service.get_provider_entity.return_value = Mock(
retrieve_tokens=lambda: Mock(
access_token="new-test-token", token_type="Bearer", expires_in=3600, refresh_token=None
)
)
return service
@pytest.fixture
def auth_callback(self):
"""Create a mock auth callback."""
return Mock()
def test_init(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test client initialization."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
headers={"Authorization": "Bearer test"},
timeout=30.0,
sse_read_timeout=60.0,
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
authorization_code="test-auth-code",
by_server_id=True,
mcp_service=mock_mcp_service,
)
assert client.server_url == "http://test.example.com"
assert client.headers == {"Authorization": "Bearer test"}
assert client.timeout == 30.0
assert client.sse_read_timeout == 60.0
assert client.provider_entity == mock_provider_entity
assert client.auth_callback == auth_callback
assert client.authorization_code == "test-auth-code"
assert client.by_server_id is True
assert client.mcp_service == mock_mcp_service
assert client._has_retried is False
# In inheritance design, we don't have _client attribute
assert hasattr(client, "_session") # Inherited from MCPClient
def test_inheritance_structure(self):
"""Test that MCPClientWithAuthRetry properly inherits from MCPClient."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
headers={"Authorization": "Bearer test"},
)
# Verify inheritance
assert isinstance(client, MCPClient)
# Verify inherited attributes are accessible
assert hasattr(client, "server_url")
assert hasattr(client, "headers")
assert hasattr(client, "_session")
assert hasattr(client, "_exit_stack")
assert hasattr(client, "_initialized")
def test_handle_auth_error_no_retry_components(self):
"""Test auth error handling when retry components are missing."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
error = MCPAuthError("Auth failed")
with pytest.raises(MCPAuthError) as exc_info:
client._handle_auth_error(error)
assert exc_info.value == error
def test_handle_auth_error_already_retried(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test auth error handling when already retried."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
mcp_service=mock_mcp_service,
)
client._has_retried = True
error = MCPAuthError("Auth failed")
with pytest.raises(MCPAuthError) as exc_info:
client._handle_auth_error(error)
assert exc_info.value == error
auth_callback.assert_not_called()
def test_handle_auth_error_successful_refresh(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test successful auth refresh on error."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
authorization_code="test-code",
by_server_id=True,
mcp_service=mock_mcp_service,
)
# Configure mocks
new_provider = Mock(spec=MCPProviderEntity)
new_provider.id = "test-provider-id"
new_provider.tenant_id = "test-tenant-id"
new_provider.retrieve_tokens.return_value = Mock(
access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None
)
mock_mcp_service.get_provider_entity.return_value = new_provider
error = MCPAuthError("Auth failed")
client._handle_auth_error(error)
# Verify auth flow
auth_callback.assert_called_once_with(mock_provider_entity, mock_mcp_service, "test-code")
mock_mcp_service.get_provider_entity.assert_called_once_with(
"test-provider-id", "test-tenant-id", by_server_id=True
)
assert client.headers["Authorization"] == "Bearer new-token"
assert client.authorization_code is None # Should be cleared after use
assert client._has_retried is True
def test_handle_auth_error_refresh_fails(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test auth refresh failure."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
mcp_service=mock_mcp_service,
)
auth_callback.side_effect = Exception("Auth callback failed")
error = MCPAuthError("Original auth failed")
with pytest.raises(MCPAuthError) as exc_info:
client._handle_auth_error(error)
assert "Authentication retry failed" in str(exc_info.value)
def test_handle_auth_error_no_token_received(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test auth refresh when no token is received."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
mcp_service=mock_mcp_service,
)
# Configure mock to return no token
new_provider = Mock(spec=MCPProviderEntity)
new_provider.retrieve_tokens.return_value = None
mock_mcp_service.get_provider_entity.return_value = new_provider
error = MCPAuthError("Auth failed")
with pytest.raises(MCPAuthError) as exc_info:
client._handle_auth_error(error)
assert "no token received" in str(exc_info.value)
def test_execute_with_retry_success(self):
"""Test successful execution without retry."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
mock_func = Mock(return_value="success")
result = client._execute_with_retry(mock_func, "arg1", kwarg1="value1")
assert result == "success"
mock_func.assert_called_once_with("arg1", kwarg1="value1")
assert client._has_retried is False
def test_execute_with_retry_auth_error_then_success(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test execution with auth error followed by successful retry."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
mcp_service=mock_mcp_service,
)
# Configure new provider with token
new_provider = Mock(spec=MCPProviderEntity)
new_provider.retrieve_tokens.return_value = Mock(
access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None
)
mock_mcp_service.get_provider_entity.return_value = new_provider
# Mock function that fails first, then succeeds
mock_func = Mock(side_effect=[MCPAuthError("Auth failed"), "success"])
# Mock the exit stack and session cleanup
with (
patch.object(client, "_exit_stack") as mock_exit_stack,
patch.object(client, "_session") as mock_session,
patch.object(client, "_initialize") as mock_initialize,
):
client._initialized = True
result = client._execute_with_retry(mock_func, "arg1", kwarg1="value1")
assert result == "success"
assert mock_func.call_count == 2
mock_func.assert_called_with("arg1", kwarg1="value1")
auth_callback.assert_called_once()
mock_exit_stack.close.assert_called_once()
mock_initialize.assert_called_once()
assert client._has_retried is False # Reset after completion
def test_execute_with_retry_non_auth_error(self):
"""Test execution with non-auth error (no retry)."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
mock_func = Mock(side_effect=ValueError("Some other error"))
with pytest.raises(ValueError) as exc_info:
client._execute_with_retry(mock_func)
assert str(exc_info.value) == "Some other error"
mock_func.assert_called_once()
def test_context_manager_enter(self):
"""Test context manager enter."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
with patch.object(client, "_initialize") as mock_initialize:
result = client.__enter__()
assert result == client
assert client._initialized is True
mock_initialize.assert_called_once()
def test_context_manager_enter_with_auth_error(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test context manager enter with auth error and retry."""
# Configure new provider with token
new_provider = Mock(spec=MCPProviderEntity)
new_provider.retrieve_tokens.return_value = Mock(
access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None
)
mock_mcp_service.get_provider_entity.return_value = new_provider
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
mcp_service=mock_mcp_service,
)
# Mock parent class __enter__ to raise auth error first, then succeed
with patch.object(MCPClient, "__enter__") as mock_parent_enter:
mock_parent_enter.side_effect = [MCPAuthError("Auth failed"), client]
result = client.__enter__()
assert result == client
assert mock_parent_enter.call_count == 2
auth_callback.assert_called_once()
def test_context_manager_exit(self):
"""Test context manager exit."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
with patch.object(client, "cleanup") as mock_cleanup:
exc_type: type[BaseException] | None = None
exc_val: BaseException | None = None
exc_tb: TracebackType | None = None
client.__exit__(exc_type, exc_val, exc_tb)
mock_cleanup.assert_called_once()
def test_list_tools_not_initialized(self):
"""Test list_tools when client not initialized."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
with pytest.raises(ValueError) as exc_info:
client.list_tools()
assert "Session not initialized" in str(exc_info.value)
def test_list_tools_success(self):
"""Test successful list_tools call."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
expected_tools = [
Tool(
name="test-tool",
description="A test tool",
inputSchema={"type": "object", "properties": {}},
annotations=ToolAnnotations(title="Test Tool"),
)
]
# Mock the parent class list_tools method
with patch.object(MCPClient, "list_tools", return_value=expected_tools):
result = client.list_tools()
assert result == expected_tools
def test_list_tools_with_auth_retry(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test list_tools with auth retry."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
mcp_service=mock_mcp_service,
)
# Configure new provider with token
new_provider = Mock(spec=MCPProviderEntity)
new_provider.retrieve_tokens.return_value = Mock(
access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None
)
mock_mcp_service.get_provider_entity.return_value = new_provider
expected_tools = [Tool(name="test-tool", description="A test tool", inputSchema={})]
# Mock parent class list_tools to raise auth error first, then succeed
with patch.object(MCPClient, "list_tools") as mock_list_tools:
mock_list_tools.side_effect = [MCPAuthError("Auth failed"), expected_tools]
result = client.list_tools()
assert result == expected_tools
assert mock_list_tools.call_count == 2
auth_callback.assert_called_once()
def test_invoke_tool_not_initialized(self):
"""Test invoke_tool when client not initialized."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
with pytest.raises(ValueError) as exc_info:
client.invoke_tool("test-tool", {"arg": "value"})
assert "Session not initialized" in str(exc_info.value)
def test_invoke_tool_success(self):
"""Test successful invoke_tool call."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
expected_result = CallToolResult(
content=[TextContent(type="text", text="Tool executed successfully")], isError=False
)
# Mock the parent class invoke_tool method
with patch.object(MCPClient, "invoke_tool", return_value=expected_result) as mock_invoke:
result = client.invoke_tool("test-tool", {"arg": "value"})
assert result == expected_result
mock_invoke.assert_called_once_with("test-tool", {"arg": "value"})
def test_invoke_tool_with_auth_retry(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test invoke_tool with auth retry."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
mcp_service=mock_mcp_service,
)
# Configure new provider with token
new_provider = Mock(spec=MCPProviderEntity)
new_provider.retrieve_tokens.return_value = Mock(
access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None
)
mock_mcp_service.get_provider_entity.return_value = new_provider
expected_result = CallToolResult(content=[TextContent(type="text", text="Success")], isError=False)
# Mock parent class invoke_tool to raise auth error first, then succeed
with patch.object(MCPClient, "invoke_tool") as mock_invoke_tool:
mock_invoke_tool.side_effect = [MCPAuthError("Auth failed"), expected_result]
result = client.invoke_tool("test-tool", {"arg": "value"})
assert result == expected_result
assert mock_invoke_tool.call_count == 2
mock_invoke_tool.assert_called_with("test-tool", {"arg": "value"})
auth_callback.assert_called_once()
def test_cleanup(self):
"""Test cleanup method."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
# Mock the parent class cleanup method
with patch.object(MCPClient, "cleanup") as mock_cleanup:
client.cleanup()
mock_cleanup.assert_called_once()
def test_cleanup_no_client(self):
"""Test cleanup when no client exists."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
# Should not raise
client.cleanup()
# Since MCPClientWithAuthRetry inherits from MCPClient,
# it doesn't have a _client attribute. The test should just
# verify that cleanup can be called without error.
assert not hasattr(client, "_client")

View File

@ -0,0 +1,239 @@
"""Unit tests for MCP entities module."""
from unittest.mock import Mock
from core.mcp.entities import (
SUPPORTED_PROTOCOL_VERSIONS,
LifespanContextT,
RequestContext,
SessionT,
)
from core.mcp.session.base_session import BaseSession
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestParams
class TestProtocolVersions:
"""Test protocol version constants."""
def test_supported_protocol_versions(self):
"""Test supported protocol versions list."""
assert isinstance(SUPPORTED_PROTOCOL_VERSIONS, list)
assert len(SUPPORTED_PROTOCOL_VERSIONS) >= 3
assert "2024-11-05" in SUPPORTED_PROTOCOL_VERSIONS
assert "2025-03-26" in SUPPORTED_PROTOCOL_VERSIONS
assert LATEST_PROTOCOL_VERSION in SUPPORTED_PROTOCOL_VERSIONS
def test_latest_protocol_version_is_supported(self):
"""Test that latest protocol version is in supported versions."""
assert LATEST_PROTOCOL_VERSION in SUPPORTED_PROTOCOL_VERSIONS
class TestRequestContext:
"""Test RequestContext dataclass."""
def test_request_context_creation(self):
"""Test creating a RequestContext instance."""
mock_session = Mock(spec=BaseSession)
mock_lifespan = {"key": "value"}
mock_meta = RequestParams.Meta(progressToken="test-token")
context = RequestContext(
request_id="test-request-123",
meta=mock_meta,
session=mock_session,
lifespan_context=mock_lifespan,
)
assert context.request_id == "test-request-123"
assert context.meta == mock_meta
assert context.session == mock_session
assert context.lifespan_context == mock_lifespan
def test_request_context_with_none_meta(self):
"""Test creating RequestContext with None meta."""
mock_session = Mock(spec=BaseSession)
context = RequestContext(
request_id=42, # Can be int or string
meta=None,
session=mock_session,
lifespan_context=None,
)
assert context.request_id == 42
assert context.meta is None
assert context.session == mock_session
assert context.lifespan_context is None
def test_request_context_attributes(self):
"""Test RequestContext attributes are accessible."""
mock_session = Mock(spec=BaseSession)
context = RequestContext(
request_id="test-123",
meta=None,
session=mock_session,
lifespan_context=None,
)
# Verify attributes are accessible
assert hasattr(context, "request_id")
assert hasattr(context, "meta")
assert hasattr(context, "session")
assert hasattr(context, "lifespan_context")
# Verify values
assert context.request_id == "test-123"
assert context.meta is None
assert context.session == mock_session
assert context.lifespan_context is None
def test_request_context_generic_typing(self):
"""Test RequestContext with different generic types."""
# Create a mock session with specific type
mock_session = Mock(spec=BaseSession)
# Create context with string lifespan context
context_str = RequestContext[BaseSession, str](
request_id="test-1",
meta=None,
session=mock_session,
lifespan_context="string-context",
)
assert isinstance(context_str.lifespan_context, str)
# Create context with dict lifespan context
context_dict = RequestContext[BaseSession, dict](
request_id="test-2",
meta=None,
session=mock_session,
lifespan_context={"key": "value"},
)
assert isinstance(context_dict.lifespan_context, dict)
# Create context with custom object lifespan context
class CustomLifespan:
def __init__(self, data):
self.data = data
custom_lifespan = CustomLifespan("test-data")
context_custom = RequestContext[BaseSession, CustomLifespan](
request_id="test-3",
meta=None,
session=mock_session,
lifespan_context=custom_lifespan,
)
assert isinstance(context_custom.lifespan_context, CustomLifespan)
assert context_custom.lifespan_context.data == "test-data"
def test_request_context_with_progress_meta(self):
"""Test RequestContext with progress metadata."""
mock_session = Mock(spec=BaseSession)
progress_meta = RequestParams.Meta(progressToken="progress-123")
context = RequestContext(
request_id="req-456",
meta=progress_meta,
session=mock_session,
lifespan_context=None,
)
assert context.meta is not None
assert context.meta.progressToken == "progress-123"
def test_request_context_equality(self):
"""Test RequestContext equality comparison."""
mock_session1 = Mock(spec=BaseSession)
mock_session2 = Mock(spec=BaseSession)
context1 = RequestContext(
request_id="test-123",
meta=None,
session=mock_session1,
lifespan_context="context",
)
context2 = RequestContext(
request_id="test-123",
meta=None,
session=mock_session1,
lifespan_context="context",
)
context3 = RequestContext(
request_id="test-456",
meta=None,
session=mock_session1,
lifespan_context="context",
)
# Same values should be equal
assert context1 == context2
# Different request_id should not be equal
assert context1 != context3
# Different session should not be equal
context4 = RequestContext(
request_id="test-123",
meta=None,
session=mock_session2,
lifespan_context="context",
)
assert context1 != context4
def test_request_context_repr(self):
"""Test RequestContext string representation."""
mock_session = Mock(spec=BaseSession)
mock_session.__repr__ = Mock(return_value="<MockSession>")
context = RequestContext(
request_id="test-123",
meta=None,
session=mock_session,
lifespan_context={"data": "test"},
)
repr_str = repr(context)
assert "RequestContext" in repr_str
assert "test-123" in repr_str
assert "MockSession" in repr_str
class TestTypeVariables:
"""Test type variables defined in the module."""
def test_session_type_var(self):
"""Test SessionT type variable."""
# Create a custom session class
class CustomSession(BaseSession):
pass
# Use in generic context
def process_session(session: SessionT) -> SessionT:
return session
mock_session = Mock(spec=CustomSession)
result = process_session(mock_session)
assert result == mock_session
def test_lifespan_context_type_var(self):
"""Test LifespanContextT type variable."""
# Use in generic context
def process_lifespan(context: LifespanContextT) -> LifespanContextT:
return context
# Test with different types
str_context = "string-context"
assert process_lifespan(str_context) == str_context
dict_context = {"key": "value"}
assert process_lifespan(dict_context) == dict_context
class CustomContext:
pass
custom_context = CustomContext()
assert process_lifespan(custom_context) == custom_context

View File

@ -0,0 +1,205 @@
"""Unit tests for MCP error classes."""
import pytest
from core.mcp.error import MCPAuthError, MCPConnectionError, MCPError
class TestMCPError:
"""Test MCPError base exception class."""
def test_mcp_error_creation(self):
"""Test creating MCPError instance."""
error = MCPError("Test error message")
assert str(error) == "Test error message"
assert isinstance(error, Exception)
def test_mcp_error_inheritance(self):
"""Test MCPError inherits from Exception."""
error = MCPError()
assert isinstance(error, Exception)
assert type(error).__name__ == "MCPError"
def test_mcp_error_with_empty_message(self):
"""Test MCPError with empty message."""
error = MCPError()
assert str(error) == ""
def test_mcp_error_raise(self):
"""Test raising MCPError."""
with pytest.raises(MCPError) as exc_info:
raise MCPError("Something went wrong")
assert str(exc_info.value) == "Something went wrong"
class TestMCPConnectionError:
"""Test MCPConnectionError exception class."""
def test_mcp_connection_error_creation(self):
"""Test creating MCPConnectionError instance."""
error = MCPConnectionError("Connection failed")
assert str(error) == "Connection failed"
assert isinstance(error, MCPError)
assert isinstance(error, Exception)
def test_mcp_connection_error_inheritance(self):
"""Test MCPConnectionError inheritance chain."""
error = MCPConnectionError()
assert isinstance(error, MCPConnectionError)
assert isinstance(error, MCPError)
assert isinstance(error, Exception)
def test_mcp_connection_error_raise(self):
"""Test raising MCPConnectionError."""
with pytest.raises(MCPConnectionError) as exc_info:
raise MCPConnectionError("Unable to connect to server")
assert str(exc_info.value) == "Unable to connect to server"
def test_mcp_connection_error_catch_as_mcp_error(self):
"""Test catching MCPConnectionError as MCPError."""
with pytest.raises(MCPError) as exc_info:
raise MCPConnectionError("Connection issue")
assert isinstance(exc_info.value, MCPConnectionError)
assert str(exc_info.value) == "Connection issue"
class TestMCPAuthError:
"""Test MCPAuthError exception class."""
def test_mcp_auth_error_creation(self):
"""Test creating MCPAuthError instance."""
error = MCPAuthError("Authentication failed")
assert str(error) == "Authentication failed"
assert isinstance(error, MCPConnectionError)
assert isinstance(error, MCPError)
assert isinstance(error, Exception)
def test_mcp_auth_error_inheritance(self):
"""Test MCPAuthError inheritance chain."""
error = MCPAuthError()
assert isinstance(error, MCPAuthError)
assert isinstance(error, MCPConnectionError)
assert isinstance(error, MCPError)
assert isinstance(error, Exception)
def test_mcp_auth_error_raise(self):
"""Test raising MCPAuthError."""
with pytest.raises(MCPAuthError) as exc_info:
raise MCPAuthError("Invalid credentials")
assert str(exc_info.value) == "Invalid credentials"
def test_mcp_auth_error_catch_hierarchy(self):
"""Test catching MCPAuthError at different levels."""
# Catch as MCPAuthError
with pytest.raises(MCPAuthError) as exc_info:
raise MCPAuthError("Auth specific error")
assert str(exc_info.value) == "Auth specific error"
# Catch as MCPConnectionError
with pytest.raises(MCPConnectionError) as exc_info:
raise MCPAuthError("Auth connection error")
assert isinstance(exc_info.value, MCPAuthError)
assert str(exc_info.value) == "Auth connection error"
# Catch as MCPError
with pytest.raises(MCPError) as exc_info:
raise MCPAuthError("Auth base error")
assert isinstance(exc_info.value, MCPAuthError)
assert str(exc_info.value) == "Auth base error"
class TestErrorHierarchy:
"""Test the complete error hierarchy."""
def test_exception_hierarchy(self):
"""Test the complete exception hierarchy."""
# Create instances
base_error = MCPError("base")
connection_error = MCPConnectionError("connection")
auth_error = MCPAuthError("auth")
# Test type relationships
assert not isinstance(base_error, MCPConnectionError)
assert not isinstance(base_error, MCPAuthError)
assert isinstance(connection_error, MCPError)
assert not isinstance(connection_error, MCPAuthError)
assert isinstance(auth_error, MCPError)
assert isinstance(auth_error, MCPConnectionError)
def test_error_handling_patterns(self):
"""Test common error handling patterns."""
def raise_auth_error():
raise MCPAuthError("401 Unauthorized")
def raise_connection_error():
raise MCPConnectionError("Connection timeout")
def raise_base_error():
raise MCPError("Generic error")
# Pattern 1: Catch specific errors first
errors_caught = []
for error_func in [raise_auth_error, raise_connection_error, raise_base_error]:
try:
error_func()
except MCPAuthError:
errors_caught.append("auth")
except MCPConnectionError:
errors_caught.append("connection")
except MCPError:
errors_caught.append("base")
assert errors_caught == ["auth", "connection", "base"]
# Pattern 2: Catch all as base error
for error_func in [raise_auth_error, raise_connection_error, raise_base_error]:
with pytest.raises(MCPError) as exc_info:
error_func()
assert isinstance(exc_info.value, MCPError)
def test_error_with_cause(self):
"""Test errors with cause (chained exceptions)."""
original_error = ValueError("Original error")
def raise_chained_error():
try:
raise original_error
except ValueError as e:
raise MCPConnectionError("Connection failed") from e
with pytest.raises(MCPConnectionError) as exc_info:
raise_chained_error()
assert str(exc_info.value) == "Connection failed"
assert exc_info.value.__cause__ == original_error
def test_error_comparison(self):
"""Test error instance comparison."""
error1 = MCPError("Test message")
error2 = MCPError("Test message")
error3 = MCPError("Different message")
# Errors are not equal even with same message (different instances)
assert error1 != error2
assert error1 != error3
# But they have the same type
assert type(error1) == type(error2) == type(error3)
def test_error_representation(self):
"""Test error string representation."""
base_error = MCPError("Base error message")
connection_error = MCPConnectionError("Connection error message")
auth_error = MCPAuthError("Auth error message")
assert repr(base_error) == "MCPError('Base error message')"
assert repr(connection_error) == "MCPConnectionError('Connection error message')"
assert repr(auth_error) == "MCPAuthError('Auth error message')"

View File

@ -0,0 +1,382 @@
"""Unit tests for MCP client."""
from contextlib import ExitStack
from types import TracebackType
from unittest.mock import Mock, patch
import pytest
from core.mcp.error import MCPConnectionError
from core.mcp.mcp_client import MCPClient
from core.mcp.types import CallToolResult, ListToolsResult, TextContent, Tool, ToolAnnotations
class TestMCPClient:
"""Test suite for MCPClient."""
def test_init(self):
"""Test client initialization."""
client = MCPClient(
server_url="http://test.example.com/mcp",
headers={"Authorization": "Bearer test"},
timeout=30.0,
sse_read_timeout=60.0,
)
assert client.server_url == "http://test.example.com/mcp"
assert client.headers == {"Authorization": "Bearer test"}
assert client.timeout == 30.0
assert client.sse_read_timeout == 60.0
assert client._session is None
assert isinstance(client._exit_stack, ExitStack)
assert client._initialized is False
def test_init_defaults(self):
"""Test client initialization with defaults."""
client = MCPClient(server_url="http://test.example.com")
assert client.server_url == "http://test.example.com"
assert client.headers == {}
assert client.timeout is None
assert client.sse_read_timeout is None
@patch("core.mcp.mcp_client.streamablehttp_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_initialize_with_mcp_url(self, mock_client_session, mock_streamable_client):
"""Test initialization with MCP URL."""
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_client_context = Mock()
mock_streamable_client.return_value.__enter__.return_value = (
mock_read_stream,
mock_write_stream,
mock_client_context,
)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(server_url="http://test.example.com/mcp")
client._initialize()
# Verify streamable client was called
mock_streamable_client.assert_called_once_with(
url="http://test.example.com/mcp",
headers={},
timeout=None,
sse_read_timeout=None,
)
# Verify session was created
mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream)
mock_session.initialize.assert_called_once()
assert client._session == mock_session
@patch("core.mcp.mcp_client.sse_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_initialize_with_sse_url(self, mock_client_session, mock_sse_client):
"""Test initialization with SSE URL."""
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_sse_client.return_value.__enter__.return_value = (mock_read_stream, mock_write_stream)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(server_url="http://test.example.com/sse")
client._initialize()
# Verify SSE client was called
mock_sse_client.assert_called_once_with(
url="http://test.example.com/sse",
headers={},
timeout=None,
sse_read_timeout=None,
)
# Verify session was created
mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream)
mock_session.initialize.assert_called_once()
assert client._session == mock_session
@patch("core.mcp.mcp_client.sse_client")
@patch("core.mcp.mcp_client.streamablehttp_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_initialize_with_unknown_method_fallback_to_sse(
self, mock_client_session, mock_streamable_client, mock_sse_client
):
"""Test initialization with unknown method falls back to SSE."""
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_sse_client.return_value.__enter__.return_value = (mock_read_stream, mock_write_stream)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(server_url="http://test.example.com/unknown")
client._initialize()
# Verify SSE client was tried
mock_sse_client.assert_called_once()
mock_streamable_client.assert_not_called()
# Verify session was created
assert client._session == mock_session
@patch("core.mcp.mcp_client.sse_client")
@patch("core.mcp.mcp_client.streamablehttp_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_initialize_fallback_from_sse_to_mcp(self, mock_client_session, mock_streamable_client, mock_sse_client):
"""Test initialization falls back from SSE to MCP on connection error."""
# Setup SSE to fail
mock_sse_client.side_effect = MCPConnectionError("SSE connection failed")
# Setup MCP to succeed
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_client_context = Mock()
mock_streamable_client.return_value.__enter__.return_value = (
mock_read_stream,
mock_write_stream,
mock_client_context,
)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(server_url="http://test.example.com/unknown")
client._initialize()
# Verify both were tried
mock_sse_client.assert_called_once()
mock_streamable_client.assert_called_once()
# Verify session was created with MCP
assert client._session == mock_session
@patch("core.mcp.mcp_client.streamablehttp_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_connect_server_mcp(self, mock_client_session, mock_streamable_client):
"""Test connect_server with MCP method."""
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_client_context = Mock()
mock_streamable_client.return_value.__enter__.return_value = (
mock_read_stream,
mock_write_stream,
mock_client_context,
)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(server_url="http://test.example.com")
client.connect_server(mock_streamable_client, "mcp")
# Verify correct streams were passed
mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream)
mock_session.initialize.assert_called_once()
@patch("core.mcp.mcp_client.sse_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_connect_server_sse(self, mock_client_session, mock_sse_client):
"""Test connect_server with SSE method."""
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_sse_client.return_value.__enter__.return_value = (mock_read_stream, mock_write_stream)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(server_url="http://test.example.com")
client.connect_server(mock_sse_client, "sse")
# Verify correct streams were passed
mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream)
mock_session.initialize.assert_called_once()
def test_context_manager_enter(self):
"""Test context manager enter."""
client = MCPClient(server_url="http://test.example.com")
with patch.object(client, "_initialize") as mock_initialize:
result = client.__enter__()
assert result == client
assert client._initialized is True
mock_initialize.assert_called_once()
def test_context_manager_exit(self):
"""Test context manager exit."""
client = MCPClient(server_url="http://test.example.com")
with patch.object(client, "cleanup") as mock_cleanup:
exc_type: type[BaseException] | None = None
exc_val: BaseException | None = None
exc_tb: TracebackType | None = None
client.__exit__(exc_type, exc_val, exc_tb)
mock_cleanup.assert_called_once()
def test_list_tools_not_initialized(self):
"""Test list_tools when session not initialized."""
client = MCPClient(server_url="http://test.example.com")
with pytest.raises(ValueError) as exc_info:
client.list_tools()
assert "Session not initialized" in str(exc_info.value)
def test_list_tools_success(self):
"""Test successful list_tools call."""
client = MCPClient(server_url="http://test.example.com")
# Setup mock session
mock_session = Mock()
expected_tools = [
Tool(
name="test-tool",
description="A test tool",
inputSchema={"type": "object", "properties": {}},
annotations=ToolAnnotations(title="Test Tool"),
)
]
mock_session.list_tools.return_value = ListToolsResult(tools=expected_tools)
client._session = mock_session
result = client.list_tools()
assert result == expected_tools
mock_session.list_tools.assert_called_once()
def test_invoke_tool_not_initialized(self):
"""Test invoke_tool when session not initialized."""
client = MCPClient(server_url="http://test.example.com")
with pytest.raises(ValueError) as exc_info:
client.invoke_tool("test-tool", {"arg": "value"})
assert "Session not initialized" in str(exc_info.value)
def test_invoke_tool_success(self):
"""Test successful invoke_tool call."""
client = MCPClient(server_url="http://test.example.com")
# Setup mock session
mock_session = Mock()
expected_result = CallToolResult(
content=[TextContent(type="text", text="Tool executed successfully")],
isError=False,
)
mock_session.call_tool.return_value = expected_result
client._session = mock_session
result = client.invoke_tool("test-tool", {"arg": "value"})
assert result == expected_result
mock_session.call_tool.assert_called_once_with("test-tool", {"arg": "value"})
def test_cleanup(self):
"""Test cleanup method."""
client = MCPClient(server_url="http://test.example.com")
mock_exit_stack = Mock(spec=ExitStack)
client._exit_stack = mock_exit_stack
client._session = Mock()
client._initialized = True
client.cleanup()
mock_exit_stack.close.assert_called_once()
assert client._session is None
assert client._initialized is False
def test_cleanup_with_error(self):
"""Test cleanup method with error."""
client = MCPClient(server_url="http://test.example.com")
mock_exit_stack = Mock(spec=ExitStack)
mock_exit_stack.close.side_effect = Exception("Cleanup error")
client._exit_stack = mock_exit_stack
client._session = Mock()
client._initialized = True
with pytest.raises(ValueError) as exc_info:
client.cleanup()
assert "Error during cleanup: Cleanup error" in str(exc_info.value)
assert client._session is None
assert client._initialized is False
@patch("core.mcp.mcp_client.streamablehttp_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_full_context_manager_flow(self, mock_client_session, mock_streamable_client):
"""Test full context manager flow."""
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_client_context = Mock()
mock_streamable_client.return_value.__enter__.return_value = (
mock_read_stream,
mock_write_stream,
mock_client_context,
)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
expected_tools = [Tool(name="test-tool", description="Test", inputSchema={})]
mock_session.list_tools.return_value = ListToolsResult(tools=expected_tools)
with MCPClient(server_url="http://test.example.com/mcp") as client:
assert client._initialized is True
assert client._session == mock_session
# Test tool operations
tools = client.list_tools()
assert tools == expected_tools
# After exit, should be cleaned up
assert client._initialized is False
assert client._session is None
def test_headers_passed_to_clients(self):
"""Test that headers are properly passed to underlying clients."""
custom_headers = {
"Authorization": "Bearer test-token",
"X-Custom-Header": "test-value",
}
with patch("core.mcp.mcp_client.streamablehttp_client") as mock_streamable_client:
with patch("core.mcp.mcp_client.ClientSession") as mock_client_session:
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_client_context = Mock()
mock_streamable_client.return_value.__enter__.return_value = (
mock_read_stream,
mock_write_stream,
mock_client_context,
)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(
server_url="http://test.example.com/mcp",
headers=custom_headers,
timeout=30.0,
sse_read_timeout=60.0,
)
client._initialize()
# Verify headers were passed
mock_streamable_client.assert_called_once_with(
url="http://test.example.com/mcp",
headers=custom_headers,
timeout=30.0,
sse_read_timeout=60.0,
)

View File

@ -0,0 +1,492 @@
"""Unit tests for MCP types module."""
import pytest
from pydantic import ValidationError
from core.mcp.types import (
INTERNAL_ERROR,
INVALID_PARAMS,
INVALID_REQUEST,
LATEST_PROTOCOL_VERSION,
METHOD_NOT_FOUND,
PARSE_ERROR,
SERVER_LATEST_PROTOCOL_VERSION,
Annotations,
CallToolRequest,
CallToolRequestParams,
CallToolResult,
ClientCapabilities,
CompleteRequest,
CompleteRequestParams,
CompleteResult,
Completion,
CompletionArgument,
CompletionContext,
ErrorData,
ImageContent,
Implementation,
InitializeRequest,
InitializeRequestParams,
InitializeResult,
JSONRPCError,
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
ListToolsRequest,
ListToolsResult,
OAuthClientInformation,
OAuthClientMetadata,
OAuthMetadata,
OAuthTokens,
PingRequest,
ProgressNotification,
ProgressNotificationParams,
PromptReference,
RequestParams,
ResourceTemplateReference,
Result,
ServerCapabilities,
TextContent,
Tool,
ToolAnnotations,
)
class TestConstants:
"""Test module constants."""
def test_protocol_versions(self):
"""Test protocol version constants."""
assert LATEST_PROTOCOL_VERSION == "2025-03-26"
assert SERVER_LATEST_PROTOCOL_VERSION == "2024-11-05"
def test_error_codes(self):
"""Test JSON-RPC error code constants."""
assert PARSE_ERROR == -32700
assert INVALID_REQUEST == -32600
assert METHOD_NOT_FOUND == -32601
assert INVALID_PARAMS == -32602
assert INTERNAL_ERROR == -32603
class TestRequestParams:
"""Test RequestParams and related classes."""
def test_request_params_basic(self):
"""Test basic RequestParams creation."""
params = RequestParams()
assert params.meta is None
def test_request_params_with_meta(self):
"""Test RequestParams with meta."""
meta = RequestParams.Meta(progressToken="test-token")
params = RequestParams(_meta=meta)
assert params.meta is not None
assert params.meta.progressToken == "test-token"
def test_request_params_meta_extra_fields(self):
"""Test RequestParams.Meta allows extra fields."""
meta = RequestParams.Meta(progressToken="token", customField="value")
assert meta.progressToken == "token"
assert meta.customField == "value" # type: ignore
def test_request_params_serialization(self):
"""Test RequestParams serialization with _meta alias."""
meta = RequestParams.Meta(progressToken="test")
params = RequestParams(_meta=meta)
# Model dump should use the alias
dumped = params.model_dump(by_alias=True)
assert "_meta" in dumped
assert dumped["_meta"] is not None
assert dumped["_meta"]["progressToken"] == "test"
class TestJSONRPCMessages:
"""Test JSON-RPC message types."""
def test_jsonrpc_request(self):
"""Test JSONRPCRequest creation and validation."""
request = JSONRPCRequest(jsonrpc="2.0", id="test-123", method="test_method", params={"key": "value"})
assert request.jsonrpc == "2.0"
assert request.id == "test-123"
assert request.method == "test_method"
assert request.params == {"key": "value"}
def test_jsonrpc_request_numeric_id(self):
"""Test JSONRPCRequest with numeric ID."""
request = JSONRPCRequest(jsonrpc="2.0", id=123, method="test", params=None)
assert request.id == 123
def test_jsonrpc_notification(self):
"""Test JSONRPCNotification creation."""
notification = JSONRPCNotification(jsonrpc="2.0", method="notification_method", params={"data": "test"})
assert notification.jsonrpc == "2.0"
assert notification.method == "notification_method"
assert not hasattr(notification, "id") # Notifications don't have ID
def test_jsonrpc_response(self):
"""Test JSONRPCResponse creation."""
response = JSONRPCResponse(jsonrpc="2.0", id="req-123", result={"success": True})
assert response.jsonrpc == "2.0"
assert response.id == "req-123"
assert response.result == {"success": True}
def test_jsonrpc_error(self):
"""Test JSONRPCError creation."""
error_data = ErrorData(code=INVALID_PARAMS, message="Invalid parameters", data={"field": "missing"})
error = JSONRPCError(jsonrpc="2.0", id="req-123", error=error_data)
assert error.jsonrpc == "2.0"
assert error.id == "req-123"
assert error.error.code == INVALID_PARAMS
assert error.error.message == "Invalid parameters"
assert error.error.data == {"field": "missing"}
def test_jsonrpc_message_parsing(self):
"""Test JSONRPCMessage parsing different message types."""
# Parse request
request_json = '{"jsonrpc": "2.0", "id": 1, "method": "test", "params": null}'
msg = JSONRPCMessage.model_validate_json(request_json)
assert isinstance(msg.root, JSONRPCRequest)
# Parse response
response_json = '{"jsonrpc": "2.0", "id": 1, "result": {"data": "test"}}'
msg = JSONRPCMessage.model_validate_json(response_json)
assert isinstance(msg.root, JSONRPCResponse)
# Parse error
error_json = '{"jsonrpc": "2.0", "id": 1, "error": {"code": -32600, "message": "Invalid Request"}}'
msg = JSONRPCMessage.model_validate_json(error_json)
assert isinstance(msg.root, JSONRPCError)
class TestCapabilities:
"""Test capability classes."""
def test_client_capabilities(self):
"""Test ClientCapabilities creation."""
caps = ClientCapabilities(
experimental={"feature": {"enabled": True}},
sampling={"model_config": {"extra": "allow"}},
roots={"listChanged": True},
)
assert caps.experimental == {"feature": {"enabled": True}}
assert caps.sampling is not None
assert caps.roots.listChanged is True # type: ignore
def test_server_capabilities(self):
"""Test ServerCapabilities creation."""
caps = ServerCapabilities(
tools={"listChanged": True},
resources={"subscribe": True, "listChanged": False},
prompts={"listChanged": True},
logging={},
completions={},
)
assert caps.tools.listChanged is True # type: ignore
assert caps.resources.subscribe is True # type: ignore
assert caps.resources.listChanged is False # type: ignore
class TestInitialization:
"""Test initialization request/response types."""
def test_initialize_request(self):
"""Test InitializeRequest creation."""
client_info = Implementation(name="test-client", version="1.0.0")
capabilities = ClientCapabilities()
params = InitializeRequestParams(
protocolVersion=LATEST_PROTOCOL_VERSION, capabilities=capabilities, clientInfo=client_info
)
request = InitializeRequest(params=params)
assert request.method == "initialize"
assert request.params.protocolVersion == LATEST_PROTOCOL_VERSION
assert request.params.clientInfo.name == "test-client"
def test_initialize_result(self):
"""Test InitializeResult creation."""
server_info = Implementation(name="test-server", version="1.0.0")
capabilities = ServerCapabilities()
result = InitializeResult(
protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=capabilities,
serverInfo=server_info,
instructions="Welcome to test server",
)
assert result.protocolVersion == LATEST_PROTOCOL_VERSION
assert result.serverInfo.name == "test-server"
assert result.instructions == "Welcome to test server"
class TestTools:
"""Test tool-related types."""
def test_tool_creation(self):
"""Test Tool creation with all fields."""
tool = Tool(
name="test_tool",
title="Test Tool",
description="A tool for testing",
inputSchema={"type": "object", "properties": {"input": {"type": "string"}}, "required": ["input"]},
outputSchema={"type": "object", "properties": {"result": {"type": "string"}}},
annotations=ToolAnnotations(
title="Test Tool", readOnlyHint=False, destructiveHint=False, idempotentHint=True
),
)
assert tool.name == "test_tool"
assert tool.title == "Test Tool"
assert tool.description == "A tool for testing"
assert tool.inputSchema["properties"]["input"]["type"] == "string"
assert tool.annotations.idempotentHint is True
def test_call_tool_request(self):
"""Test CallToolRequest creation."""
params = CallToolRequestParams(name="test_tool", arguments={"input": "test value"})
request = CallToolRequest(params=params)
assert request.method == "tools/call"
assert request.params.name == "test_tool"
assert request.params.arguments == {"input": "test value"}
def test_call_tool_result(self):
"""Test CallToolResult creation."""
result = CallToolResult(
content=[TextContent(type="text", text="Tool executed successfully")],
structuredContent={"status": "success", "data": "test"},
isError=False,
)
assert len(result.content) == 1
assert result.content[0].text == "Tool executed successfully" # type: ignore
assert result.structuredContent == {"status": "success", "data": "test"}
assert result.isError is False
def test_list_tools_request(self):
"""Test ListToolsRequest creation."""
request = ListToolsRequest()
assert request.method == "tools/list"
def test_list_tools_result(self):
"""Test ListToolsResult creation."""
tool1 = Tool(name="tool1", inputSchema={})
tool2 = Tool(name="tool2", inputSchema={})
result = ListToolsResult(tools=[tool1, tool2])
assert len(result.tools) == 2
assert result.tools[0].name == "tool1"
assert result.tools[1].name == "tool2"
class TestContent:
"""Test content types."""
def test_text_content(self):
"""Test TextContent creation."""
annotations = Annotations(audience=["user"], priority=0.8)
content = TextContent(type="text", text="Hello, world!", annotations=annotations)
assert content.type == "text"
assert content.text == "Hello, world!"
assert content.annotations is not None
assert content.annotations.priority == 0.8
def test_image_content(self):
"""Test ImageContent creation."""
content = ImageContent(type="image", data="base64encodeddata", mimeType="image/png")
assert content.type == "image"
assert content.data == "base64encodeddata"
assert content.mimeType == "image/png"
class TestOAuth:
"""Test OAuth-related types."""
def test_oauth_client_metadata(self):
"""Test OAuthClientMetadata creation."""
metadata = OAuthClientMetadata(
client_name="Test Client",
redirect_uris=["https://example.com/callback"],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
token_endpoint_auth_method="none",
client_uri="https://example.com",
scope="read write",
)
assert metadata.client_name == "Test Client"
assert len(metadata.redirect_uris) == 1
assert "authorization_code" in metadata.grant_types
def test_oauth_client_information(self):
"""Test OAuthClientInformation creation."""
info = OAuthClientInformation(client_id="test-client-id", client_secret="test-secret")
assert info.client_id == "test-client-id"
assert info.client_secret == "test-secret"
def test_oauth_client_information_without_secret(self):
"""Test OAuthClientInformation without secret."""
info = OAuthClientInformation(client_id="public-client")
assert info.client_id == "public-client"
assert info.client_secret is None
def test_oauth_tokens(self):
"""Test OAuthTokens creation."""
tokens = OAuthTokens(
access_token="access-token-123",
token_type="Bearer",
expires_in=3600,
refresh_token="refresh-token-456",
scope="read write",
)
assert tokens.access_token == "access-token-123"
assert tokens.token_type == "Bearer"
assert tokens.expires_in == 3600
assert tokens.refresh_token == "refresh-token-456"
assert tokens.scope == "read write"
def test_oauth_metadata(self):
"""Test OAuthMetadata creation."""
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
registration_endpoint="https://auth.example.com/register",
response_types_supported=["code", "token"],
grant_types_supported=["authorization_code", "refresh_token"],
code_challenge_methods_supported=["plain", "S256"],
)
assert metadata.authorization_endpoint == "https://auth.example.com/authorize"
assert "code" in metadata.response_types_supported
assert "S256" in metadata.code_challenge_methods_supported
class TestNotifications:
"""Test notification types."""
def test_progress_notification(self):
"""Test ProgressNotification creation."""
params = ProgressNotificationParams(
progressToken="progress-123", progress=50.0, total=100.0, message="Processing... 50%"
)
notification = ProgressNotification(params=params)
assert notification.method == "notifications/progress"
assert notification.params.progressToken == "progress-123"
assert notification.params.progress == 50.0
assert notification.params.total == 100.0
assert notification.params.message == "Processing... 50%"
def test_ping_request(self):
"""Test PingRequest creation."""
request = PingRequest()
assert request.method == "ping"
assert request.params is None
class TestCompletion:
"""Test completion-related types."""
def test_completion_context(self):
"""Test CompletionContext creation."""
context = CompletionContext(arguments={"template_var": "value"})
assert context.arguments == {"template_var": "value"}
def test_resource_template_reference(self):
"""Test ResourceTemplateReference creation."""
ref = ResourceTemplateReference(type="ref/resource", uri="file:///path/to/{filename}")
assert ref.type == "ref/resource"
assert ref.uri == "file:///path/to/{filename}"
def test_prompt_reference(self):
"""Test PromptReference creation."""
ref = PromptReference(type="ref/prompt", name="test_prompt")
assert ref.type == "ref/prompt"
assert ref.name == "test_prompt"
def test_complete_request(self):
"""Test CompleteRequest creation."""
ref = PromptReference(type="ref/prompt", name="test_prompt")
arg = CompletionArgument(name="arg1", value="val")
params = CompleteRequestParams(ref=ref, argument=arg, context=CompletionContext(arguments={"key": "value"}))
request = CompleteRequest(params=params)
assert request.method == "completion/complete"
assert request.params.ref.name == "test_prompt" # type: ignore
assert request.params.argument.name == "arg1"
def test_complete_result(self):
"""Test CompleteResult creation."""
completion = Completion(values=["option1", "option2", "option3"], total=10, hasMore=True)
result = CompleteResult(completion=completion)
assert len(result.completion.values) == 3
assert result.completion.total == 10
assert result.completion.hasMore is True
class TestValidation:
"""Test validation of various types."""
def test_invalid_jsonrpc_version(self):
"""Test invalid JSON-RPC version validation."""
with pytest.raises(ValidationError):
JSONRPCRequest(
jsonrpc="1.0", # Invalid version
id=1,
method="test",
)
def test_tool_annotations_validation(self):
"""Test ToolAnnotations with invalid values."""
# Valid annotations
annotations = ToolAnnotations(
title="Test", readOnlyHint=True, destructiveHint=False, idempotentHint=True, openWorldHint=False
)
assert annotations.title == "Test"
def test_extra_fields_allowed(self):
"""Test that extra fields are allowed in models."""
# Most models should allow extra fields
tool = Tool(
name="test",
inputSchema={},
customField="allowed", # type: ignore
)
assert tool.customField == "allowed" # type: ignore
def test_result_meta_alias(self):
"""Test Result model with _meta alias."""
# Create with the field name (not alias)
result = Result(_meta={"key": "value"})
# Verify the field is set correctly
assert result.meta == {"key": "value"}
# Dump with alias
dumped = result.model_dump(by_alias=True)
assert "_meta" in dumped
assert dumped["_meta"] == {"key": "value"}

View File

@ -0,0 +1,355 @@
"""Unit tests for MCP utils module."""
import json
from collections.abc import Generator
from unittest.mock import MagicMock, Mock, patch
import httpx
import httpx_sse
import pytest
from core.mcp.utils import (
STATUS_FORCELIST,
create_mcp_error_response,
create_ssrf_proxy_mcp_http_client,
ssrf_proxy_sse_connect,
)
class TestConstants:
"""Test module constants."""
def test_status_forcelist(self):
"""Test STATUS_FORCELIST contains expected HTTP status codes."""
assert STATUS_FORCELIST == [429, 500, 502, 503, 504]
assert 429 in STATUS_FORCELIST # Too Many Requests
assert 500 in STATUS_FORCELIST # Internal Server Error
assert 502 in STATUS_FORCELIST # Bad Gateway
assert 503 in STATUS_FORCELIST # Service Unavailable
assert 504 in STATUS_FORCELIST # Gateway Timeout
class TestCreateSSRFProxyMCPHTTPClient:
"""Test create_ssrf_proxy_mcp_http_client function."""
@patch("core.mcp.utils.dify_config")
def test_create_client_with_all_url_proxy(self, mock_config):
"""Test client creation with SSRF_PROXY_ALL_URL configured."""
mock_config.SSRF_PROXY_ALL_URL = "http://proxy.example.com:8080"
mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = True
client = create_ssrf_proxy_mcp_http_client(
headers={"Authorization": "Bearer token"}, timeout=httpx.Timeout(30.0)
)
assert isinstance(client, httpx.Client)
assert client.headers["Authorization"] == "Bearer token"
assert client.timeout.connect == 30.0
assert client.follow_redirects is True
# Clean up
client.close()
@patch("core.mcp.utils.dify_config")
def test_create_client_with_http_https_proxies(self, mock_config):
"""Test client creation with separate HTTP/HTTPS proxies."""
mock_config.SSRF_PROXY_ALL_URL = None
mock_config.SSRF_PROXY_HTTP_URL = "http://http-proxy.example.com:8080"
mock_config.SSRF_PROXY_HTTPS_URL = "http://https-proxy.example.com:8443"
mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = False
client = create_ssrf_proxy_mcp_http_client()
assert isinstance(client, httpx.Client)
assert client.follow_redirects is True
# Clean up
client.close()
@patch("core.mcp.utils.dify_config")
def test_create_client_without_proxy(self, mock_config):
"""Test client creation without proxy configuration."""
mock_config.SSRF_PROXY_ALL_URL = None
mock_config.SSRF_PROXY_HTTP_URL = None
mock_config.SSRF_PROXY_HTTPS_URL = None
mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = True
headers = {"X-Custom-Header": "value"}
timeout = httpx.Timeout(timeout=30.0, connect=5.0, read=10.0, write=30.0)
client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout)
assert isinstance(client, httpx.Client)
assert client.headers["X-Custom-Header"] == "value"
assert client.timeout.connect == 5.0
assert client.timeout.read == 10.0
assert client.follow_redirects is True
# Clean up
client.close()
@patch("core.mcp.utils.dify_config")
def test_create_client_default_params(self, mock_config):
"""Test client creation with default parameters."""
mock_config.SSRF_PROXY_ALL_URL = None
mock_config.SSRF_PROXY_HTTP_URL = None
mock_config.SSRF_PROXY_HTTPS_URL = None
mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = True
client = create_ssrf_proxy_mcp_http_client()
assert isinstance(client, httpx.Client)
# httpx.Client adds default headers, so we just check it's a Headers object
assert isinstance(client.headers, httpx.Headers)
# When no timeout is provided, httpx uses its default timeout
assert client.timeout is not None
# Clean up
client.close()
class TestSSRFProxySSEConnect:
"""Test ssrf_proxy_sse_connect function."""
@patch("core.mcp.utils.connect_sse")
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
def test_sse_connect_with_provided_client(self, mock_create_client, mock_connect_sse):
"""Test SSE connection with pre-configured client."""
# Setup mocks
mock_client = Mock(spec=httpx.Client)
mock_event_source = Mock(spec=httpx_sse.EventSource)
mock_context = MagicMock()
mock_context.__enter__.return_value = mock_event_source
mock_connect_sse.return_value = mock_context
# Call with provided client
result = ssrf_proxy_sse_connect(
"http://example.com/sse", client=mock_client, method="POST", headers={"Authorization": "Bearer token"}
)
# Verify client creation was not called
mock_create_client.assert_not_called()
# Verify connect_sse was called correctly
mock_connect_sse.assert_called_once_with(
mock_client, "POST", "http://example.com/sse", headers={"Authorization": "Bearer token"}
)
# Verify result
assert result == mock_context
@patch("core.mcp.utils.connect_sse")
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
@patch("core.mcp.utils.dify_config")
def test_sse_connect_without_client(self, mock_config, mock_create_client, mock_connect_sse):
"""Test SSE connection without pre-configured client."""
# Setup config
mock_config.SSRF_DEFAULT_TIME_OUT = 30.0
mock_config.SSRF_DEFAULT_CONNECT_TIME_OUT = 10.0
mock_config.SSRF_DEFAULT_READ_TIME_OUT = 60.0
mock_config.SSRF_DEFAULT_WRITE_TIME_OUT = 30.0
# Setup mocks
mock_client = Mock(spec=httpx.Client)
mock_create_client.return_value = mock_client
mock_event_source = Mock(spec=httpx_sse.EventSource)
mock_context = MagicMock()
mock_context.__enter__.return_value = mock_event_source
mock_connect_sse.return_value = mock_context
# Call without client
result = ssrf_proxy_sse_connect("http://example.com/sse", headers={"X-Custom": "value"})
# Verify client was created
mock_create_client.assert_called_once()
call_args = mock_create_client.call_args
assert call_args[1]["headers"] == {"X-Custom": "value"}
timeout = call_args[1]["timeout"]
# httpx.Timeout object has these attributes
assert isinstance(timeout, httpx.Timeout)
assert timeout.connect == 10.0
assert timeout.read == 60.0
assert timeout.write == 30.0
# Verify connect_sse was called
mock_connect_sse.assert_called_once_with(
mock_client,
"GET", # Default method
"http://example.com/sse",
)
# Verify result
assert result == mock_context
@patch("core.mcp.utils.connect_sse")
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
def test_sse_connect_with_custom_timeout(self, mock_create_client, mock_connect_sse):
"""Test SSE connection with custom timeout."""
# Setup mocks
mock_client = Mock(spec=httpx.Client)
mock_create_client.return_value = mock_client
mock_event_source = Mock(spec=httpx_sse.EventSource)
mock_context = MagicMock()
mock_context.__enter__.return_value = mock_event_source
mock_connect_sse.return_value = mock_context
custom_timeout = httpx.Timeout(timeout=60.0, read=120.0)
# Call with custom timeout
result = ssrf_proxy_sse_connect("http://example.com/sse", timeout=custom_timeout)
# Verify client was created with custom timeout
mock_create_client.assert_called_once()
call_args = mock_create_client.call_args
assert call_args[1]["timeout"] == custom_timeout
# Verify result
assert result == mock_context
@patch("core.mcp.utils.connect_sse")
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
def test_sse_connect_error_cleanup(self, mock_create_client, mock_connect_sse):
"""Test SSE connection cleans up client on error."""
# Setup mocks
mock_client = Mock(spec=httpx.Client)
mock_create_client.return_value = mock_client
# Make connect_sse raise an exception
mock_connect_sse.side_effect = httpx.ConnectError("Connection failed")
# Call should raise the exception
with pytest.raises(httpx.ConnectError):
ssrf_proxy_sse_connect("http://example.com/sse")
# Verify client was cleaned up
mock_client.close.assert_called_once()
@patch("core.mcp.utils.connect_sse")
def test_sse_connect_error_no_cleanup_with_provided_client(self, mock_connect_sse):
"""Test SSE connection doesn't clean up provided client on error."""
# Setup mocks
mock_client = Mock(spec=httpx.Client)
# Make connect_sse raise an exception
mock_connect_sse.side_effect = httpx.ConnectError("Connection failed")
# Call should raise the exception
with pytest.raises(httpx.ConnectError):
ssrf_proxy_sse_connect("http://example.com/sse", client=mock_client)
# Verify client was NOT cleaned up (because it was provided)
mock_client.close.assert_not_called()
class TestCreateMCPErrorResponse:
"""Test create_mcp_error_response function."""
def test_create_error_response_basic(self):
"""Test creating basic error response."""
generator = create_mcp_error_response(request_id="req-123", code=-32600, message="Invalid Request")
# Generator should yield bytes
assert isinstance(generator, Generator)
# Get the response
response_bytes = next(generator)
assert isinstance(response_bytes, bytes)
# Parse the response
response_str = response_bytes.decode("utf-8")
response_json = json.loads(response_str)
assert response_json["jsonrpc"] == "2.0"
assert response_json["id"] == "req-123"
assert response_json["error"]["code"] == -32600
assert response_json["error"]["message"] == "Invalid Request"
assert response_json["error"]["data"] is None
# Generator should be exhausted
with pytest.raises(StopIteration):
next(generator)
def test_create_error_response_with_data(self):
"""Test creating error response with additional data."""
error_data = {"field": "username", "reason": "required"}
generator = create_mcp_error_response(
request_id=456, # Numeric ID
code=-32602,
message="Invalid params",
data=error_data,
)
response_bytes = next(generator)
response_json = json.loads(response_bytes.decode("utf-8"))
assert response_json["id"] == 456
assert response_json["error"]["code"] == -32602
assert response_json["error"]["message"] == "Invalid params"
assert response_json["error"]["data"] == error_data
def test_create_error_response_without_request_id(self):
"""Test creating error response without request ID."""
generator = create_mcp_error_response(request_id=None, code=-32700, message="Parse error")
response_bytes = next(generator)
response_json = json.loads(response_bytes.decode("utf-8"))
# Should default to ID 1
assert response_json["id"] == 1
assert response_json["error"]["code"] == -32700
assert response_json["error"]["message"] == "Parse error"
def test_create_error_response_with_complex_data(self):
"""Test creating error response with complex error data."""
complex_data = {
"errors": [{"field": "name", "message": "Too short"}, {"field": "email", "message": "Invalid format"}],
"timestamp": "2024-01-01T00:00:00Z",
}
generator = create_mcp_error_response(
request_id="complex-req", code=-32602, message="Validation failed", data=complex_data
)
response_bytes = next(generator)
response_json = json.loads(response_bytes.decode("utf-8"))
assert response_json["error"]["data"] == complex_data
assert len(response_json["error"]["data"]["errors"]) == 2
def test_create_error_response_encoding(self):
"""Test error response with non-ASCII characters."""
generator = create_mcp_error_response(
request_id="unicode-req",
code=-32603,
message="内部错误", # Chinese characters
data={"details": "エラー詳細"}, # Japanese characters
)
response_bytes = next(generator)
# Should be valid UTF-8
response_str = response_bytes.decode("utf-8")
response_json = json.loads(response_str)
assert response_json["error"]["message"] == "内部错误"
assert response_json["error"]["data"]["details"] == "エラー詳細"
def test_create_error_response_yields_once(self):
"""Test that error response generator yields exactly once."""
generator = create_mcp_error_response(request_id="test", code=-32600, message="Test")
# First yield should work
first_yield = next(generator)
assert isinstance(first_yield, bytes)
# Second yield should raise StopIteration
with pytest.raises(StopIteration):
next(generator)
# Subsequent calls should also raise
with pytest.raises(StopIteration):
next(generator)

View File

@ -180,6 +180,25 @@ class TestMCPToolTransform:
# Set tools data with null description
mock_provider_full.tools = '[{"name": "tool1", "description": null, "inputSchema": {}}]'
# Mock the to_entity and to_api_response methods
mock_entity = Mock()
mock_entity.to_api_response.return_value = {
"name": "Test MCP Provider",
"type": ToolProviderType.MCP,
"is_team_authorization": True,
"server_url": "https://*****.com/mcp",
"provider_icon": "icon.png",
"masked_headers": {"Authorization": "Bearer *****"},
"updated_at": 1234567890,
"labels": [],
"author": "Test User",
"description": I18nObject(en_US="Test MCP Provider Description", zh_Hans="Test MCP Provider Description"),
"icon": "icon.png",
"label": I18nObject(en_US="Test MCP Provider", zh_Hans="Test MCP Provider"),
"masked_credentials": {},
}
mock_provider_full.to_entity.return_value = mock_entity
# Call the method with for_list=True
result = ToolTransformService.mcp_provider_to_user_provider(mock_provider_full, for_list=True)
@ -198,6 +217,27 @@ class TestMCPToolTransform:
# Set tools data with description
mock_provider_full.tools = '[{"name": "tool1", "description": "Tool description", "inputSchema": {}}]'
# Mock the to_entity and to_api_response methods
mock_entity = Mock()
mock_entity.to_api_response.return_value = {
"name": "Test MCP Provider",
"type": ToolProviderType.MCP,
"is_team_authorization": True,
"server_url": "https://*****.com/mcp",
"provider_icon": "icon.png",
"masked_headers": {"Authorization": "Bearer *****"},
"updated_at": 1234567890,
"labels": [],
"configuration": {"timeout": "30", "sse_read_timeout": "300"},
"original_headers": {"Authorization": "Bearer secret-token"},
"author": "Test User",
"description": I18nObject(en_US="Test MCP Provider Description", zh_Hans="Test MCP Provider Description"),
"icon": "icon.png",
"label": I18nObject(en_US="Test MCP Provider", zh_Hans="Test MCP Provider"),
"masked_credentials": {},
}
mock_provider_full.to_entity.return_value = mock_entity
# Call the method with for_list=False
result = ToolTransformService.mcp_provider_to_user_provider(mock_provider_full, for_list=False)
@ -205,8 +245,9 @@ class TestMCPToolTransform:
assert isinstance(result, ToolProviderApiEntity)
assert result.id == "server-identifier-456" # Should use server_identifier when for_list=False
assert result.server_identifier == "server-identifier-456"
assert result.timeout == 30
assert result.sse_read_timeout == 300
assert result.configuration is not None
assert result.configuration.timeout == 30
assert result.configuration.sse_read_timeout == 300
assert result.original_headers == {"Authorization": "Bearer secret-token"}
assert len(result.tools) == 1
assert result.tools[0].description.en_US == "Tool description"

File diff suppressed because it is too large Load Diff

View File

@ -329,7 +329,7 @@ services:
# The Weaviate vector store.
weaviate:
image: semitechnologies/weaviate:1.19.0
image: semitechnologies/weaviate:1.27.0
profiles:
- ""
- weaviate

View File

@ -181,7 +181,7 @@ services:
# The Weaviate vector store.
weaviate:
image: semitechnologies/weaviate:1.19.0
image: semitechnologies/weaviate:1.27.0
profiles:
- ""
- weaviate
@ -206,6 +206,7 @@ services:
AUTHORIZATION_ADMINLIST_USERS: ${WEAVIATE_AUTHORIZATION_ADMINLIST_USERS:-hello@dify.ai}
ports:
- "${EXPOSE_WEAVIATE_PORT:-8080}:8080"
- "${EXPOSE_WEAVIATE_GRPC_PORT:-50051}:50051"
networks:
# create a network between sandbox, api and ssrf_proxy, and can not access outside.

View File

@ -0,0 +1,9 @@
services:
api:
volumes:
- ../api/core/rag/datasource/vdb/weaviate/weaviate_vector.py:/app/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py:ro
command: >
sh -c "
pip install --no-cache-dir 'weaviate>=4.0.0' &&
/bin/bash /entrypoint.sh
"

View File

@ -936,7 +936,7 @@ services:
# The Weaviate vector store.
weaviate:
image: semitechnologies/weaviate:1.19.0
image: semitechnologies/weaviate:1.27.0
profiles:
- ""
- weaviate

View File

@ -5,7 +5,7 @@ import { useTranslation } from 'react-i18next'
import { useBoolean } from 'ahooks'
import TracingIcon from './tracing-icon'
import ProviderPanel from './provider-panel'
import type { AliyunConfig, ArizeConfig, LangFuseConfig, LangSmithConfig, OpikConfig, PhoenixConfig, WeaveConfig } from './type'
import type { AliyunConfig, ArizeConfig, LangFuseConfig, LangSmithConfig, OpikConfig, PhoenixConfig, TencentConfig, WeaveConfig } from './type'
import { TracingProvider } from './type'
import ProviderConfigModal from './provider-config-modal'
import Indicator from '@/app/components/header/indicator'
@ -30,7 +30,8 @@ export type PopupProps = {
opikConfig: OpikConfig | null
weaveConfig: WeaveConfig | null
aliyunConfig: AliyunConfig | null
onConfigUpdated: (provider: TracingProvider, payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig) => void
tencentConfig: TencentConfig | null
onConfigUpdated: (provider: TracingProvider, payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | TencentConfig) => void
onConfigRemoved: (provider: TracingProvider) => void
}
@ -48,6 +49,7 @@ const ConfigPopup: FC<PopupProps> = ({
opikConfig,
weaveConfig,
aliyunConfig,
tencentConfig,
onConfigUpdated,
onConfigRemoved,
}) => {
@ -81,8 +83,8 @@ const ConfigPopup: FC<PopupProps> = ({
hideConfigModal()
}, [currentProvider, hideConfigModal, onConfigRemoved])
const providerAllConfigured = arizeConfig && phoenixConfig && langSmithConfig && langFuseConfig && opikConfig && weaveConfig && aliyunConfig
const providerAllNotConfigured = !arizeConfig && !phoenixConfig && !langSmithConfig && !langFuseConfig && !opikConfig && !weaveConfig && !aliyunConfig
const providerAllConfigured = arizeConfig && phoenixConfig && langSmithConfig && langFuseConfig && opikConfig && weaveConfig && aliyunConfig && tencentConfig
const providerAllNotConfigured = !arizeConfig && !phoenixConfig && !langSmithConfig && !langFuseConfig && !opikConfig && !weaveConfig && !aliyunConfig && !tencentConfig
const switchContent = (
<Switch
@ -182,6 +184,19 @@ const ConfigPopup: FC<PopupProps> = ({
key="aliyun-provider-panel"
/>
)
const tencentPanel = (
<ProviderPanel
type={TracingProvider.tencent}
readOnly={readOnly}
config={tencentConfig}
hasConfigured={!!tencentConfig}
onConfig={handleOnConfig(TracingProvider.tencent)}
isChosen={chosenProvider === TracingProvider.tencent}
onChoose={handleOnChoose(TracingProvider.tencent)}
key="tencent-provider-panel"
/>
)
const configuredProviderPanel = () => {
const configuredPanels: JSX.Element[] = []
@ -206,6 +221,9 @@ const ConfigPopup: FC<PopupProps> = ({
if (aliyunConfig)
configuredPanels.push(aliyunPanel)
if (tencentConfig)
configuredPanels.push(tencentPanel)
return configuredPanels
}
@ -233,6 +251,9 @@ const ConfigPopup: FC<PopupProps> = ({
if (!aliyunConfig)
notConfiguredPanels.push(aliyunPanel)
if (!tencentConfig)
notConfiguredPanels.push(tencentPanel)
return notConfiguredPanels
}
@ -249,6 +270,8 @@ const ConfigPopup: FC<PopupProps> = ({
return opikConfig
if (currentProvider === TracingProvider.aliyun)
return aliyunConfig
if (currentProvider === TracingProvider.tencent)
return tencentConfig
return weaveConfig
}
@ -297,6 +320,7 @@ const ConfigPopup: FC<PopupProps> = ({
{arizePanel}
{phoenixPanel}
{aliyunPanel}
{tencentPanel}
</div>
</>
)

View File

@ -8,4 +8,5 @@ export const docURL = {
[TracingProvider.opik]: 'https://www.comet.com/docs/opik/tracing/integrations/dify#setup-instructions',
[TracingProvider.weave]: 'https://weave-docs.wandb.ai/',
[TracingProvider.aliyun]: 'https://help.aliyun.com/zh/arms/tracing-analysis/untitled-document-1750672984680',
[TracingProvider.tencent]: 'https://cloud.tencent.com/document/product/248/116531',
}

View File

@ -8,12 +8,12 @@ import {
import { useTranslation } from 'react-i18next'
import { usePathname } from 'next/navigation'
import { useBoolean } from 'ahooks'
import type { AliyunConfig, ArizeConfig, LangFuseConfig, LangSmithConfig, OpikConfig, PhoenixConfig, WeaveConfig } from './type'
import type { AliyunConfig, ArizeConfig, LangFuseConfig, LangSmithConfig, OpikConfig, PhoenixConfig, TencentConfig, WeaveConfig } from './type'
import { TracingProvider } from './type'
import TracingIcon from './tracing-icon'
import ConfigButton from './config-button'
import cn from '@/utils/classnames'
import { AliyunIcon, ArizeIcon, LangfuseIcon, LangsmithIcon, OpikIcon, PhoenixIcon, WeaveIcon } from '@/app/components/base/icons/src/public/tracing'
import { AliyunIcon, ArizeIcon, LangfuseIcon, LangsmithIcon, OpikIcon, PhoenixIcon, TencentIcon, WeaveIcon } from '@/app/components/base/icons/src/public/tracing'
import Indicator from '@/app/components/header/indicator'
import { fetchTracingConfig as doFetchTracingConfig, fetchTracingStatus, updateTracingStatus } from '@/service/apps'
import type { TracingStatus } from '@/models/app'
@ -71,6 +71,7 @@ const Panel: FC = () => {
[TracingProvider.opik]: OpikIcon,
[TracingProvider.weave]: WeaveIcon,
[TracingProvider.aliyun]: AliyunIcon,
[TracingProvider.tencent]: TencentIcon,
}
const InUseProviderIcon = inUseTracingProvider ? providerIconMap[inUseTracingProvider] : undefined
@ -81,7 +82,8 @@ const Panel: FC = () => {
const [opikConfig, setOpikConfig] = useState<OpikConfig | null>(null)
const [weaveConfig, setWeaveConfig] = useState<WeaveConfig | null>(null)
const [aliyunConfig, setAliyunConfig] = useState<AliyunConfig | null>(null)
const hasConfiguredTracing = !!(langSmithConfig || langFuseConfig || opikConfig || weaveConfig || arizeConfig || phoenixConfig || aliyunConfig)
const [tencentConfig, setTencentConfig] = useState<TencentConfig | null>(null)
const hasConfiguredTracing = !!(langSmithConfig || langFuseConfig || opikConfig || weaveConfig || arizeConfig || phoenixConfig || aliyunConfig || tencentConfig)
const fetchTracingConfig = async () => {
const getArizeConfig = async () => {
@ -119,6 +121,11 @@ const Panel: FC = () => {
if (!aliyunHasNotConfig)
setAliyunConfig(aliyunConfig as AliyunConfig)
}
const getTencentConfig = async () => {
const { tracing_config: tencentConfig, has_not_configured: tencentHasNotConfig } = await doFetchTracingConfig({ appId, provider: TracingProvider.tencent })
if (!tencentHasNotConfig)
setTencentConfig(tencentConfig as TencentConfig)
}
Promise.all([
getArizeConfig(),
getPhoenixConfig(),
@ -127,6 +134,7 @@ const Panel: FC = () => {
getOpikConfig(),
getWeaveConfig(),
getAliyunConfig(),
getTencentConfig(),
])
}
@ -147,6 +155,8 @@ const Panel: FC = () => {
setWeaveConfig(tracing_config as WeaveConfig)
else if (provider === TracingProvider.aliyun)
setAliyunConfig(tracing_config as AliyunConfig)
else if (provider === TracingProvider.tencent)
setTencentConfig(tracing_config as TencentConfig)
}
const handleTracingConfigRemoved = (provider: TracingProvider) => {
@ -164,6 +174,8 @@ const Panel: FC = () => {
setWeaveConfig(null)
else if (provider === TracingProvider.aliyun)
setAliyunConfig(null)
else if (provider === TracingProvider.tencent)
setTencentConfig(null)
if (provider === inUseTracingProvider) {
handleTracingStatusChange({
enabled: false,
@ -209,6 +221,7 @@ const Panel: FC = () => {
opikConfig={opikConfig}
weaveConfig={weaveConfig}
aliyunConfig={aliyunConfig}
tencentConfig={tencentConfig}
onConfigUpdated={handleTracingConfigUpdated}
onConfigRemoved={handleTracingConfigRemoved}
>
@ -245,6 +258,7 @@ const Panel: FC = () => {
opikConfig={opikConfig}
weaveConfig={weaveConfig}
aliyunConfig={aliyunConfig}
tencentConfig={tencentConfig}
onConfigUpdated={handleTracingConfigUpdated}
onConfigRemoved={handleTracingConfigRemoved}
>

View File

@ -4,7 +4,7 @@ import React, { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { useBoolean } from 'ahooks'
import Field from './field'
import type { AliyunConfig, ArizeConfig, LangFuseConfig, LangSmithConfig, OpikConfig, PhoenixConfig, WeaveConfig } from './type'
import type { AliyunConfig, ArizeConfig, LangFuseConfig, LangSmithConfig, OpikConfig, PhoenixConfig, TencentConfig, WeaveConfig } from './type'
import { TracingProvider } from './type'
import { docURL } from './config'
import {
@ -22,10 +22,10 @@ import Divider from '@/app/components/base/divider'
type Props = {
appId: string
type: TracingProvider
payload?: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | null
payload?: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | TencentConfig | null
onRemoved: () => void
onCancel: () => void
onSaved: (payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig) => void
onSaved: (payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | TencentConfig) => void
onChosen: (provider: TracingProvider) => void
}
@ -77,6 +77,12 @@ const aliyunConfigTemplate = {
endpoint: '',
}
const tencentConfigTemplate = {
token: '',
endpoint: '',
service_name: '',
}
const ProviderConfigModal: FC<Props> = ({
appId,
type,
@ -90,7 +96,7 @@ const ProviderConfigModal: FC<Props> = ({
const isEdit = !!payload
const isAdd = !isEdit
const [isSaving, setIsSaving] = useState(false)
const [config, setConfig] = useState<ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig>((() => {
const [config, setConfig] = useState<ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | TencentConfig>((() => {
if (isEdit)
return payload
@ -112,6 +118,9 @@ const ProviderConfigModal: FC<Props> = ({
else if (type === TracingProvider.aliyun)
return aliyunConfigTemplate
else if (type === TracingProvider.tencent)
return tencentConfigTemplate
return weaveConfigTemplate
})())
const [isShowRemoveConfirm, {
@ -202,6 +211,16 @@ const ProviderConfigModal: FC<Props> = ({
errorMessage = t('common.errorMsg.fieldRequired', { field: 'Endpoint' })
}
if (type === TracingProvider.tencent) {
const postData = config as TencentConfig
if (!errorMessage && !postData.token)
errorMessage = t('common.errorMsg.fieldRequired', { field: 'Token' })
if (!errorMessage && !postData.endpoint)
errorMessage = t('common.errorMsg.fieldRequired', { field: 'Endpoint' })
if (!errorMessage && !postData.service_name)
errorMessage = t('common.errorMsg.fieldRequired', { field: 'Service Name' })
}
return errorMessage
}, [config, t, type])
const handleSave = useCallback(async () => {
@ -338,6 +357,34 @@ const ProviderConfigModal: FC<Props> = ({
/>
</>
)}
{type === TracingProvider.tencent && (
<>
<Field
label='Token'
labelClassName='!text-sm'
isRequired
value={(config as TencentConfig).token}
onChange={handleConfigChange('token')}
placeholder={t(`${I18N_PREFIX}.placeholder`, { key: 'Token' })!}
/>
<Field
label='Endpoint'
labelClassName='!text-sm'
isRequired
value={(config as TencentConfig).endpoint}
onChange={handleConfigChange('endpoint')}
placeholder='https://your-region.cls.tencentcs.com'
/>
<Field
label='Service Name'
labelClassName='!text-sm'
isRequired
value={(config as TencentConfig).service_name}
onChange={handleConfigChange('service_name')}
placeholder='dify_app'
/>
</>
)}
{type === TracingProvider.weave && (
<>
<Field

View File

@ -7,7 +7,7 @@ import {
import { useTranslation } from 'react-i18next'
import { TracingProvider } from './type'
import cn from '@/utils/classnames'
import { AliyunIconBig, ArizeIconBig, LangfuseIconBig, LangsmithIconBig, OpikIconBig, PhoenixIconBig, WeaveIconBig } from '@/app/components/base/icons/src/public/tracing'
import { AliyunIconBig, ArizeIconBig, LangfuseIconBig, LangsmithIconBig, OpikIconBig, PhoenixIconBig, TencentIconBig, WeaveIconBig } from '@/app/components/base/icons/src/public/tracing'
import { Eye as View } from '@/app/components/base/icons/src/vender/solid/general'
const I18N_PREFIX = 'app.tracing'
@ -31,6 +31,7 @@ const getIcon = (type: TracingProvider) => {
[TracingProvider.opik]: OpikIconBig,
[TracingProvider.weave]: WeaveIconBig,
[TracingProvider.aliyun]: AliyunIconBig,
[TracingProvider.tencent]: TencentIconBig,
})[type]
}

View File

@ -6,6 +6,7 @@ export enum TracingProvider {
opik = 'opik',
weave = 'weave',
aliyun = 'aliyun',
tencent = 'tencent',
}
export type ArizeConfig = {
@ -53,3 +54,9 @@ export type AliyunConfig = {
license_key: string
endpoint: string
}
export type TencentConfig = {
token: string
endpoint: string
service_name: string
}

Some files were not shown because too many files have changed in this diff Show More