Merge branch 'main' into feat/mcp-06-18

This commit is contained in:
Novice 2025-10-23 17:01:25 +08:00
commit e7a575a33c
No known key found for this signature in database
GPG Key ID: EE3F68E3105DAAAB
585 changed files with 31247 additions and 7723 deletions

View File

@ -434,6 +434,9 @@ CODE_EXECUTION_SSL_VERIFY=True
CODE_EXECUTION_POOL_MAX_CONNECTIONS=100
CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS=20
CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY=5.0
CODE_EXECUTION_CONNECT_TIMEOUT=10
CODE_EXECUTION_READ_TIMEOUT=60
CODE_EXECUTION_WRITE_TIMEOUT=10
CODE_MAX_NUMBER=9223372036854775807
CODE_MIN_NUMBER=-9223372036854775808
CODE_MAX_STRING_LENGTH=400000

View File

@ -548,7 +548,7 @@ class UpdateConfig(BaseSettings):
class WorkflowVariableTruncationConfig(BaseSettings):
WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE: PositiveInt = Field(
# 100KB
# 1000 KiB
1024_000,
description="Maximum size for variable to trigger final truncation.",
)

View File

@ -145,7 +145,7 @@ class DatabaseConfig(BaseSettings):
default="postgresql",
)
@computed_field # type: ignore[misc]
@computed_field # type: ignore[prop-decorator]
@property
def SQLALCHEMY_DATABASE_URI(self) -> str:
db_extras = (
@ -198,7 +198,7 @@ class DatabaseConfig(BaseSettings):
default=os.cpu_count() or 1,
)
@computed_field # type: ignore[misc]
@computed_field # type: ignore[prop-decorator]
@property
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
# Parse DB_EXTRAS for 'options'

View File

@ -56,11 +56,15 @@ else:
}
DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)
# console
COOKIE_NAME_ACCESS_TOKEN = "access_token"
COOKIE_NAME_REFRESH_TOKEN = "refresh_token"
COOKIE_NAME_PASSPORT = "passport"
COOKIE_NAME_CSRF_TOKEN = "csrf_token"
# webapp
COOKIE_NAME_WEBAPP_ACCESS_TOKEN = "webapp_access_token"
COOKIE_NAME_PASSPORT = "passport"
HEADER_NAME_CSRF_TOKEN = "X-CSRF-Token"
HEADER_NAME_APP_CODE = "X-App-Code"
HEADER_NAME_PASSPORT = "X-App-Passport"

View File

@ -31,3 +31,9 @@ def supported_language(lang):
error = f"{lang} is not a valid language."
raise ValueError(error)
def get_valid_language(lang: str | None) -> str:
if lang and lang in languages:
return lang
return languages[0]

View File

@ -24,7 +24,7 @@ except ImportError:
)
else:
warnings.warn("To use python-magic guess MIMETYPE, you need to install `libmagic`", stacklevel=2)
magic = None # type: ignore
magic = None # type: ignore[assignment]
from pydantic import BaseModel

View File

@ -4,7 +4,7 @@ from flask_restx import Resource, reqparse
import services
from configs import dify_config
from constants.languages import languages
from constants.languages import get_valid_language
from controllers.console import console_ns
from controllers.console.auth.error import (
AuthenticationFailedError,
@ -29,8 +29,6 @@ from libs.token import (
clear_access_token_from_cookie,
clear_csrf_token_from_cookie,
clear_refresh_token_from_cookie,
extract_access_token,
extract_csrf_token,
set_access_token_to_cookie,
set_csrf_token_to_cookie,
set_refresh_token_to_cookie,
@ -206,10 +204,12 @@ class EmailCodeLoginApi(Resource):
.add_argument("email", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
user_email = args["email"]
language = args["language"]
token_data = AccountService.get_email_code_login_data(args["token"])
if token_data is None:
@ -243,7 +243,9 @@ class EmailCodeLoginApi(Resource):
if account is None:
try:
account = AccountService.create_account_and_tenant(
email=user_email, name=user_email, interface_language=languages[0]
email=user_email,
name=user_email,
interface_language=get_valid_language(language),
)
except WorkSpaceNotAllowedCreateError:
raise NotAllowedCreateWorkspace()
@ -286,13 +288,3 @@ class RefreshTokenApi(Resource):
return response
except Exception as e:
return {"result": "fail", "message": str(e)}, 401
# this api helps frontend to check whether user is authenticated
# TODO: remove in the future. frontend should redirect to login page by catching 401 status
@console_ns.route("/login/status")
class LoginStatus(Resource):
def get(self):
token = extract_access_token(request)
csrf_token = extract_csrf_token(request)
return {"logged_in": bool(token) and bool(csrf_token)}

View File

@ -22,7 +22,7 @@ from core.errors.error import (
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager
from libs import helper
from libs.login import current_user as current_user_
from libs.login import current_account_with_tenant
from models.model import AppMode, InstalledApp
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
@ -31,8 +31,6 @@ from .. import console_ns
logger = logging.getLogger(__name__)
current_user = current_user_._get_current_object() # type: ignore
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
class InstalledAppWorkflowRunApi(InstalledAppResource):
@ -40,6 +38,7 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
"""
Run workflow
"""
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if not app_model:
raise NotWorkflowAppError()
@ -53,7 +52,6 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
.add_argument("files", type=list, required=False, location="json")
)
args = parser.parse_args()
assert current_user is not None
try:
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
@ -89,7 +87,6 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
assert current_user is not None
# Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)

View File

@ -74,12 +74,17 @@ class SetupApi(Resource):
.add_argument("email", type=email, required=True, location="json")
.add_argument("name", type=StrLen(30), required=True, location="json")
.add_argument("password", type=valid_password, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
# setup
RegisterService.setup(
email=args["email"], name=args["name"], password=args["password"], ip_address=extract_remote_ip(request)
email=args["email"],
name=args["name"],
password=args["password"],
ip_address=extract_remote_ip(request),
language=args["language"],
)
return {"result": "success"}, 201

View File

@ -193,15 +193,16 @@ class MCPAppApi(Resource):
except ValidationError as e:
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str, session: Session) -> EndUser | None:
"""Get end user from existing session - optimized query"""
return (
session.query(EndUser)
.where(EndUser.tenant_id == tenant_id)
.where(EndUser.session_id == mcp_server_id)
.where(EndUser.type == "mcp")
.first()
)
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str) -> EndUser | None:
"""Get end user - manages its own database session"""
with Session(db.engine, expire_on_commit=False) as session, session.begin():
return (
session.query(EndUser)
.where(EndUser.tenant_id == tenant_id)
.where(EndUser.session_id == mcp_server_id)
.where(EndUser.type == "mcp")
.first()
)
def _create_end_user(
self, client_name: str, tenant_id: str, app_id: str, mcp_server_id: str, session: Session
@ -229,7 +230,7 @@ class MCPAppApi(Resource):
request_id: Union[int, str],
) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError | None:
"""Handle MCP request and return response"""
end_user = self._retrieve_end_user(mcp_server.tenant_id, mcp_server.id, session)
end_user = self._retrieve_end_user(mcp_server.tenant_id, mcp_server.id)
if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest):
client_info = mcp_request.root.params.clientInfo

View File

@ -17,8 +17,8 @@ from libs.helper import email
from libs.passport import PassportService
from libs.password import valid_password
from libs.token import (
clear_access_token_from_cookie,
extract_access_token,
clear_webapp_access_token_from_cookie,
extract_webapp_access_token,
)
from services.account_service import AccountService
from services.app_service import AppService
@ -81,7 +81,7 @@ class LoginStatusApi(Resource):
)
def get(self):
app_code = request.args.get("app_code")
token = extract_access_token(request)
token = extract_webapp_access_token(request)
if not app_code:
return {
"logged_in": bool(token),
@ -128,7 +128,7 @@ class LogoutApi(Resource):
response = make_response({"result": "success"})
# enterprise SSO sets same site to None in https deployment
# so we need to logout by calling api
clear_access_token_from_cookie(response, samesite="None")
clear_webapp_access_token_from_cookie(response, samesite="None")
return response

View File

@ -12,10 +12,8 @@ from controllers.web import web_ns
from controllers.web.error import WebAppAuthRequiredError
from extensions.ext_database import db
from libs.passport import PassportService
from libs.token import extract_access_token
from libs.token import extract_webapp_access_token
from models.model import App, EndUser, Site
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
@ -37,23 +35,18 @@ class PassportResource(Resource):
system_features = FeatureService.get_system_features()
app_code = request.headers.get(HEADER_NAME_APP_CODE)
user_id = request.args.get("user_id")
access_token = extract_access_token(request)
access_token = extract_webapp_access_token(request)
if app_code is None:
raise Unauthorized("X-App-Code header is missing.")
app_id = AppService.get_app_id_by_code(app_code)
# exchange token for enterprise logined web user
enterprise_user_decoded = decode_enterprise_webapp_user_id(access_token)
if enterprise_user_decoded:
# a web user has already logged in, exchange a token for this app without redirecting to the login page
return exchange_token_for_existing_web_user(
app_code=app_code, enterprise_user_decoded=enterprise_user_decoded
)
if system_features.webapp_auth.enabled:
app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id)
if not app_settings or not app_settings.access_mode == "public":
raise WebAppAuthRequiredError()
enterprise_user_decoded = decode_enterprise_webapp_user_id(access_token)
app_auth_type = WebAppAuthService.get_app_auth_type(app_code=app_code)
if app_auth_type != WebAppAuthType.PUBLIC:
if not enterprise_user_decoded:
raise WebAppAuthRequiredError()
return exchange_token_for_existing_web_user(
app_code=app_code, enterprise_user_decoded=enterprise_user_decoded, auth_type=app_auth_type
)
# get site from db and check if it is normal
site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal"))
@ -124,7 +117,7 @@ def decode_enterprise_webapp_user_id(jwt_token: str | None):
return decoded
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict):
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict, auth_type: WebAppAuthType):
"""
Exchange a token for an existing web user session.
"""
@ -145,13 +138,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound()
app_auth_type = WebAppAuthService.get_app_auth_type(app_code=app_code)
if app_auth_type == WebAppAuthType.PUBLIC:
if auth_type == WebAppAuthType.PUBLIC:
return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded)
elif app_auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external":
elif auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external":
raise WebAppAuthRequiredError("Please login as external user.")
elif app_auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal":
elif auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal":
raise WebAppAuthRequiredError("Please login as internal user.")
end_user = None

View File

@ -211,8 +211,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
user=user,
stream=streaming,
)
# FIXME: Type hinting issue here, ignore it for now, will fix it later
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) # type: ignore
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def _generate_worker(
self,

View File

@ -255,7 +255,7 @@ class PipelineGenerator(BaseAppGenerator):
json_text = json.dumps(text)
upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id)
features = FeatureService.get_features(dataset.tenant_id)
if features.billing.subscription.plan == "sandbox":
if features.billing.enabled and features.billing.subscription.plan == "sandbox":
tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}"
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}"

View File

@ -89,7 +89,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -98,7 +98,7 @@ class RateLimit:
else:
return RateLimitGenerator(
rate_limit=self,
generator=generator, # ty: ignore [invalid-argument-type]
generator=generator,
request_id=request_id,
)

View File

@ -49,7 +49,7 @@ class BasedGenerateTaskPipeline:
if isinstance(e, InvokeAuthorizationError):
err = InvokeAuthorizationError("Incorrect API key provided")
elif isinstance(e, InvokeError | ValueError):
err = e # ty: ignore [invalid-assignment]
err = e
else:
description = getattr(e, "description", None)
err = Exception(description if description is not None else str(e))

View File

@ -1868,7 +1868,7 @@ class ProviderConfigurations(BaseModel):
if "/" not in key:
key = str(ModelProviderID(key))
return self.configurations.get(key, default) # type: ignore
return self.configurations.get(key, default)
class ProviderModelBundle(BaseModel):

View File

@ -20,7 +20,7 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz
else:
# Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
# FIXME: mypy does not support the type of spec.loader
spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore
spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore[assignment]
if not spec or not spec.loader:
raise Exception(f"Failed to load module {module_name} from {py_file_path!r}")
if use_lazy_loader:

View File

@ -49,62 +49,80 @@ class IndexingRunner:
self.storage = storage
self.model_manager = ModelManager()
def _handle_indexing_error(self, document_id: str, error: Exception) -> None:
"""Handle indexing errors by updating document status."""
logger.exception("consume document failed")
document = db.session.get(DatasetDocument, document_id)
if document:
document.indexing_status = "error"
error_message = getattr(error, "description", str(error))
document.error = str(error_message)
document.stopped_at = naive_utc_now()
db.session.commit()
def run(self, dataset_documents: list[DatasetDocument]):
"""Run the indexing process."""
for dataset_document in dataset_documents:
document_id = dataset_document.id
try:
# Re-query the document to ensure it's bound to the current session
requeried_document = db.session.get(DatasetDocument, document_id)
if not requeried_document:
logger.warning("Document not found, skipping document id: %s", document_id)
continue
# get dataset
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
if not dataset:
raise ValueError("no dataset found")
# get the process rule
stmt = select(DatasetProcessRule).where(
DatasetProcessRule.id == dataset_document.dataset_process_rule_id
DatasetProcessRule.id == requeried_document.dataset_process_rule_id
)
processing_rule = db.session.scalar(stmt)
if not processing_rule:
raise ValueError("no process rule found")
index_type = dataset_document.doc_form
index_type = requeried_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
# extract
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
# transform
documents = self._transform(
index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict()
index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict()
)
# save segment
self._load_segments(dataset, dataset_document, documents)
self._load_segments(dataset, requeried_document, documents)
# load
self._load(
index_processor=index_processor,
dataset=dataset,
dataset_document=dataset_document,
dataset_document=requeried_document,
documents=documents,
)
except DocumentIsPausedError:
raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}")
raise DocumentIsPausedError(f"Document paused, document id: {document_id}")
except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error"
dataset_document.error = str(e.description)
dataset_document.stopped_at = naive_utc_now()
db.session.commit()
self._handle_indexing_error(document_id, e)
except ObjectDeletedError:
logger.warning("Document deleted, document id: %s", dataset_document.id)
logger.warning("Document deleted, document id: %s", document_id)
except Exception as e:
logger.exception("consume document failed")
dataset_document.indexing_status = "error"
dataset_document.error = str(e)
dataset_document.stopped_at = naive_utc_now()
db.session.commit()
self._handle_indexing_error(document_id, e)
def run_in_splitting_status(self, dataset_document: DatasetDocument):
"""Run the indexing process when the index_status is splitting."""
document_id = dataset_document.id
try:
# Re-query the document to ensure it's bound to the current session
requeried_document = db.session.get(DatasetDocument, document_id)
if not requeried_document:
logger.warning("Document not found: %s", document_id)
return
# get dataset
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
if not dataset:
raise ValueError("no dataset found")
@ -112,57 +130,60 @@ class IndexingRunner:
# get exist document_segment list and delete
document_segments = (
db.session.query(DocumentSegment)
.filter_by(dataset_id=dataset.id, document_id=dataset_document.id)
.filter_by(dataset_id=dataset.id, document_id=requeried_document.id)
.all()
)
for document_segment in document_segments:
db.session.delete(document_segment)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX:
# delete child chunks
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
db.session.commit()
# get the process rule
stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == requeried_document.dataset_process_rule_id)
processing_rule = db.session.scalar(stmt)
if not processing_rule:
raise ValueError("no process rule found")
index_type = dataset_document.doc_form
index_type = requeried_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
# extract
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
# transform
documents = self._transform(
index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict()
index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict()
)
# save segment
self._load_segments(dataset, dataset_document, documents)
self._load_segments(dataset, requeried_document, documents)
# load
self._load(
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
index_processor=index_processor,
dataset=dataset,
dataset_document=requeried_document,
documents=documents,
)
except DocumentIsPausedError:
raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}")
raise DocumentIsPausedError(f"Document paused, document id: {document_id}")
except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error"
dataset_document.error = str(e.description)
dataset_document.stopped_at = naive_utc_now()
db.session.commit()
self._handle_indexing_error(document_id, e)
except Exception as e:
logger.exception("consume document failed")
dataset_document.indexing_status = "error"
dataset_document.error = str(e)
dataset_document.stopped_at = naive_utc_now()
db.session.commit()
self._handle_indexing_error(document_id, e)
def run_in_indexing_status(self, dataset_document: DatasetDocument):
"""Run the indexing process when the index_status is indexing."""
document_id = dataset_document.id
try:
# Re-query the document to ensure it's bound to the current session
requeried_document = db.session.get(DatasetDocument, document_id)
if not requeried_document:
logger.warning("Document not found: %s", document_id)
return
# get dataset
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
if not dataset:
raise ValueError("no dataset found")
@ -170,7 +191,7 @@ class IndexingRunner:
# get exist document_segment list and delete
document_segments = (
db.session.query(DocumentSegment)
.filter_by(dataset_id=dataset.id, document_id=dataset_document.id)
.filter_by(dataset_id=dataset.id, document_id=requeried_document.id)
.all()
)
@ -188,7 +209,7 @@ class IndexingRunner:
"dataset_id": document_segment.dataset_id,
},
)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = document_segment.get_child_chunks()
if child_chunks:
child_documents = []
@ -206,24 +227,20 @@ class IndexingRunner:
document.children = child_documents
documents.append(document)
# build index
index_type = dataset_document.doc_form
index_type = requeried_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
self._load(
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
index_processor=index_processor,
dataset=dataset,
dataset_document=requeried_document,
documents=documents,
)
except DocumentIsPausedError:
raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}")
raise DocumentIsPausedError(f"Document paused, document id: {document_id}")
except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error"
dataset_document.error = str(e.description)
dataset_document.stopped_at = naive_utc_now()
db.session.commit()
self._handle_indexing_error(document_id, e)
except Exception as e:
logger.exception("consume document failed")
dataset_document.indexing_status = "error"
dataset_document.error = str(e)
dataset_document.stopped_at = naive_utc_now()
db.session.commit()
self._handle_indexing_error(document_id, e)
def indexing_estimate(
self,
@ -398,7 +415,6 @@ class IndexingRunner:
document_id=dataset_document.id,
after_indexing_status="splitting",
extra_update_params={
DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs),
DatasetDocument.parsing_completed_at: naive_utc_now(),
},
)
@ -738,6 +754,7 @@ class IndexingRunner:
extra_update_params={
DatasetDocument.cleaning_completed_at: cur_time,
DatasetDocument.splitting_completed_at: cur_time,
DatasetDocument.word_count: sum(len(doc.page_content) for doc in documents),
},
)

View File

@ -100,7 +100,7 @@ class LLMGenerator:
return name
@classmethod
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str) -> Sequence[str]:
output_parser = SuggestedQuestionsAfterAnswerOutputParser()
format_instructions = output_parser.get_format_instructions()
@ -119,6 +119,8 @@ class LLMGenerator:
prompt_messages = [UserPromptMessage(content=prompt)]
questions: Sequence[str] = []
try:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages),

View File

@ -1,17 +1,26 @@
import json
import logging
import re
from collections.abc import Sequence
from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
logger = logging.getLogger(__name__)
class SuggestedQuestionsAfterAnswerOutputParser:
def get_format_instructions(self) -> str:
return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
def parse(self, text: str):
def parse(self, text: str) -> Sequence[str]:
action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL)
questions: list[str] = []
if action_match is not None:
json_obj = json.loads(action_match.group(0).strip())
else:
json_obj = []
return json_obj
try:
json_obj = json.loads(action_match.group(0).strip())
except json.JSONDecodeError as exc:
logger.warning("Failed to decode suggested questions payload: %s", exc)
else:
if isinstance(json_obj, list):
questions = [question for question in json_obj if isinstance(question, str)]
return questions

View File

@ -2,7 +2,7 @@ import logging
import os
from datetime import datetime, timedelta
from langfuse import Langfuse # type: ignore
from langfuse import Langfuse
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance

View File

@ -76,7 +76,7 @@ class PluginParameter(BaseModel):
auto_generate: PluginParameterAutoGenerate | None = None
template: PluginParameterTemplate | None = None
required: bool = False
default: Union[float, int, str] | None = None
default: Union[float, int, str, bool] | None = None
min: Union[float, int] | None = None
max: Union[float, int] | None = None
precision: int | None = None

View File

@ -180,7 +180,7 @@ class BasePluginClient:
Make a request to the plugin daemon inner API and return the response as a model.
"""
response = self._request(method, path, headers, data, params, files)
return type_(**response.json()) # type: ignore
return type_(**response.json()) # type: ignore[return-value]
def _request_with_plugin_daemon_response(
self,

View File

@ -40,7 +40,7 @@ class PluginDaemonBadRequestError(PluginDaemonClientSideError):
description: str = "Bad Request"
class PluginInvokeError(PluginDaemonClientSideError):
class PluginInvokeError(PluginDaemonClientSideError, ValueError):
description: str = "Invoke Error"
def _get_error_object(self) -> Mapping:

View File

@ -72,6 +72,19 @@ default_retrieval_model: dict[str, Any] = {
class DatasetRetrieval:
def __init__(self, application_generate_entity=None):
self.application_generate_entity = application_generate_entity
self._llm_usage = LLMUsage.empty_usage()
@property
def llm_usage(self) -> LLMUsage:
return self._llm_usage.model_copy()
def _record_usage(self, usage: LLMUsage | None) -> None:
if usage is None or usage.total_tokens <= 0:
return
if self._llm_usage.total_tokens == 0:
self._llm_usage = usage
else:
self._llm_usage = self._llm_usage.plus(usage)
def retrieve(
self,
@ -312,15 +325,18 @@ class DatasetRetrieval:
)
tools.append(message_tool)
dataset_id = None
router_usage = LLMUsage.empty_usage()
if planning_strategy == PlanningStrategy.REACT_ROUTER:
react_multi_dataset_router = ReactMultiDatasetRouter()
dataset_id = react_multi_dataset_router.invoke(
dataset_id, router_usage = react_multi_dataset_router.invoke(
query, tools, model_config, model_instance, user_id, tenant_id
)
elif planning_strategy == PlanningStrategy.ROUTER:
function_call_router = FunctionCallMultiDatasetRouter()
dataset_id = function_call_router.invoke(query, tools, model_config, model_instance)
dataset_id, router_usage = function_call_router.invoke(query, tools, model_config, model_instance)
self._record_usage(router_usage)
if dataset_id:
# get retrieval model config
@ -983,7 +999,8 @@ class DatasetRetrieval:
)
# handle invoke result
result_text, _ = self._handle_invoke_result(invoke_result=invoke_result)
result_text, usage = self._handle_invoke_result(invoke_result=invoke_result)
self._record_usage(usage)
result_text_json = parse_and_check_json_markdown(result_text, [])
automatic_metadata_filters = []

View File

@ -2,7 +2,7 @@ from typing import Union
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage
@ -13,15 +13,15 @@ class FunctionCallMultiDatasetRouter:
dataset_tools: list[PromptMessageTool],
model_config: ModelConfigWithCredentialsEntity,
model_instance: ModelInstance,
) -> Union[str, None]:
) -> tuple[Union[str, None], LLMUsage]:
"""Given input, decided what to do.
Returns:
Action specifying what tool to use.
"""
if len(dataset_tools) == 0:
return None
return None, LLMUsage.empty_usage()
elif len(dataset_tools) == 1:
return dataset_tools[0].name
return dataset_tools[0].name, LLMUsage.empty_usage()
try:
prompt_messages = [
@ -34,9 +34,10 @@ class FunctionCallMultiDatasetRouter:
stream=False,
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
)
usage = result.usage or LLMUsage.empty_usage()
if result.message.tool_calls:
# get retrieval model config
return result.message.tool_calls[0].function.name
return None
return result.message.tool_calls[0].function.name, usage
return None, usage
except Exception:
return None
return None, LLMUsage.empty_usage()

View File

@ -58,15 +58,15 @@ class ReactMultiDatasetRouter:
model_instance: ModelInstance,
user_id: str,
tenant_id: str,
) -> Union[str, None]:
) -> tuple[Union[str, None], LLMUsage]:
"""Given input, decided what to do.
Returns:
Action specifying what tool to use.
"""
if len(dataset_tools) == 0:
return None
return None, LLMUsage.empty_usage()
elif len(dataset_tools) == 1:
return dataset_tools[0].name
return dataset_tools[0].name, LLMUsage.empty_usage()
try:
return self._react_invoke(
@ -78,7 +78,7 @@ class ReactMultiDatasetRouter:
tenant_id=tenant_id,
)
except Exception:
return None
return None, LLMUsage.empty_usage()
def _react_invoke(
self,
@ -91,7 +91,7 @@ class ReactMultiDatasetRouter:
prefix: str = PREFIX,
suffix: str = SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
) -> Union[str, None]:
) -> tuple[Union[str, None], LLMUsage]:
prompt: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
if model_config.mode == "chat":
prompt = self.create_chat_prompt(
@ -120,7 +120,7 @@ class ReactMultiDatasetRouter:
memory=None,
model_config=model_config,
)
result_text, _ = self._invoke_llm(
result_text, usage = self._invoke_llm(
completion_param=model_config.parameters,
model_instance=model_instance,
prompt_messages=prompt_messages,
@ -131,8 +131,8 @@ class ReactMultiDatasetRouter:
output_parser = StructuredChatOutputParser()
react_decision = output_parser.parse(result_text)
if isinstance(react_decision, ReactAction):
return react_decision.tool
return None
return react_decision.tool, usage
return None, usage
def _invoke_llm(
self,

View File

@ -74,7 +74,7 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository):
tenant_id = extract_tenant_id(user)
if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None
self._tenant_id = tenant_id
# Store app context
self._app_id = app_id

View File

@ -81,7 +81,7 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
tenant_id = extract_tenant_id(user)
if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None
self._tenant_id = tenant_id
# Store app context
self._app_id = app_id

View File

@ -60,7 +60,7 @@ class DifyCoreRepositoryFactory:
try:
repository_class = import_string(class_path)
return repository_class( # type: ignore[no-any-return]
return repository_class(
session_factory=session_factory,
user=user,
app_id=app_id,
@ -96,7 +96,7 @@ class DifyCoreRepositoryFactory:
try:
repository_class = import_string(class_path)
return repository_class( # type: ignore[no-any-return]
return repository_class(
session_factory=session_factory,
user=user,
app_id=app_id,

View File

@ -157,7 +157,7 @@ class BuiltinToolProviderController(ToolProviderController):
"""
returns the tool that the provider can provide
"""
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None) # type: ignore
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None)
@property
def need_credentials(self) -> bool:

View File

@ -43,7 +43,7 @@ class TTSTool(BuiltinTool):
content_text=tool_parameters.get("text"), # type: ignore
user=user_id,
tenant_id=self.runtime.tenant_id,
voice=voice, # type: ignore
voice=voice,
)
buffer = io.BytesIO()
for chunk in tts:

View File

@ -34,6 +34,7 @@ class LocaltimeToTimestampTool(BuiltinTool):
yield self.create_text_message(f"{timestamp}")
# TODO: this method's type is messy
@staticmethod
def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None:
try:

View File

@ -48,6 +48,6 @@ class TimezoneConversionTool(BuiltinTool):
datetime_with_tz = input_timezone.localize(local_time)
# timezone convert
converted_datetime = datetime_with_tz.astimezone(output_timezone)
return converted_datetime.strftime(format=time_format) # type: ignore
return converted_datetime.strftime(time_format)
except Exception as e:
raise ToolInvokeError(str(e))

View File

@ -113,7 +113,7 @@ class MCPToolProviderController(ToolProviderController):
"""
pass
def get_tool(self, tool_name: str) -> MCPTool: # type: ignore
def get_tool(self, tool_name: str) -> MCPTool:
"""
return tool with given name
"""
@ -136,7 +136,7 @@ class MCPToolProviderController(ToolProviderController):
sse_read_timeout=self.sse_read_timeout,
)
def get_tools(self) -> list[MCPTool]: # type: ignore
def get_tools(self) -> list[MCPTool]:
"""
get all tools
"""

View File

@ -26,7 +26,7 @@ class ToolLabelManager:
labels = cls.filter_tool_labels(labels)
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
provider_id = controller.provider_id # ty: ignore [unresolved-attribute]
provider_id = controller.provider_id
else:
raise ValueError("Unsupported tool type")
@ -51,7 +51,7 @@ class ToolLabelManager:
Get tool labels
"""
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
provider_id = controller.provider_id # ty: ignore [unresolved-attribute]
provider_id = controller.provider_id
elif isinstance(controller, BuiltinToolProviderController):
return controller.tool_labels
else:
@ -85,7 +85,7 @@ class ToolLabelManager:
provider_ids = []
for controller in tool_providers:
assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
provider_ids.append(controller.provider_id) # ty: ignore [unresolved-attribute]
provider_ids.append(controller.provider_id)
labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all()

View File

@ -331,7 +331,8 @@ class ToolManager:
workflow_provider_stmt = select(WorkflowToolProvider).where(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id
)
workflow_provider = db.session.scalar(workflow_provider_stmt)
with Session(db.engine, expire_on_commit=False) as session, session.begin():
workflow_provider = session.scalar(workflow_provider_stmt)
if workflow_provider is None:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")

View File

@ -193,18 +193,18 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
document = db.session.scalar(dataset_document_stmt) # type: ignore
document = db.session.scalar(dataset_document_stmt)
if dataset and document:
source = RetrievalSourceMetadata(
dataset_id=dataset.id,
dataset_name=dataset.name,
document_id=document.id, # type: ignore
document_name=document.name, # type: ignore
data_source_type=document.data_source_type, # type: ignore
document_id=document.id,
document_name=document.name,
data_source_type=document.data_source_type,
segment_id=segment.id,
retriever_from=self.retriever_from,
score=record.score or 0.0,
doc_metadata=document.doc_metadata, # type: ignore
doc_metadata=document.doc_metadata,
)
if self.retriever_from == "dev":

View File

@ -62,6 +62,11 @@ class ApiBasedToolSchemaParser:
root = root[ref]
interface["operation"]["parameters"][i] = root
for parameter in interface["operation"]["parameters"]:
# Handle complex type defaults that are not supported by PluginParameter
default_value = None
if "schema" in parameter and "default" in parameter["schema"]:
default_value = ApiBasedToolSchemaParser._sanitize_default_value(parameter["schema"]["default"])
tool_parameter = ToolParameter(
name=parameter["name"],
label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]),
@ -72,9 +77,7 @@ class ApiBasedToolSchemaParser:
required=parameter.get("required", False),
form=ToolParameter.ToolParameterForm.LLM,
llm_description=parameter.get("description"),
default=parameter["schema"]["default"]
if "schema" in parameter and "default" in parameter["schema"]
else None,
default=default_value,
placeholder=I18nObject(
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
),
@ -134,6 +137,11 @@ class ApiBasedToolSchemaParser:
required = body_schema.get("required", [])
properties = body_schema.get("properties", {})
for name, property in properties.items():
# Handle complex type defaults that are not supported by PluginParameter
default_value = ApiBasedToolSchemaParser._sanitize_default_value(
property.get("default", None)
)
tool = ToolParameter(
name=name,
label=I18nObject(en_US=name, zh_Hans=name),
@ -144,12 +152,11 @@ class ApiBasedToolSchemaParser:
required=name in required,
form=ToolParameter.ToolParameterForm.LLM,
llm_description=property.get("description", ""),
default=property.get("default", None),
default=default_value,
placeholder=I18nObject(
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
),
)
# check if there is a type
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
if typ:
@ -197,6 +204,22 @@ class ApiBasedToolSchemaParser:
return bundles
@staticmethod
def _sanitize_default_value(value):
"""
Sanitize default values for PluginParameter compatibility.
Complex types (list, dict) are converted to None to avoid validation errors.
Args:
value: The default value from OpenAPI schema
Returns:
None for complex types (list, dict), otherwise the original value
"""
if isinstance(value, (list, dict)):
return None
return value
@staticmethod
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None:
parameter = parameter or {}
@ -217,7 +240,11 @@ class ApiBasedToolSchemaParser:
return ToolParameter.ToolParameterType.STRING
elif typ == "array":
items = parameter.get("items") or parameter.get("schema", {}).get("items")
return ToolParameter.ToolParameterType.FILES if items and items.get("format") == "binary" else None
if items and items.get("format") == "binary":
return ToolParameter.ToolParameterType.FILES
else:
# For regular arrays, return ARRAY type instead of None
return ToolParameter.ToolParameterType.ARRAY
else:
return None

View File

@ -6,8 +6,8 @@ from typing import Any, cast
from urllib.parse import unquote
import chardet
import cloudscraper # type: ignore
from readabilipy import simple_json_from_html_string # type: ignore
import cloudscraper
from readabilipy import simple_json_from_html_string
from core.helper import ssrf_proxy
from core.rag.extractor import extract_processor
@ -63,8 +63,8 @@ def get_url(url: str, user_agent: str | None = None) -> str:
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
elif response.status_code == 403:
scraper = cloudscraper.create_scraper()
scraper.perform_request = ssrf_proxy.make_request # type: ignore
response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) # type: ignore
scraper.perform_request = ssrf_proxy.make_request
response = scraper.get(url, headers=headers, timeout=(120, 300))
if response.status_code != 200:
return f"URL returned status code {response.status_code}."

View File

@ -3,7 +3,7 @@ from functools import lru_cache
from pathlib import Path
from typing import Any
import yaml # type: ignore
import yaml
from yaml import YAMLError
logger = logging.getLogger(__name__)

View File

@ -1,6 +1,7 @@
from collections.abc import Mapping
from pydantic import Field
from sqlalchemy.orm import Session
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
@ -20,6 +21,7 @@ from core.tools.entities.tool_entities import (
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
from models.account import Account
from models.model import App, AppMode
from models.tools import WorkflowToolProvider
from models.workflow import Workflow
@ -44,29 +46,34 @@ class WorkflowToolProviderController(ToolProviderController):
@classmethod
def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController":
app = db_provider.app
with Session(db.engine, expire_on_commit=False) as session, session.begin():
provider = session.get(WorkflowToolProvider, db_provider.id) if db_provider.id else None
if not provider:
raise ValueError("workflow provider not found")
app = session.get(App, provider.app_id)
if not app:
raise ValueError("app not found")
if not app:
raise ValueError("app not found")
user = session.get(Account, provider.user_id) if provider.user_id else None
controller = WorkflowToolProviderController(
entity=ToolProviderEntity(
identity=ToolProviderIdentity(
author=db_provider.user.name if db_provider.user_id and db_provider.user else "",
name=db_provider.label,
label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label),
description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
icon=db_provider.icon,
controller = WorkflowToolProviderController(
entity=ToolProviderEntity(
identity=ToolProviderIdentity(
author=user.name if user else "",
name=provider.label,
label=I18nObject(en_US=provider.label, zh_Hans=provider.label),
description=I18nObject(en_US=provider.description, zh_Hans=provider.description),
icon=provider.icon,
),
credentials_schema=[],
plugin_id=None,
),
credentials_schema=[],
plugin_id=None,
),
provider_id=db_provider.id or "",
)
provider_id=provider.id or "",
)
# init tools
controller.tools = [controller._get_db_provider_tool(db_provider, app)]
controller.tools = [
controller._get_db_provider_tool(provider, app, session=session, user=user),
]
return controller
@ -74,7 +81,14 @@ class WorkflowToolProviderController(ToolProviderController):
def provider_type(self) -> ToolProviderType:
return ToolProviderType.WORKFLOW
def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool:
def _get_db_provider_tool(
self,
db_provider: WorkflowToolProvider,
app: App,
*,
session: Session,
user: Account | None = None,
) -> WorkflowTool:
"""
get db provider tool
:param db_provider: the db provider
@ -82,7 +96,7 @@ class WorkflowToolProviderController(ToolProviderController):
:return: the tool
"""
workflow: Workflow | None = (
db.session.query(Workflow)
session.query(Workflow)
.where(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
.first()
)
@ -99,9 +113,7 @@ class WorkflowToolProviderController(ToolProviderController):
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
def fetch_workflow_variable(variable_name: str) -> VariableEntity | None:
return next(filter(lambda x: x.variable == variable_name, variables), None) # type: ignore
user = db_provider.user
return next(filter(lambda x: x.variable == variable_name, variables), None)
workflow_tool_parameters = []
for parameter in parameters:
@ -187,22 +199,25 @@ class WorkflowToolProviderController(ToolProviderController):
if self.tools is not None:
return self.tools
db_providers: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.app_id == self.provider_id,
with Session(db.engine, expire_on_commit=False) as session, session.begin():
db_provider: WorkflowToolProvider | None = (
session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.app_id == self.provider_id,
)
.first()
)
.first()
)
if not db_providers:
return []
if not db_providers.app:
raise ValueError("app not found")
if not db_provider:
return []
app = db_providers.app
self.tools = [self._get_db_provider_tool(db_providers, app)]
app = session.get(App, db_provider.app_id)
if not app:
raise ValueError("app not found")
user = session.get(Account, db_provider.user_id) if db_provider.user_id else None
self.tools = [self._get_db_provider_tool(db_provider, app, session=session, user=user)]
return self.tools

View File

@ -1,12 +1,14 @@
import json
import logging
from collections.abc import Generator
from typing import Any
from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast
from flask import has_request_context
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import (
@ -48,6 +50,7 @@ class WorkflowTool(Tool):
self.workflow_entities = workflow_entities
self.workflow_call_depth = workflow_call_depth
self.label = label
self._latest_usage = LLMUsage.empty_usage()
super().__init__(entity=entity, runtime=runtime)
@ -83,10 +86,11 @@ class WorkflowTool(Tool):
assert self.runtime.invoke_from is not None
user = self._resolve_user(user_id=user_id)
if user is None:
raise ToolInvokeError("User not found")
self._latest_usage = LLMUsage.empty_usage()
result = generator.generate(
app_model=app,
workflow=workflow,
@ -110,9 +114,68 @@ class WorkflowTool(Tool):
for file in files:
yield self.create_file_message(file) # type: ignore
self._latest_usage = self._derive_usage_from_result(data)
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
yield self.create_json_message(outputs)
@property
def latest_usage(self) -> LLMUsage:
return self._latest_usage
@classmethod
def _derive_usage_from_result(cls, data: Mapping[str, Any]) -> LLMUsage:
usage_dict = cls._extract_usage_dict(data)
if usage_dict is not None:
return LLMUsage.from_metadata(cast(LLMUsageMetadata, dict(usage_dict)))
total_tokens = data.get("total_tokens")
total_price = data.get("total_price")
if total_tokens is None and total_price is None:
return LLMUsage.empty_usage()
usage_metadata: dict[str, Any] = {}
if total_tokens is not None:
try:
usage_metadata["total_tokens"] = int(str(total_tokens))
except (TypeError, ValueError):
pass
if total_price is not None:
usage_metadata["total_price"] = str(total_price)
currency = data.get("currency")
if currency is not None:
usage_metadata["currency"] = currency
if not usage_metadata:
return LLMUsage.empty_usage()
return LLMUsage.from_metadata(cast(LLMUsageMetadata, usage_metadata))
@classmethod
def _extract_usage_dict(cls, payload: Mapping[str, Any]) -> Mapping[str, Any] | None:
usage_candidate = payload.get("usage")
if isinstance(usage_candidate, Mapping):
return usage_candidate
metadata_candidate = payload.get("metadata")
if isinstance(metadata_candidate, Mapping):
usage_candidate = metadata_candidate.get("usage")
if isinstance(usage_candidate, Mapping):
return usage_candidate
for value in payload.values():
if isinstance(value, Mapping):
found = cls._extract_usage_dict(value)
if found is not None:
return found
elif isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
for item in value:
if isinstance(item, Mapping):
found = cls._extract_usage_dict(item)
if found is not None:
return found
return None
def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
"""
fork a new tool with metadata
@ -179,16 +242,17 @@ class WorkflowTool(Tool):
"""
get the workflow by app id and version
"""
if not version:
workflow = (
db.session.query(Workflow)
.where(Workflow.app_id == app_id, Workflow.version != Workflow.VERSION_DRAFT)
.order_by(Workflow.created_at.desc())
.first()
)
else:
stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version)
workflow = db.session.scalar(stmt)
with Session(db.engine, expire_on_commit=False) as session, session.begin():
if not version:
stmt = (
select(Workflow)
.where(Workflow.app_id == app_id, Workflow.version != Workflow.VERSION_DRAFT)
.order_by(Workflow.created_at.desc())
)
workflow = session.scalars(stmt).first()
else:
stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version)
workflow = session.scalar(stmt)
if not workflow:
raise ValueError("workflow not found or not published")
@ -200,7 +264,8 @@ class WorkflowTool(Tool):
get the app by app id
"""
stmt = select(App).where(App.id == app_id)
app = db.session.scalar(stmt)
with Session(db.engine, expire_on_commit=False) as session, session.begin():
app = session.scalar(stmt)
if not app:
raise ValueError("app not found")

View File

@ -4,7 +4,7 @@ from .types import SegmentType
class SegmentGroup(Segment):
value_type: SegmentType = SegmentType.GROUP
value: list[Segment] = None # type: ignore
value: list[Segment]
@property
def text(self):

View File

@ -19,7 +19,7 @@ class Segment(BaseModel):
model_config = ConfigDict(frozen=True)
value_type: SegmentType
value: Any = None
value: Any
@field_validator("value_type")
@classmethod
@ -74,12 +74,12 @@ class NoneSegment(Segment):
class StringSegment(Segment):
value_type: SegmentType = SegmentType.STRING
value: str = None # type: ignore
value: str
class FloatSegment(Segment):
value_type: SegmentType = SegmentType.FLOAT
value: float = None # type: ignore
value: float
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
# The following tests cannot pass.
#
@ -98,12 +98,12 @@ class FloatSegment(Segment):
class IntegerSegment(Segment):
value_type: SegmentType = SegmentType.INTEGER
value: int = None # type: ignore
value: int
class ObjectSegment(Segment):
value_type: SegmentType = SegmentType.OBJECT
value: Mapping[str, Any] = None # type: ignore
value: Mapping[str, Any]
@property
def text(self) -> str:
@ -136,7 +136,7 @@ class ArraySegment(Segment):
class FileSegment(Segment):
value_type: SegmentType = SegmentType.FILE
value: File = None # type: ignore
value: File
@property
def markdown(self) -> str:
@ -153,17 +153,17 @@ class FileSegment(Segment):
class BooleanSegment(Segment):
value_type: SegmentType = SegmentType.BOOLEAN
value: bool = None # type: ignore
value: bool
class ArrayAnySegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_ANY
value: Sequence[Any] = None # type: ignore
value: Sequence[Any]
class ArrayStringSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_STRING
value: Sequence[str] = None # type: ignore
value: Sequence[str]
@property
def text(self) -> str:
@ -175,17 +175,17 @@ class ArrayStringSegment(ArraySegment):
class ArrayNumberSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_NUMBER
value: Sequence[float | int] = None # type: ignore
value: Sequence[float | int]
class ArrayObjectSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_OBJECT
value: Sequence[Mapping[str, Any]] = None # type: ignore
value: Sequence[Mapping[str, Any]]
class ArrayFileSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_FILE
value: Sequence[File] = None # type: ignore
value: Sequence[File]
@property
def markdown(self) -> str:
@ -205,7 +205,7 @@ class ArrayFileSegment(ArraySegment):
class ArrayBooleanSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_BOOLEAN
value: Sequence[bool] = None # type: ignore
value: Sequence[bool]
def get_segment_discriminator(v: Any) -> SegmentType | None:

View File

@ -1,3 +1,5 @@
from ..runtime.graph_runtime_state import GraphRuntimeState
from ..runtime.variable_pool import VariablePool
from .agent import AgentNodeStrategyInit
from .graph_init_params import GraphInitParams
from .workflow_execution import WorkflowExecution
@ -6,6 +8,8 @@ from .workflow_node_execution import WorkflowNodeExecution
__all__ = [
"AgentNodeStrategyInit",
"GraphInitParams",
"GraphRuntimeState",
"VariablePool",
"WorkflowExecution",
"WorkflowNodeExecution",
]

View File

@ -3,11 +3,12 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Protocol, cast, final
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType
from core.workflow.nodes.base.node import Node
from libs.typing import is_str, is_str_dict
from .edge import Edge
from .validation import get_graph_validator
logger = logging.getLogger(__name__)
@ -201,6 +202,17 @@ class Graph:
return GraphBuilder(graph_cls=cls)
@classmethod
def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None:
"""
Promote nodes configured with FAIL_BRANCH error strategy to branch execution type.
:param nodes: mapping of node ID to node instance
"""
for node in nodes.values():
if node.error_strategy == ErrorStrategy.FAIL_BRANCH:
node.execution_type = NodeExecutionType.BRANCH
@classmethod
def _mark_inactive_root_branches(
cls,
@ -307,6 +319,9 @@ class Graph:
# Create node instances
nodes = cls._create_node_instances(node_configs_map, node_factory)
# Promote fail-branch nodes to branch execution type at graph level
cls._promote_fail_branch_nodes(nodes)
# Get root node instance
root_node = nodes[root_node_id]
@ -314,7 +329,7 @@ class Graph:
cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id)
# Create and return the graph
return cls(
graph = cls(
nodes=nodes,
edges=edges,
in_edges=in_edges,
@ -322,6 +337,11 @@ class Graph:
root_node=root_node,
)
# Validate the graph structure using built-in validators
get_graph_validator().validate(graph)
return graph
@property
def node_ids(self) -> list[str]:
"""

View File

@ -0,0 +1,125 @@
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Protocol
from core.workflow.enums import NodeExecutionType, NodeType
if TYPE_CHECKING:
from .graph import Graph
@dataclass(frozen=True, slots=True)
class GraphValidationIssue:
"""Immutable value object describing a single validation issue."""
code: str
message: str
node_id: str | None = None
class GraphValidationError(ValueError):
"""Raised when graph validation fails."""
def __init__(self, issues: Sequence[GraphValidationIssue]) -> None:
if not issues:
raise ValueError("GraphValidationError requires at least one issue.")
self.issues: tuple[GraphValidationIssue, ...] = tuple(issues)
message = "; ".join(f"[{issue.code}] {issue.message}" for issue in self.issues)
super().__init__(message)
class GraphValidationRule(Protocol):
"""Protocol that individual validation rules must satisfy."""
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
"""Validate the provided graph and return any discovered issues."""
...
@dataclass(frozen=True, slots=True)
class _EdgeEndpointValidator:
"""Ensures all edges reference existing nodes."""
missing_node_code: str = "MISSING_NODE"
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
issues: list[GraphValidationIssue] = []
for edge in graph.edges.values():
if edge.tail not in graph.nodes:
issues.append(
GraphValidationIssue(
code=self.missing_node_code,
message=f"Edge {edge.id} references unknown source node '{edge.tail}'.",
node_id=edge.tail,
)
)
if edge.head not in graph.nodes:
issues.append(
GraphValidationIssue(
code=self.missing_node_code,
message=f"Edge {edge.id} references unknown target node '{edge.head}'.",
node_id=edge.head,
)
)
return issues
@dataclass(frozen=True, slots=True)
class _RootNodeValidator:
"""Validates root node invariants."""
invalid_root_code: str = "INVALID_ROOT"
container_entry_types: tuple[NodeType, ...] = (NodeType.ITERATION_START, NodeType.LOOP_START)
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
root_node = graph.root_node
issues: list[GraphValidationIssue] = []
if root_node.id not in graph.nodes:
issues.append(
GraphValidationIssue(
code=self.invalid_root_code,
message=f"Root node '{root_node.id}' is missing from the node registry.",
node_id=root_node.id,
)
)
return issues
node_type = getattr(root_node, "node_type", None)
if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types:
issues.append(
GraphValidationIssue(
code=self.invalid_root_code,
message=f"Root node '{root_node.id}' must declare execution type 'root'.",
node_id=root_node.id,
)
)
return issues
@dataclass(frozen=True, slots=True)
class GraphValidator:
"""Coordinates execution of graph validation rules."""
rules: tuple[GraphValidationRule, ...]
def validate(self, graph: Graph) -> None:
"""Validate the graph against all configured rules."""
issues: list[GraphValidationIssue] = []
for rule in self.rules:
issues.extend(rule.validate(graph))
if issues:
raise GraphValidationError(issues)
_DEFAULT_RULES: tuple[GraphValidationRule, ...] = (
_EdgeEndpointValidator(),
_RootNodeValidator(),
)
def get_graph_validator() -> GraphValidator:
"""Construct the validator composed of default rules."""
return GraphValidator(_DEFAULT_RULES)

View File

@ -26,8 +26,8 @@ class AgentNodeData(BaseNodeData):
class ParamsAutoGenerated(IntEnum):
CLOSE = auto()
OPEN = auto()
CLOSE = 0
OPEN = 1
class AgentOldVersionModelFeatures(StrEnum):

View File

@ -1,4 +1,5 @@
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
from .usage_tracking_mixin import LLMUsageTrackingMixin
__all__ = [
"BaseIterationNodeData",
@ -6,4 +7,5 @@ __all__ = [
"BaseLoopNodeData",
"BaseLoopState",
"BaseNodeData",
"LLMUsageTrackingMixin",
]

View File

@ -1,5 +1,6 @@
import json
from abc import ABC
from builtins import type as type_
from collections.abc import Sequence
from enum import StrEnum
from typing import Any, Union
@ -58,10 +59,9 @@ class DefaultValue(BaseModel):
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
@staticmethod
def _validate_array(value: Any, element_type: DefaultValueType) -> bool:
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
"""Unified array type validation"""
# FIXME, type ignore here for do not find the reason mypy complain, if find the root cause, please fix it
return isinstance(value, list) and all(isinstance(x, element_type) for x in value) # type: ignore
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
@staticmethod
def _convert_number(value: str) -> float:

View File

@ -0,0 +1,28 @@
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.runtime import GraphRuntimeState
class LLMUsageTrackingMixin:
"""Provides shared helpers for merging and recording LLM usage within workflow nodes."""
graph_runtime_state: GraphRuntimeState
@staticmethod
def _merge_usage(current: LLMUsage, new_usage: LLMUsage | None) -> LLMUsage:
"""Return a combined usage snapshot, preserving zero-value inputs."""
if new_usage is None or new_usage.total_tokens <= 0:
return current
if current.total_tokens == 0:
return new_usage
return current.plus(new_usage)
def _accumulate_usage(self, usage: LLMUsage) -> None:
"""Push usage into the graph runtime accumulator for downstream reporting."""
if usage.total_tokens <= 0:
return
current_usage = self.graph_runtime_state.llm_usage
if current_usage.total_tokens == 0:
self.graph_runtime_state.llm_usage = usage.model_copy()
else:
self.graph_runtime_state.llm_usage = current_usage.plus(usage)

View File

@ -10,10 +10,10 @@ from typing import Any
import chardet
import docx
import pandas as pd
import pypandoc # type: ignore
import pypdfium2 # type: ignore
import webvtt # type: ignore
import yaml # type: ignore
import pypandoc
import pypdfium2
import webvtt
import yaml
from docx.document import Document
from docx.oxml.table import CT_Tbl
from docx.oxml.text.paragraph import CT_P

View File

@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, NewType, cast
from flask import Flask, current_app
from typing_extensions import TypeIs
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.variables.variables import VariableUnion
@ -34,6 +35,7 @@ from core.workflow.node_events import (
NodeRunResult,
StreamCompletedEvent,
)
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
@ -58,7 +60,7 @@ logger = logging.getLogger(__name__)
EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment)
class IterationNode(Node):
class IterationNode(LLMUsageTrackingMixin, Node):
"""
Iteration Node.
"""
@ -118,6 +120,7 @@ class IterationNode(Node):
started_at = naive_utc_now()
iter_run_map: dict[str, float] = {}
outputs: list[object] = []
usage_accumulator = [LLMUsage.empty_usage()]
yield IterationStartedEvent(
start_at=started_at,
@ -130,22 +133,27 @@ class IterationNode(Node):
iterator_list_value=iterator_list_value,
outputs=outputs,
iter_run_map=iter_run_map,
usage_accumulator=usage_accumulator,
)
self._accumulate_usage(usage_accumulator[0])
yield from self._handle_iteration_success(
started_at=started_at,
inputs=inputs,
outputs=outputs,
iterator_list_value=iterator_list_value,
iter_run_map=iter_run_map,
usage=usage_accumulator[0],
)
except IterationNodeError as e:
self._accumulate_usage(usage_accumulator[0])
yield from self._handle_iteration_failure(
started_at=started_at,
inputs=inputs,
outputs=outputs,
iterator_list_value=iterator_list_value,
iter_run_map=iter_run_map,
usage=usage_accumulator[0],
error=e,
)
@ -196,6 +204,7 @@ class IterationNode(Node):
iterator_list_value: Sequence[object],
outputs: list[object],
iter_run_map: dict[str, float],
usage_accumulator: list[LLMUsage],
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
if self._node_data.is_parallel:
# Parallel mode execution
@ -203,6 +212,7 @@ class IterationNode(Node):
iterator_list_value=iterator_list_value,
outputs=outputs,
iter_run_map=iter_run_map,
usage_accumulator=usage_accumulator,
)
else:
# Sequential mode execution
@ -228,6 +238,9 @@ class IterationNode(Node):
# Update the total tokens from this iteration
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
usage_accumulator[0] = self._merge_usage(
usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage
)
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
def _execute_parallel_iterations(
@ -235,6 +248,7 @@ class IterationNode(Node):
iterator_list_value: Sequence[object],
outputs: list[object],
iter_run_map: dict[str, float],
usage_accumulator: list[LLMUsage],
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
# Initialize outputs list with None values to maintain order
outputs.extend([None] * len(iterator_list_value))
@ -245,7 +259,16 @@ class IterationNode(Node):
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all iteration tasks
future_to_index: dict[
Future[tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]],
Future[
tuple[
datetime,
list[GraphNodeEventBase],
object | None,
int,
dict[str, VariableUnion],
LLMUsage,
]
],
int,
] = {}
for index, item in enumerate(iterator_list_value):
@ -264,7 +287,14 @@ class IterationNode(Node):
index = future_to_index[future]
try:
result = future.result()
iter_start_at, events, output_value, tokens_used, conversation_snapshot = result
(
iter_start_at,
events,
output_value,
tokens_used,
conversation_snapshot,
iteration_usage,
) = result
# Update outputs at the correct index
outputs[index] = output_value
@ -276,6 +306,8 @@ class IterationNode(Node):
self.graph_runtime_state.total_tokens += tokens_used
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage)
# Sync conversation variables after iteration completion
self._sync_conversation_variables_from_snapshot(conversation_snapshot)
@ -303,7 +335,7 @@ class IterationNode(Node):
item: object,
flask_app: Flask,
context_vars: contextvars.Context,
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]:
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion], LLMUsage]:
"""Execute a single iteration in parallel mode and return results."""
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
@ -332,6 +364,7 @@ class IterationNode(Node):
output_value,
graph_engine.graph_runtime_state.total_tokens,
conversation_snapshot,
graph_engine.graph_runtime_state.llm_usage,
)
def _handle_iteration_success(
@ -341,6 +374,8 @@ class IterationNode(Node):
outputs: list[object],
iterator_list_value: Sequence[object],
iter_run_map: dict[str, float],
*,
usage: LLMUsage,
) -> Generator[NodeEventBase, None, None]:
# Flatten the list of lists if all outputs are lists
flattened_outputs = self._flatten_outputs_if_needed(outputs)
@ -351,7 +386,9 @@ class IterationNode(Node):
outputs={"output": flattened_outputs},
steps=len(iterator_list_value),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
},
)
@ -362,8 +399,11 @@ class IterationNode(Node):
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": flattened_outputs},
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)
)
@ -400,6 +440,8 @@ class IterationNode(Node):
outputs: list[object],
iterator_list_value: Sequence[object],
iter_run_map: dict[str, float],
*,
usage: LLMUsage,
error: IterationNodeError,
) -> Generator[NodeEventBase, None, None]:
# Flatten the list of lists if all outputs are lists (even in failure case)
@ -411,7 +453,9 @@ class IterationNode(Node):
outputs={"output": flattened_outputs},
steps=len(iterator_list_value),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
},
error=str(error),
@ -420,6 +464,12 @@ class IterationNode(Node):
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(error),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)
)

View File

@ -15,14 +15,11 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.message_entities import (
PromptMessageRole,
)
from core.model_runtime.entities.model_entities import (
ModelFeature,
ModelType,
)
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.simple_prompt_transform import ModelMode
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.metadata_entities import Condition, MetadataCondition
@ -33,8 +30,14 @@ from core.variables import (
)
from core.variables.segments import ArrayObjectSegment
from core.workflow.entities import GraphInitParams
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import (
ErrorStrategy,
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
@ -80,7 +83,7 @@ default_retrieval_model = {
}
class KnowledgeRetrievalNode(Node):
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node):
node_type = NodeType.KNOWLEDGE_RETRIEVAL
_node_data: KnowledgeRetrievalNodeData
@ -141,7 +144,7 @@ class KnowledgeRetrievalNode(Node):
def version(cls):
return "1"
def _run(self) -> NodeRunResult: # type: ignore
def _run(self) -> NodeRunResult:
# extract variables
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
if not isinstance(variable, StringSegment):
@ -182,14 +185,21 @@ class KnowledgeRetrievalNode(Node):
)
# retrieve knowledge
usage = LLMUsage.empty_usage()
try:
results = self._fetch_dataset_retriever(node_data=self._node_data, query=query)
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, query=query)
outputs = {"result": ArrayObjectSegment(value=results)}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
process_data={},
process_data={"usage": jsonable_encoder(usage)},
outputs=outputs, # type: ignore
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)
except KnowledgeRetrievalNodeError as e:
@ -199,6 +209,7 @@ class KnowledgeRetrievalNode(Node):
inputs=variables,
error=str(e),
error_type=type(e).__name__,
llm_usage=usage,
)
# Temporary handle all exceptions from DatasetRetrieval class here.
except Exception as e:
@ -207,11 +218,15 @@ class KnowledgeRetrievalNode(Node):
inputs=variables,
error=str(e),
error_type=type(e).__name__,
llm_usage=usage,
)
finally:
db.session.close()
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]:
def _fetch_dataset_retriever(
self, node_data: KnowledgeRetrievalNodeData, query: str
) -> tuple[list[dict[str, Any]], LLMUsage]:
usage = LLMUsage.empty_usage()
available_datasets = []
dataset_ids = node_data.dataset_ids
@ -245,9 +260,10 @@ class KnowledgeRetrievalNode(Node):
if not dataset:
continue
available_datasets.append(dataset)
metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition(
metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition(
[dataset.id for dataset in available_datasets], query, node_data
)
usage = self._merge_usage(usage, metadata_usage)
all_documents = []
dataset_retrieval = DatasetRetrieval()
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
@ -330,6 +346,8 @@ class KnowledgeRetrievalNode(Node):
metadata_filter_document_ids=metadata_filter_document_ids,
metadata_condition=metadata_condition,
)
usage = self._merge_usage(usage, dataset_retrieval.llm_usage)
dify_documents = [item for item in all_documents if item.provider == "dify"]
external_documents = [item for item in all_documents if item.provider == "external"]
retrieval_resource_list = []
@ -406,11 +424,12 @@ class KnowledgeRetrievalNode(Node):
)
for position, item in enumerate(retrieval_resource_list, start=1):
item["metadata"]["position"] = position
return retrieval_resource_list
return retrieval_resource_list, usage
def _get_metadata_filter_condition(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]:
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]:
usage = LLMUsage.empty_usage()
document_query = db.session.query(Document).where(
Document.dataset_id.in_(dataset_ids),
Document.indexing_status == "completed",
@ -420,9 +439,12 @@ class KnowledgeRetrievalNode(Node):
filters: list[Any] = []
metadata_condition = None
if node_data.metadata_filtering_mode == "disabled":
return None, None
return None, None, usage
elif node_data.metadata_filtering_mode == "automatic":
automatic_metadata_filters = self._automatic_metadata_filter_func(dataset_ids, query, node_data)
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
dataset_ids, query, node_data
)
usage = self._merge_usage(usage, automatic_usage)
if automatic_metadata_filters:
conditions = []
for sequence, filter in enumerate(automatic_metadata_filters):
@ -443,7 +465,7 @@ class KnowledgeRetrievalNode(Node):
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator
if node_data.metadata_filtering_conditions
else "or", # type: ignore
else "or",
conditions=conditions,
)
elif node_data.metadata_filtering_mode == "manual":
@ -457,10 +479,10 @@ class KnowledgeRetrievalNode(Node):
expected_value = self.graph_runtime_state.variable_pool.convert_template(
expected_value
).value[0]
if expected_value.value_type in {"number", "integer", "float"}: # type: ignore
expected_value = expected_value.value # type: ignore
elif expected_value.value_type == "string": # type: ignore
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore
if expected_value.value_type in {"number", "integer", "float"}:
expected_value = expected_value.value
elif expected_value.value_type == "string":
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
else:
raise ValueError("Invalid expected metadata value type")
conditions.append(
@ -487,7 +509,7 @@ class KnowledgeRetrievalNode(Node):
if (
node_data.metadata_filtering_conditions
and node_data.metadata_filtering_conditions.logical_operator == "and"
): # type: ignore
):
document_query = document_query.where(and_(*filters))
else:
document_query = document_query.where(or_(*filters))
@ -496,11 +518,12 @@ class KnowledgeRetrievalNode(Node):
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
for document in documents:
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
return metadata_filter_document_ids, metadata_condition
return metadata_filter_document_ids, metadata_condition, usage
def _automatic_metadata_filter_func(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> list[dict[str, Any]]:
) -> tuple[list[dict[str, Any]], LLMUsage]:
usage = LLMUsage.empty_usage()
# get all metadata field
stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
metadata_fields = db.session.scalars(stmt).all()
@ -548,6 +571,7 @@ class KnowledgeRetrievalNode(Node):
for event in generator:
if isinstance(event, ModelInvokeCompletedEvent):
result_text = event.text
usage = self._merge_usage(usage, event.usage)
break
result_text_json = parse_and_check_json_markdown(result_text, [])
@ -564,8 +588,8 @@ class KnowledgeRetrievalNode(Node):
}
)
except Exception:
return []
return automatic_metadata_filters
return [], usage
return automatic_metadata_filters, usage
def _process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]

View File

@ -441,10 +441,14 @@ class LLMNode(Node):
usage = LLMUsage.empty_usage()
finish_reason = None
full_text_buffer = io.StringIO()
collected_structured_output = None # Collect structured_output from streaming chunks
# Consume the invoke result and handle generator exception
try:
for result in invoke_result:
if isinstance(result, LLMResultChunkWithStructuredOutput):
# Collect structured_output from the chunk
if result.structured_output is not None:
collected_structured_output = dict(result.structured_output)
yield result
if isinstance(result, LLMResultChunk):
contents = result.delta.message.content
@ -492,6 +496,8 @@ class LLMNode(Node):
finish_reason=finish_reason,
# Reasoning content for workflow variables and downstream nodes
reasoning_content=reasoning_content,
# Pass structured output if collected from streaming chunks
structured_output=collected_structured_output,
)
@staticmethod

View File

@ -5,6 +5,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, cast
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import Segment, SegmentType
from core.workflow.enums import (
ErrorStrategy,
@ -27,6 +28,7 @@ from core.workflow.node_events import (
NodeRunResult,
StreamCompletedEvent,
)
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData
@ -40,7 +42,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class LoopNode(Node):
class LoopNode(LLMUsageTrackingMixin, Node):
"""
Loop Node.
"""
@ -108,7 +110,7 @@ class LoopNode(Node):
raise ValueError(f"Invalid value for loop variable {loop_variable.label}")
variable_selector = [self._node_id, loop_variable.label]
variable = segment_to_variable(segment=processed_segment, selector=variable_selector)
self.graph_runtime_state.variable_pool.add(variable_selector, variable)
self.graph_runtime_state.variable_pool.add(variable_selector, variable.value)
loop_variable_selectors[loop_variable.label] = variable_selector
inputs[loop_variable.label] = processed_segment.value
@ -117,6 +119,7 @@ class LoopNode(Node):
loop_duration_map: dict[str, float] = {}
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
loop_usage = LLMUsage.empty_usage()
# Start Loop event
yield LoopStartedEvent(
@ -163,6 +166,9 @@ class LoopNode(Node):
# Update the total tokens from this iteration
cost_tokens += graph_engine.graph_runtime_state.total_tokens
# Accumulate usage from the sub-graph execution
loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage)
# Collect loop variable values after iteration
single_loop_variable = {}
for key, selector in loop_variable_selectors.items():
@ -189,6 +195,7 @@ class LoopNode(Node):
)
self.graph_runtime_state.total_tokens += cost_tokens
self._accumulate_usage(loop_usage)
# Loop completed successfully
yield LoopSucceededEvent(
start_at=start_at,
@ -196,7 +203,9 @@ class LoopNode(Node):
outputs=self._node_data.outputs,
steps=loop_count,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: cost_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
"completed_reason": "loop_break" if reach_break_condition else "loop_completed",
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
@ -207,22 +216,28 @@ class LoopNode(Node):
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
outputs=self._node_data.outputs,
inputs=inputs,
llm_usage=loop_usage,
)
)
except Exception as e:
self._accumulate_usage(loop_usage)
yield LoopFailedEvent(
start_at=start_at,
inputs=inputs,
steps=loop_count,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
"completed_reason": "error",
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
@ -235,10 +250,13 @@ class LoopNode(Node):
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
llm_usage=loop_usage,
)
)

View File

@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, final
from typing_extensions import override
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
from core.workflow.enums import NodeType
from core.workflow.graph import NodeFactory
from core.workflow.nodes.base.node import Node
from libs.typing import is_str, is_str_dict
@ -82,8 +82,4 @@ class DifyNodeFactory(NodeFactory):
raise ValueError(f"Node {node_id} missing data information")
node_instance.init_node_data(node_data)
# If node has fail branch, change execution type to branch
if node_instance.error_strategy == ErrorStrategy.FAIL_BRANCH:
node_instance.execution_type = NodeExecutionType.BRANCH
return node_instance

View File

@ -747,7 +747,7 @@ class ParameterExtractorNode(Node):
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
role=PromptMessageRole.SYSTEM,
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str).replace("{{instructions}}", instruction),
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str, instructions=instruction),
)
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
return [system_prompt_messages, user_prompt_message]

View File

@ -135,7 +135,7 @@ Here are the chat histories between human and assistant, inside <histories></his
### Instructions:
Some extra information are provided below, you should always follow the instructions as possible as you can.
<instructions>
{{instructions}}
{instructions}
</instructions>
"""

View File

@ -6,10 +6,13 @@ from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file import File, FileTransferMethod
from core.model_runtime.entities.llm_entities import LLMUsage
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.tools.workflow_as_tool.tool import WorkflowTool
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.enums import (
@ -136,13 +139,14 @@ class ToolNode(Node):
try:
# convert tool messages
yield from self._transform_message(
_ = yield from self._transform_message(
messages=message_stream,
tool_info=tool_info,
parameters_for_log=parameters_for_log,
user_id=self.user_id,
tenant_id=self.tenant_id,
node_id=self._node_id,
tool_runtime=tool_runtime,
)
except ToolInvokeError as e:
yield StreamCompletedEvent(
@ -236,7 +240,8 @@ class ToolNode(Node):
user_id: str,
tenant_id: str,
node_id: str,
) -> Generator:
tool_runtime: Tool,
) -> Generator[NodeEventBase, None, LLMUsage]:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
@ -424,17 +429,34 @@ class ToolNode(Node):
is_final=True,
)
usage = self._extract_tool_usage(tool_runtime)
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
}
if usage.total_tokens > 0:
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
metadata={
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
},
metadata=metadata,
inputs=parameters_for_log,
llm_usage=usage,
)
)
return usage
@staticmethod
def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage:
if isinstance(tool_runtime, WorkflowTool):
return tool_runtime.latest_usage
return LLMUsage.empty_usage()
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,

View File

@ -260,7 +260,7 @@ class VariablePool(BaseModel):
# This ensures that we can keep the id of the system variables intact.
if self._has(selector):
continue
self.add(selector, value) # type: ignore
self.add(selector, value)
@classmethod
def empty(cls) -> "VariablePool":

View File

@ -32,7 +32,8 @@ if [[ "${MODE}" == "worker" ]]; then
exec celery -A celery_entrypoint.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \
--max-tasks-per-child ${MAX_TASKS_PER_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \
-Q ${CELERY_QUEUES:-dataset,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation}
-Q ${CELERY_QUEUES:-dataset,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation} \
--prefetch-multiplier=1
elif [[ "${MODE}" == "beat" ]]; then
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}

View File

@ -6,8 +6,8 @@ from tasks.clean_dataset_task import clean_dataset_task
@dataset_was_deleted.connect
def handle(sender: Dataset, **kwargs):
dataset = sender
assert dataset.doc_form
assert dataset.indexing_technique
if not dataset.doc_form or not dataset.indexing_technique:
return
clean_dataset_task.delay(
dataset.id,
dataset.tenant_id,

View File

@ -8,6 +8,6 @@ def handle(sender, **kwargs):
dataset_id = kwargs.get("dataset_id")
doc_form = kwargs.get("doc_form")
file_id = kwargs.get("file_id")
assert dataset_id is not None
assert doc_form is not None
if not dataset_id or not doc_form:
return
clean_document_task.delay(document_id, dataset_id, doc_form, file_id)

View File

@ -1,7 +1,12 @@
from configs import dify_config
from constants import HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN
from constants import HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN, HEADER_NAME_PASSPORT
from dify_app import DifyApp
BASE_CORS_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE, HEADER_NAME_PASSPORT)
SERVICE_API_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, "Authorization")
AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN)
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
def init_app(app: DifyApp):
# register blueprint routers
@ -17,7 +22,7 @@ def init_app(app: DifyApp):
CORS(
service_api_bp,
allow_headers=["Content-Type", "Authorization", HEADER_NAME_APP_CODE],
allow_headers=list(SERVICE_API_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
)
app.register_blueprint(service_api_bp)
@ -26,7 +31,7 @@ def init_app(app: DifyApp):
web_bp,
resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
supports_credentials=True,
allow_headers=["Content-Type", "Authorization", HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN],
allow_headers=list(AUTHENTICATED_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=["X-Version", "X-Env"],
)
@ -36,7 +41,7 @@ def init_app(app: DifyApp):
console_app_bp,
resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
supports_credentials=True,
allow_headers=["Content-Type", "Authorization", HEADER_NAME_CSRF_TOKEN],
allow_headers=list(AUTHENTICATED_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=["X-Version", "X-Env"],
)
@ -44,7 +49,7 @@ def init_app(app: DifyApp):
CORS(
files_bp,
allow_headers=["Content-Type", HEADER_NAME_CSRF_TOKEN],
allow_headers=list(FILES_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
)
app.register_blueprint(files_bp)

View File

@ -7,7 +7,7 @@ def is_enabled() -> bool:
def init_app(app: DifyApp):
from flask_compress import Compress # type: ignore
from flask_compress import Compress
compress = Compress()
compress.init_app(app)

View File

@ -1,6 +1,6 @@
import json
import flask_login # type: ignore
import flask_login
from flask import Response, request
from flask_login import user_loaded_from_request, user_logged_in
from werkzeug.exceptions import NotFound, Unauthorized

View File

@ -2,7 +2,7 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
import flask_migrate # type: ignore
import flask_migrate
from extensions.ext_database import db

View File

@ -103,7 +103,7 @@ def init_app(app: DifyApp):
def shutdown_tracer():
provider = trace.get_tracer_provider()
if hasattr(provider, "force_flush"):
provider.force_flush() # ty: ignore [call-non-callable]
provider.force_flush()
class ExceptionLoggingHandler(logging.Handler):
"""Custom logging handler that creates spans for logging.exception() calls"""

View File

@ -6,4 +6,4 @@ def init_app(app: DifyApp):
if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED:
from werkzeug.middleware.proxy_fix import ProxyFix
app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore
app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore[method-assign]

View File

@ -5,7 +5,7 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
if dify_config.SENTRY_DSN:
import sentry_sdk
from langfuse import parse_error # type: ignore
from langfuse import parse_error
from sentry_sdk.integrations.celery import CeleryIntegration
from sentry_sdk.integrations.flask import FlaskIntegration
from werkzeug.exceptions import HTTPException

View File

@ -1,7 +1,7 @@
import posixpath
from collections.abc import Generator
import oss2 as aliyun_s3 # type: ignore
import oss2 as aliyun_s3
from configs import dify_config
from extensions.storage.base_storage import BaseStorage

View File

@ -2,9 +2,9 @@ import base64
import hashlib
from collections.abc import Generator
from baidubce.auth.bce_credentials import BceCredentials # type: ignore
from baidubce.bce_client_configuration import BceClientConfiguration # type: ignore
from baidubce.services.bos.bos_client import BosClient # type: ignore
from baidubce.auth.bce_credentials import BceCredentials
from baidubce.bce_client_configuration import BceClientConfiguration
from baidubce.services.bos.bos_client import BosClient
from configs import dify_config
from extensions.storage.base_storage import BaseStorage

View File

@ -11,7 +11,7 @@ from collections.abc import Generator
from io import BytesIO
from pathlib import Path
import clickzetta # type: ignore[import]
import clickzetta
from pydantic import BaseModel, model_validator
from extensions.storage.base_storage import BaseStorage

View File

@ -34,7 +34,7 @@ class VolumePermissionManager:
# Support two initialization methods: connection object or configuration dictionary
if isinstance(connection_or_config, dict):
# Create connection from configuration dictionary
import clickzetta # type: ignore[import-untyped]
import clickzetta
config = connection_or_config
self._connection = clickzetta.connect(

View File

@ -3,7 +3,7 @@ import io
import json
from collections.abc import Generator
from google.cloud import storage as google_cloud_storage # type: ignore
from google.cloud import storage as google_cloud_storage
from configs import dify_config
from extensions.storage.base_storage import BaseStorage

View File

@ -1,6 +1,6 @@
from collections.abc import Generator
from obs import ObsClient # type: ignore
from obs import ObsClient
from configs import dify_config
from extensions.storage.base_storage import BaseStorage

View File

@ -1,7 +1,7 @@
from collections.abc import Generator
import boto3 # type: ignore
from botocore.exceptions import ClientError # type: ignore
import boto3
from botocore.exceptions import ClientError
from configs import dify_config
from extensions.storage.base_storage import BaseStorage

View File

@ -1,6 +1,6 @@
from collections.abc import Generator
from qcloud_cos import CosConfig, CosS3Client # type: ignore
from qcloud_cos import CosConfig, CosS3Client
from configs import dify_config
from extensions.storage.base_storage import BaseStorage

View File

@ -1,6 +1,6 @@
from collections.abc import Generator
import tos # type: ignore
import tos
from configs import dify_config
from extensions.storage.base_storage import BaseStorage

View File

@ -146,6 +146,6 @@ class ExternalApi(Api):
kwargs["doc"] = dify_config.SWAGGER_UI_PATH if dify_config.SWAGGER_UI_ENABLED else False
# manual separate call on construction and init_app to ensure configs in kwargs effective
super().__init__(app=None, *args, **kwargs) # type: ignore
super().__init__(app=None, *args, **kwargs)
self.init_app(app, **kwargs)
register_external_error_handlers(self)

View File

@ -23,7 +23,7 @@ from hashlib import sha1
import Crypto.Hash.SHA1
import Crypto.Util.number
import gmpy2 # type: ignore
import gmpy2
from Crypto import Random
from Crypto.Signature.pss import MGF1
from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes
@ -136,7 +136,7 @@ class PKCS1OAepCipher:
# Step 3a (OS2IP)
em_int = bytes_to_long(em)
# Step 3b (RSAEP)
m_int = gmpy2.powmod(em_int, self._key.e, self._key.n) # ty: ignore [unresolved-attribute]
m_int = gmpy2.powmod(em_int, self._key.e, self._key.n)
# Step 3c (I2OSP)
c = long_to_bytes(m_int, k)
return c
@ -169,7 +169,7 @@ class PKCS1OAepCipher:
ct_int = bytes_to_long(ciphertext)
# Step 2b (RSADP)
# m_int = self._key._decrypt(ct_int)
m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n) # ty: ignore [unresolved-attribute]
m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n)
# Complete step 2c (I2OSP)
em = long_to_bytes(m_int, k)
# Step 3a
@ -191,12 +191,12 @@ class PKCS1OAepCipher:
# Step 3g
one_pos = hLen + db[hLen:].find(b"\x01")
lHash1 = db[:hLen]
invalid = bord(y) | int(one_pos < hLen) # type: ignore
invalid = bord(y) | int(one_pos < hLen) # type: ignore[arg-type]
hash_compare = strxor(lHash1, lHash)
for x in hash_compare:
invalid |= bord(x) # type: ignore
invalid |= bord(x) # type: ignore[arg-type]
for x in db[hLen:one_pos]:
invalid |= bord(x) # type: ignore
invalid |= bord(x) # type: ignore[arg-type]
if invalid != 0:
raise ValueError("Incorrect decryption.")
# Step 4

View File

@ -81,6 +81,8 @@ class AvatarUrlField(fields.Raw):
from models import Account
if isinstance(obj, Account) and obj.avatar is not None:
if obj.avatar.startswith(("http://", "https://")):
return obj.avatar
return file_helpers.get_signed_file_url(obj.avatar)
return None

View File

@ -3,7 +3,7 @@ from functools import wraps
from typing import Any
from flask import current_app, g, has_request_context, request
from flask_login.config import EXEMPT_METHODS # type: ignore
from flask_login.config import EXEMPT_METHODS
from werkzeug.local import LocalProxy
from configs import dify_config
@ -87,7 +87,7 @@ def _get_user() -> EndUser | Account | None:
if "_login_user" not in g:
current_app.login_manager._load_user() # type: ignore
return g._login_user # type: ignore
return g._login_user
return None

View File

@ -1,8 +1,8 @@
import logging
import sendgrid # type: ignore
import sendgrid
from python_http_client.exceptions import ForbiddenError, UnauthorizedError
from sendgrid.helpers.mail import Content, Email, Mail, To # type: ignore
from sendgrid.helpers.mail import Content, Email, Mail, To
logger = logging.getLogger(__name__)

View File

@ -12,6 +12,7 @@ from constants import (
COOKIE_NAME_CSRF_TOKEN,
COOKIE_NAME_PASSPORT,
COOKIE_NAME_REFRESH_TOKEN,
COOKIE_NAME_WEBAPP_ACCESS_TOKEN,
HEADER_NAME_CSRF_TOKEN,
HEADER_NAME_PASSPORT,
)
@ -81,6 +82,14 @@ def extract_access_token(request: Request) -> str | None:
return _try_extract_from_cookie(request) or _try_extract_from_header(request)
def extract_webapp_access_token(request: Request) -> str | None:
"""
Try to extract webapp access token from cookie, then header.
"""
return request.cookies.get(_real_cookie_name(COOKIE_NAME_WEBAPP_ACCESS_TOKEN)) or _try_extract_from_header(request)
def extract_webapp_passport(app_code: str, request: Request) -> str | None:
"""
Try to extract app token from header or params.
@ -155,6 +164,10 @@ def clear_access_token_from_cookie(response: Response, samesite: str = "Lax"):
_clear_cookie(response, COOKIE_NAME_ACCESS_TOKEN, samesite)
def clear_webapp_access_token_from_cookie(response: Response, samesite: str = "Lax"):
_clear_cookie(response, COOKIE_NAME_WEBAPP_ACCESS_TOKEN, samesite)
def clear_refresh_token_from_cookie(response: Response):
_clear_cookie(response, COOKIE_NAME_REFRESH_TOKEN)

View File

@ -22,55 +22,6 @@ def upgrade():
batch_op.add_column(sa.Column('app_mode', sa.String(length=255), nullable=True))
batch_op.create_index('message_app_mode_idx', ['app_mode'], unique=False)
conn = op.get_bind()
# Strategy: Update in batches to minimize lock time
# For large tables (millions of rows), this prevents long-running transactions
batch_size = 10000
print("Starting backfill of app_mode from conversations...")
# Use a more efficient UPDATE with JOIN
# This query updates messages.app_mode from conversations.mode
# Using string formatting for LIMIT since it's a constant
update_query = f"""
UPDATE messages m
SET app_mode = c.mode
FROM conversations c
WHERE m.conversation_id = c.id
AND m.app_mode IS NULL
AND m.id IN (
SELECT id FROM messages
WHERE app_mode IS NULL
LIMIT {batch_size}
)
"""
# Execute batched updates
total_updated = 0
iteration = 0
while True:
iteration += 1
result = conn.execute(sa.text(update_query))
# Check if result is None or has no rowcount
if result is None:
print("Warning: Query returned None, stopping backfill")
break
rows_updated = result.rowcount if hasattr(result, 'rowcount') else 0
total_updated += rows_updated
if rows_updated == 0:
break
print(f"Iteration {iteration}: Updated {rows_updated} messages (total: {total_updated})")
# For very large tables, add a small delay to reduce load
# Uncomment if needed: import time; time.sleep(0.1)
print(f"Backfill completed. Total messages updated: {total_updated}")
# ### end Alembic commands ###

View File

@ -0,0 +1,36 @@
"""remove-builtin-template-user
Revision ID: ae662b25d9bc
Revises: d98acf217d43
Create Date: 2025-10-21 14:30:28.566192
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'ae662b25d9bc'
down_revision = 'd98acf217d43'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
batch_op.drop_column('updated_by')
batch_op.drop_column('created_by')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
batch_op.add_column(sa.Column('created_by', sa.UUID(), autoincrement=False, nullable=False))
batch_op.add_column(sa.Column('updated_by', sa.UUID(), autoincrement=False, nullable=True))
# ### end Alembic commands ###

View File

@ -5,7 +5,7 @@ from datetime import datetime
from typing import Any, Optional
import sqlalchemy as sa
from flask_login import UserMixin # type: ignore[import-untyped]
from flask_login import UserMixin
from sqlalchemy import DateTime, String, func, select
from sqlalchemy.orm import Mapped, Session, mapped_column
from typing_extensions import deprecated

View File

@ -1239,15 +1239,6 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
language = mapped_column(db.String(255), nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
created_by = mapped_column(StringUUID, nullable=False)
updated_by = mapped_column(StringUUID, nullable=True)
@property
def created_user_name(self):
account = db.session.query(Account).where(Account.id == self.created_by).first()
if account:
return account.name
return ""
class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]

View File

@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast
import sqlalchemy as sa
from flask import request
from flask_login import UserMixin # type: ignore[import-untyped]
from flask_login import UserMixin
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from sqlalchemy.orm import Mapped, Session, mapped_column

View File

@ -219,7 +219,7 @@ class WorkflowToolProvider(TypeBase):
sa.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
# name of the workflow provider
name: Mapped[str] = mapped_column(String(255), nullable=False)
# label of the workflow provider

View File

@ -1,6 +1,6 @@
[project]
name = "dify-api"
version = "1.9.1"
version = "1.9.2"
requires-python = ">=3.11,<3.13"
dependencies = [

View File

@ -16,7 +16,25 @@
"opentelemetry.instrumentation.requests",
"opentelemetry.instrumentation.sqlalchemy",
"opentelemetry.instrumentation.redis",
"opentelemetry.instrumentation.httpx"
"langfuse",
"cloudscraper",
"readabilipy",
"pypandoc",
"pypdfium2",
"webvtt",
"flask_compress",
"oss2",
"baidubce.auth.bce_credentials",
"baidubce.bce_client_configuration",
"baidubce.services.bos.bos_client",
"clickzetta",
"google.cloud",
"obs",
"qcloud_cos",
"tos",
"gmpy2",
"sendgrid",
"sendgrid.helpers.mail"
],
"reportUnknownMemberType": "hint",
"reportUnknownParameterType": "hint",
@ -28,7 +46,7 @@
"reportUnnecessaryComparison": "hint",
"reportUnnecessaryIsInstance": "hint",
"reportUntypedFunctionDecorator": "hint",
"reportUnnecessaryTypeIgnoreComment": "hint",
"reportAttributeAccessIssue": "hint",
"pythonVersion": "3.11",
"pythonPlatform": "All"

View File

@ -48,7 +48,7 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
try:
repository_class = import_string(class_path)
return repository_class(session_maker=session_maker) # type: ignore[no-any-return]
return repository_class(session_maker=session_maker)
except (ImportError, Exception) as e:
raise RepositoryImportError(
f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}"
@ -77,6 +77,6 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
try:
repository_class = import_string(class_path)
return repository_class(session_maker=session_maker) # type: ignore[no-any-return]
return repository_class(session_maker=session_maker)
except (ImportError, Exception) as e:
raise RepositoryImportError(f"Failed to create APIWorkflowRunRepository from '{class_path}': {e}") from e

View File

@ -13,7 +13,7 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized
from configs import dify_config
from constants.languages import language_timezone_mapping, languages
from constants.languages import get_valid_language, language_timezone_mapping
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client, redis_fallback
@ -1259,7 +1259,7 @@ class RegisterService:
return f"member_invite:token:{token}"
@classmethod
def setup(cls, email: str, name: str, password: str, ip_address: str):
def setup(cls, email: str, name: str, password: str, ip_address: str, language: str):
"""
Setup dify
@ -1269,11 +1269,10 @@ class RegisterService:
:param ip_address: ip address
"""
try:
# Register
account = AccountService.create_account(
email=email,
name=name,
interface_language=languages[0],
interface_language=get_valid_language(language),
password=password,
is_setup=True,
)
@ -1315,7 +1314,7 @@ class RegisterService:
account = AccountService.create_account(
email=email,
name=name,
interface_language=language or languages[0],
interface_language=get_valid_language(language),
password=password,
is_setup=is_setup,
)

View File

@ -7,7 +7,7 @@ from enum import StrEnum
from urllib.parse import urlparse
from uuid import uuid4
import yaml # type: ignore
import yaml
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from packaging import version
@ -563,7 +563,7 @@ class AppDslService:
else:
cls._append_model_config_export_data(export_data, app_model)
return yaml.dump(export_data, allow_unicode=True) # type: ignore
return yaml.dump(export_data, allow_unicode=True)
@classmethod
def _append_workflow_export_data(

Some files were not shown because too many files have changed in this diff Show More