Merge branch 'feat/mcp-06-18' into deploy/dev

This commit is contained in:
Novice 2025-10-23 17:20:21 +08:00
commit b48a7c7cda
No known key found for this signature in database
GPG Key ID: EE3F68E3105DAAAB
466 changed files with 31391 additions and 7797 deletions

View File

@ -437,6 +437,9 @@ CODE_EXECUTION_SSL_VERIFY=True
CODE_EXECUTION_POOL_MAX_CONNECTIONS=100
CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS=20
CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY=5.0
CODE_EXECUTION_CONNECT_TIMEOUT=10
CODE_EXECUTION_READ_TIMEOUT=60
CODE_EXECUTION_WRITE_TIMEOUT=10
CODE_MAX_NUMBER=9223372036854775807
CODE_MIN_NUMBER=-9223372036854775808
CODE_MAX_STRING_LENGTH=400000

View File

@ -548,7 +548,7 @@ class UpdateConfig(BaseSettings):
class WorkflowVariableTruncationConfig(BaseSettings):
WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE: PositiveInt = Field(
# 100KB
# 1000 KiB
1024_000,
description="Maximum size for variable to trigger final truncation.",
)

View File

@ -145,7 +145,7 @@ class DatabaseConfig(BaseSettings):
default="postgresql",
)
@computed_field # type: ignore[misc]
@computed_field # type: ignore[prop-decorator]
@property
def SQLALCHEMY_DATABASE_URI(self) -> str:
db_extras = (
@ -198,7 +198,7 @@ class DatabaseConfig(BaseSettings):
default=os.cpu_count() or 1,
)
@computed_field # type: ignore[misc]
@computed_field # type: ignore[prop-decorator]
@property
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
# Parse DB_EXTRAS for 'options'

View File

@ -56,11 +56,15 @@ else:
}
DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)
# console
COOKIE_NAME_ACCESS_TOKEN = "access_token"
COOKIE_NAME_REFRESH_TOKEN = "refresh_token"
COOKIE_NAME_PASSPORT = "passport"
COOKIE_NAME_CSRF_TOKEN = "csrf_token"
# webapp
COOKIE_NAME_WEBAPP_ACCESS_TOKEN = "webapp_access_token"
COOKIE_NAME_PASSPORT = "passport"
HEADER_NAME_CSRF_TOKEN = "X-CSRF-Token"
HEADER_NAME_APP_CODE = "X-App-Code"
HEADER_NAME_PASSPORT = "X-App-Passport"

View File

@ -31,3 +31,9 @@ def supported_language(lang):
error = f"{lang} is not a valid language."
raise ValueError(error)
def get_valid_language(lang: str | None) -> str:
if lang and lang in languages:
return lang
return languages[0]

View File

@ -24,7 +24,7 @@ except ImportError:
)
else:
warnings.warn("To use python-magic guess MIMETYPE, you need to install `libmagic`", stacklevel=2)
magic = None # type: ignore
magic = None # type: ignore[assignment]
from pydantic import BaseModel

View File

@ -4,7 +4,7 @@ from flask_restx import Resource, reqparse
import services
from configs import dify_config
from constants.languages import languages
from constants.languages import get_valid_language
from controllers.console import console_ns
from controllers.console.auth.error import (
AuthenticationFailedError,
@ -207,10 +207,12 @@ class EmailCodeLoginApi(Resource):
.add_argument("email", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
user_email = args["email"]
language = args["language"]
token_data = AccountService.get_email_code_login_data(args["token"])
if token_data is None:
@ -244,7 +246,9 @@ class EmailCodeLoginApi(Resource):
if account is None:
try:
account = AccountService.create_account_and_tenant(
email=user_email, name=user_email, interface_language=languages[0]
email=user_email,
name=user_email,
interface_language=get_valid_language(language),
)
except WorkSpaceNotAllowedCreateError:
raise NotAllowedCreateWorkspace()

View File

@ -22,7 +22,7 @@ from core.errors.error import (
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager
from libs import helper
from libs.login import current_user as current_user_
from libs.login import current_account_with_tenant
from models.model import AppMode, InstalledApp
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
@ -31,8 +31,6 @@ from .. import console_ns
logger = logging.getLogger(__name__)
current_user = current_user_._get_current_object() # type: ignore
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
class InstalledAppWorkflowRunApi(InstalledAppResource):
@ -40,6 +38,7 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
"""
Run workflow
"""
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if not app_model:
raise NotWorkflowAppError()
@ -53,7 +52,6 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
.add_argument("files", type=list, required=False, location="json")
)
args = parser.parse_args()
assert current_user is not None
try:
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
@ -89,7 +87,6 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
assert current_user is not None
# Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)

View File

@ -74,12 +74,17 @@ class SetupApi(Resource):
.add_argument("email", type=email, required=True, location="json")
.add_argument("name", type=StrLen(30), required=True, location="json")
.add_argument("password", type=valid_password, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
# setup
RegisterService.setup(
email=args["email"], name=args["name"], password=args["password"], ip_address=extract_remote_ip(request)
email=args["email"],
name=args["name"],
password=args["password"],
ip_address=extract_remote_ip(request),
language=args["language"],
)
return {"result": "success"}, 201

View File

@ -30,7 +30,7 @@ from models.provider_ids import ToolProviderID
from services.plugin.oauth_service import OAuthProxyService
from services.tools.api_tools_manage_service import ApiToolManageService
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from services.tools.mcp_tools_manage_service import MCPToolManageService
from services.tools.mcp_tools_manage_service import MCPToolManageService, OAuthDataType
from services.tools.tool_labels_service import ToolLabelsService
from services.tools.tools_manage_service import ToolCommonService
from services.tools.tools_transform_service import ToolTransformService
@ -897,10 +897,6 @@ class ToolProviderMCPApi(Resource):
args = parser.parse_args()
user, tenant_id = current_account_with_tenant()
# Validate server URL
if not is_valid_url(args["server_url"]):
raise ValueError("Server URL is not valid.")
# Parse and validate models
configuration = MCPConfiguration.model_validate(args["configuration"])
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
@ -941,15 +937,21 @@ class ToolProviderMCPApi(Resource):
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
)
args = parser.parse_args()
if not is_valid_url(args["server_url"]):
if "[__HIDDEN__]" in args["server_url"]:
pass
else:
raise ValueError("Server URL is not valid.")
configuration = MCPConfiguration.model_validate(args["configuration"])
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
_, current_tenant_id = current_account_with_tenant()
# Step 1: Validate server URL change if needed (includes URL format validation and network operation)
validation_result = None
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
validation_result = service.validate_server_url_change(
tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"]
)
# No need to check for errors here, exceptions will be raised directly
# Step 2: Perform database update in a transaction
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.update_provider(
@ -964,6 +966,7 @@ class ToolProviderMCPApi(Resource):
headers=args["headers"],
configuration=configuration,
authentication=authentication,
validation_result=validation_result,
)
return {"result": "success"}
@ -998,47 +1001,49 @@ class ToolMCPAuthApi(Resource):
provider_id = args["provider_id"]
_, tenant_id = current_account_with_tenant()
with Session(db.engine) as session:
with session.begin():
service = MCPToolManageService(session=session)
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
if not db_provider:
raise ValueError("provider not found")
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
if not db_provider:
raise ValueError("provider not found")
# Convert to entity
provider_entity = db_provider.to_entity()
server_url = provider_entity.decrypt_server_url()
headers = provider_entity.decrypt_authentication()
# Convert to entity
provider_entity = db_provider.to_entity()
server_url = provider_entity.decrypt_server_url()
headers = provider_entity.decrypt_authentication()
# Try to connect without active transaction
# Try to connect without active transaction
try:
# Use MCPClientWithAuthRetry to handle authentication automatically
with MCPClient(
server_url=server_url,
headers=headers,
timeout=provider_entity.timeout,
sse_read_timeout=provider_entity.sse_read_timeout,
):
# Create new transaction for update
with session.begin():
service.update_provider_credentials(
provider=db_provider,
credentials=provider_entity.credentials,
authed=True,
)
return {"result": "success"}
except MCPAuthError as e:
service = MCPToolManageService(session=session)
try:
# Use MCPClientWithAuthRetry to handle authentication automatically
with MCPClient(
server_url=server_url,
headers=headers,
timeout=provider_entity.timeout,
sse_read_timeout=provider_entity.sse_read_timeout,
):
# Create new transaction for update
with session.begin():
service.update_provider_credentials(
provider=db_provider,
credentials=provider_entity.credentials,
authed=True,
)
return {"result": "success"}
except MCPAuthError as e:
service = MCPToolManageService(session=session)
try:
return auth(provider_entity, service, args.get("authorization_code"))
except MCPRefreshTokenError as e:
with session.begin():
service.clear_provider_credentials(provider=db_provider)
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
except MCPError as e:
auth_result = auth(provider_entity, args.get("authorization_code"))
with session.begin():
response = service.execute_auth_actions(auth_result)
return response
except MCPRefreshTokenError as e:
with session.begin():
service.clear_provider_credentials(provider=db_provider)
raise ValueError(f"Failed to connect to MCP server: {e}") from e
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
except MCPError as e:
with session.begin():
service.clear_provider_credentials(provider=db_provider)
raise ValueError(f"Failed to connect to MCP server: {e}") from e
@console_ns.route("/workspaces/current/tool-provider/mcp/tools/<path:provider_id>")
@ -1048,7 +1053,7 @@ class ToolMCPDetailApi(Resource):
@account_initialization_required
def get(self, provider_id):
_, tenant_id = current_account_with_tenant()
with Session(db.engine) as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
@ -1062,7 +1067,7 @@ class ToolMCPListAllApi(Resource):
def get(self):
_, tenant_id = current_account_with_tenant()
with Session(db.engine, expire_on_commit=False) as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
tools = service.list_providers(tenant_id=tenant_id)
@ -1100,6 +1105,11 @@ class ToolMCPCallbackApi(Resource):
# Create service instance for handle_callback
with Session(db.engine) as session, session.begin():
mcp_service = MCPToolManageService(session=session)
handle_callback(state_key, authorization_code, mcp_service)
# handle_callback now returns state data and tokens
state_data, tokens = handle_callback(state_key, authorization_code)
# Save tokens using the service layer
mcp_service.save_oauth_data(
state_data.provider_id, state_data.tenant_id, tokens.model_dump(), OAuthDataType.TOKENS
)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")

View File

@ -193,15 +193,16 @@ class MCPAppApi(Resource):
except ValidationError as e:
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str, session: Session) -> EndUser | None:
"""Get end user from existing session - optimized query"""
return (
session.query(EndUser)
.where(EndUser.tenant_id == tenant_id)
.where(EndUser.session_id == mcp_server_id)
.where(EndUser.type == "mcp")
.first()
)
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str) -> EndUser | None:
"""Get end user - manages its own database session"""
with Session(db.engine, expire_on_commit=False) as session, session.begin():
return (
session.query(EndUser)
.where(EndUser.tenant_id == tenant_id)
.where(EndUser.session_id == mcp_server_id)
.where(EndUser.type == "mcp")
.first()
)
def _create_end_user(
self, client_name: str, tenant_id: str, app_id: str, mcp_server_id: str, session: Session
@ -229,7 +230,7 @@ class MCPAppApi(Resource):
request_id: Union[int, str],
) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError | None:
"""Handle MCP request and return response"""
end_user = self._retrieve_end_user(mcp_server.tenant_id, mcp_server.id, session)
end_user = self._retrieve_end_user(mcp_server.tenant_id, mcp_server.id)
if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest):
client_info = mcp_request.root.params.clientInfo

View File

@ -17,8 +17,8 @@ from libs.helper import email
from libs.passport import PassportService
from libs.password import valid_password
from libs.token import (
clear_access_token_from_cookie,
extract_access_token,
clear_webapp_access_token_from_cookie,
extract_webapp_access_token,
)
from services.account_service import AccountService
from services.app_service import AppService
@ -81,7 +81,7 @@ class LoginStatusApi(Resource):
)
def get(self):
app_code = request.args.get("app_code")
token = extract_access_token(request)
token = extract_webapp_access_token(request)
if not app_code:
return {
"logged_in": bool(token),
@ -128,7 +128,7 @@ class LogoutApi(Resource):
response = make_response({"result": "success"})
# enterprise SSO sets same site to None in https deployment
# so we need to logout by calling api
clear_access_token_from_cookie(response, samesite="None")
clear_webapp_access_token_from_cookie(response, samesite="None")
return response

View File

@ -12,10 +12,8 @@ from controllers.web import web_ns
from controllers.web.error import WebAppAuthRequiredError
from extensions.ext_database import db
from libs.passport import PassportService
from libs.token import extract_access_token
from libs.token import extract_webapp_access_token
from models.model import App, EndUser, Site
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
@ -37,23 +35,18 @@ class PassportResource(Resource):
system_features = FeatureService.get_system_features()
app_code = request.headers.get(HEADER_NAME_APP_CODE)
user_id = request.args.get("user_id")
access_token = extract_access_token(request)
access_token = extract_webapp_access_token(request)
if app_code is None:
raise Unauthorized("X-App-Code header is missing.")
app_id = AppService.get_app_id_by_code(app_code)
# exchange token for enterprise logined web user
enterprise_user_decoded = decode_enterprise_webapp_user_id(access_token)
if enterprise_user_decoded:
# a web user has already logged in, exchange a token for this app without redirecting to the login page
return exchange_token_for_existing_web_user(
app_code=app_code, enterprise_user_decoded=enterprise_user_decoded
)
if system_features.webapp_auth.enabled:
app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id)
if not app_settings or not app_settings.access_mode == "public":
raise WebAppAuthRequiredError()
enterprise_user_decoded = decode_enterprise_webapp_user_id(access_token)
app_auth_type = WebAppAuthService.get_app_auth_type(app_code=app_code)
if app_auth_type != WebAppAuthType.PUBLIC:
if not enterprise_user_decoded:
raise WebAppAuthRequiredError()
return exchange_token_for_existing_web_user(
app_code=app_code, enterprise_user_decoded=enterprise_user_decoded, auth_type=app_auth_type
)
# get site from db and check if it is normal
site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal"))
@ -124,7 +117,7 @@ def decode_enterprise_webapp_user_id(jwt_token: str | None):
return decoded
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict):
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict, auth_type: WebAppAuthType):
"""
Exchange a token for an existing web user session.
"""
@ -145,13 +138,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound()
app_auth_type = WebAppAuthService.get_app_auth_type(app_code=app_code)
if app_auth_type == WebAppAuthType.PUBLIC:
if auth_type == WebAppAuthType.PUBLIC:
return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded)
elif app_auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external":
elif auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external":
raise WebAppAuthRequiredError("Please login as external user.")
elif app_auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal":
elif auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal":
raise WebAppAuthRequiredError("Please login as internal user.")
end_user = None

View File

@ -211,8 +211,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
user=user,
stream=streaming,
)
# FIXME: Type hinting issue here, ignore it for now, will fix it later
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) # type: ignore
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def _generate_worker(
self,

View File

@ -255,7 +255,7 @@ class PipelineGenerator(BaseAppGenerator):
json_text = json.dumps(text)
upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id)
features = FeatureService.get_features(dataset.tenant_id)
if features.billing.subscription.plan == "sandbox":
if features.billing.enabled and features.billing.subscription.plan == "sandbox":
tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}"
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}"

View File

@ -89,7 +89,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -98,7 +98,7 @@ class RateLimit:
else:
return RateLimitGenerator(
rate_limit=self,
generator=generator, # ty: ignore [invalid-argument-type]
generator=generator,
request_id=request_id,
)

View File

@ -49,7 +49,7 @@ class BasedGenerateTaskPipeline:
if isinstance(e, InvokeAuthorizationError):
err = InvokeAuthorizationError("Incorrect API key provided")
elif isinstance(e, InvokeError | ValueError):
err = e # ty: ignore [invalid-assignment]
err = e
else:
description = getattr(e, "description", None)
err = Exception(description if description is not None else str(e))

View File

@ -1868,7 +1868,7 @@ class ProviderConfigurations(BaseModel):
if "/" not in key:
key = str(ModelProviderID(key))
return self.configurations.get(key, default) # type: ignore
return self.configurations.get(key, default)
class ProviderModelBundle(BaseModel):

View File

@ -20,7 +20,7 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz
else:
# Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
# FIXME: mypy does not support the type of spec.loader
spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore
spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore[assignment]
if not spec or not spec.loader:
raise Exception(f"Failed to load module {module_name} from {py_file_path!r}")
if use_lazy_loader:

View File

@ -49,62 +49,80 @@ class IndexingRunner:
self.storage = storage
self.model_manager = ModelManager()
def _handle_indexing_error(self, document_id: str, error: Exception) -> None:
"""Handle indexing errors by updating document status."""
logger.exception("consume document failed")
document = db.session.get(DatasetDocument, document_id)
if document:
document.indexing_status = "error"
error_message = getattr(error, "description", str(error))
document.error = str(error_message)
document.stopped_at = naive_utc_now()
db.session.commit()
def run(self, dataset_documents: list[DatasetDocument]):
"""Run the indexing process."""
for dataset_document in dataset_documents:
document_id = dataset_document.id
try:
# Re-query the document to ensure it's bound to the current session
requeried_document = db.session.get(DatasetDocument, document_id)
if not requeried_document:
logger.warning("Document not found, skipping document id: %s", document_id)
continue
# get dataset
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
if not dataset:
raise ValueError("no dataset found")
# get the process rule
stmt = select(DatasetProcessRule).where(
DatasetProcessRule.id == dataset_document.dataset_process_rule_id
DatasetProcessRule.id == requeried_document.dataset_process_rule_id
)
processing_rule = db.session.scalar(stmt)
if not processing_rule:
raise ValueError("no process rule found")
index_type = dataset_document.doc_form
index_type = requeried_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
# extract
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
# transform
documents = self._transform(
index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict()
index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict()
)
# save segment
self._load_segments(dataset, dataset_document, documents)
self._load_segments(dataset, requeried_document, documents)
# load
self._load(
index_processor=index_processor,
dataset=dataset,
dataset_document=dataset_document,
dataset_document=requeried_document,
documents=documents,
)
except DocumentIsPausedError:
raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}")
raise DocumentIsPausedError(f"Document paused, document id: {document_id}")
except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error"
dataset_document.error = str(e.description)
dataset_document.stopped_at = naive_utc_now()
db.session.commit()
self._handle_indexing_error(document_id, e)
except ObjectDeletedError:
logger.warning("Document deleted, document id: %s", dataset_document.id)
logger.warning("Document deleted, document id: %s", document_id)
except Exception as e:
logger.exception("consume document failed")
dataset_document.indexing_status = "error"
dataset_document.error = str(e)
dataset_document.stopped_at = naive_utc_now()
db.session.commit()
self._handle_indexing_error(document_id, e)
def run_in_splitting_status(self, dataset_document: DatasetDocument):
"""Run the indexing process when the index_status is splitting."""
document_id = dataset_document.id
try:
# Re-query the document to ensure it's bound to the current session
requeried_document = db.session.get(DatasetDocument, document_id)
if not requeried_document:
logger.warning("Document not found: %s", document_id)
return
# get dataset
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
if not dataset:
raise ValueError("no dataset found")
@ -112,57 +130,60 @@ class IndexingRunner:
# get exist document_segment list and delete
document_segments = (
db.session.query(DocumentSegment)
.filter_by(dataset_id=dataset.id, document_id=dataset_document.id)
.filter_by(dataset_id=dataset.id, document_id=requeried_document.id)
.all()
)
for document_segment in document_segments:
db.session.delete(document_segment)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX:
# delete child chunks
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
db.session.commit()
# get the process rule
stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == requeried_document.dataset_process_rule_id)
processing_rule = db.session.scalar(stmt)
if not processing_rule:
raise ValueError("no process rule found")
index_type = dataset_document.doc_form
index_type = requeried_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
# extract
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
# transform
documents = self._transform(
index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict()
index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict()
)
# save segment
self._load_segments(dataset, dataset_document, documents)
self._load_segments(dataset, requeried_document, documents)
# load
self._load(
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
index_processor=index_processor,
dataset=dataset,
dataset_document=requeried_document,
documents=documents,
)
except DocumentIsPausedError:
raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}")
raise DocumentIsPausedError(f"Document paused, document id: {document_id}")
except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error"
dataset_document.error = str(e.description)
dataset_document.stopped_at = naive_utc_now()
db.session.commit()
self._handle_indexing_error(document_id, e)
except Exception as e:
logger.exception("consume document failed")
dataset_document.indexing_status = "error"
dataset_document.error = str(e)
dataset_document.stopped_at = naive_utc_now()
db.session.commit()
self._handle_indexing_error(document_id, e)
def run_in_indexing_status(self, dataset_document: DatasetDocument):
"""Run the indexing process when the index_status is indexing."""
document_id = dataset_document.id
try:
# Re-query the document to ensure it's bound to the current session
requeried_document = db.session.get(DatasetDocument, document_id)
if not requeried_document:
logger.warning("Document not found: %s", document_id)
return
# get dataset
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
if not dataset:
raise ValueError("no dataset found")
@ -170,7 +191,7 @@ class IndexingRunner:
# get exist document_segment list and delete
document_segments = (
db.session.query(DocumentSegment)
.filter_by(dataset_id=dataset.id, document_id=dataset_document.id)
.filter_by(dataset_id=dataset.id, document_id=requeried_document.id)
.all()
)
@ -188,7 +209,7 @@ class IndexingRunner:
"dataset_id": document_segment.dataset_id,
},
)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = document_segment.get_child_chunks()
if child_chunks:
child_documents = []
@ -206,24 +227,20 @@ class IndexingRunner:
document.children = child_documents
documents.append(document)
# build index
index_type = dataset_document.doc_form
index_type = requeried_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
self._load(
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
index_processor=index_processor,
dataset=dataset,
dataset_document=requeried_document,
documents=documents,
)
except DocumentIsPausedError:
raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}")
raise DocumentIsPausedError(f"Document paused, document id: {document_id}")
except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error"
dataset_document.error = str(e.description)
dataset_document.stopped_at = naive_utc_now()
db.session.commit()
self._handle_indexing_error(document_id, e)
except Exception as e:
logger.exception("consume document failed")
dataset_document.indexing_status = "error"
dataset_document.error = str(e)
dataset_document.stopped_at = naive_utc_now()
db.session.commit()
self._handle_indexing_error(document_id, e)
def indexing_estimate(
self,
@ -398,7 +415,6 @@ class IndexingRunner:
document_id=dataset_document.id,
after_indexing_status="splitting",
extra_update_params={
DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs),
DatasetDocument.parsing_completed_at: naive_utc_now(),
},
)
@ -738,6 +754,7 @@ class IndexingRunner:
extra_update_params={
DatasetDocument.cleaning_completed_at: cur_time,
DatasetDocument.splitting_completed_at: cur_time,
DatasetDocument.word_count: sum(len(doc.page_content) for doc in documents),
},
)

View File

@ -103,7 +103,7 @@ class LLMGenerator:
return name
@classmethod
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str) -> Sequence[str]:
output_parser = SuggestedQuestionsAfterAnswerOutputParser()
format_instructions = output_parser.get_format_instructions()
@ -122,6 +122,8 @@ class LLMGenerator:
prompt_messages = [UserPromptMessage(content=prompt)]
questions: Sequence[str] = []
try:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages),

View File

@ -1,17 +1,26 @@
import json
import logging
import re
from collections.abc import Sequence
from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
logger = logging.getLogger(__name__)
class SuggestedQuestionsAfterAnswerOutputParser:
def get_format_instructions(self) -> str:
return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
def parse(self, text: str):
def parse(self, text: str) -> Sequence[str]:
action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL)
questions: list[str] = []
if action_match is not None:
json_obj = json.loads(action_match.group(0).strip())
else:
json_obj = []
return json_obj
try:
json_obj = json.loads(action_match.group(0).strip())
except json.JSONDecodeError as exc:
logger.warning("Failed to decode suggested questions payload: %s", exc)
else:
if isinstance(json_obj, list):
questions = [question for question in json_obj if isinstance(question, str)]
return questions

View File

@ -4,14 +4,14 @@ import json
import os
import secrets
import urllib.parse
from typing import TYPE_CHECKING
from urllib.parse import urljoin, urlparse
from httpx import ConnectError, HTTPStatusError, RequestError
from pydantic import BaseModel, ValidationError
from pydantic import ValidationError
from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType
from core.helper import ssrf_proxy
from core.mcp.entities import AuthAction, AuthActionType, AuthResult, OAuthCallbackState
from core.mcp.error import MCPRefreshTokenError
from core.mcp.types import (
LATEST_PROTOCOL_VERSION,
@ -23,23 +23,10 @@ from core.mcp.types import (
)
from extensions.ext_redis import redis_client
if TYPE_CHECKING:
from services.tools.mcp_tools_manage_service import MCPToolManageService
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
class OAuthCallbackState(BaseModel):
provider_id: str
tenant_id: str
server_url: str
metadata: OAuthMetadata | None = None
client_information: OAuthClientInformation
code_verifier: str
redirect_uri: str
def generate_pkce_challenge() -> tuple[str, str]:
"""Generate PKCE challenge and verifier."""
code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8")
@ -86,8 +73,13 @@ def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
raise ValueError(f"Invalid state parameter: {str(e)}")
def handle_callback(state_key: str, authorization_code: str, mcp_service: "MCPToolManageService") -> OAuthCallbackState:
"""Handle the callback from the OAuth provider."""
def handle_callback(state_key: str, authorization_code: str) -> tuple[OAuthCallbackState, OAuthTokens]:
"""
Handle the callback from the OAuth provider.
Returns:
A tuple of (callback_state, tokens) that can be used by the caller to save data.
"""
# Retrieve state data from Redis (state is automatically deleted after retrieval)
full_state_data = _retrieve_redis_state(state_key)
@ -100,10 +92,7 @@ def handle_callback(state_key: str, authorization_code: str, mcp_service: "MCPTo
full_state_data.redirect_uri,
)
# Save tokens using the service layer
mcp_service.save_oauth_data(full_state_data.provider_id, full_state_data.tenant_id, tokens.model_dump(), "tokens")
return full_state_data
return full_state_data, tokens
def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
@ -361,11 +350,24 @@ def register_client(
def auth(
provider: MCPProviderEntity,
mcp_service: "MCPToolManageService",
authorization_code: str | None = None,
state_param: str | None = None,
) -> dict[str, str]:
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
) -> AuthResult:
"""
Orchestrates the full auth flow with a server using secure Redis state storage.
This function performs only network operations and returns actions that need
to be performed by the caller (such as saving data to database).
Args:
provider: The MCP provider entity
authorization_code: Optional authorization code from OAuth callback
state_param: Optional state parameter from OAuth callback
Returns:
AuthResult containing actions to be performed and response data
"""
actions: list[AuthAction] = []
server_url = provider.decrypt_server_url()
server_metadata = discover_oauth_metadata(server_url)
client_metadata = provider.client_metadata
@ -407,9 +409,14 @@ def auth(
except RequestError as e:
raise ValueError(f"Could not register OAuth client: {e}")
# Save client information using service layer
mcp_service.save_oauth_data(
provider_id, tenant_id, {"client_information": full_information.model_dump()}, "client_info"
# Return action to save client information
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_CLIENT_INFO,
data={"client_information": full_information.model_dump()},
provider_id=provider_id,
tenant_id=tenant_id,
)
)
client_information = full_information
@ -426,12 +433,20 @@ def auth(
scope,
)
# Save tokens and grant type
# Return action to save tokens and grant type
token_data = tokens.model_dump()
token_data["grant_type"] = MCPSupportGrantType.CLIENT_CREDENTIALS.value
mcp_service.save_oauth_data(provider_id, tenant_id, token_data, "tokens")
return {"result": "success"}
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_TOKENS,
data=token_data,
provider_id=provider_id,
tenant_id=tenant_id,
)
)
return AuthResult(actions=actions, response={"result": "success"})
except (RequestError, ValueError, KeyError) as e:
# RequestError: HTTP request failed
# ValueError: Invalid response data
@ -465,10 +480,17 @@ def auth(
redirect_uri,
)
# Save tokens using service layer
mcp_service.save_oauth_data(provider_id, tenant_id, tokens.model_dump(), "tokens")
# Return action to save tokens
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_TOKENS,
data=tokens.model_dump(),
provider_id=provider_id,
tenant_id=tenant_id,
)
)
return {"result": "success"}
return AuthResult(actions=actions, response={"result": "success"})
provider_tokens = provider.retrieve_tokens()
@ -479,10 +501,17 @@ def auth(
server_url, server_metadata, client_information, provider_tokens.refresh_token
)
# Save new tokens using service layer
mcp_service.save_oauth_data(provider_id, tenant_id, new_tokens.model_dump(), "tokens")
# Return action to save new tokens
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_TOKENS,
data=new_tokens.model_dump(),
provider_id=provider_id,
tenant_id=tenant_id,
)
)
return {"result": "success"}
return AuthResult(actions=actions, response={"result": "success"})
except (RequestError, ValueError, KeyError) as e:
# RequestError: HTTP request failed
# ValueError: Invalid response data
@ -499,7 +528,14 @@ def auth(
tenant_id,
)
# Save code verifier using service layer
mcp_service.save_oauth_data(provider_id, tenant_id, {"code_verifier": code_verifier}, "code_verifier")
# Return action to save code verifier
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_CODE_VERIFIER,
data={"code_verifier": code_verifier},
provider_id=provider_id,
tenant_id=tenant_id,
)
)
return {"authorization_url": authorization_url}
return AuthResult(actions=actions, response={"authorization_url": authorization_url})

View File

@ -7,15 +7,15 @@ authentication failures and retries operations after refreshing tokens.
import logging
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Optional
from typing import Any
from sqlalchemy.orm import Session
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.error import MCPAuthError
from core.mcp.mcp_client import MCPClient
from core.mcp.types import CallToolResult, Tool
if TYPE_CHECKING:
from services.tools.mcp_tools_manage_service import MCPToolManageService
from extensions.ext_database import db
logger = logging.getLogger(__name__)
@ -26,6 +26,9 @@ class MCPClientWithAuthRetry(MCPClient):
This class extends MCPClient and intercepts MCPAuthError exceptions
to refresh authentication before retrying failed operations.
Note: This class uses lazy session creation - database sessions are only
created when authentication retry is actually needed, not on every request.
"""
def __init__(
@ -35,11 +38,8 @@ class MCPClientWithAuthRetry(MCPClient):
timeout: float | None = None,
sse_read_timeout: float | None = None,
provider_entity: MCPProviderEntity | None = None,
auth_callback: Callable[[MCPProviderEntity, "MCPToolManageService", Optional[str]], dict[str, str]]
| None = None,
authorization_code: str | None = None,
by_server_id: bool = False,
mcp_service: Optional["MCPToolManageService"] = None,
):
"""
Initialize the MCP client with auth retry capability.
@ -50,31 +50,30 @@ class MCPClientWithAuthRetry(MCPClient):
timeout: Request timeout
sse_read_timeout: SSE read timeout
provider_entity: Provider entity for authentication
auth_callback: Authentication callback function
authorization_code: Optional authorization code for initial auth
by_server_id: Whether to look up provider by server ID
mcp_service: MCP service instance
"""
super().__init__(server_url, headers, timeout, sse_read_timeout)
self.provider_entity = provider_entity
self.auth_callback = auth_callback
self.authorization_code = authorization_code
self.by_server_id = by_server_id
self.mcp_service = mcp_service
self._has_retried = False
def _handle_auth_error(self, error: MCPAuthError) -> None:
"""
Handle authentication error by refreshing tokens.
This method creates a short-lived database session only when authentication
retry is needed, minimizing database connection hold time.
Args:
error: The authentication error
Raises:
MCPAuthError: If authentication fails or max retries reached
"""
if not self.provider_entity or not self.auth_callback or not self.mcp_service:
if not self.provider_entity:
raise error
if self._has_retried:
raise error
@ -82,13 +81,23 @@ class MCPClientWithAuthRetry(MCPClient):
self._has_retried = True
try:
# Perform authentication
self.auth_callback(self.provider_entity, self.mcp_service, self.authorization_code)
# Create a temporary session only for auth retry
# This session is short-lived and only exists during the auth operation
# Retrieve new tokens
self.provider_entity = self.mcp_service.get_provider_entity(
self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id
)
from services.tools.mcp_tools_manage_service import MCPToolManageService
with Session(db.engine) as session, session.begin():
mcp_service = MCPToolManageService(session=session)
# Perform authentication using the service's auth method
mcp_service.auth_with_actions(self.provider_entity, self.authorization_code)
# Retrieve new tokens
self.provider_entity = mcp_service.get_provider_entity(
self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id
)
# Session is closed here, before we update headers
token = self.provider_entity.retrieve_tokens()
if not token:
raise MCPAuthError("Authentication failed - no token received")

View File

@ -1,8 +1,11 @@
from dataclasses import dataclass
from enum import StrEnum
from typing import Any, Generic, TypeVar
from pydantic import BaseModel
from core.mcp.session.base_session import BaseSession
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestId, RequestParams
from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAuthMetadata, RequestId, RequestParams
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
@ -17,3 +20,41 @@ class RequestContext(Generic[SessionT, LifespanContextT]):
meta: RequestParams.Meta | None
session: SessionT
lifespan_context: LifespanContextT
class AuthActionType(StrEnum):
"""Types of actions that can be performed during auth flow."""
SAVE_CLIENT_INFO = "save_client_info"
SAVE_TOKENS = "save_tokens"
SAVE_CODE_VERIFIER = "save_code_verifier"
START_AUTHORIZATION = "start_authorization"
SUCCESS = "success"
class AuthAction(BaseModel):
"""Represents an action that needs to be performed as a result of auth flow."""
action_type: AuthActionType
data: dict[str, Any]
provider_id: str | None = None
tenant_id: str | None = None
class AuthResult(BaseModel):
"""Result of auth function containing actions to be performed and response data."""
actions: list[AuthAction]
response: dict[str, str]
class OAuthCallbackState(BaseModel):
"""State data stored in Redis during OAuth callback flow."""
provider_id: str
tenant_id: str
server_url: str
metadata: OAuthMetadata | None = None
client_information: OAuthClientInformation
code_verifier: str
redirect_uri: str

View File

@ -2,7 +2,7 @@ import logging
import os
from datetime import datetime, timedelta
from langfuse import Langfuse # type: ignore
from langfuse import Langfuse
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance

View File

@ -76,7 +76,7 @@ class PluginParameter(BaseModel):
auto_generate: PluginParameterAutoGenerate | None = None
template: PluginParameterTemplate | None = None
required: bool = False
default: Union[float, int, str] | None = None
default: Union[float, int, str, bool] | None = None
min: Union[float, int] | None = None
max: Union[float, int] | None = None
precision: int | None = None

View File

@ -180,7 +180,7 @@ class BasePluginClient:
Make a request to the plugin daemon inner API and return the response as a model.
"""
response = self._request(method, path, headers, data, params, files)
return type_(**response.json()) # type: ignore
return type_(**response.json()) # type: ignore[return-value]
def _request_with_plugin_daemon_response(
self,

View File

@ -40,7 +40,7 @@ class PluginDaemonBadRequestError(PluginDaemonClientSideError):
description: str = "Bad Request"
class PluginInvokeError(PluginDaemonClientSideError):
class PluginInvokeError(PluginDaemonClientSideError, ValueError):
description: str = "Invoke Error"
def _get_error_object(self) -> Mapping:

View File

@ -72,6 +72,19 @@ default_retrieval_model: dict[str, Any] = {
class DatasetRetrieval:
def __init__(self, application_generate_entity=None):
self.application_generate_entity = application_generate_entity
self._llm_usage = LLMUsage.empty_usage()
@property
def llm_usage(self) -> LLMUsage:
return self._llm_usage.model_copy()
def _record_usage(self, usage: LLMUsage | None) -> None:
if usage is None or usage.total_tokens <= 0:
return
if self._llm_usage.total_tokens == 0:
self._llm_usage = usage
else:
self._llm_usage = self._llm_usage.plus(usage)
def retrieve(
self,
@ -312,15 +325,18 @@ class DatasetRetrieval:
)
tools.append(message_tool)
dataset_id = None
router_usage = LLMUsage.empty_usage()
if planning_strategy == PlanningStrategy.REACT_ROUTER:
react_multi_dataset_router = ReactMultiDatasetRouter()
dataset_id = react_multi_dataset_router.invoke(
dataset_id, router_usage = react_multi_dataset_router.invoke(
query, tools, model_config, model_instance, user_id, tenant_id
)
elif planning_strategy == PlanningStrategy.ROUTER:
function_call_router = FunctionCallMultiDatasetRouter()
dataset_id = function_call_router.invoke(query, tools, model_config, model_instance)
dataset_id, router_usage = function_call_router.invoke(query, tools, model_config, model_instance)
self._record_usage(router_usage)
if dataset_id:
# get retrieval model config
@ -983,7 +999,8 @@ class DatasetRetrieval:
)
# handle invoke result
result_text, _ = self._handle_invoke_result(invoke_result=invoke_result)
result_text, usage = self._handle_invoke_result(invoke_result=invoke_result)
self._record_usage(usage)
result_text_json = parse_and_check_json_markdown(result_text, [])
automatic_metadata_filters = []

View File

@ -2,7 +2,7 @@ from typing import Union
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage
@ -13,15 +13,15 @@ class FunctionCallMultiDatasetRouter:
dataset_tools: list[PromptMessageTool],
model_config: ModelConfigWithCredentialsEntity,
model_instance: ModelInstance,
) -> Union[str, None]:
) -> tuple[Union[str, None], LLMUsage]:
"""Given input, decided what to do.
Returns:
Action specifying what tool to use.
"""
if len(dataset_tools) == 0:
return None
return None, LLMUsage.empty_usage()
elif len(dataset_tools) == 1:
return dataset_tools[0].name
return dataset_tools[0].name, LLMUsage.empty_usage()
try:
prompt_messages = [
@ -34,9 +34,10 @@ class FunctionCallMultiDatasetRouter:
stream=False,
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
)
usage = result.usage or LLMUsage.empty_usage()
if result.message.tool_calls:
# get retrieval model config
return result.message.tool_calls[0].function.name
return None
return result.message.tool_calls[0].function.name, usage
return None, usage
except Exception:
return None
return None, LLMUsage.empty_usage()

View File

@ -58,15 +58,15 @@ class ReactMultiDatasetRouter:
model_instance: ModelInstance,
user_id: str,
tenant_id: str,
) -> Union[str, None]:
) -> tuple[Union[str, None], LLMUsage]:
"""Given input, decided what to do.
Returns:
Action specifying what tool to use.
"""
if len(dataset_tools) == 0:
return None
return None, LLMUsage.empty_usage()
elif len(dataset_tools) == 1:
return dataset_tools[0].name
return dataset_tools[0].name, LLMUsage.empty_usage()
try:
return self._react_invoke(
@ -78,7 +78,7 @@ class ReactMultiDatasetRouter:
tenant_id=tenant_id,
)
except Exception:
return None
return None, LLMUsage.empty_usage()
def _react_invoke(
self,
@ -91,7 +91,7 @@ class ReactMultiDatasetRouter:
prefix: str = PREFIX,
suffix: str = SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
) -> Union[str, None]:
) -> tuple[Union[str, None], LLMUsage]:
prompt: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
if model_config.mode == "chat":
prompt = self.create_chat_prompt(
@ -120,7 +120,7 @@ class ReactMultiDatasetRouter:
memory=None,
model_config=model_config,
)
result_text, _ = self._invoke_llm(
result_text, usage = self._invoke_llm(
completion_param=model_config.parameters,
model_instance=model_instance,
prompt_messages=prompt_messages,
@ -131,8 +131,8 @@ class ReactMultiDatasetRouter:
output_parser = StructuredChatOutputParser()
react_decision = output_parser.parse(result_text)
if isinstance(react_decision, ReactAction):
return react_decision.tool
return None
return react_decision.tool, usage
return None, usage
def _invoke_llm(
self,

View File

@ -74,7 +74,7 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository):
tenant_id = extract_tenant_id(user)
if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None
self._tenant_id = tenant_id
# Store app context
self._app_id = app_id

View File

@ -81,7 +81,7 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
tenant_id = extract_tenant_id(user)
if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None
self._tenant_id = tenant_id
# Store app context
self._app_id = app_id

View File

@ -60,7 +60,7 @@ class DifyCoreRepositoryFactory:
try:
repository_class = import_string(class_path)
return repository_class( # type: ignore[no-any-return]
return repository_class(
session_factory=session_factory,
user=user,
app_id=app_id,
@ -96,7 +96,7 @@ class DifyCoreRepositoryFactory:
try:
repository_class = import_string(class_path)
return repository_class( # type: ignore[no-any-return]
return repository_class(
session_factory=session_factory,
user=user,
app_id=app_id,

View File

@ -157,7 +157,7 @@ class BuiltinToolProviderController(ToolProviderController):
"""
returns the tool that the provider can provide
"""
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None) # type: ignore
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None)
@property
def need_credentials(self) -> bool:

View File

@ -43,7 +43,7 @@ class TTSTool(BuiltinTool):
content_text=tool_parameters.get("text"), # type: ignore
user=user_id,
tenant_id=self.runtime.tenant_id,
voice=voice, # type: ignore
voice=voice,
)
buffer = io.BytesIO()
for chunk in tts:

View File

@ -34,6 +34,7 @@ class LocaltimeToTimestampTool(BuiltinTool):
yield self.create_text_message(f"{timestamp}")
# TODO: this method's type is messy
@staticmethod
def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None:
try:

View File

@ -48,6 +48,6 @@ class TimezoneConversionTool(BuiltinTool):
datetime_with_tz = input_timezone.localize(local_time)
# timezone convert
converted_datetime = datetime_with_tz.astimezone(output_timezone)
return converted_datetime.strftime(format=time_format) # type: ignore
return converted_datetime.strftime(time_format)
except Exception as e:
raise ToolInvokeError(str(e))

View File

@ -113,7 +113,7 @@ class MCPToolProviderController(ToolProviderController):
"""
pass
def get_tool(self, tool_name: str) -> MCPTool: # type: ignore
def get_tool(self, tool_name: str) -> MCPTool:
"""
return tool with given name
"""
@ -136,7 +136,7 @@ class MCPToolProviderController(ToolProviderController):
sse_read_timeout=self.sse_read_timeout,
)
def get_tools(self) -> list[MCPTool]: # type: ignore
def get_tools(self) -> list[MCPTool]:
"""
get all tools
"""

View File

@ -3,7 +3,6 @@ import json
from collections.abc import Generator
from typing import Any
from core.mcp.auth.auth_flow import auth
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPConnectionError
from core.mcp.types import CallToolResult, ImageContent, TextContent
@ -125,71 +124,39 @@ class MCPTool(Tool):
headers = self.headers.copy() if self.headers else {}
tool_parameters = self._handle_none_parameter(tool_parameters)
# Get provider entity to access tokens
from sqlalchemy.orm import Session
# Get MCP service from invoke parameters or create new one
provider_entity = None
mcp_service = None
from extensions.ext_database import db
from services.tools.mcp_tools_manage_service import MCPToolManageService
# Check if mcp_service is passed in tool_parameters
if "_mcp_service" in tool_parameters:
mcp_service = tool_parameters.pop("_mcp_service")
if mcp_service:
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
headers = provider_entity.decrypt_headers()
# Try to get existing token and add to headers
if not headers:
tokens = provider_entity.retrieve_tokens()
if tokens and tokens.access_token:
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
# Step 1: Load provider entity and credentials in a short-lived session
# This minimizes database connection hold time
with Session(db.engine, expire_on_commit=False) as session:
mcp_service = MCPToolManageService(session=session)
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
# Use MCPClientWithAuthRetry to handle authentication automatically
try:
with MCPClientWithAuthRetry(
server_url=provider_entity.decrypt_server_url() if provider_entity else self.server_url,
headers=headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
provider_entity=provider_entity,
auth_callback=auth if mcp_service else None,
mcp_service=mcp_service,
) as mcp_client:
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPConnectionError as e:
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
except (ValueError, TypeError, KeyError) as e:
# Catch specific exceptions that might occur during tool invocation
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
else:
# Fallback to creating service with database session
from sqlalchemy.orm import Session
# Decrypt and prepare all credentials before closing session
server_url = provider_entity.decrypt_server_url()
headers = provider_entity.decrypt_headers()
from extensions.ext_database import db
from services.tools.mcp_tools_manage_service import MCPToolManageService
# Try to get existing token and add to headers
if not headers:
tokens = provider_entity.retrieve_tokens()
if tokens and tokens.access_token:
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
with Session(db.engine, expire_on_commit=False) as session:
mcp_service = MCPToolManageService(session=session)
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
headers = provider_entity.decrypt_headers()
# Try to get existing token and add to headers
if not headers:
tokens = provider_entity.retrieve_tokens()
if tokens and tokens.access_token:
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
# Use MCPClientWithAuthRetry to handle authentication automatically
try:
with MCPClientWithAuthRetry(
server_url=provider_entity.decrypt_server_url() if provider_entity else self.server_url,
headers=headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
provider_entity=provider_entity,
auth_callback=auth if mcp_service else None,
mcp_service=mcp_service,
) as mcp_client:
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPConnectionError as e:
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
except Exception as e:
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
# Step 2: Session is now closed, perform network operations without holding database connection
# MCPClientWithAuthRetry will create a new session lazily only if auth retry is needed
try:
with MCPClientWithAuthRetry(
server_url=server_url,
headers=headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
provider_entity=provider_entity,
) as mcp_client:
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPConnectionError as e:
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
except Exception as e:
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e

View File

@ -26,7 +26,7 @@ class ToolLabelManager:
labels = cls.filter_tool_labels(labels)
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
provider_id = controller.provider_id # ty: ignore [unresolved-attribute]
provider_id = controller.provider_id
else:
raise ValueError("Unsupported tool type")
@ -51,7 +51,7 @@ class ToolLabelManager:
Get tool labels
"""
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
provider_id = controller.provider_id # ty: ignore [unresolved-attribute]
provider_id = controller.provider_id
elif isinstance(controller, BuiltinToolProviderController):
return controller.tool_labels
else:
@ -85,7 +85,7 @@ class ToolLabelManager:
provider_ids = []
for controller in tool_providers:
assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
provider_ids.append(controller.provider_id) # ty: ignore [unresolved-attribute]
provider_ids.append(controller.provider_id)
labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all()

View File

@ -331,7 +331,8 @@ class ToolManager:
workflow_provider_stmt = select(WorkflowToolProvider).where(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id
)
workflow_provider = db.session.scalar(workflow_provider_stmt)
with Session(db.engine, expire_on_commit=False) as session, session.begin():
workflow_provider = session.scalar(workflow_provider_stmt)
if workflow_provider is None:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")

View File

@ -193,18 +193,18 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
document = db.session.scalar(dataset_document_stmt) # type: ignore
document = db.session.scalar(dataset_document_stmt)
if dataset and document:
source = RetrievalSourceMetadata(
dataset_id=dataset.id,
dataset_name=dataset.name,
document_id=document.id, # type: ignore
document_name=document.name, # type: ignore
data_source_type=document.data_source_type, # type: ignore
document_id=document.id,
document_name=document.name,
data_source_type=document.data_source_type,
segment_id=segment.id,
retriever_from=self.retriever_from,
score=record.score or 0.0,
doc_metadata=document.doc_metadata, # type: ignore
doc_metadata=document.doc_metadata,
)
if self.retriever_from == "dev":

View File

@ -62,6 +62,11 @@ class ApiBasedToolSchemaParser:
root = root[ref]
interface["operation"]["parameters"][i] = root
for parameter in interface["operation"]["parameters"]:
# Handle complex type defaults that are not supported by PluginParameter
default_value = None
if "schema" in parameter and "default" in parameter["schema"]:
default_value = ApiBasedToolSchemaParser._sanitize_default_value(parameter["schema"]["default"])
tool_parameter = ToolParameter(
name=parameter["name"],
label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]),
@ -72,9 +77,7 @@ class ApiBasedToolSchemaParser:
required=parameter.get("required", False),
form=ToolParameter.ToolParameterForm.LLM,
llm_description=parameter.get("description"),
default=parameter["schema"]["default"]
if "schema" in parameter and "default" in parameter["schema"]
else None,
default=default_value,
placeholder=I18nObject(
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
),
@ -134,6 +137,11 @@ class ApiBasedToolSchemaParser:
required = body_schema.get("required", [])
properties = body_schema.get("properties", {})
for name, property in properties.items():
# Handle complex type defaults that are not supported by PluginParameter
default_value = ApiBasedToolSchemaParser._sanitize_default_value(
property.get("default", None)
)
tool = ToolParameter(
name=name,
label=I18nObject(en_US=name, zh_Hans=name),
@ -144,12 +152,11 @@ class ApiBasedToolSchemaParser:
required=name in required,
form=ToolParameter.ToolParameterForm.LLM,
llm_description=property.get("description", ""),
default=property.get("default", None),
default=default_value,
placeholder=I18nObject(
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
),
)
# check if there is a type
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
if typ:
@ -197,6 +204,22 @@ class ApiBasedToolSchemaParser:
return bundles
@staticmethod
def _sanitize_default_value(value):
"""
Sanitize default values for PluginParameter compatibility.
Complex types (list, dict) are converted to None to avoid validation errors.
Args:
value: The default value from OpenAPI schema
Returns:
None for complex types (list, dict), otherwise the original value
"""
if isinstance(value, (list, dict)):
return None
return value
@staticmethod
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None:
parameter = parameter or {}
@ -217,7 +240,11 @@ class ApiBasedToolSchemaParser:
return ToolParameter.ToolParameterType.STRING
elif typ == "array":
items = parameter.get("items") or parameter.get("schema", {}).get("items")
return ToolParameter.ToolParameterType.FILES if items and items.get("format") == "binary" else None
if items and items.get("format") == "binary":
return ToolParameter.ToolParameterType.FILES
else:
# For regular arrays, return ARRAY type instead of None
return ToolParameter.ToolParameterType.ARRAY
else:
return None

View File

@ -6,8 +6,8 @@ from typing import Any, cast
from urllib.parse import unquote
import chardet
import cloudscraper # type: ignore
from readabilipy import simple_json_from_html_string # type: ignore
import cloudscraper
from readabilipy import simple_json_from_html_string
from core.helper import ssrf_proxy
from core.rag.extractor import extract_processor
@ -63,8 +63,8 @@ def get_url(url: str, user_agent: str | None = None) -> str:
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
elif response.status_code == 403:
scraper = cloudscraper.create_scraper()
scraper.perform_request = ssrf_proxy.make_request # type: ignore
response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) # type: ignore
scraper.perform_request = ssrf_proxy.make_request
response = scraper.get(url, headers=headers, timeout=(120, 300))
if response.status_code != 200:
return f"URL returned status code {response.status_code}."

View File

@ -3,7 +3,7 @@ from functools import lru_cache
from pathlib import Path
from typing import Any
import yaml # type: ignore
import yaml
from yaml import YAMLError
logger = logging.getLogger(__name__)

View File

@ -1,6 +1,7 @@
from collections.abc import Mapping
from pydantic import Field
from sqlalchemy.orm import Session
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
@ -20,6 +21,7 @@ from core.tools.entities.tool_entities import (
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
from models.account import Account
from models.model import App, AppMode
from models.tools import WorkflowToolProvider
from models.workflow import Workflow
@ -44,29 +46,34 @@ class WorkflowToolProviderController(ToolProviderController):
@classmethod
def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController":
app = db_provider.app
with Session(db.engine, expire_on_commit=False) as session, session.begin():
provider = session.get(WorkflowToolProvider, db_provider.id) if db_provider.id else None
if not provider:
raise ValueError("workflow provider not found")
app = session.get(App, provider.app_id)
if not app:
raise ValueError("app not found")
if not app:
raise ValueError("app not found")
user = session.get(Account, provider.user_id) if provider.user_id else None
controller = WorkflowToolProviderController(
entity=ToolProviderEntity(
identity=ToolProviderIdentity(
author=db_provider.user.name if db_provider.user_id and db_provider.user else "",
name=db_provider.label,
label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label),
description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
icon=db_provider.icon,
controller = WorkflowToolProviderController(
entity=ToolProviderEntity(
identity=ToolProviderIdentity(
author=user.name if user else "",
name=provider.label,
label=I18nObject(en_US=provider.label, zh_Hans=provider.label),
description=I18nObject(en_US=provider.description, zh_Hans=provider.description),
icon=provider.icon,
),
credentials_schema=[],
plugin_id=None,
),
credentials_schema=[],
plugin_id=None,
),
provider_id=db_provider.id or "",
)
provider_id=provider.id or "",
)
# init tools
controller.tools = [controller._get_db_provider_tool(db_provider, app)]
controller.tools = [
controller._get_db_provider_tool(provider, app, session=session, user=user),
]
return controller
@ -74,7 +81,14 @@ class WorkflowToolProviderController(ToolProviderController):
def provider_type(self) -> ToolProviderType:
return ToolProviderType.WORKFLOW
def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool:
def _get_db_provider_tool(
self,
db_provider: WorkflowToolProvider,
app: App,
*,
session: Session,
user: Account | None = None,
) -> WorkflowTool:
"""
get db provider tool
:param db_provider: the db provider
@ -82,7 +96,7 @@ class WorkflowToolProviderController(ToolProviderController):
:return: the tool
"""
workflow: Workflow | None = (
db.session.query(Workflow)
session.query(Workflow)
.where(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
.first()
)
@ -99,9 +113,7 @@ class WorkflowToolProviderController(ToolProviderController):
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
def fetch_workflow_variable(variable_name: str) -> VariableEntity | None:
return next(filter(lambda x: x.variable == variable_name, variables), None) # type: ignore
user = db_provider.user
return next(filter(lambda x: x.variable == variable_name, variables), None)
workflow_tool_parameters = []
for parameter in parameters:
@ -187,22 +199,25 @@ class WorkflowToolProviderController(ToolProviderController):
if self.tools is not None:
return self.tools
db_providers: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.app_id == self.provider_id,
with Session(db.engine, expire_on_commit=False) as session, session.begin():
db_provider: WorkflowToolProvider | None = (
session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.app_id == self.provider_id,
)
.first()
)
.first()
)
if not db_providers:
return []
if not db_providers.app:
raise ValueError("app not found")
if not db_provider:
return []
app = db_providers.app
self.tools = [self._get_db_provider_tool(db_providers, app)]
app = session.get(App, db_provider.app_id)
if not app:
raise ValueError("app not found")
user = session.get(Account, db_provider.user_id) if db_provider.user_id else None
self.tools = [self._get_db_provider_tool(db_provider, app, session=session, user=user)]
return self.tools

View File

@ -1,12 +1,14 @@
import json
import logging
from collections.abc import Generator
from typing import Any
from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast
from flask import has_request_context
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import (
@ -48,6 +50,7 @@ class WorkflowTool(Tool):
self.workflow_entities = workflow_entities
self.workflow_call_depth = workflow_call_depth
self.label = label
self._latest_usage = LLMUsage.empty_usage()
super().__init__(entity=entity, runtime=runtime)
@ -83,10 +86,11 @@ class WorkflowTool(Tool):
assert self.runtime.invoke_from is not None
user = self._resolve_user(user_id=user_id)
if user is None:
raise ToolInvokeError("User not found")
self._latest_usage = LLMUsage.empty_usage()
result = generator.generate(
app_model=app,
workflow=workflow,
@ -110,9 +114,68 @@ class WorkflowTool(Tool):
for file in files:
yield self.create_file_message(file) # type: ignore
self._latest_usage = self._derive_usage_from_result(data)
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
yield self.create_json_message(outputs)
@property
def latest_usage(self) -> LLMUsage:
return self._latest_usage
@classmethod
def _derive_usage_from_result(cls, data: Mapping[str, Any]) -> LLMUsage:
usage_dict = cls._extract_usage_dict(data)
if usage_dict is not None:
return LLMUsage.from_metadata(cast(LLMUsageMetadata, dict(usage_dict)))
total_tokens = data.get("total_tokens")
total_price = data.get("total_price")
if total_tokens is None and total_price is None:
return LLMUsage.empty_usage()
usage_metadata: dict[str, Any] = {}
if total_tokens is not None:
try:
usage_metadata["total_tokens"] = int(str(total_tokens))
except (TypeError, ValueError):
pass
if total_price is not None:
usage_metadata["total_price"] = str(total_price)
currency = data.get("currency")
if currency is not None:
usage_metadata["currency"] = currency
if not usage_metadata:
return LLMUsage.empty_usage()
return LLMUsage.from_metadata(cast(LLMUsageMetadata, usage_metadata))
@classmethod
def _extract_usage_dict(cls, payload: Mapping[str, Any]) -> Mapping[str, Any] | None:
usage_candidate = payload.get("usage")
if isinstance(usage_candidate, Mapping):
return usage_candidate
metadata_candidate = payload.get("metadata")
if isinstance(metadata_candidate, Mapping):
usage_candidate = metadata_candidate.get("usage")
if isinstance(usage_candidate, Mapping):
return usage_candidate
for value in payload.values():
if isinstance(value, Mapping):
found = cls._extract_usage_dict(value)
if found is not None:
return found
elif isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
for item in value:
if isinstance(item, Mapping):
found = cls._extract_usage_dict(item)
if found is not None:
return found
return None
def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
"""
fork a new tool with metadata
@ -179,16 +242,17 @@ class WorkflowTool(Tool):
"""
get the workflow by app id and version
"""
if not version:
workflow = (
db.session.query(Workflow)
.where(Workflow.app_id == app_id, Workflow.version != Workflow.VERSION_DRAFT)
.order_by(Workflow.created_at.desc())
.first()
)
else:
stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version)
workflow = db.session.scalar(stmt)
with Session(db.engine, expire_on_commit=False) as session, session.begin():
if not version:
stmt = (
select(Workflow)
.where(Workflow.app_id == app_id, Workflow.version != Workflow.VERSION_DRAFT)
.order_by(Workflow.created_at.desc())
)
workflow = session.scalars(stmt).first()
else:
stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version)
workflow = session.scalar(stmt)
if not workflow:
raise ValueError("workflow not found or not published")
@ -200,7 +264,8 @@ class WorkflowTool(Tool):
get the app by app id
"""
stmt = select(App).where(App.id == app_id)
app = db.session.scalar(stmt)
with Session(db.engine, expire_on_commit=False) as session, session.begin():
app = session.scalar(stmt)
if not app:
raise ValueError("app not found")

View File

@ -4,7 +4,7 @@ from .types import SegmentType
class SegmentGroup(Segment):
value_type: SegmentType = SegmentType.GROUP
value: list[Segment] = None # type: ignore
value: list[Segment]
@property
def text(self):

View File

@ -19,7 +19,7 @@ class Segment(BaseModel):
model_config = ConfigDict(frozen=True)
value_type: SegmentType
value: Any = None
value: Any
@field_validator("value_type")
@classmethod
@ -74,12 +74,12 @@ class NoneSegment(Segment):
class StringSegment(Segment):
value_type: SegmentType = SegmentType.STRING
value: str = None # type: ignore
value: str
class FloatSegment(Segment):
value_type: SegmentType = SegmentType.FLOAT
value: float = None # type: ignore
value: float
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
# The following tests cannot pass.
#
@ -98,12 +98,12 @@ class FloatSegment(Segment):
class IntegerSegment(Segment):
value_type: SegmentType = SegmentType.INTEGER
value: int = None # type: ignore
value: int
class ObjectSegment(Segment):
value_type: SegmentType = SegmentType.OBJECT
value: Mapping[str, Any] = None # type: ignore
value: Mapping[str, Any]
@property
def text(self) -> str:
@ -136,7 +136,7 @@ class ArraySegment(Segment):
class FileSegment(Segment):
value_type: SegmentType = SegmentType.FILE
value: File = None # type: ignore
value: File
@property
def markdown(self) -> str:
@ -153,17 +153,17 @@ class FileSegment(Segment):
class BooleanSegment(Segment):
value_type: SegmentType = SegmentType.BOOLEAN
value: bool = None # type: ignore
value: bool
class ArrayAnySegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_ANY
value: Sequence[Any] = None # type: ignore
value: Sequence[Any]
class ArrayStringSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_STRING
value: Sequence[str] = None # type: ignore
value: Sequence[str]
@property
def text(self) -> str:
@ -175,17 +175,17 @@ class ArrayStringSegment(ArraySegment):
class ArrayNumberSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_NUMBER
value: Sequence[float | int] = None # type: ignore
value: Sequence[float | int]
class ArrayObjectSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_OBJECT
value: Sequence[Mapping[str, Any]] = None # type: ignore
value: Sequence[Mapping[str, Any]]
class ArrayFileSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_FILE
value: Sequence[File] = None # type: ignore
value: Sequence[File]
@property
def markdown(self) -> str:
@ -247,7 +247,7 @@ class VersionedMemorySegment(Segment):
class ArrayBooleanSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_BOOLEAN
value: Sequence[bool] = None # type: ignore
value: Sequence[bool]
def get_segment_discriminator(v: Any) -> SegmentType | None:

View File

@ -1,3 +1,5 @@
from ..runtime.graph_runtime_state import GraphRuntimeState
from ..runtime.variable_pool import VariablePool
from .agent import AgentNodeStrategyInit
from .graph_init_params import GraphInitParams
from .workflow_execution import WorkflowExecution
@ -6,6 +8,8 @@ from .workflow_node_execution import WorkflowNodeExecution
__all__ = [
"AgentNodeStrategyInit",
"GraphInitParams",
"GraphRuntimeState",
"VariablePool",
"WorkflowExecution",
"WorkflowNodeExecution",
]

View File

@ -3,11 +3,12 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Protocol, cast, final
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType
from core.workflow.nodes.base.node import Node
from libs.typing import is_str, is_str_dict
from .edge import Edge
from .validation import get_graph_validator
logger = logging.getLogger(__name__)
@ -201,6 +202,17 @@ class Graph:
return GraphBuilder(graph_cls=cls)
@classmethod
def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None:
"""
Promote nodes configured with FAIL_BRANCH error strategy to branch execution type.
:param nodes: mapping of node ID to node instance
"""
for node in nodes.values():
if node.error_strategy == ErrorStrategy.FAIL_BRANCH:
node.execution_type = NodeExecutionType.BRANCH
@classmethod
def _mark_inactive_root_branches(
cls,
@ -307,6 +319,9 @@ class Graph:
# Create node instances
nodes = cls._create_node_instances(node_configs_map, node_factory)
# Promote fail-branch nodes to branch execution type at graph level
cls._promote_fail_branch_nodes(nodes)
# Get root node instance
root_node = nodes[root_node_id]
@ -314,7 +329,7 @@ class Graph:
cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id)
# Create and return the graph
return cls(
graph = cls(
nodes=nodes,
edges=edges,
in_edges=in_edges,
@ -322,6 +337,11 @@ class Graph:
root_node=root_node,
)
# Validate the graph structure using built-in validators
get_graph_validator().validate(graph)
return graph
@property
def node_ids(self) -> list[str]:
"""

View File

@ -0,0 +1,125 @@
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Protocol
from core.workflow.enums import NodeExecutionType, NodeType
if TYPE_CHECKING:
from .graph import Graph
@dataclass(frozen=True, slots=True)
class GraphValidationIssue:
"""Immutable value object describing a single validation issue."""
code: str
message: str
node_id: str | None = None
class GraphValidationError(ValueError):
"""Raised when graph validation fails."""
def __init__(self, issues: Sequence[GraphValidationIssue]) -> None:
if not issues:
raise ValueError("GraphValidationError requires at least one issue.")
self.issues: tuple[GraphValidationIssue, ...] = tuple(issues)
message = "; ".join(f"[{issue.code}] {issue.message}" for issue in self.issues)
super().__init__(message)
class GraphValidationRule(Protocol):
"""Protocol that individual validation rules must satisfy."""
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
"""Validate the provided graph and return any discovered issues."""
...
@dataclass(frozen=True, slots=True)
class _EdgeEndpointValidator:
"""Ensures all edges reference existing nodes."""
missing_node_code: str = "MISSING_NODE"
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
issues: list[GraphValidationIssue] = []
for edge in graph.edges.values():
if edge.tail not in graph.nodes:
issues.append(
GraphValidationIssue(
code=self.missing_node_code,
message=f"Edge {edge.id} references unknown source node '{edge.tail}'.",
node_id=edge.tail,
)
)
if edge.head not in graph.nodes:
issues.append(
GraphValidationIssue(
code=self.missing_node_code,
message=f"Edge {edge.id} references unknown target node '{edge.head}'.",
node_id=edge.head,
)
)
return issues
@dataclass(frozen=True, slots=True)
class _RootNodeValidator:
"""Validates root node invariants."""
invalid_root_code: str = "INVALID_ROOT"
container_entry_types: tuple[NodeType, ...] = (NodeType.ITERATION_START, NodeType.LOOP_START)
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
root_node = graph.root_node
issues: list[GraphValidationIssue] = []
if root_node.id not in graph.nodes:
issues.append(
GraphValidationIssue(
code=self.invalid_root_code,
message=f"Root node '{root_node.id}' is missing from the node registry.",
node_id=root_node.id,
)
)
return issues
node_type = getattr(root_node, "node_type", None)
if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types:
issues.append(
GraphValidationIssue(
code=self.invalid_root_code,
message=f"Root node '{root_node.id}' must declare execution type 'root'.",
node_id=root_node.id,
)
)
return issues
@dataclass(frozen=True, slots=True)
class GraphValidator:
"""Coordinates execution of graph validation rules."""
rules: tuple[GraphValidationRule, ...]
def validate(self, graph: Graph) -> None:
"""Validate the graph against all configured rules."""
issues: list[GraphValidationIssue] = []
for rule in self.rules:
issues.extend(rule.validate(graph))
if issues:
raise GraphValidationError(issues)
_DEFAULT_RULES: tuple[GraphValidationRule, ...] = (
_EdgeEndpointValidator(),
_RootNodeValidator(),
)
def get_graph_validator() -> GraphValidator:
"""Construct the validator composed of default rules."""
return GraphValidator(_DEFAULT_RULES)

View File

@ -26,8 +26,8 @@ class AgentNodeData(BaseNodeData):
class ParamsAutoGenerated(IntEnum):
CLOSE = auto()
OPEN = auto()
CLOSE = 0
OPEN = 1
class AgentOldVersionModelFeatures(StrEnum):

View File

@ -1,4 +1,5 @@
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
from .usage_tracking_mixin import LLMUsageTrackingMixin
__all__ = [
"BaseIterationNodeData",
@ -6,4 +7,5 @@ __all__ = [
"BaseLoopNodeData",
"BaseLoopState",
"BaseNodeData",
"LLMUsageTrackingMixin",
]

View File

@ -1,5 +1,6 @@
import json
from abc import ABC
from builtins import type as type_
from collections.abc import Sequence
from enum import StrEnum
from typing import Any, Union
@ -58,10 +59,9 @@ class DefaultValue(BaseModel):
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
@staticmethod
def _validate_array(value: Any, element_type: DefaultValueType) -> bool:
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
"""Unified array type validation"""
# FIXME, type ignore here for do not find the reason mypy complain, if find the root cause, please fix it
return isinstance(value, list) and all(isinstance(x, element_type) for x in value) # type: ignore
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
@staticmethod
def _convert_number(value: str) -> float:

View File

@ -0,0 +1,28 @@
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.runtime import GraphRuntimeState
class LLMUsageTrackingMixin:
"""Provides shared helpers for merging and recording LLM usage within workflow nodes."""
graph_runtime_state: GraphRuntimeState
@staticmethod
def _merge_usage(current: LLMUsage, new_usage: LLMUsage | None) -> LLMUsage:
"""Return a combined usage snapshot, preserving zero-value inputs."""
if new_usage is None or new_usage.total_tokens <= 0:
return current
if current.total_tokens == 0:
return new_usage
return current.plus(new_usage)
def _accumulate_usage(self, usage: LLMUsage) -> None:
"""Push usage into the graph runtime accumulator for downstream reporting."""
if usage.total_tokens <= 0:
return
current_usage = self.graph_runtime_state.llm_usage
if current_usage.total_tokens == 0:
self.graph_runtime_state.llm_usage = usage.model_copy()
else:
self.graph_runtime_state.llm_usage = current_usage.plus(usage)

View File

@ -10,10 +10,10 @@ from typing import Any
import chardet
import docx
import pandas as pd
import pypandoc # type: ignore
import pypdfium2 # type: ignore
import webvtt # type: ignore
import yaml # type: ignore
import pypandoc
import pypdfium2
import webvtt
import yaml
from docx.document import Document
from docx.oxml.table import CT_Tbl
from docx.oxml.text.paragraph import CT_P

View File

@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, NewType, cast
from flask import Flask, current_app
from typing_extensions import TypeIs
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.variables.variables import VariableUnion
@ -34,6 +35,7 @@ from core.workflow.node_events import (
NodeRunResult,
StreamCompletedEvent,
)
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
@ -58,7 +60,7 @@ logger = logging.getLogger(__name__)
EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment)
class IterationNode(Node):
class IterationNode(LLMUsageTrackingMixin, Node):
"""
Iteration Node.
"""
@ -118,6 +120,7 @@ class IterationNode(Node):
started_at = naive_utc_now()
iter_run_map: dict[str, float] = {}
outputs: list[object] = []
usage_accumulator = [LLMUsage.empty_usage()]
yield IterationStartedEvent(
start_at=started_at,
@ -130,22 +133,27 @@ class IterationNode(Node):
iterator_list_value=iterator_list_value,
outputs=outputs,
iter_run_map=iter_run_map,
usage_accumulator=usage_accumulator,
)
self._accumulate_usage(usage_accumulator[0])
yield from self._handle_iteration_success(
started_at=started_at,
inputs=inputs,
outputs=outputs,
iterator_list_value=iterator_list_value,
iter_run_map=iter_run_map,
usage=usage_accumulator[0],
)
except IterationNodeError as e:
self._accumulate_usage(usage_accumulator[0])
yield from self._handle_iteration_failure(
started_at=started_at,
inputs=inputs,
outputs=outputs,
iterator_list_value=iterator_list_value,
iter_run_map=iter_run_map,
usage=usage_accumulator[0],
error=e,
)
@ -196,6 +204,7 @@ class IterationNode(Node):
iterator_list_value: Sequence[object],
outputs: list[object],
iter_run_map: dict[str, float],
usage_accumulator: list[LLMUsage],
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
if self._node_data.is_parallel:
# Parallel mode execution
@ -203,6 +212,7 @@ class IterationNode(Node):
iterator_list_value=iterator_list_value,
outputs=outputs,
iter_run_map=iter_run_map,
usage_accumulator=usage_accumulator,
)
else:
# Sequential mode execution
@ -228,6 +238,9 @@ class IterationNode(Node):
# Update the total tokens from this iteration
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
usage_accumulator[0] = self._merge_usage(
usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage
)
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
def _execute_parallel_iterations(
@ -235,6 +248,7 @@ class IterationNode(Node):
iterator_list_value: Sequence[object],
outputs: list[object],
iter_run_map: dict[str, float],
usage_accumulator: list[LLMUsage],
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
# Initialize outputs list with None values to maintain order
outputs.extend([None] * len(iterator_list_value))
@ -245,7 +259,16 @@ class IterationNode(Node):
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all iteration tasks
future_to_index: dict[
Future[tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]],
Future[
tuple[
datetime,
list[GraphNodeEventBase],
object | None,
int,
dict[str, VariableUnion],
LLMUsage,
]
],
int,
] = {}
for index, item in enumerate(iterator_list_value):
@ -264,7 +287,14 @@ class IterationNode(Node):
index = future_to_index[future]
try:
result = future.result()
iter_start_at, events, output_value, tokens_used, conversation_snapshot = result
(
iter_start_at,
events,
output_value,
tokens_used,
conversation_snapshot,
iteration_usage,
) = result
# Update outputs at the correct index
outputs[index] = output_value
@ -276,6 +306,8 @@ class IterationNode(Node):
self.graph_runtime_state.total_tokens += tokens_used
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage)
# Sync conversation variables after iteration completion
self._sync_conversation_variables_from_snapshot(conversation_snapshot)
@ -303,7 +335,7 @@ class IterationNode(Node):
item: object,
flask_app: Flask,
context_vars: contextvars.Context,
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]:
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion], LLMUsage]:
"""Execute a single iteration in parallel mode and return results."""
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
@ -332,6 +364,7 @@ class IterationNode(Node):
output_value,
graph_engine.graph_runtime_state.total_tokens,
conversation_snapshot,
graph_engine.graph_runtime_state.llm_usage,
)
def _handle_iteration_success(
@ -341,6 +374,8 @@ class IterationNode(Node):
outputs: list[object],
iterator_list_value: Sequence[object],
iter_run_map: dict[str, float],
*,
usage: LLMUsage,
) -> Generator[NodeEventBase, None, None]:
# Flatten the list of lists if all outputs are lists
flattened_outputs = self._flatten_outputs_if_needed(outputs)
@ -351,7 +386,9 @@ class IterationNode(Node):
outputs={"output": flattened_outputs},
steps=len(iterator_list_value),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
},
)
@ -362,8 +399,11 @@ class IterationNode(Node):
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": flattened_outputs},
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)
)
@ -400,6 +440,8 @@ class IterationNode(Node):
outputs: list[object],
iterator_list_value: Sequence[object],
iter_run_map: dict[str, float],
*,
usage: LLMUsage,
error: IterationNodeError,
) -> Generator[NodeEventBase, None, None]:
# Flatten the list of lists if all outputs are lists (even in failure case)
@ -411,7 +453,9 @@ class IterationNode(Node):
outputs={"output": flattened_outputs},
steps=len(iterator_list_value),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
},
error=str(error),
@ -420,6 +464,12 @@ class IterationNode(Node):
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(error),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)
)

View File

@ -15,14 +15,11 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.message_entities import (
PromptMessageRole,
)
from core.model_runtime.entities.model_entities import (
ModelFeature,
ModelType,
)
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.simple_prompt_transform import ModelMode
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.metadata_entities import Condition, MetadataCondition
@ -33,8 +30,14 @@ from core.variables import (
)
from core.variables.segments import ArrayObjectSegment
from core.workflow.entities import GraphInitParams
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import (
ErrorStrategy,
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
@ -80,7 +83,7 @@ default_retrieval_model = {
}
class KnowledgeRetrievalNode(Node):
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node):
node_type = NodeType.KNOWLEDGE_RETRIEVAL
_node_data: KnowledgeRetrievalNodeData
@ -141,7 +144,7 @@ class KnowledgeRetrievalNode(Node):
def version(cls):
return "1"
def _run(self) -> NodeRunResult: # type: ignore
def _run(self) -> NodeRunResult:
# extract variables
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
if not isinstance(variable, StringSegment):
@ -182,14 +185,21 @@ class KnowledgeRetrievalNode(Node):
)
# retrieve knowledge
usage = LLMUsage.empty_usage()
try:
results = self._fetch_dataset_retriever(node_data=self._node_data, query=query)
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, query=query)
outputs = {"result": ArrayObjectSegment(value=results)}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
process_data={},
process_data={"usage": jsonable_encoder(usage)},
outputs=outputs, # type: ignore
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)
except KnowledgeRetrievalNodeError as e:
@ -199,6 +209,7 @@ class KnowledgeRetrievalNode(Node):
inputs=variables,
error=str(e),
error_type=type(e).__name__,
llm_usage=usage,
)
# Temporary handle all exceptions from DatasetRetrieval class here.
except Exception as e:
@ -207,11 +218,15 @@ class KnowledgeRetrievalNode(Node):
inputs=variables,
error=str(e),
error_type=type(e).__name__,
llm_usage=usage,
)
finally:
db.session.close()
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]:
def _fetch_dataset_retriever(
self, node_data: KnowledgeRetrievalNodeData, query: str
) -> tuple[list[dict[str, Any]], LLMUsage]:
usage = LLMUsage.empty_usage()
available_datasets = []
dataset_ids = node_data.dataset_ids
@ -245,9 +260,10 @@ class KnowledgeRetrievalNode(Node):
if not dataset:
continue
available_datasets.append(dataset)
metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition(
metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition(
[dataset.id for dataset in available_datasets], query, node_data
)
usage = self._merge_usage(usage, metadata_usage)
all_documents = []
dataset_retrieval = DatasetRetrieval()
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
@ -330,6 +346,8 @@ class KnowledgeRetrievalNode(Node):
metadata_filter_document_ids=metadata_filter_document_ids,
metadata_condition=metadata_condition,
)
usage = self._merge_usage(usage, dataset_retrieval.llm_usage)
dify_documents = [item for item in all_documents if item.provider == "dify"]
external_documents = [item for item in all_documents if item.provider == "external"]
retrieval_resource_list = []
@ -406,11 +424,12 @@ class KnowledgeRetrievalNode(Node):
)
for position, item in enumerate(retrieval_resource_list, start=1):
item["metadata"]["position"] = position
return retrieval_resource_list
return retrieval_resource_list, usage
def _get_metadata_filter_condition(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]:
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]:
usage = LLMUsage.empty_usage()
document_query = db.session.query(Document).where(
Document.dataset_id.in_(dataset_ids),
Document.indexing_status == "completed",
@ -420,9 +439,12 @@ class KnowledgeRetrievalNode(Node):
filters: list[Any] = []
metadata_condition = None
if node_data.metadata_filtering_mode == "disabled":
return None, None
return None, None, usage
elif node_data.metadata_filtering_mode == "automatic":
automatic_metadata_filters = self._automatic_metadata_filter_func(dataset_ids, query, node_data)
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
dataset_ids, query, node_data
)
usage = self._merge_usage(usage, automatic_usage)
if automatic_metadata_filters:
conditions = []
for sequence, filter in enumerate(automatic_metadata_filters):
@ -443,7 +465,7 @@ class KnowledgeRetrievalNode(Node):
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator
if node_data.metadata_filtering_conditions
else "or", # type: ignore
else "or",
conditions=conditions,
)
elif node_data.metadata_filtering_mode == "manual":
@ -457,10 +479,10 @@ class KnowledgeRetrievalNode(Node):
expected_value = self.graph_runtime_state.variable_pool.convert_template(
expected_value
).value[0]
if expected_value.value_type in {"number", "integer", "float"}: # type: ignore
expected_value = expected_value.value # type: ignore
elif expected_value.value_type == "string": # type: ignore
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore
if expected_value.value_type in {"number", "integer", "float"}:
expected_value = expected_value.value
elif expected_value.value_type == "string":
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
else:
raise ValueError("Invalid expected metadata value type")
conditions.append(
@ -487,7 +509,7 @@ class KnowledgeRetrievalNode(Node):
if (
node_data.metadata_filtering_conditions
and node_data.metadata_filtering_conditions.logical_operator == "and"
): # type: ignore
):
document_query = document_query.where(and_(*filters))
else:
document_query = document_query.where(or_(*filters))
@ -496,11 +518,12 @@ class KnowledgeRetrievalNode(Node):
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
for document in documents:
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
return metadata_filter_document_ids, metadata_condition
return metadata_filter_document_ids, metadata_condition, usage
def _automatic_metadata_filter_func(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> list[dict[str, Any]]:
) -> tuple[list[dict[str, Any]], LLMUsage]:
usage = LLMUsage.empty_usage()
# get all metadata field
stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
metadata_fields = db.session.scalars(stmt).all()
@ -548,6 +571,7 @@ class KnowledgeRetrievalNode(Node):
for event in generator:
if isinstance(event, ModelInvokeCompletedEvent):
result_text = event.text
usage = self._merge_usage(usage, event.usage)
break
result_text_json = parse_and_check_json_markdown(result_text, [])
@ -564,8 +588,8 @@ class KnowledgeRetrievalNode(Node):
}
)
except Exception:
return []
return automatic_metadata_filters
return [], usage
return automatic_metadata_filters, usage
def _process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]

View File

@ -452,10 +452,14 @@ class LLMNode(Node):
usage = LLMUsage.empty_usage()
finish_reason = None
full_text_buffer = io.StringIO()
collected_structured_output = None # Collect structured_output from streaming chunks
# Consume the invoke result and handle generator exception
try:
for result in invoke_result:
if isinstance(result, LLMResultChunkWithStructuredOutput):
# Collect structured_output from the chunk
if result.structured_output is not None:
collected_structured_output = dict(result.structured_output)
yield result
if isinstance(result, LLMResultChunk):
contents = result.delta.message.content
@ -503,6 +507,8 @@ class LLMNode(Node):
finish_reason=finish_reason,
# Reasoning content for workflow variables and downstream nodes
reasoning_content=reasoning_content,
# Pass structured output if collected from streaming chunks
structured_output=collected_structured_output,
)
@staticmethod

View File

@ -5,6 +5,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, cast
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import Segment, SegmentType
from core.workflow.enums import (
ErrorStrategy,
@ -27,6 +28,7 @@ from core.workflow.node_events import (
NodeRunResult,
StreamCompletedEvent,
)
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData
@ -40,7 +42,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class LoopNode(Node):
class LoopNode(LLMUsageTrackingMixin, Node):
"""
Loop Node.
"""
@ -108,7 +110,7 @@ class LoopNode(Node):
raise ValueError(f"Invalid value for loop variable {loop_variable.label}")
variable_selector = [self._node_id, loop_variable.label]
variable = segment_to_variable(segment=processed_segment, selector=variable_selector)
self.graph_runtime_state.variable_pool.add(variable_selector, variable)
self.graph_runtime_state.variable_pool.add(variable_selector, variable.value)
loop_variable_selectors[loop_variable.label] = variable_selector
inputs[loop_variable.label] = processed_segment.value
@ -117,6 +119,7 @@ class LoopNode(Node):
loop_duration_map: dict[str, float] = {}
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
loop_usage = LLMUsage.empty_usage()
# Start Loop event
yield LoopStartedEvent(
@ -163,6 +166,9 @@ class LoopNode(Node):
# Update the total tokens from this iteration
cost_tokens += graph_engine.graph_runtime_state.total_tokens
# Accumulate usage from the sub-graph execution
loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage)
# Collect loop variable values after iteration
single_loop_variable = {}
for key, selector in loop_variable_selectors.items():
@ -189,6 +195,7 @@ class LoopNode(Node):
)
self.graph_runtime_state.total_tokens += cost_tokens
self._accumulate_usage(loop_usage)
# Loop completed successfully
yield LoopSucceededEvent(
start_at=start_at,
@ -196,7 +203,9 @@ class LoopNode(Node):
outputs=self._node_data.outputs,
steps=loop_count,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: cost_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
"completed_reason": "loop_break" if reach_break_condition else "loop_completed",
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
@ -207,22 +216,28 @@ class LoopNode(Node):
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
outputs=self._node_data.outputs,
inputs=inputs,
llm_usage=loop_usage,
)
)
except Exception as e:
self._accumulate_usage(loop_usage)
yield LoopFailedEvent(
start_at=start_at,
inputs=inputs,
steps=loop_count,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
"completed_reason": "error",
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
@ -235,10 +250,13 @@ class LoopNode(Node):
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
llm_usage=loop_usage,
)
)

View File

@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, final
from typing_extensions import override
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
from core.workflow.enums import NodeType
from core.workflow.graph import NodeFactory
from core.workflow.nodes.base.node import Node
from libs.typing import is_str, is_str_dict
@ -82,8 +82,4 @@ class DifyNodeFactory(NodeFactory):
raise ValueError(f"Node {node_id} missing data information")
node_instance.init_node_data(node_data)
# If node has fail branch, change execution type to branch
if node_instance.error_strategy == ErrorStrategy.FAIL_BRANCH:
node_instance.execution_type = NodeExecutionType.BRANCH
return node_instance

View File

@ -747,7 +747,7 @@ class ParameterExtractorNode(Node):
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
role=PromptMessageRole.SYSTEM,
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str).replace("{{instructions}}", instruction),
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str, instructions=instruction),
)
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
return [system_prompt_messages, user_prompt_message]

View File

@ -135,7 +135,7 @@ Here are the chat histories between human and assistant, inside <histories></his
### Instructions:
Some extra information are provided below, you should always follow the instructions as possible as you can.
<instructions>
{{instructions}}
{instructions}
</instructions>
"""

View File

@ -6,10 +6,13 @@ from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file import File, FileTransferMethod
from core.model_runtime.entities.llm_entities import LLMUsage
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.tools.workflow_as_tool.tool import WorkflowTool
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.enums import (
@ -136,13 +139,14 @@ class ToolNode(Node):
try:
# convert tool messages
yield from self._transform_message(
_ = yield from self._transform_message(
messages=message_stream,
tool_info=tool_info,
parameters_for_log=parameters_for_log,
user_id=self.user_id,
tenant_id=self.tenant_id,
node_id=self._node_id,
tool_runtime=tool_runtime,
)
except ToolInvokeError as e:
yield StreamCompletedEvent(
@ -236,7 +240,8 @@ class ToolNode(Node):
user_id: str,
tenant_id: str,
node_id: str,
) -> Generator:
tool_runtime: Tool,
) -> Generator[NodeEventBase, None, LLMUsage]:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
@ -424,17 +429,34 @@ class ToolNode(Node):
is_final=True,
)
usage = self._extract_tool_usage(tool_runtime)
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
}
if usage.total_tokens > 0:
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
metadata={
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
},
metadata=metadata,
inputs=parameters_for_log,
llm_usage=usage,
)
)
return usage
@staticmethod
def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage:
if isinstance(tool_runtime, WorkflowTool):
return tool_runtime.latest_usage
return LLMUsage.empty_usage()
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,

View File

@ -277,7 +277,7 @@ class VariablePool(BaseModel):
# This ensures that we can keep the id of the system variables intact.
if self._has(selector):
continue
self.add(selector, value) # type: ignore
self.add(selector, value)
@classmethod
def empty(cls) -> "VariablePool":

View File

@ -32,7 +32,8 @@ if [[ "${MODE}" == "worker" ]]; then
exec celery -A celery_entrypoint.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \
--max-tasks-per-child ${MAX_TASKS_PER_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \
-Q ${CELERY_QUEUES:-dataset,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation}
-Q ${CELERY_QUEUES:-dataset,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation} \
--prefetch-multiplier=1
elif [[ "${MODE}" == "beat" ]]; then
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}

View File

@ -6,8 +6,8 @@ from tasks.clean_dataset_task import clean_dataset_task
@dataset_was_deleted.connect
def handle(sender: Dataset, **kwargs):
dataset = sender
assert dataset.doc_form
assert dataset.indexing_technique
if not dataset.doc_form or not dataset.indexing_technique:
return
clean_dataset_task.delay(
dataset.id,
dataset.tenant_id,

View File

@ -8,6 +8,6 @@ def handle(sender, **kwargs):
dataset_id = kwargs.get("dataset_id")
doc_form = kwargs.get("doc_form")
file_id = kwargs.get("file_id")
assert dataset_id is not None
assert doc_form is not None
if not dataset_id or not doc_form:
return
clean_document_task.delay(document_id, dataset_id, doc_form, file_id)

View File

@ -2,6 +2,11 @@ from configs import dify_config
from constants import HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN, HEADER_NAME_PASSPORT
from dify_app import DifyApp
BASE_CORS_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE, HEADER_NAME_PASSPORT)
SERVICE_API_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, "Authorization")
AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN)
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
def init_app(app: DifyApp):
# register blueprint routers
@ -17,7 +22,7 @@ def init_app(app: DifyApp):
CORS(
service_api_bp,
allow_headers=["Content-Type", "Authorization", HEADER_NAME_APP_CODE],
allow_headers=list(SERVICE_API_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
)
app.register_blueprint(service_api_bp)
@ -27,11 +32,11 @@ def init_app(app: DifyApp):
resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
supports_credentials=True,
allow_headers=[
"Content-Type",
"Authorization",
HEADER_NAME_APP_CODE,
HEADER_NAME_CSRF_TOKEN,
HEADER_NAME_PASSPORT
"Content-Type",
"Authorization",
HEADER_NAME_APP_CODE,
HEADER_NAME_CSRF_TOKEN,
HEADER_NAME_PASSPORT,
],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=["X-Version", "X-Env"],
@ -42,7 +47,7 @@ def init_app(app: DifyApp):
console_app_bp,
resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
supports_credentials=True,
allow_headers=["Content-Type", "Authorization", HEADER_NAME_CSRF_TOKEN],
allow_headers=list(AUTHENTICATED_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=["X-Version", "X-Env"],
)
@ -50,7 +55,7 @@ def init_app(app: DifyApp):
CORS(
files_bp,
allow_headers=["Content-Type", HEADER_NAME_CSRF_TOKEN],
allow_headers=list(FILES_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
)
app.register_blueprint(files_bp)

View File

@ -7,7 +7,7 @@ def is_enabled() -> bool:
def init_app(app: DifyApp):
from flask_compress import Compress # type: ignore
from flask_compress import Compress
compress = Compress()
compress.init_app(app)

View File

@ -1,6 +1,6 @@
import json
import flask_login # type: ignore
import flask_login
from flask import Response, request
from flask_login import user_loaded_from_request, user_logged_in
from werkzeug.exceptions import NotFound, Unauthorized

View File

@ -2,7 +2,7 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
import flask_migrate # type: ignore
import flask_migrate
from extensions.ext_database import db

View File

@ -103,7 +103,7 @@ def init_app(app: DifyApp):
def shutdown_tracer():
provider = trace.get_tracer_provider()
if hasattr(provider, "force_flush"):
provider.force_flush() # ty: ignore [call-non-callable]
provider.force_flush()
class ExceptionLoggingHandler(logging.Handler):
"""Custom logging handler that creates spans for logging.exception() calls"""

View File

@ -6,4 +6,4 @@ def init_app(app: DifyApp):
if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED:
from werkzeug.middleware.proxy_fix import ProxyFix
app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore
app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore[method-assign]

View File

@ -5,7 +5,7 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
if dify_config.SENTRY_DSN:
import sentry_sdk
from langfuse import parse_error # type: ignore
from langfuse import parse_error
from sentry_sdk.integrations.celery import CeleryIntegration
from sentry_sdk.integrations.flask import FlaskIntegration
from werkzeug.exceptions import HTTPException

View File

@ -1,7 +1,7 @@
import posixpath
from collections.abc import Generator
import oss2 as aliyun_s3 # type: ignore
import oss2 as aliyun_s3
from configs import dify_config
from extensions.storage.base_storage import BaseStorage

View File

@ -2,9 +2,9 @@ import base64
import hashlib
from collections.abc import Generator
from baidubce.auth.bce_credentials import BceCredentials # type: ignore
from baidubce.bce_client_configuration import BceClientConfiguration # type: ignore
from baidubce.services.bos.bos_client import BosClient # type: ignore
from baidubce.auth.bce_credentials import BceCredentials
from baidubce.bce_client_configuration import BceClientConfiguration
from baidubce.services.bos.bos_client import BosClient
from configs import dify_config
from extensions.storage.base_storage import BaseStorage

View File

@ -11,7 +11,7 @@ from collections.abc import Generator
from io import BytesIO
from pathlib import Path
import clickzetta # type: ignore[import]
import clickzetta
from pydantic import BaseModel, model_validator
from extensions.storage.base_storage import BaseStorage

View File

@ -34,7 +34,7 @@ class VolumePermissionManager:
# Support two initialization methods: connection object or configuration dictionary
if isinstance(connection_or_config, dict):
# Create connection from configuration dictionary
import clickzetta # type: ignore[import-untyped]
import clickzetta
config = connection_or_config
self._connection = clickzetta.connect(

View File

@ -3,7 +3,7 @@ import io
import json
from collections.abc import Generator
from google.cloud import storage as google_cloud_storage # type: ignore
from google.cloud import storage as google_cloud_storage
from configs import dify_config
from extensions.storage.base_storage import BaseStorage

View File

@ -1,6 +1,6 @@
from collections.abc import Generator
from obs import ObsClient # type: ignore
from obs import ObsClient
from configs import dify_config
from extensions.storage.base_storage import BaseStorage

View File

@ -1,7 +1,7 @@
from collections.abc import Generator
import boto3 # type: ignore
from botocore.exceptions import ClientError # type: ignore
import boto3
from botocore.exceptions import ClientError
from configs import dify_config
from extensions.storage.base_storage import BaseStorage

View File

@ -1,6 +1,6 @@
from collections.abc import Generator
from qcloud_cos import CosConfig, CosS3Client # type: ignore
from qcloud_cos import CosConfig, CosS3Client
from configs import dify_config
from extensions.storage.base_storage import BaseStorage

View File

@ -1,6 +1,6 @@
from collections.abc import Generator
import tos # type: ignore
import tos
from configs import dify_config
from extensions.storage.base_storage import BaseStorage

View File

@ -146,6 +146,6 @@ class ExternalApi(Api):
kwargs["doc"] = dify_config.SWAGGER_UI_PATH if dify_config.SWAGGER_UI_ENABLED else False
# manual separate call on construction and init_app to ensure configs in kwargs effective
super().__init__(app=None, *args, **kwargs) # type: ignore
super().__init__(app=None, *args, **kwargs)
self.init_app(app, **kwargs)
register_external_error_handlers(self)

View File

@ -23,7 +23,7 @@ from hashlib import sha1
import Crypto.Hash.SHA1
import Crypto.Util.number
import gmpy2 # type: ignore
import gmpy2
from Crypto import Random
from Crypto.Signature.pss import MGF1
from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes
@ -136,7 +136,7 @@ class PKCS1OAepCipher:
# Step 3a (OS2IP)
em_int = bytes_to_long(em)
# Step 3b (RSAEP)
m_int = gmpy2.powmod(em_int, self._key.e, self._key.n) # ty: ignore [unresolved-attribute]
m_int = gmpy2.powmod(em_int, self._key.e, self._key.n)
# Step 3c (I2OSP)
c = long_to_bytes(m_int, k)
return c
@ -169,7 +169,7 @@ class PKCS1OAepCipher:
ct_int = bytes_to_long(ciphertext)
# Step 2b (RSADP)
# m_int = self._key._decrypt(ct_int)
m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n) # ty: ignore [unresolved-attribute]
m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n)
# Complete step 2c (I2OSP)
em = long_to_bytes(m_int, k)
# Step 3a
@ -191,12 +191,12 @@ class PKCS1OAepCipher:
# Step 3g
one_pos = hLen + db[hLen:].find(b"\x01")
lHash1 = db[:hLen]
invalid = bord(y) | int(one_pos < hLen) # type: ignore
invalid = bord(y) | int(one_pos < hLen) # type: ignore[arg-type]
hash_compare = strxor(lHash1, lHash)
for x in hash_compare:
invalid |= bord(x) # type: ignore
invalid |= bord(x) # type: ignore[arg-type]
for x in db[hLen:one_pos]:
invalid |= bord(x) # type: ignore
invalid |= bord(x) # type: ignore[arg-type]
if invalid != 0:
raise ValueError("Incorrect decryption.")
# Step 4

View File

@ -81,6 +81,8 @@ class AvatarUrlField(fields.Raw):
from models import Account
if isinstance(obj, Account) and obj.avatar is not None:
if obj.avatar.startswith(("http://", "https://")):
return obj.avatar
return file_helpers.get_signed_file_url(obj.avatar)
return None

View File

@ -3,7 +3,7 @@ from functools import wraps
from typing import Any
from flask import current_app, g, has_request_context, request
from flask_login.config import EXEMPT_METHODS # type: ignore
from flask_login.config import EXEMPT_METHODS
from werkzeug.local import LocalProxy
from configs import dify_config
@ -87,7 +87,7 @@ def _get_user() -> EndUser | Account | None:
if "_login_user" not in g:
current_app.login_manager._load_user() # type: ignore
return g._login_user # type: ignore
return g._login_user
return None

View File

@ -1,8 +1,8 @@
import logging
import sendgrid # type: ignore
import sendgrid
from python_http_client.exceptions import ForbiddenError, UnauthorizedError
from sendgrid.helpers.mail import Content, Email, Mail, To # type: ignore
from sendgrid.helpers.mail import Content, Email, Mail, To
logger = logging.getLogger(__name__)

View File

@ -12,6 +12,7 @@ from constants import (
COOKIE_NAME_CSRF_TOKEN,
COOKIE_NAME_PASSPORT,
COOKIE_NAME_REFRESH_TOKEN,
COOKIE_NAME_WEBAPP_ACCESS_TOKEN,
HEADER_NAME_CSRF_TOKEN,
HEADER_NAME_PASSPORT,
)
@ -81,6 +82,14 @@ def extract_access_token(request: Request) -> str | None:
return _try_extract_from_cookie(request) or _try_extract_from_header(request)
def extract_webapp_access_token(request: Request) -> str | None:
"""
Try to extract webapp access token from cookie, then header.
"""
return request.cookies.get(_real_cookie_name(COOKIE_NAME_WEBAPP_ACCESS_TOKEN)) or _try_extract_from_header(request)
def extract_webapp_passport(app_code: str, request: Request) -> str | None:
"""
Try to extract app token from header or params.
@ -155,6 +164,10 @@ def clear_access_token_from_cookie(response: Response, samesite: str = "Lax"):
_clear_cookie(response, COOKIE_NAME_ACCESS_TOKEN, samesite)
def clear_webapp_access_token_from_cookie(response: Response, samesite: str = "Lax"):
_clear_cookie(response, COOKIE_NAME_WEBAPP_ACCESS_TOKEN, samesite)
def clear_refresh_token_from_cookie(response: Response):
_clear_cookie(response, COOKIE_NAME_REFRESH_TOKEN)

View File

@ -0,0 +1,36 @@
"""remove-builtin-template-user
Revision ID: ae662b25d9bc
Revises: d98acf217d43
Create Date: 2025-10-21 14:30:28.566192
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'ae662b25d9bc'
down_revision = 'd98acf217d43'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
batch_op.drop_column('updated_by')
batch_op.drop_column('created_by')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
batch_op.add_column(sa.Column('created_by', sa.UUID(), autoincrement=False, nullable=False))
batch_op.add_column(sa.Column('updated_by', sa.UUID(), autoincrement=False, nullable=True))
# ### end Alembic commands ###

View File

@ -5,7 +5,7 @@ from datetime import datetime
from typing import Any, Optional
import sqlalchemy as sa
from flask_login import UserMixin # type: ignore[import-untyped]
from flask_login import UserMixin
from sqlalchemy import DateTime, String, func, select
from sqlalchemy.orm import Mapped, Session, mapped_column
from typing_extensions import deprecated

View File

@ -1239,15 +1239,6 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
language = mapped_column(db.String(255), nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
created_by = mapped_column(StringUUID, nullable=False)
updated_by = mapped_column(StringUUID, nullable=True)
@property
def created_user_name(self):
account = db.session.query(Account).where(Account.id == self.created_by).first()
if account:
return account.name
return ""
class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]

View File

@ -219,7 +219,7 @@ class WorkflowToolProvider(TypeBase):
sa.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
# name of the workflow provider
name: Mapped[str] = mapped_column(String(255), nullable=False)
# label of the workflow provider

View File

@ -1,6 +1,6 @@
[project]
name = "dify-api"
version = "1.9.1"
version = "1.9.2"
requires-python = ">=3.11,<3.13"
dependencies = [

View File

@ -16,7 +16,25 @@
"opentelemetry.instrumentation.requests",
"opentelemetry.instrumentation.sqlalchemy",
"opentelemetry.instrumentation.redis",
"opentelemetry.instrumentation.httpx"
"langfuse",
"cloudscraper",
"readabilipy",
"pypandoc",
"pypdfium2",
"webvtt",
"flask_compress",
"oss2",
"baidubce.auth.bce_credentials",
"baidubce.bce_client_configuration",
"baidubce.services.bos.bos_client",
"clickzetta",
"google.cloud",
"obs",
"qcloud_cos",
"tos",
"gmpy2",
"sendgrid",
"sendgrid.helpers.mail"
],
"reportUnknownMemberType": "hint",
"reportUnknownParameterType": "hint",
@ -28,7 +46,7 @@
"reportUnnecessaryComparison": "hint",
"reportUnnecessaryIsInstance": "hint",
"reportUntypedFunctionDecorator": "hint",
"reportUnnecessaryTypeIgnoreComment": "hint",
"reportAttributeAccessIssue": "hint",
"pythonVersion": "3.11",
"pythonPlatform": "All"

Some files were not shown because too many files have changed in this diff Show More