mirror of
https://github.com/langgenius/dify.git
synced 2026-04-29 04:26:30 +08:00
Merge branch 'main' into feat/mcp-06-18
This commit is contained in:
commit
e7a575a33c
@ -434,6 +434,9 @@ CODE_EXECUTION_SSL_VERIFY=True
|
|||||||
CODE_EXECUTION_POOL_MAX_CONNECTIONS=100
|
CODE_EXECUTION_POOL_MAX_CONNECTIONS=100
|
||||||
CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS=20
|
CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS=20
|
||||||
CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY=5.0
|
CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY=5.0
|
||||||
|
CODE_EXECUTION_CONNECT_TIMEOUT=10
|
||||||
|
CODE_EXECUTION_READ_TIMEOUT=60
|
||||||
|
CODE_EXECUTION_WRITE_TIMEOUT=10
|
||||||
CODE_MAX_NUMBER=9223372036854775807
|
CODE_MAX_NUMBER=9223372036854775807
|
||||||
CODE_MIN_NUMBER=-9223372036854775808
|
CODE_MIN_NUMBER=-9223372036854775808
|
||||||
CODE_MAX_STRING_LENGTH=400000
|
CODE_MAX_STRING_LENGTH=400000
|
||||||
|
|||||||
@ -548,7 +548,7 @@ class UpdateConfig(BaseSettings):
|
|||||||
|
|
||||||
class WorkflowVariableTruncationConfig(BaseSettings):
|
class WorkflowVariableTruncationConfig(BaseSettings):
|
||||||
WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE: PositiveInt = Field(
|
WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE: PositiveInt = Field(
|
||||||
# 100KB
|
# 1000 KiB
|
||||||
1024_000,
|
1024_000,
|
||||||
description="Maximum size for variable to trigger final truncation.",
|
description="Maximum size for variable to trigger final truncation.",
|
||||||
)
|
)
|
||||||
|
|||||||
@ -145,7 +145,7 @@ class DatabaseConfig(BaseSettings):
|
|||||||
default="postgresql",
|
default="postgresql",
|
||||||
)
|
)
|
||||||
|
|
||||||
@computed_field # type: ignore[misc]
|
@computed_field # type: ignore[prop-decorator]
|
||||||
@property
|
@property
|
||||||
def SQLALCHEMY_DATABASE_URI(self) -> str:
|
def SQLALCHEMY_DATABASE_URI(self) -> str:
|
||||||
db_extras = (
|
db_extras = (
|
||||||
@ -198,7 +198,7 @@ class DatabaseConfig(BaseSettings):
|
|||||||
default=os.cpu_count() or 1,
|
default=os.cpu_count() or 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
@computed_field # type: ignore[misc]
|
@computed_field # type: ignore[prop-decorator]
|
||||||
@property
|
@property
|
||||||
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
|
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
|
||||||
# Parse DB_EXTRAS for 'options'
|
# Parse DB_EXTRAS for 'options'
|
||||||
|
|||||||
@ -56,11 +56,15 @@ else:
|
|||||||
}
|
}
|
||||||
DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)
|
DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)
|
||||||
|
|
||||||
|
# console
|
||||||
COOKIE_NAME_ACCESS_TOKEN = "access_token"
|
COOKIE_NAME_ACCESS_TOKEN = "access_token"
|
||||||
COOKIE_NAME_REFRESH_TOKEN = "refresh_token"
|
COOKIE_NAME_REFRESH_TOKEN = "refresh_token"
|
||||||
COOKIE_NAME_PASSPORT = "passport"
|
|
||||||
COOKIE_NAME_CSRF_TOKEN = "csrf_token"
|
COOKIE_NAME_CSRF_TOKEN = "csrf_token"
|
||||||
|
|
||||||
|
# webapp
|
||||||
|
COOKIE_NAME_WEBAPP_ACCESS_TOKEN = "webapp_access_token"
|
||||||
|
COOKIE_NAME_PASSPORT = "passport"
|
||||||
|
|
||||||
HEADER_NAME_CSRF_TOKEN = "X-CSRF-Token"
|
HEADER_NAME_CSRF_TOKEN = "X-CSRF-Token"
|
||||||
HEADER_NAME_APP_CODE = "X-App-Code"
|
HEADER_NAME_APP_CODE = "X-App-Code"
|
||||||
HEADER_NAME_PASSPORT = "X-App-Passport"
|
HEADER_NAME_PASSPORT = "X-App-Passport"
|
||||||
|
|||||||
@ -31,3 +31,9 @@ def supported_language(lang):
|
|||||||
|
|
||||||
error = f"{lang} is not a valid language."
|
error = f"{lang} is not a valid language."
|
||||||
raise ValueError(error)
|
raise ValueError(error)
|
||||||
|
|
||||||
|
|
||||||
|
def get_valid_language(lang: str | None) -> str:
|
||||||
|
if lang and lang in languages:
|
||||||
|
return lang
|
||||||
|
return languages[0]
|
||||||
|
|||||||
@ -24,7 +24,7 @@ except ImportError:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
warnings.warn("To use python-magic guess MIMETYPE, you need to install `libmagic`", stacklevel=2)
|
warnings.warn("To use python-magic guess MIMETYPE, you need to install `libmagic`", stacklevel=2)
|
||||||
magic = None # type: ignore
|
magic = None # type: ignore[assignment]
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from flask_restx import Resource, reqparse
|
|||||||
|
|
||||||
import services
|
import services
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from constants.languages import languages
|
from constants.languages import get_valid_language
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.auth.error import (
|
from controllers.console.auth.error import (
|
||||||
AuthenticationFailedError,
|
AuthenticationFailedError,
|
||||||
@ -29,8 +29,6 @@ from libs.token import (
|
|||||||
clear_access_token_from_cookie,
|
clear_access_token_from_cookie,
|
||||||
clear_csrf_token_from_cookie,
|
clear_csrf_token_from_cookie,
|
||||||
clear_refresh_token_from_cookie,
|
clear_refresh_token_from_cookie,
|
||||||
extract_access_token,
|
|
||||||
extract_csrf_token,
|
|
||||||
set_access_token_to_cookie,
|
set_access_token_to_cookie,
|
||||||
set_csrf_token_to_cookie,
|
set_csrf_token_to_cookie,
|
||||||
set_refresh_token_to_cookie,
|
set_refresh_token_to_cookie,
|
||||||
@ -206,10 +204,12 @@ class EmailCodeLoginApi(Resource):
|
|||||||
.add_argument("email", type=str, required=True, location="json")
|
.add_argument("email", type=str, required=True, location="json")
|
||||||
.add_argument("code", type=str, required=True, location="json")
|
.add_argument("code", type=str, required=True, location="json")
|
||||||
.add_argument("token", type=str, required=True, location="json")
|
.add_argument("token", type=str, required=True, location="json")
|
||||||
|
.add_argument("language", type=str, required=False, location="json")
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
user_email = args["email"]
|
user_email = args["email"]
|
||||||
|
language = args["language"]
|
||||||
|
|
||||||
token_data = AccountService.get_email_code_login_data(args["token"])
|
token_data = AccountService.get_email_code_login_data(args["token"])
|
||||||
if token_data is None:
|
if token_data is None:
|
||||||
@ -243,7 +243,9 @@ class EmailCodeLoginApi(Resource):
|
|||||||
if account is None:
|
if account is None:
|
||||||
try:
|
try:
|
||||||
account = AccountService.create_account_and_tenant(
|
account = AccountService.create_account_and_tenant(
|
||||||
email=user_email, name=user_email, interface_language=languages[0]
|
email=user_email,
|
||||||
|
name=user_email,
|
||||||
|
interface_language=get_valid_language(language),
|
||||||
)
|
)
|
||||||
except WorkSpaceNotAllowedCreateError:
|
except WorkSpaceNotAllowedCreateError:
|
||||||
raise NotAllowedCreateWorkspace()
|
raise NotAllowedCreateWorkspace()
|
||||||
@ -286,13 +288,3 @@ class RefreshTokenApi(Resource):
|
|||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"result": "fail", "message": str(e)}, 401
|
return {"result": "fail", "message": str(e)}, 401
|
||||||
|
|
||||||
|
|
||||||
# this api helps frontend to check whether user is authenticated
|
|
||||||
# TODO: remove in the future. frontend should redirect to login page by catching 401 status
|
|
||||||
@console_ns.route("/login/status")
|
|
||||||
class LoginStatus(Resource):
|
|
||||||
def get(self):
|
|
||||||
token = extract_access_token(request)
|
|
||||||
csrf_token = extract_csrf_token(request)
|
|
||||||
return {"logged_in": bool(token) and bool(csrf_token)}
|
|
||||||
|
|||||||
@ -22,7 +22,7 @@ from core.errors.error import (
|
|||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.login import current_user as current_user_
|
from libs.login import current_account_with_tenant
|
||||||
from models.model import AppMode, InstalledApp
|
from models.model import AppMode, InstalledApp
|
||||||
from services.app_generate_service import AppGenerateService
|
from services.app_generate_service import AppGenerateService
|
||||||
from services.errors.llm import InvokeRateLimitError
|
from services.errors.llm import InvokeRateLimitError
|
||||||
@ -31,8 +31,6 @@ from .. import console_ns
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
current_user = current_user_._get_current_object() # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
|
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
|
||||||
class InstalledAppWorkflowRunApi(InstalledAppResource):
|
class InstalledAppWorkflowRunApi(InstalledAppResource):
|
||||||
@ -40,6 +38,7 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
|
|||||||
"""
|
"""
|
||||||
Run workflow
|
Run workflow
|
||||||
"""
|
"""
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
if not app_model:
|
if not app_model:
|
||||||
raise NotWorkflowAppError()
|
raise NotWorkflowAppError()
|
||||||
@ -53,7 +52,6 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
|
|||||||
.add_argument("files", type=list, required=False, location="json")
|
.add_argument("files", type=list, required=False, location="json")
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
assert current_user is not None
|
|
||||||
try:
|
try:
|
||||||
response = AppGenerateService.generate(
|
response = AppGenerateService.generate(
|
||||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||||
@ -89,7 +87,6 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
|
|||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode != AppMode.WORKFLOW:
|
if app_mode != AppMode.WORKFLOW:
|
||||||
raise NotWorkflowAppError()
|
raise NotWorkflowAppError()
|
||||||
assert current_user is not None
|
|
||||||
|
|
||||||
# Stop using both mechanisms for backward compatibility
|
# Stop using both mechanisms for backward compatibility
|
||||||
# Legacy stop flag mechanism (without user check)
|
# Legacy stop flag mechanism (without user check)
|
||||||
|
|||||||
@ -74,12 +74,17 @@ class SetupApi(Resource):
|
|||||||
.add_argument("email", type=email, required=True, location="json")
|
.add_argument("email", type=email, required=True, location="json")
|
||||||
.add_argument("name", type=StrLen(30), required=True, location="json")
|
.add_argument("name", type=StrLen(30), required=True, location="json")
|
||||||
.add_argument("password", type=valid_password, required=True, location="json")
|
.add_argument("password", type=valid_password, required=True, location="json")
|
||||||
|
.add_argument("language", type=str, required=False, location="json")
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# setup
|
# setup
|
||||||
RegisterService.setup(
|
RegisterService.setup(
|
||||||
email=args["email"], name=args["name"], password=args["password"], ip_address=extract_remote_ip(request)
|
email=args["email"],
|
||||||
|
name=args["name"],
|
||||||
|
password=args["password"],
|
||||||
|
ip_address=extract_remote_ip(request),
|
||||||
|
language=args["language"],
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"result": "success"}, 201
|
return {"result": "success"}, 201
|
||||||
|
|||||||
@ -193,15 +193,16 @@ class MCPAppApi(Resource):
|
|||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
|
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
|
||||||
|
|
||||||
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str, session: Session) -> EndUser | None:
|
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str) -> EndUser | None:
|
||||||
"""Get end user from existing session - optimized query"""
|
"""Get end user - manages its own database session"""
|
||||||
return (
|
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||||
session.query(EndUser)
|
return (
|
||||||
.where(EndUser.tenant_id == tenant_id)
|
session.query(EndUser)
|
||||||
.where(EndUser.session_id == mcp_server_id)
|
.where(EndUser.tenant_id == tenant_id)
|
||||||
.where(EndUser.type == "mcp")
|
.where(EndUser.session_id == mcp_server_id)
|
||||||
.first()
|
.where(EndUser.type == "mcp")
|
||||||
)
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
def _create_end_user(
|
def _create_end_user(
|
||||||
self, client_name: str, tenant_id: str, app_id: str, mcp_server_id: str, session: Session
|
self, client_name: str, tenant_id: str, app_id: str, mcp_server_id: str, session: Session
|
||||||
@ -229,7 +230,7 @@ class MCPAppApi(Resource):
|
|||||||
request_id: Union[int, str],
|
request_id: Union[int, str],
|
||||||
) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError | None:
|
) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError | None:
|
||||||
"""Handle MCP request and return response"""
|
"""Handle MCP request and return response"""
|
||||||
end_user = self._retrieve_end_user(mcp_server.tenant_id, mcp_server.id, session)
|
end_user = self._retrieve_end_user(mcp_server.tenant_id, mcp_server.id)
|
||||||
|
|
||||||
if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest):
|
if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest):
|
||||||
client_info = mcp_request.root.params.clientInfo
|
client_info = mcp_request.root.params.clientInfo
|
||||||
|
|||||||
@ -17,8 +17,8 @@ from libs.helper import email
|
|||||||
from libs.passport import PassportService
|
from libs.passport import PassportService
|
||||||
from libs.password import valid_password
|
from libs.password import valid_password
|
||||||
from libs.token import (
|
from libs.token import (
|
||||||
clear_access_token_from_cookie,
|
clear_webapp_access_token_from_cookie,
|
||||||
extract_access_token,
|
extract_webapp_access_token,
|
||||||
)
|
)
|
||||||
from services.account_service import AccountService
|
from services.account_service import AccountService
|
||||||
from services.app_service import AppService
|
from services.app_service import AppService
|
||||||
@ -81,7 +81,7 @@ class LoginStatusApi(Resource):
|
|||||||
)
|
)
|
||||||
def get(self):
|
def get(self):
|
||||||
app_code = request.args.get("app_code")
|
app_code = request.args.get("app_code")
|
||||||
token = extract_access_token(request)
|
token = extract_webapp_access_token(request)
|
||||||
if not app_code:
|
if not app_code:
|
||||||
return {
|
return {
|
||||||
"logged_in": bool(token),
|
"logged_in": bool(token),
|
||||||
@ -128,7 +128,7 @@ class LogoutApi(Resource):
|
|||||||
response = make_response({"result": "success"})
|
response = make_response({"result": "success"})
|
||||||
# enterprise SSO sets same site to None in https deployment
|
# enterprise SSO sets same site to None in https deployment
|
||||||
# so we need to logout by calling api
|
# so we need to logout by calling api
|
||||||
clear_access_token_from_cookie(response, samesite="None")
|
clear_webapp_access_token_from_cookie(response, samesite="None")
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -12,10 +12,8 @@ from controllers.web import web_ns
|
|||||||
from controllers.web.error import WebAppAuthRequiredError
|
from controllers.web.error import WebAppAuthRequiredError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.passport import PassportService
|
from libs.passport import PassportService
|
||||||
from libs.token import extract_access_token
|
from libs.token import extract_webapp_access_token
|
||||||
from models.model import App, EndUser, Site
|
from models.model import App, EndUser, Site
|
||||||
from services.app_service import AppService
|
|
||||||
from services.enterprise.enterprise_service import EnterpriseService
|
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
|
from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
|
||||||
|
|
||||||
@ -37,23 +35,18 @@ class PassportResource(Resource):
|
|||||||
system_features = FeatureService.get_system_features()
|
system_features = FeatureService.get_system_features()
|
||||||
app_code = request.headers.get(HEADER_NAME_APP_CODE)
|
app_code = request.headers.get(HEADER_NAME_APP_CODE)
|
||||||
user_id = request.args.get("user_id")
|
user_id = request.args.get("user_id")
|
||||||
access_token = extract_access_token(request)
|
access_token = extract_webapp_access_token(request)
|
||||||
|
|
||||||
if app_code is None:
|
if app_code is None:
|
||||||
raise Unauthorized("X-App-Code header is missing.")
|
raise Unauthorized("X-App-Code header is missing.")
|
||||||
app_id = AppService.get_app_id_by_code(app_code)
|
|
||||||
# exchange token for enterprise logined web user
|
|
||||||
enterprise_user_decoded = decode_enterprise_webapp_user_id(access_token)
|
|
||||||
if enterprise_user_decoded:
|
|
||||||
# a web user has already logged in, exchange a token for this app without redirecting to the login page
|
|
||||||
return exchange_token_for_existing_web_user(
|
|
||||||
app_code=app_code, enterprise_user_decoded=enterprise_user_decoded
|
|
||||||
)
|
|
||||||
|
|
||||||
if system_features.webapp_auth.enabled:
|
if system_features.webapp_auth.enabled:
|
||||||
app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id)
|
enterprise_user_decoded = decode_enterprise_webapp_user_id(access_token)
|
||||||
if not app_settings or not app_settings.access_mode == "public":
|
app_auth_type = WebAppAuthService.get_app_auth_type(app_code=app_code)
|
||||||
raise WebAppAuthRequiredError()
|
if app_auth_type != WebAppAuthType.PUBLIC:
|
||||||
|
if not enterprise_user_decoded:
|
||||||
|
raise WebAppAuthRequiredError()
|
||||||
|
return exchange_token_for_existing_web_user(
|
||||||
|
app_code=app_code, enterprise_user_decoded=enterprise_user_decoded, auth_type=app_auth_type
|
||||||
|
)
|
||||||
|
|
||||||
# get site from db and check if it is normal
|
# get site from db and check if it is normal
|
||||||
site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal"))
|
site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal"))
|
||||||
@ -124,7 +117,7 @@ def decode_enterprise_webapp_user_id(jwt_token: str | None):
|
|||||||
return decoded
|
return decoded
|
||||||
|
|
||||||
|
|
||||||
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict):
|
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict, auth_type: WebAppAuthType):
|
||||||
"""
|
"""
|
||||||
Exchange a token for an existing web user session.
|
Exchange a token for an existing web user session.
|
||||||
"""
|
"""
|
||||||
@ -145,13 +138,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
|
|||||||
if not app_model or app_model.status != "normal" or not app_model.enable_site:
|
if not app_model or app_model.status != "normal" or not app_model.enable_site:
|
||||||
raise NotFound()
|
raise NotFound()
|
||||||
|
|
||||||
app_auth_type = WebAppAuthService.get_app_auth_type(app_code=app_code)
|
if auth_type == WebAppAuthType.PUBLIC:
|
||||||
|
|
||||||
if app_auth_type == WebAppAuthType.PUBLIC:
|
|
||||||
return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded)
|
return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded)
|
||||||
elif app_auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external":
|
elif auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external":
|
||||||
raise WebAppAuthRequiredError("Please login as external user.")
|
raise WebAppAuthRequiredError("Please login as external user.")
|
||||||
elif app_auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal":
|
elif auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal":
|
||||||
raise WebAppAuthRequiredError("Please login as internal user.")
|
raise WebAppAuthRequiredError("Please login as internal user.")
|
||||||
|
|
||||||
end_user = None
|
end_user = None
|
||||||
|
|||||||
@ -211,8 +211,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
user=user,
|
user=user,
|
||||||
stream=streaming,
|
stream=streaming,
|
||||||
)
|
)
|
||||||
# FIXME: Type hinting issue here, ignore it for now, will fix it later
|
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||||
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) # type: ignore
|
|
||||||
|
|
||||||
def _generate_worker(
|
def _generate_worker(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -255,7 +255,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
json_text = json.dumps(text)
|
json_text = json.dumps(text)
|
||||||
upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id)
|
upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id)
|
||||||
features = FeatureService.get_features(dataset.tenant_id)
|
features = FeatureService.get_features(dataset.tenant_id)
|
||||||
if features.billing.subscription.plan == "sandbox":
|
if features.billing.enabled and features.billing.subscription.plan == "sandbox":
|
||||||
tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}"
|
tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}"
|
||||||
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}"
|
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}"
|
||||||
|
|
||||||
|
|||||||
@ -89,7 +89,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||||
response_chunk.update(data)
|
response_chunk.update(data)
|
||||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
|
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||||
else:
|
else:
|
||||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||||
yield response_chunk
|
yield response_chunk
|
||||||
|
|||||||
@ -98,7 +98,7 @@ class RateLimit:
|
|||||||
else:
|
else:
|
||||||
return RateLimitGenerator(
|
return RateLimitGenerator(
|
||||||
rate_limit=self,
|
rate_limit=self,
|
||||||
generator=generator, # ty: ignore [invalid-argument-type]
|
generator=generator,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -49,7 +49,7 @@ class BasedGenerateTaskPipeline:
|
|||||||
if isinstance(e, InvokeAuthorizationError):
|
if isinstance(e, InvokeAuthorizationError):
|
||||||
err = InvokeAuthorizationError("Incorrect API key provided")
|
err = InvokeAuthorizationError("Incorrect API key provided")
|
||||||
elif isinstance(e, InvokeError | ValueError):
|
elif isinstance(e, InvokeError | ValueError):
|
||||||
err = e # ty: ignore [invalid-assignment]
|
err = e
|
||||||
else:
|
else:
|
||||||
description = getattr(e, "description", None)
|
description = getattr(e, "description", None)
|
||||||
err = Exception(description if description is not None else str(e))
|
err = Exception(description if description is not None else str(e))
|
||||||
|
|||||||
@ -1868,7 +1868,7 @@ class ProviderConfigurations(BaseModel):
|
|||||||
if "/" not in key:
|
if "/" not in key:
|
||||||
key = str(ModelProviderID(key))
|
key = str(ModelProviderID(key))
|
||||||
|
|
||||||
return self.configurations.get(key, default) # type: ignore
|
return self.configurations.get(key, default)
|
||||||
|
|
||||||
|
|
||||||
class ProviderModelBundle(BaseModel):
|
class ProviderModelBundle(BaseModel):
|
||||||
|
|||||||
@ -20,7 +20,7 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz
|
|||||||
else:
|
else:
|
||||||
# Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
|
# Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
|
||||||
# FIXME: mypy does not support the type of spec.loader
|
# FIXME: mypy does not support the type of spec.loader
|
||||||
spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore
|
spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore[assignment]
|
||||||
if not spec or not spec.loader:
|
if not spec or not spec.loader:
|
||||||
raise Exception(f"Failed to load module {module_name} from {py_file_path!r}")
|
raise Exception(f"Failed to load module {module_name} from {py_file_path!r}")
|
||||||
if use_lazy_loader:
|
if use_lazy_loader:
|
||||||
|
|||||||
@ -49,62 +49,80 @@ class IndexingRunner:
|
|||||||
self.storage = storage
|
self.storage = storage
|
||||||
self.model_manager = ModelManager()
|
self.model_manager = ModelManager()
|
||||||
|
|
||||||
|
def _handle_indexing_error(self, document_id: str, error: Exception) -> None:
|
||||||
|
"""Handle indexing errors by updating document status."""
|
||||||
|
logger.exception("consume document failed")
|
||||||
|
document = db.session.get(DatasetDocument, document_id)
|
||||||
|
if document:
|
||||||
|
document.indexing_status = "error"
|
||||||
|
error_message = getattr(error, "description", str(error))
|
||||||
|
document.error = str(error_message)
|
||||||
|
document.stopped_at = naive_utc_now()
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
def run(self, dataset_documents: list[DatasetDocument]):
|
def run(self, dataset_documents: list[DatasetDocument]):
|
||||||
"""Run the indexing process."""
|
"""Run the indexing process."""
|
||||||
for dataset_document in dataset_documents:
|
for dataset_document in dataset_documents:
|
||||||
|
document_id = dataset_document.id
|
||||||
try:
|
try:
|
||||||
|
# Re-query the document to ensure it's bound to the current session
|
||||||
|
requeried_document = db.session.get(DatasetDocument, document_id)
|
||||||
|
if not requeried_document:
|
||||||
|
logger.warning("Document not found, skipping document id: %s", document_id)
|
||||||
|
continue
|
||||||
|
|
||||||
# get dataset
|
# get dataset
|
||||||
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
|
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
|
||||||
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError("no dataset found")
|
raise ValueError("no dataset found")
|
||||||
# get the process rule
|
# get the process rule
|
||||||
stmt = select(DatasetProcessRule).where(
|
stmt = select(DatasetProcessRule).where(
|
||||||
DatasetProcessRule.id == dataset_document.dataset_process_rule_id
|
DatasetProcessRule.id == requeried_document.dataset_process_rule_id
|
||||||
)
|
)
|
||||||
processing_rule = db.session.scalar(stmt)
|
processing_rule = db.session.scalar(stmt)
|
||||||
if not processing_rule:
|
if not processing_rule:
|
||||||
raise ValueError("no process rule found")
|
raise ValueError("no process rule found")
|
||||||
index_type = dataset_document.doc_form
|
index_type = requeried_document.doc_form
|
||||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||||
# extract
|
# extract
|
||||||
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
|
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
|
||||||
|
|
||||||
# transform
|
# transform
|
||||||
documents = self._transform(
|
documents = self._transform(
|
||||||
index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict()
|
index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict()
|
||||||
)
|
)
|
||||||
# save segment
|
# save segment
|
||||||
self._load_segments(dataset, dataset_document, documents)
|
self._load_segments(dataset, requeried_document, documents)
|
||||||
|
|
||||||
# load
|
# load
|
||||||
self._load(
|
self._load(
|
||||||
index_processor=index_processor,
|
index_processor=index_processor,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
dataset_document=dataset_document,
|
dataset_document=requeried_document,
|
||||||
documents=documents,
|
documents=documents,
|
||||||
)
|
)
|
||||||
except DocumentIsPausedError:
|
except DocumentIsPausedError:
|
||||||
raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}")
|
raise DocumentIsPausedError(f"Document paused, document id: {document_id}")
|
||||||
except ProviderTokenNotInitError as e:
|
except ProviderTokenNotInitError as e:
|
||||||
dataset_document.indexing_status = "error"
|
self._handle_indexing_error(document_id, e)
|
||||||
dataset_document.error = str(e.description)
|
|
||||||
dataset_document.stopped_at = naive_utc_now()
|
|
||||||
db.session.commit()
|
|
||||||
except ObjectDeletedError:
|
except ObjectDeletedError:
|
||||||
logger.warning("Document deleted, document id: %s", dataset_document.id)
|
logger.warning("Document deleted, document id: %s", document_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("consume document failed")
|
self._handle_indexing_error(document_id, e)
|
||||||
dataset_document.indexing_status = "error"
|
|
||||||
dataset_document.error = str(e)
|
|
||||||
dataset_document.stopped_at = naive_utc_now()
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
def run_in_splitting_status(self, dataset_document: DatasetDocument):
|
def run_in_splitting_status(self, dataset_document: DatasetDocument):
|
||||||
"""Run the indexing process when the index_status is splitting."""
|
"""Run the indexing process when the index_status is splitting."""
|
||||||
|
document_id = dataset_document.id
|
||||||
try:
|
try:
|
||||||
|
# Re-query the document to ensure it's bound to the current session
|
||||||
|
requeried_document = db.session.get(DatasetDocument, document_id)
|
||||||
|
if not requeried_document:
|
||||||
|
logger.warning("Document not found: %s", document_id)
|
||||||
|
return
|
||||||
|
|
||||||
# get dataset
|
# get dataset
|
||||||
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
|
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
|
||||||
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError("no dataset found")
|
raise ValueError("no dataset found")
|
||||||
@ -112,57 +130,60 @@ class IndexingRunner:
|
|||||||
# get exist document_segment list and delete
|
# get exist document_segment list and delete
|
||||||
document_segments = (
|
document_segments = (
|
||||||
db.session.query(DocumentSegment)
|
db.session.query(DocumentSegment)
|
||||||
.filter_by(dataset_id=dataset.id, document_id=dataset_document.id)
|
.filter_by(dataset_id=dataset.id, document_id=requeried_document.id)
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
for document_segment in document_segments:
|
for document_segment in document_segments:
|
||||||
db.session.delete(document_segment)
|
db.session.delete(document_segment)
|
||||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||||
# delete child chunks
|
# delete child chunks
|
||||||
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
|
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
# get the process rule
|
# get the process rule
|
||||||
stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
|
stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == requeried_document.dataset_process_rule_id)
|
||||||
processing_rule = db.session.scalar(stmt)
|
processing_rule = db.session.scalar(stmt)
|
||||||
if not processing_rule:
|
if not processing_rule:
|
||||||
raise ValueError("no process rule found")
|
raise ValueError("no process rule found")
|
||||||
|
|
||||||
index_type = dataset_document.doc_form
|
index_type = requeried_document.doc_form
|
||||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||||
# extract
|
# extract
|
||||||
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
|
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
|
||||||
|
|
||||||
# transform
|
# transform
|
||||||
documents = self._transform(
|
documents = self._transform(
|
||||||
index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict()
|
index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict()
|
||||||
)
|
)
|
||||||
# save segment
|
# save segment
|
||||||
self._load_segments(dataset, dataset_document, documents)
|
self._load_segments(dataset, requeried_document, documents)
|
||||||
|
|
||||||
# load
|
# load
|
||||||
self._load(
|
self._load(
|
||||||
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
|
index_processor=index_processor,
|
||||||
|
dataset=dataset,
|
||||||
|
dataset_document=requeried_document,
|
||||||
|
documents=documents,
|
||||||
)
|
)
|
||||||
except DocumentIsPausedError:
|
except DocumentIsPausedError:
|
||||||
raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}")
|
raise DocumentIsPausedError(f"Document paused, document id: {document_id}")
|
||||||
except ProviderTokenNotInitError as e:
|
except ProviderTokenNotInitError as e:
|
||||||
dataset_document.indexing_status = "error"
|
self._handle_indexing_error(document_id, e)
|
||||||
dataset_document.error = str(e.description)
|
|
||||||
dataset_document.stopped_at = naive_utc_now()
|
|
||||||
db.session.commit()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("consume document failed")
|
self._handle_indexing_error(document_id, e)
|
||||||
dataset_document.indexing_status = "error"
|
|
||||||
dataset_document.error = str(e)
|
|
||||||
dataset_document.stopped_at = naive_utc_now()
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
def run_in_indexing_status(self, dataset_document: DatasetDocument):
|
def run_in_indexing_status(self, dataset_document: DatasetDocument):
|
||||||
"""Run the indexing process when the index_status is indexing."""
|
"""Run the indexing process when the index_status is indexing."""
|
||||||
|
document_id = dataset_document.id
|
||||||
try:
|
try:
|
||||||
|
# Re-query the document to ensure it's bound to the current session
|
||||||
|
requeried_document = db.session.get(DatasetDocument, document_id)
|
||||||
|
if not requeried_document:
|
||||||
|
logger.warning("Document not found: %s", document_id)
|
||||||
|
return
|
||||||
|
|
||||||
# get dataset
|
# get dataset
|
||||||
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
|
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
|
||||||
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError("no dataset found")
|
raise ValueError("no dataset found")
|
||||||
@ -170,7 +191,7 @@ class IndexingRunner:
|
|||||||
# get exist document_segment list and delete
|
# get exist document_segment list and delete
|
||||||
document_segments = (
|
document_segments = (
|
||||||
db.session.query(DocumentSegment)
|
db.session.query(DocumentSegment)
|
||||||
.filter_by(dataset_id=dataset.id, document_id=dataset_document.id)
|
.filter_by(dataset_id=dataset.id, document_id=requeried_document.id)
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -188,7 +209,7 @@ class IndexingRunner:
|
|||||||
"dataset_id": document_segment.dataset_id,
|
"dataset_id": document_segment.dataset_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||||
child_chunks = document_segment.get_child_chunks()
|
child_chunks = document_segment.get_child_chunks()
|
||||||
if child_chunks:
|
if child_chunks:
|
||||||
child_documents = []
|
child_documents = []
|
||||||
@ -206,24 +227,20 @@ class IndexingRunner:
|
|||||||
document.children = child_documents
|
document.children = child_documents
|
||||||
documents.append(document)
|
documents.append(document)
|
||||||
# build index
|
# build index
|
||||||
index_type = dataset_document.doc_form
|
index_type = requeried_document.doc_form
|
||||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||||
self._load(
|
self._load(
|
||||||
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
|
index_processor=index_processor,
|
||||||
|
dataset=dataset,
|
||||||
|
dataset_document=requeried_document,
|
||||||
|
documents=documents,
|
||||||
)
|
)
|
||||||
except DocumentIsPausedError:
|
except DocumentIsPausedError:
|
||||||
raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}")
|
raise DocumentIsPausedError(f"Document paused, document id: {document_id}")
|
||||||
except ProviderTokenNotInitError as e:
|
except ProviderTokenNotInitError as e:
|
||||||
dataset_document.indexing_status = "error"
|
self._handle_indexing_error(document_id, e)
|
||||||
dataset_document.error = str(e.description)
|
|
||||||
dataset_document.stopped_at = naive_utc_now()
|
|
||||||
db.session.commit()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("consume document failed")
|
self._handle_indexing_error(document_id, e)
|
||||||
dataset_document.indexing_status = "error"
|
|
||||||
dataset_document.error = str(e)
|
|
||||||
dataset_document.stopped_at = naive_utc_now()
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
def indexing_estimate(
|
def indexing_estimate(
|
||||||
self,
|
self,
|
||||||
@ -398,7 +415,6 @@ class IndexingRunner:
|
|||||||
document_id=dataset_document.id,
|
document_id=dataset_document.id,
|
||||||
after_indexing_status="splitting",
|
after_indexing_status="splitting",
|
||||||
extra_update_params={
|
extra_update_params={
|
||||||
DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs),
|
|
||||||
DatasetDocument.parsing_completed_at: naive_utc_now(),
|
DatasetDocument.parsing_completed_at: naive_utc_now(),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -738,6 +754,7 @@ class IndexingRunner:
|
|||||||
extra_update_params={
|
extra_update_params={
|
||||||
DatasetDocument.cleaning_completed_at: cur_time,
|
DatasetDocument.cleaning_completed_at: cur_time,
|
||||||
DatasetDocument.splitting_completed_at: cur_time,
|
DatasetDocument.splitting_completed_at: cur_time,
|
||||||
|
DatasetDocument.word_count: sum(len(doc.page_content) for doc in documents),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -100,7 +100,7 @@ class LLMGenerator:
|
|||||||
return name
|
return name
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
|
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str) -> Sequence[str]:
|
||||||
output_parser = SuggestedQuestionsAfterAnswerOutputParser()
|
output_parser = SuggestedQuestionsAfterAnswerOutputParser()
|
||||||
format_instructions = output_parser.get_format_instructions()
|
format_instructions = output_parser.get_format_instructions()
|
||||||
|
|
||||||
@ -119,6 +119,8 @@ class LLMGenerator:
|
|||||||
|
|
||||||
prompt_messages = [UserPromptMessage(content=prompt)]
|
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||||
|
|
||||||
|
questions: Sequence[str] = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response: LLMResult = model_instance.invoke_llm(
|
response: LLMResult = model_instance.invoke_llm(
|
||||||
prompt_messages=list(prompt_messages),
|
prompt_messages=list(prompt_messages),
|
||||||
|
|||||||
@ -1,17 +1,26 @@
|
|||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SuggestedQuestionsAfterAnswerOutputParser:
|
class SuggestedQuestionsAfterAnswerOutputParser:
|
||||||
def get_format_instructions(self) -> str:
|
def get_format_instructions(self) -> str:
|
||||||
return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
||||||
|
|
||||||
def parse(self, text: str):
|
def parse(self, text: str) -> Sequence[str]:
|
||||||
action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL)
|
action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL)
|
||||||
|
questions: list[str] = []
|
||||||
if action_match is not None:
|
if action_match is not None:
|
||||||
json_obj = json.loads(action_match.group(0).strip())
|
try:
|
||||||
else:
|
json_obj = json.loads(action_match.group(0).strip())
|
||||||
json_obj = []
|
except json.JSONDecodeError as exc:
|
||||||
return json_obj
|
logger.warning("Failed to decode suggested questions payload: %s", exc)
|
||||||
|
else:
|
||||||
|
if isinstance(json_obj, list):
|
||||||
|
questions = [question for question in json_obj if isinstance(question, str)]
|
||||||
|
return questions
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
from langfuse import Langfuse # type: ignore
|
from langfuse import Langfuse
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from core.ops.base_trace_instance import BaseTraceInstance
|
from core.ops.base_trace_instance import BaseTraceInstance
|
||||||
|
|||||||
@ -76,7 +76,7 @@ class PluginParameter(BaseModel):
|
|||||||
auto_generate: PluginParameterAutoGenerate | None = None
|
auto_generate: PluginParameterAutoGenerate | None = None
|
||||||
template: PluginParameterTemplate | None = None
|
template: PluginParameterTemplate | None = None
|
||||||
required: bool = False
|
required: bool = False
|
||||||
default: Union[float, int, str] | None = None
|
default: Union[float, int, str, bool] | None = None
|
||||||
min: Union[float, int] | None = None
|
min: Union[float, int] | None = None
|
||||||
max: Union[float, int] | None = None
|
max: Union[float, int] | None = None
|
||||||
precision: int | None = None
|
precision: int | None = None
|
||||||
|
|||||||
@ -180,7 +180,7 @@ class BasePluginClient:
|
|||||||
Make a request to the plugin daemon inner API and return the response as a model.
|
Make a request to the plugin daemon inner API and return the response as a model.
|
||||||
"""
|
"""
|
||||||
response = self._request(method, path, headers, data, params, files)
|
response = self._request(method, path, headers, data, params, files)
|
||||||
return type_(**response.json()) # type: ignore
|
return type_(**response.json()) # type: ignore[return-value]
|
||||||
|
|
||||||
def _request_with_plugin_daemon_response(
|
def _request_with_plugin_daemon_response(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -40,7 +40,7 @@ class PluginDaemonBadRequestError(PluginDaemonClientSideError):
|
|||||||
description: str = "Bad Request"
|
description: str = "Bad Request"
|
||||||
|
|
||||||
|
|
||||||
class PluginInvokeError(PluginDaemonClientSideError):
|
class PluginInvokeError(PluginDaemonClientSideError, ValueError):
|
||||||
description: str = "Invoke Error"
|
description: str = "Invoke Error"
|
||||||
|
|
||||||
def _get_error_object(self) -> Mapping:
|
def _get_error_object(self) -> Mapping:
|
||||||
|
|||||||
@ -72,6 +72,19 @@ default_retrieval_model: dict[str, Any] = {
|
|||||||
class DatasetRetrieval:
|
class DatasetRetrieval:
|
||||||
def __init__(self, application_generate_entity=None):
|
def __init__(self, application_generate_entity=None):
|
||||||
self.application_generate_entity = application_generate_entity
|
self.application_generate_entity = application_generate_entity
|
||||||
|
self._llm_usage = LLMUsage.empty_usage()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def llm_usage(self) -> LLMUsage:
|
||||||
|
return self._llm_usage.model_copy()
|
||||||
|
|
||||||
|
def _record_usage(self, usage: LLMUsage | None) -> None:
|
||||||
|
if usage is None or usage.total_tokens <= 0:
|
||||||
|
return
|
||||||
|
if self._llm_usage.total_tokens == 0:
|
||||||
|
self._llm_usage = usage
|
||||||
|
else:
|
||||||
|
self._llm_usage = self._llm_usage.plus(usage)
|
||||||
|
|
||||||
def retrieve(
|
def retrieve(
|
||||||
self,
|
self,
|
||||||
@ -312,15 +325,18 @@ class DatasetRetrieval:
|
|||||||
)
|
)
|
||||||
tools.append(message_tool)
|
tools.append(message_tool)
|
||||||
dataset_id = None
|
dataset_id = None
|
||||||
|
router_usage = LLMUsage.empty_usage()
|
||||||
if planning_strategy == PlanningStrategy.REACT_ROUTER:
|
if planning_strategy == PlanningStrategy.REACT_ROUTER:
|
||||||
react_multi_dataset_router = ReactMultiDatasetRouter()
|
react_multi_dataset_router = ReactMultiDatasetRouter()
|
||||||
dataset_id = react_multi_dataset_router.invoke(
|
dataset_id, router_usage = react_multi_dataset_router.invoke(
|
||||||
query, tools, model_config, model_instance, user_id, tenant_id
|
query, tools, model_config, model_instance, user_id, tenant_id
|
||||||
)
|
)
|
||||||
|
|
||||||
elif planning_strategy == PlanningStrategy.ROUTER:
|
elif planning_strategy == PlanningStrategy.ROUTER:
|
||||||
function_call_router = FunctionCallMultiDatasetRouter()
|
function_call_router = FunctionCallMultiDatasetRouter()
|
||||||
dataset_id = function_call_router.invoke(query, tools, model_config, model_instance)
|
dataset_id, router_usage = function_call_router.invoke(query, tools, model_config, model_instance)
|
||||||
|
|
||||||
|
self._record_usage(router_usage)
|
||||||
|
|
||||||
if dataset_id:
|
if dataset_id:
|
||||||
# get retrieval model config
|
# get retrieval model config
|
||||||
@ -983,7 +999,8 @@ class DatasetRetrieval:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# handle invoke result
|
# handle invoke result
|
||||||
result_text, _ = self._handle_invoke_result(invoke_result=invoke_result)
|
result_text, usage = self._handle_invoke_result(invoke_result=invoke_result)
|
||||||
|
self._record_usage(usage)
|
||||||
|
|
||||||
result_text_json = parse_and_check_json_markdown(result_text, [])
|
result_text_json = parse_and_check_json_markdown(result_text, [])
|
||||||
automatic_metadata_filters = []
|
automatic_metadata_filters = []
|
||||||
|
|||||||
@ -2,7 +2,7 @@ from typing import Union
|
|||||||
|
|
||||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||||
from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage
|
from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage
|
||||||
|
|
||||||
|
|
||||||
@ -13,15 +13,15 @@ class FunctionCallMultiDatasetRouter:
|
|||||||
dataset_tools: list[PromptMessageTool],
|
dataset_tools: list[PromptMessageTool],
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
model_instance: ModelInstance,
|
model_instance: ModelInstance,
|
||||||
) -> Union[str, None]:
|
) -> tuple[Union[str, None], LLMUsage]:
|
||||||
"""Given input, decided what to do.
|
"""Given input, decided what to do.
|
||||||
Returns:
|
Returns:
|
||||||
Action specifying what tool to use.
|
Action specifying what tool to use.
|
||||||
"""
|
"""
|
||||||
if len(dataset_tools) == 0:
|
if len(dataset_tools) == 0:
|
||||||
return None
|
return None, LLMUsage.empty_usage()
|
||||||
elif len(dataset_tools) == 1:
|
elif len(dataset_tools) == 1:
|
||||||
return dataset_tools[0].name
|
return dataset_tools[0].name, LLMUsage.empty_usage()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
prompt_messages = [
|
prompt_messages = [
|
||||||
@ -34,9 +34,10 @@ class FunctionCallMultiDatasetRouter:
|
|||||||
stream=False,
|
stream=False,
|
||||||
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
|
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
|
||||||
)
|
)
|
||||||
|
usage = result.usage or LLMUsage.empty_usage()
|
||||||
if result.message.tool_calls:
|
if result.message.tool_calls:
|
||||||
# get retrieval model config
|
# get retrieval model config
|
||||||
return result.message.tool_calls[0].function.name
|
return result.message.tool_calls[0].function.name, usage
|
||||||
return None
|
return None, usage
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None, LLMUsage.empty_usage()
|
||||||
|
|||||||
@ -58,15 +58,15 @@ class ReactMultiDatasetRouter:
|
|||||||
model_instance: ModelInstance,
|
model_instance: ModelInstance,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
) -> Union[str, None]:
|
) -> tuple[Union[str, None], LLMUsage]:
|
||||||
"""Given input, decided what to do.
|
"""Given input, decided what to do.
|
||||||
Returns:
|
Returns:
|
||||||
Action specifying what tool to use.
|
Action specifying what tool to use.
|
||||||
"""
|
"""
|
||||||
if len(dataset_tools) == 0:
|
if len(dataset_tools) == 0:
|
||||||
return None
|
return None, LLMUsage.empty_usage()
|
||||||
elif len(dataset_tools) == 1:
|
elif len(dataset_tools) == 1:
|
||||||
return dataset_tools[0].name
|
return dataset_tools[0].name, LLMUsage.empty_usage()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return self._react_invoke(
|
return self._react_invoke(
|
||||||
@ -78,7 +78,7 @@ class ReactMultiDatasetRouter:
|
|||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None, LLMUsage.empty_usage()
|
||||||
|
|
||||||
def _react_invoke(
|
def _react_invoke(
|
||||||
self,
|
self,
|
||||||
@ -91,7 +91,7 @@ class ReactMultiDatasetRouter:
|
|||||||
prefix: str = PREFIX,
|
prefix: str = PREFIX,
|
||||||
suffix: str = SUFFIX,
|
suffix: str = SUFFIX,
|
||||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||||
) -> Union[str, None]:
|
) -> tuple[Union[str, None], LLMUsage]:
|
||||||
prompt: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
|
prompt: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
|
||||||
if model_config.mode == "chat":
|
if model_config.mode == "chat":
|
||||||
prompt = self.create_chat_prompt(
|
prompt = self.create_chat_prompt(
|
||||||
@ -120,7 +120,7 @@ class ReactMultiDatasetRouter:
|
|||||||
memory=None,
|
memory=None,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
)
|
)
|
||||||
result_text, _ = self._invoke_llm(
|
result_text, usage = self._invoke_llm(
|
||||||
completion_param=model_config.parameters,
|
completion_param=model_config.parameters,
|
||||||
model_instance=model_instance,
|
model_instance=model_instance,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
@ -131,8 +131,8 @@ class ReactMultiDatasetRouter:
|
|||||||
output_parser = StructuredChatOutputParser()
|
output_parser = StructuredChatOutputParser()
|
||||||
react_decision = output_parser.parse(result_text)
|
react_decision = output_parser.parse(result_text)
|
||||||
if isinstance(react_decision, ReactAction):
|
if isinstance(react_decision, ReactAction):
|
||||||
return react_decision.tool
|
return react_decision.tool, usage
|
||||||
return None
|
return None, usage
|
||||||
|
|
||||||
def _invoke_llm(
|
def _invoke_llm(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -74,7 +74,7 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository):
|
|||||||
tenant_id = extract_tenant_id(user)
|
tenant_id = extract_tenant_id(user)
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
raise ValueError("User must have a tenant_id or current_tenant_id")
|
raise ValueError("User must have a tenant_id or current_tenant_id")
|
||||||
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None
|
self._tenant_id = tenant_id
|
||||||
|
|
||||||
# Store app context
|
# Store app context
|
||||||
self._app_id = app_id
|
self._app_id = app_id
|
||||||
|
|||||||
@ -81,7 +81,7 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
|
|||||||
tenant_id = extract_tenant_id(user)
|
tenant_id = extract_tenant_id(user)
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
raise ValueError("User must have a tenant_id or current_tenant_id")
|
raise ValueError("User must have a tenant_id or current_tenant_id")
|
||||||
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None
|
self._tenant_id = tenant_id
|
||||||
|
|
||||||
# Store app context
|
# Store app context
|
||||||
self._app_id = app_id
|
self._app_id = app_id
|
||||||
|
|||||||
@ -60,7 +60,7 @@ class DifyCoreRepositoryFactory:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
repository_class = import_string(class_path)
|
repository_class = import_string(class_path)
|
||||||
return repository_class( # type: ignore[no-any-return]
|
return repository_class(
|
||||||
session_factory=session_factory,
|
session_factory=session_factory,
|
||||||
user=user,
|
user=user,
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
@ -96,7 +96,7 @@ class DifyCoreRepositoryFactory:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
repository_class = import_string(class_path)
|
repository_class = import_string(class_path)
|
||||||
return repository_class( # type: ignore[no-any-return]
|
return repository_class(
|
||||||
session_factory=session_factory,
|
session_factory=session_factory,
|
||||||
user=user,
|
user=user,
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
|
|||||||
@ -157,7 +157,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||||||
"""
|
"""
|
||||||
returns the tool that the provider can provide
|
returns the tool that the provider can provide
|
||||||
"""
|
"""
|
||||||
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None) # type: ignore
|
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def need_credentials(self) -> bool:
|
def need_credentials(self) -> bool:
|
||||||
|
|||||||
@ -43,7 +43,7 @@ class TTSTool(BuiltinTool):
|
|||||||
content_text=tool_parameters.get("text"), # type: ignore
|
content_text=tool_parameters.get("text"), # type: ignore
|
||||||
user=user_id,
|
user=user_id,
|
||||||
tenant_id=self.runtime.tenant_id,
|
tenant_id=self.runtime.tenant_id,
|
||||||
voice=voice, # type: ignore
|
voice=voice,
|
||||||
)
|
)
|
||||||
buffer = io.BytesIO()
|
buffer = io.BytesIO()
|
||||||
for chunk in tts:
|
for chunk in tts:
|
||||||
|
|||||||
@ -34,6 +34,7 @@ class LocaltimeToTimestampTool(BuiltinTool):
|
|||||||
|
|
||||||
yield self.create_text_message(f"{timestamp}")
|
yield self.create_text_message(f"{timestamp}")
|
||||||
|
|
||||||
|
# TODO: this method's type is messy
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None:
|
def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -48,6 +48,6 @@ class TimezoneConversionTool(BuiltinTool):
|
|||||||
datetime_with_tz = input_timezone.localize(local_time)
|
datetime_with_tz = input_timezone.localize(local_time)
|
||||||
# timezone convert
|
# timezone convert
|
||||||
converted_datetime = datetime_with_tz.astimezone(output_timezone)
|
converted_datetime = datetime_with_tz.astimezone(output_timezone)
|
||||||
return converted_datetime.strftime(format=time_format) # type: ignore
|
return converted_datetime.strftime(time_format)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ToolInvokeError(str(e))
|
raise ToolInvokeError(str(e))
|
||||||
|
|||||||
@ -113,7 +113,7 @@ class MCPToolProviderController(ToolProviderController):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_tool(self, tool_name: str) -> MCPTool: # type: ignore
|
def get_tool(self, tool_name: str) -> MCPTool:
|
||||||
"""
|
"""
|
||||||
return tool with given name
|
return tool with given name
|
||||||
"""
|
"""
|
||||||
@ -136,7 +136,7 @@ class MCPToolProviderController(ToolProviderController):
|
|||||||
sse_read_timeout=self.sse_read_timeout,
|
sse_read_timeout=self.sse_read_timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_tools(self) -> list[MCPTool]: # type: ignore
|
def get_tools(self) -> list[MCPTool]:
|
||||||
"""
|
"""
|
||||||
get all tools
|
get all tools
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -26,7 +26,7 @@ class ToolLabelManager:
|
|||||||
labels = cls.filter_tool_labels(labels)
|
labels = cls.filter_tool_labels(labels)
|
||||||
|
|
||||||
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||||
provider_id = controller.provider_id # ty: ignore [unresolved-attribute]
|
provider_id = controller.provider_id
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported tool type")
|
raise ValueError("Unsupported tool type")
|
||||||
|
|
||||||
@ -51,7 +51,7 @@ class ToolLabelManager:
|
|||||||
Get tool labels
|
Get tool labels
|
||||||
"""
|
"""
|
||||||
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||||
provider_id = controller.provider_id # ty: ignore [unresolved-attribute]
|
provider_id = controller.provider_id
|
||||||
elif isinstance(controller, BuiltinToolProviderController):
|
elif isinstance(controller, BuiltinToolProviderController):
|
||||||
return controller.tool_labels
|
return controller.tool_labels
|
||||||
else:
|
else:
|
||||||
@ -85,7 +85,7 @@ class ToolLabelManager:
|
|||||||
provider_ids = []
|
provider_ids = []
|
||||||
for controller in tool_providers:
|
for controller in tool_providers:
|
||||||
assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
|
assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
|
||||||
provider_ids.append(controller.provider_id) # ty: ignore [unresolved-attribute]
|
provider_ids.append(controller.provider_id)
|
||||||
|
|
||||||
labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all()
|
labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all()
|
||||||
|
|
||||||
|
|||||||
@ -331,7 +331,8 @@ class ToolManager:
|
|||||||
workflow_provider_stmt = select(WorkflowToolProvider).where(
|
workflow_provider_stmt = select(WorkflowToolProvider).where(
|
||||||
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id
|
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id
|
||||||
)
|
)
|
||||||
workflow_provider = db.session.scalar(workflow_provider_stmt)
|
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||||
|
workflow_provider = session.scalar(workflow_provider_stmt)
|
||||||
|
|
||||||
if workflow_provider is None:
|
if workflow_provider is None:
|
||||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||||
|
|||||||
@ -193,18 +193,18 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||||||
DatasetDocument.enabled == True,
|
DatasetDocument.enabled == True,
|
||||||
DatasetDocument.archived == False,
|
DatasetDocument.archived == False,
|
||||||
)
|
)
|
||||||
document = db.session.scalar(dataset_document_stmt) # type: ignore
|
document = db.session.scalar(dataset_document_stmt)
|
||||||
if dataset and document:
|
if dataset and document:
|
||||||
source = RetrievalSourceMetadata(
|
source = RetrievalSourceMetadata(
|
||||||
dataset_id=dataset.id,
|
dataset_id=dataset.id,
|
||||||
dataset_name=dataset.name,
|
dataset_name=dataset.name,
|
||||||
document_id=document.id, # type: ignore
|
document_id=document.id,
|
||||||
document_name=document.name, # type: ignore
|
document_name=document.name,
|
||||||
data_source_type=document.data_source_type, # type: ignore
|
data_source_type=document.data_source_type,
|
||||||
segment_id=segment.id,
|
segment_id=segment.id,
|
||||||
retriever_from=self.retriever_from,
|
retriever_from=self.retriever_from,
|
||||||
score=record.score or 0.0,
|
score=record.score or 0.0,
|
||||||
doc_metadata=document.doc_metadata, # type: ignore
|
doc_metadata=document.doc_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.retriever_from == "dev":
|
if self.retriever_from == "dev":
|
||||||
|
|||||||
@ -62,6 +62,11 @@ class ApiBasedToolSchemaParser:
|
|||||||
root = root[ref]
|
root = root[ref]
|
||||||
interface["operation"]["parameters"][i] = root
|
interface["operation"]["parameters"][i] = root
|
||||||
for parameter in interface["operation"]["parameters"]:
|
for parameter in interface["operation"]["parameters"]:
|
||||||
|
# Handle complex type defaults that are not supported by PluginParameter
|
||||||
|
default_value = None
|
||||||
|
if "schema" in parameter and "default" in parameter["schema"]:
|
||||||
|
default_value = ApiBasedToolSchemaParser._sanitize_default_value(parameter["schema"]["default"])
|
||||||
|
|
||||||
tool_parameter = ToolParameter(
|
tool_parameter = ToolParameter(
|
||||||
name=parameter["name"],
|
name=parameter["name"],
|
||||||
label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]),
|
label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]),
|
||||||
@ -72,9 +77,7 @@ class ApiBasedToolSchemaParser:
|
|||||||
required=parameter.get("required", False),
|
required=parameter.get("required", False),
|
||||||
form=ToolParameter.ToolParameterForm.LLM,
|
form=ToolParameter.ToolParameterForm.LLM,
|
||||||
llm_description=parameter.get("description"),
|
llm_description=parameter.get("description"),
|
||||||
default=parameter["schema"]["default"]
|
default=default_value,
|
||||||
if "schema" in parameter and "default" in parameter["schema"]
|
|
||||||
else None,
|
|
||||||
placeholder=I18nObject(
|
placeholder=I18nObject(
|
||||||
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||||
),
|
),
|
||||||
@ -134,6 +137,11 @@ class ApiBasedToolSchemaParser:
|
|||||||
required = body_schema.get("required", [])
|
required = body_schema.get("required", [])
|
||||||
properties = body_schema.get("properties", {})
|
properties = body_schema.get("properties", {})
|
||||||
for name, property in properties.items():
|
for name, property in properties.items():
|
||||||
|
# Handle complex type defaults that are not supported by PluginParameter
|
||||||
|
default_value = ApiBasedToolSchemaParser._sanitize_default_value(
|
||||||
|
property.get("default", None)
|
||||||
|
)
|
||||||
|
|
||||||
tool = ToolParameter(
|
tool = ToolParameter(
|
||||||
name=name,
|
name=name,
|
||||||
label=I18nObject(en_US=name, zh_Hans=name),
|
label=I18nObject(en_US=name, zh_Hans=name),
|
||||||
@ -144,12 +152,11 @@ class ApiBasedToolSchemaParser:
|
|||||||
required=name in required,
|
required=name in required,
|
||||||
form=ToolParameter.ToolParameterForm.LLM,
|
form=ToolParameter.ToolParameterForm.LLM,
|
||||||
llm_description=property.get("description", ""),
|
llm_description=property.get("description", ""),
|
||||||
default=property.get("default", None),
|
default=default_value,
|
||||||
placeholder=I18nObject(
|
placeholder=I18nObject(
|
||||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# check if there is a type
|
# check if there is a type
|
||||||
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
|
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
|
||||||
if typ:
|
if typ:
|
||||||
@ -197,6 +204,22 @@ class ApiBasedToolSchemaParser:
|
|||||||
|
|
||||||
return bundles
|
return bundles
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sanitize_default_value(value):
|
||||||
|
"""
|
||||||
|
Sanitize default values for PluginParameter compatibility.
|
||||||
|
Complex types (list, dict) are converted to None to avoid validation errors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: The default value from OpenAPI schema
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None for complex types (list, dict), otherwise the original value
|
||||||
|
"""
|
||||||
|
if isinstance(value, (list, dict)):
|
||||||
|
return None
|
||||||
|
return value
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None:
|
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None:
|
||||||
parameter = parameter or {}
|
parameter = parameter or {}
|
||||||
@ -217,7 +240,11 @@ class ApiBasedToolSchemaParser:
|
|||||||
return ToolParameter.ToolParameterType.STRING
|
return ToolParameter.ToolParameterType.STRING
|
||||||
elif typ == "array":
|
elif typ == "array":
|
||||||
items = parameter.get("items") or parameter.get("schema", {}).get("items")
|
items = parameter.get("items") or parameter.get("schema", {}).get("items")
|
||||||
return ToolParameter.ToolParameterType.FILES if items and items.get("format") == "binary" else None
|
if items and items.get("format") == "binary":
|
||||||
|
return ToolParameter.ToolParameterType.FILES
|
||||||
|
else:
|
||||||
|
# For regular arrays, return ARRAY type instead of None
|
||||||
|
return ToolParameter.ToolParameterType.ARRAY
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@ -6,8 +6,8 @@ from typing import Any, cast
|
|||||||
from urllib.parse import unquote
|
from urllib.parse import unquote
|
||||||
|
|
||||||
import chardet
|
import chardet
|
||||||
import cloudscraper # type: ignore
|
import cloudscraper
|
||||||
from readabilipy import simple_json_from_html_string # type: ignore
|
from readabilipy import simple_json_from_html_string
|
||||||
|
|
||||||
from core.helper import ssrf_proxy
|
from core.helper import ssrf_proxy
|
||||||
from core.rag.extractor import extract_processor
|
from core.rag.extractor import extract_processor
|
||||||
@ -63,8 +63,8 @@ def get_url(url: str, user_agent: str | None = None) -> str:
|
|||||||
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
|
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
|
||||||
elif response.status_code == 403:
|
elif response.status_code == 403:
|
||||||
scraper = cloudscraper.create_scraper()
|
scraper = cloudscraper.create_scraper()
|
||||||
scraper.perform_request = ssrf_proxy.make_request # type: ignore
|
scraper.perform_request = ssrf_proxy.make_request
|
||||||
response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) # type: ignore
|
response = scraper.get(url, headers=headers, timeout=(120, 300))
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
return f"URL returned status code {response.status_code}."
|
return f"URL returned status code {response.status_code}."
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from functools import lru_cache
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import yaml # type: ignore
|
import yaml
|
||||||
from yaml import YAMLError
|
from yaml import YAMLError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||||
@ -20,6 +21,7 @@ from core.tools.entities.tool_entities import (
|
|||||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from models.account import Account
|
||||||
from models.model import App, AppMode
|
from models.model import App, AppMode
|
||||||
from models.tools import WorkflowToolProvider
|
from models.tools import WorkflowToolProvider
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
@ -44,29 +46,34 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController":
|
def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController":
|
||||||
app = db_provider.app
|
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||||
|
provider = session.get(WorkflowToolProvider, db_provider.id) if db_provider.id else None
|
||||||
|
if not provider:
|
||||||
|
raise ValueError("workflow provider not found")
|
||||||
|
app = session.get(App, provider.app_id)
|
||||||
|
if not app:
|
||||||
|
raise ValueError("app not found")
|
||||||
|
|
||||||
if not app:
|
user = session.get(Account, provider.user_id) if provider.user_id else None
|
||||||
raise ValueError("app not found")
|
|
||||||
|
|
||||||
controller = WorkflowToolProviderController(
|
controller = WorkflowToolProviderController(
|
||||||
entity=ToolProviderEntity(
|
entity=ToolProviderEntity(
|
||||||
identity=ToolProviderIdentity(
|
identity=ToolProviderIdentity(
|
||||||
author=db_provider.user.name if db_provider.user_id and db_provider.user else "",
|
author=user.name if user else "",
|
||||||
name=db_provider.label,
|
name=provider.label,
|
||||||
label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label),
|
label=I18nObject(en_US=provider.label, zh_Hans=provider.label),
|
||||||
description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
|
description=I18nObject(en_US=provider.description, zh_Hans=provider.description),
|
||||||
icon=db_provider.icon,
|
icon=provider.icon,
|
||||||
|
),
|
||||||
|
credentials_schema=[],
|
||||||
|
plugin_id=None,
|
||||||
),
|
),
|
||||||
credentials_schema=[],
|
provider_id=provider.id or "",
|
||||||
plugin_id=None,
|
)
|
||||||
),
|
|
||||||
provider_id=db_provider.id or "",
|
|
||||||
)
|
|
||||||
|
|
||||||
# init tools
|
controller.tools = [
|
||||||
|
controller._get_db_provider_tool(provider, app, session=session, user=user),
|
||||||
controller.tools = [controller._get_db_provider_tool(db_provider, app)]
|
]
|
||||||
|
|
||||||
return controller
|
return controller
|
||||||
|
|
||||||
@ -74,7 +81,14 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||||||
def provider_type(self) -> ToolProviderType:
|
def provider_type(self) -> ToolProviderType:
|
||||||
return ToolProviderType.WORKFLOW
|
return ToolProviderType.WORKFLOW
|
||||||
|
|
||||||
def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool:
|
def _get_db_provider_tool(
|
||||||
|
self,
|
||||||
|
db_provider: WorkflowToolProvider,
|
||||||
|
app: App,
|
||||||
|
*,
|
||||||
|
session: Session,
|
||||||
|
user: Account | None = None,
|
||||||
|
) -> WorkflowTool:
|
||||||
"""
|
"""
|
||||||
get db provider tool
|
get db provider tool
|
||||||
:param db_provider: the db provider
|
:param db_provider: the db provider
|
||||||
@ -82,7 +96,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||||||
:return: the tool
|
:return: the tool
|
||||||
"""
|
"""
|
||||||
workflow: Workflow | None = (
|
workflow: Workflow | None = (
|
||||||
db.session.query(Workflow)
|
session.query(Workflow)
|
||||||
.where(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
|
.where(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
@ -99,9 +113,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||||||
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
|
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
|
||||||
|
|
||||||
def fetch_workflow_variable(variable_name: str) -> VariableEntity | None:
|
def fetch_workflow_variable(variable_name: str) -> VariableEntity | None:
|
||||||
return next(filter(lambda x: x.variable == variable_name, variables), None) # type: ignore
|
return next(filter(lambda x: x.variable == variable_name, variables), None)
|
||||||
|
|
||||||
user = db_provider.user
|
|
||||||
|
|
||||||
workflow_tool_parameters = []
|
workflow_tool_parameters = []
|
||||||
for parameter in parameters:
|
for parameter in parameters:
|
||||||
@ -187,22 +199,25 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||||||
if self.tools is not None:
|
if self.tools is not None:
|
||||||
return self.tools
|
return self.tools
|
||||||
|
|
||||||
db_providers: WorkflowToolProvider | None = (
|
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||||
db.session.query(WorkflowToolProvider)
|
db_provider: WorkflowToolProvider | None = (
|
||||||
.where(
|
session.query(WorkflowToolProvider)
|
||||||
WorkflowToolProvider.tenant_id == tenant_id,
|
.where(
|
||||||
WorkflowToolProvider.app_id == self.provider_id,
|
WorkflowToolProvider.tenant_id == tenant_id,
|
||||||
|
WorkflowToolProvider.app_id == self.provider_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
)
|
)
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not db_providers:
|
if not db_provider:
|
||||||
return []
|
return []
|
||||||
if not db_providers.app:
|
|
||||||
raise ValueError("app not found")
|
|
||||||
|
|
||||||
app = db_providers.app
|
app = session.get(App, db_provider.app_id)
|
||||||
self.tools = [self._get_db_provider_tool(db_providers, app)]
|
if not app:
|
||||||
|
raise ValueError("app not found")
|
||||||
|
|
||||||
|
user = session.get(Account, db_provider.user_id) if db_provider.user_id else None
|
||||||
|
self.tools = [self._get_db_provider_tool(db_provider, app, session=session, user=user)]
|
||||||
|
|
||||||
return self.tools
|
return self.tools
|
||||||
|
|
||||||
|
|||||||
@ -1,12 +1,14 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
|
|
||||||
from flask import has_request_context
|
from flask import has_request_context
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||||
from core.tools.__base.tool import Tool
|
from core.tools.__base.tool import Tool
|
||||||
from core.tools.__base.tool_runtime import ToolRuntime
|
from core.tools.__base.tool_runtime import ToolRuntime
|
||||||
from core.tools.entities.tool_entities import (
|
from core.tools.entities.tool_entities import (
|
||||||
@ -48,6 +50,7 @@ class WorkflowTool(Tool):
|
|||||||
self.workflow_entities = workflow_entities
|
self.workflow_entities = workflow_entities
|
||||||
self.workflow_call_depth = workflow_call_depth
|
self.workflow_call_depth = workflow_call_depth
|
||||||
self.label = label
|
self.label = label
|
||||||
|
self._latest_usage = LLMUsage.empty_usage()
|
||||||
|
|
||||||
super().__init__(entity=entity, runtime=runtime)
|
super().__init__(entity=entity, runtime=runtime)
|
||||||
|
|
||||||
@ -83,10 +86,11 @@ class WorkflowTool(Tool):
|
|||||||
assert self.runtime.invoke_from is not None
|
assert self.runtime.invoke_from is not None
|
||||||
|
|
||||||
user = self._resolve_user(user_id=user_id)
|
user = self._resolve_user(user_id=user_id)
|
||||||
|
|
||||||
if user is None:
|
if user is None:
|
||||||
raise ToolInvokeError("User not found")
|
raise ToolInvokeError("User not found")
|
||||||
|
|
||||||
|
self._latest_usage = LLMUsage.empty_usage()
|
||||||
|
|
||||||
result = generator.generate(
|
result = generator.generate(
|
||||||
app_model=app,
|
app_model=app,
|
||||||
workflow=workflow,
|
workflow=workflow,
|
||||||
@ -110,9 +114,68 @@ class WorkflowTool(Tool):
|
|||||||
for file in files:
|
for file in files:
|
||||||
yield self.create_file_message(file) # type: ignore
|
yield self.create_file_message(file) # type: ignore
|
||||||
|
|
||||||
|
self._latest_usage = self._derive_usage_from_result(data)
|
||||||
|
|
||||||
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
|
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
|
||||||
yield self.create_json_message(outputs)
|
yield self.create_json_message(outputs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def latest_usage(self) -> LLMUsage:
|
||||||
|
return self._latest_usage
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _derive_usage_from_result(cls, data: Mapping[str, Any]) -> LLMUsage:
|
||||||
|
usage_dict = cls._extract_usage_dict(data)
|
||||||
|
if usage_dict is not None:
|
||||||
|
return LLMUsage.from_metadata(cast(LLMUsageMetadata, dict(usage_dict)))
|
||||||
|
|
||||||
|
total_tokens = data.get("total_tokens")
|
||||||
|
total_price = data.get("total_price")
|
||||||
|
if total_tokens is None and total_price is None:
|
||||||
|
return LLMUsage.empty_usage()
|
||||||
|
|
||||||
|
usage_metadata: dict[str, Any] = {}
|
||||||
|
if total_tokens is not None:
|
||||||
|
try:
|
||||||
|
usage_metadata["total_tokens"] = int(str(total_tokens))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
if total_price is not None:
|
||||||
|
usage_metadata["total_price"] = str(total_price)
|
||||||
|
currency = data.get("currency")
|
||||||
|
if currency is not None:
|
||||||
|
usage_metadata["currency"] = currency
|
||||||
|
|
||||||
|
if not usage_metadata:
|
||||||
|
return LLMUsage.empty_usage()
|
||||||
|
|
||||||
|
return LLMUsage.from_metadata(cast(LLMUsageMetadata, usage_metadata))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _extract_usage_dict(cls, payload: Mapping[str, Any]) -> Mapping[str, Any] | None:
|
||||||
|
usage_candidate = payload.get("usage")
|
||||||
|
if isinstance(usage_candidate, Mapping):
|
||||||
|
return usage_candidate
|
||||||
|
|
||||||
|
metadata_candidate = payload.get("metadata")
|
||||||
|
if isinstance(metadata_candidate, Mapping):
|
||||||
|
usage_candidate = metadata_candidate.get("usage")
|
||||||
|
if isinstance(usage_candidate, Mapping):
|
||||||
|
return usage_candidate
|
||||||
|
|
||||||
|
for value in payload.values():
|
||||||
|
if isinstance(value, Mapping):
|
||||||
|
found = cls._extract_usage_dict(value)
|
||||||
|
if found is not None:
|
||||||
|
return found
|
||||||
|
elif isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, Mapping):
|
||||||
|
found = cls._extract_usage_dict(item)
|
||||||
|
if found is not None:
|
||||||
|
return found
|
||||||
|
return None
|
||||||
|
|
||||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
|
def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
|
||||||
"""
|
"""
|
||||||
fork a new tool with metadata
|
fork a new tool with metadata
|
||||||
@ -179,16 +242,17 @@ class WorkflowTool(Tool):
|
|||||||
"""
|
"""
|
||||||
get the workflow by app id and version
|
get the workflow by app id and version
|
||||||
"""
|
"""
|
||||||
if not version:
|
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||||
workflow = (
|
if not version:
|
||||||
db.session.query(Workflow)
|
stmt = (
|
||||||
.where(Workflow.app_id == app_id, Workflow.version != Workflow.VERSION_DRAFT)
|
select(Workflow)
|
||||||
.order_by(Workflow.created_at.desc())
|
.where(Workflow.app_id == app_id, Workflow.version != Workflow.VERSION_DRAFT)
|
||||||
.first()
|
.order_by(Workflow.created_at.desc())
|
||||||
)
|
)
|
||||||
else:
|
workflow = session.scalars(stmt).first()
|
||||||
stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version)
|
else:
|
||||||
workflow = db.session.scalar(stmt)
|
stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version)
|
||||||
|
workflow = session.scalar(stmt)
|
||||||
|
|
||||||
if not workflow:
|
if not workflow:
|
||||||
raise ValueError("workflow not found or not published")
|
raise ValueError("workflow not found or not published")
|
||||||
@ -200,7 +264,8 @@ class WorkflowTool(Tool):
|
|||||||
get the app by app id
|
get the app by app id
|
||||||
"""
|
"""
|
||||||
stmt = select(App).where(App.id == app_id)
|
stmt = select(App).where(App.id == app_id)
|
||||||
app = db.session.scalar(stmt)
|
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||||
|
app = session.scalar(stmt)
|
||||||
if not app:
|
if not app:
|
||||||
raise ValueError("app not found")
|
raise ValueError("app not found")
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from .types import SegmentType
|
|||||||
|
|
||||||
class SegmentGroup(Segment):
|
class SegmentGroup(Segment):
|
||||||
value_type: SegmentType = SegmentType.GROUP
|
value_type: SegmentType = SegmentType.GROUP
|
||||||
value: list[Segment] = None # type: ignore
|
value: list[Segment]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text(self):
|
def text(self):
|
||||||
|
|||||||
@ -19,7 +19,7 @@ class Segment(BaseModel):
|
|||||||
model_config = ConfigDict(frozen=True)
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
value_type: SegmentType
|
value_type: SegmentType
|
||||||
value: Any = None
|
value: Any
|
||||||
|
|
||||||
@field_validator("value_type")
|
@field_validator("value_type")
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -74,12 +74,12 @@ class NoneSegment(Segment):
|
|||||||
|
|
||||||
class StringSegment(Segment):
|
class StringSegment(Segment):
|
||||||
value_type: SegmentType = SegmentType.STRING
|
value_type: SegmentType = SegmentType.STRING
|
||||||
value: str = None # type: ignore
|
value: str
|
||||||
|
|
||||||
|
|
||||||
class FloatSegment(Segment):
|
class FloatSegment(Segment):
|
||||||
value_type: SegmentType = SegmentType.FLOAT
|
value_type: SegmentType = SegmentType.FLOAT
|
||||||
value: float = None # type: ignore
|
value: float
|
||||||
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
|
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
|
||||||
# The following tests cannot pass.
|
# The following tests cannot pass.
|
||||||
#
|
#
|
||||||
@ -98,12 +98,12 @@ class FloatSegment(Segment):
|
|||||||
|
|
||||||
class IntegerSegment(Segment):
|
class IntegerSegment(Segment):
|
||||||
value_type: SegmentType = SegmentType.INTEGER
|
value_type: SegmentType = SegmentType.INTEGER
|
||||||
value: int = None # type: ignore
|
value: int
|
||||||
|
|
||||||
|
|
||||||
class ObjectSegment(Segment):
|
class ObjectSegment(Segment):
|
||||||
value_type: SegmentType = SegmentType.OBJECT
|
value_type: SegmentType = SegmentType.OBJECT
|
||||||
value: Mapping[str, Any] = None # type: ignore
|
value: Mapping[str, Any]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text(self) -> str:
|
def text(self) -> str:
|
||||||
@ -136,7 +136,7 @@ class ArraySegment(Segment):
|
|||||||
|
|
||||||
class FileSegment(Segment):
|
class FileSegment(Segment):
|
||||||
value_type: SegmentType = SegmentType.FILE
|
value_type: SegmentType = SegmentType.FILE
|
||||||
value: File = None # type: ignore
|
value: File
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def markdown(self) -> str:
|
def markdown(self) -> str:
|
||||||
@ -153,17 +153,17 @@ class FileSegment(Segment):
|
|||||||
|
|
||||||
class BooleanSegment(Segment):
|
class BooleanSegment(Segment):
|
||||||
value_type: SegmentType = SegmentType.BOOLEAN
|
value_type: SegmentType = SegmentType.BOOLEAN
|
||||||
value: bool = None # type: ignore
|
value: bool
|
||||||
|
|
||||||
|
|
||||||
class ArrayAnySegment(ArraySegment):
|
class ArrayAnySegment(ArraySegment):
|
||||||
value_type: SegmentType = SegmentType.ARRAY_ANY
|
value_type: SegmentType = SegmentType.ARRAY_ANY
|
||||||
value: Sequence[Any] = None # type: ignore
|
value: Sequence[Any]
|
||||||
|
|
||||||
|
|
||||||
class ArrayStringSegment(ArraySegment):
|
class ArrayStringSegment(ArraySegment):
|
||||||
value_type: SegmentType = SegmentType.ARRAY_STRING
|
value_type: SegmentType = SegmentType.ARRAY_STRING
|
||||||
value: Sequence[str] = None # type: ignore
|
value: Sequence[str]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text(self) -> str:
|
def text(self) -> str:
|
||||||
@ -175,17 +175,17 @@ class ArrayStringSegment(ArraySegment):
|
|||||||
|
|
||||||
class ArrayNumberSegment(ArraySegment):
|
class ArrayNumberSegment(ArraySegment):
|
||||||
value_type: SegmentType = SegmentType.ARRAY_NUMBER
|
value_type: SegmentType = SegmentType.ARRAY_NUMBER
|
||||||
value: Sequence[float | int] = None # type: ignore
|
value: Sequence[float | int]
|
||||||
|
|
||||||
|
|
||||||
class ArrayObjectSegment(ArraySegment):
|
class ArrayObjectSegment(ArraySegment):
|
||||||
value_type: SegmentType = SegmentType.ARRAY_OBJECT
|
value_type: SegmentType = SegmentType.ARRAY_OBJECT
|
||||||
value: Sequence[Mapping[str, Any]] = None # type: ignore
|
value: Sequence[Mapping[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
class ArrayFileSegment(ArraySegment):
|
class ArrayFileSegment(ArraySegment):
|
||||||
value_type: SegmentType = SegmentType.ARRAY_FILE
|
value_type: SegmentType = SegmentType.ARRAY_FILE
|
||||||
value: Sequence[File] = None # type: ignore
|
value: Sequence[File]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def markdown(self) -> str:
|
def markdown(self) -> str:
|
||||||
@ -205,7 +205,7 @@ class ArrayFileSegment(ArraySegment):
|
|||||||
|
|
||||||
class ArrayBooleanSegment(ArraySegment):
|
class ArrayBooleanSegment(ArraySegment):
|
||||||
value_type: SegmentType = SegmentType.ARRAY_BOOLEAN
|
value_type: SegmentType = SegmentType.ARRAY_BOOLEAN
|
||||||
value: Sequence[bool] = None # type: ignore
|
value: Sequence[bool]
|
||||||
|
|
||||||
|
|
||||||
def get_segment_discriminator(v: Any) -> SegmentType | None:
|
def get_segment_discriminator(v: Any) -> SegmentType | None:
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from ..runtime.graph_runtime_state import GraphRuntimeState
|
||||||
|
from ..runtime.variable_pool import VariablePool
|
||||||
from .agent import AgentNodeStrategyInit
|
from .agent import AgentNodeStrategyInit
|
||||||
from .graph_init_params import GraphInitParams
|
from .graph_init_params import GraphInitParams
|
||||||
from .workflow_execution import WorkflowExecution
|
from .workflow_execution import WorkflowExecution
|
||||||
@ -6,6 +8,8 @@ from .workflow_node_execution import WorkflowNodeExecution
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"AgentNodeStrategyInit",
|
"AgentNodeStrategyInit",
|
||||||
"GraphInitParams",
|
"GraphInitParams",
|
||||||
|
"GraphRuntimeState",
|
||||||
|
"VariablePool",
|
||||||
"WorkflowExecution",
|
"WorkflowExecution",
|
||||||
"WorkflowNodeExecution",
|
"WorkflowNodeExecution",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -3,11 +3,12 @@ from collections import defaultdict
|
|||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Protocol, cast, final
|
from typing import Protocol, cast, final
|
||||||
|
|
||||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
|
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from libs.typing import is_str, is_str_dict
|
from libs.typing import is_str, is_str_dict
|
||||||
|
|
||||||
from .edge import Edge
|
from .edge import Edge
|
||||||
|
from .validation import get_graph_validator
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -201,6 +202,17 @@ class Graph:
|
|||||||
|
|
||||||
return GraphBuilder(graph_cls=cls)
|
return GraphBuilder(graph_cls=cls)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None:
|
||||||
|
"""
|
||||||
|
Promote nodes configured with FAIL_BRANCH error strategy to branch execution type.
|
||||||
|
|
||||||
|
:param nodes: mapping of node ID to node instance
|
||||||
|
"""
|
||||||
|
for node in nodes.values():
|
||||||
|
if node.error_strategy == ErrorStrategy.FAIL_BRANCH:
|
||||||
|
node.execution_type = NodeExecutionType.BRANCH
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _mark_inactive_root_branches(
|
def _mark_inactive_root_branches(
|
||||||
cls,
|
cls,
|
||||||
@ -307,6 +319,9 @@ class Graph:
|
|||||||
# Create node instances
|
# Create node instances
|
||||||
nodes = cls._create_node_instances(node_configs_map, node_factory)
|
nodes = cls._create_node_instances(node_configs_map, node_factory)
|
||||||
|
|
||||||
|
# Promote fail-branch nodes to branch execution type at graph level
|
||||||
|
cls._promote_fail_branch_nodes(nodes)
|
||||||
|
|
||||||
# Get root node instance
|
# Get root node instance
|
||||||
root_node = nodes[root_node_id]
|
root_node = nodes[root_node_id]
|
||||||
|
|
||||||
@ -314,7 +329,7 @@ class Graph:
|
|||||||
cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id)
|
cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id)
|
||||||
|
|
||||||
# Create and return the graph
|
# Create and return the graph
|
||||||
return cls(
|
graph = cls(
|
||||||
nodes=nodes,
|
nodes=nodes,
|
||||||
edges=edges,
|
edges=edges,
|
||||||
in_edges=in_edges,
|
in_edges=in_edges,
|
||||||
@ -322,6 +337,11 @@ class Graph:
|
|||||||
root_node=root_node,
|
root_node=root_node,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Validate the graph structure using built-in validators
|
||||||
|
get_graph_validator().validate(graph)
|
||||||
|
|
||||||
|
return graph
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def node_ids(self) -> list[str]:
|
def node_ids(self) -> list[str]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
125
api/core/workflow/graph/validation.py
Normal file
125
api/core/workflow/graph/validation.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Protocol
|
||||||
|
|
||||||
|
from core.workflow.enums import NodeExecutionType, NodeType
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .graph import Graph
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class GraphValidationIssue:
|
||||||
|
"""Immutable value object describing a single validation issue."""
|
||||||
|
|
||||||
|
code: str
|
||||||
|
message: str
|
||||||
|
node_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class GraphValidationError(ValueError):
|
||||||
|
"""Raised when graph validation fails."""
|
||||||
|
|
||||||
|
def __init__(self, issues: Sequence[GraphValidationIssue]) -> None:
|
||||||
|
if not issues:
|
||||||
|
raise ValueError("GraphValidationError requires at least one issue.")
|
||||||
|
self.issues: tuple[GraphValidationIssue, ...] = tuple(issues)
|
||||||
|
message = "; ".join(f"[{issue.code}] {issue.message}" for issue in self.issues)
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class GraphValidationRule(Protocol):
|
||||||
|
"""Protocol that individual validation rules must satisfy."""
|
||||||
|
|
||||||
|
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
||||||
|
"""Validate the provided graph and return any discovered issues."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class _EdgeEndpointValidator:
|
||||||
|
"""Ensures all edges reference existing nodes."""
|
||||||
|
|
||||||
|
missing_node_code: str = "MISSING_NODE"
|
||||||
|
|
||||||
|
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
||||||
|
issues: list[GraphValidationIssue] = []
|
||||||
|
for edge in graph.edges.values():
|
||||||
|
if edge.tail not in graph.nodes:
|
||||||
|
issues.append(
|
||||||
|
GraphValidationIssue(
|
||||||
|
code=self.missing_node_code,
|
||||||
|
message=f"Edge {edge.id} references unknown source node '{edge.tail}'.",
|
||||||
|
node_id=edge.tail,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if edge.head not in graph.nodes:
|
||||||
|
issues.append(
|
||||||
|
GraphValidationIssue(
|
||||||
|
code=self.missing_node_code,
|
||||||
|
message=f"Edge {edge.id} references unknown target node '{edge.head}'.",
|
||||||
|
node_id=edge.head,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return issues
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class _RootNodeValidator:
|
||||||
|
"""Validates root node invariants."""
|
||||||
|
|
||||||
|
invalid_root_code: str = "INVALID_ROOT"
|
||||||
|
container_entry_types: tuple[NodeType, ...] = (NodeType.ITERATION_START, NodeType.LOOP_START)
|
||||||
|
|
||||||
|
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
||||||
|
root_node = graph.root_node
|
||||||
|
issues: list[GraphValidationIssue] = []
|
||||||
|
if root_node.id not in graph.nodes:
|
||||||
|
issues.append(
|
||||||
|
GraphValidationIssue(
|
||||||
|
code=self.invalid_root_code,
|
||||||
|
message=f"Root node '{root_node.id}' is missing from the node registry.",
|
||||||
|
node_id=root_node.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return issues
|
||||||
|
|
||||||
|
node_type = getattr(root_node, "node_type", None)
|
||||||
|
if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types:
|
||||||
|
issues.append(
|
||||||
|
GraphValidationIssue(
|
||||||
|
code=self.invalid_root_code,
|
||||||
|
message=f"Root node '{root_node.id}' must declare execution type 'root'.",
|
||||||
|
node_id=root_node.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return issues
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class GraphValidator:
|
||||||
|
"""Coordinates execution of graph validation rules."""
|
||||||
|
|
||||||
|
rules: tuple[GraphValidationRule, ...]
|
||||||
|
|
||||||
|
def validate(self, graph: Graph) -> None:
|
||||||
|
"""Validate the graph against all configured rules."""
|
||||||
|
issues: list[GraphValidationIssue] = []
|
||||||
|
for rule in self.rules:
|
||||||
|
issues.extend(rule.validate(graph))
|
||||||
|
|
||||||
|
if issues:
|
||||||
|
raise GraphValidationError(issues)
|
||||||
|
|
||||||
|
|
||||||
|
_DEFAULT_RULES: tuple[GraphValidationRule, ...] = (
|
||||||
|
_EdgeEndpointValidator(),
|
||||||
|
_RootNodeValidator(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_graph_validator() -> GraphValidator:
|
||||||
|
"""Construct the validator composed of default rules."""
|
||||||
|
return GraphValidator(_DEFAULT_RULES)
|
||||||
@ -26,8 +26,8 @@ class AgentNodeData(BaseNodeData):
|
|||||||
|
|
||||||
|
|
||||||
class ParamsAutoGenerated(IntEnum):
|
class ParamsAutoGenerated(IntEnum):
|
||||||
CLOSE = auto()
|
CLOSE = 0
|
||||||
OPEN = auto()
|
OPEN = 1
|
||||||
|
|
||||||
|
|
||||||
class AgentOldVersionModelFeatures(StrEnum):
|
class AgentOldVersionModelFeatures(StrEnum):
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
|
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||||
|
from .usage_tracking_mixin import LLMUsageTrackingMixin
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseIterationNodeData",
|
"BaseIterationNodeData",
|
||||||
@ -6,4 +7,5 @@ __all__ = [
|
|||||||
"BaseLoopNodeData",
|
"BaseLoopNodeData",
|
||||||
"BaseLoopState",
|
"BaseLoopState",
|
||||||
"BaseNodeData",
|
"BaseNodeData",
|
||||||
|
"LLMUsageTrackingMixin",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
from builtins import type as type_
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
@ -58,10 +59,9 @@ class DefaultValue(BaseModel):
|
|||||||
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
|
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _validate_array(value: Any, element_type: DefaultValueType) -> bool:
|
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
|
||||||
"""Unified array type validation"""
|
"""Unified array type validation"""
|
||||||
# FIXME, type ignore here for do not find the reason mypy complain, if find the root cause, please fix it
|
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
|
||||||
return isinstance(value, list) and all(isinstance(x, element_type) for x in value) # type: ignore
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _convert_number(value: str) -> float:
|
def _convert_number(value: str) -> float:
|
||||||
|
|||||||
28
api/core/workflow/nodes/base/usage_tracking_mixin.py
Normal file
28
api/core/workflow/nodes/base/usage_tracking_mixin.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
|
from core.workflow.runtime import GraphRuntimeState
|
||||||
|
|
||||||
|
|
||||||
|
class LLMUsageTrackingMixin:
|
||||||
|
"""Provides shared helpers for merging and recording LLM usage within workflow nodes."""
|
||||||
|
|
||||||
|
graph_runtime_state: GraphRuntimeState
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _merge_usage(current: LLMUsage, new_usage: LLMUsage | None) -> LLMUsage:
|
||||||
|
"""Return a combined usage snapshot, preserving zero-value inputs."""
|
||||||
|
if new_usage is None or new_usage.total_tokens <= 0:
|
||||||
|
return current
|
||||||
|
if current.total_tokens == 0:
|
||||||
|
return new_usage
|
||||||
|
return current.plus(new_usage)
|
||||||
|
|
||||||
|
def _accumulate_usage(self, usage: LLMUsage) -> None:
|
||||||
|
"""Push usage into the graph runtime accumulator for downstream reporting."""
|
||||||
|
if usage.total_tokens <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
current_usage = self.graph_runtime_state.llm_usage
|
||||||
|
if current_usage.total_tokens == 0:
|
||||||
|
self.graph_runtime_state.llm_usage = usage.model_copy()
|
||||||
|
else:
|
||||||
|
self.graph_runtime_state.llm_usage = current_usage.plus(usage)
|
||||||
@ -10,10 +10,10 @@ from typing import Any
|
|||||||
import chardet
|
import chardet
|
||||||
import docx
|
import docx
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pypandoc # type: ignore
|
import pypandoc
|
||||||
import pypdfium2 # type: ignore
|
import pypdfium2
|
||||||
import webvtt # type: ignore
|
import webvtt
|
||||||
import yaml # type: ignore
|
import yaml
|
||||||
from docx.document import Document
|
from docx.document import Document
|
||||||
from docx.oxml.table import CT_Tbl
|
from docx.oxml.table import CT_Tbl
|
||||||
from docx.oxml.text.paragraph import CT_P
|
from docx.oxml.text.paragraph import CT_P
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, NewType, cast
|
|||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
from typing_extensions import TypeIs
|
from typing_extensions import TypeIs
|
||||||
|
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
from core.variables import IntegerVariable, NoneSegment
|
from core.variables import IntegerVariable, NoneSegment
|
||||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||||
from core.variables.variables import VariableUnion
|
from core.variables.variables import VariableUnion
|
||||||
@ -34,6 +35,7 @@ from core.workflow.node_events import (
|
|||||||
NodeRunResult,
|
NodeRunResult,
|
||||||
StreamCompletedEvent,
|
StreamCompletedEvent,
|
||||||
)
|
)
|
||||||
|
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||||
@ -58,7 +60,7 @@ logger = logging.getLogger(__name__)
|
|||||||
EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment)
|
EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment)
|
||||||
|
|
||||||
|
|
||||||
class IterationNode(Node):
|
class IterationNode(LLMUsageTrackingMixin, Node):
|
||||||
"""
|
"""
|
||||||
Iteration Node.
|
Iteration Node.
|
||||||
"""
|
"""
|
||||||
@ -118,6 +120,7 @@ class IterationNode(Node):
|
|||||||
started_at = naive_utc_now()
|
started_at = naive_utc_now()
|
||||||
iter_run_map: dict[str, float] = {}
|
iter_run_map: dict[str, float] = {}
|
||||||
outputs: list[object] = []
|
outputs: list[object] = []
|
||||||
|
usage_accumulator = [LLMUsage.empty_usage()]
|
||||||
|
|
||||||
yield IterationStartedEvent(
|
yield IterationStartedEvent(
|
||||||
start_at=started_at,
|
start_at=started_at,
|
||||||
@ -130,22 +133,27 @@ class IterationNode(Node):
|
|||||||
iterator_list_value=iterator_list_value,
|
iterator_list_value=iterator_list_value,
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
iter_run_map=iter_run_map,
|
iter_run_map=iter_run_map,
|
||||||
|
usage_accumulator=usage_accumulator,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._accumulate_usage(usage_accumulator[0])
|
||||||
yield from self._handle_iteration_success(
|
yield from self._handle_iteration_success(
|
||||||
started_at=started_at,
|
started_at=started_at,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
iterator_list_value=iterator_list_value,
|
iterator_list_value=iterator_list_value,
|
||||||
iter_run_map=iter_run_map,
|
iter_run_map=iter_run_map,
|
||||||
|
usage=usage_accumulator[0],
|
||||||
)
|
)
|
||||||
except IterationNodeError as e:
|
except IterationNodeError as e:
|
||||||
|
self._accumulate_usage(usage_accumulator[0])
|
||||||
yield from self._handle_iteration_failure(
|
yield from self._handle_iteration_failure(
|
||||||
started_at=started_at,
|
started_at=started_at,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
iterator_list_value=iterator_list_value,
|
iterator_list_value=iterator_list_value,
|
||||||
iter_run_map=iter_run_map,
|
iter_run_map=iter_run_map,
|
||||||
|
usage=usage_accumulator[0],
|
||||||
error=e,
|
error=e,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -196,6 +204,7 @@ class IterationNode(Node):
|
|||||||
iterator_list_value: Sequence[object],
|
iterator_list_value: Sequence[object],
|
||||||
outputs: list[object],
|
outputs: list[object],
|
||||||
iter_run_map: dict[str, float],
|
iter_run_map: dict[str, float],
|
||||||
|
usage_accumulator: list[LLMUsage],
|
||||||
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
|
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
|
||||||
if self._node_data.is_parallel:
|
if self._node_data.is_parallel:
|
||||||
# Parallel mode execution
|
# Parallel mode execution
|
||||||
@ -203,6 +212,7 @@ class IterationNode(Node):
|
|||||||
iterator_list_value=iterator_list_value,
|
iterator_list_value=iterator_list_value,
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
iter_run_map=iter_run_map,
|
iter_run_map=iter_run_map,
|
||||||
|
usage_accumulator=usage_accumulator,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Sequential mode execution
|
# Sequential mode execution
|
||||||
@ -228,6 +238,9 @@ class IterationNode(Node):
|
|||||||
|
|
||||||
# Update the total tokens from this iteration
|
# Update the total tokens from this iteration
|
||||||
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
|
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
|
||||||
|
usage_accumulator[0] = self._merge_usage(
|
||||||
|
usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage
|
||||||
|
)
|
||||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||||
|
|
||||||
def _execute_parallel_iterations(
|
def _execute_parallel_iterations(
|
||||||
@ -235,6 +248,7 @@ class IterationNode(Node):
|
|||||||
iterator_list_value: Sequence[object],
|
iterator_list_value: Sequence[object],
|
||||||
outputs: list[object],
|
outputs: list[object],
|
||||||
iter_run_map: dict[str, float],
|
iter_run_map: dict[str, float],
|
||||||
|
usage_accumulator: list[LLMUsage],
|
||||||
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
|
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
|
||||||
# Initialize outputs list with None values to maintain order
|
# Initialize outputs list with None values to maintain order
|
||||||
outputs.extend([None] * len(iterator_list_value))
|
outputs.extend([None] * len(iterator_list_value))
|
||||||
@ -245,7 +259,16 @@ class IterationNode(Node):
|
|||||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
# Submit all iteration tasks
|
# Submit all iteration tasks
|
||||||
future_to_index: dict[
|
future_to_index: dict[
|
||||||
Future[tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]],
|
Future[
|
||||||
|
tuple[
|
||||||
|
datetime,
|
||||||
|
list[GraphNodeEventBase],
|
||||||
|
object | None,
|
||||||
|
int,
|
||||||
|
dict[str, VariableUnion],
|
||||||
|
LLMUsage,
|
||||||
|
]
|
||||||
|
],
|
||||||
int,
|
int,
|
||||||
] = {}
|
] = {}
|
||||||
for index, item in enumerate(iterator_list_value):
|
for index, item in enumerate(iterator_list_value):
|
||||||
@ -264,7 +287,14 @@ class IterationNode(Node):
|
|||||||
index = future_to_index[future]
|
index = future_to_index[future]
|
||||||
try:
|
try:
|
||||||
result = future.result()
|
result = future.result()
|
||||||
iter_start_at, events, output_value, tokens_used, conversation_snapshot = result
|
(
|
||||||
|
iter_start_at,
|
||||||
|
events,
|
||||||
|
output_value,
|
||||||
|
tokens_used,
|
||||||
|
conversation_snapshot,
|
||||||
|
iteration_usage,
|
||||||
|
) = result
|
||||||
|
|
||||||
# Update outputs at the correct index
|
# Update outputs at the correct index
|
||||||
outputs[index] = output_value
|
outputs[index] = output_value
|
||||||
@ -276,6 +306,8 @@ class IterationNode(Node):
|
|||||||
self.graph_runtime_state.total_tokens += tokens_used
|
self.graph_runtime_state.total_tokens += tokens_used
|
||||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||||
|
|
||||||
|
usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage)
|
||||||
|
|
||||||
# Sync conversation variables after iteration completion
|
# Sync conversation variables after iteration completion
|
||||||
self._sync_conversation_variables_from_snapshot(conversation_snapshot)
|
self._sync_conversation_variables_from_snapshot(conversation_snapshot)
|
||||||
|
|
||||||
@ -303,7 +335,7 @@ class IterationNode(Node):
|
|||||||
item: object,
|
item: object,
|
||||||
flask_app: Flask,
|
flask_app: Flask,
|
||||||
context_vars: contextvars.Context,
|
context_vars: contextvars.Context,
|
||||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]:
|
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion], LLMUsage]:
|
||||||
"""Execute a single iteration in parallel mode and return results."""
|
"""Execute a single iteration in parallel mode and return results."""
|
||||||
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
|
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
|
||||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
@ -332,6 +364,7 @@ class IterationNode(Node):
|
|||||||
output_value,
|
output_value,
|
||||||
graph_engine.graph_runtime_state.total_tokens,
|
graph_engine.graph_runtime_state.total_tokens,
|
||||||
conversation_snapshot,
|
conversation_snapshot,
|
||||||
|
graph_engine.graph_runtime_state.llm_usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_iteration_success(
|
def _handle_iteration_success(
|
||||||
@ -341,6 +374,8 @@ class IterationNode(Node):
|
|||||||
outputs: list[object],
|
outputs: list[object],
|
||||||
iterator_list_value: Sequence[object],
|
iterator_list_value: Sequence[object],
|
||||||
iter_run_map: dict[str, float],
|
iter_run_map: dict[str, float],
|
||||||
|
*,
|
||||||
|
usage: LLMUsage,
|
||||||
) -> Generator[NodeEventBase, None, None]:
|
) -> Generator[NodeEventBase, None, None]:
|
||||||
# Flatten the list of lists if all outputs are lists
|
# Flatten the list of lists if all outputs are lists
|
||||||
flattened_outputs = self._flatten_outputs_if_needed(outputs)
|
flattened_outputs = self._flatten_outputs_if_needed(outputs)
|
||||||
@ -351,7 +386,9 @@ class IterationNode(Node):
|
|||||||
outputs={"output": flattened_outputs},
|
outputs={"output": flattened_outputs},
|
||||||
steps=len(iterator_list_value),
|
steps=len(iterator_list_value),
|
||||||
metadata={
|
metadata={
|
||||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||||
|
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||||
|
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||||
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -362,8 +399,11 @@ class IterationNode(Node):
|
|||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
outputs={"output": flattened_outputs},
|
outputs={"output": flattened_outputs},
|
||||||
metadata={
|
metadata={
|
||||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||||
|
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||||
|
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||||
},
|
},
|
||||||
|
llm_usage=usage,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -400,6 +440,8 @@ class IterationNode(Node):
|
|||||||
outputs: list[object],
|
outputs: list[object],
|
||||||
iterator_list_value: Sequence[object],
|
iterator_list_value: Sequence[object],
|
||||||
iter_run_map: dict[str, float],
|
iter_run_map: dict[str, float],
|
||||||
|
*,
|
||||||
|
usage: LLMUsage,
|
||||||
error: IterationNodeError,
|
error: IterationNodeError,
|
||||||
) -> Generator[NodeEventBase, None, None]:
|
) -> Generator[NodeEventBase, None, None]:
|
||||||
# Flatten the list of lists if all outputs are lists (even in failure case)
|
# Flatten the list of lists if all outputs are lists (even in failure case)
|
||||||
@ -411,7 +453,9 @@ class IterationNode(Node):
|
|||||||
outputs={"output": flattened_outputs},
|
outputs={"output": flattened_outputs},
|
||||||
steps=len(iterator_list_value),
|
steps=len(iterator_list_value),
|
||||||
metadata={
|
metadata={
|
||||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||||
|
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||||
|
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||||
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||||
},
|
},
|
||||||
error=str(error),
|
error=str(error),
|
||||||
@ -420,6 +464,12 @@ class IterationNode(Node):
|
|||||||
node_run_result=NodeRunResult(
|
node_run_result=NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
error=str(error),
|
error=str(error),
|
||||||
|
metadata={
|
||||||
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||||
|
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||||
|
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||||
|
},
|
||||||
|
llm_usage=usage,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -15,14 +15,11 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
|
|||||||
from core.entities.agent_entities import PlanningStrategy
|
from core.entities.agent_entities import PlanningStrategy
|
||||||
from core.entities.model_entities import ModelStatus
|
from core.entities.model_entities import ModelStatus
|
||||||
from core.model_manager import ModelInstance, ModelManager
|
from core.model_manager import ModelInstance, ModelManager
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
PromptMessageRole,
|
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||||
)
|
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||||
from core.model_runtime.entities.model_entities import (
|
|
||||||
ModelFeature,
|
|
||||||
ModelType,
|
|
||||||
)
|
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.prompt.simple_prompt_transform import ModelMode
|
from core.prompt.simple_prompt_transform import ModelMode
|
||||||
from core.rag.datasource.retrieval_service import RetrievalService
|
from core.rag.datasource.retrieval_service import RetrievalService
|
||||||
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
||||||
@ -33,8 +30,14 @@ from core.variables import (
|
|||||||
)
|
)
|
||||||
from core.variables.segments import ArrayObjectSegment
|
from core.variables.segments import ArrayObjectSegment
|
||||||
from core.workflow.entities import GraphInitParams
|
from core.workflow.entities import GraphInitParams
|
||||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
from core.workflow.enums import (
|
||||||
|
ErrorStrategy,
|
||||||
|
NodeType,
|
||||||
|
WorkflowNodeExecutionMetadataKey,
|
||||||
|
WorkflowNodeExecutionStatus,
|
||||||
|
)
|
||||||
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
|
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
|
||||||
|
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
|
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
|
||||||
@ -80,7 +83,7 @@ default_retrieval_model = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeRetrievalNode(Node):
|
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node):
|
||||||
node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||||
|
|
||||||
_node_data: KnowledgeRetrievalNodeData
|
_node_data: KnowledgeRetrievalNodeData
|
||||||
@ -141,7 +144,7 @@ class KnowledgeRetrievalNode(Node):
|
|||||||
def version(cls):
|
def version(cls):
|
||||||
return "1"
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> NodeRunResult: # type: ignore
|
def _run(self) -> NodeRunResult:
|
||||||
# extract variables
|
# extract variables
|
||||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
|
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
|
||||||
if not isinstance(variable, StringSegment):
|
if not isinstance(variable, StringSegment):
|
||||||
@ -182,14 +185,21 @@ class KnowledgeRetrievalNode(Node):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# retrieve knowledge
|
# retrieve knowledge
|
||||||
|
usage = LLMUsage.empty_usage()
|
||||||
try:
|
try:
|
||||||
results = self._fetch_dataset_retriever(node_data=self._node_data, query=query)
|
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, query=query)
|
||||||
outputs = {"result": ArrayObjectSegment(value=results)}
|
outputs = {"result": ArrayObjectSegment(value=results)}
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
inputs=variables,
|
inputs=variables,
|
||||||
process_data={},
|
process_data={"usage": jsonable_encoder(usage)},
|
||||||
outputs=outputs, # type: ignore
|
outputs=outputs, # type: ignore
|
||||||
|
metadata={
|
||||||
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||||
|
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||||
|
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||||
|
},
|
||||||
|
llm_usage=usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
except KnowledgeRetrievalNodeError as e:
|
except KnowledgeRetrievalNodeError as e:
|
||||||
@ -199,6 +209,7 @@ class KnowledgeRetrievalNode(Node):
|
|||||||
inputs=variables,
|
inputs=variables,
|
||||||
error=str(e),
|
error=str(e),
|
||||||
error_type=type(e).__name__,
|
error_type=type(e).__name__,
|
||||||
|
llm_usage=usage,
|
||||||
)
|
)
|
||||||
# Temporary handle all exceptions from DatasetRetrieval class here.
|
# Temporary handle all exceptions from DatasetRetrieval class here.
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -207,11 +218,15 @@ class KnowledgeRetrievalNode(Node):
|
|||||||
inputs=variables,
|
inputs=variables,
|
||||||
error=str(e),
|
error=str(e),
|
||||||
error_type=type(e).__name__,
|
error_type=type(e).__name__,
|
||||||
|
llm_usage=usage,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]:
|
def _fetch_dataset_retriever(
|
||||||
|
self, node_data: KnowledgeRetrievalNodeData, query: str
|
||||||
|
) -> tuple[list[dict[str, Any]], LLMUsage]:
|
||||||
|
usage = LLMUsage.empty_usage()
|
||||||
available_datasets = []
|
available_datasets = []
|
||||||
dataset_ids = node_data.dataset_ids
|
dataset_ids = node_data.dataset_ids
|
||||||
|
|
||||||
@ -245,9 +260,10 @@ class KnowledgeRetrievalNode(Node):
|
|||||||
if not dataset:
|
if not dataset:
|
||||||
continue
|
continue
|
||||||
available_datasets.append(dataset)
|
available_datasets.append(dataset)
|
||||||
metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition(
|
metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition(
|
||||||
[dataset.id for dataset in available_datasets], query, node_data
|
[dataset.id for dataset in available_datasets], query, node_data
|
||||||
)
|
)
|
||||||
|
usage = self._merge_usage(usage, metadata_usage)
|
||||||
all_documents = []
|
all_documents = []
|
||||||
dataset_retrieval = DatasetRetrieval()
|
dataset_retrieval = DatasetRetrieval()
|
||||||
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||||
@ -330,6 +346,8 @@ class KnowledgeRetrievalNode(Node):
|
|||||||
metadata_filter_document_ids=metadata_filter_document_ids,
|
metadata_filter_document_ids=metadata_filter_document_ids,
|
||||||
metadata_condition=metadata_condition,
|
metadata_condition=metadata_condition,
|
||||||
)
|
)
|
||||||
|
usage = self._merge_usage(usage, dataset_retrieval.llm_usage)
|
||||||
|
|
||||||
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
||||||
external_documents = [item for item in all_documents if item.provider == "external"]
|
external_documents = [item for item in all_documents if item.provider == "external"]
|
||||||
retrieval_resource_list = []
|
retrieval_resource_list = []
|
||||||
@ -406,11 +424,12 @@ class KnowledgeRetrievalNode(Node):
|
|||||||
)
|
)
|
||||||
for position, item in enumerate(retrieval_resource_list, start=1):
|
for position, item in enumerate(retrieval_resource_list, start=1):
|
||||||
item["metadata"]["position"] = position
|
item["metadata"]["position"] = position
|
||||||
return retrieval_resource_list
|
return retrieval_resource_list, usage
|
||||||
|
|
||||||
def _get_metadata_filter_condition(
|
def _get_metadata_filter_condition(
|
||||||
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||||
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]:
|
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]:
|
||||||
|
usage = LLMUsage.empty_usage()
|
||||||
document_query = db.session.query(Document).where(
|
document_query = db.session.query(Document).where(
|
||||||
Document.dataset_id.in_(dataset_ids),
|
Document.dataset_id.in_(dataset_ids),
|
||||||
Document.indexing_status == "completed",
|
Document.indexing_status == "completed",
|
||||||
@ -420,9 +439,12 @@ class KnowledgeRetrievalNode(Node):
|
|||||||
filters: list[Any] = []
|
filters: list[Any] = []
|
||||||
metadata_condition = None
|
metadata_condition = None
|
||||||
if node_data.metadata_filtering_mode == "disabled":
|
if node_data.metadata_filtering_mode == "disabled":
|
||||||
return None, None
|
return None, None, usage
|
||||||
elif node_data.metadata_filtering_mode == "automatic":
|
elif node_data.metadata_filtering_mode == "automatic":
|
||||||
automatic_metadata_filters = self._automatic_metadata_filter_func(dataset_ids, query, node_data)
|
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
|
||||||
|
dataset_ids, query, node_data
|
||||||
|
)
|
||||||
|
usage = self._merge_usage(usage, automatic_usage)
|
||||||
if automatic_metadata_filters:
|
if automatic_metadata_filters:
|
||||||
conditions = []
|
conditions = []
|
||||||
for sequence, filter in enumerate(automatic_metadata_filters):
|
for sequence, filter in enumerate(automatic_metadata_filters):
|
||||||
@ -443,7 +465,7 @@ class KnowledgeRetrievalNode(Node):
|
|||||||
metadata_condition = MetadataCondition(
|
metadata_condition = MetadataCondition(
|
||||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator
|
logical_operator=node_data.metadata_filtering_conditions.logical_operator
|
||||||
if node_data.metadata_filtering_conditions
|
if node_data.metadata_filtering_conditions
|
||||||
else "or", # type: ignore
|
else "or",
|
||||||
conditions=conditions,
|
conditions=conditions,
|
||||||
)
|
)
|
||||||
elif node_data.metadata_filtering_mode == "manual":
|
elif node_data.metadata_filtering_mode == "manual":
|
||||||
@ -457,10 +479,10 @@ class KnowledgeRetrievalNode(Node):
|
|||||||
expected_value = self.graph_runtime_state.variable_pool.convert_template(
|
expected_value = self.graph_runtime_state.variable_pool.convert_template(
|
||||||
expected_value
|
expected_value
|
||||||
).value[0]
|
).value[0]
|
||||||
if expected_value.value_type in {"number", "integer", "float"}: # type: ignore
|
if expected_value.value_type in {"number", "integer", "float"}:
|
||||||
expected_value = expected_value.value # type: ignore
|
expected_value = expected_value.value
|
||||||
elif expected_value.value_type == "string": # type: ignore
|
elif expected_value.value_type == "string":
|
||||||
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore
|
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid expected metadata value type")
|
raise ValueError("Invalid expected metadata value type")
|
||||||
conditions.append(
|
conditions.append(
|
||||||
@ -487,7 +509,7 @@ class KnowledgeRetrievalNode(Node):
|
|||||||
if (
|
if (
|
||||||
node_data.metadata_filtering_conditions
|
node_data.metadata_filtering_conditions
|
||||||
and node_data.metadata_filtering_conditions.logical_operator == "and"
|
and node_data.metadata_filtering_conditions.logical_operator == "and"
|
||||||
): # type: ignore
|
):
|
||||||
document_query = document_query.where(and_(*filters))
|
document_query = document_query.where(and_(*filters))
|
||||||
else:
|
else:
|
||||||
document_query = document_query.where(or_(*filters))
|
document_query = document_query.where(or_(*filters))
|
||||||
@ -496,11 +518,12 @@ class KnowledgeRetrievalNode(Node):
|
|||||||
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
|
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
|
||||||
for document in documents:
|
for document in documents:
|
||||||
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
|
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
|
||||||
return metadata_filter_document_ids, metadata_condition
|
return metadata_filter_document_ids, metadata_condition, usage
|
||||||
|
|
||||||
def _automatic_metadata_filter_func(
|
def _automatic_metadata_filter_func(
|
||||||
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||||
) -> list[dict[str, Any]]:
|
) -> tuple[list[dict[str, Any]], LLMUsage]:
|
||||||
|
usage = LLMUsage.empty_usage()
|
||||||
# get all metadata field
|
# get all metadata field
|
||||||
stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
|
stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
|
||||||
metadata_fields = db.session.scalars(stmt).all()
|
metadata_fields = db.session.scalars(stmt).all()
|
||||||
@ -548,6 +571,7 @@ class KnowledgeRetrievalNode(Node):
|
|||||||
for event in generator:
|
for event in generator:
|
||||||
if isinstance(event, ModelInvokeCompletedEvent):
|
if isinstance(event, ModelInvokeCompletedEvent):
|
||||||
result_text = event.text
|
result_text = event.text
|
||||||
|
usage = self._merge_usage(usage, event.usage)
|
||||||
break
|
break
|
||||||
|
|
||||||
result_text_json = parse_and_check_json_markdown(result_text, [])
|
result_text_json = parse_and_check_json_markdown(result_text, [])
|
||||||
@ -564,8 +588,8 @@ class KnowledgeRetrievalNode(Node):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
return []
|
return [], usage
|
||||||
return automatic_metadata_filters
|
return automatic_metadata_filters, usage
|
||||||
|
|
||||||
def _process_metadata_filter_func(
|
def _process_metadata_filter_func(
|
||||||
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]
|
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]
|
||||||
|
|||||||
@ -441,10 +441,14 @@ class LLMNode(Node):
|
|||||||
usage = LLMUsage.empty_usage()
|
usage = LLMUsage.empty_usage()
|
||||||
finish_reason = None
|
finish_reason = None
|
||||||
full_text_buffer = io.StringIO()
|
full_text_buffer = io.StringIO()
|
||||||
|
collected_structured_output = None # Collect structured_output from streaming chunks
|
||||||
# Consume the invoke result and handle generator exception
|
# Consume the invoke result and handle generator exception
|
||||||
try:
|
try:
|
||||||
for result in invoke_result:
|
for result in invoke_result:
|
||||||
if isinstance(result, LLMResultChunkWithStructuredOutput):
|
if isinstance(result, LLMResultChunkWithStructuredOutput):
|
||||||
|
# Collect structured_output from the chunk
|
||||||
|
if result.structured_output is not None:
|
||||||
|
collected_structured_output = dict(result.structured_output)
|
||||||
yield result
|
yield result
|
||||||
if isinstance(result, LLMResultChunk):
|
if isinstance(result, LLMResultChunk):
|
||||||
contents = result.delta.message.content
|
contents = result.delta.message.content
|
||||||
@ -492,6 +496,8 @@ class LLMNode(Node):
|
|||||||
finish_reason=finish_reason,
|
finish_reason=finish_reason,
|
||||||
# Reasoning content for workflow variables and downstream nodes
|
# Reasoning content for workflow variables and downstream nodes
|
||||||
reasoning_content=reasoning_content,
|
reasoning_content=reasoning_content,
|
||||||
|
# Pass structured output if collected from streaming chunks
|
||||||
|
structured_output=collected_structured_output,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||||
|
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
from core.variables import Segment, SegmentType
|
from core.variables import Segment, SegmentType
|
||||||
from core.workflow.enums import (
|
from core.workflow.enums import (
|
||||||
ErrorStrategy,
|
ErrorStrategy,
|
||||||
@ -27,6 +28,7 @@ from core.workflow.node_events import (
|
|||||||
NodeRunResult,
|
NodeRunResult,
|
||||||
StreamCompletedEvent,
|
StreamCompletedEvent,
|
||||||
)
|
)
|
||||||
|
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData
|
from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData
|
||||||
@ -40,7 +42,7 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LoopNode(Node):
|
class LoopNode(LLMUsageTrackingMixin, Node):
|
||||||
"""
|
"""
|
||||||
Loop Node.
|
Loop Node.
|
||||||
"""
|
"""
|
||||||
@ -108,7 +110,7 @@ class LoopNode(Node):
|
|||||||
raise ValueError(f"Invalid value for loop variable {loop_variable.label}")
|
raise ValueError(f"Invalid value for loop variable {loop_variable.label}")
|
||||||
variable_selector = [self._node_id, loop_variable.label]
|
variable_selector = [self._node_id, loop_variable.label]
|
||||||
variable = segment_to_variable(segment=processed_segment, selector=variable_selector)
|
variable = segment_to_variable(segment=processed_segment, selector=variable_selector)
|
||||||
self.graph_runtime_state.variable_pool.add(variable_selector, variable)
|
self.graph_runtime_state.variable_pool.add(variable_selector, variable.value)
|
||||||
loop_variable_selectors[loop_variable.label] = variable_selector
|
loop_variable_selectors[loop_variable.label] = variable_selector
|
||||||
inputs[loop_variable.label] = processed_segment.value
|
inputs[loop_variable.label] = processed_segment.value
|
||||||
|
|
||||||
@ -117,6 +119,7 @@ class LoopNode(Node):
|
|||||||
|
|
||||||
loop_duration_map: dict[str, float] = {}
|
loop_duration_map: dict[str, float] = {}
|
||||||
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
|
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
|
||||||
|
loop_usage = LLMUsage.empty_usage()
|
||||||
|
|
||||||
# Start Loop event
|
# Start Loop event
|
||||||
yield LoopStartedEvent(
|
yield LoopStartedEvent(
|
||||||
@ -163,6 +166,9 @@ class LoopNode(Node):
|
|||||||
# Update the total tokens from this iteration
|
# Update the total tokens from this iteration
|
||||||
cost_tokens += graph_engine.graph_runtime_state.total_tokens
|
cost_tokens += graph_engine.graph_runtime_state.total_tokens
|
||||||
|
|
||||||
|
# Accumulate usage from the sub-graph execution
|
||||||
|
loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage)
|
||||||
|
|
||||||
# Collect loop variable values after iteration
|
# Collect loop variable values after iteration
|
||||||
single_loop_variable = {}
|
single_loop_variable = {}
|
||||||
for key, selector in loop_variable_selectors.items():
|
for key, selector in loop_variable_selectors.items():
|
||||||
@ -189,6 +195,7 @@ class LoopNode(Node):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.graph_runtime_state.total_tokens += cost_tokens
|
self.graph_runtime_state.total_tokens += cost_tokens
|
||||||
|
self._accumulate_usage(loop_usage)
|
||||||
# Loop completed successfully
|
# Loop completed successfully
|
||||||
yield LoopSucceededEvent(
|
yield LoopSucceededEvent(
|
||||||
start_at=start_at,
|
start_at=start_at,
|
||||||
@ -196,7 +203,9 @@ class LoopNode(Node):
|
|||||||
outputs=self._node_data.outputs,
|
outputs=self._node_data.outputs,
|
||||||
steps=loop_count,
|
steps=loop_count,
|
||||||
metadata={
|
metadata={
|
||||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: cost_tokens,
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
||||||
|
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
|
||||||
|
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
|
||||||
"completed_reason": "loop_break" if reach_break_condition else "loop_completed",
|
"completed_reason": "loop_break" if reach_break_condition else "loop_completed",
|
||||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||||
@ -207,22 +216,28 @@ class LoopNode(Node):
|
|||||||
node_run_result=NodeRunResult(
|
node_run_result=NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
metadata={
|
metadata={
|
||||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
||||||
|
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
|
||||||
|
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
|
||||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||||
},
|
},
|
||||||
outputs=self._node_data.outputs,
|
outputs=self._node_data.outputs,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
|
llm_usage=loop_usage,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
self._accumulate_usage(loop_usage)
|
||||||
yield LoopFailedEvent(
|
yield LoopFailedEvent(
|
||||||
start_at=start_at,
|
start_at=start_at,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
steps=loop_count,
|
steps=loop_count,
|
||||||
metadata={
|
metadata={
|
||||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
||||||
|
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
|
||||||
|
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
|
||||||
"completed_reason": "error",
|
"completed_reason": "error",
|
||||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||||
@ -235,10 +250,13 @@ class LoopNode(Node):
|
|||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
error=str(e),
|
error=str(e),
|
||||||
metadata={
|
metadata={
|
||||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
||||||
|
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
|
||||||
|
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
|
||||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||||
},
|
},
|
||||||
|
llm_usage=loop_usage,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, final
|
|||||||
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
|
from core.workflow.enums import NodeType
|
||||||
from core.workflow.graph import NodeFactory
|
from core.workflow.graph import NodeFactory
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from libs.typing import is_str, is_str_dict
|
from libs.typing import is_str, is_str_dict
|
||||||
@ -82,8 +82,4 @@ class DifyNodeFactory(NodeFactory):
|
|||||||
raise ValueError(f"Node {node_id} missing data information")
|
raise ValueError(f"Node {node_id} missing data information")
|
||||||
node_instance.init_node_data(node_data)
|
node_instance.init_node_data(node_data)
|
||||||
|
|
||||||
# If node has fail branch, change execution type to branch
|
|
||||||
if node_instance.error_strategy == ErrorStrategy.FAIL_BRANCH:
|
|
||||||
node_instance.execution_type = NodeExecutionType.BRANCH
|
|
||||||
|
|
||||||
return node_instance
|
return node_instance
|
||||||
|
|||||||
@ -747,7 +747,7 @@ class ParameterExtractorNode(Node):
|
|||||||
if model_mode == ModelMode.CHAT:
|
if model_mode == ModelMode.CHAT:
|
||||||
system_prompt_messages = ChatModelMessage(
|
system_prompt_messages = ChatModelMessage(
|
||||||
role=PromptMessageRole.SYSTEM,
|
role=PromptMessageRole.SYSTEM,
|
||||||
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str).replace("{{instructions}}", instruction),
|
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str, instructions=instruction),
|
||||||
)
|
)
|
||||||
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
|
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
|
||||||
return [system_prompt_messages, user_prompt_message]
|
return [system_prompt_messages, user_prompt_message]
|
||||||
|
|||||||
@ -135,7 +135,7 @@ Here are the chat histories between human and assistant, inside <histories></his
|
|||||||
### Instructions:
|
### Instructions:
|
||||||
Some extra information are provided below, you should always follow the instructions as possible as you can.
|
Some extra information are provided below, you should always follow the instructions as possible as you can.
|
||||||
<instructions>
|
<instructions>
|
||||||
{{instructions}}
|
{instructions}
|
||||||
</instructions>
|
</instructions>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@ -6,10 +6,13 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||||
from core.file import File, FileTransferMethod
|
from core.file import File, FileTransferMethod
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
|
from core.tools.__base.tool import Tool
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||||
from core.tools.errors import ToolInvokeError
|
from core.tools.errors import ToolInvokeError
|
||||||
from core.tools.tool_engine import ToolEngine
|
from core.tools.tool_engine import ToolEngine
|
||||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||||
|
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
|
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
|
||||||
from core.variables.variables import ArrayAnyVariable
|
from core.variables.variables import ArrayAnyVariable
|
||||||
from core.workflow.enums import (
|
from core.workflow.enums import (
|
||||||
@ -136,13 +139,14 @@ class ToolNode(Node):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# convert tool messages
|
# convert tool messages
|
||||||
yield from self._transform_message(
|
_ = yield from self._transform_message(
|
||||||
messages=message_stream,
|
messages=message_stream,
|
||||||
tool_info=tool_info,
|
tool_info=tool_info,
|
||||||
parameters_for_log=parameters_for_log,
|
parameters_for_log=parameters_for_log,
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
|
tool_runtime=tool_runtime,
|
||||||
)
|
)
|
||||||
except ToolInvokeError as e:
|
except ToolInvokeError as e:
|
||||||
yield StreamCompletedEvent(
|
yield StreamCompletedEvent(
|
||||||
@ -236,7 +240,8 @@ class ToolNode(Node):
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
node_id: str,
|
node_id: str,
|
||||||
) -> Generator:
|
tool_runtime: Tool,
|
||||||
|
) -> Generator[NodeEventBase, None, LLMUsage]:
|
||||||
"""
|
"""
|
||||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||||
"""
|
"""
|
||||||
@ -424,17 +429,34 @@ class ToolNode(Node):
|
|||||||
is_final=True,
|
is_final=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
usage = self._extract_tool_usage(tool_runtime)
|
||||||
|
|
||||||
|
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
|
||||||
|
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||||
|
}
|
||||||
|
if usage.total_tokens > 0:
|
||||||
|
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
|
||||||
|
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
|
||||||
|
metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency
|
||||||
|
|
||||||
yield StreamCompletedEvent(
|
yield StreamCompletedEvent(
|
||||||
node_run_result=NodeRunResult(
|
node_run_result=NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
|
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
|
||||||
metadata={
|
metadata=metadata,
|
||||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
|
||||||
},
|
|
||||||
inputs=parameters_for_log,
|
inputs=parameters_for_log,
|
||||||
|
llm_usage=usage,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return usage
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage:
|
||||||
|
if isinstance(tool_runtime, WorkflowTool):
|
||||||
|
return tool_runtime.latest_usage
|
||||||
|
return LLMUsage.empty_usage()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_variable_selector_to_variable_mapping(
|
def _extract_variable_selector_to_variable_mapping(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@ -260,7 +260,7 @@ class VariablePool(BaseModel):
|
|||||||
# This ensures that we can keep the id of the system variables intact.
|
# This ensures that we can keep the id of the system variables intact.
|
||||||
if self._has(selector):
|
if self._has(selector):
|
||||||
continue
|
continue
|
||||||
self.add(selector, value) # type: ignore
|
self.add(selector, value)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def empty(cls) -> "VariablePool":
|
def empty(cls) -> "VariablePool":
|
||||||
|
|||||||
@ -32,7 +32,8 @@ if [[ "${MODE}" == "worker" ]]; then
|
|||||||
|
|
||||||
exec celery -A celery_entrypoint.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \
|
exec celery -A celery_entrypoint.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \
|
||||||
--max-tasks-per-child ${MAX_TASKS_PER_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \
|
--max-tasks-per-child ${MAX_TASKS_PER_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \
|
||||||
-Q ${CELERY_QUEUES:-dataset,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation}
|
-Q ${CELERY_QUEUES:-dataset,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation} \
|
||||||
|
--prefetch-multiplier=1
|
||||||
|
|
||||||
elif [[ "${MODE}" == "beat" ]]; then
|
elif [[ "${MODE}" == "beat" ]]; then
|
||||||
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}
|
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}
|
||||||
|
|||||||
@ -6,8 +6,8 @@ from tasks.clean_dataset_task import clean_dataset_task
|
|||||||
@dataset_was_deleted.connect
|
@dataset_was_deleted.connect
|
||||||
def handle(sender: Dataset, **kwargs):
|
def handle(sender: Dataset, **kwargs):
|
||||||
dataset = sender
|
dataset = sender
|
||||||
assert dataset.doc_form
|
if not dataset.doc_form or not dataset.indexing_technique:
|
||||||
assert dataset.indexing_technique
|
return
|
||||||
clean_dataset_task.delay(
|
clean_dataset_task.delay(
|
||||||
dataset.id,
|
dataset.id,
|
||||||
dataset.tenant_id,
|
dataset.tenant_id,
|
||||||
|
|||||||
@ -8,6 +8,6 @@ def handle(sender, **kwargs):
|
|||||||
dataset_id = kwargs.get("dataset_id")
|
dataset_id = kwargs.get("dataset_id")
|
||||||
doc_form = kwargs.get("doc_form")
|
doc_form = kwargs.get("doc_form")
|
||||||
file_id = kwargs.get("file_id")
|
file_id = kwargs.get("file_id")
|
||||||
assert dataset_id is not None
|
if not dataset_id or not doc_form:
|
||||||
assert doc_form is not None
|
return
|
||||||
clean_document_task.delay(document_id, dataset_id, doc_form, file_id)
|
clean_document_task.delay(document_id, dataset_id, doc_form, file_id)
|
||||||
|
|||||||
@ -1,7 +1,12 @@
|
|||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from constants import HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN
|
from constants import HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN, HEADER_NAME_PASSPORT
|
||||||
from dify_app import DifyApp
|
from dify_app import DifyApp
|
||||||
|
|
||||||
|
BASE_CORS_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE, HEADER_NAME_PASSPORT)
|
||||||
|
SERVICE_API_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, "Authorization")
|
||||||
|
AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN)
|
||||||
|
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
|
||||||
|
|
||||||
|
|
||||||
def init_app(app: DifyApp):
|
def init_app(app: DifyApp):
|
||||||
# register blueprint routers
|
# register blueprint routers
|
||||||
@ -17,7 +22,7 @@ def init_app(app: DifyApp):
|
|||||||
|
|
||||||
CORS(
|
CORS(
|
||||||
service_api_bp,
|
service_api_bp,
|
||||||
allow_headers=["Content-Type", "Authorization", HEADER_NAME_APP_CODE],
|
allow_headers=list(SERVICE_API_HEADERS),
|
||||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||||
)
|
)
|
||||||
app.register_blueprint(service_api_bp)
|
app.register_blueprint(service_api_bp)
|
||||||
@ -26,7 +31,7 @@ def init_app(app: DifyApp):
|
|||||||
web_bp,
|
web_bp,
|
||||||
resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
|
resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
|
||||||
supports_credentials=True,
|
supports_credentials=True,
|
||||||
allow_headers=["Content-Type", "Authorization", HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN],
|
allow_headers=list(AUTHENTICATED_HEADERS),
|
||||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||||
expose_headers=["X-Version", "X-Env"],
|
expose_headers=["X-Version", "X-Env"],
|
||||||
)
|
)
|
||||||
@ -36,7 +41,7 @@ def init_app(app: DifyApp):
|
|||||||
console_app_bp,
|
console_app_bp,
|
||||||
resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
|
resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
|
||||||
supports_credentials=True,
|
supports_credentials=True,
|
||||||
allow_headers=["Content-Type", "Authorization", HEADER_NAME_CSRF_TOKEN],
|
allow_headers=list(AUTHENTICATED_HEADERS),
|
||||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||||
expose_headers=["X-Version", "X-Env"],
|
expose_headers=["X-Version", "X-Env"],
|
||||||
)
|
)
|
||||||
@ -44,7 +49,7 @@ def init_app(app: DifyApp):
|
|||||||
|
|
||||||
CORS(
|
CORS(
|
||||||
files_bp,
|
files_bp,
|
||||||
allow_headers=["Content-Type", HEADER_NAME_CSRF_TOKEN],
|
allow_headers=list(FILES_HEADERS),
|
||||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||||
)
|
)
|
||||||
app.register_blueprint(files_bp)
|
app.register_blueprint(files_bp)
|
||||||
|
|||||||
@ -7,7 +7,7 @@ def is_enabled() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def init_app(app: DifyApp):
|
def init_app(app: DifyApp):
|
||||||
from flask_compress import Compress # type: ignore
|
from flask_compress import Compress
|
||||||
|
|
||||||
compress = Compress()
|
compress = Compress()
|
||||||
compress.init_app(app)
|
compress.init_app(app)
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
import flask_login # type: ignore
|
import flask_login
|
||||||
from flask import Response, request
|
from flask import Response, request
|
||||||
from flask_login import user_loaded_from_request, user_logged_in
|
from flask_login import user_loaded_from_request, user_logged_in
|
||||||
from werkzeug.exceptions import NotFound, Unauthorized
|
from werkzeug.exceptions import NotFound, Unauthorized
|
||||||
|
|||||||
@ -2,7 +2,7 @@ from dify_app import DifyApp
|
|||||||
|
|
||||||
|
|
||||||
def init_app(app: DifyApp):
|
def init_app(app: DifyApp):
|
||||||
import flask_migrate # type: ignore
|
import flask_migrate
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
|||||||
@ -103,7 +103,7 @@ def init_app(app: DifyApp):
|
|||||||
def shutdown_tracer():
|
def shutdown_tracer():
|
||||||
provider = trace.get_tracer_provider()
|
provider = trace.get_tracer_provider()
|
||||||
if hasattr(provider, "force_flush"):
|
if hasattr(provider, "force_flush"):
|
||||||
provider.force_flush() # ty: ignore [call-non-callable]
|
provider.force_flush()
|
||||||
|
|
||||||
class ExceptionLoggingHandler(logging.Handler):
|
class ExceptionLoggingHandler(logging.Handler):
|
||||||
"""Custom logging handler that creates spans for logging.exception() calls"""
|
"""Custom logging handler that creates spans for logging.exception() calls"""
|
||||||
|
|||||||
@ -6,4 +6,4 @@ def init_app(app: DifyApp):
|
|||||||
if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED:
|
if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED:
|
||||||
from werkzeug.middleware.proxy_fix import ProxyFix
|
from werkzeug.middleware.proxy_fix import ProxyFix
|
||||||
|
|
||||||
app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore
|
app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore[method-assign]
|
||||||
|
|||||||
@ -5,7 +5,7 @@ from dify_app import DifyApp
|
|||||||
def init_app(app: DifyApp):
|
def init_app(app: DifyApp):
|
||||||
if dify_config.SENTRY_DSN:
|
if dify_config.SENTRY_DSN:
|
||||||
import sentry_sdk
|
import sentry_sdk
|
||||||
from langfuse import parse_error # type: ignore
|
from langfuse import parse_error
|
||||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||||
from sentry_sdk.integrations.flask import FlaskIntegration
|
from sentry_sdk.integrations.flask import FlaskIntegration
|
||||||
from werkzeug.exceptions import HTTPException
|
from werkzeug.exceptions import HTTPException
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import posixpath
|
import posixpath
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
|
||||||
import oss2 as aliyun_s3 # type: ignore
|
import oss2 as aliyun_s3
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from extensions.storage.base_storage import BaseStorage
|
from extensions.storage.base_storage import BaseStorage
|
||||||
|
|||||||
@ -2,9 +2,9 @@ import base64
|
|||||||
import hashlib
|
import hashlib
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
|
||||||
from baidubce.auth.bce_credentials import BceCredentials # type: ignore
|
from baidubce.auth.bce_credentials import BceCredentials
|
||||||
from baidubce.bce_client_configuration import BceClientConfiguration # type: ignore
|
from baidubce.bce_client_configuration import BceClientConfiguration
|
||||||
from baidubce.services.bos.bos_client import BosClient # type: ignore
|
from baidubce.services.bos.bos_client import BosClient
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from extensions.storage.base_storage import BaseStorage
|
from extensions.storage.base_storage import BaseStorage
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from collections.abc import Generator
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import clickzetta # type: ignore[import]
|
import clickzetta
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
from extensions.storage.base_storage import BaseStorage
|
from extensions.storage.base_storage import BaseStorage
|
||||||
|
|||||||
@ -34,7 +34,7 @@ class VolumePermissionManager:
|
|||||||
# Support two initialization methods: connection object or configuration dictionary
|
# Support two initialization methods: connection object or configuration dictionary
|
||||||
if isinstance(connection_or_config, dict):
|
if isinstance(connection_or_config, dict):
|
||||||
# Create connection from configuration dictionary
|
# Create connection from configuration dictionary
|
||||||
import clickzetta # type: ignore[import-untyped]
|
import clickzetta
|
||||||
|
|
||||||
config = connection_or_config
|
config = connection_or_config
|
||||||
self._connection = clickzetta.connect(
|
self._connection = clickzetta.connect(
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import io
|
|||||||
import json
|
import json
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
|
||||||
from google.cloud import storage as google_cloud_storage # type: ignore
|
from google.cloud import storage as google_cloud_storage
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from extensions.storage.base_storage import BaseStorage
|
from extensions.storage.base_storage import BaseStorage
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
|
||||||
from obs import ObsClient # type: ignore
|
from obs import ObsClient
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from extensions.storage.base_storage import BaseStorage
|
from extensions.storage.base_storage import BaseStorage
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
|
||||||
import boto3 # type: ignore
|
import boto3
|
||||||
from botocore.exceptions import ClientError # type: ignore
|
from botocore.exceptions import ClientError
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from extensions.storage.base_storage import BaseStorage
|
from extensions.storage.base_storage import BaseStorage
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
|
||||||
from qcloud_cos import CosConfig, CosS3Client # type: ignore
|
from qcloud_cos import CosConfig, CosS3Client
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from extensions.storage.base_storage import BaseStorage
|
from extensions.storage.base_storage import BaseStorage
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
|
||||||
import tos # type: ignore
|
import tos
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from extensions.storage.base_storage import BaseStorage
|
from extensions.storage.base_storage import BaseStorage
|
||||||
|
|||||||
@ -146,6 +146,6 @@ class ExternalApi(Api):
|
|||||||
kwargs["doc"] = dify_config.SWAGGER_UI_PATH if dify_config.SWAGGER_UI_ENABLED else False
|
kwargs["doc"] = dify_config.SWAGGER_UI_PATH if dify_config.SWAGGER_UI_ENABLED else False
|
||||||
|
|
||||||
# manual separate call on construction and init_app to ensure configs in kwargs effective
|
# manual separate call on construction and init_app to ensure configs in kwargs effective
|
||||||
super().__init__(app=None, *args, **kwargs) # type: ignore
|
super().__init__(app=None, *args, **kwargs)
|
||||||
self.init_app(app, **kwargs)
|
self.init_app(app, **kwargs)
|
||||||
register_external_error_handlers(self)
|
register_external_error_handlers(self)
|
||||||
|
|||||||
@ -23,7 +23,7 @@ from hashlib import sha1
|
|||||||
|
|
||||||
import Crypto.Hash.SHA1
|
import Crypto.Hash.SHA1
|
||||||
import Crypto.Util.number
|
import Crypto.Util.number
|
||||||
import gmpy2 # type: ignore
|
import gmpy2
|
||||||
from Crypto import Random
|
from Crypto import Random
|
||||||
from Crypto.Signature.pss import MGF1
|
from Crypto.Signature.pss import MGF1
|
||||||
from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes
|
from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes
|
||||||
@ -136,7 +136,7 @@ class PKCS1OAepCipher:
|
|||||||
# Step 3a (OS2IP)
|
# Step 3a (OS2IP)
|
||||||
em_int = bytes_to_long(em)
|
em_int = bytes_to_long(em)
|
||||||
# Step 3b (RSAEP)
|
# Step 3b (RSAEP)
|
||||||
m_int = gmpy2.powmod(em_int, self._key.e, self._key.n) # ty: ignore [unresolved-attribute]
|
m_int = gmpy2.powmod(em_int, self._key.e, self._key.n)
|
||||||
# Step 3c (I2OSP)
|
# Step 3c (I2OSP)
|
||||||
c = long_to_bytes(m_int, k)
|
c = long_to_bytes(m_int, k)
|
||||||
return c
|
return c
|
||||||
@ -169,7 +169,7 @@ class PKCS1OAepCipher:
|
|||||||
ct_int = bytes_to_long(ciphertext)
|
ct_int = bytes_to_long(ciphertext)
|
||||||
# Step 2b (RSADP)
|
# Step 2b (RSADP)
|
||||||
# m_int = self._key._decrypt(ct_int)
|
# m_int = self._key._decrypt(ct_int)
|
||||||
m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n) # ty: ignore [unresolved-attribute]
|
m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n)
|
||||||
# Complete step 2c (I2OSP)
|
# Complete step 2c (I2OSP)
|
||||||
em = long_to_bytes(m_int, k)
|
em = long_to_bytes(m_int, k)
|
||||||
# Step 3a
|
# Step 3a
|
||||||
@ -191,12 +191,12 @@ class PKCS1OAepCipher:
|
|||||||
# Step 3g
|
# Step 3g
|
||||||
one_pos = hLen + db[hLen:].find(b"\x01")
|
one_pos = hLen + db[hLen:].find(b"\x01")
|
||||||
lHash1 = db[:hLen]
|
lHash1 = db[:hLen]
|
||||||
invalid = bord(y) | int(one_pos < hLen) # type: ignore
|
invalid = bord(y) | int(one_pos < hLen) # type: ignore[arg-type]
|
||||||
hash_compare = strxor(lHash1, lHash)
|
hash_compare = strxor(lHash1, lHash)
|
||||||
for x in hash_compare:
|
for x in hash_compare:
|
||||||
invalid |= bord(x) # type: ignore
|
invalid |= bord(x) # type: ignore[arg-type]
|
||||||
for x in db[hLen:one_pos]:
|
for x in db[hLen:one_pos]:
|
||||||
invalid |= bord(x) # type: ignore
|
invalid |= bord(x) # type: ignore[arg-type]
|
||||||
if invalid != 0:
|
if invalid != 0:
|
||||||
raise ValueError("Incorrect decryption.")
|
raise ValueError("Incorrect decryption.")
|
||||||
# Step 4
|
# Step 4
|
||||||
|
|||||||
@ -81,6 +81,8 @@ class AvatarUrlField(fields.Raw):
|
|||||||
from models import Account
|
from models import Account
|
||||||
|
|
||||||
if isinstance(obj, Account) and obj.avatar is not None:
|
if isinstance(obj, Account) and obj.avatar is not None:
|
||||||
|
if obj.avatar.startswith(("http://", "https://")):
|
||||||
|
return obj.avatar
|
||||||
return file_helpers.get_signed_file_url(obj.avatar)
|
return file_helpers.get_signed_file_url(obj.avatar)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from functools import wraps
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from flask import current_app, g, has_request_context, request
|
from flask import current_app, g, has_request_context, request
|
||||||
from flask_login.config import EXEMPT_METHODS # type: ignore
|
from flask_login.config import EXEMPT_METHODS
|
||||||
from werkzeug.local import LocalProxy
|
from werkzeug.local import LocalProxy
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
@ -87,7 +87,7 @@ def _get_user() -> EndUser | Account | None:
|
|||||||
if "_login_user" not in g:
|
if "_login_user" not in g:
|
||||||
current_app.login_manager._load_user() # type: ignore
|
current_app.login_manager._load_user() # type: ignore
|
||||||
|
|
||||||
return g._login_user # type: ignore
|
return g._login_user
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
import sendgrid # type: ignore
|
import sendgrid
|
||||||
from python_http_client.exceptions import ForbiddenError, UnauthorizedError
|
from python_http_client.exceptions import ForbiddenError, UnauthorizedError
|
||||||
from sendgrid.helpers.mail import Content, Email, Mail, To # type: ignore
|
from sendgrid.helpers.mail import Content, Email, Mail, To
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from constants import (
|
|||||||
COOKIE_NAME_CSRF_TOKEN,
|
COOKIE_NAME_CSRF_TOKEN,
|
||||||
COOKIE_NAME_PASSPORT,
|
COOKIE_NAME_PASSPORT,
|
||||||
COOKIE_NAME_REFRESH_TOKEN,
|
COOKIE_NAME_REFRESH_TOKEN,
|
||||||
|
COOKIE_NAME_WEBAPP_ACCESS_TOKEN,
|
||||||
HEADER_NAME_CSRF_TOKEN,
|
HEADER_NAME_CSRF_TOKEN,
|
||||||
HEADER_NAME_PASSPORT,
|
HEADER_NAME_PASSPORT,
|
||||||
)
|
)
|
||||||
@ -81,6 +82,14 @@ def extract_access_token(request: Request) -> str | None:
|
|||||||
return _try_extract_from_cookie(request) or _try_extract_from_header(request)
|
return _try_extract_from_cookie(request) or _try_extract_from_header(request)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_webapp_access_token(request: Request) -> str | None:
|
||||||
|
"""
|
||||||
|
Try to extract webapp access token from cookie, then header.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return request.cookies.get(_real_cookie_name(COOKIE_NAME_WEBAPP_ACCESS_TOKEN)) or _try_extract_from_header(request)
|
||||||
|
|
||||||
|
|
||||||
def extract_webapp_passport(app_code: str, request: Request) -> str | None:
|
def extract_webapp_passport(app_code: str, request: Request) -> str | None:
|
||||||
"""
|
"""
|
||||||
Try to extract app token from header or params.
|
Try to extract app token from header or params.
|
||||||
@ -155,6 +164,10 @@ def clear_access_token_from_cookie(response: Response, samesite: str = "Lax"):
|
|||||||
_clear_cookie(response, COOKIE_NAME_ACCESS_TOKEN, samesite)
|
_clear_cookie(response, COOKIE_NAME_ACCESS_TOKEN, samesite)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_webapp_access_token_from_cookie(response: Response, samesite: str = "Lax"):
|
||||||
|
_clear_cookie(response, COOKIE_NAME_WEBAPP_ACCESS_TOKEN, samesite)
|
||||||
|
|
||||||
|
|
||||||
def clear_refresh_token_from_cookie(response: Response):
|
def clear_refresh_token_from_cookie(response: Response):
|
||||||
_clear_cookie(response, COOKIE_NAME_REFRESH_TOKEN)
|
_clear_cookie(response, COOKIE_NAME_REFRESH_TOKEN)
|
||||||
|
|
||||||
|
|||||||
@ -22,55 +22,6 @@ def upgrade():
|
|||||||
batch_op.add_column(sa.Column('app_mode', sa.String(length=255), nullable=True))
|
batch_op.add_column(sa.Column('app_mode', sa.String(length=255), nullable=True))
|
||||||
batch_op.create_index('message_app_mode_idx', ['app_mode'], unique=False)
|
batch_op.create_index('message_app_mode_idx', ['app_mode'], unique=False)
|
||||||
|
|
||||||
conn = op.get_bind()
|
|
||||||
|
|
||||||
# Strategy: Update in batches to minimize lock time
|
|
||||||
# For large tables (millions of rows), this prevents long-running transactions
|
|
||||||
batch_size = 10000
|
|
||||||
|
|
||||||
print("Starting backfill of app_mode from conversations...")
|
|
||||||
|
|
||||||
# Use a more efficient UPDATE with JOIN
|
|
||||||
# This query updates messages.app_mode from conversations.mode
|
|
||||||
# Using string formatting for LIMIT since it's a constant
|
|
||||||
update_query = f"""
|
|
||||||
UPDATE messages m
|
|
||||||
SET app_mode = c.mode
|
|
||||||
FROM conversations c
|
|
||||||
WHERE m.conversation_id = c.id
|
|
||||||
AND m.app_mode IS NULL
|
|
||||||
AND m.id IN (
|
|
||||||
SELECT id FROM messages
|
|
||||||
WHERE app_mode IS NULL
|
|
||||||
LIMIT {batch_size}
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Execute batched updates
|
|
||||||
total_updated = 0
|
|
||||||
iteration = 0
|
|
||||||
while True:
|
|
||||||
iteration += 1
|
|
||||||
result = conn.execute(sa.text(update_query))
|
|
||||||
|
|
||||||
# Check if result is None or has no rowcount
|
|
||||||
if result is None:
|
|
||||||
print("Warning: Query returned None, stopping backfill")
|
|
||||||
break
|
|
||||||
|
|
||||||
rows_updated = result.rowcount if hasattr(result, 'rowcount') else 0
|
|
||||||
total_updated += rows_updated
|
|
||||||
|
|
||||||
if rows_updated == 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
print(f"Iteration {iteration}: Updated {rows_updated} messages (total: {total_updated})")
|
|
||||||
|
|
||||||
# For very large tables, add a small delay to reduce load
|
|
||||||
# Uncomment if needed: import time; time.sleep(0.1)
|
|
||||||
|
|
||||||
print(f"Backfill completed. Total messages updated: {total_updated}")
|
|
||||||
|
|
||||||
# ### end Alembic commands ###
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,36 @@
|
|||||||
|
"""remove-builtin-template-user
|
||||||
|
|
||||||
|
Revision ID: ae662b25d9bc
|
||||||
|
Revises: d98acf217d43
|
||||||
|
Create Date: 2025-10-21 14:30:28.566192
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'ae662b25d9bc'
|
||||||
|
down_revision = 'd98acf217d43'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
|
||||||
|
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
|
||||||
|
batch_op.drop_column('updated_by')
|
||||||
|
batch_op.drop_column('created_by')
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('created_by', sa.UUID(), autoincrement=False, nullable=False))
|
||||||
|
batch_op.add_column(sa.Column('updated_by', sa.UUID(), autoincrement=False, nullable=True))
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
@ -5,7 +5,7 @@ from datetime import datetime
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from flask_login import UserMixin # type: ignore[import-untyped]
|
from flask_login import UserMixin
|
||||||
from sqlalchemy import DateTime, String, func, select
|
from sqlalchemy import DateTime, String, func, select
|
||||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
|
|||||||
@ -1239,15 +1239,6 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
|
|||||||
language = mapped_column(db.String(255), nullable=False)
|
language = mapped_column(db.String(255), nullable=False)
|
||||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
created_by = mapped_column(StringUUID, nullable=False)
|
|
||||||
updated_by = mapped_column(StringUUID, nullable=True)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def created_user_name(self):
|
|
||||||
account = db.session.query(Account).where(Account.id == self.created_by).first()
|
|
||||||
if account:
|
|
||||||
return account.name
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
|
class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
|||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import UserMixin # type: ignore[import-untyped]
|
from flask_login import UserMixin
|
||||||
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
|
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
|
||||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||||
|
|
||||||
|
|||||||
@ -219,7 +219,7 @@ class WorkflowToolProvider(TypeBase):
|
|||||||
sa.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"),
|
sa.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||||
# name of the workflow provider
|
# name of the workflow provider
|
||||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
# label of the workflow provider
|
# label of the workflow provider
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dify-api"
|
name = "dify-api"
|
||||||
version = "1.9.1"
|
version = "1.9.2"
|
||||||
requires-python = ">=3.11,<3.13"
|
requires-python = ">=3.11,<3.13"
|
||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
|||||||
@ -16,7 +16,25 @@
|
|||||||
"opentelemetry.instrumentation.requests",
|
"opentelemetry.instrumentation.requests",
|
||||||
"opentelemetry.instrumentation.sqlalchemy",
|
"opentelemetry.instrumentation.sqlalchemy",
|
||||||
"opentelemetry.instrumentation.redis",
|
"opentelemetry.instrumentation.redis",
|
||||||
"opentelemetry.instrumentation.httpx"
|
"langfuse",
|
||||||
|
"cloudscraper",
|
||||||
|
"readabilipy",
|
||||||
|
"pypandoc",
|
||||||
|
"pypdfium2",
|
||||||
|
"webvtt",
|
||||||
|
"flask_compress",
|
||||||
|
"oss2",
|
||||||
|
"baidubce.auth.bce_credentials",
|
||||||
|
"baidubce.bce_client_configuration",
|
||||||
|
"baidubce.services.bos.bos_client",
|
||||||
|
"clickzetta",
|
||||||
|
"google.cloud",
|
||||||
|
"obs",
|
||||||
|
"qcloud_cos",
|
||||||
|
"tos",
|
||||||
|
"gmpy2",
|
||||||
|
"sendgrid",
|
||||||
|
"sendgrid.helpers.mail"
|
||||||
],
|
],
|
||||||
"reportUnknownMemberType": "hint",
|
"reportUnknownMemberType": "hint",
|
||||||
"reportUnknownParameterType": "hint",
|
"reportUnknownParameterType": "hint",
|
||||||
@ -28,7 +46,7 @@
|
|||||||
"reportUnnecessaryComparison": "hint",
|
"reportUnnecessaryComparison": "hint",
|
||||||
"reportUnnecessaryIsInstance": "hint",
|
"reportUnnecessaryIsInstance": "hint",
|
||||||
"reportUntypedFunctionDecorator": "hint",
|
"reportUntypedFunctionDecorator": "hint",
|
||||||
|
"reportUnnecessaryTypeIgnoreComment": "hint",
|
||||||
"reportAttributeAccessIssue": "hint",
|
"reportAttributeAccessIssue": "hint",
|
||||||
"pythonVersion": "3.11",
|
"pythonVersion": "3.11",
|
||||||
"pythonPlatform": "All"
|
"pythonPlatform": "All"
|
||||||
|
|||||||
@ -48,7 +48,7 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
repository_class = import_string(class_path)
|
repository_class = import_string(class_path)
|
||||||
return repository_class(session_maker=session_maker) # type: ignore[no-any-return]
|
return repository_class(session_maker=session_maker)
|
||||||
except (ImportError, Exception) as e:
|
except (ImportError, Exception) as e:
|
||||||
raise RepositoryImportError(
|
raise RepositoryImportError(
|
||||||
f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}"
|
f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}"
|
||||||
@ -77,6 +77,6 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
repository_class = import_string(class_path)
|
repository_class = import_string(class_path)
|
||||||
return repository_class(session_maker=session_maker) # type: ignore[no-any-return]
|
return repository_class(session_maker=session_maker)
|
||||||
except (ImportError, Exception) as e:
|
except (ImportError, Exception) as e:
|
||||||
raise RepositoryImportError(f"Failed to create APIWorkflowRunRepository from '{class_path}': {e}") from e
|
raise RepositoryImportError(f"Failed to create APIWorkflowRunRepository from '{class_path}': {e}") from e
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from sqlalchemy.orm import Session
|
|||||||
from werkzeug.exceptions import Unauthorized
|
from werkzeug.exceptions import Unauthorized
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from constants.languages import language_timezone_mapping, languages
|
from constants.languages import get_valid_language, language_timezone_mapping
|
||||||
from events.tenant_event import tenant_was_created
|
from events.tenant_event import tenant_was_created
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client, redis_fallback
|
from extensions.ext_redis import redis_client, redis_fallback
|
||||||
@ -1259,7 +1259,7 @@ class RegisterService:
|
|||||||
return f"member_invite:token:{token}"
|
return f"member_invite:token:{token}"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setup(cls, email: str, name: str, password: str, ip_address: str):
|
def setup(cls, email: str, name: str, password: str, ip_address: str, language: str):
|
||||||
"""
|
"""
|
||||||
Setup dify
|
Setup dify
|
||||||
|
|
||||||
@ -1269,11 +1269,10 @@ class RegisterService:
|
|||||||
:param ip_address: ip address
|
:param ip_address: ip address
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Register
|
|
||||||
account = AccountService.create_account(
|
account = AccountService.create_account(
|
||||||
email=email,
|
email=email,
|
||||||
name=name,
|
name=name,
|
||||||
interface_language=languages[0],
|
interface_language=get_valid_language(language),
|
||||||
password=password,
|
password=password,
|
||||||
is_setup=True,
|
is_setup=True,
|
||||||
)
|
)
|
||||||
@ -1315,7 +1314,7 @@ class RegisterService:
|
|||||||
account = AccountService.create_account(
|
account = AccountService.create_account(
|
||||||
email=email,
|
email=email,
|
||||||
name=name,
|
name=name,
|
||||||
interface_language=language or languages[0],
|
interface_language=get_valid_language(language),
|
||||||
password=password,
|
password=password,
|
||||||
is_setup=is_setup,
|
is_setup=is_setup,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from enum import StrEnum
|
|||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import yaml # type: ignore
|
import yaml
|
||||||
from Crypto.Cipher import AES
|
from Crypto.Cipher import AES
|
||||||
from Crypto.Util.Padding import pad, unpad
|
from Crypto.Util.Padding import pad, unpad
|
||||||
from packaging import version
|
from packaging import version
|
||||||
@ -563,7 +563,7 @@ class AppDslService:
|
|||||||
else:
|
else:
|
||||||
cls._append_model_config_export_data(export_data, app_model)
|
cls._append_model_config_export_data(export_data, app_model)
|
||||||
|
|
||||||
return yaml.dump(export_data, allow_unicode=True) # type: ignore
|
return yaml.dump(export_data, allow_unicode=True)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _append_workflow_export_data(
|
def _append_workflow_export_data(
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user