mirror of https://github.com/langgenius/dify.git
Merge branch 'feat/mcp-06-18' into deploy/dev
This commit is contained in:
commit
b48a7c7cda
|
|
@ -437,6 +437,9 @@ CODE_EXECUTION_SSL_VERIFY=True
|
|||
CODE_EXECUTION_POOL_MAX_CONNECTIONS=100
|
||||
CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS=20
|
||||
CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY=5.0
|
||||
CODE_EXECUTION_CONNECT_TIMEOUT=10
|
||||
CODE_EXECUTION_READ_TIMEOUT=60
|
||||
CODE_EXECUTION_WRITE_TIMEOUT=10
|
||||
CODE_MAX_NUMBER=9223372036854775807
|
||||
CODE_MIN_NUMBER=-9223372036854775808
|
||||
CODE_MAX_STRING_LENGTH=400000
|
||||
|
|
|
|||
|
|
@ -548,7 +548,7 @@ class UpdateConfig(BaseSettings):
|
|||
|
||||
class WorkflowVariableTruncationConfig(BaseSettings):
|
||||
WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE: PositiveInt = Field(
|
||||
# 100KB
|
||||
# 1000 KiB
|
||||
1024_000,
|
||||
description="Maximum size for variable to trigger final truncation.",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -145,7 +145,7 @@ class DatabaseConfig(BaseSettings):
|
|||
default="postgresql",
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[misc]
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def SQLALCHEMY_DATABASE_URI(self) -> str:
|
||||
db_extras = (
|
||||
|
|
@ -198,7 +198,7 @@ class DatabaseConfig(BaseSettings):
|
|||
default=os.cpu_count() or 1,
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[misc]
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
|
||||
# Parse DB_EXTRAS for 'options'
|
||||
|
|
|
|||
|
|
@ -56,11 +56,15 @@ else:
|
|||
}
|
||||
DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)
|
||||
|
||||
# console
|
||||
COOKIE_NAME_ACCESS_TOKEN = "access_token"
|
||||
COOKIE_NAME_REFRESH_TOKEN = "refresh_token"
|
||||
COOKIE_NAME_PASSPORT = "passport"
|
||||
COOKIE_NAME_CSRF_TOKEN = "csrf_token"
|
||||
|
||||
# webapp
|
||||
COOKIE_NAME_WEBAPP_ACCESS_TOKEN = "webapp_access_token"
|
||||
COOKIE_NAME_PASSPORT = "passport"
|
||||
|
||||
HEADER_NAME_CSRF_TOKEN = "X-CSRF-Token"
|
||||
HEADER_NAME_APP_CODE = "X-App-Code"
|
||||
HEADER_NAME_PASSPORT = "X-App-Passport"
|
||||
|
|
|
|||
|
|
@ -31,3 +31,9 @@ def supported_language(lang):
|
|||
|
||||
error = f"{lang} is not a valid language."
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
def get_valid_language(lang: str | None) -> str:
|
||||
if lang and lang in languages:
|
||||
return lang
|
||||
return languages[0]
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ except ImportError:
|
|||
)
|
||||
else:
|
||||
warnings.warn("To use python-magic guess MIMETYPE, you need to install `libmagic`", stacklevel=2)
|
||||
magic = None # type: ignore
|
||||
magic = None # type: ignore[assignment]
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from flask_restx import Resource, reqparse
|
|||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from constants.languages import get_valid_language
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
AuthenticationFailedError,
|
||||
|
|
@ -207,10 +207,12 @@ class EmailCodeLoginApi(Resource):
|
|||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
user_email = args["email"]
|
||||
language = args["language"]
|
||||
|
||||
token_data = AccountService.get_email_code_login_data(args["token"])
|
||||
if token_data is None:
|
||||
|
|
@ -244,7 +246,9 @@ class EmailCodeLoginApi(Resource):
|
|||
if account is None:
|
||||
try:
|
||||
account = AccountService.create_account_and_tenant(
|
||||
email=user_email, name=user_email, interface_language=languages[0]
|
||||
email=user_email,
|
||||
name=user_email,
|
||||
interface_language=get_valid_language(language),
|
||||
)
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
raise NotAllowedCreateWorkspace()
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from core.errors.error import (
|
|||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from libs import helper
|
||||
from libs.login import current_user as current_user_
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.model import AppMode, InstalledApp
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
|
@ -31,8 +31,6 @@ from .. import console_ns
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
current_user = current_user_._get_current_object() # type: ignore
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
|
||||
class InstalledAppWorkflowRunApi(InstalledAppResource):
|
||||
|
|
@ -40,6 +38,7 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
|
|||
"""
|
||||
Run workflow
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
if not app_model:
|
||||
raise NotWorkflowAppError()
|
||||
|
|
@ -53,7 +52,6 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
|
|||
.add_argument("files", type=list, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
assert current_user is not None
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||
|
|
@ -89,7 +87,6 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
|
|||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
assert current_user is not None
|
||||
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
|
|
|
|||
|
|
@ -74,12 +74,17 @@ class SetupApi(Resource):
|
|||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=True, location="json")
|
||||
.add_argument("password", type=valid_password, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# setup
|
||||
RegisterService.setup(
|
||||
email=args["email"], name=args["name"], password=args["password"], ip_address=extract_remote_ip(request)
|
||||
email=args["email"],
|
||||
name=args["name"],
|
||||
password=args["password"],
|
||||
ip_address=extract_remote_ip(request),
|
||||
language=args["language"],
|
||||
)
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ from models.provider_ids import ToolProviderID
|
|||
from services.plugin.oauth_service import OAuthProxyService
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService, OAuthDataType
|
||||
from services.tools.tool_labels_service import ToolLabelsService
|
||||
from services.tools.tools_manage_service import ToolCommonService
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
|
@ -897,10 +897,6 @@ class ToolProviderMCPApi(Resource):
|
|||
args = parser.parse_args()
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
# Validate server URL
|
||||
if not is_valid_url(args["server_url"]):
|
||||
raise ValueError("Server URL is not valid.")
|
||||
|
||||
# Parse and validate models
|
||||
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||
|
|
@ -941,15 +937,21 @@ class ToolProviderMCPApi(Resource):
|
|||
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if not is_valid_url(args["server_url"]):
|
||||
if "[__HIDDEN__]" in args["server_url"]:
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Server URL is not valid.")
|
||||
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# Step 1: Validate server URL change if needed (includes URL format validation and network operation)
|
||||
validation_result = None
|
||||
with Session(db.engine) as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
validation_result = service.validate_server_url_change(
|
||||
tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"]
|
||||
)
|
||||
|
||||
# No need to check for errors here, exceptions will be raised directly
|
||||
|
||||
# Step 2: Perform database update in a transaction
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.update_provider(
|
||||
|
|
@ -964,6 +966,7 @@ class ToolProviderMCPApi(Resource):
|
|||
headers=args["headers"],
|
||||
configuration=configuration,
|
||||
authentication=authentication,
|
||||
validation_result=validation_result,
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
|
|
@ -998,47 +1001,49 @@ class ToolMCPAuthApi(Resource):
|
|||
provider_id = args["provider_id"]
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
if not db_provider:
|
||||
raise ValueError("provider not found")
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
if not db_provider:
|
||||
raise ValueError("provider not found")
|
||||
|
||||
# Convert to entity
|
||||
provider_entity = db_provider.to_entity()
|
||||
server_url = provider_entity.decrypt_server_url()
|
||||
headers = provider_entity.decrypt_authentication()
|
||||
# Convert to entity
|
||||
provider_entity = db_provider.to_entity()
|
||||
server_url = provider_entity.decrypt_server_url()
|
||||
headers = provider_entity.decrypt_authentication()
|
||||
|
||||
# Try to connect without active transaction
|
||||
# Try to connect without active transaction
|
||||
try:
|
||||
# Use MCPClientWithAuthRetry to handle authentication automatically
|
||||
with MCPClient(
|
||||
server_url=server_url,
|
||||
headers=headers,
|
||||
timeout=provider_entity.timeout,
|
||||
sse_read_timeout=provider_entity.sse_read_timeout,
|
||||
):
|
||||
# Create new transaction for update
|
||||
with session.begin():
|
||||
service.update_provider_credentials(
|
||||
provider=db_provider,
|
||||
credentials=provider_entity.credentials,
|
||||
authed=True,
|
||||
)
|
||||
return {"result": "success"}
|
||||
except MCPAuthError as e:
|
||||
service = MCPToolManageService(session=session)
|
||||
try:
|
||||
# Use MCPClientWithAuthRetry to handle authentication automatically
|
||||
with MCPClient(
|
||||
server_url=server_url,
|
||||
headers=headers,
|
||||
timeout=provider_entity.timeout,
|
||||
sse_read_timeout=provider_entity.sse_read_timeout,
|
||||
):
|
||||
# Create new transaction for update
|
||||
with session.begin():
|
||||
service.update_provider_credentials(
|
||||
provider=db_provider,
|
||||
credentials=provider_entity.credentials,
|
||||
authed=True,
|
||||
)
|
||||
return {"result": "success"}
|
||||
except MCPAuthError as e:
|
||||
service = MCPToolManageService(session=session)
|
||||
try:
|
||||
return auth(provider_entity, service, args.get("authorization_code"))
|
||||
except MCPRefreshTokenError as e:
|
||||
with session.begin():
|
||||
service.clear_provider_credentials(provider=db_provider)
|
||||
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
|
||||
except MCPError as e:
|
||||
auth_result = auth(provider_entity, args.get("authorization_code"))
|
||||
with session.begin():
|
||||
response = service.execute_auth_actions(auth_result)
|
||||
return response
|
||||
except MCPRefreshTokenError as e:
|
||||
with session.begin():
|
||||
service.clear_provider_credentials(provider=db_provider)
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
||||
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
|
||||
except MCPError as e:
|
||||
with session.begin():
|
||||
service.clear_provider_credentials(provider=db_provider)
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/mcp/tools/<path:provider_id>")
|
||||
|
|
@ -1048,7 +1053,7 @@ class ToolMCPDetailApi(Resource):
|
|||
@account_initialization_required
|
||||
def get(self, provider_id):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
|
||||
|
|
@ -1062,7 +1067,7 @@ class ToolMCPListAllApi(Resource):
|
|||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
tools = service.list_providers(tenant_id=tenant_id)
|
||||
|
||||
|
|
@ -1100,6 +1105,11 @@ class ToolMCPCallbackApi(Resource):
|
|||
# Create service instance for handle_callback
|
||||
with Session(db.engine) as session, session.begin():
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
handle_callback(state_key, authorization_code, mcp_service)
|
||||
# handle_callback now returns state data and tokens
|
||||
state_data, tokens = handle_callback(state_key, authorization_code)
|
||||
# Save tokens using the service layer
|
||||
mcp_service.save_oauth_data(
|
||||
state_data.provider_id, state_data.tenant_id, tokens.model_dump(), OAuthDataType.TOKENS
|
||||
)
|
||||
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
|
|
|||
|
|
@ -193,15 +193,16 @@ class MCPAppApi(Resource):
|
|||
except ValidationError as e:
|
||||
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
|
||||
|
||||
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str, session: Session) -> EndUser | None:
|
||||
"""Get end user from existing session - optimized query"""
|
||||
return (
|
||||
session.query(EndUser)
|
||||
.where(EndUser.tenant_id == tenant_id)
|
||||
.where(EndUser.session_id == mcp_server_id)
|
||||
.where(EndUser.type == "mcp")
|
||||
.first()
|
||||
)
|
||||
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str) -> EndUser | None:
|
||||
"""Get end user - manages its own database session"""
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
return (
|
||||
session.query(EndUser)
|
||||
.where(EndUser.tenant_id == tenant_id)
|
||||
.where(EndUser.session_id == mcp_server_id)
|
||||
.where(EndUser.type == "mcp")
|
||||
.first()
|
||||
)
|
||||
|
||||
def _create_end_user(
|
||||
self, client_name: str, tenant_id: str, app_id: str, mcp_server_id: str, session: Session
|
||||
|
|
@ -229,7 +230,7 @@ class MCPAppApi(Resource):
|
|||
request_id: Union[int, str],
|
||||
) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError | None:
|
||||
"""Handle MCP request and return response"""
|
||||
end_user = self._retrieve_end_user(mcp_server.tenant_id, mcp_server.id, session)
|
||||
end_user = self._retrieve_end_user(mcp_server.tenant_id, mcp_server.id)
|
||||
|
||||
if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest):
|
||||
client_info = mcp_request.root.params.clientInfo
|
||||
|
|
|
|||
|
|
@ -17,8 +17,8 @@ from libs.helper import email
|
|||
from libs.passport import PassportService
|
||||
from libs.password import valid_password
|
||||
from libs.token import (
|
||||
clear_access_token_from_cookie,
|
||||
extract_access_token,
|
||||
clear_webapp_access_token_from_cookie,
|
||||
extract_webapp_access_token,
|
||||
)
|
||||
from services.account_service import AccountService
|
||||
from services.app_service import AppService
|
||||
|
|
@ -81,7 +81,7 @@ class LoginStatusApi(Resource):
|
|||
)
|
||||
def get(self):
|
||||
app_code = request.args.get("app_code")
|
||||
token = extract_access_token(request)
|
||||
token = extract_webapp_access_token(request)
|
||||
if not app_code:
|
||||
return {
|
||||
"logged_in": bool(token),
|
||||
|
|
@ -128,7 +128,7 @@ class LogoutApi(Resource):
|
|||
response = make_response({"result": "success"})
|
||||
# enterprise SSO sets same site to None in https deployment
|
||||
# so we need to logout by calling api
|
||||
clear_access_token_from_cookie(response, samesite="None")
|
||||
clear_webapp_access_token_from_cookie(response, samesite="None")
|
||||
return response
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,10 +12,8 @@ from controllers.web import web_ns
|
|||
from controllers.web.error import WebAppAuthRequiredError
|
||||
from extensions.ext_database import db
|
||||
from libs.passport import PassportService
|
||||
from libs.token import extract_access_token
|
||||
from libs.token import extract_webapp_access_token
|
||||
from models.model import App, EndUser, Site
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
|
||||
|
||||
|
|
@ -37,23 +35,18 @@ class PassportResource(Resource):
|
|||
system_features = FeatureService.get_system_features()
|
||||
app_code = request.headers.get(HEADER_NAME_APP_CODE)
|
||||
user_id = request.args.get("user_id")
|
||||
access_token = extract_access_token(request)
|
||||
|
||||
access_token = extract_webapp_access_token(request)
|
||||
if app_code is None:
|
||||
raise Unauthorized("X-App-Code header is missing.")
|
||||
app_id = AppService.get_app_id_by_code(app_code)
|
||||
# exchange token for enterprise logined web user
|
||||
enterprise_user_decoded = decode_enterprise_webapp_user_id(access_token)
|
||||
if enterprise_user_decoded:
|
||||
# a web user has already logged in, exchange a token for this app without redirecting to the login page
|
||||
return exchange_token_for_existing_web_user(
|
||||
app_code=app_code, enterprise_user_decoded=enterprise_user_decoded
|
||||
)
|
||||
|
||||
if system_features.webapp_auth.enabled:
|
||||
app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id)
|
||||
if not app_settings or not app_settings.access_mode == "public":
|
||||
raise WebAppAuthRequiredError()
|
||||
enterprise_user_decoded = decode_enterprise_webapp_user_id(access_token)
|
||||
app_auth_type = WebAppAuthService.get_app_auth_type(app_code=app_code)
|
||||
if app_auth_type != WebAppAuthType.PUBLIC:
|
||||
if not enterprise_user_decoded:
|
||||
raise WebAppAuthRequiredError()
|
||||
return exchange_token_for_existing_web_user(
|
||||
app_code=app_code, enterprise_user_decoded=enterprise_user_decoded, auth_type=app_auth_type
|
||||
)
|
||||
|
||||
# get site from db and check if it is normal
|
||||
site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal"))
|
||||
|
|
@ -124,7 +117,7 @@ def decode_enterprise_webapp_user_id(jwt_token: str | None):
|
|||
return decoded
|
||||
|
||||
|
||||
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict):
|
||||
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict, auth_type: WebAppAuthType):
|
||||
"""
|
||||
Exchange a token for an existing web user session.
|
||||
"""
|
||||
|
|
@ -145,13 +138,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
|
|||
if not app_model or app_model.status != "normal" or not app_model.enable_site:
|
||||
raise NotFound()
|
||||
|
||||
app_auth_type = WebAppAuthService.get_app_auth_type(app_code=app_code)
|
||||
|
||||
if app_auth_type == WebAppAuthType.PUBLIC:
|
||||
if auth_type == WebAppAuthType.PUBLIC:
|
||||
return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded)
|
||||
elif app_auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external":
|
||||
elif auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external":
|
||||
raise WebAppAuthRequiredError("Please login as external user.")
|
||||
elif app_auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal":
|
||||
elif auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal":
|
||||
raise WebAppAuthRequiredError("Please login as internal user.")
|
||||
|
||||
end_user = None
|
||||
|
|
|
|||
|
|
@ -211,8 +211,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
user=user,
|
||||
stream=streaming,
|
||||
)
|
||||
# FIXME: Type hinting issue here, ignore it for now, will fix it later
|
||||
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) # type: ignore
|
||||
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def _generate_worker(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -255,7 +255,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
json_text = json.dumps(text)
|
||||
upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id)
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
if features.billing.subscription.plan == "sandbox":
|
||||
if features.billing.enabled and features.billing.subscription.plan == "sandbox":
|
||||
tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}"
|
||||
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}"
|
||||
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
yield response_chunk
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ class RateLimit:
|
|||
else:
|
||||
return RateLimitGenerator(
|
||||
rate_limit=self,
|
||||
generator=generator, # ty: ignore [invalid-argument-type]
|
||||
generator=generator,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ class BasedGenerateTaskPipeline:
|
|||
if isinstance(e, InvokeAuthorizationError):
|
||||
err = InvokeAuthorizationError("Incorrect API key provided")
|
||||
elif isinstance(e, InvokeError | ValueError):
|
||||
err = e # ty: ignore [invalid-assignment]
|
||||
err = e
|
||||
else:
|
||||
description = getattr(e, "description", None)
|
||||
err = Exception(description if description is not None else str(e))
|
||||
|
|
|
|||
|
|
@ -1868,7 +1868,7 @@ class ProviderConfigurations(BaseModel):
|
|||
if "/" not in key:
|
||||
key = str(ModelProviderID(key))
|
||||
|
||||
return self.configurations.get(key, default) # type: ignore
|
||||
return self.configurations.get(key, default)
|
||||
|
||||
|
||||
class ProviderModelBundle(BaseModel):
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz
|
|||
else:
|
||||
# Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
|
||||
# FIXME: mypy does not support the type of spec.loader
|
||||
spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore
|
||||
spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore[assignment]
|
||||
if not spec or not spec.loader:
|
||||
raise Exception(f"Failed to load module {module_name} from {py_file_path!r}")
|
||||
if use_lazy_loader:
|
||||
|
|
|
|||
|
|
@ -49,62 +49,80 @@ class IndexingRunner:
|
|||
self.storage = storage
|
||||
self.model_manager = ModelManager()
|
||||
|
||||
def _handle_indexing_error(self, document_id: str, error: Exception) -> None:
|
||||
"""Handle indexing errors by updating document status."""
|
||||
logger.exception("consume document failed")
|
||||
document = db.session.get(DatasetDocument, document_id)
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
error_message = getattr(error, "description", str(error))
|
||||
document.error = str(error_message)
|
||||
document.stopped_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
def run(self, dataset_documents: list[DatasetDocument]):
|
||||
"""Run the indexing process."""
|
||||
for dataset_document in dataset_documents:
|
||||
document_id = dataset_document.id
|
||||
try:
|
||||
# Re-query the document to ensure it's bound to the current session
|
||||
requeried_document = db.session.get(DatasetDocument, document_id)
|
||||
if not requeried_document:
|
||||
logger.warning("Document not found, skipping document id: %s", document_id)
|
||||
continue
|
||||
|
||||
# get dataset
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
|
||||
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("no dataset found")
|
||||
# get the process rule
|
||||
stmt = select(DatasetProcessRule).where(
|
||||
DatasetProcessRule.id == dataset_document.dataset_process_rule_id
|
||||
DatasetProcessRule.id == requeried_document.dataset_process_rule_id
|
||||
)
|
||||
processing_rule = db.session.scalar(stmt)
|
||||
if not processing_rule:
|
||||
raise ValueError("no process rule found")
|
||||
index_type = dataset_document.doc_form
|
||||
index_type = requeried_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
# extract
|
||||
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
|
||||
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
|
||||
|
||||
# transform
|
||||
documents = self._transform(
|
||||
index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict()
|
||||
index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict()
|
||||
)
|
||||
# save segment
|
||||
self._load_segments(dataset, dataset_document, documents)
|
||||
self._load_segments(dataset, requeried_document, documents)
|
||||
|
||||
# load
|
||||
self._load(
|
||||
index_processor=index_processor,
|
||||
dataset=dataset,
|
||||
dataset_document=dataset_document,
|
||||
dataset_document=requeried_document,
|
||||
documents=documents,
|
||||
)
|
||||
except DocumentIsPausedError:
|
||||
raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}")
|
||||
raise DocumentIsPausedError(f"Document paused, document id: {document_id}")
|
||||
except ProviderTokenNotInitError as e:
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e.description)
|
||||
dataset_document.stopped_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
self._handle_indexing_error(document_id, e)
|
||||
except ObjectDeletedError:
|
||||
logger.warning("Document deleted, document id: %s", dataset_document.id)
|
||||
logger.warning("Document deleted, document id: %s", document_id)
|
||||
except Exception as e:
|
||||
logger.exception("consume document failed")
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e)
|
||||
dataset_document.stopped_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
self._handle_indexing_error(document_id, e)
|
||||
|
||||
def run_in_splitting_status(self, dataset_document: DatasetDocument):
|
||||
"""Run the indexing process when the index_status is splitting."""
|
||||
document_id = dataset_document.id
|
||||
try:
|
||||
# Re-query the document to ensure it's bound to the current session
|
||||
requeried_document = db.session.get(DatasetDocument, document_id)
|
||||
if not requeried_document:
|
||||
logger.warning("Document not found: %s", document_id)
|
||||
return
|
||||
|
||||
# get dataset
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
|
||||
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("no dataset found")
|
||||
|
|
@ -112,57 +130,60 @@ class IndexingRunner:
|
|||
# get exist document_segment list and delete
|
||||
document_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter_by(dataset_id=dataset.id, document_id=dataset_document.id)
|
||||
.filter_by(dataset_id=dataset.id, document_id=requeried_document.id)
|
||||
.all()
|
||||
)
|
||||
|
||||
for document_segment in document_segments:
|
||||
db.session.delete(document_segment)
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
# delete child chunks
|
||||
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
|
||||
db.session.commit()
|
||||
# get the process rule
|
||||
stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
|
||||
stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == requeried_document.dataset_process_rule_id)
|
||||
processing_rule = db.session.scalar(stmt)
|
||||
if not processing_rule:
|
||||
raise ValueError("no process rule found")
|
||||
|
||||
index_type = dataset_document.doc_form
|
||||
index_type = requeried_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
# extract
|
||||
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
|
||||
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
|
||||
|
||||
# transform
|
||||
documents = self._transform(
|
||||
index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict()
|
||||
index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict()
|
||||
)
|
||||
# save segment
|
||||
self._load_segments(dataset, dataset_document, documents)
|
||||
self._load_segments(dataset, requeried_document, documents)
|
||||
|
||||
# load
|
||||
self._load(
|
||||
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
|
||||
index_processor=index_processor,
|
||||
dataset=dataset,
|
||||
dataset_document=requeried_document,
|
||||
documents=documents,
|
||||
)
|
||||
except DocumentIsPausedError:
|
||||
raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}")
|
||||
raise DocumentIsPausedError(f"Document paused, document id: {document_id}")
|
||||
except ProviderTokenNotInitError as e:
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e.description)
|
||||
dataset_document.stopped_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
self._handle_indexing_error(document_id, e)
|
||||
except Exception as e:
|
||||
logger.exception("consume document failed")
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e)
|
||||
dataset_document.stopped_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
self._handle_indexing_error(document_id, e)
|
||||
|
||||
def run_in_indexing_status(self, dataset_document: DatasetDocument):
|
||||
"""Run the indexing process when the index_status is indexing."""
|
||||
document_id = dataset_document.id
|
||||
try:
|
||||
# Re-query the document to ensure it's bound to the current session
|
||||
requeried_document = db.session.get(DatasetDocument, document_id)
|
||||
if not requeried_document:
|
||||
logger.warning("Document not found: %s", document_id)
|
||||
return
|
||||
|
||||
# get dataset
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
|
||||
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("no dataset found")
|
||||
|
|
@ -170,7 +191,7 @@ class IndexingRunner:
|
|||
# get exist document_segment list and delete
|
||||
document_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter_by(dataset_id=dataset.id, document_id=dataset_document.id)
|
||||
.filter_by(dataset_id=dataset.id, document_id=requeried_document.id)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
|
@ -188,7 +209,7 @@ class IndexingRunner:
|
|||
"dataset_id": document_segment.dataset_id,
|
||||
},
|
||||
)
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunks = document_segment.get_child_chunks()
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
|
|
@ -206,24 +227,20 @@ class IndexingRunner:
|
|||
document.children = child_documents
|
||||
documents.append(document)
|
||||
# build index
|
||||
index_type = dataset_document.doc_form
|
||||
index_type = requeried_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
self._load(
|
||||
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
|
||||
index_processor=index_processor,
|
||||
dataset=dataset,
|
||||
dataset_document=requeried_document,
|
||||
documents=documents,
|
||||
)
|
||||
except DocumentIsPausedError:
|
||||
raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}")
|
||||
raise DocumentIsPausedError(f"Document paused, document id: {document_id}")
|
||||
except ProviderTokenNotInitError as e:
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e.description)
|
||||
dataset_document.stopped_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
self._handle_indexing_error(document_id, e)
|
||||
except Exception as e:
|
||||
logger.exception("consume document failed")
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e)
|
||||
dataset_document.stopped_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
self._handle_indexing_error(document_id, e)
|
||||
|
||||
def indexing_estimate(
|
||||
self,
|
||||
|
|
@ -398,7 +415,6 @@ class IndexingRunner:
|
|||
document_id=dataset_document.id,
|
||||
after_indexing_status="splitting",
|
||||
extra_update_params={
|
||||
DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs),
|
||||
DatasetDocument.parsing_completed_at: naive_utc_now(),
|
||||
},
|
||||
)
|
||||
|
|
@ -738,6 +754,7 @@ class IndexingRunner:
|
|||
extra_update_params={
|
||||
DatasetDocument.cleaning_completed_at: cur_time,
|
||||
DatasetDocument.splitting_completed_at: cur_time,
|
||||
DatasetDocument.word_count: sum(len(doc.page_content) for doc in documents),
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -103,7 +103,7 @@ class LLMGenerator:
|
|||
return name
|
||||
|
||||
@classmethod
|
||||
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
|
||||
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str) -> Sequence[str]:
|
||||
output_parser = SuggestedQuestionsAfterAnswerOutputParser()
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
|
||||
|
|
@ -122,6 +122,8 @@ class LLMGenerator:
|
|||
|
||||
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||
|
||||
questions: Sequence[str] = []
|
||||
|
||||
try:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
|
|
|
|||
|
|
@ -1,17 +1,26 @@
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
|
||||
from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SuggestedQuestionsAfterAnswerOutputParser:
|
||||
def get_format_instructions(self) -> str:
|
||||
return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
||||
|
||||
def parse(self, text: str):
|
||||
def parse(self, text: str) -> Sequence[str]:
|
||||
action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL)
|
||||
questions: list[str] = []
|
||||
if action_match is not None:
|
||||
json_obj = json.loads(action_match.group(0).strip())
|
||||
else:
|
||||
json_obj = []
|
||||
return json_obj
|
||||
try:
|
||||
json_obj = json.loads(action_match.group(0).strip())
|
||||
except json.JSONDecodeError as exc:
|
||||
logger.warning("Failed to decode suggested questions payload: %s", exc)
|
||||
else:
|
||||
if isinstance(json_obj, list):
|
||||
questions = [question for question in json_obj if isinstance(question, str)]
|
||||
return questions
|
||||
|
|
|
|||
|
|
@ -4,14 +4,14 @@ import json
|
|||
import os
|
||||
import secrets
|
||||
import urllib.parse
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
from httpx import ConnectError, HTTPStatusError, RequestError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType
|
||||
from core.helper import ssrf_proxy
|
||||
from core.mcp.entities import AuthAction, AuthActionType, AuthResult, OAuthCallbackState
|
||||
from core.mcp.error import MCPRefreshTokenError
|
||||
from core.mcp.types import (
|
||||
LATEST_PROTOCOL_VERSION,
|
||||
|
|
@ -23,23 +23,10 @@ from core.mcp.types import (
|
|||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
|
||||
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
|
||||
|
||||
|
||||
class OAuthCallbackState(BaseModel):
|
||||
provider_id: str
|
||||
tenant_id: str
|
||||
server_url: str
|
||||
metadata: OAuthMetadata | None = None
|
||||
client_information: OAuthClientInformation
|
||||
code_verifier: str
|
||||
redirect_uri: str
|
||||
|
||||
|
||||
def generate_pkce_challenge() -> tuple[str, str]:
|
||||
"""Generate PKCE challenge and verifier."""
|
||||
code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8")
|
||||
|
|
@ -86,8 +73,13 @@ def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
|
|||
raise ValueError(f"Invalid state parameter: {str(e)}")
|
||||
|
||||
|
||||
def handle_callback(state_key: str, authorization_code: str, mcp_service: "MCPToolManageService") -> OAuthCallbackState:
|
||||
"""Handle the callback from the OAuth provider."""
|
||||
def handle_callback(state_key: str, authorization_code: str) -> tuple[OAuthCallbackState, OAuthTokens]:
|
||||
"""
|
||||
Handle the callback from the OAuth provider.
|
||||
|
||||
Returns:
|
||||
A tuple of (callback_state, tokens) that can be used by the caller to save data.
|
||||
"""
|
||||
# Retrieve state data from Redis (state is automatically deleted after retrieval)
|
||||
full_state_data = _retrieve_redis_state(state_key)
|
||||
|
||||
|
|
@ -100,10 +92,7 @@ def handle_callback(state_key: str, authorization_code: str, mcp_service: "MCPTo
|
|||
full_state_data.redirect_uri,
|
||||
)
|
||||
|
||||
# Save tokens using the service layer
|
||||
mcp_service.save_oauth_data(full_state_data.provider_id, full_state_data.tenant_id, tokens.model_dump(), "tokens")
|
||||
|
||||
return full_state_data
|
||||
return full_state_data, tokens
|
||||
|
||||
|
||||
def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
|
||||
|
|
@ -361,11 +350,24 @@ def register_client(
|
|||
|
||||
def auth(
|
||||
provider: MCPProviderEntity,
|
||||
mcp_service: "MCPToolManageService",
|
||||
authorization_code: str | None = None,
|
||||
state_param: str | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
|
||||
) -> AuthResult:
|
||||
"""
|
||||
Orchestrates the full auth flow with a server using secure Redis state storage.
|
||||
|
||||
This function performs only network operations and returns actions that need
|
||||
to be performed by the caller (such as saving data to database).
|
||||
|
||||
Args:
|
||||
provider: The MCP provider entity
|
||||
authorization_code: Optional authorization code from OAuth callback
|
||||
state_param: Optional state parameter from OAuth callback
|
||||
|
||||
Returns:
|
||||
AuthResult containing actions to be performed and response data
|
||||
"""
|
||||
actions: list[AuthAction] = []
|
||||
server_url = provider.decrypt_server_url()
|
||||
server_metadata = discover_oauth_metadata(server_url)
|
||||
client_metadata = provider.client_metadata
|
||||
|
|
@ -407,9 +409,14 @@ def auth(
|
|||
except RequestError as e:
|
||||
raise ValueError(f"Could not register OAuth client: {e}")
|
||||
|
||||
# Save client information using service layer
|
||||
mcp_service.save_oauth_data(
|
||||
provider_id, tenant_id, {"client_information": full_information.model_dump()}, "client_info"
|
||||
# Return action to save client information
|
||||
actions.append(
|
||||
AuthAction(
|
||||
action_type=AuthActionType.SAVE_CLIENT_INFO,
|
||||
data={"client_information": full_information.model_dump()},
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
client_information = full_information
|
||||
|
|
@ -426,12 +433,20 @@ def auth(
|
|||
scope,
|
||||
)
|
||||
|
||||
# Save tokens and grant type
|
||||
# Return action to save tokens and grant type
|
||||
token_data = tokens.model_dump()
|
||||
token_data["grant_type"] = MCPSupportGrantType.CLIENT_CREDENTIALS.value
|
||||
mcp_service.save_oauth_data(provider_id, tenant_id, token_data, "tokens")
|
||||
|
||||
return {"result": "success"}
|
||||
actions.append(
|
||||
AuthAction(
|
||||
action_type=AuthActionType.SAVE_TOKENS,
|
||||
data=token_data,
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
return AuthResult(actions=actions, response={"result": "success"})
|
||||
except (RequestError, ValueError, KeyError) as e:
|
||||
# RequestError: HTTP request failed
|
||||
# ValueError: Invalid response data
|
||||
|
|
@ -465,10 +480,17 @@ def auth(
|
|||
redirect_uri,
|
||||
)
|
||||
|
||||
# Save tokens using service layer
|
||||
mcp_service.save_oauth_data(provider_id, tenant_id, tokens.model_dump(), "tokens")
|
||||
# Return action to save tokens
|
||||
actions.append(
|
||||
AuthAction(
|
||||
action_type=AuthActionType.SAVE_TOKENS,
|
||||
data=tokens.model_dump(),
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
return AuthResult(actions=actions, response={"result": "success"})
|
||||
|
||||
provider_tokens = provider.retrieve_tokens()
|
||||
|
||||
|
|
@ -479,10 +501,17 @@ def auth(
|
|||
server_url, server_metadata, client_information, provider_tokens.refresh_token
|
||||
)
|
||||
|
||||
# Save new tokens using service layer
|
||||
mcp_service.save_oauth_data(provider_id, tenant_id, new_tokens.model_dump(), "tokens")
|
||||
# Return action to save new tokens
|
||||
actions.append(
|
||||
AuthAction(
|
||||
action_type=AuthActionType.SAVE_TOKENS,
|
||||
data=new_tokens.model_dump(),
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
return AuthResult(actions=actions, response={"result": "success"})
|
||||
except (RequestError, ValueError, KeyError) as e:
|
||||
# RequestError: HTTP request failed
|
||||
# ValueError: Invalid response data
|
||||
|
|
@ -499,7 +528,14 @@ def auth(
|
|||
tenant_id,
|
||||
)
|
||||
|
||||
# Save code verifier using service layer
|
||||
mcp_service.save_oauth_data(provider_id, tenant_id, {"code_verifier": code_verifier}, "code_verifier")
|
||||
# Return action to save code verifier
|
||||
actions.append(
|
||||
AuthAction(
|
||||
action_type=AuthActionType.SAVE_CODE_VERIFIER,
|
||||
data={"code_verifier": code_verifier},
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
return {"authorization_url": authorization_url}
|
||||
return AuthResult(actions=actions, response={"authorization_url": authorization_url})
|
||||
|
|
|
|||
|
|
@ -7,15 +7,15 @@ authentication failures and retries operations after refreshing tokens.
|
|||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
from core.mcp.error import MCPAuthError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.mcp.types import CallToolResult, Tool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
from extensions.ext_database import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -26,6 +26,9 @@ class MCPClientWithAuthRetry(MCPClient):
|
|||
|
||||
This class extends MCPClient and intercepts MCPAuthError exceptions
|
||||
to refresh authentication before retrying failed operations.
|
||||
|
||||
Note: This class uses lazy session creation - database sessions are only
|
||||
created when authentication retry is actually needed, not on every request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -35,11 +38,8 @@ class MCPClientWithAuthRetry(MCPClient):
|
|||
timeout: float | None = None,
|
||||
sse_read_timeout: float | None = None,
|
||||
provider_entity: MCPProviderEntity | None = None,
|
||||
auth_callback: Callable[[MCPProviderEntity, "MCPToolManageService", Optional[str]], dict[str, str]]
|
||||
| None = None,
|
||||
authorization_code: str | None = None,
|
||||
by_server_id: bool = False,
|
||||
mcp_service: Optional["MCPToolManageService"] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the MCP client with auth retry capability.
|
||||
|
|
@ -50,31 +50,30 @@ class MCPClientWithAuthRetry(MCPClient):
|
|||
timeout: Request timeout
|
||||
sse_read_timeout: SSE read timeout
|
||||
provider_entity: Provider entity for authentication
|
||||
auth_callback: Authentication callback function
|
||||
authorization_code: Optional authorization code for initial auth
|
||||
by_server_id: Whether to look up provider by server ID
|
||||
mcp_service: MCP service instance
|
||||
"""
|
||||
super().__init__(server_url, headers, timeout, sse_read_timeout)
|
||||
|
||||
self.provider_entity = provider_entity
|
||||
self.auth_callback = auth_callback
|
||||
self.authorization_code = authorization_code
|
||||
self.by_server_id = by_server_id
|
||||
self.mcp_service = mcp_service
|
||||
self._has_retried = False
|
||||
|
||||
def _handle_auth_error(self, error: MCPAuthError) -> None:
|
||||
"""
|
||||
Handle authentication error by refreshing tokens.
|
||||
|
||||
This method creates a short-lived database session only when authentication
|
||||
retry is needed, minimizing database connection hold time.
|
||||
|
||||
Args:
|
||||
error: The authentication error
|
||||
|
||||
Raises:
|
||||
MCPAuthError: If authentication fails or max retries reached
|
||||
"""
|
||||
if not self.provider_entity or not self.auth_callback or not self.mcp_service:
|
||||
if not self.provider_entity:
|
||||
raise error
|
||||
if self._has_retried:
|
||||
raise error
|
||||
|
|
@ -82,13 +81,23 @@ class MCPClientWithAuthRetry(MCPClient):
|
|||
self._has_retried = True
|
||||
|
||||
try:
|
||||
# Perform authentication
|
||||
self.auth_callback(self.provider_entity, self.mcp_service, self.authorization_code)
|
||||
# Create a temporary session only for auth retry
|
||||
# This session is short-lived and only exists during the auth operation
|
||||
|
||||
# Retrieve new tokens
|
||||
self.provider_entity = self.mcp_service.get_provider_entity(
|
||||
self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id
|
||||
)
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
with Session(db.engine) as session, session.begin():
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
|
||||
# Perform authentication using the service's auth method
|
||||
mcp_service.auth_with_actions(self.provider_entity, self.authorization_code)
|
||||
|
||||
# Retrieve new tokens
|
||||
self.provider_entity = mcp_service.get_provider_entity(
|
||||
self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id
|
||||
)
|
||||
|
||||
# Session is closed here, before we update headers
|
||||
token = self.provider_entity.retrieve_tokens()
|
||||
if not token:
|
||||
raise MCPAuthError("Authentication failed - no token received")
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.mcp.session.base_session import BaseSession
|
||||
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestId, RequestParams
|
||||
from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAuthMetadata, RequestId, RequestParams
|
||||
|
||||
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
|
||||
|
||||
|
|
@ -17,3 +20,41 @@ class RequestContext(Generic[SessionT, LifespanContextT]):
|
|||
meta: RequestParams.Meta | None
|
||||
session: SessionT
|
||||
lifespan_context: LifespanContextT
|
||||
|
||||
|
||||
class AuthActionType(StrEnum):
|
||||
"""Types of actions that can be performed during auth flow."""
|
||||
|
||||
SAVE_CLIENT_INFO = "save_client_info"
|
||||
SAVE_TOKENS = "save_tokens"
|
||||
SAVE_CODE_VERIFIER = "save_code_verifier"
|
||||
START_AUTHORIZATION = "start_authorization"
|
||||
SUCCESS = "success"
|
||||
|
||||
|
||||
class AuthAction(BaseModel):
|
||||
"""Represents an action that needs to be performed as a result of auth flow."""
|
||||
|
||||
action_type: AuthActionType
|
||||
data: dict[str, Any]
|
||||
provider_id: str | None = None
|
||||
tenant_id: str | None = None
|
||||
|
||||
|
||||
class AuthResult(BaseModel):
|
||||
"""Result of auth function containing actions to be performed and response data."""
|
||||
|
||||
actions: list[AuthAction]
|
||||
response: dict[str, str]
|
||||
|
||||
|
||||
class OAuthCallbackState(BaseModel):
|
||||
"""State data stored in Redis during OAuth callback flow."""
|
||||
|
||||
provider_id: str
|
||||
tenant_id: str
|
||||
server_url: str
|
||||
metadata: OAuthMetadata | None = None
|
||||
client_information: OAuthClientInformation
|
||||
code_verifier: str
|
||||
redirect_uri: str
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import logging
|
|||
import os
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from langfuse import Langfuse # type: ignore
|
||||
from langfuse import Langfuse
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ class PluginParameter(BaseModel):
|
|||
auto_generate: PluginParameterAutoGenerate | None = None
|
||||
template: PluginParameterTemplate | None = None
|
||||
required: bool = False
|
||||
default: Union[float, int, str] | None = None
|
||||
default: Union[float, int, str, bool] | None = None
|
||||
min: Union[float, int] | None = None
|
||||
max: Union[float, int] | None = None
|
||||
precision: int | None = None
|
||||
|
|
|
|||
|
|
@ -180,7 +180,7 @@ class BasePluginClient:
|
|||
Make a request to the plugin daemon inner API and return the response as a model.
|
||||
"""
|
||||
response = self._request(method, path, headers, data, params, files)
|
||||
return type_(**response.json()) # type: ignore
|
||||
return type_(**response.json()) # type: ignore[return-value]
|
||||
|
||||
def _request_with_plugin_daemon_response(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class PluginDaemonBadRequestError(PluginDaemonClientSideError):
|
|||
description: str = "Bad Request"
|
||||
|
||||
|
||||
class PluginInvokeError(PluginDaemonClientSideError):
|
||||
class PluginInvokeError(PluginDaemonClientSideError, ValueError):
|
||||
description: str = "Invoke Error"
|
||||
|
||||
def _get_error_object(self) -> Mapping:
|
||||
|
|
|
|||
|
|
@ -72,6 +72,19 @@ default_retrieval_model: dict[str, Any] = {
|
|||
class DatasetRetrieval:
|
||||
def __init__(self, application_generate_entity=None):
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self._llm_usage = LLMUsage.empty_usage()
|
||||
|
||||
@property
|
||||
def llm_usage(self) -> LLMUsage:
|
||||
return self._llm_usage.model_copy()
|
||||
|
||||
def _record_usage(self, usage: LLMUsage | None) -> None:
|
||||
if usage is None or usage.total_tokens <= 0:
|
||||
return
|
||||
if self._llm_usage.total_tokens == 0:
|
||||
self._llm_usage = usage
|
||||
else:
|
||||
self._llm_usage = self._llm_usage.plus(usage)
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
|
|
@ -312,15 +325,18 @@ class DatasetRetrieval:
|
|||
)
|
||||
tools.append(message_tool)
|
||||
dataset_id = None
|
||||
router_usage = LLMUsage.empty_usage()
|
||||
if planning_strategy == PlanningStrategy.REACT_ROUTER:
|
||||
react_multi_dataset_router = ReactMultiDatasetRouter()
|
||||
dataset_id = react_multi_dataset_router.invoke(
|
||||
dataset_id, router_usage = react_multi_dataset_router.invoke(
|
||||
query, tools, model_config, model_instance, user_id, tenant_id
|
||||
)
|
||||
|
||||
elif planning_strategy == PlanningStrategy.ROUTER:
|
||||
function_call_router = FunctionCallMultiDatasetRouter()
|
||||
dataset_id = function_call_router.invoke(query, tools, model_config, model_instance)
|
||||
dataset_id, router_usage = function_call_router.invoke(query, tools, model_config, model_instance)
|
||||
|
||||
self._record_usage(router_usage)
|
||||
|
||||
if dataset_id:
|
||||
# get retrieval model config
|
||||
|
|
@ -983,7 +999,8 @@ class DatasetRetrieval:
|
|||
)
|
||||
|
||||
# handle invoke result
|
||||
result_text, _ = self._handle_invoke_result(invoke_result=invoke_result)
|
||||
result_text, usage = self._handle_invoke_result(invoke_result=invoke_result)
|
||||
self._record_usage(usage)
|
||||
|
||||
result_text_json = parse_and_check_json_markdown(result_text, [])
|
||||
automatic_metadata_filters = []
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from typing import Union
|
|||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage
|
||||
|
||||
|
||||
|
|
@ -13,15 +13,15 @@ class FunctionCallMultiDatasetRouter:
|
|||
dataset_tools: list[PromptMessageTool],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_instance: ModelInstance,
|
||||
) -> Union[str, None]:
|
||||
) -> tuple[Union[str, None], LLMUsage]:
|
||||
"""Given input, decided what to do.
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
if len(dataset_tools) == 0:
|
||||
return None
|
||||
return None, LLMUsage.empty_usage()
|
||||
elif len(dataset_tools) == 1:
|
||||
return dataset_tools[0].name
|
||||
return dataset_tools[0].name, LLMUsage.empty_usage()
|
||||
|
||||
try:
|
||||
prompt_messages = [
|
||||
|
|
@ -34,9 +34,10 @@ class FunctionCallMultiDatasetRouter:
|
|||
stream=False,
|
||||
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
|
||||
)
|
||||
usage = result.usage or LLMUsage.empty_usage()
|
||||
if result.message.tool_calls:
|
||||
# get retrieval model config
|
||||
return result.message.tool_calls[0].function.name
|
||||
return None
|
||||
return result.message.tool_calls[0].function.name, usage
|
||||
return None, usage
|
||||
except Exception:
|
||||
return None
|
||||
return None, LLMUsage.empty_usage()
|
||||
|
|
|
|||
|
|
@ -58,15 +58,15 @@ class ReactMultiDatasetRouter:
|
|||
model_instance: ModelInstance,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
) -> Union[str, None]:
|
||||
) -> tuple[Union[str, None], LLMUsage]:
|
||||
"""Given input, decided what to do.
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
if len(dataset_tools) == 0:
|
||||
return None
|
||||
return None, LLMUsage.empty_usage()
|
||||
elif len(dataset_tools) == 1:
|
||||
return dataset_tools[0].name
|
||||
return dataset_tools[0].name, LLMUsage.empty_usage()
|
||||
|
||||
try:
|
||||
return self._react_invoke(
|
||||
|
|
@ -78,7 +78,7 @@ class ReactMultiDatasetRouter:
|
|||
tenant_id=tenant_id,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
return None, LLMUsage.empty_usage()
|
||||
|
||||
def _react_invoke(
|
||||
self,
|
||||
|
|
@ -91,7 +91,7 @@ class ReactMultiDatasetRouter:
|
|||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
) -> Union[str, None]:
|
||||
) -> tuple[Union[str, None], LLMUsage]:
|
||||
prompt: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
|
||||
if model_config.mode == "chat":
|
||||
prompt = self.create_chat_prompt(
|
||||
|
|
@ -120,7 +120,7 @@ class ReactMultiDatasetRouter:
|
|||
memory=None,
|
||||
model_config=model_config,
|
||||
)
|
||||
result_text, _ = self._invoke_llm(
|
||||
result_text, usage = self._invoke_llm(
|
||||
completion_param=model_config.parameters,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
|
|
@ -131,8 +131,8 @@ class ReactMultiDatasetRouter:
|
|||
output_parser = StructuredChatOutputParser()
|
||||
react_decision = output_parser.parse(result_text)
|
||||
if isinstance(react_decision, ReactAction):
|
||||
return react_decision.tool
|
||||
return None
|
||||
return react_decision.tool, usage
|
||||
return None, usage
|
||||
|
||||
def _invoke_llm(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository):
|
|||
tenant_id = extract_tenant_id(user)
|
||||
if not tenant_id:
|
||||
raise ValueError("User must have a tenant_id or current_tenant_id")
|
||||
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
# Store app context
|
||||
self._app_id = app_id
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
|
|||
tenant_id = extract_tenant_id(user)
|
||||
if not tenant_id:
|
||||
raise ValueError("User must have a tenant_id or current_tenant_id")
|
||||
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
# Store app context
|
||||
self._app_id = app_id
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ class DifyCoreRepositoryFactory:
|
|||
|
||||
try:
|
||||
repository_class = import_string(class_path)
|
||||
return repository_class( # type: ignore[no-any-return]
|
||||
return repository_class(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=app_id,
|
||||
|
|
@ -96,7 +96,7 @@ class DifyCoreRepositoryFactory:
|
|||
|
||||
try:
|
||||
repository_class = import_string(class_path)
|
||||
return repository_class( # type: ignore[no-any-return]
|
||||
return repository_class(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=app_id,
|
||||
|
|
|
|||
|
|
@ -157,7 +157,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||
"""
|
||||
returns the tool that the provider can provide
|
||||
"""
|
||||
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None) # type: ignore
|
||||
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None)
|
||||
|
||||
@property
|
||||
def need_credentials(self) -> bool:
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ class TTSTool(BuiltinTool):
|
|||
content_text=tool_parameters.get("text"), # type: ignore
|
||||
user=user_id,
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
voice=voice, # type: ignore
|
||||
voice=voice,
|
||||
)
|
||||
buffer = io.BytesIO()
|
||||
for chunk in tts:
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ class LocaltimeToTimestampTool(BuiltinTool):
|
|||
|
||||
yield self.create_text_message(f"{timestamp}")
|
||||
|
||||
# TODO: this method's type is messy
|
||||
@staticmethod
|
||||
def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -48,6 +48,6 @@ class TimezoneConversionTool(BuiltinTool):
|
|||
datetime_with_tz = input_timezone.localize(local_time)
|
||||
# timezone convert
|
||||
converted_datetime = datetime_with_tz.astimezone(output_timezone)
|
||||
return converted_datetime.strftime(format=time_format) # type: ignore
|
||||
return converted_datetime.strftime(time_format)
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(str(e))
|
||||
|
|
|
|||
|
|
@ -113,7 +113,7 @@ class MCPToolProviderController(ToolProviderController):
|
|||
"""
|
||||
pass
|
||||
|
||||
def get_tool(self, tool_name: str) -> MCPTool: # type: ignore
|
||||
def get_tool(self, tool_name: str) -> MCPTool:
|
||||
"""
|
||||
return tool with given name
|
||||
"""
|
||||
|
|
@ -136,7 +136,7 @@ class MCPToolProviderController(ToolProviderController):
|
|||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
|
||||
def get_tools(self) -> list[MCPTool]: # type: ignore
|
||||
def get_tools(self) -> list[MCPTool]:
|
||||
"""
|
||||
get all tools
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import json
|
|||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.mcp.auth.auth_flow import auth
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPConnectionError
|
||||
from core.mcp.types import CallToolResult, ImageContent, TextContent
|
||||
|
|
@ -125,71 +124,39 @@ class MCPTool(Tool):
|
|||
headers = self.headers.copy() if self.headers else {}
|
||||
tool_parameters = self._handle_none_parameter(tool_parameters)
|
||||
|
||||
# Get provider entity to access tokens
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Get MCP service from invoke parameters or create new one
|
||||
provider_entity = None
|
||||
mcp_service = None
|
||||
from extensions.ext_database import db
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
# Check if mcp_service is passed in tool_parameters
|
||||
if "_mcp_service" in tool_parameters:
|
||||
mcp_service = tool_parameters.pop("_mcp_service")
|
||||
if mcp_service:
|
||||
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
|
||||
headers = provider_entity.decrypt_headers()
|
||||
# Try to get existing token and add to headers
|
||||
if not headers:
|
||||
tokens = provider_entity.retrieve_tokens()
|
||||
if tokens and tokens.access_token:
|
||||
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
|
||||
# Step 1: Load provider entity and credentials in a short-lived session
|
||||
# This minimizes database connection hold time
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
|
||||
|
||||
# Use MCPClientWithAuthRetry to handle authentication automatically
|
||||
try:
|
||||
with MCPClientWithAuthRetry(
|
||||
server_url=provider_entity.decrypt_server_url() if provider_entity else self.server_url,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
provider_entity=provider_entity,
|
||||
auth_callback=auth if mcp_service else None,
|
||||
mcp_service=mcp_service,
|
||||
) as mcp_client:
|
||||
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||
except MCPConnectionError as e:
|
||||
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
||||
except (ValueError, TypeError, KeyError) as e:
|
||||
# Catch specific exceptions that might occur during tool invocation
|
||||
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
|
||||
else:
|
||||
# Fallback to creating service with database session
|
||||
from sqlalchemy.orm import Session
|
||||
# Decrypt and prepare all credentials before closing session
|
||||
server_url = provider_entity.decrypt_server_url()
|
||||
headers = provider_entity.decrypt_headers()
|
||||
|
||||
from extensions.ext_database import db
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
# Try to get existing token and add to headers
|
||||
if not headers:
|
||||
tokens = provider_entity.retrieve_tokens()
|
||||
if tokens and tokens.access_token:
|
||||
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
|
||||
headers = provider_entity.decrypt_headers()
|
||||
# Try to get existing token and add to headers
|
||||
if not headers:
|
||||
tokens = provider_entity.retrieve_tokens()
|
||||
if tokens and tokens.access_token:
|
||||
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
|
||||
|
||||
# Use MCPClientWithAuthRetry to handle authentication automatically
|
||||
try:
|
||||
with MCPClientWithAuthRetry(
|
||||
server_url=provider_entity.decrypt_server_url() if provider_entity else self.server_url,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
provider_entity=provider_entity,
|
||||
auth_callback=auth if mcp_service else None,
|
||||
mcp_service=mcp_service,
|
||||
) as mcp_client:
|
||||
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||
except MCPConnectionError as e:
|
||||
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
|
||||
# Step 2: Session is now closed, perform network operations without holding database connection
|
||||
# MCPClientWithAuthRetry will create a new session lazily only if auth retry is needed
|
||||
try:
|
||||
with MCPClientWithAuthRetry(
|
||||
server_url=server_url,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
provider_entity=provider_entity,
|
||||
) as mcp_client:
|
||||
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||
except MCPConnectionError as e:
|
||||
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class ToolLabelManager:
|
|||
labels = cls.filter_tool_labels(labels)
|
||||
|
||||
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
provider_id = controller.provider_id # ty: ignore [unresolved-attribute]
|
||||
provider_id = controller.provider_id
|
||||
else:
|
||||
raise ValueError("Unsupported tool type")
|
||||
|
||||
|
|
@ -51,7 +51,7 @@ class ToolLabelManager:
|
|||
Get tool labels
|
||||
"""
|
||||
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
provider_id = controller.provider_id # ty: ignore [unresolved-attribute]
|
||||
provider_id = controller.provider_id
|
||||
elif isinstance(controller, BuiltinToolProviderController):
|
||||
return controller.tool_labels
|
||||
else:
|
||||
|
|
@ -85,7 +85,7 @@ class ToolLabelManager:
|
|||
provider_ids = []
|
||||
for controller in tool_providers:
|
||||
assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
|
||||
provider_ids.append(controller.provider_id) # ty: ignore [unresolved-attribute]
|
||||
provider_ids.append(controller.provider_id)
|
||||
|
||||
labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all()
|
||||
|
||||
|
|
|
|||
|
|
@ -331,7 +331,8 @@ class ToolManager:
|
|||
workflow_provider_stmt = select(WorkflowToolProvider).where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id
|
||||
)
|
||||
workflow_provider = db.session.scalar(workflow_provider_stmt)
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
workflow_provider = session.scalar(workflow_provider_stmt)
|
||||
|
||||
if workflow_provider is None:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
|
|
|||
|
|
@ -193,18 +193,18 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
document = db.session.scalar(dataset_document_stmt) # type: ignore
|
||||
document = db.session.scalar(dataset_document_stmt)
|
||||
if dataset and document:
|
||||
source = RetrievalSourceMetadata(
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
document_id=document.id, # type: ignore
|
||||
document_name=document.name, # type: ignore
|
||||
data_source_type=document.data_source_type, # type: ignore
|
||||
document_id=document.id,
|
||||
document_name=document.name,
|
||||
data_source_type=document.data_source_type,
|
||||
segment_id=segment.id,
|
||||
retriever_from=self.retriever_from,
|
||||
score=record.score or 0.0,
|
||||
doc_metadata=document.doc_metadata, # type: ignore
|
||||
doc_metadata=document.doc_metadata,
|
||||
)
|
||||
|
||||
if self.retriever_from == "dev":
|
||||
|
|
|
|||
|
|
@ -62,6 +62,11 @@ class ApiBasedToolSchemaParser:
|
|||
root = root[ref]
|
||||
interface["operation"]["parameters"][i] = root
|
||||
for parameter in interface["operation"]["parameters"]:
|
||||
# Handle complex type defaults that are not supported by PluginParameter
|
||||
default_value = None
|
||||
if "schema" in parameter and "default" in parameter["schema"]:
|
||||
default_value = ApiBasedToolSchemaParser._sanitize_default_value(parameter["schema"]["default"])
|
||||
|
||||
tool_parameter = ToolParameter(
|
||||
name=parameter["name"],
|
||||
label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]),
|
||||
|
|
@ -72,9 +77,7 @@ class ApiBasedToolSchemaParser:
|
|||
required=parameter.get("required", False),
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=parameter.get("description"),
|
||||
default=parameter["schema"]["default"]
|
||||
if "schema" in parameter and "default" in parameter["schema"]
|
||||
else None,
|
||||
default=default_value,
|
||||
placeholder=I18nObject(
|
||||
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||
),
|
||||
|
|
@ -134,6 +137,11 @@ class ApiBasedToolSchemaParser:
|
|||
required = body_schema.get("required", [])
|
||||
properties = body_schema.get("properties", {})
|
||||
for name, property in properties.items():
|
||||
# Handle complex type defaults that are not supported by PluginParameter
|
||||
default_value = ApiBasedToolSchemaParser._sanitize_default_value(
|
||||
property.get("default", None)
|
||||
)
|
||||
|
||||
tool = ToolParameter(
|
||||
name=name,
|
||||
label=I18nObject(en_US=name, zh_Hans=name),
|
||||
|
|
@ -144,12 +152,11 @@ class ApiBasedToolSchemaParser:
|
|||
required=name in required,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=property.get("description", ""),
|
||||
default=property.get("default", None),
|
||||
default=default_value,
|
||||
placeholder=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
)
|
||||
|
||||
# check if there is a type
|
||||
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
|
||||
if typ:
|
||||
|
|
@ -197,6 +204,22 @@ class ApiBasedToolSchemaParser:
|
|||
|
||||
return bundles
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_default_value(value):
|
||||
"""
|
||||
Sanitize default values for PluginParameter compatibility.
|
||||
Complex types (list, dict) are converted to None to avoid validation errors.
|
||||
|
||||
Args:
|
||||
value: The default value from OpenAPI schema
|
||||
|
||||
Returns:
|
||||
None for complex types (list, dict), otherwise the original value
|
||||
"""
|
||||
if isinstance(value, (list, dict)):
|
||||
return None
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None:
|
||||
parameter = parameter or {}
|
||||
|
|
@ -217,7 +240,11 @@ class ApiBasedToolSchemaParser:
|
|||
return ToolParameter.ToolParameterType.STRING
|
||||
elif typ == "array":
|
||||
items = parameter.get("items") or parameter.get("schema", {}).get("items")
|
||||
return ToolParameter.ToolParameterType.FILES if items and items.get("format") == "binary" else None
|
||||
if items and items.get("format") == "binary":
|
||||
return ToolParameter.ToolParameterType.FILES
|
||||
else:
|
||||
# For regular arrays, return ARRAY type instead of None
|
||||
return ToolParameter.ToolParameterType.ARRAY
|
||||
else:
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ from typing import Any, cast
|
|||
from urllib.parse import unquote
|
||||
|
||||
import chardet
|
||||
import cloudscraper # type: ignore
|
||||
from readabilipy import simple_json_from_html_string # type: ignore
|
||||
import cloudscraper
|
||||
from readabilipy import simple_json_from_html_string
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.rag.extractor import extract_processor
|
||||
|
|
@ -63,8 +63,8 @@ def get_url(url: str, user_agent: str | None = None) -> str:
|
|||
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
|
||||
elif response.status_code == 403:
|
||||
scraper = cloudscraper.create_scraper()
|
||||
scraper.perform_request = ssrf_proxy.make_request # type: ignore
|
||||
response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) # type: ignore
|
||||
scraper.perform_request = ssrf_proxy.make_request
|
||||
response = scraper.get(url, headers=headers, timeout=(120, 300))
|
||||
|
||||
if response.status_code != 200:
|
||||
return f"URL returned status code {response.status_code}."
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from functools import lru_cache
|
|||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml # type: ignore
|
||||
import yaml
|
||||
from yaml import YAMLError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from collections.abc import Mapping
|
||||
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
|
|
@ -20,6 +21,7 @@ from core.tools.entities.tool_entities import (
|
|||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, AppMode
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.workflow import Workflow
|
||||
|
|
@ -44,29 +46,34 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||
|
||||
@classmethod
|
||||
def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController":
|
||||
app = db_provider.app
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
provider = session.get(WorkflowToolProvider, db_provider.id) if db_provider.id else None
|
||||
if not provider:
|
||||
raise ValueError("workflow provider not found")
|
||||
app = session.get(App, provider.app_id)
|
||||
if not app:
|
||||
raise ValueError("app not found")
|
||||
|
||||
if not app:
|
||||
raise ValueError("app not found")
|
||||
user = session.get(Account, provider.user_id) if provider.user_id else None
|
||||
|
||||
controller = WorkflowToolProviderController(
|
||||
entity=ToolProviderEntity(
|
||||
identity=ToolProviderIdentity(
|
||||
author=db_provider.user.name if db_provider.user_id and db_provider.user else "",
|
||||
name=db_provider.label,
|
||||
label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label),
|
||||
description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
|
||||
icon=db_provider.icon,
|
||||
controller = WorkflowToolProviderController(
|
||||
entity=ToolProviderEntity(
|
||||
identity=ToolProviderIdentity(
|
||||
author=user.name if user else "",
|
||||
name=provider.label,
|
||||
label=I18nObject(en_US=provider.label, zh_Hans=provider.label),
|
||||
description=I18nObject(en_US=provider.description, zh_Hans=provider.description),
|
||||
icon=provider.icon,
|
||||
),
|
||||
credentials_schema=[],
|
||||
plugin_id=None,
|
||||
),
|
||||
credentials_schema=[],
|
||||
plugin_id=None,
|
||||
),
|
||||
provider_id=db_provider.id or "",
|
||||
)
|
||||
provider_id=provider.id or "",
|
||||
)
|
||||
|
||||
# init tools
|
||||
|
||||
controller.tools = [controller._get_db_provider_tool(db_provider, app)]
|
||||
controller.tools = [
|
||||
controller._get_db_provider_tool(provider, app, session=session, user=user),
|
||||
]
|
||||
|
||||
return controller
|
||||
|
||||
|
|
@ -74,7 +81,14 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||
def provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.WORKFLOW
|
||||
|
||||
def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool:
|
||||
def _get_db_provider_tool(
|
||||
self,
|
||||
db_provider: WorkflowToolProvider,
|
||||
app: App,
|
||||
*,
|
||||
session: Session,
|
||||
user: Account | None = None,
|
||||
) -> WorkflowTool:
|
||||
"""
|
||||
get db provider tool
|
||||
:param db_provider: the db provider
|
||||
|
|
@ -82,7 +96,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||
:return: the tool
|
||||
"""
|
||||
workflow: Workflow | None = (
|
||||
db.session.query(Workflow)
|
||||
session.query(Workflow)
|
||||
.where(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
|
||||
.first()
|
||||
)
|
||||
|
|
@ -99,9 +113,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
|
||||
|
||||
def fetch_workflow_variable(variable_name: str) -> VariableEntity | None:
|
||||
return next(filter(lambda x: x.variable == variable_name, variables), None) # type: ignore
|
||||
|
||||
user = db_provider.user
|
||||
return next(filter(lambda x: x.variable == variable_name, variables), None)
|
||||
|
||||
workflow_tool_parameters = []
|
||||
for parameter in parameters:
|
||||
|
|
@ -187,22 +199,25 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||
if self.tools is not None:
|
||||
return self.tools
|
||||
|
||||
db_providers: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.app_id == self.provider_id,
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
db_provider: WorkflowToolProvider | None = (
|
||||
session.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.app_id == self.provider_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not db_providers:
|
||||
return []
|
||||
if not db_providers.app:
|
||||
raise ValueError("app not found")
|
||||
if not db_provider:
|
||||
return []
|
||||
|
||||
app = db_providers.app
|
||||
self.tools = [self._get_db_provider_tool(db_providers, app)]
|
||||
app = session.get(App, db_provider.app_id)
|
||||
if not app:
|
||||
raise ValueError("app not found")
|
||||
|
||||
user = session.get(Account, db_provider.user_id) if db_provider.user_id else None
|
||||
self.tools = [self._get_db_provider_tool(db_provider, app, session=session, user=user)]
|
||||
|
||||
return self.tools
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,14 @@
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import has_request_context
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import (
|
||||
|
|
@ -48,6 +50,7 @@ class WorkflowTool(Tool):
|
|||
self.workflow_entities = workflow_entities
|
||||
self.workflow_call_depth = workflow_call_depth
|
||||
self.label = label
|
||||
self._latest_usage = LLMUsage.empty_usage()
|
||||
|
||||
super().__init__(entity=entity, runtime=runtime)
|
||||
|
||||
|
|
@ -83,10 +86,11 @@ class WorkflowTool(Tool):
|
|||
assert self.runtime.invoke_from is not None
|
||||
|
||||
user = self._resolve_user(user_id=user_id)
|
||||
|
||||
if user is None:
|
||||
raise ToolInvokeError("User not found")
|
||||
|
||||
self._latest_usage = LLMUsage.empty_usage()
|
||||
|
||||
result = generator.generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
|
|
@ -110,9 +114,68 @@ class WorkflowTool(Tool):
|
|||
for file in files:
|
||||
yield self.create_file_message(file) # type: ignore
|
||||
|
||||
self._latest_usage = self._derive_usage_from_result(data)
|
||||
|
||||
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
|
||||
yield self.create_json_message(outputs)
|
||||
|
||||
@property
|
||||
def latest_usage(self) -> LLMUsage:
|
||||
return self._latest_usage
|
||||
|
||||
@classmethod
|
||||
def _derive_usage_from_result(cls, data: Mapping[str, Any]) -> LLMUsage:
|
||||
usage_dict = cls._extract_usage_dict(data)
|
||||
if usage_dict is not None:
|
||||
return LLMUsage.from_metadata(cast(LLMUsageMetadata, dict(usage_dict)))
|
||||
|
||||
total_tokens = data.get("total_tokens")
|
||||
total_price = data.get("total_price")
|
||||
if total_tokens is None and total_price is None:
|
||||
return LLMUsage.empty_usage()
|
||||
|
||||
usage_metadata: dict[str, Any] = {}
|
||||
if total_tokens is not None:
|
||||
try:
|
||||
usage_metadata["total_tokens"] = int(str(total_tokens))
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
if total_price is not None:
|
||||
usage_metadata["total_price"] = str(total_price)
|
||||
currency = data.get("currency")
|
||||
if currency is not None:
|
||||
usage_metadata["currency"] = currency
|
||||
|
||||
if not usage_metadata:
|
||||
return LLMUsage.empty_usage()
|
||||
|
||||
return LLMUsage.from_metadata(cast(LLMUsageMetadata, usage_metadata))
|
||||
|
||||
@classmethod
|
||||
def _extract_usage_dict(cls, payload: Mapping[str, Any]) -> Mapping[str, Any] | None:
|
||||
usage_candidate = payload.get("usage")
|
||||
if isinstance(usage_candidate, Mapping):
|
||||
return usage_candidate
|
||||
|
||||
metadata_candidate = payload.get("metadata")
|
||||
if isinstance(metadata_candidate, Mapping):
|
||||
usage_candidate = metadata_candidate.get("usage")
|
||||
if isinstance(usage_candidate, Mapping):
|
||||
return usage_candidate
|
||||
|
||||
for value in payload.values():
|
||||
if isinstance(value, Mapping):
|
||||
found = cls._extract_usage_dict(value)
|
||||
if found is not None:
|
||||
return found
|
||||
elif isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
|
||||
for item in value:
|
||||
if isinstance(item, Mapping):
|
||||
found = cls._extract_usage_dict(item)
|
||||
if found is not None:
|
||||
return found
|
||||
return None
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
|
||||
"""
|
||||
fork a new tool with metadata
|
||||
|
|
@ -179,16 +242,17 @@ class WorkflowTool(Tool):
|
|||
"""
|
||||
get the workflow by app id and version
|
||||
"""
|
||||
if not version:
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.where(Workflow.app_id == app_id, Workflow.version != Workflow.VERSION_DRAFT)
|
||||
.order_by(Workflow.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version)
|
||||
workflow = db.session.scalar(stmt)
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
if not version:
|
||||
stmt = (
|
||||
select(Workflow)
|
||||
.where(Workflow.app_id == app_id, Workflow.version != Workflow.VERSION_DRAFT)
|
||||
.order_by(Workflow.created_at.desc())
|
||||
)
|
||||
workflow = session.scalars(stmt).first()
|
||||
else:
|
||||
stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version)
|
||||
workflow = session.scalar(stmt)
|
||||
|
||||
if not workflow:
|
||||
raise ValueError("workflow not found or not published")
|
||||
|
|
@ -200,7 +264,8 @@ class WorkflowTool(Tool):
|
|||
get the app by app id
|
||||
"""
|
||||
stmt = select(App).where(App.id == app_id)
|
||||
app = db.session.scalar(stmt)
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
app = session.scalar(stmt)
|
||||
if not app:
|
||||
raise ValueError("app not found")
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from .types import SegmentType
|
|||
|
||||
class SegmentGroup(Segment):
|
||||
value_type: SegmentType = SegmentType.GROUP
|
||||
value: list[Segment] = None # type: ignore
|
||||
value: list[Segment]
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ class Segment(BaseModel):
|
|||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
value_type: SegmentType
|
||||
value: Any = None
|
||||
value: Any
|
||||
|
||||
@field_validator("value_type")
|
||||
@classmethod
|
||||
|
|
@ -74,12 +74,12 @@ class NoneSegment(Segment):
|
|||
|
||||
class StringSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.STRING
|
||||
value: str = None # type: ignore
|
||||
value: str
|
||||
|
||||
|
||||
class FloatSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.FLOAT
|
||||
value: float = None # type: ignore
|
||||
value: float
|
||||
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
|
||||
# The following tests cannot pass.
|
||||
#
|
||||
|
|
@ -98,12 +98,12 @@ class FloatSegment(Segment):
|
|||
|
||||
class IntegerSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.INTEGER
|
||||
value: int = None # type: ignore
|
||||
value: int
|
||||
|
||||
|
||||
class ObjectSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.OBJECT
|
||||
value: Mapping[str, Any] = None # type: ignore
|
||||
value: Mapping[str, Any]
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
|
|
@ -136,7 +136,7 @@ class ArraySegment(Segment):
|
|||
|
||||
class FileSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.FILE
|
||||
value: File = None # type: ignore
|
||||
value: File
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
|
|
@ -153,17 +153,17 @@ class FileSegment(Segment):
|
|||
|
||||
class BooleanSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.BOOLEAN
|
||||
value: bool = None # type: ignore
|
||||
value: bool
|
||||
|
||||
|
||||
class ArrayAnySegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_ANY
|
||||
value: Sequence[Any] = None # type: ignore
|
||||
value: Sequence[Any]
|
||||
|
||||
|
||||
class ArrayStringSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_STRING
|
||||
value: Sequence[str] = None # type: ignore
|
||||
value: Sequence[str]
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
|
|
@ -175,17 +175,17 @@ class ArrayStringSegment(ArraySegment):
|
|||
|
||||
class ArrayNumberSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_NUMBER
|
||||
value: Sequence[float | int] = None # type: ignore
|
||||
value: Sequence[float | int]
|
||||
|
||||
|
||||
class ArrayObjectSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_OBJECT
|
||||
value: Sequence[Mapping[str, Any]] = None # type: ignore
|
||||
value: Sequence[Mapping[str, Any]]
|
||||
|
||||
|
||||
class ArrayFileSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_FILE
|
||||
value: Sequence[File] = None # type: ignore
|
||||
value: Sequence[File]
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
|
|
@ -247,7 +247,7 @@ class VersionedMemorySegment(Segment):
|
|||
|
||||
class ArrayBooleanSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_BOOLEAN
|
||||
value: Sequence[bool] = None # type: ignore
|
||||
value: Sequence[bool]
|
||||
|
||||
|
||||
def get_segment_discriminator(v: Any) -> SegmentType | None:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from ..runtime.graph_runtime_state import GraphRuntimeState
|
||||
from ..runtime.variable_pool import VariablePool
|
||||
from .agent import AgentNodeStrategyInit
|
||||
from .graph_init_params import GraphInitParams
|
||||
from .workflow_execution import WorkflowExecution
|
||||
|
|
@ -6,6 +8,8 @@ from .workflow_node_execution import WorkflowNodeExecution
|
|||
__all__ = [
|
||||
"AgentNodeStrategyInit",
|
||||
"GraphInitParams",
|
||||
"GraphRuntimeState",
|
||||
"VariablePool",
|
||||
"WorkflowExecution",
|
||||
"WorkflowNodeExecution",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -3,11 +3,12 @@ from collections import defaultdict
|
|||
from collections.abc import Mapping, Sequence
|
||||
from typing import Protocol, cast, final
|
||||
|
||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from libs.typing import is_str, is_str_dict
|
||||
|
||||
from .edge import Edge
|
||||
from .validation import get_graph_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -201,6 +202,17 @@ class Graph:
|
|||
|
||||
return GraphBuilder(graph_cls=cls)
|
||||
|
||||
@classmethod
|
||||
def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None:
|
||||
"""
|
||||
Promote nodes configured with FAIL_BRANCH error strategy to branch execution type.
|
||||
|
||||
:param nodes: mapping of node ID to node instance
|
||||
"""
|
||||
for node in nodes.values():
|
||||
if node.error_strategy == ErrorStrategy.FAIL_BRANCH:
|
||||
node.execution_type = NodeExecutionType.BRANCH
|
||||
|
||||
@classmethod
|
||||
def _mark_inactive_root_branches(
|
||||
cls,
|
||||
|
|
@ -307,6 +319,9 @@ class Graph:
|
|||
# Create node instances
|
||||
nodes = cls._create_node_instances(node_configs_map, node_factory)
|
||||
|
||||
# Promote fail-branch nodes to branch execution type at graph level
|
||||
cls._promote_fail_branch_nodes(nodes)
|
||||
|
||||
# Get root node instance
|
||||
root_node = nodes[root_node_id]
|
||||
|
||||
|
|
@ -314,7 +329,7 @@ class Graph:
|
|||
cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id)
|
||||
|
||||
# Create and return the graph
|
||||
return cls(
|
||||
graph = cls(
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
in_edges=in_edges,
|
||||
|
|
@ -322,6 +337,11 @@ class Graph:
|
|||
root_node=root_node,
|
||||
)
|
||||
|
||||
# Validate the graph structure using built-in validators
|
||||
get_graph_validator().validate(graph)
|
||||
|
||||
return graph
|
||||
|
||||
@property
|
||||
def node_ids(self) -> list[str]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,125 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
from core.workflow.enums import NodeExecutionType, NodeType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .graph import Graph
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class GraphValidationIssue:
|
||||
"""Immutable value object describing a single validation issue."""
|
||||
|
||||
code: str
|
||||
message: str
|
||||
node_id: str | None = None
|
||||
|
||||
|
||||
class GraphValidationError(ValueError):
|
||||
"""Raised when graph validation fails."""
|
||||
|
||||
def __init__(self, issues: Sequence[GraphValidationIssue]) -> None:
|
||||
if not issues:
|
||||
raise ValueError("GraphValidationError requires at least one issue.")
|
||||
self.issues: tuple[GraphValidationIssue, ...] = tuple(issues)
|
||||
message = "; ".join(f"[{issue.code}] {issue.message}" for issue in self.issues)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class GraphValidationRule(Protocol):
|
||||
"""Protocol that individual validation rules must satisfy."""
|
||||
|
||||
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
||||
"""Validate the provided graph and return any discovered issues."""
|
||||
...
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _EdgeEndpointValidator:
|
||||
"""Ensures all edges reference existing nodes."""
|
||||
|
||||
missing_node_code: str = "MISSING_NODE"
|
||||
|
||||
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
||||
issues: list[GraphValidationIssue] = []
|
||||
for edge in graph.edges.values():
|
||||
if edge.tail not in graph.nodes:
|
||||
issues.append(
|
||||
GraphValidationIssue(
|
||||
code=self.missing_node_code,
|
||||
message=f"Edge {edge.id} references unknown source node '{edge.tail}'.",
|
||||
node_id=edge.tail,
|
||||
)
|
||||
)
|
||||
if edge.head not in graph.nodes:
|
||||
issues.append(
|
||||
GraphValidationIssue(
|
||||
code=self.missing_node_code,
|
||||
message=f"Edge {edge.id} references unknown target node '{edge.head}'.",
|
||||
node_id=edge.head,
|
||||
)
|
||||
)
|
||||
return issues
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _RootNodeValidator:
|
||||
"""Validates root node invariants."""
|
||||
|
||||
invalid_root_code: str = "INVALID_ROOT"
|
||||
container_entry_types: tuple[NodeType, ...] = (NodeType.ITERATION_START, NodeType.LOOP_START)
|
||||
|
||||
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
||||
root_node = graph.root_node
|
||||
issues: list[GraphValidationIssue] = []
|
||||
if root_node.id not in graph.nodes:
|
||||
issues.append(
|
||||
GraphValidationIssue(
|
||||
code=self.invalid_root_code,
|
||||
message=f"Root node '{root_node.id}' is missing from the node registry.",
|
||||
node_id=root_node.id,
|
||||
)
|
||||
)
|
||||
return issues
|
||||
|
||||
node_type = getattr(root_node, "node_type", None)
|
||||
if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types:
|
||||
issues.append(
|
||||
GraphValidationIssue(
|
||||
code=self.invalid_root_code,
|
||||
message=f"Root node '{root_node.id}' must declare execution type 'root'.",
|
||||
node_id=root_node.id,
|
||||
)
|
||||
)
|
||||
return issues
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class GraphValidator:
|
||||
"""Coordinates execution of graph validation rules."""
|
||||
|
||||
rules: tuple[GraphValidationRule, ...]
|
||||
|
||||
def validate(self, graph: Graph) -> None:
|
||||
"""Validate the graph against all configured rules."""
|
||||
issues: list[GraphValidationIssue] = []
|
||||
for rule in self.rules:
|
||||
issues.extend(rule.validate(graph))
|
||||
|
||||
if issues:
|
||||
raise GraphValidationError(issues)
|
||||
|
||||
|
||||
_DEFAULT_RULES: tuple[GraphValidationRule, ...] = (
|
||||
_EdgeEndpointValidator(),
|
||||
_RootNodeValidator(),
|
||||
)
|
||||
|
||||
|
||||
def get_graph_validator() -> GraphValidator:
|
||||
"""Construct the validator composed of default rules."""
|
||||
return GraphValidator(_DEFAULT_RULES)
|
||||
|
|
@ -26,8 +26,8 @@ class AgentNodeData(BaseNodeData):
|
|||
|
||||
|
||||
class ParamsAutoGenerated(IntEnum):
|
||||
CLOSE = auto()
|
||||
OPEN = auto()
|
||||
CLOSE = 0
|
||||
OPEN = 1
|
||||
|
||||
|
||||
class AgentOldVersionModelFeatures(StrEnum):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||
from .usage_tracking_mixin import LLMUsageTrackingMixin
|
||||
|
||||
__all__ = [
|
||||
"BaseIterationNodeData",
|
||||
|
|
@ -6,4 +7,5 @@ __all__ = [
|
|||
"BaseLoopNodeData",
|
||||
"BaseLoopState",
|
||||
"BaseNodeData",
|
||||
"LLMUsageTrackingMixin",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import json
|
||||
from abc import ABC
|
||||
from builtins import type as type_
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
|
|
@ -58,10 +59,9 @@ class DefaultValue(BaseModel):
|
|||
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
|
||||
|
||||
@staticmethod
|
||||
def _validate_array(value: Any, element_type: DefaultValueType) -> bool:
|
||||
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
|
||||
"""Unified array type validation"""
|
||||
# FIXME, type ignore here for do not find the reason mypy complain, if find the root cause, please fix it
|
||||
return isinstance(value, list) and all(isinstance(x, element_type) for x in value) # type: ignore
|
||||
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
|
||||
|
||||
@staticmethod
|
||||
def _convert_number(value: str) -> float:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,28 @@
|
|||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
class LLMUsageTrackingMixin:
|
||||
"""Provides shared helpers for merging and recording LLM usage within workflow nodes."""
|
||||
|
||||
graph_runtime_state: GraphRuntimeState
|
||||
|
||||
@staticmethod
|
||||
def _merge_usage(current: LLMUsage, new_usage: LLMUsage | None) -> LLMUsage:
|
||||
"""Return a combined usage snapshot, preserving zero-value inputs."""
|
||||
if new_usage is None or new_usage.total_tokens <= 0:
|
||||
return current
|
||||
if current.total_tokens == 0:
|
||||
return new_usage
|
||||
return current.plus(new_usage)
|
||||
|
||||
def _accumulate_usage(self, usage: LLMUsage) -> None:
|
||||
"""Push usage into the graph runtime accumulator for downstream reporting."""
|
||||
if usage.total_tokens <= 0:
|
||||
return
|
||||
|
||||
current_usage = self.graph_runtime_state.llm_usage
|
||||
if current_usage.total_tokens == 0:
|
||||
self.graph_runtime_state.llm_usage = usage.model_copy()
|
||||
else:
|
||||
self.graph_runtime_state.llm_usage = current_usage.plus(usage)
|
||||
|
|
@ -10,10 +10,10 @@ from typing import Any
|
|||
import chardet
|
||||
import docx
|
||||
import pandas as pd
|
||||
import pypandoc # type: ignore
|
||||
import pypdfium2 # type: ignore
|
||||
import webvtt # type: ignore
|
||||
import yaml # type: ignore
|
||||
import pypandoc
|
||||
import pypdfium2
|
||||
import webvtt
|
||||
import yaml
|
||||
from docx.document import Document
|
||||
from docx.oxml.table import CT_Tbl
|
||||
from docx.oxml.text.paragraph import CT_P
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, NewType, cast
|
|||
from flask import Flask, current_app
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables import IntegerVariable, NoneSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||
from core.variables.variables import VariableUnion
|
||||
|
|
@ -34,6 +35,7 @@ from core.workflow.node_events import (
|
|||
NodeRunResult,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
|
|
@ -58,7 +60,7 @@ logger = logging.getLogger(__name__)
|
|||
EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment)
|
||||
|
||||
|
||||
class IterationNode(Node):
|
||||
class IterationNode(LLMUsageTrackingMixin, Node):
|
||||
"""
|
||||
Iteration Node.
|
||||
"""
|
||||
|
|
@ -118,6 +120,7 @@ class IterationNode(Node):
|
|||
started_at = naive_utc_now()
|
||||
iter_run_map: dict[str, float] = {}
|
||||
outputs: list[object] = []
|
||||
usage_accumulator = [LLMUsage.empty_usage()]
|
||||
|
||||
yield IterationStartedEvent(
|
||||
start_at=started_at,
|
||||
|
|
@ -130,22 +133,27 @@ class IterationNode(Node):
|
|||
iterator_list_value=iterator_list_value,
|
||||
outputs=outputs,
|
||||
iter_run_map=iter_run_map,
|
||||
usage_accumulator=usage_accumulator,
|
||||
)
|
||||
|
||||
self._accumulate_usage(usage_accumulator[0])
|
||||
yield from self._handle_iteration_success(
|
||||
started_at=started_at,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
iterator_list_value=iterator_list_value,
|
||||
iter_run_map=iter_run_map,
|
||||
usage=usage_accumulator[0],
|
||||
)
|
||||
except IterationNodeError as e:
|
||||
self._accumulate_usage(usage_accumulator[0])
|
||||
yield from self._handle_iteration_failure(
|
||||
started_at=started_at,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
iterator_list_value=iterator_list_value,
|
||||
iter_run_map=iter_run_map,
|
||||
usage=usage_accumulator[0],
|
||||
error=e,
|
||||
)
|
||||
|
||||
|
|
@ -196,6 +204,7 @@ class IterationNode(Node):
|
|||
iterator_list_value: Sequence[object],
|
||||
outputs: list[object],
|
||||
iter_run_map: dict[str, float],
|
||||
usage_accumulator: list[LLMUsage],
|
||||
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
|
||||
if self._node_data.is_parallel:
|
||||
# Parallel mode execution
|
||||
|
|
@ -203,6 +212,7 @@ class IterationNode(Node):
|
|||
iterator_list_value=iterator_list_value,
|
||||
outputs=outputs,
|
||||
iter_run_map=iter_run_map,
|
||||
usage_accumulator=usage_accumulator,
|
||||
)
|
||||
else:
|
||||
# Sequential mode execution
|
||||
|
|
@ -228,6 +238,9 @@ class IterationNode(Node):
|
|||
|
||||
# Update the total tokens from this iteration
|
||||
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
|
||||
usage_accumulator[0] = self._merge_usage(
|
||||
usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage
|
||||
)
|
||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
|
||||
def _execute_parallel_iterations(
|
||||
|
|
@ -235,6 +248,7 @@ class IterationNode(Node):
|
|||
iterator_list_value: Sequence[object],
|
||||
outputs: list[object],
|
||||
iter_run_map: dict[str, float],
|
||||
usage_accumulator: list[LLMUsage],
|
||||
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
|
||||
# Initialize outputs list with None values to maintain order
|
||||
outputs.extend([None] * len(iterator_list_value))
|
||||
|
|
@ -245,7 +259,16 @@ class IterationNode(Node):
|
|||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all iteration tasks
|
||||
future_to_index: dict[
|
||||
Future[tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]],
|
||||
Future[
|
||||
tuple[
|
||||
datetime,
|
||||
list[GraphNodeEventBase],
|
||||
object | None,
|
||||
int,
|
||||
dict[str, VariableUnion],
|
||||
LLMUsage,
|
||||
]
|
||||
],
|
||||
int,
|
||||
] = {}
|
||||
for index, item in enumerate(iterator_list_value):
|
||||
|
|
@ -264,7 +287,14 @@ class IterationNode(Node):
|
|||
index = future_to_index[future]
|
||||
try:
|
||||
result = future.result()
|
||||
iter_start_at, events, output_value, tokens_used, conversation_snapshot = result
|
||||
(
|
||||
iter_start_at,
|
||||
events,
|
||||
output_value,
|
||||
tokens_used,
|
||||
conversation_snapshot,
|
||||
iteration_usage,
|
||||
) = result
|
||||
|
||||
# Update outputs at the correct index
|
||||
outputs[index] = output_value
|
||||
|
|
@ -276,6 +306,8 @@ class IterationNode(Node):
|
|||
self.graph_runtime_state.total_tokens += tokens_used
|
||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
|
||||
usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage)
|
||||
|
||||
# Sync conversation variables after iteration completion
|
||||
self._sync_conversation_variables_from_snapshot(conversation_snapshot)
|
||||
|
||||
|
|
@ -303,7 +335,7 @@ class IterationNode(Node):
|
|||
item: object,
|
||||
flask_app: Flask,
|
||||
context_vars: contextvars.Context,
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]:
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion], LLMUsage]:
|
||||
"""Execute a single iteration in parallel mode and return results."""
|
||||
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
|
@ -332,6 +364,7 @@ class IterationNode(Node):
|
|||
output_value,
|
||||
graph_engine.graph_runtime_state.total_tokens,
|
||||
conversation_snapshot,
|
||||
graph_engine.graph_runtime_state.llm_usage,
|
||||
)
|
||||
|
||||
def _handle_iteration_success(
|
||||
|
|
@ -341,6 +374,8 @@ class IterationNode(Node):
|
|||
outputs: list[object],
|
||||
iterator_list_value: Sequence[object],
|
||||
iter_run_map: dict[str, float],
|
||||
*,
|
||||
usage: LLMUsage,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
# Flatten the list of lists if all outputs are lists
|
||||
flattened_outputs = self._flatten_outputs_if_needed(outputs)
|
||||
|
|
@ -351,7 +386,9 @@ class IterationNode(Node):
|
|||
outputs={"output": flattened_outputs},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||
},
|
||||
)
|
||||
|
|
@ -362,8 +399,11 @@ class IterationNode(Node):
|
|||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"output": flattened_outputs},
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
},
|
||||
llm_usage=usage,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -400,6 +440,8 @@ class IterationNode(Node):
|
|||
outputs: list[object],
|
||||
iterator_list_value: Sequence[object],
|
||||
iter_run_map: dict[str, float],
|
||||
*,
|
||||
usage: LLMUsage,
|
||||
error: IterationNodeError,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
# Flatten the list of lists if all outputs are lists (even in failure case)
|
||||
|
|
@ -411,7 +453,9 @@ class IterationNode(Node):
|
|||
outputs={"output": flattened_outputs},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||
},
|
||||
error=str(error),
|
||||
|
|
@ -420,6 +464,12 @@ class IterationNode(Node):
|
|||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(error),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
},
|
||||
llm_usage=usage,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -15,14 +15,11 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
|
|||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessageRole,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
ModelFeature,
|
||||
ModelType,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
||||
|
|
@ -33,8 +30,14 @@ from core.variables import (
|
|||
)
|
||||
from core.variables.segments import ArrayObjectSegment
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
NodeType,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
|
||||
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
|
||||
|
|
@ -80,7 +83,7 @@ default_retrieval_model = {
|
|||
}
|
||||
|
||||
|
||||
class KnowledgeRetrievalNode(Node):
|
||||
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node):
|
||||
node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
|
||||
_node_data: KnowledgeRetrievalNodeData
|
||||
|
|
@ -141,7 +144,7 @@ class KnowledgeRetrievalNode(Node):
|
|||
def version(cls):
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult: # type: ignore
|
||||
def _run(self) -> NodeRunResult:
|
||||
# extract variables
|
||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
|
||||
if not isinstance(variable, StringSegment):
|
||||
|
|
@ -182,14 +185,21 @@ class KnowledgeRetrievalNode(Node):
|
|||
)
|
||||
|
||||
# retrieve knowledge
|
||||
usage = LLMUsage.empty_usage()
|
||||
try:
|
||||
results = self._fetch_dataset_retriever(node_data=self._node_data, query=query)
|
||||
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, query=query)
|
||||
outputs = {"result": ArrayObjectSegment(value=results)}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
process_data={},
|
||||
process_data={"usage": jsonable_encoder(usage)},
|
||||
outputs=outputs, # type: ignore
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
},
|
||||
llm_usage=usage,
|
||||
)
|
||||
|
||||
except KnowledgeRetrievalNodeError as e:
|
||||
|
|
@ -199,6 +209,7 @@ class KnowledgeRetrievalNode(Node):
|
|||
inputs=variables,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
llm_usage=usage,
|
||||
)
|
||||
# Temporary handle all exceptions from DatasetRetrieval class here.
|
||||
except Exception as e:
|
||||
|
|
@ -207,11 +218,15 @@ class KnowledgeRetrievalNode(Node):
|
|||
inputs=variables,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
llm_usage=usage,
|
||||
)
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]:
|
||||
def _fetch_dataset_retriever(
|
||||
self, node_data: KnowledgeRetrievalNodeData, query: str
|
||||
) -> tuple[list[dict[str, Any]], LLMUsage]:
|
||||
usage = LLMUsage.empty_usage()
|
||||
available_datasets = []
|
||||
dataset_ids = node_data.dataset_ids
|
||||
|
||||
|
|
@ -245,9 +260,10 @@ class KnowledgeRetrievalNode(Node):
|
|||
if not dataset:
|
||||
continue
|
||||
available_datasets.append(dataset)
|
||||
metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition(
|
||||
metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition(
|
||||
[dataset.id for dataset in available_datasets], query, node_data
|
||||
)
|
||||
usage = self._merge_usage(usage, metadata_usage)
|
||||
all_documents = []
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||
|
|
@ -330,6 +346,8 @@ class KnowledgeRetrievalNode(Node):
|
|||
metadata_filter_document_ids=metadata_filter_document_ids,
|
||||
metadata_condition=metadata_condition,
|
||||
)
|
||||
usage = self._merge_usage(usage, dataset_retrieval.llm_usage)
|
||||
|
||||
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
||||
external_documents = [item for item in all_documents if item.provider == "external"]
|
||||
retrieval_resource_list = []
|
||||
|
|
@ -406,11 +424,12 @@ class KnowledgeRetrievalNode(Node):
|
|||
)
|
||||
for position, item in enumerate(retrieval_resource_list, start=1):
|
||||
item["metadata"]["position"] = position
|
||||
return retrieval_resource_list
|
||||
return retrieval_resource_list, usage
|
||||
|
||||
def _get_metadata_filter_condition(
|
||||
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]:
|
||||
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]:
|
||||
usage = LLMUsage.empty_usage()
|
||||
document_query = db.session.query(Document).where(
|
||||
Document.dataset_id.in_(dataset_ids),
|
||||
Document.indexing_status == "completed",
|
||||
|
|
@ -420,9 +439,12 @@ class KnowledgeRetrievalNode(Node):
|
|||
filters: list[Any] = []
|
||||
metadata_condition = None
|
||||
if node_data.metadata_filtering_mode == "disabled":
|
||||
return None, None
|
||||
return None, None, usage
|
||||
elif node_data.metadata_filtering_mode == "automatic":
|
||||
automatic_metadata_filters = self._automatic_metadata_filter_func(dataset_ids, query, node_data)
|
||||
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
|
||||
dataset_ids, query, node_data
|
||||
)
|
||||
usage = self._merge_usage(usage, automatic_usage)
|
||||
if automatic_metadata_filters:
|
||||
conditions = []
|
||||
for sequence, filter in enumerate(automatic_metadata_filters):
|
||||
|
|
@ -443,7 +465,7 @@ class KnowledgeRetrievalNode(Node):
|
|||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator
|
||||
if node_data.metadata_filtering_conditions
|
||||
else "or", # type: ignore
|
||||
else "or",
|
||||
conditions=conditions,
|
||||
)
|
||||
elif node_data.metadata_filtering_mode == "manual":
|
||||
|
|
@ -457,10 +479,10 @@ class KnowledgeRetrievalNode(Node):
|
|||
expected_value = self.graph_runtime_state.variable_pool.convert_template(
|
||||
expected_value
|
||||
).value[0]
|
||||
if expected_value.value_type in {"number", "integer", "float"}: # type: ignore
|
||||
expected_value = expected_value.value # type: ignore
|
||||
elif expected_value.value_type == "string": # type: ignore
|
||||
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore
|
||||
if expected_value.value_type in {"number", "integer", "float"}:
|
||||
expected_value = expected_value.value
|
||||
elif expected_value.value_type == "string":
|
||||
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
|
||||
else:
|
||||
raise ValueError("Invalid expected metadata value type")
|
||||
conditions.append(
|
||||
|
|
@ -487,7 +509,7 @@ class KnowledgeRetrievalNode(Node):
|
|||
if (
|
||||
node_data.metadata_filtering_conditions
|
||||
and node_data.metadata_filtering_conditions.logical_operator == "and"
|
||||
): # type: ignore
|
||||
):
|
||||
document_query = document_query.where(and_(*filters))
|
||||
else:
|
||||
document_query = document_query.where(or_(*filters))
|
||||
|
|
@ -496,11 +518,12 @@ class KnowledgeRetrievalNode(Node):
|
|||
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
|
||||
for document in documents:
|
||||
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
|
||||
return metadata_filter_document_ids, metadata_condition
|
||||
return metadata_filter_document_ids, metadata_condition, usage
|
||||
|
||||
def _automatic_metadata_filter_func(
|
||||
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||
) -> list[dict[str, Any]]:
|
||||
) -> tuple[list[dict[str, Any]], LLMUsage]:
|
||||
usage = LLMUsage.empty_usage()
|
||||
# get all metadata field
|
||||
stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
|
||||
metadata_fields = db.session.scalars(stmt).all()
|
||||
|
|
@ -548,6 +571,7 @@ class KnowledgeRetrievalNode(Node):
|
|||
for event in generator:
|
||||
if isinstance(event, ModelInvokeCompletedEvent):
|
||||
result_text = event.text
|
||||
usage = self._merge_usage(usage, event.usage)
|
||||
break
|
||||
|
||||
result_text_json = parse_and_check_json_markdown(result_text, [])
|
||||
|
|
@ -564,8 +588,8 @@ class KnowledgeRetrievalNode(Node):
|
|||
}
|
||||
)
|
||||
except Exception:
|
||||
return []
|
||||
return automatic_metadata_filters
|
||||
return [], usage
|
||||
return automatic_metadata_filters, usage
|
||||
|
||||
def _process_metadata_filter_func(
|
||||
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]
|
||||
|
|
|
|||
|
|
@ -452,10 +452,14 @@ class LLMNode(Node):
|
|||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
full_text_buffer = io.StringIO()
|
||||
collected_structured_output = None # Collect structured_output from streaming chunks
|
||||
# Consume the invoke result and handle generator exception
|
||||
try:
|
||||
for result in invoke_result:
|
||||
if isinstance(result, LLMResultChunkWithStructuredOutput):
|
||||
# Collect structured_output from the chunk
|
||||
if result.structured_output is not None:
|
||||
collected_structured_output = dict(result.structured_output)
|
||||
yield result
|
||||
if isinstance(result, LLMResultChunk):
|
||||
contents = result.delta.message.content
|
||||
|
|
@ -503,6 +507,8 @@ class LLMNode(Node):
|
|||
finish_reason=finish_reason,
|
||||
# Reasoning content for workflow variables and downstream nodes
|
||||
reasoning_content=reasoning_content,
|
||||
# Pass structured output if collected from streaming chunks
|
||||
structured_output=collected_structured_output,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence
|
|||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables import Segment, SegmentType
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
|
|
@ -27,6 +28,7 @@ from core.workflow.node_events import (
|
|||
NodeRunResult,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData
|
||||
|
|
@ -40,7 +42,7 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoopNode(Node):
|
||||
class LoopNode(LLMUsageTrackingMixin, Node):
|
||||
"""
|
||||
Loop Node.
|
||||
"""
|
||||
|
|
@ -108,7 +110,7 @@ class LoopNode(Node):
|
|||
raise ValueError(f"Invalid value for loop variable {loop_variable.label}")
|
||||
variable_selector = [self._node_id, loop_variable.label]
|
||||
variable = segment_to_variable(segment=processed_segment, selector=variable_selector)
|
||||
self.graph_runtime_state.variable_pool.add(variable_selector, variable)
|
||||
self.graph_runtime_state.variable_pool.add(variable_selector, variable.value)
|
||||
loop_variable_selectors[loop_variable.label] = variable_selector
|
||||
inputs[loop_variable.label] = processed_segment.value
|
||||
|
||||
|
|
@ -117,6 +119,7 @@ class LoopNode(Node):
|
|||
|
||||
loop_duration_map: dict[str, float] = {}
|
||||
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
|
||||
loop_usage = LLMUsage.empty_usage()
|
||||
|
||||
# Start Loop event
|
||||
yield LoopStartedEvent(
|
||||
|
|
@ -163,6 +166,9 @@ class LoopNode(Node):
|
|||
# Update the total tokens from this iteration
|
||||
cost_tokens += graph_engine.graph_runtime_state.total_tokens
|
||||
|
||||
# Accumulate usage from the sub-graph execution
|
||||
loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage)
|
||||
|
||||
# Collect loop variable values after iteration
|
||||
single_loop_variable = {}
|
||||
for key, selector in loop_variable_selectors.items():
|
||||
|
|
@ -189,6 +195,7 @@ class LoopNode(Node):
|
|||
)
|
||||
|
||||
self.graph_runtime_state.total_tokens += cost_tokens
|
||||
self._accumulate_usage(loop_usage)
|
||||
# Loop completed successfully
|
||||
yield LoopSucceededEvent(
|
||||
start_at=start_at,
|
||||
|
|
@ -196,7 +203,9 @@ class LoopNode(Node):
|
|||
outputs=self._node_data.outputs,
|
||||
steps=loop_count,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: cost_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
|
||||
"completed_reason": "loop_break" if reach_break_condition else "loop_completed",
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
|
|
@ -207,22 +216,28 @@ class LoopNode(Node):
|
|||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
outputs=self._node_data.outputs,
|
||||
inputs=inputs,
|
||||
llm_usage=loop_usage,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self._accumulate_usage(loop_usage)
|
||||
yield LoopFailedEvent(
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
steps=loop_count,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
|
||||
"completed_reason": "error",
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
|
|
@ -235,10 +250,13 @@ class LoopNode(Node):
|
|||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
llm_usage=loop_usage,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, final
|
|||
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph import NodeFactory
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from libs.typing import is_str, is_str_dict
|
||||
|
|
@ -82,8 +82,4 @@ class DifyNodeFactory(NodeFactory):
|
|||
raise ValueError(f"Node {node_id} missing data information")
|
||||
node_instance.init_node_data(node_data)
|
||||
|
||||
# If node has fail branch, change execution type to branch
|
||||
if node_instance.error_strategy == ErrorStrategy.FAIL_BRANCH:
|
||||
node_instance.execution_type = NodeExecutionType.BRANCH
|
||||
|
||||
return node_instance
|
||||
|
|
|
|||
|
|
@ -747,7 +747,7 @@ class ParameterExtractorNode(Node):
|
|||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = ChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str).replace("{{instructions}}", instruction),
|
||||
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str, instructions=instruction),
|
||||
)
|
||||
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
|
||||
return [system_prompt_messages, user_prompt_message]
|
||||
|
|
|
|||
|
|
@ -135,7 +135,7 @@ Here are the chat histories between human and assistant, inside <histories></his
|
|||
### Instructions:
|
||||
Some extra information are provided below, you should always follow the instructions as possible as you can.
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
{instructions}
|
||||
</instructions>
|
||||
"""
|
||||
|
||||
|
|
|
|||
|
|
@ -6,10 +6,13 @@ from sqlalchemy.orm import Session
|
|||
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.enums import (
|
||||
|
|
@ -136,13 +139,14 @@ class ToolNode(Node):
|
|||
|
||||
try:
|
||||
# convert tool messages
|
||||
yield from self._transform_message(
|
||||
_ = yield from self._transform_message(
|
||||
messages=message_stream,
|
||||
tool_info=tool_info,
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
node_id=self._node_id,
|
||||
tool_runtime=tool_runtime,
|
||||
)
|
||||
except ToolInvokeError as e:
|
||||
yield StreamCompletedEvent(
|
||||
|
|
@ -236,7 +240,8 @@ class ToolNode(Node):
|
|||
user_id: str,
|
||||
tenant_id: str,
|
||||
node_id: str,
|
||||
) -> Generator:
|
||||
tool_runtime: Tool,
|
||||
) -> Generator[NodeEventBase, None, LLMUsage]:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
|
|
@ -424,17 +429,34 @@ class ToolNode(Node):
|
|||
is_final=True,
|
||||
)
|
||||
|
||||
usage = self._extract_tool_usage(tool_runtime)
|
||||
|
||||
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
}
|
||||
if usage.total_tokens > 0:
|
||||
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
|
||||
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
|
||||
metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
},
|
||||
metadata=metadata,
|
||||
inputs=parameters_for_log,
|
||||
llm_usage=usage,
|
||||
)
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage:
|
||||
if isinstance(tool_runtime, WorkflowTool):
|
||||
return tool_runtime.latest_usage
|
||||
return LLMUsage.empty_usage()
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
|
|
|
|||
|
|
@ -277,7 +277,7 @@ class VariablePool(BaseModel):
|
|||
# This ensures that we can keep the id of the system variables intact.
|
||||
if self._has(selector):
|
||||
continue
|
||||
self.add(selector, value) # type: ignore
|
||||
self.add(selector, value)
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> "VariablePool":
|
||||
|
|
|
|||
|
|
@ -32,7 +32,8 @@ if [[ "${MODE}" == "worker" ]]; then
|
|||
|
||||
exec celery -A celery_entrypoint.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \
|
||||
--max-tasks-per-child ${MAX_TASKS_PER_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \
|
||||
-Q ${CELERY_QUEUES:-dataset,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation}
|
||||
-Q ${CELERY_QUEUES:-dataset,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation} \
|
||||
--prefetch-multiplier=1
|
||||
|
||||
elif [[ "${MODE}" == "beat" ]]; then
|
||||
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ from tasks.clean_dataset_task import clean_dataset_task
|
|||
@dataset_was_deleted.connect
|
||||
def handle(sender: Dataset, **kwargs):
|
||||
dataset = sender
|
||||
assert dataset.doc_form
|
||||
assert dataset.indexing_technique
|
||||
if not dataset.doc_form or not dataset.indexing_technique:
|
||||
return
|
||||
clean_dataset_task.delay(
|
||||
dataset.id,
|
||||
dataset.tenant_id,
|
||||
|
|
|
|||
|
|
@ -8,6 +8,6 @@ def handle(sender, **kwargs):
|
|||
dataset_id = kwargs.get("dataset_id")
|
||||
doc_form = kwargs.get("doc_form")
|
||||
file_id = kwargs.get("file_id")
|
||||
assert dataset_id is not None
|
||||
assert doc_form is not None
|
||||
if not dataset_id or not doc_form:
|
||||
return
|
||||
clean_document_task.delay(document_id, dataset_id, doc_form, file_id)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,11 @@ from configs import dify_config
|
|||
from constants import HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN, HEADER_NAME_PASSPORT
|
||||
from dify_app import DifyApp
|
||||
|
||||
BASE_CORS_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE, HEADER_NAME_PASSPORT)
|
||||
SERVICE_API_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, "Authorization")
|
||||
AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN)
|
||||
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
# register blueprint routers
|
||||
|
|
@ -17,7 +22,7 @@ def init_app(app: DifyApp):
|
|||
|
||||
CORS(
|
||||
service_api_bp,
|
||||
allow_headers=["Content-Type", "Authorization", HEADER_NAME_APP_CODE],
|
||||
allow_headers=list(SERVICE_API_HEADERS),
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
)
|
||||
app.register_blueprint(service_api_bp)
|
||||
|
|
@ -27,11 +32,11 @@ def init_app(app: DifyApp):
|
|||
resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
|
||||
supports_credentials=True,
|
||||
allow_headers=[
|
||||
"Content-Type",
|
||||
"Authorization",
|
||||
HEADER_NAME_APP_CODE,
|
||||
HEADER_NAME_CSRF_TOKEN,
|
||||
HEADER_NAME_PASSPORT
|
||||
"Content-Type",
|
||||
"Authorization",
|
||||
HEADER_NAME_APP_CODE,
|
||||
HEADER_NAME_CSRF_TOKEN,
|
||||
HEADER_NAME_PASSPORT,
|
||||
],
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
expose_headers=["X-Version", "X-Env"],
|
||||
|
|
@ -42,7 +47,7 @@ def init_app(app: DifyApp):
|
|||
console_app_bp,
|
||||
resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
|
||||
supports_credentials=True,
|
||||
allow_headers=["Content-Type", "Authorization", HEADER_NAME_CSRF_TOKEN],
|
||||
allow_headers=list(AUTHENTICATED_HEADERS),
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
expose_headers=["X-Version", "X-Env"],
|
||||
)
|
||||
|
|
@ -50,7 +55,7 @@ def init_app(app: DifyApp):
|
|||
|
||||
CORS(
|
||||
files_bp,
|
||||
allow_headers=["Content-Type", HEADER_NAME_CSRF_TOKEN],
|
||||
allow_headers=list(FILES_HEADERS),
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
)
|
||||
app.register_blueprint(files_bp)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ def is_enabled() -> bool:
|
|||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
from flask_compress import Compress # type: ignore
|
||||
from flask_compress import Compress
|
||||
|
||||
compress = Compress()
|
||||
compress.init_app(app)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
|
||||
import flask_login # type: ignore
|
||||
import flask_login
|
||||
from flask import Response, request
|
||||
from flask_login import user_loaded_from_request, user_logged_in
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from dify_app import DifyApp
|
|||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
import flask_migrate # type: ignore
|
||||
import flask_migrate
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
|
|
|||
|
|
@ -103,7 +103,7 @@ def init_app(app: DifyApp):
|
|||
def shutdown_tracer():
|
||||
provider = trace.get_tracer_provider()
|
||||
if hasattr(provider, "force_flush"):
|
||||
provider.force_flush() # ty: ignore [call-non-callable]
|
||||
provider.force_flush()
|
||||
|
||||
class ExceptionLoggingHandler(logging.Handler):
|
||||
"""Custom logging handler that creates spans for logging.exception() calls"""
|
||||
|
|
|
|||
|
|
@ -6,4 +6,4 @@ def init_app(app: DifyApp):
|
|||
if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED:
|
||||
from werkzeug.middleware.proxy_fix import ProxyFix
|
||||
|
||||
app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore
|
||||
app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore[method-assign]
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from dify_app import DifyApp
|
|||
def init_app(app: DifyApp):
|
||||
if dify_config.SENTRY_DSN:
|
||||
import sentry_sdk
|
||||
from langfuse import parse_error # type: ignore
|
||||
from langfuse import parse_error
|
||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
from sentry_sdk.integrations.flask import FlaskIntegration
|
||||
from werkzeug.exceptions import HTTPException
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import posixpath
|
||||
from collections.abc import Generator
|
||||
|
||||
import oss2 as aliyun_s3 # type: ignore
|
||||
import oss2 as aliyun_s3
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@ import base64
|
|||
import hashlib
|
||||
from collections.abc import Generator
|
||||
|
||||
from baidubce.auth.bce_credentials import BceCredentials # type: ignore
|
||||
from baidubce.bce_client_configuration import BceClientConfiguration # type: ignore
|
||||
from baidubce.services.bos.bos_client import BosClient # type: ignore
|
||||
from baidubce.auth.bce_credentials import BceCredentials
|
||||
from baidubce.bce_client_configuration import BceClientConfiguration
|
||||
from baidubce.services.bos.bos_client import BosClient
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from collections.abc import Generator
|
|||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import clickzetta # type: ignore[import]
|
||||
import clickzetta
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class VolumePermissionManager:
|
|||
# Support two initialization methods: connection object or configuration dictionary
|
||||
if isinstance(connection_or_config, dict):
|
||||
# Create connection from configuration dictionary
|
||||
import clickzetta # type: ignore[import-untyped]
|
||||
import clickzetta
|
||||
|
||||
config = connection_or_config
|
||||
self._connection = clickzetta.connect(
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import io
|
|||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from google.cloud import storage as google_cloud_storage # type: ignore
|
||||
from google.cloud import storage as google_cloud_storage
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from collections.abc import Generator
|
||||
|
||||
from obs import ObsClient # type: ignore
|
||||
from obs import ObsClient
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from collections.abc import Generator
|
||||
|
||||
import boto3 # type: ignore
|
||||
from botocore.exceptions import ClientError # type: ignore
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from collections.abc import Generator
|
||||
|
||||
from qcloud_cos import CosConfig, CosS3Client # type: ignore
|
||||
from qcloud_cos import CosConfig, CosS3Client
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from collections.abc import Generator
|
||||
|
||||
import tos # type: ignore
|
||||
import tos
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
|
|
|||
|
|
@ -146,6 +146,6 @@ class ExternalApi(Api):
|
|||
kwargs["doc"] = dify_config.SWAGGER_UI_PATH if dify_config.SWAGGER_UI_ENABLED else False
|
||||
|
||||
# manual separate call on construction and init_app to ensure configs in kwargs effective
|
||||
super().__init__(app=None, *args, **kwargs) # type: ignore
|
||||
super().__init__(app=None, *args, **kwargs)
|
||||
self.init_app(app, **kwargs)
|
||||
register_external_error_handlers(self)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from hashlib import sha1
|
|||
|
||||
import Crypto.Hash.SHA1
|
||||
import Crypto.Util.number
|
||||
import gmpy2 # type: ignore
|
||||
import gmpy2
|
||||
from Crypto import Random
|
||||
from Crypto.Signature.pss import MGF1
|
||||
from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes
|
||||
|
|
@ -136,7 +136,7 @@ class PKCS1OAepCipher:
|
|||
# Step 3a (OS2IP)
|
||||
em_int = bytes_to_long(em)
|
||||
# Step 3b (RSAEP)
|
||||
m_int = gmpy2.powmod(em_int, self._key.e, self._key.n) # ty: ignore [unresolved-attribute]
|
||||
m_int = gmpy2.powmod(em_int, self._key.e, self._key.n)
|
||||
# Step 3c (I2OSP)
|
||||
c = long_to_bytes(m_int, k)
|
||||
return c
|
||||
|
|
@ -169,7 +169,7 @@ class PKCS1OAepCipher:
|
|||
ct_int = bytes_to_long(ciphertext)
|
||||
# Step 2b (RSADP)
|
||||
# m_int = self._key._decrypt(ct_int)
|
||||
m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n) # ty: ignore [unresolved-attribute]
|
||||
m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n)
|
||||
# Complete step 2c (I2OSP)
|
||||
em = long_to_bytes(m_int, k)
|
||||
# Step 3a
|
||||
|
|
@ -191,12 +191,12 @@ class PKCS1OAepCipher:
|
|||
# Step 3g
|
||||
one_pos = hLen + db[hLen:].find(b"\x01")
|
||||
lHash1 = db[:hLen]
|
||||
invalid = bord(y) | int(one_pos < hLen) # type: ignore
|
||||
invalid = bord(y) | int(one_pos < hLen) # type: ignore[arg-type]
|
||||
hash_compare = strxor(lHash1, lHash)
|
||||
for x in hash_compare:
|
||||
invalid |= bord(x) # type: ignore
|
||||
invalid |= bord(x) # type: ignore[arg-type]
|
||||
for x in db[hLen:one_pos]:
|
||||
invalid |= bord(x) # type: ignore
|
||||
invalid |= bord(x) # type: ignore[arg-type]
|
||||
if invalid != 0:
|
||||
raise ValueError("Incorrect decryption.")
|
||||
# Step 4
|
||||
|
|
|
|||
|
|
@ -81,6 +81,8 @@ class AvatarUrlField(fields.Raw):
|
|||
from models import Account
|
||||
|
||||
if isinstance(obj, Account) and obj.avatar is not None:
|
||||
if obj.avatar.startswith(("http://", "https://")):
|
||||
return obj.avatar
|
||||
return file_helpers.get_signed_file_url(obj.avatar)
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from functools import wraps
|
|||
from typing import Any
|
||||
|
||||
from flask import current_app, g, has_request_context, request
|
||||
from flask_login.config import EXEMPT_METHODS # type: ignore
|
||||
from flask_login.config import EXEMPT_METHODS
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -87,7 +87,7 @@ def _get_user() -> EndUser | Account | None:
|
|||
if "_login_user" not in g:
|
||||
current_app.login_manager._load_user() # type: ignore
|
||||
|
||||
return g._login_user # type: ignore
|
||||
return g._login_user
|
||||
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import logging
|
||||
|
||||
import sendgrid # type: ignore
|
||||
import sendgrid
|
||||
from python_http_client.exceptions import ForbiddenError, UnauthorizedError
|
||||
from sendgrid.helpers.mail import Content, Email, Mail, To # type: ignore
|
||||
from sendgrid.helpers.mail import Content, Email, Mail, To
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from constants import (
|
|||
COOKIE_NAME_CSRF_TOKEN,
|
||||
COOKIE_NAME_PASSPORT,
|
||||
COOKIE_NAME_REFRESH_TOKEN,
|
||||
COOKIE_NAME_WEBAPP_ACCESS_TOKEN,
|
||||
HEADER_NAME_CSRF_TOKEN,
|
||||
HEADER_NAME_PASSPORT,
|
||||
)
|
||||
|
|
@ -81,6 +82,14 @@ def extract_access_token(request: Request) -> str | None:
|
|||
return _try_extract_from_cookie(request) or _try_extract_from_header(request)
|
||||
|
||||
|
||||
def extract_webapp_access_token(request: Request) -> str | None:
|
||||
"""
|
||||
Try to extract webapp access token from cookie, then header.
|
||||
"""
|
||||
|
||||
return request.cookies.get(_real_cookie_name(COOKIE_NAME_WEBAPP_ACCESS_TOKEN)) or _try_extract_from_header(request)
|
||||
|
||||
|
||||
def extract_webapp_passport(app_code: str, request: Request) -> str | None:
|
||||
"""
|
||||
Try to extract app token from header or params.
|
||||
|
|
@ -155,6 +164,10 @@ def clear_access_token_from_cookie(response: Response, samesite: str = "Lax"):
|
|||
_clear_cookie(response, COOKIE_NAME_ACCESS_TOKEN, samesite)
|
||||
|
||||
|
||||
def clear_webapp_access_token_from_cookie(response: Response, samesite: str = "Lax"):
|
||||
_clear_cookie(response, COOKIE_NAME_WEBAPP_ACCESS_TOKEN, samesite)
|
||||
|
||||
|
||||
def clear_refresh_token_from_cookie(response: Response):
|
||||
_clear_cookie(response, COOKIE_NAME_REFRESH_TOKEN)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,36 @@
|
|||
"""remove-builtin-template-user
|
||||
|
||||
Revision ID: ae662b25d9bc
|
||||
Revises: d98acf217d43
|
||||
Create Date: 2025-10-21 14:30:28.566192
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'ae662b25d9bc'
|
||||
down_revision = 'd98acf217d43'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
|
||||
batch_op.drop_column('updated_by')
|
||||
batch_op.drop_column('created_by')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('created_by', sa.UUID(), autoincrement=False, nullable=False))
|
||||
batch_op.add_column(sa.Column('updated_by', sa.UUID(), autoincrement=False, nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -5,7 +5,7 @@ from datetime import datetime
|
|||
from typing import Any, Optional
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask_login import UserMixin # type: ignore[import-untyped]
|
||||
from flask_login import UserMixin
|
||||
from sqlalchemy import DateTime, String, func, select
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
from typing_extensions import deprecated
|
||||
|
|
|
|||
|
|
@ -1239,15 +1239,6 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
|
|||
language = mapped_column(db.String(255), nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
|
||||
@property
|
||||
def created_user_name(self):
|
||||
account = db.session.query(Account).where(Account.id == self.created_by).first()
|
||||
if account:
|
||||
return account.name
|
||||
return ""
|
||||
|
||||
|
||||
class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
|
||||
|
|
|
|||
|
|
@ -219,7 +219,7 @@ class WorkflowToolProvider(TypeBase):
|
|||
sa.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
# name of the workflow provider
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# label of the workflow provider
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "dify-api"
|
||||
version = "1.9.1"
|
||||
version = "1.9.2"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
|
||||
dependencies = [
|
||||
|
|
|
|||
|
|
@ -16,7 +16,25 @@
|
|||
"opentelemetry.instrumentation.requests",
|
||||
"opentelemetry.instrumentation.sqlalchemy",
|
||||
"opentelemetry.instrumentation.redis",
|
||||
"opentelemetry.instrumentation.httpx"
|
||||
"langfuse",
|
||||
"cloudscraper",
|
||||
"readabilipy",
|
||||
"pypandoc",
|
||||
"pypdfium2",
|
||||
"webvtt",
|
||||
"flask_compress",
|
||||
"oss2",
|
||||
"baidubce.auth.bce_credentials",
|
||||
"baidubce.bce_client_configuration",
|
||||
"baidubce.services.bos.bos_client",
|
||||
"clickzetta",
|
||||
"google.cloud",
|
||||
"obs",
|
||||
"qcloud_cos",
|
||||
"tos",
|
||||
"gmpy2",
|
||||
"sendgrid",
|
||||
"sendgrid.helpers.mail"
|
||||
],
|
||||
"reportUnknownMemberType": "hint",
|
||||
"reportUnknownParameterType": "hint",
|
||||
|
|
@ -28,7 +46,7 @@
|
|||
"reportUnnecessaryComparison": "hint",
|
||||
"reportUnnecessaryIsInstance": "hint",
|
||||
"reportUntypedFunctionDecorator": "hint",
|
||||
|
||||
"reportUnnecessaryTypeIgnoreComment": "hint",
|
||||
"reportAttributeAccessIssue": "hint",
|
||||
"pythonVersion": "3.11",
|
||||
"pythonPlatform": "All"
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue