Merge branch 'main' into feat/mcp-06-18

This commit is contained in:
Novice 2025-10-23 17:01:25 +08:00
commit e7a575a33c
No known key found for this signature in database
GPG Key ID: EE3F68E3105DAAAB
585 changed files with 31247 additions and 7723 deletions

View File

@ -434,6 +434,9 @@ CODE_EXECUTION_SSL_VERIFY=True
CODE_EXECUTION_POOL_MAX_CONNECTIONS=100 CODE_EXECUTION_POOL_MAX_CONNECTIONS=100
CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS=20 CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS=20
CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY=5.0 CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY=5.0
CODE_EXECUTION_CONNECT_TIMEOUT=10
CODE_EXECUTION_READ_TIMEOUT=60
CODE_EXECUTION_WRITE_TIMEOUT=10
CODE_MAX_NUMBER=9223372036854775807 CODE_MAX_NUMBER=9223372036854775807
CODE_MIN_NUMBER=-9223372036854775808 CODE_MIN_NUMBER=-9223372036854775808
CODE_MAX_STRING_LENGTH=400000 CODE_MAX_STRING_LENGTH=400000

View File

@ -548,7 +548,7 @@ class UpdateConfig(BaseSettings):
class WorkflowVariableTruncationConfig(BaseSettings): class WorkflowVariableTruncationConfig(BaseSettings):
WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE: PositiveInt = Field( WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE: PositiveInt = Field(
# 100KB # 1000 KiB
1024_000, 1024_000,
description="Maximum size for variable to trigger final truncation.", description="Maximum size for variable to trigger final truncation.",
) )

View File

@ -145,7 +145,7 @@ class DatabaseConfig(BaseSettings):
default="postgresql", default="postgresql",
) )
@computed_field # type: ignore[misc] @computed_field # type: ignore[prop-decorator]
@property @property
def SQLALCHEMY_DATABASE_URI(self) -> str: def SQLALCHEMY_DATABASE_URI(self) -> str:
db_extras = ( db_extras = (
@ -198,7 +198,7 @@ class DatabaseConfig(BaseSettings):
default=os.cpu_count() or 1, default=os.cpu_count() or 1,
) )
@computed_field # type: ignore[misc] @computed_field # type: ignore[prop-decorator]
@property @property
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]: def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
# Parse DB_EXTRAS for 'options' # Parse DB_EXTRAS for 'options'

View File

@ -56,11 +56,15 @@ else:
} }
DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions) DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)
# console
COOKIE_NAME_ACCESS_TOKEN = "access_token" COOKIE_NAME_ACCESS_TOKEN = "access_token"
COOKIE_NAME_REFRESH_TOKEN = "refresh_token" COOKIE_NAME_REFRESH_TOKEN = "refresh_token"
COOKIE_NAME_PASSPORT = "passport"
COOKIE_NAME_CSRF_TOKEN = "csrf_token" COOKIE_NAME_CSRF_TOKEN = "csrf_token"
# webapp
COOKIE_NAME_WEBAPP_ACCESS_TOKEN = "webapp_access_token"
COOKIE_NAME_PASSPORT = "passport"
HEADER_NAME_CSRF_TOKEN = "X-CSRF-Token" HEADER_NAME_CSRF_TOKEN = "X-CSRF-Token"
HEADER_NAME_APP_CODE = "X-App-Code" HEADER_NAME_APP_CODE = "X-App-Code"
HEADER_NAME_PASSPORT = "X-App-Passport" HEADER_NAME_PASSPORT = "X-App-Passport"

View File

@ -31,3 +31,9 @@ def supported_language(lang):
error = f"{lang} is not a valid language." error = f"{lang} is not a valid language."
raise ValueError(error) raise ValueError(error)
def get_valid_language(lang: str | None) -> str:
if lang and lang in languages:
return lang
return languages[0]

View File

@ -24,7 +24,7 @@ except ImportError:
) )
else: else:
warnings.warn("To use python-magic guess MIMETYPE, you need to install `libmagic`", stacklevel=2) warnings.warn("To use python-magic guess MIMETYPE, you need to install `libmagic`", stacklevel=2)
magic = None # type: ignore magic = None # type: ignore[assignment]
from pydantic import BaseModel from pydantic import BaseModel

View File

@ -4,7 +4,7 @@ from flask_restx import Resource, reqparse
import services import services
from configs import dify_config from configs import dify_config
from constants.languages import languages from constants.languages import get_valid_language
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.auth.error import ( from controllers.console.auth.error import (
AuthenticationFailedError, AuthenticationFailedError,
@ -29,8 +29,6 @@ from libs.token import (
clear_access_token_from_cookie, clear_access_token_from_cookie,
clear_csrf_token_from_cookie, clear_csrf_token_from_cookie,
clear_refresh_token_from_cookie, clear_refresh_token_from_cookie,
extract_access_token,
extract_csrf_token,
set_access_token_to_cookie, set_access_token_to_cookie,
set_csrf_token_to_cookie, set_csrf_token_to_cookie,
set_refresh_token_to_cookie, set_refresh_token_to_cookie,
@ -206,10 +204,12 @@ class EmailCodeLoginApi(Resource):
.add_argument("email", type=str, required=True, location="json") .add_argument("email", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json") .add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, location="json") .add_argument("token", type=str, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
) )
args = parser.parse_args() args = parser.parse_args()
user_email = args["email"] user_email = args["email"]
language = args["language"]
token_data = AccountService.get_email_code_login_data(args["token"]) token_data = AccountService.get_email_code_login_data(args["token"])
if token_data is None: if token_data is None:
@ -243,7 +243,9 @@ class EmailCodeLoginApi(Resource):
if account is None: if account is None:
try: try:
account = AccountService.create_account_and_tenant( account = AccountService.create_account_and_tenant(
email=user_email, name=user_email, interface_language=languages[0] email=user_email,
name=user_email,
interface_language=get_valid_language(language),
) )
except WorkSpaceNotAllowedCreateError: except WorkSpaceNotAllowedCreateError:
raise NotAllowedCreateWorkspace() raise NotAllowedCreateWorkspace()
@ -286,13 +288,3 @@ class RefreshTokenApi(Resource):
return response return response
except Exception as e: except Exception as e:
return {"result": "fail", "message": str(e)}, 401 return {"result": "fail", "message": str(e)}, 401
# this api helps frontend to check whether user is authenticated
# TODO: remove in the future. frontend should redirect to login page by catching 401 status
@console_ns.route("/login/status")
class LoginStatus(Resource):
def get(self):
token = extract_access_token(request)
csrf_token = extract_csrf_token(request)
return {"logged_in": bool(token) and bool(csrf_token)}

View File

@ -22,7 +22,7 @@ from core.errors.error import (
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager from core.workflow.graph_engine.manager import GraphEngineManager
from libs import helper from libs import helper
from libs.login import current_user as current_user_ from libs.login import current_account_with_tenant
from models.model import AppMode, InstalledApp from models.model import AppMode, InstalledApp
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError from services.errors.llm import InvokeRateLimitError
@ -31,8 +31,6 @@ from .. import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
current_user = current_user_._get_current_object() # type: ignore
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run") @console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
class InstalledAppWorkflowRunApi(InstalledAppResource): class InstalledAppWorkflowRunApi(InstalledAppResource):
@ -40,6 +38,7 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
""" """
Run workflow Run workflow
""" """
current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
if not app_model: if not app_model:
raise NotWorkflowAppError() raise NotWorkflowAppError()
@ -53,7 +52,6 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
.add_argument("files", type=list, required=False, location="json") .add_argument("files", type=list, required=False, location="json")
) )
args = parser.parse_args() args = parser.parse_args()
assert current_user is not None
try: try:
response = AppGenerateService.generate( response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
@ -89,7 +87,6 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW: if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError() raise NotWorkflowAppError()
assert current_user is not None
# Stop using both mechanisms for backward compatibility # Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check) # Legacy stop flag mechanism (without user check)

View File

@ -74,12 +74,17 @@ class SetupApi(Resource):
.add_argument("email", type=email, required=True, location="json") .add_argument("email", type=email, required=True, location="json")
.add_argument("name", type=StrLen(30), required=True, location="json") .add_argument("name", type=StrLen(30), required=True, location="json")
.add_argument("password", type=valid_password, required=True, location="json") .add_argument("password", type=valid_password, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
) )
args = parser.parse_args() args = parser.parse_args()
# setup # setup
RegisterService.setup( RegisterService.setup(
email=args["email"], name=args["name"], password=args["password"], ip_address=extract_remote_ip(request) email=args["email"],
name=args["name"],
password=args["password"],
ip_address=extract_remote_ip(request),
language=args["language"],
) )
return {"result": "success"}, 201 return {"result": "success"}, 201

View File

@ -193,15 +193,16 @@ class MCPAppApi(Resource):
except ValidationError as e: except ValidationError as e:
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}") raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str, session: Session) -> EndUser | None: def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str) -> EndUser | None:
"""Get end user from existing session - optimized query""" """Get end user - manages its own database session"""
return ( with Session(db.engine, expire_on_commit=False) as session, session.begin():
session.query(EndUser) return (
.where(EndUser.tenant_id == tenant_id) session.query(EndUser)
.where(EndUser.session_id == mcp_server_id) .where(EndUser.tenant_id == tenant_id)
.where(EndUser.type == "mcp") .where(EndUser.session_id == mcp_server_id)
.first() .where(EndUser.type == "mcp")
) .first()
)
def _create_end_user( def _create_end_user(
self, client_name: str, tenant_id: str, app_id: str, mcp_server_id: str, session: Session self, client_name: str, tenant_id: str, app_id: str, mcp_server_id: str, session: Session
@ -229,7 +230,7 @@ class MCPAppApi(Resource):
request_id: Union[int, str], request_id: Union[int, str],
) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError | None: ) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError | None:
"""Handle MCP request and return response""" """Handle MCP request and return response"""
end_user = self._retrieve_end_user(mcp_server.tenant_id, mcp_server.id, session) end_user = self._retrieve_end_user(mcp_server.tenant_id, mcp_server.id)
if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest): if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest):
client_info = mcp_request.root.params.clientInfo client_info = mcp_request.root.params.clientInfo

View File

@ -17,8 +17,8 @@ from libs.helper import email
from libs.passport import PassportService from libs.passport import PassportService
from libs.password import valid_password from libs.password import valid_password
from libs.token import ( from libs.token import (
clear_access_token_from_cookie, clear_webapp_access_token_from_cookie,
extract_access_token, extract_webapp_access_token,
) )
from services.account_service import AccountService from services.account_service import AccountService
from services.app_service import AppService from services.app_service import AppService
@ -81,7 +81,7 @@ class LoginStatusApi(Resource):
) )
def get(self): def get(self):
app_code = request.args.get("app_code") app_code = request.args.get("app_code")
token = extract_access_token(request) token = extract_webapp_access_token(request)
if not app_code: if not app_code:
return { return {
"logged_in": bool(token), "logged_in": bool(token),
@ -128,7 +128,7 @@ class LogoutApi(Resource):
response = make_response({"result": "success"}) response = make_response({"result": "success"})
# enterprise SSO sets same site to None in https deployment # enterprise SSO sets same site to None in https deployment
# so we need to logout by calling api # so we need to logout by calling api
clear_access_token_from_cookie(response, samesite="None") clear_webapp_access_token_from_cookie(response, samesite="None")
return response return response

View File

@ -12,10 +12,8 @@ from controllers.web import web_ns
from controllers.web.error import WebAppAuthRequiredError from controllers.web.error import WebAppAuthRequiredError
from extensions.ext_database import db from extensions.ext_database import db
from libs.passport import PassportService from libs.passport import PassportService
from libs.token import extract_access_token from libs.token import extract_webapp_access_token
from models.model import App, EndUser, Site from models.model import App, EndUser, Site
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService, WebAppAuthType from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
@ -37,23 +35,18 @@ class PassportResource(Resource):
system_features = FeatureService.get_system_features() system_features = FeatureService.get_system_features()
app_code = request.headers.get(HEADER_NAME_APP_CODE) app_code = request.headers.get(HEADER_NAME_APP_CODE)
user_id = request.args.get("user_id") user_id = request.args.get("user_id")
access_token = extract_access_token(request) access_token = extract_webapp_access_token(request)
if app_code is None: if app_code is None:
raise Unauthorized("X-App-Code header is missing.") raise Unauthorized("X-App-Code header is missing.")
app_id = AppService.get_app_id_by_code(app_code)
# exchange token for enterprise logined web user
enterprise_user_decoded = decode_enterprise_webapp_user_id(access_token)
if enterprise_user_decoded:
# a web user has already logged in, exchange a token for this app without redirecting to the login page
return exchange_token_for_existing_web_user(
app_code=app_code, enterprise_user_decoded=enterprise_user_decoded
)
if system_features.webapp_auth.enabled: if system_features.webapp_auth.enabled:
app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id) enterprise_user_decoded = decode_enterprise_webapp_user_id(access_token)
if not app_settings or not app_settings.access_mode == "public": app_auth_type = WebAppAuthService.get_app_auth_type(app_code=app_code)
raise WebAppAuthRequiredError() if app_auth_type != WebAppAuthType.PUBLIC:
if not enterprise_user_decoded:
raise WebAppAuthRequiredError()
return exchange_token_for_existing_web_user(
app_code=app_code, enterprise_user_decoded=enterprise_user_decoded, auth_type=app_auth_type
)
# get site from db and check if it is normal # get site from db and check if it is normal
site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal")) site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal"))
@ -124,7 +117,7 @@ def decode_enterprise_webapp_user_id(jwt_token: str | None):
return decoded return decoded
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict): def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict, auth_type: WebAppAuthType):
""" """
Exchange a token for an existing web user session. Exchange a token for an existing web user session.
""" """
@ -145,13 +138,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
if not app_model or app_model.status != "normal" or not app_model.enable_site: if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound() raise NotFound()
app_auth_type = WebAppAuthService.get_app_auth_type(app_code=app_code) if auth_type == WebAppAuthType.PUBLIC:
if app_auth_type == WebAppAuthType.PUBLIC:
return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded) return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded)
elif app_auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external": elif auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external":
raise WebAppAuthRequiredError("Please login as external user.") raise WebAppAuthRequiredError("Please login as external user.")
elif app_auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal": elif auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal":
raise WebAppAuthRequiredError("Please login as internal user.") raise WebAppAuthRequiredError("Please login as internal user.")
end_user = None end_user = None

View File

@ -211,8 +211,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
user=user, user=user,
stream=streaming, stream=streaming,
) )
# FIXME: Type hinting issue here, ignore it for now, will fix it later return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) # type: ignore
def _generate_worker( def _generate_worker(
self, self,

View File

@ -255,7 +255,7 @@ class PipelineGenerator(BaseAppGenerator):
json_text = json.dumps(text) json_text = json.dumps(text)
upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id) upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id)
features = FeatureService.get_features(dataset.tenant_id) features = FeatureService.get_features(dataset.tenant_id)
if features.billing.subscription.plan == "sandbox": if features.billing.enabled and features.billing.subscription.plan == "sandbox":
tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}" tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}"
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}" tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}"

View File

@ -89,7 +89,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data) response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute] response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else: else:
response_chunk.update(sub_stream_response.model_dump(mode="json")) response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk yield response_chunk

View File

@ -98,7 +98,7 @@ class RateLimit:
else: else:
return RateLimitGenerator( return RateLimitGenerator(
rate_limit=self, rate_limit=self,
generator=generator, # ty: ignore [invalid-argument-type] generator=generator,
request_id=request_id, request_id=request_id,
) )

View File

@ -49,7 +49,7 @@ class BasedGenerateTaskPipeline:
if isinstance(e, InvokeAuthorizationError): if isinstance(e, InvokeAuthorizationError):
err = InvokeAuthorizationError("Incorrect API key provided") err = InvokeAuthorizationError("Incorrect API key provided")
elif isinstance(e, InvokeError | ValueError): elif isinstance(e, InvokeError | ValueError):
err = e # ty: ignore [invalid-assignment] err = e
else: else:
description = getattr(e, "description", None) description = getattr(e, "description", None)
err = Exception(description if description is not None else str(e)) err = Exception(description if description is not None else str(e))

View File

@ -1868,7 +1868,7 @@ class ProviderConfigurations(BaseModel):
if "/" not in key: if "/" not in key:
key = str(ModelProviderID(key)) key = str(ModelProviderID(key))
return self.configurations.get(key, default) # type: ignore return self.configurations.get(key, default)
class ProviderModelBundle(BaseModel): class ProviderModelBundle(BaseModel):

View File

@ -20,7 +20,7 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz
else: else:
# Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly # Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
# FIXME: mypy does not support the type of spec.loader # FIXME: mypy does not support the type of spec.loader
spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore[assignment]
if not spec or not spec.loader: if not spec or not spec.loader:
raise Exception(f"Failed to load module {module_name} from {py_file_path!r}") raise Exception(f"Failed to load module {module_name} from {py_file_path!r}")
if use_lazy_loader: if use_lazy_loader:

View File

@ -49,62 +49,80 @@ class IndexingRunner:
self.storage = storage self.storage = storage
self.model_manager = ModelManager() self.model_manager = ModelManager()
def _handle_indexing_error(self, document_id: str, error: Exception) -> None:
"""Handle indexing errors by updating document status."""
logger.exception("consume document failed")
document = db.session.get(DatasetDocument, document_id)
if document:
document.indexing_status = "error"
error_message = getattr(error, "description", str(error))
document.error = str(error_message)
document.stopped_at = naive_utc_now()
db.session.commit()
def run(self, dataset_documents: list[DatasetDocument]): def run(self, dataset_documents: list[DatasetDocument]):
"""Run the indexing process.""" """Run the indexing process."""
for dataset_document in dataset_documents: for dataset_document in dataset_documents:
document_id = dataset_document.id
try: try:
# Re-query the document to ensure it's bound to the current session
requeried_document = db.session.get(DatasetDocument, document_id)
if not requeried_document:
logger.warning("Document not found, skipping document id: %s", document_id)
continue
# get dataset # get dataset
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
if not dataset: if not dataset:
raise ValueError("no dataset found") raise ValueError("no dataset found")
# get the process rule # get the process rule
stmt = select(DatasetProcessRule).where( stmt = select(DatasetProcessRule).where(
DatasetProcessRule.id == dataset_document.dataset_process_rule_id DatasetProcessRule.id == requeried_document.dataset_process_rule_id
) )
processing_rule = db.session.scalar(stmt) processing_rule = db.session.scalar(stmt)
if not processing_rule: if not processing_rule:
raise ValueError("no process rule found") raise ValueError("no process rule found")
index_type = dataset_document.doc_form index_type = requeried_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor = IndexProcessorFactory(index_type).init_index_processor()
# extract # extract
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict()) text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
# transform # transform
documents = self._transform( documents = self._transform(
index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict() index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict()
) )
# save segment # save segment
self._load_segments(dataset, dataset_document, documents) self._load_segments(dataset, requeried_document, documents)
# load # load
self._load( self._load(
index_processor=index_processor, index_processor=index_processor,
dataset=dataset, dataset=dataset,
dataset_document=dataset_document, dataset_document=requeried_document,
documents=documents, documents=documents,
) )
except DocumentIsPausedError: except DocumentIsPausedError:
raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}") raise DocumentIsPausedError(f"Document paused, document id: {document_id}")
except ProviderTokenNotInitError as e: except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error" self._handle_indexing_error(document_id, e)
dataset_document.error = str(e.description)
dataset_document.stopped_at = naive_utc_now()
db.session.commit()
except ObjectDeletedError: except ObjectDeletedError:
logger.warning("Document deleted, document id: %s", dataset_document.id) logger.warning("Document deleted, document id: %s", document_id)
except Exception as e: except Exception as e:
logger.exception("consume document failed") self._handle_indexing_error(document_id, e)
dataset_document.indexing_status = "error"
dataset_document.error = str(e)
dataset_document.stopped_at = naive_utc_now()
db.session.commit()
def run_in_splitting_status(self, dataset_document: DatasetDocument): def run_in_splitting_status(self, dataset_document: DatasetDocument):
"""Run the indexing process when the index_status is splitting.""" """Run the indexing process when the index_status is splitting."""
document_id = dataset_document.id
try: try:
# Re-query the document to ensure it's bound to the current session
requeried_document = db.session.get(DatasetDocument, document_id)
if not requeried_document:
logger.warning("Document not found: %s", document_id)
return
# get dataset # get dataset
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
if not dataset: if not dataset:
raise ValueError("no dataset found") raise ValueError("no dataset found")
@ -112,57 +130,60 @@ class IndexingRunner:
# get exist document_segment list and delete # get exist document_segment list and delete
document_segments = ( document_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter_by(dataset_id=dataset.id, document_id=dataset_document.id) .filter_by(dataset_id=dataset.id, document_id=requeried_document.id)
.all() .all()
) )
for document_segment in document_segments: for document_segment in document_segments:
db.session.delete(document_segment) db.session.delete(document_segment)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX:
# delete child chunks # delete child chunks
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete() db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
db.session.commit() db.session.commit()
# get the process rule # get the process rule
stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == requeried_document.dataset_process_rule_id)
processing_rule = db.session.scalar(stmt) processing_rule = db.session.scalar(stmt)
if not processing_rule: if not processing_rule:
raise ValueError("no process rule found") raise ValueError("no process rule found")
index_type = dataset_document.doc_form index_type = requeried_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor = IndexProcessorFactory(index_type).init_index_processor()
# extract # extract
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict()) text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
# transform # transform
documents = self._transform( documents = self._transform(
index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict() index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict()
) )
# save segment # save segment
self._load_segments(dataset, dataset_document, documents) self._load_segments(dataset, requeried_document, documents)
# load # load
self._load( self._load(
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents index_processor=index_processor,
dataset=dataset,
dataset_document=requeried_document,
documents=documents,
) )
except DocumentIsPausedError: except DocumentIsPausedError:
raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}") raise DocumentIsPausedError(f"Document paused, document id: {document_id}")
except ProviderTokenNotInitError as e: except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error" self._handle_indexing_error(document_id, e)
dataset_document.error = str(e.description)
dataset_document.stopped_at = naive_utc_now()
db.session.commit()
except Exception as e: except Exception as e:
logger.exception("consume document failed") self._handle_indexing_error(document_id, e)
dataset_document.indexing_status = "error"
dataset_document.error = str(e)
dataset_document.stopped_at = naive_utc_now()
db.session.commit()
def run_in_indexing_status(self, dataset_document: DatasetDocument): def run_in_indexing_status(self, dataset_document: DatasetDocument):
"""Run the indexing process when the index_status is indexing.""" """Run the indexing process when the index_status is indexing."""
document_id = dataset_document.id
try: try:
# Re-query the document to ensure it's bound to the current session
requeried_document = db.session.get(DatasetDocument, document_id)
if not requeried_document:
logger.warning("Document not found: %s", document_id)
return
# get dataset # get dataset
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
if not dataset: if not dataset:
raise ValueError("no dataset found") raise ValueError("no dataset found")
@ -170,7 +191,7 @@ class IndexingRunner:
# get exist document_segment list and delete # get exist document_segment list and delete
document_segments = ( document_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter_by(dataset_id=dataset.id, document_id=dataset_document.id) .filter_by(dataset_id=dataset.id, document_id=requeried_document.id)
.all() .all()
) )
@ -188,7 +209,7 @@ class IndexingRunner:
"dataset_id": document_segment.dataset_id, "dataset_id": document_segment.dataset_id,
}, },
) )
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = document_segment.get_child_chunks() child_chunks = document_segment.get_child_chunks()
if child_chunks: if child_chunks:
child_documents = [] child_documents = []
@ -206,24 +227,20 @@ class IndexingRunner:
document.children = child_documents document.children = child_documents
documents.append(document) documents.append(document)
# build index # build index
index_type = dataset_document.doc_form index_type = requeried_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor = IndexProcessorFactory(index_type).init_index_processor()
self._load( self._load(
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents index_processor=index_processor,
dataset=dataset,
dataset_document=requeried_document,
documents=documents,
) )
except DocumentIsPausedError: except DocumentIsPausedError:
raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}") raise DocumentIsPausedError(f"Document paused, document id: {document_id}")
except ProviderTokenNotInitError as e: except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error" self._handle_indexing_error(document_id, e)
dataset_document.error = str(e.description)
dataset_document.stopped_at = naive_utc_now()
db.session.commit()
except Exception as e: except Exception as e:
logger.exception("consume document failed") self._handle_indexing_error(document_id, e)
dataset_document.indexing_status = "error"
dataset_document.error = str(e)
dataset_document.stopped_at = naive_utc_now()
db.session.commit()
def indexing_estimate( def indexing_estimate(
self, self,
@ -398,7 +415,6 @@ class IndexingRunner:
document_id=dataset_document.id, document_id=dataset_document.id,
after_indexing_status="splitting", after_indexing_status="splitting",
extra_update_params={ extra_update_params={
DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs),
DatasetDocument.parsing_completed_at: naive_utc_now(), DatasetDocument.parsing_completed_at: naive_utc_now(),
}, },
) )
@ -738,6 +754,7 @@ class IndexingRunner:
extra_update_params={ extra_update_params={
DatasetDocument.cleaning_completed_at: cur_time, DatasetDocument.cleaning_completed_at: cur_time,
DatasetDocument.splitting_completed_at: cur_time, DatasetDocument.splitting_completed_at: cur_time,
DatasetDocument.word_count: sum(len(doc.page_content) for doc in documents),
}, },
) )

View File

@ -100,7 +100,7 @@ class LLMGenerator:
return name return name
@classmethod @classmethod
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str): def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str) -> Sequence[str]:
output_parser = SuggestedQuestionsAfterAnswerOutputParser() output_parser = SuggestedQuestionsAfterAnswerOutputParser()
format_instructions = output_parser.get_format_instructions() format_instructions = output_parser.get_format_instructions()
@ -119,6 +119,8 @@ class LLMGenerator:
prompt_messages = [UserPromptMessage(content=prompt)] prompt_messages = [UserPromptMessage(content=prompt)]
questions: Sequence[str] = []
try: try:
response: LLMResult = model_instance.invoke_llm( response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), prompt_messages=list(prompt_messages),

View File

@ -1,17 +1,26 @@
import json import json
import logging
import re import re
from collections.abc import Sequence
from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
logger = logging.getLogger(__name__)
class SuggestedQuestionsAfterAnswerOutputParser: class SuggestedQuestionsAfterAnswerOutputParser:
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
def parse(self, text: str): def parse(self, text: str) -> Sequence[str]:
action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL) action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL)
questions: list[str] = []
if action_match is not None: if action_match is not None:
json_obj = json.loads(action_match.group(0).strip()) try:
else: json_obj = json.loads(action_match.group(0).strip())
json_obj = [] except json.JSONDecodeError as exc:
return json_obj logger.warning("Failed to decode suggested questions payload: %s", exc)
else:
if isinstance(json_obj, list):
questions = [question for question in json_obj if isinstance(question, str)]
return questions

View File

@ -2,7 +2,7 @@ import logging
import os import os
from datetime import datetime, timedelta from datetime import datetime, timedelta
from langfuse import Langfuse # type: ignore from langfuse import Langfuse
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance from core.ops.base_trace_instance import BaseTraceInstance

View File

@ -76,7 +76,7 @@ class PluginParameter(BaseModel):
auto_generate: PluginParameterAutoGenerate | None = None auto_generate: PluginParameterAutoGenerate | None = None
template: PluginParameterTemplate | None = None template: PluginParameterTemplate | None = None
required: bool = False required: bool = False
default: Union[float, int, str] | None = None default: Union[float, int, str, bool] | None = None
min: Union[float, int] | None = None min: Union[float, int] | None = None
max: Union[float, int] | None = None max: Union[float, int] | None = None
precision: int | None = None precision: int | None = None

View File

@ -180,7 +180,7 @@ class BasePluginClient:
Make a request to the plugin daemon inner API and return the response as a model. Make a request to the plugin daemon inner API and return the response as a model.
""" """
response = self._request(method, path, headers, data, params, files) response = self._request(method, path, headers, data, params, files)
return type_(**response.json()) # type: ignore return type_(**response.json()) # type: ignore[return-value]
def _request_with_plugin_daemon_response( def _request_with_plugin_daemon_response(
self, self,

View File

@ -40,7 +40,7 @@ class PluginDaemonBadRequestError(PluginDaemonClientSideError):
description: str = "Bad Request" description: str = "Bad Request"
class PluginInvokeError(PluginDaemonClientSideError): class PluginInvokeError(PluginDaemonClientSideError, ValueError):
description: str = "Invoke Error" description: str = "Invoke Error"
def _get_error_object(self) -> Mapping: def _get_error_object(self) -> Mapping:

View File

@ -72,6 +72,19 @@ default_retrieval_model: dict[str, Any] = {
class DatasetRetrieval: class DatasetRetrieval:
def __init__(self, application_generate_entity=None): def __init__(self, application_generate_entity=None):
self.application_generate_entity = application_generate_entity self.application_generate_entity = application_generate_entity
self._llm_usage = LLMUsage.empty_usage()
@property
def llm_usage(self) -> LLMUsage:
return self._llm_usage.model_copy()
def _record_usage(self, usage: LLMUsage | None) -> None:
if usage is None or usage.total_tokens <= 0:
return
if self._llm_usage.total_tokens == 0:
self._llm_usage = usage
else:
self._llm_usage = self._llm_usage.plus(usage)
def retrieve( def retrieve(
self, self,
@ -312,15 +325,18 @@ class DatasetRetrieval:
) )
tools.append(message_tool) tools.append(message_tool)
dataset_id = None dataset_id = None
router_usage = LLMUsage.empty_usage()
if planning_strategy == PlanningStrategy.REACT_ROUTER: if planning_strategy == PlanningStrategy.REACT_ROUTER:
react_multi_dataset_router = ReactMultiDatasetRouter() react_multi_dataset_router = ReactMultiDatasetRouter()
dataset_id = react_multi_dataset_router.invoke( dataset_id, router_usage = react_multi_dataset_router.invoke(
query, tools, model_config, model_instance, user_id, tenant_id query, tools, model_config, model_instance, user_id, tenant_id
) )
elif planning_strategy == PlanningStrategy.ROUTER: elif planning_strategy == PlanningStrategy.ROUTER:
function_call_router = FunctionCallMultiDatasetRouter() function_call_router = FunctionCallMultiDatasetRouter()
dataset_id = function_call_router.invoke(query, tools, model_config, model_instance) dataset_id, router_usage = function_call_router.invoke(query, tools, model_config, model_instance)
self._record_usage(router_usage)
if dataset_id: if dataset_id:
# get retrieval model config # get retrieval model config
@ -983,7 +999,8 @@ class DatasetRetrieval:
) )
# handle invoke result # handle invoke result
result_text, _ = self._handle_invoke_result(invoke_result=invoke_result) result_text, usage = self._handle_invoke_result(invoke_result=invoke_result)
self._record_usage(usage)
result_text_json = parse_and_check_json_markdown(result_text, []) result_text_json = parse_and_check_json_markdown(result_text, [])
automatic_metadata_filters = [] automatic_metadata_filters = []

View File

@ -2,7 +2,7 @@ from typing import Union
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage
@ -13,15 +13,15 @@ class FunctionCallMultiDatasetRouter:
dataset_tools: list[PromptMessageTool], dataset_tools: list[PromptMessageTool],
model_config: ModelConfigWithCredentialsEntity, model_config: ModelConfigWithCredentialsEntity,
model_instance: ModelInstance, model_instance: ModelInstance,
) -> Union[str, None]: ) -> tuple[Union[str, None], LLMUsage]:
"""Given input, decided what to do. """Given input, decided what to do.
Returns: Returns:
Action specifying what tool to use. Action specifying what tool to use.
""" """
if len(dataset_tools) == 0: if len(dataset_tools) == 0:
return None return None, LLMUsage.empty_usage()
elif len(dataset_tools) == 1: elif len(dataset_tools) == 1:
return dataset_tools[0].name return dataset_tools[0].name, LLMUsage.empty_usage()
try: try:
prompt_messages = [ prompt_messages = [
@ -34,9 +34,10 @@ class FunctionCallMultiDatasetRouter:
stream=False, stream=False,
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500}, model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
) )
usage = result.usage or LLMUsage.empty_usage()
if result.message.tool_calls: if result.message.tool_calls:
# get retrieval model config # get retrieval model config
return result.message.tool_calls[0].function.name return result.message.tool_calls[0].function.name, usage
return None return None, usage
except Exception: except Exception:
return None return None, LLMUsage.empty_usage()

View File

@ -58,15 +58,15 @@ class ReactMultiDatasetRouter:
model_instance: ModelInstance, model_instance: ModelInstance,
user_id: str, user_id: str,
tenant_id: str, tenant_id: str,
) -> Union[str, None]: ) -> tuple[Union[str, None], LLMUsage]:
"""Given input, decided what to do. """Given input, decided what to do.
Returns: Returns:
Action specifying what tool to use. Action specifying what tool to use.
""" """
if len(dataset_tools) == 0: if len(dataset_tools) == 0:
return None return None, LLMUsage.empty_usage()
elif len(dataset_tools) == 1: elif len(dataset_tools) == 1:
return dataset_tools[0].name return dataset_tools[0].name, LLMUsage.empty_usage()
try: try:
return self._react_invoke( return self._react_invoke(
@ -78,7 +78,7 @@ class ReactMultiDatasetRouter:
tenant_id=tenant_id, tenant_id=tenant_id,
) )
except Exception: except Exception:
return None return None, LLMUsage.empty_usage()
def _react_invoke( def _react_invoke(
self, self,
@ -91,7 +91,7 @@ class ReactMultiDatasetRouter:
prefix: str = PREFIX, prefix: str = PREFIX,
suffix: str = SUFFIX, suffix: str = SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS, format_instructions: str = FORMAT_INSTRUCTIONS,
) -> Union[str, None]: ) -> tuple[Union[str, None], LLMUsage]:
prompt: Union[list[ChatModelMessage], CompletionModelPromptTemplate] prompt: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
if model_config.mode == "chat": if model_config.mode == "chat":
prompt = self.create_chat_prompt( prompt = self.create_chat_prompt(
@ -120,7 +120,7 @@ class ReactMultiDatasetRouter:
memory=None, memory=None,
model_config=model_config, model_config=model_config,
) )
result_text, _ = self._invoke_llm( result_text, usage = self._invoke_llm(
completion_param=model_config.parameters, completion_param=model_config.parameters,
model_instance=model_instance, model_instance=model_instance,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
@ -131,8 +131,8 @@ class ReactMultiDatasetRouter:
output_parser = StructuredChatOutputParser() output_parser = StructuredChatOutputParser()
react_decision = output_parser.parse(result_text) react_decision = output_parser.parse(result_text)
if isinstance(react_decision, ReactAction): if isinstance(react_decision, ReactAction):
return react_decision.tool return react_decision.tool, usage
return None return None, usage
def _invoke_llm( def _invoke_llm(
self, self,

View File

@ -74,7 +74,7 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository):
tenant_id = extract_tenant_id(user) tenant_id = extract_tenant_id(user)
if not tenant_id: if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id") raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None self._tenant_id = tenant_id
# Store app context # Store app context
self._app_id = app_id self._app_id = app_id

View File

@ -81,7 +81,7 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
tenant_id = extract_tenant_id(user) tenant_id = extract_tenant_id(user)
if not tenant_id: if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id") raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None self._tenant_id = tenant_id
# Store app context # Store app context
self._app_id = app_id self._app_id = app_id

View File

@ -60,7 +60,7 @@ class DifyCoreRepositoryFactory:
try: try:
repository_class = import_string(class_path) repository_class = import_string(class_path)
return repository_class( # type: ignore[no-any-return] return repository_class(
session_factory=session_factory, session_factory=session_factory,
user=user, user=user,
app_id=app_id, app_id=app_id,
@ -96,7 +96,7 @@ class DifyCoreRepositoryFactory:
try: try:
repository_class = import_string(class_path) repository_class = import_string(class_path)
return repository_class( # type: ignore[no-any-return] return repository_class(
session_factory=session_factory, session_factory=session_factory,
user=user, user=user,
app_id=app_id, app_id=app_id,

View File

@ -157,7 +157,7 @@ class BuiltinToolProviderController(ToolProviderController):
""" """
returns the tool that the provider can provide returns the tool that the provider can provide
""" """
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None) # type: ignore return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None)
@property @property
def need_credentials(self) -> bool: def need_credentials(self) -> bool:

View File

@ -43,7 +43,7 @@ class TTSTool(BuiltinTool):
content_text=tool_parameters.get("text"), # type: ignore content_text=tool_parameters.get("text"), # type: ignore
user=user_id, user=user_id,
tenant_id=self.runtime.tenant_id, tenant_id=self.runtime.tenant_id,
voice=voice, # type: ignore voice=voice,
) )
buffer = io.BytesIO() buffer = io.BytesIO()
for chunk in tts: for chunk in tts:

View File

@ -34,6 +34,7 @@ class LocaltimeToTimestampTool(BuiltinTool):
yield self.create_text_message(f"{timestamp}") yield self.create_text_message(f"{timestamp}")
# TODO: this method's type is messy
@staticmethod @staticmethod
def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None: def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None:
try: try:

View File

@ -48,6 +48,6 @@ class TimezoneConversionTool(BuiltinTool):
datetime_with_tz = input_timezone.localize(local_time) datetime_with_tz = input_timezone.localize(local_time)
# timezone convert # timezone convert
converted_datetime = datetime_with_tz.astimezone(output_timezone) converted_datetime = datetime_with_tz.astimezone(output_timezone)
return converted_datetime.strftime(format=time_format) # type: ignore return converted_datetime.strftime(time_format)
except Exception as e: except Exception as e:
raise ToolInvokeError(str(e)) raise ToolInvokeError(str(e))

View File

@ -113,7 +113,7 @@ class MCPToolProviderController(ToolProviderController):
""" """
pass pass
def get_tool(self, tool_name: str) -> MCPTool: # type: ignore def get_tool(self, tool_name: str) -> MCPTool:
""" """
return tool with given name return tool with given name
""" """
@ -136,7 +136,7 @@ class MCPToolProviderController(ToolProviderController):
sse_read_timeout=self.sse_read_timeout, sse_read_timeout=self.sse_read_timeout,
) )
def get_tools(self) -> list[MCPTool]: # type: ignore def get_tools(self) -> list[MCPTool]:
""" """
get all tools get all tools
""" """

View File

@ -26,7 +26,7 @@ class ToolLabelManager:
labels = cls.filter_tool_labels(labels) labels = cls.filter_tool_labels(labels)
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
provider_id = controller.provider_id # ty: ignore [unresolved-attribute] provider_id = controller.provider_id
else: else:
raise ValueError("Unsupported tool type") raise ValueError("Unsupported tool type")
@ -51,7 +51,7 @@ class ToolLabelManager:
Get tool labels Get tool labels
""" """
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
provider_id = controller.provider_id # ty: ignore [unresolved-attribute] provider_id = controller.provider_id
elif isinstance(controller, BuiltinToolProviderController): elif isinstance(controller, BuiltinToolProviderController):
return controller.tool_labels return controller.tool_labels
else: else:
@ -85,7 +85,7 @@ class ToolLabelManager:
provider_ids = [] provider_ids = []
for controller in tool_providers: for controller in tool_providers:
assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController) assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
provider_ids.append(controller.provider_id) # ty: ignore [unresolved-attribute] provider_ids.append(controller.provider_id)
labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all() labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all()

View File

@ -331,7 +331,8 @@ class ToolManager:
workflow_provider_stmt = select(WorkflowToolProvider).where( workflow_provider_stmt = select(WorkflowToolProvider).where(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id
) )
workflow_provider = db.session.scalar(workflow_provider_stmt) with Session(db.engine, expire_on_commit=False) as session, session.begin():
workflow_provider = session.scalar(workflow_provider_stmt)
if workflow_provider is None: if workflow_provider is None:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")

View File

@ -193,18 +193,18 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
DatasetDocument.enabled == True, DatasetDocument.enabled == True,
DatasetDocument.archived == False, DatasetDocument.archived == False,
) )
document = db.session.scalar(dataset_document_stmt) # type: ignore document = db.session.scalar(dataset_document_stmt)
if dataset and document: if dataset and document:
source = RetrievalSourceMetadata( source = RetrievalSourceMetadata(
dataset_id=dataset.id, dataset_id=dataset.id,
dataset_name=dataset.name, dataset_name=dataset.name,
document_id=document.id, # type: ignore document_id=document.id,
document_name=document.name, # type: ignore document_name=document.name,
data_source_type=document.data_source_type, # type: ignore data_source_type=document.data_source_type,
segment_id=segment.id, segment_id=segment.id,
retriever_from=self.retriever_from, retriever_from=self.retriever_from,
score=record.score or 0.0, score=record.score or 0.0,
doc_metadata=document.doc_metadata, # type: ignore doc_metadata=document.doc_metadata,
) )
if self.retriever_from == "dev": if self.retriever_from == "dev":

View File

@ -62,6 +62,11 @@ class ApiBasedToolSchemaParser:
root = root[ref] root = root[ref]
interface["operation"]["parameters"][i] = root interface["operation"]["parameters"][i] = root
for parameter in interface["operation"]["parameters"]: for parameter in interface["operation"]["parameters"]:
# Handle complex type defaults that are not supported by PluginParameter
default_value = None
if "schema" in parameter and "default" in parameter["schema"]:
default_value = ApiBasedToolSchemaParser._sanitize_default_value(parameter["schema"]["default"])
tool_parameter = ToolParameter( tool_parameter = ToolParameter(
name=parameter["name"], name=parameter["name"],
label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]), label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]),
@ -72,9 +77,7 @@ class ApiBasedToolSchemaParser:
required=parameter.get("required", False), required=parameter.get("required", False),
form=ToolParameter.ToolParameterForm.LLM, form=ToolParameter.ToolParameterForm.LLM,
llm_description=parameter.get("description"), llm_description=parameter.get("description"),
default=parameter["schema"]["default"] default=default_value,
if "schema" in parameter and "default" in parameter["schema"]
else None,
placeholder=I18nObject( placeholder=I18nObject(
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "") en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
), ),
@ -134,6 +137,11 @@ class ApiBasedToolSchemaParser:
required = body_schema.get("required", []) required = body_schema.get("required", [])
properties = body_schema.get("properties", {}) properties = body_schema.get("properties", {})
for name, property in properties.items(): for name, property in properties.items():
# Handle complex type defaults that are not supported by PluginParameter
default_value = ApiBasedToolSchemaParser._sanitize_default_value(
property.get("default", None)
)
tool = ToolParameter( tool = ToolParameter(
name=name, name=name,
label=I18nObject(en_US=name, zh_Hans=name), label=I18nObject(en_US=name, zh_Hans=name),
@ -144,12 +152,11 @@ class ApiBasedToolSchemaParser:
required=name in required, required=name in required,
form=ToolParameter.ToolParameterForm.LLM, form=ToolParameter.ToolParameterForm.LLM,
llm_description=property.get("description", ""), llm_description=property.get("description", ""),
default=property.get("default", None), default=default_value,
placeholder=I18nObject( placeholder=I18nObject(
en_US=property.get("description", ""), zh_Hans=property.get("description", "") en_US=property.get("description", ""), zh_Hans=property.get("description", "")
), ),
) )
# check if there is a type # check if there is a type
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property) typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
if typ: if typ:
@ -197,6 +204,22 @@ class ApiBasedToolSchemaParser:
return bundles return bundles
@staticmethod
def _sanitize_default_value(value):
"""
Sanitize default values for PluginParameter compatibility.
Complex types (list, dict) are converted to None to avoid validation errors.
Args:
value: The default value from OpenAPI schema
Returns:
None for complex types (list, dict), otherwise the original value
"""
if isinstance(value, (list, dict)):
return None
return value
@staticmethod @staticmethod
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None: def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None:
parameter = parameter or {} parameter = parameter or {}
@ -217,7 +240,11 @@ class ApiBasedToolSchemaParser:
return ToolParameter.ToolParameterType.STRING return ToolParameter.ToolParameterType.STRING
elif typ == "array": elif typ == "array":
items = parameter.get("items") or parameter.get("schema", {}).get("items") items = parameter.get("items") or parameter.get("schema", {}).get("items")
return ToolParameter.ToolParameterType.FILES if items and items.get("format") == "binary" else None if items and items.get("format") == "binary":
return ToolParameter.ToolParameterType.FILES
else:
# For regular arrays, return ARRAY type instead of None
return ToolParameter.ToolParameterType.ARRAY
else: else:
return None return None

View File

@ -6,8 +6,8 @@ from typing import Any, cast
from urllib.parse import unquote from urllib.parse import unquote
import chardet import chardet
import cloudscraper # type: ignore import cloudscraper
from readabilipy import simple_json_from_html_string # type: ignore from readabilipy import simple_json_from_html_string
from core.helper import ssrf_proxy from core.helper import ssrf_proxy
from core.rag.extractor import extract_processor from core.rag.extractor import extract_processor
@ -63,8 +63,8 @@ def get_url(url: str, user_agent: str | None = None) -> str:
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
elif response.status_code == 403: elif response.status_code == 403:
scraper = cloudscraper.create_scraper() scraper = cloudscraper.create_scraper()
scraper.perform_request = ssrf_proxy.make_request # type: ignore scraper.perform_request = ssrf_proxy.make_request
response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) # type: ignore response = scraper.get(url, headers=headers, timeout=(120, 300))
if response.status_code != 200: if response.status_code != 200:
return f"URL returned status code {response.status_code}." return f"URL returned status code {response.status_code}."

View File

@ -3,7 +3,7 @@ from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
import yaml # type: ignore import yaml
from yaml import YAMLError from yaml import YAMLError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -1,6 +1,7 @@
from collections.abc import Mapping from collections.abc import Mapping
from pydantic import Field from pydantic import Field
from sqlalchemy.orm import Session
from core.app.app_config.entities import VariableEntity, VariableEntityType from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
@ -20,6 +21,7 @@ from core.tools.entities.tool_entities import (
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from core.tools.workflow_as_tool.tool import WorkflowTool from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account
from models.model import App, AppMode from models.model import App, AppMode
from models.tools import WorkflowToolProvider from models.tools import WorkflowToolProvider
from models.workflow import Workflow from models.workflow import Workflow
@ -44,29 +46,34 @@ class WorkflowToolProviderController(ToolProviderController):
@classmethod @classmethod
def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController": def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController":
app = db_provider.app with Session(db.engine, expire_on_commit=False) as session, session.begin():
provider = session.get(WorkflowToolProvider, db_provider.id) if db_provider.id else None
if not provider:
raise ValueError("workflow provider not found")
app = session.get(App, provider.app_id)
if not app:
raise ValueError("app not found")
if not app: user = session.get(Account, provider.user_id) if provider.user_id else None
raise ValueError("app not found")
controller = WorkflowToolProviderController( controller = WorkflowToolProviderController(
entity=ToolProviderEntity( entity=ToolProviderEntity(
identity=ToolProviderIdentity( identity=ToolProviderIdentity(
author=db_provider.user.name if db_provider.user_id and db_provider.user else "", author=user.name if user else "",
name=db_provider.label, name=provider.label,
label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label), label=I18nObject(en_US=provider.label, zh_Hans=provider.label),
description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description), description=I18nObject(en_US=provider.description, zh_Hans=provider.description),
icon=db_provider.icon, icon=provider.icon,
),
credentials_schema=[],
plugin_id=None,
), ),
credentials_schema=[], provider_id=provider.id or "",
plugin_id=None, )
),
provider_id=db_provider.id or "",
)
# init tools controller.tools = [
controller._get_db_provider_tool(provider, app, session=session, user=user),
controller.tools = [controller._get_db_provider_tool(db_provider, app)] ]
return controller return controller
@ -74,7 +81,14 @@ class WorkflowToolProviderController(ToolProviderController):
def provider_type(self) -> ToolProviderType: def provider_type(self) -> ToolProviderType:
return ToolProviderType.WORKFLOW return ToolProviderType.WORKFLOW
def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool: def _get_db_provider_tool(
self,
db_provider: WorkflowToolProvider,
app: App,
*,
session: Session,
user: Account | None = None,
) -> WorkflowTool:
""" """
get db provider tool get db provider tool
:param db_provider: the db provider :param db_provider: the db provider
@ -82,7 +96,7 @@ class WorkflowToolProviderController(ToolProviderController):
:return: the tool :return: the tool
""" """
workflow: Workflow | None = ( workflow: Workflow | None = (
db.session.query(Workflow) session.query(Workflow)
.where(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version) .where(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
.first() .first()
) )
@ -99,9 +113,7 @@ class WorkflowToolProviderController(ToolProviderController):
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph) variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
def fetch_workflow_variable(variable_name: str) -> VariableEntity | None: def fetch_workflow_variable(variable_name: str) -> VariableEntity | None:
return next(filter(lambda x: x.variable == variable_name, variables), None) # type: ignore return next(filter(lambda x: x.variable == variable_name, variables), None)
user = db_provider.user
workflow_tool_parameters = [] workflow_tool_parameters = []
for parameter in parameters: for parameter in parameters:
@ -187,22 +199,25 @@ class WorkflowToolProviderController(ToolProviderController):
if self.tools is not None: if self.tools is not None:
return self.tools return self.tools
db_providers: WorkflowToolProvider | None = ( with Session(db.engine, expire_on_commit=False) as session, session.begin():
db.session.query(WorkflowToolProvider) db_provider: WorkflowToolProvider | None = (
.where( session.query(WorkflowToolProvider)
WorkflowToolProvider.tenant_id == tenant_id, .where(
WorkflowToolProvider.app_id == self.provider_id, WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.app_id == self.provider_id,
)
.first()
) )
.first()
)
if not db_providers: if not db_provider:
return [] return []
if not db_providers.app:
raise ValueError("app not found")
app = db_providers.app app = session.get(App, db_provider.app_id)
self.tools = [self._get_db_provider_tool(db_providers, app)] if not app:
raise ValueError("app not found")
user = session.get(Account, db_provider.user_id) if db_provider.user_id else None
self.tools = [self._get_db_provider_tool(db_provider, app, session=session, user=user)]
return self.tools return self.tools

View File

@ -1,12 +1,14 @@
import json import json
import logging import logging
from collections.abc import Generator from collections.abc import Generator, Mapping, Sequence
from typing import Any from typing import Any, cast
from flask import has_request_context from flask import has_request_context
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.tools.__base.tool import Tool from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
@ -48,6 +50,7 @@ class WorkflowTool(Tool):
self.workflow_entities = workflow_entities self.workflow_entities = workflow_entities
self.workflow_call_depth = workflow_call_depth self.workflow_call_depth = workflow_call_depth
self.label = label self.label = label
self._latest_usage = LLMUsage.empty_usage()
super().__init__(entity=entity, runtime=runtime) super().__init__(entity=entity, runtime=runtime)
@ -83,10 +86,11 @@ class WorkflowTool(Tool):
assert self.runtime.invoke_from is not None assert self.runtime.invoke_from is not None
user = self._resolve_user(user_id=user_id) user = self._resolve_user(user_id=user_id)
if user is None: if user is None:
raise ToolInvokeError("User not found") raise ToolInvokeError("User not found")
self._latest_usage = LLMUsage.empty_usage()
result = generator.generate( result = generator.generate(
app_model=app, app_model=app,
workflow=workflow, workflow=workflow,
@ -110,9 +114,68 @@ class WorkflowTool(Tool):
for file in files: for file in files:
yield self.create_file_message(file) # type: ignore yield self.create_file_message(file) # type: ignore
self._latest_usage = self._derive_usage_from_result(data)
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False)) yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
yield self.create_json_message(outputs) yield self.create_json_message(outputs)
@property
def latest_usage(self) -> LLMUsage:
return self._latest_usage
@classmethod
def _derive_usage_from_result(cls, data: Mapping[str, Any]) -> LLMUsage:
usage_dict = cls._extract_usage_dict(data)
if usage_dict is not None:
return LLMUsage.from_metadata(cast(LLMUsageMetadata, dict(usage_dict)))
total_tokens = data.get("total_tokens")
total_price = data.get("total_price")
if total_tokens is None and total_price is None:
return LLMUsage.empty_usage()
usage_metadata: dict[str, Any] = {}
if total_tokens is not None:
try:
usage_metadata["total_tokens"] = int(str(total_tokens))
except (TypeError, ValueError):
pass
if total_price is not None:
usage_metadata["total_price"] = str(total_price)
currency = data.get("currency")
if currency is not None:
usage_metadata["currency"] = currency
if not usage_metadata:
return LLMUsage.empty_usage()
return LLMUsage.from_metadata(cast(LLMUsageMetadata, usage_metadata))
@classmethod
def _extract_usage_dict(cls, payload: Mapping[str, Any]) -> Mapping[str, Any] | None:
usage_candidate = payload.get("usage")
if isinstance(usage_candidate, Mapping):
return usage_candidate
metadata_candidate = payload.get("metadata")
if isinstance(metadata_candidate, Mapping):
usage_candidate = metadata_candidate.get("usage")
if isinstance(usage_candidate, Mapping):
return usage_candidate
for value in payload.values():
if isinstance(value, Mapping):
found = cls._extract_usage_dict(value)
if found is not None:
return found
elif isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
for item in value:
if isinstance(item, Mapping):
found = cls._extract_usage_dict(item)
if found is not None:
return found
return None
def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool": def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
""" """
fork a new tool with metadata fork a new tool with metadata
@ -179,16 +242,17 @@ class WorkflowTool(Tool):
""" """
get the workflow by app id and version get the workflow by app id and version
""" """
if not version: with Session(db.engine, expire_on_commit=False) as session, session.begin():
workflow = ( if not version:
db.session.query(Workflow) stmt = (
.where(Workflow.app_id == app_id, Workflow.version != Workflow.VERSION_DRAFT) select(Workflow)
.order_by(Workflow.created_at.desc()) .where(Workflow.app_id == app_id, Workflow.version != Workflow.VERSION_DRAFT)
.first() .order_by(Workflow.created_at.desc())
) )
else: workflow = session.scalars(stmt).first()
stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version) else:
workflow = db.session.scalar(stmt) stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version)
workflow = session.scalar(stmt)
if not workflow: if not workflow:
raise ValueError("workflow not found or not published") raise ValueError("workflow not found or not published")
@ -200,7 +264,8 @@ class WorkflowTool(Tool):
get the app by app id get the app by app id
""" """
stmt = select(App).where(App.id == app_id) stmt = select(App).where(App.id == app_id)
app = db.session.scalar(stmt) with Session(db.engine, expire_on_commit=False) as session, session.begin():
app = session.scalar(stmt)
if not app: if not app:
raise ValueError("app not found") raise ValueError("app not found")

View File

@ -4,7 +4,7 @@ from .types import SegmentType
class SegmentGroup(Segment): class SegmentGroup(Segment):
value_type: SegmentType = SegmentType.GROUP value_type: SegmentType = SegmentType.GROUP
value: list[Segment] = None # type: ignore value: list[Segment]
@property @property
def text(self): def text(self):

View File

@ -19,7 +19,7 @@ class Segment(BaseModel):
model_config = ConfigDict(frozen=True) model_config = ConfigDict(frozen=True)
value_type: SegmentType value_type: SegmentType
value: Any = None value: Any
@field_validator("value_type") @field_validator("value_type")
@classmethod @classmethod
@ -74,12 +74,12 @@ class NoneSegment(Segment):
class StringSegment(Segment): class StringSegment(Segment):
value_type: SegmentType = SegmentType.STRING value_type: SegmentType = SegmentType.STRING
value: str = None # type: ignore value: str
class FloatSegment(Segment): class FloatSegment(Segment):
value_type: SegmentType = SegmentType.FLOAT value_type: SegmentType = SegmentType.FLOAT
value: float = None # type: ignore value: float
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems. # NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
# The following tests cannot pass. # The following tests cannot pass.
# #
@ -98,12 +98,12 @@ class FloatSegment(Segment):
class IntegerSegment(Segment): class IntegerSegment(Segment):
value_type: SegmentType = SegmentType.INTEGER value_type: SegmentType = SegmentType.INTEGER
value: int = None # type: ignore value: int
class ObjectSegment(Segment): class ObjectSegment(Segment):
value_type: SegmentType = SegmentType.OBJECT value_type: SegmentType = SegmentType.OBJECT
value: Mapping[str, Any] = None # type: ignore value: Mapping[str, Any]
@property @property
def text(self) -> str: def text(self) -> str:
@ -136,7 +136,7 @@ class ArraySegment(Segment):
class FileSegment(Segment): class FileSegment(Segment):
value_type: SegmentType = SegmentType.FILE value_type: SegmentType = SegmentType.FILE
value: File = None # type: ignore value: File
@property @property
def markdown(self) -> str: def markdown(self) -> str:
@ -153,17 +153,17 @@ class FileSegment(Segment):
class BooleanSegment(Segment): class BooleanSegment(Segment):
value_type: SegmentType = SegmentType.BOOLEAN value_type: SegmentType = SegmentType.BOOLEAN
value: bool = None # type: ignore value: bool
class ArrayAnySegment(ArraySegment): class ArrayAnySegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_ANY value_type: SegmentType = SegmentType.ARRAY_ANY
value: Sequence[Any] = None # type: ignore value: Sequence[Any]
class ArrayStringSegment(ArraySegment): class ArrayStringSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_STRING value_type: SegmentType = SegmentType.ARRAY_STRING
value: Sequence[str] = None # type: ignore value: Sequence[str]
@property @property
def text(self) -> str: def text(self) -> str:
@ -175,17 +175,17 @@ class ArrayStringSegment(ArraySegment):
class ArrayNumberSegment(ArraySegment): class ArrayNumberSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_NUMBER value_type: SegmentType = SegmentType.ARRAY_NUMBER
value: Sequence[float | int] = None # type: ignore value: Sequence[float | int]
class ArrayObjectSegment(ArraySegment): class ArrayObjectSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_OBJECT value_type: SegmentType = SegmentType.ARRAY_OBJECT
value: Sequence[Mapping[str, Any]] = None # type: ignore value: Sequence[Mapping[str, Any]]
class ArrayFileSegment(ArraySegment): class ArrayFileSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_FILE value_type: SegmentType = SegmentType.ARRAY_FILE
value: Sequence[File] = None # type: ignore value: Sequence[File]
@property @property
def markdown(self) -> str: def markdown(self) -> str:
@ -205,7 +205,7 @@ class ArrayFileSegment(ArraySegment):
class ArrayBooleanSegment(ArraySegment): class ArrayBooleanSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_BOOLEAN value_type: SegmentType = SegmentType.ARRAY_BOOLEAN
value: Sequence[bool] = None # type: ignore value: Sequence[bool]
def get_segment_discriminator(v: Any) -> SegmentType | None: def get_segment_discriminator(v: Any) -> SegmentType | None:

View File

@ -1,3 +1,5 @@
from ..runtime.graph_runtime_state import GraphRuntimeState
from ..runtime.variable_pool import VariablePool
from .agent import AgentNodeStrategyInit from .agent import AgentNodeStrategyInit
from .graph_init_params import GraphInitParams from .graph_init_params import GraphInitParams
from .workflow_execution import WorkflowExecution from .workflow_execution import WorkflowExecution
@ -6,6 +8,8 @@ from .workflow_node_execution import WorkflowNodeExecution
__all__ = [ __all__ = [
"AgentNodeStrategyInit", "AgentNodeStrategyInit",
"GraphInitParams", "GraphInitParams",
"GraphRuntimeState",
"VariablePool",
"WorkflowExecution", "WorkflowExecution",
"WorkflowNodeExecution", "WorkflowNodeExecution",
] ]

View File

@ -3,11 +3,12 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Protocol, cast, final from typing import Protocol, cast, final
from core.workflow.enums import NodeExecutionType, NodeState, NodeType from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType
from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.node import Node
from libs.typing import is_str, is_str_dict from libs.typing import is_str, is_str_dict
from .edge import Edge from .edge import Edge
from .validation import get_graph_validator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -201,6 +202,17 @@ class Graph:
return GraphBuilder(graph_cls=cls) return GraphBuilder(graph_cls=cls)
@classmethod
def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None:
"""
Promote nodes configured with FAIL_BRANCH error strategy to branch execution type.
:param nodes: mapping of node ID to node instance
"""
for node in nodes.values():
if node.error_strategy == ErrorStrategy.FAIL_BRANCH:
node.execution_type = NodeExecutionType.BRANCH
@classmethod @classmethod
def _mark_inactive_root_branches( def _mark_inactive_root_branches(
cls, cls,
@ -307,6 +319,9 @@ class Graph:
# Create node instances # Create node instances
nodes = cls._create_node_instances(node_configs_map, node_factory) nodes = cls._create_node_instances(node_configs_map, node_factory)
# Promote fail-branch nodes to branch execution type at graph level
cls._promote_fail_branch_nodes(nodes)
# Get root node instance # Get root node instance
root_node = nodes[root_node_id] root_node = nodes[root_node_id]
@ -314,7 +329,7 @@ class Graph:
cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id) cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id)
# Create and return the graph # Create and return the graph
return cls( graph = cls(
nodes=nodes, nodes=nodes,
edges=edges, edges=edges,
in_edges=in_edges, in_edges=in_edges,
@ -322,6 +337,11 @@ class Graph:
root_node=root_node, root_node=root_node,
) )
# Validate the graph structure using built-in validators
get_graph_validator().validate(graph)
return graph
@property @property
def node_ids(self) -> list[str]: def node_ids(self) -> list[str]:
""" """

View File

@ -0,0 +1,125 @@
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Protocol
from core.workflow.enums import NodeExecutionType, NodeType
if TYPE_CHECKING:
from .graph import Graph
@dataclass(frozen=True, slots=True)
class GraphValidationIssue:
"""Immutable value object describing a single validation issue."""
code: str
message: str
node_id: str | None = None
class GraphValidationError(ValueError):
"""Raised when graph validation fails."""
def __init__(self, issues: Sequence[GraphValidationIssue]) -> None:
if not issues:
raise ValueError("GraphValidationError requires at least one issue.")
self.issues: tuple[GraphValidationIssue, ...] = tuple(issues)
message = "; ".join(f"[{issue.code}] {issue.message}" for issue in self.issues)
super().__init__(message)
class GraphValidationRule(Protocol):
"""Protocol that individual validation rules must satisfy."""
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
"""Validate the provided graph and return any discovered issues."""
...
@dataclass(frozen=True, slots=True)
class _EdgeEndpointValidator:
"""Ensures all edges reference existing nodes."""
missing_node_code: str = "MISSING_NODE"
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
issues: list[GraphValidationIssue] = []
for edge in graph.edges.values():
if edge.tail not in graph.nodes:
issues.append(
GraphValidationIssue(
code=self.missing_node_code,
message=f"Edge {edge.id} references unknown source node '{edge.tail}'.",
node_id=edge.tail,
)
)
if edge.head not in graph.nodes:
issues.append(
GraphValidationIssue(
code=self.missing_node_code,
message=f"Edge {edge.id} references unknown target node '{edge.head}'.",
node_id=edge.head,
)
)
return issues
@dataclass(frozen=True, slots=True)
class _RootNodeValidator:
"""Validates root node invariants."""
invalid_root_code: str = "INVALID_ROOT"
container_entry_types: tuple[NodeType, ...] = (NodeType.ITERATION_START, NodeType.LOOP_START)
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
root_node = graph.root_node
issues: list[GraphValidationIssue] = []
if root_node.id not in graph.nodes:
issues.append(
GraphValidationIssue(
code=self.invalid_root_code,
message=f"Root node '{root_node.id}' is missing from the node registry.",
node_id=root_node.id,
)
)
return issues
node_type = getattr(root_node, "node_type", None)
if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types:
issues.append(
GraphValidationIssue(
code=self.invalid_root_code,
message=f"Root node '{root_node.id}' must declare execution type 'root'.",
node_id=root_node.id,
)
)
return issues
@dataclass(frozen=True, slots=True)
class GraphValidator:
"""Coordinates execution of graph validation rules."""
rules: tuple[GraphValidationRule, ...]
def validate(self, graph: Graph) -> None:
"""Validate the graph against all configured rules."""
issues: list[GraphValidationIssue] = []
for rule in self.rules:
issues.extend(rule.validate(graph))
if issues:
raise GraphValidationError(issues)
_DEFAULT_RULES: tuple[GraphValidationRule, ...] = (
_EdgeEndpointValidator(),
_RootNodeValidator(),
)
def get_graph_validator() -> GraphValidator:
"""Construct the validator composed of default rules."""
return GraphValidator(_DEFAULT_RULES)

View File

@ -26,8 +26,8 @@ class AgentNodeData(BaseNodeData):
class ParamsAutoGenerated(IntEnum): class ParamsAutoGenerated(IntEnum):
CLOSE = auto() CLOSE = 0
OPEN = auto() OPEN = 1
class AgentOldVersionModelFeatures(StrEnum): class AgentOldVersionModelFeatures(StrEnum):

View File

@ -1,4 +1,5 @@
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
from .usage_tracking_mixin import LLMUsageTrackingMixin
__all__ = [ __all__ = [
"BaseIterationNodeData", "BaseIterationNodeData",
@ -6,4 +7,5 @@ __all__ = [
"BaseLoopNodeData", "BaseLoopNodeData",
"BaseLoopState", "BaseLoopState",
"BaseNodeData", "BaseNodeData",
"LLMUsageTrackingMixin",
] ]

View File

@ -1,5 +1,6 @@
import json import json
from abc import ABC from abc import ABC
from builtins import type as type_
from collections.abc import Sequence from collections.abc import Sequence
from enum import StrEnum from enum import StrEnum
from typing import Any, Union from typing import Any, Union
@ -58,10 +59,9 @@ class DefaultValue(BaseModel):
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}") raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
@staticmethod @staticmethod
def _validate_array(value: Any, element_type: DefaultValueType) -> bool: def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
"""Unified array type validation""" """Unified array type validation"""
# FIXME, type ignore here for do not find the reason mypy complain, if find the root cause, please fix it return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
return isinstance(value, list) and all(isinstance(x, element_type) for x in value) # type: ignore
@staticmethod @staticmethod
def _convert_number(value: str) -> float: def _convert_number(value: str) -> float:

View File

@ -0,0 +1,28 @@
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.runtime import GraphRuntimeState
class LLMUsageTrackingMixin:
"""Provides shared helpers for merging and recording LLM usage within workflow nodes."""
graph_runtime_state: GraphRuntimeState
@staticmethod
def _merge_usage(current: LLMUsage, new_usage: LLMUsage | None) -> LLMUsage:
"""Return a combined usage snapshot, preserving zero-value inputs."""
if new_usage is None or new_usage.total_tokens <= 0:
return current
if current.total_tokens == 0:
return new_usage
return current.plus(new_usage)
def _accumulate_usage(self, usage: LLMUsage) -> None:
"""Push usage into the graph runtime accumulator for downstream reporting."""
if usage.total_tokens <= 0:
return
current_usage = self.graph_runtime_state.llm_usage
if current_usage.total_tokens == 0:
self.graph_runtime_state.llm_usage = usage.model_copy()
else:
self.graph_runtime_state.llm_usage = current_usage.plus(usage)

View File

@ -10,10 +10,10 @@ from typing import Any
import chardet import chardet
import docx import docx
import pandas as pd import pandas as pd
import pypandoc # type: ignore import pypandoc
import pypdfium2 # type: ignore import pypdfium2
import webvtt # type: ignore import webvtt
import yaml # type: ignore import yaml
from docx.document import Document from docx.document import Document
from docx.oxml.table import CT_Tbl from docx.oxml.table import CT_Tbl
from docx.oxml.text.paragraph import CT_P from docx.oxml.text.paragraph import CT_P

View File

@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, NewType, cast
from flask import Flask, current_app from flask import Flask, current_app
from typing_extensions import TypeIs from typing_extensions import TypeIs
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import IntegerVariable, NoneSegment from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment from core.variables.segments import ArrayAnySegment, ArraySegment
from core.variables.variables import VariableUnion from core.variables.variables import VariableUnion
@ -34,6 +35,7 @@ from core.workflow.node_events import (
NodeRunResult, NodeRunResult,
StreamCompletedEvent, StreamCompletedEvent,
) )
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
@ -58,7 +60,7 @@ logger = logging.getLogger(__name__)
EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment) EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment)
class IterationNode(Node): class IterationNode(LLMUsageTrackingMixin, Node):
""" """
Iteration Node. Iteration Node.
""" """
@ -118,6 +120,7 @@ class IterationNode(Node):
started_at = naive_utc_now() started_at = naive_utc_now()
iter_run_map: dict[str, float] = {} iter_run_map: dict[str, float] = {}
outputs: list[object] = [] outputs: list[object] = []
usage_accumulator = [LLMUsage.empty_usage()]
yield IterationStartedEvent( yield IterationStartedEvent(
start_at=started_at, start_at=started_at,
@ -130,22 +133,27 @@ class IterationNode(Node):
iterator_list_value=iterator_list_value, iterator_list_value=iterator_list_value,
outputs=outputs, outputs=outputs,
iter_run_map=iter_run_map, iter_run_map=iter_run_map,
usage_accumulator=usage_accumulator,
) )
self._accumulate_usage(usage_accumulator[0])
yield from self._handle_iteration_success( yield from self._handle_iteration_success(
started_at=started_at, started_at=started_at,
inputs=inputs, inputs=inputs,
outputs=outputs, outputs=outputs,
iterator_list_value=iterator_list_value, iterator_list_value=iterator_list_value,
iter_run_map=iter_run_map, iter_run_map=iter_run_map,
usage=usage_accumulator[0],
) )
except IterationNodeError as e: except IterationNodeError as e:
self._accumulate_usage(usage_accumulator[0])
yield from self._handle_iteration_failure( yield from self._handle_iteration_failure(
started_at=started_at, started_at=started_at,
inputs=inputs, inputs=inputs,
outputs=outputs, outputs=outputs,
iterator_list_value=iterator_list_value, iterator_list_value=iterator_list_value,
iter_run_map=iter_run_map, iter_run_map=iter_run_map,
usage=usage_accumulator[0],
error=e, error=e,
) )
@ -196,6 +204,7 @@ class IterationNode(Node):
iterator_list_value: Sequence[object], iterator_list_value: Sequence[object],
outputs: list[object], outputs: list[object],
iter_run_map: dict[str, float], iter_run_map: dict[str, float],
usage_accumulator: list[LLMUsage],
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: ) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
if self._node_data.is_parallel: if self._node_data.is_parallel:
# Parallel mode execution # Parallel mode execution
@ -203,6 +212,7 @@ class IterationNode(Node):
iterator_list_value=iterator_list_value, iterator_list_value=iterator_list_value,
outputs=outputs, outputs=outputs,
iter_run_map=iter_run_map, iter_run_map=iter_run_map,
usage_accumulator=usage_accumulator,
) )
else: else:
# Sequential mode execution # Sequential mode execution
@ -228,6 +238,9 @@ class IterationNode(Node):
# Update the total tokens from this iteration # Update the total tokens from this iteration
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
usage_accumulator[0] = self._merge_usage(
usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage
)
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
def _execute_parallel_iterations( def _execute_parallel_iterations(
@ -235,6 +248,7 @@ class IterationNode(Node):
iterator_list_value: Sequence[object], iterator_list_value: Sequence[object],
outputs: list[object], outputs: list[object],
iter_run_map: dict[str, float], iter_run_map: dict[str, float],
usage_accumulator: list[LLMUsage],
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: ) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
# Initialize outputs list with None values to maintain order # Initialize outputs list with None values to maintain order
outputs.extend([None] * len(iterator_list_value)) outputs.extend([None] * len(iterator_list_value))
@ -245,7 +259,16 @@ class IterationNode(Node):
with ThreadPoolExecutor(max_workers=max_workers) as executor: with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all iteration tasks # Submit all iteration tasks
future_to_index: dict[ future_to_index: dict[
Future[tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]], Future[
tuple[
datetime,
list[GraphNodeEventBase],
object | None,
int,
dict[str, VariableUnion],
LLMUsage,
]
],
int, int,
] = {} ] = {}
for index, item in enumerate(iterator_list_value): for index, item in enumerate(iterator_list_value):
@ -264,7 +287,14 @@ class IterationNode(Node):
index = future_to_index[future] index = future_to_index[future]
try: try:
result = future.result() result = future.result()
iter_start_at, events, output_value, tokens_used, conversation_snapshot = result (
iter_start_at,
events,
output_value,
tokens_used,
conversation_snapshot,
iteration_usage,
) = result
# Update outputs at the correct index # Update outputs at the correct index
outputs[index] = output_value outputs[index] = output_value
@ -276,6 +306,8 @@ class IterationNode(Node):
self.graph_runtime_state.total_tokens += tokens_used self.graph_runtime_state.total_tokens += tokens_used
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage)
# Sync conversation variables after iteration completion # Sync conversation variables after iteration completion
self._sync_conversation_variables_from_snapshot(conversation_snapshot) self._sync_conversation_variables_from_snapshot(conversation_snapshot)
@ -303,7 +335,7 @@ class IterationNode(Node):
item: object, item: object,
flask_app: Flask, flask_app: Flask,
context_vars: contextvars.Context, context_vars: contextvars.Context,
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]: ) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion], LLMUsage]:
"""Execute a single iteration in parallel mode and return results.""" """Execute a single iteration in parallel mode and return results."""
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars): with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
iter_start_at = datetime.now(UTC).replace(tzinfo=None) iter_start_at = datetime.now(UTC).replace(tzinfo=None)
@ -332,6 +364,7 @@ class IterationNode(Node):
output_value, output_value,
graph_engine.graph_runtime_state.total_tokens, graph_engine.graph_runtime_state.total_tokens,
conversation_snapshot, conversation_snapshot,
graph_engine.graph_runtime_state.llm_usage,
) )
def _handle_iteration_success( def _handle_iteration_success(
@ -341,6 +374,8 @@ class IterationNode(Node):
outputs: list[object], outputs: list[object],
iterator_list_value: Sequence[object], iterator_list_value: Sequence[object],
iter_run_map: dict[str, float], iter_run_map: dict[str, float],
*,
usage: LLMUsage,
) -> Generator[NodeEventBase, None, None]: ) -> Generator[NodeEventBase, None, None]:
# Flatten the list of lists if all outputs are lists # Flatten the list of lists if all outputs are lists
flattened_outputs = self._flatten_outputs_if_needed(outputs) flattened_outputs = self._flatten_outputs_if_needed(outputs)
@ -351,7 +386,9 @@ class IterationNode(Node):
outputs={"output": flattened_outputs}, outputs={"output": flattened_outputs},
steps=len(iterator_list_value), steps=len(iterator_list_value),
metadata={ metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
}, },
) )
@ -362,8 +399,11 @@ class IterationNode(Node):
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": flattened_outputs}, outputs={"output": flattened_outputs},
metadata={ metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
}, },
llm_usage=usage,
) )
) )
@ -400,6 +440,8 @@ class IterationNode(Node):
outputs: list[object], outputs: list[object],
iterator_list_value: Sequence[object], iterator_list_value: Sequence[object],
iter_run_map: dict[str, float], iter_run_map: dict[str, float],
*,
usage: LLMUsage,
error: IterationNodeError, error: IterationNodeError,
) -> Generator[NodeEventBase, None, None]: ) -> Generator[NodeEventBase, None, None]:
# Flatten the list of lists if all outputs are lists (even in failure case) # Flatten the list of lists if all outputs are lists (even in failure case)
@ -411,7 +453,9 @@ class IterationNode(Node):
outputs={"output": flattened_outputs}, outputs={"output": flattened_outputs},
steps=len(iterator_list_value), steps=len(iterator_list_value),
metadata={ metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
}, },
error=str(error), error=str(error),
@ -420,6 +464,12 @@ class IterationNode(Node):
node_run_result=NodeRunResult( node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
error=str(error), error=str(error),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
) )
) )

View File

@ -15,14 +15,11 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
from core.entities.agent_entities import PlanningStrategy from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus from core.entities.model_entities import ModelStatus
from core.model_manager import ModelInstance, ModelManager from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.llm_entities import LLMUsage
PromptMessageRole, from core.model_runtime.entities.message_entities import PromptMessageRole
) from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.entities.model_entities import (
ModelFeature,
ModelType,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.simple_prompt_transform import ModelMode from core.prompt.simple_prompt_transform import ModelMode
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.metadata_entities import Condition, MetadataCondition from core.rag.entities.metadata_entities import Condition, MetadataCondition
@ -33,8 +30,14 @@ from core.variables import (
) )
from core.variables.segments import ArrayObjectSegment from core.variables.segments import ArrayObjectSegment
from core.workflow.entities import GraphInitParams from core.workflow.entities import GraphInitParams
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus from core.workflow.enums import (
ErrorStrategy,
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.node import Node
from core.workflow.nodes.knowledge_retrieval.template_prompts import ( from core.workflow.nodes.knowledge_retrieval.template_prompts import (
@ -80,7 +83,7 @@ default_retrieval_model = {
} }
class KnowledgeRetrievalNode(Node): class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node):
node_type = NodeType.KNOWLEDGE_RETRIEVAL node_type = NodeType.KNOWLEDGE_RETRIEVAL
_node_data: KnowledgeRetrievalNodeData _node_data: KnowledgeRetrievalNodeData
@ -141,7 +144,7 @@ class KnowledgeRetrievalNode(Node):
def version(cls): def version(cls):
return "1" return "1"
def _run(self) -> NodeRunResult: # type: ignore def _run(self) -> NodeRunResult:
# extract variables # extract variables
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector) variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
if not isinstance(variable, StringSegment): if not isinstance(variable, StringSegment):
@ -182,14 +185,21 @@ class KnowledgeRetrievalNode(Node):
) )
# retrieve knowledge # retrieve knowledge
usage = LLMUsage.empty_usage()
try: try:
results = self._fetch_dataset_retriever(node_data=self._node_data, query=query) results, usage = self._fetch_dataset_retriever(node_data=self._node_data, query=query)
outputs = {"result": ArrayObjectSegment(value=results)} outputs = {"result": ArrayObjectSegment(value=results)}
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables, inputs=variables,
process_data={}, process_data={"usage": jsonable_encoder(usage)},
outputs=outputs, # type: ignore outputs=outputs, # type: ignore
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
) )
except KnowledgeRetrievalNodeError as e: except KnowledgeRetrievalNodeError as e:
@ -199,6 +209,7 @@ class KnowledgeRetrievalNode(Node):
inputs=variables, inputs=variables,
error=str(e), error=str(e),
error_type=type(e).__name__, error_type=type(e).__name__,
llm_usage=usage,
) )
# Temporary handle all exceptions from DatasetRetrieval class here. # Temporary handle all exceptions from DatasetRetrieval class here.
except Exception as e: except Exception as e:
@ -207,11 +218,15 @@ class KnowledgeRetrievalNode(Node):
inputs=variables, inputs=variables,
error=str(e), error=str(e),
error_type=type(e).__name__, error_type=type(e).__name__,
llm_usage=usage,
) )
finally: finally:
db.session.close() db.session.close()
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]: def _fetch_dataset_retriever(
self, node_data: KnowledgeRetrievalNodeData, query: str
) -> tuple[list[dict[str, Any]], LLMUsage]:
usage = LLMUsage.empty_usage()
available_datasets = [] available_datasets = []
dataset_ids = node_data.dataset_ids dataset_ids = node_data.dataset_ids
@ -245,9 +260,10 @@ class KnowledgeRetrievalNode(Node):
if not dataset: if not dataset:
continue continue
available_datasets.append(dataset) available_datasets.append(dataset)
metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition( metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition(
[dataset.id for dataset in available_datasets], query, node_data [dataset.id for dataset in available_datasets], query, node_data
) )
usage = self._merge_usage(usage, metadata_usage)
all_documents = [] all_documents = []
dataset_retrieval = DatasetRetrieval() dataset_retrieval = DatasetRetrieval()
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
@ -330,6 +346,8 @@ class KnowledgeRetrievalNode(Node):
metadata_filter_document_ids=metadata_filter_document_ids, metadata_filter_document_ids=metadata_filter_document_ids,
metadata_condition=metadata_condition, metadata_condition=metadata_condition,
) )
usage = self._merge_usage(usage, dataset_retrieval.llm_usage)
dify_documents = [item for item in all_documents if item.provider == "dify"] dify_documents = [item for item in all_documents if item.provider == "dify"]
external_documents = [item for item in all_documents if item.provider == "external"] external_documents = [item for item in all_documents if item.provider == "external"]
retrieval_resource_list = [] retrieval_resource_list = []
@ -406,11 +424,12 @@ class KnowledgeRetrievalNode(Node):
) )
for position, item in enumerate(retrieval_resource_list, start=1): for position, item in enumerate(retrieval_resource_list, start=1):
item["metadata"]["position"] = position item["metadata"]["position"] = position
return retrieval_resource_list return retrieval_resource_list, usage
def _get_metadata_filter_condition( def _get_metadata_filter_condition(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]: ) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]:
usage = LLMUsage.empty_usage()
document_query = db.session.query(Document).where( document_query = db.session.query(Document).where(
Document.dataset_id.in_(dataset_ids), Document.dataset_id.in_(dataset_ids),
Document.indexing_status == "completed", Document.indexing_status == "completed",
@ -420,9 +439,12 @@ class KnowledgeRetrievalNode(Node):
filters: list[Any] = [] filters: list[Any] = []
metadata_condition = None metadata_condition = None
if node_data.metadata_filtering_mode == "disabled": if node_data.metadata_filtering_mode == "disabled":
return None, None return None, None, usage
elif node_data.metadata_filtering_mode == "automatic": elif node_data.metadata_filtering_mode == "automatic":
automatic_metadata_filters = self._automatic_metadata_filter_func(dataset_ids, query, node_data) automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
dataset_ids, query, node_data
)
usage = self._merge_usage(usage, automatic_usage)
if automatic_metadata_filters: if automatic_metadata_filters:
conditions = [] conditions = []
for sequence, filter in enumerate(automatic_metadata_filters): for sequence, filter in enumerate(automatic_metadata_filters):
@ -443,7 +465,7 @@ class KnowledgeRetrievalNode(Node):
metadata_condition = MetadataCondition( metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator logical_operator=node_data.metadata_filtering_conditions.logical_operator
if node_data.metadata_filtering_conditions if node_data.metadata_filtering_conditions
else "or", # type: ignore else "or",
conditions=conditions, conditions=conditions,
) )
elif node_data.metadata_filtering_mode == "manual": elif node_data.metadata_filtering_mode == "manual":
@ -457,10 +479,10 @@ class KnowledgeRetrievalNode(Node):
expected_value = self.graph_runtime_state.variable_pool.convert_template( expected_value = self.graph_runtime_state.variable_pool.convert_template(
expected_value expected_value
).value[0] ).value[0]
if expected_value.value_type in {"number", "integer", "float"}: # type: ignore if expected_value.value_type in {"number", "integer", "float"}:
expected_value = expected_value.value # type: ignore expected_value = expected_value.value
elif expected_value.value_type == "string": # type: ignore elif expected_value.value_type == "string":
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
else: else:
raise ValueError("Invalid expected metadata value type") raise ValueError("Invalid expected metadata value type")
conditions.append( conditions.append(
@ -487,7 +509,7 @@ class KnowledgeRetrievalNode(Node):
if ( if (
node_data.metadata_filtering_conditions node_data.metadata_filtering_conditions
and node_data.metadata_filtering_conditions.logical_operator == "and" and node_data.metadata_filtering_conditions.logical_operator == "and"
): # type: ignore ):
document_query = document_query.where(and_(*filters)) document_query = document_query.where(and_(*filters))
else: else:
document_query = document_query.where(or_(*filters)) document_query = document_query.where(or_(*filters))
@ -496,11 +518,12 @@ class KnowledgeRetrievalNode(Node):
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
for document in documents: for document in documents:
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
return metadata_filter_document_ids, metadata_condition return metadata_filter_document_ids, metadata_condition, usage
def _automatic_metadata_filter_func( def _automatic_metadata_filter_func(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> list[dict[str, Any]]: ) -> tuple[list[dict[str, Any]], LLMUsage]:
usage = LLMUsage.empty_usage()
# get all metadata field # get all metadata field
stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)) stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
metadata_fields = db.session.scalars(stmt).all() metadata_fields = db.session.scalars(stmt).all()
@ -548,6 +571,7 @@ class KnowledgeRetrievalNode(Node):
for event in generator: for event in generator:
if isinstance(event, ModelInvokeCompletedEvent): if isinstance(event, ModelInvokeCompletedEvent):
result_text = event.text result_text = event.text
usage = self._merge_usage(usage, event.usage)
break break
result_text_json = parse_and_check_json_markdown(result_text, []) result_text_json = parse_and_check_json_markdown(result_text, [])
@ -564,8 +588,8 @@ class KnowledgeRetrievalNode(Node):
} }
) )
except Exception: except Exception:
return [] return [], usage
return automatic_metadata_filters return automatic_metadata_filters, usage
def _process_metadata_filter_func( def _process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any] self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]

View File

@ -441,10 +441,14 @@ class LLMNode(Node):
usage = LLMUsage.empty_usage() usage = LLMUsage.empty_usage()
finish_reason = None finish_reason = None
full_text_buffer = io.StringIO() full_text_buffer = io.StringIO()
collected_structured_output = None # Collect structured_output from streaming chunks
# Consume the invoke result and handle generator exception # Consume the invoke result and handle generator exception
try: try:
for result in invoke_result: for result in invoke_result:
if isinstance(result, LLMResultChunkWithStructuredOutput): if isinstance(result, LLMResultChunkWithStructuredOutput):
# Collect structured_output from the chunk
if result.structured_output is not None:
collected_structured_output = dict(result.structured_output)
yield result yield result
if isinstance(result, LLMResultChunk): if isinstance(result, LLMResultChunk):
contents = result.delta.message.content contents = result.delta.message.content
@ -492,6 +496,8 @@ class LLMNode(Node):
finish_reason=finish_reason, finish_reason=finish_reason,
# Reasoning content for workflow variables and downstream nodes # Reasoning content for workflow variables and downstream nodes
reasoning_content=reasoning_content, reasoning_content=reasoning_content,
# Pass structured output if collected from streaming chunks
structured_output=collected_structured_output,
) )
@staticmethod @staticmethod

View File

@ -5,6 +5,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, cast from typing import TYPE_CHECKING, Any, Literal, cast
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import Segment, SegmentType from core.variables import Segment, SegmentType
from core.workflow.enums import ( from core.workflow.enums import (
ErrorStrategy, ErrorStrategy,
@ -27,6 +28,7 @@ from core.workflow.node_events import (
NodeRunResult, NodeRunResult,
StreamCompletedEvent, StreamCompletedEvent,
) )
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.node import Node
from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData
@ -40,7 +42,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LoopNode(Node): class LoopNode(LLMUsageTrackingMixin, Node):
""" """
Loop Node. Loop Node.
""" """
@ -108,7 +110,7 @@ class LoopNode(Node):
raise ValueError(f"Invalid value for loop variable {loop_variable.label}") raise ValueError(f"Invalid value for loop variable {loop_variable.label}")
variable_selector = [self._node_id, loop_variable.label] variable_selector = [self._node_id, loop_variable.label]
variable = segment_to_variable(segment=processed_segment, selector=variable_selector) variable = segment_to_variable(segment=processed_segment, selector=variable_selector)
self.graph_runtime_state.variable_pool.add(variable_selector, variable) self.graph_runtime_state.variable_pool.add(variable_selector, variable.value)
loop_variable_selectors[loop_variable.label] = variable_selector loop_variable_selectors[loop_variable.label] = variable_selector
inputs[loop_variable.label] = processed_segment.value inputs[loop_variable.label] = processed_segment.value
@ -117,6 +119,7 @@ class LoopNode(Node):
loop_duration_map: dict[str, float] = {} loop_duration_map: dict[str, float] = {}
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
loop_usage = LLMUsage.empty_usage()
# Start Loop event # Start Loop event
yield LoopStartedEvent( yield LoopStartedEvent(
@ -163,6 +166,9 @@ class LoopNode(Node):
# Update the total tokens from this iteration # Update the total tokens from this iteration
cost_tokens += graph_engine.graph_runtime_state.total_tokens cost_tokens += graph_engine.graph_runtime_state.total_tokens
# Accumulate usage from the sub-graph execution
loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage)
# Collect loop variable values after iteration # Collect loop variable values after iteration
single_loop_variable = {} single_loop_variable = {}
for key, selector in loop_variable_selectors.items(): for key, selector in loop_variable_selectors.items():
@ -189,6 +195,7 @@ class LoopNode(Node):
) )
self.graph_runtime_state.total_tokens += cost_tokens self.graph_runtime_state.total_tokens += cost_tokens
self._accumulate_usage(loop_usage)
# Loop completed successfully # Loop completed successfully
yield LoopSucceededEvent( yield LoopSucceededEvent(
start_at=start_at, start_at=start_at,
@ -196,7 +203,9 @@ class LoopNode(Node):
outputs=self._node_data.outputs, outputs=self._node_data.outputs,
steps=loop_count, steps=loop_count,
metadata={ metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: cost_tokens, WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
"completed_reason": "loop_break" if reach_break_condition else "loop_completed", "completed_reason": "loop_break" if reach_break_condition else "loop_completed",
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
@ -207,22 +216,28 @@ class LoopNode(Node):
node_run_result=NodeRunResult( node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
metadata={ metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
}, },
outputs=self._node_data.outputs, outputs=self._node_data.outputs,
inputs=inputs, inputs=inputs,
llm_usage=loop_usage,
) )
) )
except Exception as e: except Exception as e:
self._accumulate_usage(loop_usage)
yield LoopFailedEvent( yield LoopFailedEvent(
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
steps=loop_count, steps=loop_count,
metadata={ metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
"completed_reason": "error", "completed_reason": "error",
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
@ -235,10 +250,13 @@ class LoopNode(Node):
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
error=str(e), error=str(e),
metadata={ metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
}, },
llm_usage=loop_usage,
) )
) )

View File

@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, final
from typing_extensions import override from typing_extensions import override
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType from core.workflow.enums import NodeType
from core.workflow.graph import NodeFactory from core.workflow.graph import NodeFactory
from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.node import Node
from libs.typing import is_str, is_str_dict from libs.typing import is_str, is_str_dict
@ -82,8 +82,4 @@ class DifyNodeFactory(NodeFactory):
raise ValueError(f"Node {node_id} missing data information") raise ValueError(f"Node {node_id} missing data information")
node_instance.init_node_data(node_data) node_instance.init_node_data(node_data)
# If node has fail branch, change execution type to branch
if node_instance.error_strategy == ErrorStrategy.FAIL_BRANCH:
node_instance.execution_type = NodeExecutionType.BRANCH
return node_instance return node_instance

View File

@ -747,7 +747,7 @@ class ParameterExtractorNode(Node):
if model_mode == ModelMode.CHAT: if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage( system_prompt_messages = ChatModelMessage(
role=PromptMessageRole.SYSTEM, role=PromptMessageRole.SYSTEM,
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str).replace("{{instructions}}", instruction), text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str, instructions=instruction),
) )
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
return [system_prompt_messages, user_prompt_message] return [system_prompt_messages, user_prompt_message]

View File

@ -135,7 +135,7 @@ Here are the chat histories between human and assistant, inside <histories></his
### Instructions: ### Instructions:
Some extra information are provided below, you should always follow the instructions as possible as you can. Some extra information are provided below, you should always follow the instructions as possible as you can.
<instructions> <instructions>
{{instructions}} {instructions}
</instructions> </instructions>
""" """

View File

@ -6,10 +6,13 @@ from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file import File, FileTransferMethod from core.file import File, FileTransferMethod
from core.model_runtime.entities.llm_entities import LLMUsage
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.tools.workflow_as_tool.tool import WorkflowTool
from core.variables.segments import ArrayAnySegment, ArrayFileSegment from core.variables.segments import ArrayAnySegment, ArrayFileSegment
from core.variables.variables import ArrayAnyVariable from core.variables.variables import ArrayAnyVariable
from core.workflow.enums import ( from core.workflow.enums import (
@ -136,13 +139,14 @@ class ToolNode(Node):
try: try:
# convert tool messages # convert tool messages
yield from self._transform_message( _ = yield from self._transform_message(
messages=message_stream, messages=message_stream,
tool_info=tool_info, tool_info=tool_info,
parameters_for_log=parameters_for_log, parameters_for_log=parameters_for_log,
user_id=self.user_id, user_id=self.user_id,
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
node_id=self._node_id, node_id=self._node_id,
tool_runtime=tool_runtime,
) )
except ToolInvokeError as e: except ToolInvokeError as e:
yield StreamCompletedEvent( yield StreamCompletedEvent(
@ -236,7 +240,8 @@ class ToolNode(Node):
user_id: str, user_id: str,
tenant_id: str, tenant_id: str,
node_id: str, node_id: str,
) -> Generator: tool_runtime: Tool,
) -> Generator[NodeEventBase, None, LLMUsage]:
""" """
Convert ToolInvokeMessages into tuple[plain_text, files] Convert ToolInvokeMessages into tuple[plain_text, files]
""" """
@ -424,17 +429,34 @@ class ToolNode(Node):
is_final=True, is_final=True,
) )
usage = self._extract_tool_usage(tool_runtime)
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
}
if usage.total_tokens > 0:
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency
yield StreamCompletedEvent( yield StreamCompletedEvent(
node_run_result=NodeRunResult( node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables}, outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
metadata={ metadata=metadata,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
},
inputs=parameters_for_log, inputs=parameters_for_log,
llm_usage=usage,
) )
) )
return usage
@staticmethod
def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage:
if isinstance(tool_runtime, WorkflowTool):
return tool_runtime.latest_usage
return LLMUsage.empty_usage()
@classmethod @classmethod
def _extract_variable_selector_to_variable_mapping( def _extract_variable_selector_to_variable_mapping(
cls, cls,

View File

@ -260,7 +260,7 @@ class VariablePool(BaseModel):
# This ensures that we can keep the id of the system variables intact. # This ensures that we can keep the id of the system variables intact.
if self._has(selector): if self._has(selector):
continue continue
self.add(selector, value) # type: ignore self.add(selector, value)
@classmethod @classmethod
def empty(cls) -> "VariablePool": def empty(cls) -> "VariablePool":

View File

@ -32,7 +32,8 @@ if [[ "${MODE}" == "worker" ]]; then
exec celery -A celery_entrypoint.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \ exec celery -A celery_entrypoint.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \
--max-tasks-per-child ${MAX_TASKS_PER_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \ --max-tasks-per-child ${MAX_TASKS_PER_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \
-Q ${CELERY_QUEUES:-dataset,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation} -Q ${CELERY_QUEUES:-dataset,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation} \
--prefetch-multiplier=1
elif [[ "${MODE}" == "beat" ]]; then elif [[ "${MODE}" == "beat" ]]; then
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO} exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}

View File

@ -6,8 +6,8 @@ from tasks.clean_dataset_task import clean_dataset_task
@dataset_was_deleted.connect @dataset_was_deleted.connect
def handle(sender: Dataset, **kwargs): def handle(sender: Dataset, **kwargs):
dataset = sender dataset = sender
assert dataset.doc_form if not dataset.doc_form or not dataset.indexing_technique:
assert dataset.indexing_technique return
clean_dataset_task.delay( clean_dataset_task.delay(
dataset.id, dataset.id,
dataset.tenant_id, dataset.tenant_id,

View File

@ -8,6 +8,6 @@ def handle(sender, **kwargs):
dataset_id = kwargs.get("dataset_id") dataset_id = kwargs.get("dataset_id")
doc_form = kwargs.get("doc_form") doc_form = kwargs.get("doc_form")
file_id = kwargs.get("file_id") file_id = kwargs.get("file_id")
assert dataset_id is not None if not dataset_id or not doc_form:
assert doc_form is not None return
clean_document_task.delay(document_id, dataset_id, doc_form, file_id) clean_document_task.delay(document_id, dataset_id, doc_form, file_id)

View File

@ -1,7 +1,12 @@
from configs import dify_config from configs import dify_config
from constants import HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN from constants import HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN, HEADER_NAME_PASSPORT
from dify_app import DifyApp from dify_app import DifyApp
BASE_CORS_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE, HEADER_NAME_PASSPORT)
SERVICE_API_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, "Authorization")
AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN)
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
def init_app(app: DifyApp): def init_app(app: DifyApp):
# register blueprint routers # register blueprint routers
@ -17,7 +22,7 @@ def init_app(app: DifyApp):
CORS( CORS(
service_api_bp, service_api_bp,
allow_headers=["Content-Type", "Authorization", HEADER_NAME_APP_CODE], allow_headers=list(SERVICE_API_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
) )
app.register_blueprint(service_api_bp) app.register_blueprint(service_api_bp)
@ -26,7 +31,7 @@ def init_app(app: DifyApp):
web_bp, web_bp,
resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}}, resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
supports_credentials=True, supports_credentials=True,
allow_headers=["Content-Type", "Authorization", HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN], allow_headers=list(AUTHENTICATED_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=["X-Version", "X-Env"], expose_headers=["X-Version", "X-Env"],
) )
@ -36,7 +41,7 @@ def init_app(app: DifyApp):
console_app_bp, console_app_bp,
resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}}, resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
supports_credentials=True, supports_credentials=True,
allow_headers=["Content-Type", "Authorization", HEADER_NAME_CSRF_TOKEN], allow_headers=list(AUTHENTICATED_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=["X-Version", "X-Env"], expose_headers=["X-Version", "X-Env"],
) )
@ -44,7 +49,7 @@ def init_app(app: DifyApp):
CORS( CORS(
files_bp, files_bp,
allow_headers=["Content-Type", HEADER_NAME_CSRF_TOKEN], allow_headers=list(FILES_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
) )
app.register_blueprint(files_bp) app.register_blueprint(files_bp)

View File

@ -7,7 +7,7 @@ def is_enabled() -> bool:
def init_app(app: DifyApp): def init_app(app: DifyApp):
from flask_compress import Compress # type: ignore from flask_compress import Compress
compress = Compress() compress = Compress()
compress.init_app(app) compress.init_app(app)

View File

@ -1,6 +1,6 @@
import json import json
import flask_login # type: ignore import flask_login
from flask import Response, request from flask import Response, request
from flask_login import user_loaded_from_request, user_logged_in from flask_login import user_loaded_from_request, user_logged_in
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import NotFound, Unauthorized

View File

@ -2,7 +2,7 @@ from dify_app import DifyApp
def init_app(app: DifyApp): def init_app(app: DifyApp):
import flask_migrate # type: ignore import flask_migrate
from extensions.ext_database import db from extensions.ext_database import db

View File

@ -103,7 +103,7 @@ def init_app(app: DifyApp):
def shutdown_tracer(): def shutdown_tracer():
provider = trace.get_tracer_provider() provider = trace.get_tracer_provider()
if hasattr(provider, "force_flush"): if hasattr(provider, "force_flush"):
provider.force_flush() # ty: ignore [call-non-callable] provider.force_flush()
class ExceptionLoggingHandler(logging.Handler): class ExceptionLoggingHandler(logging.Handler):
"""Custom logging handler that creates spans for logging.exception() calls""" """Custom logging handler that creates spans for logging.exception() calls"""

View File

@ -6,4 +6,4 @@ def init_app(app: DifyApp):
if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED: if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED:
from werkzeug.middleware.proxy_fix import ProxyFix from werkzeug.middleware.proxy_fix import ProxyFix
app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore[method-assign]

View File

@ -5,7 +5,7 @@ from dify_app import DifyApp
def init_app(app: DifyApp): def init_app(app: DifyApp):
if dify_config.SENTRY_DSN: if dify_config.SENTRY_DSN:
import sentry_sdk import sentry_sdk
from langfuse import parse_error # type: ignore from langfuse import parse_error
from sentry_sdk.integrations.celery import CeleryIntegration from sentry_sdk.integrations.celery import CeleryIntegration
from sentry_sdk.integrations.flask import FlaskIntegration from sentry_sdk.integrations.flask import FlaskIntegration
from werkzeug.exceptions import HTTPException from werkzeug.exceptions import HTTPException

View File

@ -1,7 +1,7 @@
import posixpath import posixpath
from collections.abc import Generator from collections.abc import Generator
import oss2 as aliyun_s3 # type: ignore import oss2 as aliyun_s3
from configs import dify_config from configs import dify_config
from extensions.storage.base_storage import BaseStorage from extensions.storage.base_storage import BaseStorage

View File

@ -2,9 +2,9 @@ import base64
import hashlib import hashlib
from collections.abc import Generator from collections.abc import Generator
from baidubce.auth.bce_credentials import BceCredentials # type: ignore from baidubce.auth.bce_credentials import BceCredentials
from baidubce.bce_client_configuration import BceClientConfiguration # type: ignore from baidubce.bce_client_configuration import BceClientConfiguration
from baidubce.services.bos.bos_client import BosClient # type: ignore from baidubce.services.bos.bos_client import BosClient
from configs import dify_config from configs import dify_config
from extensions.storage.base_storage import BaseStorage from extensions.storage.base_storage import BaseStorage

View File

@ -11,7 +11,7 @@ from collections.abc import Generator
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
import clickzetta # type: ignore[import] import clickzetta
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from extensions.storage.base_storage import BaseStorage from extensions.storage.base_storage import BaseStorage

View File

@ -34,7 +34,7 @@ class VolumePermissionManager:
# Support two initialization methods: connection object or configuration dictionary # Support two initialization methods: connection object or configuration dictionary
if isinstance(connection_or_config, dict): if isinstance(connection_or_config, dict):
# Create connection from configuration dictionary # Create connection from configuration dictionary
import clickzetta # type: ignore[import-untyped] import clickzetta
config = connection_or_config config = connection_or_config
self._connection = clickzetta.connect( self._connection = clickzetta.connect(

View File

@ -3,7 +3,7 @@ import io
import json import json
from collections.abc import Generator from collections.abc import Generator
from google.cloud import storage as google_cloud_storage # type: ignore from google.cloud import storage as google_cloud_storage
from configs import dify_config from configs import dify_config
from extensions.storage.base_storage import BaseStorage from extensions.storage.base_storage import BaseStorage

View File

@ -1,6 +1,6 @@
from collections.abc import Generator from collections.abc import Generator
from obs import ObsClient # type: ignore from obs import ObsClient
from configs import dify_config from configs import dify_config
from extensions.storage.base_storage import BaseStorage from extensions.storage.base_storage import BaseStorage

View File

@ -1,7 +1,7 @@
from collections.abc import Generator from collections.abc import Generator
import boto3 # type: ignore import boto3
from botocore.exceptions import ClientError # type: ignore from botocore.exceptions import ClientError
from configs import dify_config from configs import dify_config
from extensions.storage.base_storage import BaseStorage from extensions.storage.base_storage import BaseStorage

View File

@ -1,6 +1,6 @@
from collections.abc import Generator from collections.abc import Generator
from qcloud_cos import CosConfig, CosS3Client # type: ignore from qcloud_cos import CosConfig, CosS3Client
from configs import dify_config from configs import dify_config
from extensions.storage.base_storage import BaseStorage from extensions.storage.base_storage import BaseStorage

View File

@ -1,6 +1,6 @@
from collections.abc import Generator from collections.abc import Generator
import tos # type: ignore import tos
from configs import dify_config from configs import dify_config
from extensions.storage.base_storage import BaseStorage from extensions.storage.base_storage import BaseStorage

View File

@ -146,6 +146,6 @@ class ExternalApi(Api):
kwargs["doc"] = dify_config.SWAGGER_UI_PATH if dify_config.SWAGGER_UI_ENABLED else False kwargs["doc"] = dify_config.SWAGGER_UI_PATH if dify_config.SWAGGER_UI_ENABLED else False
# manual separate call on construction and init_app to ensure configs in kwargs effective # manual separate call on construction and init_app to ensure configs in kwargs effective
super().__init__(app=None, *args, **kwargs) # type: ignore super().__init__(app=None, *args, **kwargs)
self.init_app(app, **kwargs) self.init_app(app, **kwargs)
register_external_error_handlers(self) register_external_error_handlers(self)

View File

@ -23,7 +23,7 @@ from hashlib import sha1
import Crypto.Hash.SHA1 import Crypto.Hash.SHA1
import Crypto.Util.number import Crypto.Util.number
import gmpy2 # type: ignore import gmpy2
from Crypto import Random from Crypto import Random
from Crypto.Signature.pss import MGF1 from Crypto.Signature.pss import MGF1
from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes
@ -136,7 +136,7 @@ class PKCS1OAepCipher:
# Step 3a (OS2IP) # Step 3a (OS2IP)
em_int = bytes_to_long(em) em_int = bytes_to_long(em)
# Step 3b (RSAEP) # Step 3b (RSAEP)
m_int = gmpy2.powmod(em_int, self._key.e, self._key.n) # ty: ignore [unresolved-attribute] m_int = gmpy2.powmod(em_int, self._key.e, self._key.n)
# Step 3c (I2OSP) # Step 3c (I2OSP)
c = long_to_bytes(m_int, k) c = long_to_bytes(m_int, k)
return c return c
@ -169,7 +169,7 @@ class PKCS1OAepCipher:
ct_int = bytes_to_long(ciphertext) ct_int = bytes_to_long(ciphertext)
# Step 2b (RSADP) # Step 2b (RSADP)
# m_int = self._key._decrypt(ct_int) # m_int = self._key._decrypt(ct_int)
m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n) # ty: ignore [unresolved-attribute] m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n)
# Complete step 2c (I2OSP) # Complete step 2c (I2OSP)
em = long_to_bytes(m_int, k) em = long_to_bytes(m_int, k)
# Step 3a # Step 3a
@ -191,12 +191,12 @@ class PKCS1OAepCipher:
# Step 3g # Step 3g
one_pos = hLen + db[hLen:].find(b"\x01") one_pos = hLen + db[hLen:].find(b"\x01")
lHash1 = db[:hLen] lHash1 = db[:hLen]
invalid = bord(y) | int(one_pos < hLen) # type: ignore invalid = bord(y) | int(one_pos < hLen) # type: ignore[arg-type]
hash_compare = strxor(lHash1, lHash) hash_compare = strxor(lHash1, lHash)
for x in hash_compare: for x in hash_compare:
invalid |= bord(x) # type: ignore invalid |= bord(x) # type: ignore[arg-type]
for x in db[hLen:one_pos]: for x in db[hLen:one_pos]:
invalid |= bord(x) # type: ignore invalid |= bord(x) # type: ignore[arg-type]
if invalid != 0: if invalid != 0:
raise ValueError("Incorrect decryption.") raise ValueError("Incorrect decryption.")
# Step 4 # Step 4

View File

@ -81,6 +81,8 @@ class AvatarUrlField(fields.Raw):
from models import Account from models import Account
if isinstance(obj, Account) and obj.avatar is not None: if isinstance(obj, Account) and obj.avatar is not None:
if obj.avatar.startswith(("http://", "https://")):
return obj.avatar
return file_helpers.get_signed_file_url(obj.avatar) return file_helpers.get_signed_file_url(obj.avatar)
return None return None

View File

@ -3,7 +3,7 @@ from functools import wraps
from typing import Any from typing import Any
from flask import current_app, g, has_request_context, request from flask import current_app, g, has_request_context, request
from flask_login.config import EXEMPT_METHODS # type: ignore from flask_login.config import EXEMPT_METHODS
from werkzeug.local import LocalProxy from werkzeug.local import LocalProxy
from configs import dify_config from configs import dify_config
@ -87,7 +87,7 @@ def _get_user() -> EndUser | Account | None:
if "_login_user" not in g: if "_login_user" not in g:
current_app.login_manager._load_user() # type: ignore current_app.login_manager._load_user() # type: ignore
return g._login_user # type: ignore return g._login_user
return None return None

View File

@ -1,8 +1,8 @@
import logging import logging
import sendgrid # type: ignore import sendgrid
from python_http_client.exceptions import ForbiddenError, UnauthorizedError from python_http_client.exceptions import ForbiddenError, UnauthorizedError
from sendgrid.helpers.mail import Content, Email, Mail, To # type: ignore from sendgrid.helpers.mail import Content, Email, Mail, To
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -12,6 +12,7 @@ from constants import (
COOKIE_NAME_CSRF_TOKEN, COOKIE_NAME_CSRF_TOKEN,
COOKIE_NAME_PASSPORT, COOKIE_NAME_PASSPORT,
COOKIE_NAME_REFRESH_TOKEN, COOKIE_NAME_REFRESH_TOKEN,
COOKIE_NAME_WEBAPP_ACCESS_TOKEN,
HEADER_NAME_CSRF_TOKEN, HEADER_NAME_CSRF_TOKEN,
HEADER_NAME_PASSPORT, HEADER_NAME_PASSPORT,
) )
@ -81,6 +82,14 @@ def extract_access_token(request: Request) -> str | None:
return _try_extract_from_cookie(request) or _try_extract_from_header(request) return _try_extract_from_cookie(request) or _try_extract_from_header(request)
def extract_webapp_access_token(request: Request) -> str | None:
"""
Try to extract webapp access token from cookie, then header.
"""
return request.cookies.get(_real_cookie_name(COOKIE_NAME_WEBAPP_ACCESS_TOKEN)) or _try_extract_from_header(request)
def extract_webapp_passport(app_code: str, request: Request) -> str | None: def extract_webapp_passport(app_code: str, request: Request) -> str | None:
""" """
Try to extract app token from header or params. Try to extract app token from header or params.
@ -155,6 +164,10 @@ def clear_access_token_from_cookie(response: Response, samesite: str = "Lax"):
_clear_cookie(response, COOKIE_NAME_ACCESS_TOKEN, samesite) _clear_cookie(response, COOKIE_NAME_ACCESS_TOKEN, samesite)
def clear_webapp_access_token_from_cookie(response: Response, samesite: str = "Lax"):
_clear_cookie(response, COOKIE_NAME_WEBAPP_ACCESS_TOKEN, samesite)
def clear_refresh_token_from_cookie(response: Response): def clear_refresh_token_from_cookie(response: Response):
_clear_cookie(response, COOKIE_NAME_REFRESH_TOKEN) _clear_cookie(response, COOKIE_NAME_REFRESH_TOKEN)

View File

@ -22,55 +22,6 @@ def upgrade():
batch_op.add_column(sa.Column('app_mode', sa.String(length=255), nullable=True)) batch_op.add_column(sa.Column('app_mode', sa.String(length=255), nullable=True))
batch_op.create_index('message_app_mode_idx', ['app_mode'], unique=False) batch_op.create_index('message_app_mode_idx', ['app_mode'], unique=False)
conn = op.get_bind()
# Strategy: Update in batches to minimize lock time
# For large tables (millions of rows), this prevents long-running transactions
batch_size = 10000
print("Starting backfill of app_mode from conversations...")
# Use a more efficient UPDATE with JOIN
# This query updates messages.app_mode from conversations.mode
# Using string formatting for LIMIT since it's a constant
update_query = f"""
UPDATE messages m
SET app_mode = c.mode
FROM conversations c
WHERE m.conversation_id = c.id
AND m.app_mode IS NULL
AND m.id IN (
SELECT id FROM messages
WHERE app_mode IS NULL
LIMIT {batch_size}
)
"""
# Execute batched updates
total_updated = 0
iteration = 0
while True:
iteration += 1
result = conn.execute(sa.text(update_query))
# Check if result is None or has no rowcount
if result is None:
print("Warning: Query returned None, stopping backfill")
break
rows_updated = result.rowcount if hasattr(result, 'rowcount') else 0
total_updated += rows_updated
if rows_updated == 0:
break
print(f"Iteration {iteration}: Updated {rows_updated} messages (total: {total_updated})")
# For very large tables, add a small delay to reduce load
# Uncomment if needed: import time; time.sleep(0.1)
print(f"Backfill completed. Total messages updated: {total_updated}")
# ### end Alembic commands ### # ### end Alembic commands ###

View File

@ -0,0 +1,36 @@
"""remove-builtin-template-user
Revision ID: ae662b25d9bc
Revises: d98acf217d43
Create Date: 2025-10-21 14:30:28.566192
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'ae662b25d9bc'
down_revision = 'd98acf217d43'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
batch_op.drop_column('updated_by')
batch_op.drop_column('created_by')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
batch_op.add_column(sa.Column('created_by', sa.UUID(), autoincrement=False, nullable=False))
batch_op.add_column(sa.Column('updated_by', sa.UUID(), autoincrement=False, nullable=True))
# ### end Alembic commands ###

View File

@ -5,7 +5,7 @@ from datetime import datetime
from typing import Any, Optional from typing import Any, Optional
import sqlalchemy as sa import sqlalchemy as sa
from flask_login import UserMixin # type: ignore[import-untyped] from flask_login import UserMixin
from sqlalchemy import DateTime, String, func, select from sqlalchemy import DateTime, String, func, select
from sqlalchemy.orm import Mapped, Session, mapped_column from sqlalchemy.orm import Mapped, Session, mapped_column
from typing_extensions import deprecated from typing_extensions import deprecated

View File

@ -1239,15 +1239,6 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
language = mapped_column(db.String(255), nullable=False) language = mapped_column(db.String(255), nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
created_by = mapped_column(StringUUID, nullable=False)
updated_by = mapped_column(StringUUID, nullable=True)
@property
def created_user_name(self):
account = db.session.query(Account).where(Account.id == self.created_by).first()
if account:
return account.name
return ""
class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]

View File

@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast
import sqlalchemy as sa import sqlalchemy as sa
from flask import request from flask import request
from flask_login import UserMixin # type: ignore[import-untyped] from flask_login import UserMixin
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from sqlalchemy.orm import Mapped, Session, mapped_column from sqlalchemy.orm import Mapped, Session, mapped_column

View File

@ -219,7 +219,7 @@ class WorkflowToolProvider(TypeBase):
sa.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"), sa.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"),
) )
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False) id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
# name of the workflow provider # name of the workflow provider
name: Mapped[str] = mapped_column(String(255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False)
# label of the workflow provider # label of the workflow provider

View File

@ -1,6 +1,6 @@
[project] [project]
name = "dify-api" name = "dify-api"
version = "1.9.1" version = "1.9.2"
requires-python = ">=3.11,<3.13" requires-python = ">=3.11,<3.13"
dependencies = [ dependencies = [

View File

@ -16,7 +16,25 @@
"opentelemetry.instrumentation.requests", "opentelemetry.instrumentation.requests",
"opentelemetry.instrumentation.sqlalchemy", "opentelemetry.instrumentation.sqlalchemy",
"opentelemetry.instrumentation.redis", "opentelemetry.instrumentation.redis",
"opentelemetry.instrumentation.httpx" "langfuse",
"cloudscraper",
"readabilipy",
"pypandoc",
"pypdfium2",
"webvtt",
"flask_compress",
"oss2",
"baidubce.auth.bce_credentials",
"baidubce.bce_client_configuration",
"baidubce.services.bos.bos_client",
"clickzetta",
"google.cloud",
"obs",
"qcloud_cos",
"tos",
"gmpy2",
"sendgrid",
"sendgrid.helpers.mail"
], ],
"reportUnknownMemberType": "hint", "reportUnknownMemberType": "hint",
"reportUnknownParameterType": "hint", "reportUnknownParameterType": "hint",
@ -28,7 +46,7 @@
"reportUnnecessaryComparison": "hint", "reportUnnecessaryComparison": "hint",
"reportUnnecessaryIsInstance": "hint", "reportUnnecessaryIsInstance": "hint",
"reportUntypedFunctionDecorator": "hint", "reportUntypedFunctionDecorator": "hint",
"reportUnnecessaryTypeIgnoreComment": "hint",
"reportAttributeAccessIssue": "hint", "reportAttributeAccessIssue": "hint",
"pythonVersion": "3.11", "pythonVersion": "3.11",
"pythonPlatform": "All" "pythonPlatform": "All"

View File

@ -48,7 +48,7 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
try: try:
repository_class = import_string(class_path) repository_class = import_string(class_path)
return repository_class(session_maker=session_maker) # type: ignore[no-any-return] return repository_class(session_maker=session_maker)
except (ImportError, Exception) as e: except (ImportError, Exception) as e:
raise RepositoryImportError( raise RepositoryImportError(
f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}" f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}"
@ -77,6 +77,6 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
try: try:
repository_class = import_string(class_path) repository_class = import_string(class_path)
return repository_class(session_maker=session_maker) # type: ignore[no-any-return] return repository_class(session_maker=session_maker)
except (ImportError, Exception) as e: except (ImportError, Exception) as e:
raise RepositoryImportError(f"Failed to create APIWorkflowRunRepository from '{class_path}': {e}") from e raise RepositoryImportError(f"Failed to create APIWorkflowRunRepository from '{class_path}': {e}") from e

View File

@ -13,7 +13,7 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized from werkzeug.exceptions import Unauthorized
from configs import dify_config from configs import dify_config
from constants.languages import language_timezone_mapping, languages from constants.languages import get_valid_language, language_timezone_mapping
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client, redis_fallback from extensions.ext_redis import redis_client, redis_fallback
@ -1259,7 +1259,7 @@ class RegisterService:
return f"member_invite:token:{token}" return f"member_invite:token:{token}"
@classmethod @classmethod
def setup(cls, email: str, name: str, password: str, ip_address: str): def setup(cls, email: str, name: str, password: str, ip_address: str, language: str):
""" """
Setup dify Setup dify
@ -1269,11 +1269,10 @@ class RegisterService:
:param ip_address: ip address :param ip_address: ip address
""" """
try: try:
# Register
account = AccountService.create_account( account = AccountService.create_account(
email=email, email=email,
name=name, name=name,
interface_language=languages[0], interface_language=get_valid_language(language),
password=password, password=password,
is_setup=True, is_setup=True,
) )
@ -1315,7 +1314,7 @@ class RegisterService:
account = AccountService.create_account( account = AccountService.create_account(
email=email, email=email,
name=name, name=name,
interface_language=language or languages[0], interface_language=get_valid_language(language),
password=password, password=password,
is_setup=is_setup, is_setup=is_setup,
) )

View File

@ -7,7 +7,7 @@ from enum import StrEnum
from urllib.parse import urlparse from urllib.parse import urlparse
from uuid import uuid4 from uuid import uuid4
import yaml # type: ignore import yaml
from Crypto.Cipher import AES from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad from Crypto.Util.Padding import pad, unpad
from packaging import version from packaging import version
@ -563,7 +563,7 @@ class AppDslService:
else: else:
cls._append_model_config_export_data(export_data, app_model) cls._append_model_config_export_data(export_data, app_model)
return yaml.dump(export_data, allow_unicode=True) # type: ignore return yaml.dump(export_data, allow_unicode=True)
@classmethod @classmethod
def _append_workflow_export_data( def _append_workflow_export_data(

Some files were not shown because too many files have changed in this diff Show More