diff --git a/api/.env.example b/api/.env.example index c5cdccc0cf..db4dfba326 100644 --- a/api/.env.example +++ b/api/.env.example @@ -437,6 +437,9 @@ CODE_EXECUTION_SSL_VERIFY=True CODE_EXECUTION_POOL_MAX_CONNECTIONS=100 CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS=20 CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY=5.0 +CODE_EXECUTION_CONNECT_TIMEOUT=10 +CODE_EXECUTION_READ_TIMEOUT=60 +CODE_EXECUTION_WRITE_TIMEOUT=10 CODE_MAX_NUMBER=9223372036854775807 CODE_MIN_NUMBER=-9223372036854775808 CODE_MAX_STRING_LENGTH=400000 diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 00b6b16577..d8fdbf7102 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -548,7 +548,7 @@ class UpdateConfig(BaseSettings): class WorkflowVariableTruncationConfig(BaseSettings): WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE: PositiveInt = Field( - # 100KB + # 1000 KiB 1024_000, description="Maximum size for variable to trigger final truncation.", ) diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index d872e8201b..816d0e442f 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -145,7 +145,7 @@ class DatabaseConfig(BaseSettings): default="postgresql", ) - @computed_field # type: ignore[misc] + @computed_field # type: ignore[prop-decorator] @property def SQLALCHEMY_DATABASE_URI(self) -> str: db_extras = ( @@ -198,7 +198,7 @@ class DatabaseConfig(BaseSettings): default=os.cpu_count() or 1, ) - @computed_field # type: ignore[misc] + @computed_field # type: ignore[prop-decorator] @property def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]: # Parse DB_EXTRAS for 'options' diff --git a/api/constants/__init__.py b/api/constants/__init__.py index 248cdfc09f..e441395afc 100644 --- a/api/constants/__init__.py +++ b/api/constants/__init__.py @@ -56,11 +56,15 @@ else: } DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions) +# console COOKIE_NAME_ACCESS_TOKEN = "access_token" COOKIE_NAME_REFRESH_TOKEN = "refresh_token" -COOKIE_NAME_PASSPORT = "passport" COOKIE_NAME_CSRF_TOKEN = "csrf_token" +# webapp +COOKIE_NAME_WEBAPP_ACCESS_TOKEN = "webapp_access_token" +COOKIE_NAME_PASSPORT = "passport" + HEADER_NAME_CSRF_TOKEN = "X-CSRF-Token" HEADER_NAME_APP_CODE = "X-App-Code" HEADER_NAME_PASSPORT = "X-App-Passport" diff --git a/api/constants/languages.py b/api/constants/languages.py index a509ddcf5d..0312a558c9 100644 --- a/api/constants/languages.py +++ b/api/constants/languages.py @@ -31,3 +31,9 @@ def supported_language(lang): error = f"{lang} is not a valid language." raise ValueError(error) + + +def get_valid_language(lang: str | None) -> str: + if lang and lang in languages: + return lang + return languages[0] diff --git a/api/controllers/common/helpers.py b/api/controllers/common/helpers.py index 6a5197635e..ef89e66980 100644 --- a/api/controllers/common/helpers.py +++ b/api/controllers/common/helpers.py @@ -24,7 +24,7 @@ except ImportError: ) else: warnings.warn("To use python-magic guess MIMETYPE, you need to install `libmagic`", stacklevel=2) - magic = None # type: ignore + magic = None # type: ignore[assignment] from pydantic import BaseModel diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 154d0cefcb..4fcdb5ea2a 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -4,7 +4,7 @@ from flask_restx import Resource, reqparse import services from configs import dify_config -from constants.languages import languages +from constants.languages import get_valid_language from controllers.console import console_ns from controllers.console.auth.error import ( AuthenticationFailedError, @@ -207,10 +207,12 @@ class EmailCodeLoginApi(Resource): .add_argument("email", type=str, required=True, location="json") .add_argument("code", type=str, required=True, location="json") .add_argument("token", type=str, required=True, location="json") + .add_argument("language", type=str, required=False, location="json") ) args = parser.parse_args() user_email = args["email"] + language = args["language"] token_data = AccountService.get_email_code_login_data(args["token"]) if token_data is None: @@ -244,7 +246,9 @@ class EmailCodeLoginApi(Resource): if account is None: try: account = AccountService.create_account_and_tenant( - email=user_email, name=user_email, interface_language=languages[0] + email=user_email, + name=user_email, + interface_language=get_valid_language(language), ) except WorkSpaceNotAllowedCreateError: raise NotAllowedCreateWorkspace() diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index 3022d937b9..125f603a5a 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -22,7 +22,7 @@ from core.errors.error import ( from core.model_runtime.errors.invoke import InvokeError from core.workflow.graph_engine.manager import GraphEngineManager from libs import helper -from libs.login import current_user as current_user_ +from libs.login import current_account_with_tenant from models.model import AppMode, InstalledApp from services.app_generate_service import AppGenerateService from services.errors.llm import InvokeRateLimitError @@ -31,8 +31,6 @@ from .. import console_ns logger = logging.getLogger(__name__) -current_user = current_user_._get_current_object() # type: ignore - @console_ns.route("/installed-apps//workflows/run") class InstalledAppWorkflowRunApi(InstalledAppResource): @@ -40,6 +38,7 @@ class InstalledAppWorkflowRunApi(InstalledAppResource): """ Run workflow """ + current_user, _ = current_account_with_tenant() app_model = installed_app.app if not app_model: raise NotWorkflowAppError() @@ -53,7 +52,6 @@ class InstalledAppWorkflowRunApi(InstalledAppResource): .add_argument("files", type=list, required=False, location="json") ) args = parser.parse_args() - assert current_user is not None try: response = AppGenerateService.generate( app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True @@ -89,7 +87,6 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource): app_mode = AppMode.value_of(app_model.mode) if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() - assert current_user is not None # Stop using both mechanisms for backward compatibility # Legacy stop flag mechanism (without user check) diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index 6d2b22bde3..1200349e2d 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -74,12 +74,17 @@ class SetupApi(Resource): .add_argument("email", type=email, required=True, location="json") .add_argument("name", type=StrLen(30), required=True, location="json") .add_argument("password", type=valid_password, required=True, location="json") + .add_argument("language", type=str, required=False, location="json") ) args = parser.parse_args() # setup RegisterService.setup( - email=args["email"], name=args["name"], password=args["password"], ip_address=extract_remote_ip(request) + email=args["email"], + name=args["name"], + password=args["password"], + ip_address=extract_remote_ip(request), + language=args["language"], ) return {"result": "success"}, 201 diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 54d64e7085..a8d4f0f5de 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -30,7 +30,7 @@ from models.provider_ids import ToolProviderID from services.plugin.oauth_service import OAuthProxyService from services.tools.api_tools_manage_service import ApiToolManageService from services.tools.builtin_tools_manage_service import BuiltinToolManageService -from services.tools.mcp_tools_manage_service import MCPToolManageService +from services.tools.mcp_tools_manage_service import MCPToolManageService, OAuthDataType from services.tools.tool_labels_service import ToolLabelsService from services.tools.tools_manage_service import ToolCommonService from services.tools.tools_transform_service import ToolTransformService @@ -897,10 +897,6 @@ class ToolProviderMCPApi(Resource): args = parser.parse_args() user, tenant_id = current_account_with_tenant() - # Validate server URL - if not is_valid_url(args["server_url"]): - raise ValueError("Server URL is not valid.") - # Parse and validate models configuration = MCPConfiguration.model_validate(args["configuration"]) authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None @@ -941,15 +937,21 @@ class ToolProviderMCPApi(Resource): .add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={}) ) args = parser.parse_args() - if not is_valid_url(args["server_url"]): - if "[__HIDDEN__]" in args["server_url"]: - pass - else: - raise ValueError("Server URL is not valid.") configuration = MCPConfiguration.model_validate(args["configuration"]) authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None _, current_tenant_id = current_account_with_tenant() + # Step 1: Validate server URL change if needed (includes URL format validation and network operation) + validation_result = None + with Session(db.engine) as session: + service = MCPToolManageService(session=session) + validation_result = service.validate_server_url_change( + tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"] + ) + + # No need to check for errors here, exceptions will be raised directly + + # Step 2: Perform database update in a transaction with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) service.update_provider( @@ -964,6 +966,7 @@ class ToolProviderMCPApi(Resource): headers=args["headers"], configuration=configuration, authentication=authentication, + validation_result=validation_result, ) return {"result": "success"} @@ -998,47 +1001,49 @@ class ToolMCPAuthApi(Resource): provider_id = args["provider_id"] _, tenant_id = current_account_with_tenant() - with Session(db.engine) as session: - with session.begin(): - service = MCPToolManageService(session=session) - db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id) - if not db_provider: - raise ValueError("provider not found") + with Session(db.engine) as session, session.begin(): + service = MCPToolManageService(session=session) + db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id) + if not db_provider: + raise ValueError("provider not found") - # Convert to entity - provider_entity = db_provider.to_entity() - server_url = provider_entity.decrypt_server_url() - headers = provider_entity.decrypt_authentication() + # Convert to entity + provider_entity = db_provider.to_entity() + server_url = provider_entity.decrypt_server_url() + headers = provider_entity.decrypt_authentication() - # Try to connect without active transaction + # Try to connect without active transaction + try: + # Use MCPClientWithAuthRetry to handle authentication automatically + with MCPClient( + server_url=server_url, + headers=headers, + timeout=provider_entity.timeout, + sse_read_timeout=provider_entity.sse_read_timeout, + ): + # Create new transaction for update + with session.begin(): + service.update_provider_credentials( + provider=db_provider, + credentials=provider_entity.credentials, + authed=True, + ) + return {"result": "success"} + except MCPAuthError as e: + service = MCPToolManageService(session=session) try: - # Use MCPClientWithAuthRetry to handle authentication automatically - with MCPClient( - server_url=server_url, - headers=headers, - timeout=provider_entity.timeout, - sse_read_timeout=provider_entity.sse_read_timeout, - ): - # Create new transaction for update - with session.begin(): - service.update_provider_credentials( - provider=db_provider, - credentials=provider_entity.credentials, - authed=True, - ) - return {"result": "success"} - except MCPAuthError as e: - service = MCPToolManageService(session=session) - try: - return auth(provider_entity, service, args.get("authorization_code")) - except MCPRefreshTokenError as e: - with session.begin(): - service.clear_provider_credentials(provider=db_provider) - raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e - except MCPError as e: + auth_result = auth(provider_entity, args.get("authorization_code")) + with session.begin(): + response = service.execute_auth_actions(auth_result) + return response + except MCPRefreshTokenError as e: with session.begin(): service.clear_provider_credentials(provider=db_provider) - raise ValueError(f"Failed to connect to MCP server: {e}") from e + raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e + except MCPError as e: + with session.begin(): + service.clear_provider_credentials(provider=db_provider) + raise ValueError(f"Failed to connect to MCP server: {e}") from e @console_ns.route("/workspaces/current/tool-provider/mcp/tools/") @@ -1048,7 +1053,7 @@ class ToolMCPDetailApi(Resource): @account_initialization_required def get(self, provider_id): _, tenant_id = current_account_with_tenant() - with Session(db.engine) as session: + with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id) return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True)) @@ -1062,7 +1067,7 @@ class ToolMCPListAllApi(Resource): def get(self): _, tenant_id = current_account_with_tenant() - with Session(db.engine, expire_on_commit=False) as session: + with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) tools = service.list_providers(tenant_id=tenant_id) @@ -1100,6 +1105,11 @@ class ToolMCPCallbackApi(Resource): # Create service instance for handle_callback with Session(db.engine) as session, session.begin(): mcp_service = MCPToolManageService(session=session) - handle_callback(state_key, authorization_code, mcp_service) + # handle_callback now returns state data and tokens + state_data, tokens = handle_callback(state_key, authorization_code) + # Save tokens using the service layer + mcp_service.save_oauth_data( + state_data.provider_id, state_data.tenant_id, tokens.model_dump(), OAuthDataType.TOKENS + ) return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index 85b7df229f..8d8fe6b3a8 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -193,15 +193,16 @@ class MCPAppApi(Resource): except ValidationError as e: raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}") - def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str, session: Session) -> EndUser | None: - """Get end user from existing session - optimized query""" - return ( - session.query(EndUser) - .where(EndUser.tenant_id == tenant_id) - .where(EndUser.session_id == mcp_server_id) - .where(EndUser.type == "mcp") - .first() - ) + def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str) -> EndUser | None: + """Get end user - manages its own database session""" + with Session(db.engine, expire_on_commit=False) as session, session.begin(): + return ( + session.query(EndUser) + .where(EndUser.tenant_id == tenant_id) + .where(EndUser.session_id == mcp_server_id) + .where(EndUser.type == "mcp") + .first() + ) def _create_end_user( self, client_name: str, tenant_id: str, app_id: str, mcp_server_id: str, session: Session @@ -229,7 +230,7 @@ class MCPAppApi(Resource): request_id: Union[int, str], ) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError | None: """Handle MCP request and return response""" - end_user = self._retrieve_end_user(mcp_server.tenant_id, mcp_server.id, session) + end_user = self._retrieve_end_user(mcp_server.tenant_id, mcp_server.id) if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest): client_info = mcp_request.root.params.clientInfo diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py index f213fd8c90..244ef47982 100644 --- a/api/controllers/web/login.py +++ b/api/controllers/web/login.py @@ -17,8 +17,8 @@ from libs.helper import email from libs.passport import PassportService from libs.password import valid_password from libs.token import ( - clear_access_token_from_cookie, - extract_access_token, + clear_webapp_access_token_from_cookie, + extract_webapp_access_token, ) from services.account_service import AccountService from services.app_service import AppService @@ -81,7 +81,7 @@ class LoginStatusApi(Resource): ) def get(self): app_code = request.args.get("app_code") - token = extract_access_token(request) + token = extract_webapp_access_token(request) if not app_code: return { "logged_in": bool(token), @@ -128,7 +128,7 @@ class LogoutApi(Resource): response = make_response({"result": "success"}) # enterprise SSO sets same site to None in https deployment # so we need to logout by calling api - clear_access_token_from_cookie(response, samesite="None") + clear_webapp_access_token_from_cookie(response, samesite="None") return response diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index 776b743e92..6a2e0b65fb 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -12,10 +12,8 @@ from controllers.web import web_ns from controllers.web.error import WebAppAuthRequiredError from extensions.ext_database import db from libs.passport import PassportService -from libs.token import extract_access_token +from libs.token import extract_webapp_access_token from models.model import App, EndUser, Site -from services.app_service import AppService -from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService from services.webapp_auth_service import WebAppAuthService, WebAppAuthType @@ -37,23 +35,18 @@ class PassportResource(Resource): system_features = FeatureService.get_system_features() app_code = request.headers.get(HEADER_NAME_APP_CODE) user_id = request.args.get("user_id") - access_token = extract_access_token(request) - + access_token = extract_webapp_access_token(request) if app_code is None: raise Unauthorized("X-App-Code header is missing.") - app_id = AppService.get_app_id_by_code(app_code) - # exchange token for enterprise logined web user - enterprise_user_decoded = decode_enterprise_webapp_user_id(access_token) - if enterprise_user_decoded: - # a web user has already logged in, exchange a token for this app without redirecting to the login page - return exchange_token_for_existing_web_user( - app_code=app_code, enterprise_user_decoded=enterprise_user_decoded - ) - if system_features.webapp_auth.enabled: - app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id) - if not app_settings or not app_settings.access_mode == "public": - raise WebAppAuthRequiredError() + enterprise_user_decoded = decode_enterprise_webapp_user_id(access_token) + app_auth_type = WebAppAuthService.get_app_auth_type(app_code=app_code) + if app_auth_type != WebAppAuthType.PUBLIC: + if not enterprise_user_decoded: + raise WebAppAuthRequiredError() + return exchange_token_for_existing_web_user( + app_code=app_code, enterprise_user_decoded=enterprise_user_decoded, auth_type=app_auth_type + ) # get site from db and check if it is normal site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal")) @@ -124,7 +117,7 @@ def decode_enterprise_webapp_user_id(jwt_token: str | None): return decoded -def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict): +def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict, auth_type: WebAppAuthType): """ Exchange a token for an existing web user session. """ @@ -145,13 +138,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: if not app_model or app_model.status != "normal" or not app_model.enable_site: raise NotFound() - app_auth_type = WebAppAuthService.get_app_auth_type(app_code=app_code) - - if app_auth_type == WebAppAuthType.PUBLIC: + if auth_type == WebAppAuthType.PUBLIC: return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded) - elif app_auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external": + elif auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external": raise WebAppAuthRequiredError("Please login as external user.") - elif app_auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal": + elif auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal": raise WebAppAuthRequiredError("Please login as internal user.") end_user = None diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index c6d98374c1..7bd3b8a56e 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -211,8 +211,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): user=user, stream=streaming, ) - # FIXME: Type hinting issue here, ignore it for now, will fix it later - return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) # type: ignore + return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def _generate_worker( self, diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 1fb076b685..f8bfbce37a 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -255,7 +255,7 @@ class PipelineGenerator(BaseAppGenerator): json_text = json.dumps(text) upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id) features = FeatureService.get_features(dataset.tenant_id) - if features.billing.subscription.plan == "sandbox": + if features.billing.enabled and features.billing.subscription.plan == "sandbox": tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}" tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}" diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py index 01ecf0298f..c64f44a603 100644 --- a/api/core/app/apps/workflow/generate_response_converter.py +++ b/api/core/app/apps/workflow/generate_response_converter.py @@ -89,7 +89,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): - response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute] + response_chunk.update(sub_stream_response.to_ignore_detail_dict()) else: response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index ffa10cd43c..565905be0d 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -98,7 +98,7 @@ class RateLimit: else: return RateLimitGenerator( rate_limit=self, - generator=generator, # ty: ignore [invalid-argument-type] + generator=generator, request_id=request_id, ) diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 45e3c0006b..26c7e60a4c 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -49,7 +49,7 @@ class BasedGenerateTaskPipeline: if isinstance(e, InvokeAuthorizationError): err = InvokeAuthorizationError("Incorrect API key provided") elif isinstance(e, InvokeError | ValueError): - err = e # ty: ignore [invalid-assignment] + err = e else: description = getattr(e, "description", None) err = Exception(description if description is not None else str(e)) diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index c4be429219..b10838f8c9 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -1868,7 +1868,7 @@ class ProviderConfigurations(BaseModel): if "/" not in key: key = str(ModelProviderID(key)) - return self.configurations.get(key, default) # type: ignore + return self.configurations.get(key, default) class ProviderModelBundle(BaseModel): diff --git a/api/core/helper/module_import_helper.py b/api/core/helper/module_import_helper.py index 6a2f27b8ba..2bada85582 100644 --- a/api/core/helper/module_import_helper.py +++ b/api/core/helper/module_import_helper.py @@ -20,7 +20,7 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz else: # Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly # FIXME: mypy does not support the type of spec.loader - spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore + spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore[assignment] if not spec or not spec.loader: raise Exception(f"Failed to load module {module_name} from {py_file_path!r}") if use_lazy_loader: diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 7822ed4268..36b38b7b45 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -49,62 +49,80 @@ class IndexingRunner: self.storage = storage self.model_manager = ModelManager() + def _handle_indexing_error(self, document_id: str, error: Exception) -> None: + """Handle indexing errors by updating document status.""" + logger.exception("consume document failed") + document = db.session.get(DatasetDocument, document_id) + if document: + document.indexing_status = "error" + error_message = getattr(error, "description", str(error)) + document.error = str(error_message) + document.stopped_at = naive_utc_now() + db.session.commit() + def run(self, dataset_documents: list[DatasetDocument]): """Run the indexing process.""" for dataset_document in dataset_documents: + document_id = dataset_document.id try: + # Re-query the document to ensure it's bound to the current session + requeried_document = db.session.get(DatasetDocument, document_id) + if not requeried_document: + logger.warning("Document not found, skipping document id: %s", document_id) + continue + # get dataset - dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first() + dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first() if not dataset: raise ValueError("no dataset found") # get the process rule stmt = select(DatasetProcessRule).where( - DatasetProcessRule.id == dataset_document.dataset_process_rule_id + DatasetProcessRule.id == requeried_document.dataset_process_rule_id ) processing_rule = db.session.scalar(stmt) if not processing_rule: raise ValueError("no process rule found") - index_type = dataset_document.doc_form + index_type = requeried_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() # extract - text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict()) + text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict()) # transform documents = self._transform( - index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict() + index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict() ) # save segment - self._load_segments(dataset, dataset_document, documents) + self._load_segments(dataset, requeried_document, documents) # load self._load( index_processor=index_processor, dataset=dataset, - dataset_document=dataset_document, + dataset_document=requeried_document, documents=documents, ) except DocumentIsPausedError: - raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}") + raise DocumentIsPausedError(f"Document paused, document id: {document_id}") except ProviderTokenNotInitError as e: - dataset_document.indexing_status = "error" - dataset_document.error = str(e.description) - dataset_document.stopped_at = naive_utc_now() - db.session.commit() + self._handle_indexing_error(document_id, e) except ObjectDeletedError: - logger.warning("Document deleted, document id: %s", dataset_document.id) + logger.warning("Document deleted, document id: %s", document_id) except Exception as e: - logger.exception("consume document failed") - dataset_document.indexing_status = "error" - dataset_document.error = str(e) - dataset_document.stopped_at = naive_utc_now() - db.session.commit() + self._handle_indexing_error(document_id, e) def run_in_splitting_status(self, dataset_document: DatasetDocument): """Run the indexing process when the index_status is splitting.""" + document_id = dataset_document.id try: + # Re-query the document to ensure it's bound to the current session + requeried_document = db.session.get(DatasetDocument, document_id) + if not requeried_document: + logger.warning("Document not found: %s", document_id) + return + # get dataset - dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first() + dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first() if not dataset: raise ValueError("no dataset found") @@ -112,57 +130,60 @@ class IndexingRunner: # get exist document_segment list and delete document_segments = ( db.session.query(DocumentSegment) - .filter_by(dataset_id=dataset.id, document_id=dataset_document.id) + .filter_by(dataset_id=dataset.id, document_id=requeried_document.id) .all() ) for document_segment in document_segments: db.session.delete(document_segment) - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX: # delete child chunks db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete() db.session.commit() # get the process rule - stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) + stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == requeried_document.dataset_process_rule_id) processing_rule = db.session.scalar(stmt) if not processing_rule: raise ValueError("no process rule found") - index_type = dataset_document.doc_form + index_type = requeried_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() # extract - text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict()) + text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict()) # transform documents = self._transform( - index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict() + index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict() ) # save segment - self._load_segments(dataset, dataset_document, documents) + self._load_segments(dataset, requeried_document, documents) # load self._load( - index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents + index_processor=index_processor, + dataset=dataset, + dataset_document=requeried_document, + documents=documents, ) except DocumentIsPausedError: - raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}") + raise DocumentIsPausedError(f"Document paused, document id: {document_id}") except ProviderTokenNotInitError as e: - dataset_document.indexing_status = "error" - dataset_document.error = str(e.description) - dataset_document.stopped_at = naive_utc_now() - db.session.commit() + self._handle_indexing_error(document_id, e) except Exception as e: - logger.exception("consume document failed") - dataset_document.indexing_status = "error" - dataset_document.error = str(e) - dataset_document.stopped_at = naive_utc_now() - db.session.commit() + self._handle_indexing_error(document_id, e) def run_in_indexing_status(self, dataset_document: DatasetDocument): """Run the indexing process when the index_status is indexing.""" + document_id = dataset_document.id try: + # Re-query the document to ensure it's bound to the current session + requeried_document = db.session.get(DatasetDocument, document_id) + if not requeried_document: + logger.warning("Document not found: %s", document_id) + return + # get dataset - dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first() + dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first() if not dataset: raise ValueError("no dataset found") @@ -170,7 +191,7 @@ class IndexingRunner: # get exist document_segment list and delete document_segments = ( db.session.query(DocumentSegment) - .filter_by(dataset_id=dataset.id, document_id=dataset_document.id) + .filter_by(dataset_id=dataset.id, document_id=requeried_document.id) .all() ) @@ -188,7 +209,7 @@ class IndexingRunner: "dataset_id": document_segment.dataset_id, }, ) - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX: child_chunks = document_segment.get_child_chunks() if child_chunks: child_documents = [] @@ -206,24 +227,20 @@ class IndexingRunner: document.children = child_documents documents.append(document) # build index - index_type = dataset_document.doc_form + index_type = requeried_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() self._load( - index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents + index_processor=index_processor, + dataset=dataset, + dataset_document=requeried_document, + documents=documents, ) except DocumentIsPausedError: - raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}") + raise DocumentIsPausedError(f"Document paused, document id: {document_id}") except ProviderTokenNotInitError as e: - dataset_document.indexing_status = "error" - dataset_document.error = str(e.description) - dataset_document.stopped_at = naive_utc_now() - db.session.commit() + self._handle_indexing_error(document_id, e) except Exception as e: - logger.exception("consume document failed") - dataset_document.indexing_status = "error" - dataset_document.error = str(e) - dataset_document.stopped_at = naive_utc_now() - db.session.commit() + self._handle_indexing_error(document_id, e) def indexing_estimate( self, @@ -398,7 +415,6 @@ class IndexingRunner: document_id=dataset_document.id, after_indexing_status="splitting", extra_update_params={ - DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs), DatasetDocument.parsing_completed_at: naive_utc_now(), }, ) @@ -738,6 +754,7 @@ class IndexingRunner: extra_update_params={ DatasetDocument.cleaning_completed_at: cur_time, DatasetDocument.splitting_completed_at: cur_time, + DatasetDocument.word_count: sum(len(doc.page_content) for doc in documents), }, ) diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index ea5334c011..67b86dfc43 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -103,7 +103,7 @@ class LLMGenerator: return name @classmethod - def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str): + def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str) -> Sequence[str]: output_parser = SuggestedQuestionsAfterAnswerOutputParser() format_instructions = output_parser.get_format_instructions() @@ -122,6 +122,8 @@ class LLMGenerator: prompt_messages = [UserPromptMessage(content=prompt)] + questions: Sequence[str] = [] + try: response: LLMResult = model_instance.invoke_llm( prompt_messages=list(prompt_messages), diff --git a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py index e78859cc1a..eec771181f 100644 --- a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py +++ b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py @@ -1,17 +1,26 @@ import json +import logging import re +from collections.abc import Sequence from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT +logger = logging.getLogger(__name__) + class SuggestedQuestionsAfterAnswerOutputParser: def get_format_instructions(self) -> str: return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT - def parse(self, text: str): + def parse(self, text: str) -> Sequence[str]: action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL) + questions: list[str] = [] if action_match is not None: - json_obj = json.loads(action_match.group(0).strip()) - else: - json_obj = [] - return json_obj + try: + json_obj = json.loads(action_match.group(0).strip()) + except json.JSONDecodeError as exc: + logger.warning("Failed to decode suggested questions payload: %s", exc) + else: + if isinstance(json_obj, list): + questions = [question for question in json_obj if isinstance(question, str)] + return questions diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index a1fcd6e033..951c22f6dd 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -4,14 +4,14 @@ import json import os import secrets import urllib.parse -from typing import TYPE_CHECKING from urllib.parse import urljoin, urlparse from httpx import ConnectError, HTTPStatusError, RequestError -from pydantic import BaseModel, ValidationError +from pydantic import ValidationError from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType from core.helper import ssrf_proxy +from core.mcp.entities import AuthAction, AuthActionType, AuthResult, OAuthCallbackState from core.mcp.error import MCPRefreshTokenError from core.mcp.types import ( LATEST_PROTOCOL_VERSION, @@ -23,23 +23,10 @@ from core.mcp.types import ( ) from extensions.ext_redis import redis_client -if TYPE_CHECKING: - from services.tools.mcp_tools_manage_service import MCPToolManageService - OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:" -class OAuthCallbackState(BaseModel): - provider_id: str - tenant_id: str - server_url: str - metadata: OAuthMetadata | None = None - client_information: OAuthClientInformation - code_verifier: str - redirect_uri: str - - def generate_pkce_challenge() -> tuple[str, str]: """Generate PKCE challenge and verifier.""" code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8") @@ -86,8 +73,13 @@ def _retrieve_redis_state(state_key: str) -> OAuthCallbackState: raise ValueError(f"Invalid state parameter: {str(e)}") -def handle_callback(state_key: str, authorization_code: str, mcp_service: "MCPToolManageService") -> OAuthCallbackState: - """Handle the callback from the OAuth provider.""" +def handle_callback(state_key: str, authorization_code: str) -> tuple[OAuthCallbackState, OAuthTokens]: + """ + Handle the callback from the OAuth provider. + + Returns: + A tuple of (callback_state, tokens) that can be used by the caller to save data. + """ # Retrieve state data from Redis (state is automatically deleted after retrieval) full_state_data = _retrieve_redis_state(state_key) @@ -100,10 +92,7 @@ def handle_callback(state_key: str, authorization_code: str, mcp_service: "MCPTo full_state_data.redirect_uri, ) - # Save tokens using the service layer - mcp_service.save_oauth_data(full_state_data.provider_id, full_state_data.tenant_id, tokens.model_dump(), "tokens") - - return full_state_data + return full_state_data, tokens def check_support_resource_discovery(server_url: str) -> tuple[bool, str]: @@ -361,11 +350,24 @@ def register_client( def auth( provider: MCPProviderEntity, - mcp_service: "MCPToolManageService", authorization_code: str | None = None, state_param: str | None = None, -) -> dict[str, str]: - """Orchestrates the full auth flow with a server using secure Redis state storage.""" +) -> AuthResult: + """ + Orchestrates the full auth flow with a server using secure Redis state storage. + + This function performs only network operations and returns actions that need + to be performed by the caller (such as saving data to database). + + Args: + provider: The MCP provider entity + authorization_code: Optional authorization code from OAuth callback + state_param: Optional state parameter from OAuth callback + + Returns: + AuthResult containing actions to be performed and response data + """ + actions: list[AuthAction] = [] server_url = provider.decrypt_server_url() server_metadata = discover_oauth_metadata(server_url) client_metadata = provider.client_metadata @@ -407,9 +409,14 @@ def auth( except RequestError as e: raise ValueError(f"Could not register OAuth client: {e}") - # Save client information using service layer - mcp_service.save_oauth_data( - provider_id, tenant_id, {"client_information": full_information.model_dump()}, "client_info" + # Return action to save client information + actions.append( + AuthAction( + action_type=AuthActionType.SAVE_CLIENT_INFO, + data={"client_information": full_information.model_dump()}, + provider_id=provider_id, + tenant_id=tenant_id, + ) ) client_information = full_information @@ -426,12 +433,20 @@ def auth( scope, ) - # Save tokens and grant type + # Return action to save tokens and grant type token_data = tokens.model_dump() token_data["grant_type"] = MCPSupportGrantType.CLIENT_CREDENTIALS.value - mcp_service.save_oauth_data(provider_id, tenant_id, token_data, "tokens") - return {"result": "success"} + actions.append( + AuthAction( + action_type=AuthActionType.SAVE_TOKENS, + data=token_data, + provider_id=provider_id, + tenant_id=tenant_id, + ) + ) + + return AuthResult(actions=actions, response={"result": "success"}) except (RequestError, ValueError, KeyError) as e: # RequestError: HTTP request failed # ValueError: Invalid response data @@ -465,10 +480,17 @@ def auth( redirect_uri, ) - # Save tokens using service layer - mcp_service.save_oauth_data(provider_id, tenant_id, tokens.model_dump(), "tokens") + # Return action to save tokens + actions.append( + AuthAction( + action_type=AuthActionType.SAVE_TOKENS, + data=tokens.model_dump(), + provider_id=provider_id, + tenant_id=tenant_id, + ) + ) - return {"result": "success"} + return AuthResult(actions=actions, response={"result": "success"}) provider_tokens = provider.retrieve_tokens() @@ -479,10 +501,17 @@ def auth( server_url, server_metadata, client_information, provider_tokens.refresh_token ) - # Save new tokens using service layer - mcp_service.save_oauth_data(provider_id, tenant_id, new_tokens.model_dump(), "tokens") + # Return action to save new tokens + actions.append( + AuthAction( + action_type=AuthActionType.SAVE_TOKENS, + data=new_tokens.model_dump(), + provider_id=provider_id, + tenant_id=tenant_id, + ) + ) - return {"result": "success"} + return AuthResult(actions=actions, response={"result": "success"}) except (RequestError, ValueError, KeyError) as e: # RequestError: HTTP request failed # ValueError: Invalid response data @@ -499,7 +528,14 @@ def auth( tenant_id, ) - # Save code verifier using service layer - mcp_service.save_oauth_data(provider_id, tenant_id, {"code_verifier": code_verifier}, "code_verifier") + # Return action to save code verifier + actions.append( + AuthAction( + action_type=AuthActionType.SAVE_CODE_VERIFIER, + data={"code_verifier": code_verifier}, + provider_id=provider_id, + tenant_id=tenant_id, + ) + ) - return {"authorization_url": authorization_url} + return AuthResult(actions=actions, response={"authorization_url": authorization_url}) diff --git a/api/core/mcp/auth_client.py b/api/core/mcp/auth_client.py index 95f552f5db..942c8d3c23 100644 --- a/api/core/mcp/auth_client.py +++ b/api/core/mcp/auth_client.py @@ -7,15 +7,15 @@ authentication failures and retries operations after refreshing tokens. import logging from collections.abc import Callable -from typing import TYPE_CHECKING, Any, Optional +from typing import Any + +from sqlalchemy.orm import Session from core.entities.mcp_provider import MCPProviderEntity from core.mcp.error import MCPAuthError from core.mcp.mcp_client import MCPClient from core.mcp.types import CallToolResult, Tool - -if TYPE_CHECKING: - from services.tools.mcp_tools_manage_service import MCPToolManageService +from extensions.ext_database import db logger = logging.getLogger(__name__) @@ -26,6 +26,9 @@ class MCPClientWithAuthRetry(MCPClient): This class extends MCPClient and intercepts MCPAuthError exceptions to refresh authentication before retrying failed operations. + + Note: This class uses lazy session creation - database sessions are only + created when authentication retry is actually needed, not on every request. """ def __init__( @@ -35,11 +38,8 @@ class MCPClientWithAuthRetry(MCPClient): timeout: float | None = None, sse_read_timeout: float | None = None, provider_entity: MCPProviderEntity | None = None, - auth_callback: Callable[[MCPProviderEntity, "MCPToolManageService", Optional[str]], dict[str, str]] - | None = None, authorization_code: str | None = None, by_server_id: bool = False, - mcp_service: Optional["MCPToolManageService"] = None, ): """ Initialize the MCP client with auth retry capability. @@ -50,31 +50,30 @@ class MCPClientWithAuthRetry(MCPClient): timeout: Request timeout sse_read_timeout: SSE read timeout provider_entity: Provider entity for authentication - auth_callback: Authentication callback function authorization_code: Optional authorization code for initial auth by_server_id: Whether to look up provider by server ID - mcp_service: MCP service instance """ super().__init__(server_url, headers, timeout, sse_read_timeout) self.provider_entity = provider_entity - self.auth_callback = auth_callback self.authorization_code = authorization_code self.by_server_id = by_server_id - self.mcp_service = mcp_service self._has_retried = False def _handle_auth_error(self, error: MCPAuthError) -> None: """ Handle authentication error by refreshing tokens. + This method creates a short-lived database session only when authentication + retry is needed, minimizing database connection hold time. + Args: error: The authentication error Raises: MCPAuthError: If authentication fails or max retries reached """ - if not self.provider_entity or not self.auth_callback or not self.mcp_service: + if not self.provider_entity: raise error if self._has_retried: raise error @@ -82,13 +81,23 @@ class MCPClientWithAuthRetry(MCPClient): self._has_retried = True try: - # Perform authentication - self.auth_callback(self.provider_entity, self.mcp_service, self.authorization_code) + # Create a temporary session only for auth retry + # This session is short-lived and only exists during the auth operation - # Retrieve new tokens - self.provider_entity = self.mcp_service.get_provider_entity( - self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id - ) + from services.tools.mcp_tools_manage_service import MCPToolManageService + + with Session(db.engine) as session, session.begin(): + mcp_service = MCPToolManageService(session=session) + + # Perform authentication using the service's auth method + mcp_service.auth_with_actions(self.provider_entity, self.authorization_code) + + # Retrieve new tokens + self.provider_entity = mcp_service.get_provider_entity( + self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id + ) + + # Session is closed here, before we update headers token = self.provider_entity.retrieve_tokens() if not token: raise MCPAuthError("Authentication failed - no token received") diff --git a/api/core/mcp/entities.py b/api/core/mcp/entities.py index 9e414ab2b3..08823daab1 100644 --- a/api/core/mcp/entities.py +++ b/api/core/mcp/entities.py @@ -1,8 +1,11 @@ from dataclasses import dataclass +from enum import StrEnum from typing import Any, Generic, TypeVar +from pydantic import BaseModel + from core.mcp.session.base_session import BaseSession -from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestId, RequestParams +from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAuthMetadata, RequestId, RequestParams SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION] @@ -17,3 +20,41 @@ class RequestContext(Generic[SessionT, LifespanContextT]): meta: RequestParams.Meta | None session: SessionT lifespan_context: LifespanContextT + + +class AuthActionType(StrEnum): + """Types of actions that can be performed during auth flow.""" + + SAVE_CLIENT_INFO = "save_client_info" + SAVE_TOKENS = "save_tokens" + SAVE_CODE_VERIFIER = "save_code_verifier" + START_AUTHORIZATION = "start_authorization" + SUCCESS = "success" + + +class AuthAction(BaseModel): + """Represents an action that needs to be performed as a result of auth flow.""" + + action_type: AuthActionType + data: dict[str, Any] + provider_id: str | None = None + tenant_id: str | None = None + + +class AuthResult(BaseModel): + """Result of auth function containing actions to be performed and response data.""" + + actions: list[AuthAction] + response: dict[str, str] + + +class OAuthCallbackState(BaseModel): + """State data stored in Redis during OAuth callback flow.""" + + provider_id: str + tenant_id: str + server_url: str + metadata: OAuthMetadata | None = None + client_information: OAuthClientInformation + code_verifier: str + redirect_uri: str diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 92e6b8ea60..4de4f403ce 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -2,7 +2,7 @@ import logging import os from datetime import datetime, timedelta -from langfuse import Langfuse # type: ignore +from langfuse import Langfuse from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance diff --git a/api/core/plugin/entities/parameters.py b/api/core/plugin/entities/parameters.py index 68b5c1084a..1e7f8e4c86 100644 --- a/api/core/plugin/entities/parameters.py +++ b/api/core/plugin/entities/parameters.py @@ -76,7 +76,7 @@ class PluginParameter(BaseModel): auto_generate: PluginParameterAutoGenerate | None = None template: PluginParameterTemplate | None = None required: bool = False - default: Union[float, int, str] | None = None + default: Union[float, int, str, bool] | None = None min: Union[float, int] | None = None max: Union[float, int] | None = None precision: int | None = None diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index 5095b46432..e9dc58eec8 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -180,7 +180,7 @@ class BasePluginClient: Make a request to the plugin daemon inner API and return the response as a model. """ response = self._request(method, path, headers, data, params, files) - return type_(**response.json()) # type: ignore + return type_(**response.json()) # type: ignore[return-value] def _request_with_plugin_daemon_response( self, diff --git a/api/core/plugin/impl/exc.py b/api/core/plugin/impl/exc.py index 23a69bd92f..e28a324217 100644 --- a/api/core/plugin/impl/exc.py +++ b/api/core/plugin/impl/exc.py @@ -40,7 +40,7 @@ class PluginDaemonBadRequestError(PluginDaemonClientSideError): description: str = "Bad Request" -class PluginInvokeError(PluginDaemonClientSideError): +class PluginInvokeError(PluginDaemonClientSideError, ValueError): description: str = "Invoke Error" def _get_error_object(self) -> Mapping: diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 99bbe615fb..45b19f25a0 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -72,6 +72,19 @@ default_retrieval_model: dict[str, Any] = { class DatasetRetrieval: def __init__(self, application_generate_entity=None): self.application_generate_entity = application_generate_entity + self._llm_usage = LLMUsage.empty_usage() + + @property + def llm_usage(self) -> LLMUsage: + return self._llm_usage.model_copy() + + def _record_usage(self, usage: LLMUsage | None) -> None: + if usage is None or usage.total_tokens <= 0: + return + if self._llm_usage.total_tokens == 0: + self._llm_usage = usage + else: + self._llm_usage = self._llm_usage.plus(usage) def retrieve( self, @@ -312,15 +325,18 @@ class DatasetRetrieval: ) tools.append(message_tool) dataset_id = None + router_usage = LLMUsage.empty_usage() if planning_strategy == PlanningStrategy.REACT_ROUTER: react_multi_dataset_router = ReactMultiDatasetRouter() - dataset_id = react_multi_dataset_router.invoke( + dataset_id, router_usage = react_multi_dataset_router.invoke( query, tools, model_config, model_instance, user_id, tenant_id ) elif planning_strategy == PlanningStrategy.ROUTER: function_call_router = FunctionCallMultiDatasetRouter() - dataset_id = function_call_router.invoke(query, tools, model_config, model_instance) + dataset_id, router_usage = function_call_router.invoke(query, tools, model_config, model_instance) + + self._record_usage(router_usage) if dataset_id: # get retrieval model config @@ -983,7 +999,8 @@ class DatasetRetrieval: ) # handle invoke result - result_text, _ = self._handle_invoke_result(invoke_result=invoke_result) + result_text, usage = self._handle_invoke_result(invoke_result=invoke_result) + self._record_usage(usage) result_text_json = parse_and_check_json_markdown(result_text, []) automatic_metadata_filters = [] diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py index de59c6380e..5f3e1a8cae 100644 --- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py +++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py @@ -2,7 +2,7 @@ from typing import Union from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage @@ -13,15 +13,15 @@ class FunctionCallMultiDatasetRouter: dataset_tools: list[PromptMessageTool], model_config: ModelConfigWithCredentialsEntity, model_instance: ModelInstance, - ) -> Union[str, None]: + ) -> tuple[Union[str, None], LLMUsage]: """Given input, decided what to do. Returns: Action specifying what tool to use. """ if len(dataset_tools) == 0: - return None + return None, LLMUsage.empty_usage() elif len(dataset_tools) == 1: - return dataset_tools[0].name + return dataset_tools[0].name, LLMUsage.empty_usage() try: prompt_messages = [ @@ -34,9 +34,10 @@ class FunctionCallMultiDatasetRouter: stream=False, model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500}, ) + usage = result.usage or LLMUsage.empty_usage() if result.message.tool_calls: # get retrieval model config - return result.message.tool_calls[0].function.name - return None + return result.message.tool_calls[0].function.name, usage + return None, usage except Exception: - return None + return None, LLMUsage.empty_usage() diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index 59d36229b3..8f3bec2704 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -58,15 +58,15 @@ class ReactMultiDatasetRouter: model_instance: ModelInstance, user_id: str, tenant_id: str, - ) -> Union[str, None]: + ) -> tuple[Union[str, None], LLMUsage]: """Given input, decided what to do. Returns: Action specifying what tool to use. """ if len(dataset_tools) == 0: - return None + return None, LLMUsage.empty_usage() elif len(dataset_tools) == 1: - return dataset_tools[0].name + return dataset_tools[0].name, LLMUsage.empty_usage() try: return self._react_invoke( @@ -78,7 +78,7 @@ class ReactMultiDatasetRouter: tenant_id=tenant_id, ) except Exception: - return None + return None, LLMUsage.empty_usage() def _react_invoke( self, @@ -91,7 +91,7 @@ class ReactMultiDatasetRouter: prefix: str = PREFIX, suffix: str = SUFFIX, format_instructions: str = FORMAT_INSTRUCTIONS, - ) -> Union[str, None]: + ) -> tuple[Union[str, None], LLMUsage]: prompt: Union[list[ChatModelMessage], CompletionModelPromptTemplate] if model_config.mode == "chat": prompt = self.create_chat_prompt( @@ -120,7 +120,7 @@ class ReactMultiDatasetRouter: memory=None, model_config=model_config, ) - result_text, _ = self._invoke_llm( + result_text, usage = self._invoke_llm( completion_param=model_config.parameters, model_instance=model_instance, prompt_messages=prompt_messages, @@ -131,8 +131,8 @@ class ReactMultiDatasetRouter: output_parser = StructuredChatOutputParser() react_decision = output_parser.parse(result_text) if isinstance(react_decision, ReactAction): - return react_decision.tool - return None + return react_decision.tool, usage + return None, usage def _invoke_llm( self, diff --git a/api/core/repositories/celery_workflow_execution_repository.py b/api/core/repositories/celery_workflow_execution_repository.py index 460bb75722..c7f5942f5f 100644 --- a/api/core/repositories/celery_workflow_execution_repository.py +++ b/api/core/repositories/celery_workflow_execution_repository.py @@ -74,7 +74,7 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository): tenant_id = extract_tenant_id(user) if not tenant_id: raise ValueError("User must have a tenant_id or current_tenant_id") - self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None + self._tenant_id = tenant_id # Store app context self._app_id = app_id diff --git a/api/core/repositories/celery_workflow_node_execution_repository.py b/api/core/repositories/celery_workflow_node_execution_repository.py index 21a0b7eefe..9b8e45b1eb 100644 --- a/api/core/repositories/celery_workflow_node_execution_repository.py +++ b/api/core/repositories/celery_workflow_node_execution_repository.py @@ -81,7 +81,7 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): tenant_id = extract_tenant_id(user) if not tenant_id: raise ValueError("User must have a tenant_id or current_tenant_id") - self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None + self._tenant_id = tenant_id # Store app context self._app_id = app_id diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py index 854c122331..02fcabab5d 100644 --- a/api/core/repositories/factory.py +++ b/api/core/repositories/factory.py @@ -60,7 +60,7 @@ class DifyCoreRepositoryFactory: try: repository_class = import_string(class_path) - return repository_class( # type: ignore[no-any-return] + return repository_class( session_factory=session_factory, user=user, app_id=app_id, @@ -96,7 +96,7 @@ class DifyCoreRepositoryFactory: try: repository_class = import_string(class_path) - return repository_class( # type: ignore[no-any-return] + return repository_class( session_factory=session_factory, user=user, app_id=app_id, diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index 2e94907f30..a391136a5c 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -157,7 +157,7 @@ class BuiltinToolProviderController(ToolProviderController): """ returns the tool that the provider can provide """ - return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None) # type: ignore + return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None) @property def need_credentials(self) -> bool: diff --git a/api/core/tools/builtin_tool/providers/audio/tools/tts.py b/api/core/tools/builtin_tool/providers/audio/tools/tts.py index 8bc159bb85..5009f7ac21 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/tts.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/tts.py @@ -43,7 +43,7 @@ class TTSTool(BuiltinTool): content_text=tool_parameters.get("text"), # type: ignore user=user_id, tenant_id=self.runtime.tenant_id, - voice=voice, # type: ignore + voice=voice, ) buffer = io.BytesIO() for chunk in tts: diff --git a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py index 197b062e44..d0a41b940f 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py +++ b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py @@ -34,6 +34,7 @@ class LocaltimeToTimestampTool(BuiltinTool): yield self.create_text_message(f"{timestamp}") + # TODO: this method's type is messy @staticmethod def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None: try: diff --git a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py index babfa9bcd9..e23ae3b001 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py +++ b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py @@ -48,6 +48,6 @@ class TimezoneConversionTool(BuiltinTool): datetime_with_tz = input_timezone.localize(local_time) # timezone convert converted_datetime = datetime_with_tz.astimezone(output_timezone) - return converted_datetime.strftime(format=time_format) # type: ignore + return converted_datetime.strftime(time_format) except Exception as e: raise ToolInvokeError(str(e)) diff --git a/api/core/tools/mcp_tool/provider.py b/api/core/tools/mcp_tool/provider.py index 3404f5c3b4..557211c8c8 100644 --- a/api/core/tools/mcp_tool/provider.py +++ b/api/core/tools/mcp_tool/provider.py @@ -113,7 +113,7 @@ class MCPToolProviderController(ToolProviderController): """ pass - def get_tool(self, tool_name: str) -> MCPTool: # type: ignore + def get_tool(self, tool_name: str) -> MCPTool: """ return tool with given name """ @@ -136,7 +136,7 @@ class MCPToolProviderController(ToolProviderController): sse_read_timeout=self.sse_read_timeout, ) - def get_tools(self) -> list[MCPTool]: # type: ignore + def get_tools(self) -> list[MCPTool]: """ get all tools """ diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index 290077ecd8..a476859f29 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -3,7 +3,6 @@ import json from collections.abc import Generator from typing import Any -from core.mcp.auth.auth_flow import auth from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.error import MCPConnectionError from core.mcp.types import CallToolResult, ImageContent, TextContent @@ -125,71 +124,39 @@ class MCPTool(Tool): headers = self.headers.copy() if self.headers else {} tool_parameters = self._handle_none_parameter(tool_parameters) - # Get provider entity to access tokens + from sqlalchemy.orm import Session - # Get MCP service from invoke parameters or create new one - provider_entity = None - mcp_service = None + from extensions.ext_database import db + from services.tools.mcp_tools_manage_service import MCPToolManageService - # Check if mcp_service is passed in tool_parameters - if "_mcp_service" in tool_parameters: - mcp_service = tool_parameters.pop("_mcp_service") - if mcp_service: - provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True) - headers = provider_entity.decrypt_headers() - # Try to get existing token and add to headers - if not headers: - tokens = provider_entity.retrieve_tokens() - if tokens and tokens.access_token: - headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}" + # Step 1: Load provider entity and credentials in a short-lived session + # This minimizes database connection hold time + with Session(db.engine, expire_on_commit=False) as session: + mcp_service = MCPToolManageService(session=session) + provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True) - # Use MCPClientWithAuthRetry to handle authentication automatically - try: - with MCPClientWithAuthRetry( - server_url=provider_entity.decrypt_server_url() if provider_entity else self.server_url, - headers=headers, - timeout=self.timeout, - sse_read_timeout=self.sse_read_timeout, - provider_entity=provider_entity, - auth_callback=auth if mcp_service else None, - mcp_service=mcp_service, - ) as mcp_client: - return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters) - except MCPConnectionError as e: - raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e - except (ValueError, TypeError, KeyError) as e: - # Catch specific exceptions that might occur during tool invocation - raise ToolInvokeError(f"Failed to invoke tool: {e}") from e - else: - # Fallback to creating service with database session - from sqlalchemy.orm import Session + # Decrypt and prepare all credentials before closing session + server_url = provider_entity.decrypt_server_url() + headers = provider_entity.decrypt_headers() - from extensions.ext_database import db - from services.tools.mcp_tools_manage_service import MCPToolManageService + # Try to get existing token and add to headers + if not headers: + tokens = provider_entity.retrieve_tokens() + if tokens and tokens.access_token: + headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}" - with Session(db.engine, expire_on_commit=False) as session: - mcp_service = MCPToolManageService(session=session) - provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True) - headers = provider_entity.decrypt_headers() - # Try to get existing token and add to headers - if not headers: - tokens = provider_entity.retrieve_tokens() - if tokens and tokens.access_token: - headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}" - - # Use MCPClientWithAuthRetry to handle authentication automatically - try: - with MCPClientWithAuthRetry( - server_url=provider_entity.decrypt_server_url() if provider_entity else self.server_url, - headers=headers, - timeout=self.timeout, - sse_read_timeout=self.sse_read_timeout, - provider_entity=provider_entity, - auth_callback=auth if mcp_service else None, - mcp_service=mcp_service, - ) as mcp_client: - return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters) - except MCPConnectionError as e: - raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e - except Exception as e: - raise ToolInvokeError(f"Failed to invoke tool: {e}") from e + # Step 2: Session is now closed, perform network operations without holding database connection + # MCPClientWithAuthRetry will create a new session lazily only if auth retry is needed + try: + with MCPClientWithAuthRetry( + server_url=server_url, + headers=headers, + timeout=self.timeout, + sse_read_timeout=self.sse_read_timeout, + provider_entity=provider_entity, + ) as mcp_client: + return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters) + except MCPConnectionError as e: + raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e + except Exception as e: + raise ToolInvokeError(f"Failed to invoke tool: {e}") from e diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index 39646b7fc8..90d5a647e9 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -26,7 +26,7 @@ class ToolLabelManager: labels = cls.filter_tool_labels(labels) if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): - provider_id = controller.provider_id # ty: ignore [unresolved-attribute] + provider_id = controller.provider_id else: raise ValueError("Unsupported tool type") @@ -51,7 +51,7 @@ class ToolLabelManager: Get tool labels """ if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): - provider_id = controller.provider_id # ty: ignore [unresolved-attribute] + provider_id = controller.provider_id elif isinstance(controller, BuiltinToolProviderController): return controller.tool_labels else: @@ -85,7 +85,7 @@ class ToolLabelManager: provider_ids = [] for controller in tool_providers: assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController) - provider_ids.append(controller.provider_id) # ty: ignore [unresolved-attribute] + provider_ids.append(controller.provider_id) labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all() diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 0641fa01fe..ff7dcc0e55 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -331,7 +331,8 @@ class ToolManager: workflow_provider_stmt = select(WorkflowToolProvider).where( WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id ) - workflow_provider = db.session.scalar(workflow_provider_stmt) + with Session(db.engine, expire_on_commit=False) as session, session.begin(): + workflow_provider = session.scalar(workflow_provider_stmt) if workflow_provider is None: raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index 915a22dd0f..f96510fb45 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -193,18 +193,18 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): DatasetDocument.enabled == True, DatasetDocument.archived == False, ) - document = db.session.scalar(dataset_document_stmt) # type: ignore + document = db.session.scalar(dataset_document_stmt) if dataset and document: source = RetrievalSourceMetadata( dataset_id=dataset.id, dataset_name=dataset.name, - document_id=document.id, # type: ignore - document_name=document.name, # type: ignore - data_source_type=document.data_source_type, # type: ignore + document_id=document.id, + document_name=document.name, + data_source_type=document.data_source_type, segment_id=segment.id, retriever_from=self.retriever_from, score=record.score or 0.0, - doc_metadata=document.doc_metadata, # type: ignore + doc_metadata=document.doc_metadata, ) if self.retriever_from == "dev": diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index c7ac3387e5..6eabde3991 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -62,6 +62,11 @@ class ApiBasedToolSchemaParser: root = root[ref] interface["operation"]["parameters"][i] = root for parameter in interface["operation"]["parameters"]: + # Handle complex type defaults that are not supported by PluginParameter + default_value = None + if "schema" in parameter and "default" in parameter["schema"]: + default_value = ApiBasedToolSchemaParser._sanitize_default_value(parameter["schema"]["default"]) + tool_parameter = ToolParameter( name=parameter["name"], label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]), @@ -72,9 +77,7 @@ class ApiBasedToolSchemaParser: required=parameter.get("required", False), form=ToolParameter.ToolParameterForm.LLM, llm_description=parameter.get("description"), - default=parameter["schema"]["default"] - if "schema" in parameter and "default" in parameter["schema"] - else None, + default=default_value, placeholder=I18nObject( en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "") ), @@ -134,6 +137,11 @@ class ApiBasedToolSchemaParser: required = body_schema.get("required", []) properties = body_schema.get("properties", {}) for name, property in properties.items(): + # Handle complex type defaults that are not supported by PluginParameter + default_value = ApiBasedToolSchemaParser._sanitize_default_value( + property.get("default", None) + ) + tool = ToolParameter( name=name, label=I18nObject(en_US=name, zh_Hans=name), @@ -144,12 +152,11 @@ class ApiBasedToolSchemaParser: required=name in required, form=ToolParameter.ToolParameterForm.LLM, llm_description=property.get("description", ""), - default=property.get("default", None), + default=default_value, placeholder=I18nObject( en_US=property.get("description", ""), zh_Hans=property.get("description", "") ), ) - # check if there is a type typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property) if typ: @@ -197,6 +204,22 @@ class ApiBasedToolSchemaParser: return bundles + @staticmethod + def _sanitize_default_value(value): + """ + Sanitize default values for PluginParameter compatibility. + Complex types (list, dict) are converted to None to avoid validation errors. + + Args: + value: The default value from OpenAPI schema + + Returns: + None for complex types (list, dict), otherwise the original value + """ + if isinstance(value, (list, dict)): + return None + return value + @staticmethod def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None: parameter = parameter or {} @@ -217,7 +240,11 @@ class ApiBasedToolSchemaParser: return ToolParameter.ToolParameterType.STRING elif typ == "array": items = parameter.get("items") or parameter.get("schema", {}).get("items") - return ToolParameter.ToolParameterType.FILES if items and items.get("format") == "binary" else None + if items and items.get("format") == "binary": + return ToolParameter.ToolParameterType.FILES + else: + # For regular arrays, return ARRAY type instead of None + return ToolParameter.ToolParameterType.ARRAY else: return None diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index 52c16c34a0..ef6913d0bd 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -6,8 +6,8 @@ from typing import Any, cast from urllib.parse import unquote import chardet -import cloudscraper # type: ignore -from readabilipy import simple_json_from_html_string # type: ignore +import cloudscraper +from readabilipy import simple_json_from_html_string from core.helper import ssrf_proxy from core.rag.extractor import extract_processor @@ -63,8 +63,8 @@ def get_url(url: str, user_agent: str | None = None) -> str: response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) elif response.status_code == 403: scraper = cloudscraper.create_scraper() - scraper.perform_request = ssrf_proxy.make_request # type: ignore - response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) # type: ignore + scraper.perform_request = ssrf_proxy.make_request + response = scraper.get(url, headers=headers, timeout=(120, 300)) if response.status_code != 200: return f"URL returned status code {response.status_code}." diff --git a/api/core/tools/utils/yaml_utils.py b/api/core/tools/utils/yaml_utils.py index e9b5dab7d3..071154ee71 100644 --- a/api/core/tools/utils/yaml_utils.py +++ b/api/core/tools/utils/yaml_utils.py @@ -3,7 +3,7 @@ from functools import lru_cache from pathlib import Path from typing import Any -import yaml # type: ignore +import yaml from yaml import YAMLError logger = logging.getLogger(__name__) diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index 4d9c8895fc..d7afbc7389 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -1,6 +1,7 @@ from collections.abc import Mapping from pydantic import Field +from sqlalchemy.orm import Session from core.app.app_config.entities import VariableEntity, VariableEntityType from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager @@ -20,6 +21,7 @@ from core.tools.entities.tool_entities import ( from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db +from models.account import Account from models.model import App, AppMode from models.tools import WorkflowToolProvider from models.workflow import Workflow @@ -44,29 +46,34 @@ class WorkflowToolProviderController(ToolProviderController): @classmethod def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController": - app = db_provider.app + with Session(db.engine, expire_on_commit=False) as session, session.begin(): + provider = session.get(WorkflowToolProvider, db_provider.id) if db_provider.id else None + if not provider: + raise ValueError("workflow provider not found") + app = session.get(App, provider.app_id) + if not app: + raise ValueError("app not found") - if not app: - raise ValueError("app not found") + user = session.get(Account, provider.user_id) if provider.user_id else None - controller = WorkflowToolProviderController( - entity=ToolProviderEntity( - identity=ToolProviderIdentity( - author=db_provider.user.name if db_provider.user_id and db_provider.user else "", - name=db_provider.label, - label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label), - description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description), - icon=db_provider.icon, + controller = WorkflowToolProviderController( + entity=ToolProviderEntity( + identity=ToolProviderIdentity( + author=user.name if user else "", + name=provider.label, + label=I18nObject(en_US=provider.label, zh_Hans=provider.label), + description=I18nObject(en_US=provider.description, zh_Hans=provider.description), + icon=provider.icon, + ), + credentials_schema=[], + plugin_id=None, ), - credentials_schema=[], - plugin_id=None, - ), - provider_id=db_provider.id or "", - ) + provider_id=provider.id or "", + ) - # init tools - - controller.tools = [controller._get_db_provider_tool(db_provider, app)] + controller.tools = [ + controller._get_db_provider_tool(provider, app, session=session, user=user), + ] return controller @@ -74,7 +81,14 @@ class WorkflowToolProviderController(ToolProviderController): def provider_type(self) -> ToolProviderType: return ToolProviderType.WORKFLOW - def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool: + def _get_db_provider_tool( + self, + db_provider: WorkflowToolProvider, + app: App, + *, + session: Session, + user: Account | None = None, + ) -> WorkflowTool: """ get db provider tool :param db_provider: the db provider @@ -82,7 +96,7 @@ class WorkflowToolProviderController(ToolProviderController): :return: the tool """ workflow: Workflow | None = ( - db.session.query(Workflow) + session.query(Workflow) .where(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version) .first() ) @@ -99,9 +113,7 @@ class WorkflowToolProviderController(ToolProviderController): variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph) def fetch_workflow_variable(variable_name: str) -> VariableEntity | None: - return next(filter(lambda x: x.variable == variable_name, variables), None) # type: ignore - - user = db_provider.user + return next(filter(lambda x: x.variable == variable_name, variables), None) workflow_tool_parameters = [] for parameter in parameters: @@ -187,22 +199,25 @@ class WorkflowToolProviderController(ToolProviderController): if self.tools is not None: return self.tools - db_providers: WorkflowToolProvider | None = ( - db.session.query(WorkflowToolProvider) - .where( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.app_id == self.provider_id, + with Session(db.engine, expire_on_commit=False) as session, session.begin(): + db_provider: WorkflowToolProvider | None = ( + session.query(WorkflowToolProvider) + .where( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.app_id == self.provider_id, + ) + .first() ) - .first() - ) - if not db_providers: - return [] - if not db_providers.app: - raise ValueError("app not found") + if not db_provider: + return [] - app = db_providers.app - self.tools = [self._get_db_provider_tool(db_providers, app)] + app = session.get(App, db_provider.app_id) + if not app: + raise ValueError("app not found") + + user = session.get(Account, db_provider.user_id) if db_provider.user_id else None + self.tools = [self._get_db_provider_tool(db_provider, app, session=session, user=user)] return self.tools diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 50c2327004..2cd46647a0 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -1,12 +1,14 @@ import json import logging -from collections.abc import Generator -from typing import Any +from collections.abc import Generator, Mapping, Sequence +from typing import Any, cast from flask import has_request_context from sqlalchemy import select +from sqlalchemy.orm import Session from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod +from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ( @@ -48,6 +50,7 @@ class WorkflowTool(Tool): self.workflow_entities = workflow_entities self.workflow_call_depth = workflow_call_depth self.label = label + self._latest_usage = LLMUsage.empty_usage() super().__init__(entity=entity, runtime=runtime) @@ -83,10 +86,11 @@ class WorkflowTool(Tool): assert self.runtime.invoke_from is not None user = self._resolve_user(user_id=user_id) - if user is None: raise ToolInvokeError("User not found") + self._latest_usage = LLMUsage.empty_usage() + result = generator.generate( app_model=app, workflow=workflow, @@ -110,9 +114,68 @@ class WorkflowTool(Tool): for file in files: yield self.create_file_message(file) # type: ignore + self._latest_usage = self._derive_usage_from_result(data) + yield self.create_text_message(json.dumps(outputs, ensure_ascii=False)) yield self.create_json_message(outputs) + @property + def latest_usage(self) -> LLMUsage: + return self._latest_usage + + @classmethod + def _derive_usage_from_result(cls, data: Mapping[str, Any]) -> LLMUsage: + usage_dict = cls._extract_usage_dict(data) + if usage_dict is not None: + return LLMUsage.from_metadata(cast(LLMUsageMetadata, dict(usage_dict))) + + total_tokens = data.get("total_tokens") + total_price = data.get("total_price") + if total_tokens is None and total_price is None: + return LLMUsage.empty_usage() + + usage_metadata: dict[str, Any] = {} + if total_tokens is not None: + try: + usage_metadata["total_tokens"] = int(str(total_tokens)) + except (TypeError, ValueError): + pass + if total_price is not None: + usage_metadata["total_price"] = str(total_price) + currency = data.get("currency") + if currency is not None: + usage_metadata["currency"] = currency + + if not usage_metadata: + return LLMUsage.empty_usage() + + return LLMUsage.from_metadata(cast(LLMUsageMetadata, usage_metadata)) + + @classmethod + def _extract_usage_dict(cls, payload: Mapping[str, Any]) -> Mapping[str, Any] | None: + usage_candidate = payload.get("usage") + if isinstance(usage_candidate, Mapping): + return usage_candidate + + metadata_candidate = payload.get("metadata") + if isinstance(metadata_candidate, Mapping): + usage_candidate = metadata_candidate.get("usage") + if isinstance(usage_candidate, Mapping): + return usage_candidate + + for value in payload.values(): + if isinstance(value, Mapping): + found = cls._extract_usage_dict(value) + if found is not None: + return found + elif isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): + for item in value: + if isinstance(item, Mapping): + found = cls._extract_usage_dict(item) + if found is not None: + return found + return None + def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool": """ fork a new tool with metadata @@ -179,16 +242,17 @@ class WorkflowTool(Tool): """ get the workflow by app id and version """ - if not version: - workflow = ( - db.session.query(Workflow) - .where(Workflow.app_id == app_id, Workflow.version != Workflow.VERSION_DRAFT) - .order_by(Workflow.created_at.desc()) - .first() - ) - else: - stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version) - workflow = db.session.scalar(stmt) + with Session(db.engine, expire_on_commit=False) as session, session.begin(): + if not version: + stmt = ( + select(Workflow) + .where(Workflow.app_id == app_id, Workflow.version != Workflow.VERSION_DRAFT) + .order_by(Workflow.created_at.desc()) + ) + workflow = session.scalars(stmt).first() + else: + stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version) + workflow = session.scalar(stmt) if not workflow: raise ValueError("workflow not found or not published") @@ -200,7 +264,8 @@ class WorkflowTool(Tool): get the app by app id """ stmt = select(App).where(App.id == app_id) - app = db.session.scalar(stmt) + with Session(db.engine, expire_on_commit=False) as session, session.begin(): + app = session.scalar(stmt) if not app: raise ValueError("app not found") diff --git a/api/core/variables/segment_group.py b/api/core/variables/segment_group.py index 0a41b64228..b363255b2c 100644 --- a/api/core/variables/segment_group.py +++ b/api/core/variables/segment_group.py @@ -4,7 +4,7 @@ from .types import SegmentType class SegmentGroup(Segment): value_type: SegmentType = SegmentType.GROUP - value: list[Segment] = None # type: ignore + value: list[Segment] @property def text(self): diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index 0a50cccbca..95e3c2d1b6 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -19,7 +19,7 @@ class Segment(BaseModel): model_config = ConfigDict(frozen=True) value_type: SegmentType - value: Any = None + value: Any @field_validator("value_type") @classmethod @@ -74,12 +74,12 @@ class NoneSegment(Segment): class StringSegment(Segment): value_type: SegmentType = SegmentType.STRING - value: str = None # type: ignore + value: str class FloatSegment(Segment): value_type: SegmentType = SegmentType.FLOAT - value: float = None # type: ignore + value: float # NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems. # The following tests cannot pass. # @@ -98,12 +98,12 @@ class FloatSegment(Segment): class IntegerSegment(Segment): value_type: SegmentType = SegmentType.INTEGER - value: int = None # type: ignore + value: int class ObjectSegment(Segment): value_type: SegmentType = SegmentType.OBJECT - value: Mapping[str, Any] = None # type: ignore + value: Mapping[str, Any] @property def text(self) -> str: @@ -136,7 +136,7 @@ class ArraySegment(Segment): class FileSegment(Segment): value_type: SegmentType = SegmentType.FILE - value: File = None # type: ignore + value: File @property def markdown(self) -> str: @@ -153,17 +153,17 @@ class FileSegment(Segment): class BooleanSegment(Segment): value_type: SegmentType = SegmentType.BOOLEAN - value: bool = None # type: ignore + value: bool class ArrayAnySegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_ANY - value: Sequence[Any] = None # type: ignore + value: Sequence[Any] class ArrayStringSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_STRING - value: Sequence[str] = None # type: ignore + value: Sequence[str] @property def text(self) -> str: @@ -175,17 +175,17 @@ class ArrayStringSegment(ArraySegment): class ArrayNumberSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_NUMBER - value: Sequence[float | int] = None # type: ignore + value: Sequence[float | int] class ArrayObjectSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_OBJECT - value: Sequence[Mapping[str, Any]] = None # type: ignore + value: Sequence[Mapping[str, Any]] class ArrayFileSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_FILE - value: Sequence[File] = None # type: ignore + value: Sequence[File] @property def markdown(self) -> str: @@ -247,7 +247,7 @@ class VersionedMemorySegment(Segment): class ArrayBooleanSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_BOOLEAN - value: Sequence[bool] = None # type: ignore + value: Sequence[bool] def get_segment_discriminator(v: Any) -> SegmentType | None: diff --git a/api/core/workflow/entities/__init__.py b/api/core/workflow/entities/__init__.py index be70e467a0..185f0ad620 100644 --- a/api/core/workflow/entities/__init__.py +++ b/api/core/workflow/entities/__init__.py @@ -1,3 +1,5 @@ +from ..runtime.graph_runtime_state import GraphRuntimeState +from ..runtime.variable_pool import VariablePool from .agent import AgentNodeStrategyInit from .graph_init_params import GraphInitParams from .workflow_execution import WorkflowExecution @@ -6,6 +8,8 @@ from .workflow_node_execution import WorkflowNodeExecution __all__ = [ "AgentNodeStrategyInit", "GraphInitParams", + "GraphRuntimeState", + "VariablePool", "WorkflowExecution", "WorkflowNodeExecution", ] diff --git a/api/core/workflow/graph/graph.py b/api/core/workflow/graph/graph.py index 20b5193875..d04724425c 100644 --- a/api/core/workflow/graph/graph.py +++ b/api/core/workflow/graph/graph.py @@ -3,11 +3,12 @@ from collections import defaultdict from collections.abc import Mapping, Sequence from typing import Protocol, cast, final -from core.workflow.enums import NodeExecutionType, NodeState, NodeType +from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType from core.workflow.nodes.base.node import Node from libs.typing import is_str, is_str_dict from .edge import Edge +from .validation import get_graph_validator logger = logging.getLogger(__name__) @@ -201,6 +202,17 @@ class Graph: return GraphBuilder(graph_cls=cls) + @classmethod + def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None: + """ + Promote nodes configured with FAIL_BRANCH error strategy to branch execution type. + + :param nodes: mapping of node ID to node instance + """ + for node in nodes.values(): + if node.error_strategy == ErrorStrategy.FAIL_BRANCH: + node.execution_type = NodeExecutionType.BRANCH + @classmethod def _mark_inactive_root_branches( cls, @@ -307,6 +319,9 @@ class Graph: # Create node instances nodes = cls._create_node_instances(node_configs_map, node_factory) + # Promote fail-branch nodes to branch execution type at graph level + cls._promote_fail_branch_nodes(nodes) + # Get root node instance root_node = nodes[root_node_id] @@ -314,7 +329,7 @@ class Graph: cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id) # Create and return the graph - return cls( + graph = cls( nodes=nodes, edges=edges, in_edges=in_edges, @@ -322,6 +337,11 @@ class Graph: root_node=root_node, ) + # Validate the graph structure using built-in validators + get_graph_validator().validate(graph) + + return graph + @property def node_ids(self) -> list[str]: """ diff --git a/api/core/workflow/graph/validation.py b/api/core/workflow/graph/validation.py new file mode 100644 index 0000000000..87aa7db2e4 --- /dev/null +++ b/api/core/workflow/graph/validation.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Protocol + +from core.workflow.enums import NodeExecutionType, NodeType + +if TYPE_CHECKING: + from .graph import Graph + + +@dataclass(frozen=True, slots=True) +class GraphValidationIssue: + """Immutable value object describing a single validation issue.""" + + code: str + message: str + node_id: str | None = None + + +class GraphValidationError(ValueError): + """Raised when graph validation fails.""" + + def __init__(self, issues: Sequence[GraphValidationIssue]) -> None: + if not issues: + raise ValueError("GraphValidationError requires at least one issue.") + self.issues: tuple[GraphValidationIssue, ...] = tuple(issues) + message = "; ".join(f"[{issue.code}] {issue.message}" for issue in self.issues) + super().__init__(message) + + +class GraphValidationRule(Protocol): + """Protocol that individual validation rules must satisfy.""" + + def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: + """Validate the provided graph and return any discovered issues.""" + ... + + +@dataclass(frozen=True, slots=True) +class _EdgeEndpointValidator: + """Ensures all edges reference existing nodes.""" + + missing_node_code: str = "MISSING_NODE" + + def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: + issues: list[GraphValidationIssue] = [] + for edge in graph.edges.values(): + if edge.tail not in graph.nodes: + issues.append( + GraphValidationIssue( + code=self.missing_node_code, + message=f"Edge {edge.id} references unknown source node '{edge.tail}'.", + node_id=edge.tail, + ) + ) + if edge.head not in graph.nodes: + issues.append( + GraphValidationIssue( + code=self.missing_node_code, + message=f"Edge {edge.id} references unknown target node '{edge.head}'.", + node_id=edge.head, + ) + ) + return issues + + +@dataclass(frozen=True, slots=True) +class _RootNodeValidator: + """Validates root node invariants.""" + + invalid_root_code: str = "INVALID_ROOT" + container_entry_types: tuple[NodeType, ...] = (NodeType.ITERATION_START, NodeType.LOOP_START) + + def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: + root_node = graph.root_node + issues: list[GraphValidationIssue] = [] + if root_node.id not in graph.nodes: + issues.append( + GraphValidationIssue( + code=self.invalid_root_code, + message=f"Root node '{root_node.id}' is missing from the node registry.", + node_id=root_node.id, + ) + ) + return issues + + node_type = getattr(root_node, "node_type", None) + if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types: + issues.append( + GraphValidationIssue( + code=self.invalid_root_code, + message=f"Root node '{root_node.id}' must declare execution type 'root'.", + node_id=root_node.id, + ) + ) + return issues + + +@dataclass(frozen=True, slots=True) +class GraphValidator: + """Coordinates execution of graph validation rules.""" + + rules: tuple[GraphValidationRule, ...] + + def validate(self, graph: Graph) -> None: + """Validate the graph against all configured rules.""" + issues: list[GraphValidationIssue] = [] + for rule in self.rules: + issues.extend(rule.validate(graph)) + + if issues: + raise GraphValidationError(issues) + + +_DEFAULT_RULES: tuple[GraphValidationRule, ...] = ( + _EdgeEndpointValidator(), + _RootNodeValidator(), +) + + +def get_graph_validator() -> GraphValidator: + """Construct the validator composed of default rules.""" + return GraphValidator(_DEFAULT_RULES) diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index ce6eb33ecc..985ee5eef2 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -26,8 +26,8 @@ class AgentNodeData(BaseNodeData): class ParamsAutoGenerated(IntEnum): - CLOSE = auto() - OPEN = auto() + CLOSE = 0 + OPEN = 1 class AgentOldVersionModelFeatures(StrEnum): diff --git a/api/core/workflow/nodes/base/__init__.py b/api/core/workflow/nodes/base/__init__.py index 8cf31dc342..f83df0e323 100644 --- a/api/core/workflow/nodes/base/__init__.py +++ b/api/core/workflow/nodes/base/__init__.py @@ -1,4 +1,5 @@ from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData +from .usage_tracking_mixin import LLMUsageTrackingMixin __all__ = [ "BaseIterationNodeData", @@ -6,4 +7,5 @@ __all__ = [ "BaseLoopNodeData", "BaseLoopState", "BaseNodeData", + "LLMUsageTrackingMixin", ] diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index 5aef9d79cf..94b0d1d8bc 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -1,5 +1,6 @@ import json from abc import ABC +from builtins import type as type_ from collections.abc import Sequence from enum import StrEnum from typing import Any, Union @@ -58,10 +59,9 @@ class DefaultValue(BaseModel): raise DefaultValueTypeError(f"Invalid JSON format for value: {value}") @staticmethod - def _validate_array(value: Any, element_type: DefaultValueType) -> bool: + def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool: """Unified array type validation""" - # FIXME, type ignore here for do not find the reason mypy complain, if find the root cause, please fix it - return isinstance(value, list) and all(isinstance(x, element_type) for x in value) # type: ignore + return isinstance(value, list) and all(isinstance(x, element_type) for x in value) @staticmethod def _convert_number(value: str) -> float: diff --git a/api/core/workflow/nodes/base/usage_tracking_mixin.py b/api/core/workflow/nodes/base/usage_tracking_mixin.py new file mode 100644 index 0000000000..d9a0ef8972 --- /dev/null +++ b/api/core/workflow/nodes/base/usage_tracking_mixin.py @@ -0,0 +1,28 @@ +from core.model_runtime.entities.llm_entities import LLMUsage +from core.workflow.runtime import GraphRuntimeState + + +class LLMUsageTrackingMixin: + """Provides shared helpers for merging and recording LLM usage within workflow nodes.""" + + graph_runtime_state: GraphRuntimeState + + @staticmethod + def _merge_usage(current: LLMUsage, new_usage: LLMUsage | None) -> LLMUsage: + """Return a combined usage snapshot, preserving zero-value inputs.""" + if new_usage is None or new_usage.total_tokens <= 0: + return current + if current.total_tokens == 0: + return new_usage + return current.plus(new_usage) + + def _accumulate_usage(self, usage: LLMUsage) -> None: + """Push usage into the graph runtime accumulator for downstream reporting.""" + if usage.total_tokens <= 0: + return + + current_usage = self.graph_runtime_state.llm_usage + if current_usage.total_tokens == 0: + self.graph_runtime_state.llm_usage = usage.model_copy() + else: + self.graph_runtime_state.llm_usage = current_usage.plus(usage) diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index ae1061d72c..cd5f50aaab 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -10,10 +10,10 @@ from typing import Any import chardet import docx import pandas as pd -import pypandoc # type: ignore -import pypdfium2 # type: ignore -import webvtt # type: ignore -import yaml # type: ignore +import pypandoc +import pypdfium2 +import webvtt +import yaml from docx.document import Document from docx.oxml.table import CT_Tbl from docx.oxml.text.paragraph import CT_P diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 41060bd569..3a3a2290be 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, NewType, cast from flask import Flask, current_app from typing_extensions import TypeIs +from core.model_runtime.entities.llm_entities import LLMUsage from core.variables import IntegerVariable, NoneSegment from core.variables.segments import ArrayAnySegment, ArraySegment from core.variables.variables import VariableUnion @@ -34,6 +35,7 @@ from core.workflow.node_events import ( NodeRunResult, StreamCompletedEvent, ) +from core.workflow.nodes.base import LLMUsageTrackingMixin from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData @@ -58,7 +60,7 @@ logger = logging.getLogger(__name__) EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment) -class IterationNode(Node): +class IterationNode(LLMUsageTrackingMixin, Node): """ Iteration Node. """ @@ -118,6 +120,7 @@ class IterationNode(Node): started_at = naive_utc_now() iter_run_map: dict[str, float] = {} outputs: list[object] = [] + usage_accumulator = [LLMUsage.empty_usage()] yield IterationStartedEvent( start_at=started_at, @@ -130,22 +133,27 @@ class IterationNode(Node): iterator_list_value=iterator_list_value, outputs=outputs, iter_run_map=iter_run_map, + usage_accumulator=usage_accumulator, ) + self._accumulate_usage(usage_accumulator[0]) yield from self._handle_iteration_success( started_at=started_at, inputs=inputs, outputs=outputs, iterator_list_value=iterator_list_value, iter_run_map=iter_run_map, + usage=usage_accumulator[0], ) except IterationNodeError as e: + self._accumulate_usage(usage_accumulator[0]) yield from self._handle_iteration_failure( started_at=started_at, inputs=inputs, outputs=outputs, iterator_list_value=iterator_list_value, iter_run_map=iter_run_map, + usage=usage_accumulator[0], error=e, ) @@ -196,6 +204,7 @@ class IterationNode(Node): iterator_list_value: Sequence[object], outputs: list[object], iter_run_map: dict[str, float], + usage_accumulator: list[LLMUsage], ) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: if self._node_data.is_parallel: # Parallel mode execution @@ -203,6 +212,7 @@ class IterationNode(Node): iterator_list_value=iterator_list_value, outputs=outputs, iter_run_map=iter_run_map, + usage_accumulator=usage_accumulator, ) else: # Sequential mode execution @@ -228,6 +238,9 @@ class IterationNode(Node): # Update the total tokens from this iteration self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens + usage_accumulator[0] = self._merge_usage( + usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage + ) iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() def _execute_parallel_iterations( @@ -235,6 +248,7 @@ class IterationNode(Node): iterator_list_value: Sequence[object], outputs: list[object], iter_run_map: dict[str, float], + usage_accumulator: list[LLMUsage], ) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: # Initialize outputs list with None values to maintain order outputs.extend([None] * len(iterator_list_value)) @@ -245,7 +259,16 @@ class IterationNode(Node): with ThreadPoolExecutor(max_workers=max_workers) as executor: # Submit all iteration tasks future_to_index: dict[ - Future[tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]], + Future[ + tuple[ + datetime, + list[GraphNodeEventBase], + object | None, + int, + dict[str, VariableUnion], + LLMUsage, + ] + ], int, ] = {} for index, item in enumerate(iterator_list_value): @@ -264,7 +287,14 @@ class IterationNode(Node): index = future_to_index[future] try: result = future.result() - iter_start_at, events, output_value, tokens_used, conversation_snapshot = result + ( + iter_start_at, + events, + output_value, + tokens_used, + conversation_snapshot, + iteration_usage, + ) = result # Update outputs at the correct index outputs[index] = output_value @@ -276,6 +306,8 @@ class IterationNode(Node): self.graph_runtime_state.total_tokens += tokens_used iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() + usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage) + # Sync conversation variables after iteration completion self._sync_conversation_variables_from_snapshot(conversation_snapshot) @@ -303,7 +335,7 @@ class IterationNode(Node): item: object, flask_app: Flask, context_vars: contextvars.Context, - ) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]: + ) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion], LLMUsage]: """Execute a single iteration in parallel mode and return results.""" with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars): iter_start_at = datetime.now(UTC).replace(tzinfo=None) @@ -332,6 +364,7 @@ class IterationNode(Node): output_value, graph_engine.graph_runtime_state.total_tokens, conversation_snapshot, + graph_engine.graph_runtime_state.llm_usage, ) def _handle_iteration_success( @@ -341,6 +374,8 @@ class IterationNode(Node): outputs: list[object], iterator_list_value: Sequence[object], iter_run_map: dict[str, float], + *, + usage: LLMUsage, ) -> Generator[NodeEventBase, None, None]: # Flatten the list of lists if all outputs are lists flattened_outputs = self._flatten_outputs_if_needed(outputs) @@ -351,7 +386,9 @@ class IterationNode(Node): outputs={"output": flattened_outputs}, steps=len(iterator_list_value), metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, }, ) @@ -362,8 +399,11 @@ class IterationNode(Node): status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"output": flattened_outputs}, metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, }, + llm_usage=usage, ) ) @@ -400,6 +440,8 @@ class IterationNode(Node): outputs: list[object], iterator_list_value: Sequence[object], iter_run_map: dict[str, float], + *, + usage: LLMUsage, error: IterationNodeError, ) -> Generator[NodeEventBase, None, None]: # Flatten the list of lists if all outputs are lists (even in failure case) @@ -411,7 +453,9 @@ class IterationNode(Node): outputs={"output": flattened_outputs}, steps=len(iterator_list_value), metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, }, error=str(error), @@ -420,6 +464,12 @@ class IterationNode(Node): node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, error=str(error), + metadata={ + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, + }, + llm_usage=usage, ) ) diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 2dc3cb9320..4a63900527 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -15,14 +15,11 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti from core.entities.agent_entities import PlanningStrategy from core.entities.model_entities import ModelStatus from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.message_entities import ( - PromptMessageRole, -) -from core.model_runtime.entities.model_entities import ( - ModelFeature, - ModelType, -) +from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.message_entities import PromptMessageRole +from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.simple_prompt_transform import ModelMode from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.metadata_entities import Condition, MetadataCondition @@ -33,8 +30,14 @@ from core.variables import ( ) from core.variables.segments import ArrayObjectSegment from core.workflow.entities import GraphInitParams -from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus +from core.workflow.enums import ( + ErrorStrategy, + NodeType, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult +from core.workflow.nodes.base import LLMUsageTrackingMixin from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.knowledge_retrieval.template_prompts import ( @@ -80,7 +83,7 @@ default_retrieval_model = { } -class KnowledgeRetrievalNode(Node): +class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node): node_type = NodeType.KNOWLEDGE_RETRIEVAL _node_data: KnowledgeRetrievalNodeData @@ -141,7 +144,7 @@ class KnowledgeRetrievalNode(Node): def version(cls): return "1" - def _run(self) -> NodeRunResult: # type: ignore + def _run(self) -> NodeRunResult: # extract variables variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector) if not isinstance(variable, StringSegment): @@ -182,14 +185,21 @@ class KnowledgeRetrievalNode(Node): ) # retrieve knowledge + usage = LLMUsage.empty_usage() try: - results = self._fetch_dataset_retriever(node_data=self._node_data, query=query) + results, usage = self._fetch_dataset_retriever(node_data=self._node_data, query=query) outputs = {"result": ArrayObjectSegment(value=results)} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, - process_data={}, + process_data={"usage": jsonable_encoder(usage)}, outputs=outputs, # type: ignore + metadata={ + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, + }, + llm_usage=usage, ) except KnowledgeRetrievalNodeError as e: @@ -199,6 +209,7 @@ class KnowledgeRetrievalNode(Node): inputs=variables, error=str(e), error_type=type(e).__name__, + llm_usage=usage, ) # Temporary handle all exceptions from DatasetRetrieval class here. except Exception as e: @@ -207,11 +218,15 @@ class KnowledgeRetrievalNode(Node): inputs=variables, error=str(e), error_type=type(e).__name__, + llm_usage=usage, ) finally: db.session.close() - def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]: + def _fetch_dataset_retriever( + self, node_data: KnowledgeRetrievalNodeData, query: str + ) -> tuple[list[dict[str, Any]], LLMUsage]: + usage = LLMUsage.empty_usage() available_datasets = [] dataset_ids = node_data.dataset_ids @@ -245,9 +260,10 @@ class KnowledgeRetrievalNode(Node): if not dataset: continue available_datasets.append(dataset) - metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition( + metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition( [dataset.id for dataset in available_datasets], query, node_data ) + usage = self._merge_usage(usage, metadata_usage) all_documents = [] dataset_retrieval = DatasetRetrieval() if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: @@ -330,6 +346,8 @@ class KnowledgeRetrievalNode(Node): metadata_filter_document_ids=metadata_filter_document_ids, metadata_condition=metadata_condition, ) + usage = self._merge_usage(usage, dataset_retrieval.llm_usage) + dify_documents = [item for item in all_documents if item.provider == "dify"] external_documents = [item for item in all_documents if item.provider == "external"] retrieval_resource_list = [] @@ -406,11 +424,12 @@ class KnowledgeRetrievalNode(Node): ) for position, item in enumerate(retrieval_resource_list, start=1): item["metadata"]["position"] = position - return retrieval_resource_list + return retrieval_resource_list, usage def _get_metadata_filter_condition( self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData - ) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]: + ) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]: + usage = LLMUsage.empty_usage() document_query = db.session.query(Document).where( Document.dataset_id.in_(dataset_ids), Document.indexing_status == "completed", @@ -420,9 +439,12 @@ class KnowledgeRetrievalNode(Node): filters: list[Any] = [] metadata_condition = None if node_data.metadata_filtering_mode == "disabled": - return None, None + return None, None, usage elif node_data.metadata_filtering_mode == "automatic": - automatic_metadata_filters = self._automatic_metadata_filter_func(dataset_ids, query, node_data) + automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func( + dataset_ids, query, node_data + ) + usage = self._merge_usage(usage, automatic_usage) if automatic_metadata_filters: conditions = [] for sequence, filter in enumerate(automatic_metadata_filters): @@ -443,7 +465,7 @@ class KnowledgeRetrievalNode(Node): metadata_condition = MetadataCondition( logical_operator=node_data.metadata_filtering_conditions.logical_operator if node_data.metadata_filtering_conditions - else "or", # type: ignore + else "or", conditions=conditions, ) elif node_data.metadata_filtering_mode == "manual": @@ -457,10 +479,10 @@ class KnowledgeRetrievalNode(Node): expected_value = self.graph_runtime_state.variable_pool.convert_template( expected_value ).value[0] - if expected_value.value_type in {"number", "integer", "float"}: # type: ignore - expected_value = expected_value.value # type: ignore - elif expected_value.value_type == "string": # type: ignore - expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore + if expected_value.value_type in {"number", "integer", "float"}: + expected_value = expected_value.value + elif expected_value.value_type == "string": + expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() else: raise ValueError("Invalid expected metadata value type") conditions.append( @@ -487,7 +509,7 @@ class KnowledgeRetrievalNode(Node): if ( node_data.metadata_filtering_conditions and node_data.metadata_filtering_conditions.logical_operator == "and" - ): # type: ignore + ): document_query = document_query.where(and_(*filters)) else: document_query = document_query.where(or_(*filters)) @@ -496,11 +518,12 @@ class KnowledgeRetrievalNode(Node): metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore for document in documents: metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore - return metadata_filter_document_ids, metadata_condition + return metadata_filter_document_ids, metadata_condition, usage def _automatic_metadata_filter_func( self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData - ) -> list[dict[str, Any]]: + ) -> tuple[list[dict[str, Any]], LLMUsage]: + usage = LLMUsage.empty_usage() # get all metadata field stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)) metadata_fields = db.session.scalars(stmt).all() @@ -548,6 +571,7 @@ class KnowledgeRetrievalNode(Node): for event in generator: if isinstance(event, ModelInvokeCompletedEvent): result_text = event.text + usage = self._merge_usage(usage, event.usage) break result_text_json = parse_and_check_json_markdown(result_text, []) @@ -564,8 +588,8 @@ class KnowledgeRetrievalNode(Node): } ) except Exception: - return [] - return automatic_metadata_filters + return [], usage + return automatic_metadata_filters, usage def _process_metadata_filter_func( self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any] diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index bbbd55a3b0..330109418f 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -452,10 +452,14 @@ class LLMNode(Node): usage = LLMUsage.empty_usage() finish_reason = None full_text_buffer = io.StringIO() + collected_structured_output = None # Collect structured_output from streaming chunks # Consume the invoke result and handle generator exception try: for result in invoke_result: if isinstance(result, LLMResultChunkWithStructuredOutput): + # Collect structured_output from the chunk + if result.structured_output is not None: + collected_structured_output = dict(result.structured_output) yield result if isinstance(result, LLMResultChunk): contents = result.delta.message.content @@ -503,6 +507,8 @@ class LLMNode(Node): finish_reason=finish_reason, # Reasoning content for workflow variables and downstream nodes reasoning_content=reasoning_content, + # Pass structured output if collected from streaming chunks + structured_output=collected_structured_output, ) @staticmethod diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index b51790c0a2..ca39e5aa23 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -5,6 +5,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence from datetime import datetime from typing import TYPE_CHECKING, Any, Literal, cast +from core.model_runtime.entities.llm_entities import LLMUsage from core.variables import Segment, SegmentType from core.workflow.enums import ( ErrorStrategy, @@ -27,6 +28,7 @@ from core.workflow.node_events import ( NodeRunResult, StreamCompletedEvent, ) +from core.workflow.nodes.base import LLMUsageTrackingMixin from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData @@ -40,7 +42,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class LoopNode(Node): +class LoopNode(LLMUsageTrackingMixin, Node): """ Loop Node. """ @@ -108,7 +110,7 @@ class LoopNode(Node): raise ValueError(f"Invalid value for loop variable {loop_variable.label}") variable_selector = [self._node_id, loop_variable.label] variable = segment_to_variable(segment=processed_segment, selector=variable_selector) - self.graph_runtime_state.variable_pool.add(variable_selector, variable) + self.graph_runtime_state.variable_pool.add(variable_selector, variable.value) loop_variable_selectors[loop_variable.label] = variable_selector inputs[loop_variable.label] = processed_segment.value @@ -117,6 +119,7 @@ class LoopNode(Node): loop_duration_map: dict[str, float] = {} single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output + loop_usage = LLMUsage.empty_usage() # Start Loop event yield LoopStartedEvent( @@ -163,6 +166,9 @@ class LoopNode(Node): # Update the total tokens from this iteration cost_tokens += graph_engine.graph_runtime_state.total_tokens + # Accumulate usage from the sub-graph execution + loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage) + # Collect loop variable values after iteration single_loop_variable = {} for key, selector in loop_variable_selectors.items(): @@ -189,6 +195,7 @@ class LoopNode(Node): ) self.graph_runtime_state.total_tokens += cost_tokens + self._accumulate_usage(loop_usage) # Loop completed successfully yield LoopSucceededEvent( start_at=start_at, @@ -196,7 +203,9 @@ class LoopNode(Node): outputs=self._node_data.outputs, steps=loop_count, metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: cost_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, "completed_reason": "loop_break" if reach_break_condition else "loop_completed", WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, @@ -207,22 +216,28 @@ class LoopNode(Node): node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, }, outputs=self._node_data.outputs, inputs=inputs, + llm_usage=loop_usage, ) ) except Exception as e: + self._accumulate_usage(loop_usage) yield LoopFailedEvent( start_at=start_at, inputs=inputs, steps=loop_count, metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, "completed_reason": "error", WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, @@ -235,10 +250,13 @@ class LoopNode(Node): status=WorkflowNodeExecutionStatus.FAILED, error=str(e), metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, }, + llm_usage=loop_usage, ) ) diff --git a/api/core/workflow/nodes/node_factory.py b/api/core/workflow/nodes/node_factory.py index 87d1b8c435..84f63d57eb 100644 --- a/api/core/workflow/nodes/node_factory.py +++ b/api/core/workflow/nodes/node_factory.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, final from typing_extensions import override -from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType +from core.workflow.enums import NodeType from core.workflow.graph import NodeFactory from core.workflow.nodes.base.node import Node from libs.typing import is_str, is_str_dict @@ -82,8 +82,4 @@ class DifyNodeFactory(NodeFactory): raise ValueError(f"Node {node_id} missing data information") node_instance.init_node_data(node_data) - # If node has fail branch, change execution type to branch - if node_instance.error_strategy == ErrorStrategy.FAIL_BRANCH: - node_instance.execution_type = NodeExecutionType.BRANCH - return node_instance diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 2b65cc30b6..e250650fef 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -747,7 +747,7 @@ class ParameterExtractorNode(Node): if model_mode == ModelMode.CHAT: system_prompt_messages = ChatModelMessage( role=PromptMessageRole.SYSTEM, - text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str).replace("{{instructions}}", instruction), + text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str, instructions=instruction), ) user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) return [system_prompt_messages, user_prompt_message] diff --git a/api/core/workflow/nodes/parameter_extractor/prompts.py b/api/core/workflow/nodes/parameter_extractor/prompts.py index b74be8f206..1b29be4418 100644 --- a/api/core/workflow/nodes/parameter_extractor/prompts.py +++ b/api/core/workflow/nodes/parameter_extractor/prompts.py @@ -135,7 +135,7 @@ Here are the chat histories between human and assistant, inside -{{instructions}} +{instructions} """ diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 2e2c32ac93..69ab6f0718 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -6,10 +6,13 @@ from sqlalchemy.orm import Session from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.file import File, FileTransferMethod +from core.model_runtime.entities.llm_entities import LLMUsage +from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.errors import ToolInvokeError from core.tools.tool_engine import ToolEngine from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.tools.workflow_as_tool.tool import WorkflowTool from core.variables.segments import ArrayAnySegment, ArrayFileSegment from core.variables.variables import ArrayAnyVariable from core.workflow.enums import ( @@ -136,13 +139,14 @@ class ToolNode(Node): try: # convert tool messages - yield from self._transform_message( + _ = yield from self._transform_message( messages=message_stream, tool_info=tool_info, parameters_for_log=parameters_for_log, user_id=self.user_id, tenant_id=self.tenant_id, node_id=self._node_id, + tool_runtime=tool_runtime, ) except ToolInvokeError as e: yield StreamCompletedEvent( @@ -236,7 +240,8 @@ class ToolNode(Node): user_id: str, tenant_id: str, node_id: str, - ) -> Generator: + tool_runtime: Tool, + ) -> Generator[NodeEventBase, None, LLMUsage]: """ Convert ToolInvokeMessages into tuple[plain_text, files] """ @@ -424,17 +429,34 @@ class ToolNode(Node): is_final=True, ) + usage = self._extract_tool_usage(tool_runtime) + + metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = { + WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, + } + if usage.total_tokens > 0: + metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens + metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price + metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency + yield StreamCompletedEvent( node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables}, - metadata={ - WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, - }, + metadata=metadata, inputs=parameters_for_log, + llm_usage=usage, ) ) + return usage + + @staticmethod + def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage: + if isinstance(tool_runtime, WorkflowTool): + return tool_runtime.latest_usage + return LLMUsage.empty_usage() + @classmethod def _extract_variable_selector_to_variable_mapping( cls, diff --git a/api/core/workflow/runtime/variable_pool.py b/api/core/workflow/runtime/variable_pool.py index 7a73797128..353f7ba373 100644 --- a/api/core/workflow/runtime/variable_pool.py +++ b/api/core/workflow/runtime/variable_pool.py @@ -277,7 +277,7 @@ class VariablePool(BaseModel): # This ensures that we can keep the id of the system variables intact. if self._has(selector): continue - self.add(selector, value) # type: ignore + self.add(selector, value) @classmethod def empty(cls) -> "VariablePool": diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 05ad1d575b..27da5ebb57 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -32,7 +32,8 @@ if [[ "${MODE}" == "worker" ]]; then exec celery -A celery_entrypoint.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \ --max-tasks-per-child ${MAX_TASKS_PER_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \ - -Q ${CELERY_QUEUES:-dataset,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation} + -Q ${CELERY_QUEUES:-dataset,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation} \ + --prefetch-multiplier=1 elif [[ "${MODE}" == "beat" ]]; then exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO} diff --git a/api/events/event_handlers/clean_when_dataset_deleted.py b/api/events/event_handlers/clean_when_dataset_deleted.py index 0f6aa0e778..1666e2e29f 100644 --- a/api/events/event_handlers/clean_when_dataset_deleted.py +++ b/api/events/event_handlers/clean_when_dataset_deleted.py @@ -6,8 +6,8 @@ from tasks.clean_dataset_task import clean_dataset_task @dataset_was_deleted.connect def handle(sender: Dataset, **kwargs): dataset = sender - assert dataset.doc_form - assert dataset.indexing_technique + if not dataset.doc_form or not dataset.indexing_technique: + return clean_dataset_task.delay( dataset.id, dataset.tenant_id, diff --git a/api/events/event_handlers/clean_when_document_deleted.py b/api/events/event_handlers/clean_when_document_deleted.py index bbc913b7cf..0add109b06 100644 --- a/api/events/event_handlers/clean_when_document_deleted.py +++ b/api/events/event_handlers/clean_when_document_deleted.py @@ -8,6 +8,6 @@ def handle(sender, **kwargs): dataset_id = kwargs.get("dataset_id") doc_form = kwargs.get("doc_form") file_id = kwargs.get("file_id") - assert dataset_id is not None - assert doc_form is not None + if not dataset_id or not doc_form: + return clean_document_task.delay(document_id, dataset_id, doc_form, file_id) diff --git a/api/extensions/ext_blueprints.py b/api/extensions/ext_blueprints.py index 6027833419..b7cb6bc44b 100644 --- a/api/extensions/ext_blueprints.py +++ b/api/extensions/ext_blueprints.py @@ -2,6 +2,11 @@ from configs import dify_config from constants import HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN, HEADER_NAME_PASSPORT from dify_app import DifyApp +BASE_CORS_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE, HEADER_NAME_PASSPORT) +SERVICE_API_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, "Authorization") +AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN) +FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN) + def init_app(app: DifyApp): # register blueprint routers @@ -17,7 +22,7 @@ def init_app(app: DifyApp): CORS( service_api_bp, - allow_headers=["Content-Type", "Authorization", HEADER_NAME_APP_CODE], + allow_headers=list(SERVICE_API_HEADERS), methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], ) app.register_blueprint(service_api_bp) @@ -27,11 +32,11 @@ def init_app(app: DifyApp): resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}}, supports_credentials=True, allow_headers=[ - "Content-Type", - "Authorization", - HEADER_NAME_APP_CODE, - HEADER_NAME_CSRF_TOKEN, - HEADER_NAME_PASSPORT + "Content-Type", + "Authorization", + HEADER_NAME_APP_CODE, + HEADER_NAME_CSRF_TOKEN, + HEADER_NAME_PASSPORT, ], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], expose_headers=["X-Version", "X-Env"], @@ -42,7 +47,7 @@ def init_app(app: DifyApp): console_app_bp, resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}}, supports_credentials=True, - allow_headers=["Content-Type", "Authorization", HEADER_NAME_CSRF_TOKEN], + allow_headers=list(AUTHENTICATED_HEADERS), methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], expose_headers=["X-Version", "X-Env"], ) @@ -50,7 +55,7 @@ def init_app(app: DifyApp): CORS( files_bp, - allow_headers=["Content-Type", HEADER_NAME_CSRF_TOKEN], + allow_headers=list(FILES_HEADERS), methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], ) app.register_blueprint(files_bp) diff --git a/api/extensions/ext_compress.py b/api/extensions/ext_compress.py index 26ff6427be..9c3a663af4 100644 --- a/api/extensions/ext_compress.py +++ b/api/extensions/ext_compress.py @@ -7,7 +7,7 @@ def is_enabled() -> bool: def init_app(app: DifyApp): - from flask_compress import Compress # type: ignore + from flask_compress import Compress compress = Compress() compress.init_app(app) diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index e7816a2e88..ed4fe332c1 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -1,6 +1,6 @@ import json -import flask_login # type: ignore +import flask_login from flask import Response, request from flask_login import user_loaded_from_request, user_logged_in from werkzeug.exceptions import NotFound, Unauthorized diff --git a/api/extensions/ext_migrate.py b/api/extensions/ext_migrate.py index 5f862181fa..6d8f35c30d 100644 --- a/api/extensions/ext_migrate.py +++ b/api/extensions/ext_migrate.py @@ -2,7 +2,7 @@ from dify_app import DifyApp def init_app(app: DifyApp): - import flask_migrate # type: ignore + import flask_migrate from extensions.ext_database import db diff --git a/api/extensions/ext_otel.py b/api/extensions/ext_otel.py index cb6e4849a9..20ac2503a2 100644 --- a/api/extensions/ext_otel.py +++ b/api/extensions/ext_otel.py @@ -103,7 +103,7 @@ def init_app(app: DifyApp): def shutdown_tracer(): provider = trace.get_tracer_provider() if hasattr(provider, "force_flush"): - provider.force_flush() # ty: ignore [call-non-callable] + provider.force_flush() class ExceptionLoggingHandler(logging.Handler): """Custom logging handler that creates spans for logging.exception() calls""" diff --git a/api/extensions/ext_proxy_fix.py b/api/extensions/ext_proxy_fix.py index c085aed986..fe6685f633 100644 --- a/api/extensions/ext_proxy_fix.py +++ b/api/extensions/ext_proxy_fix.py @@ -6,4 +6,4 @@ def init_app(app: DifyApp): if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED: from werkzeug.middleware.proxy_fix import ProxyFix - app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore + app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore[method-assign] diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index 5ed7840211..c3aa8edf80 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -5,7 +5,7 @@ from dify_app import DifyApp def init_app(app: DifyApp): if dify_config.SENTRY_DSN: import sentry_sdk - from langfuse import parse_error # type: ignore + from langfuse import parse_error from sentry_sdk.integrations.celery import CeleryIntegration from sentry_sdk.integrations.flask import FlaskIntegration from werkzeug.exceptions import HTTPException diff --git a/api/extensions/storage/aliyun_oss_storage.py b/api/extensions/storage/aliyun_oss_storage.py index 5da4737138..2283581f62 100644 --- a/api/extensions/storage/aliyun_oss_storage.py +++ b/api/extensions/storage/aliyun_oss_storage.py @@ -1,7 +1,7 @@ import posixpath from collections.abc import Generator -import oss2 as aliyun_s3 # type: ignore +import oss2 as aliyun_s3 from configs import dify_config from extensions.storage.base_storage import BaseStorage diff --git a/api/extensions/storage/baidu_obs_storage.py b/api/extensions/storage/baidu_obs_storage.py index b94efa08be..0bb4648c0a 100644 --- a/api/extensions/storage/baidu_obs_storage.py +++ b/api/extensions/storage/baidu_obs_storage.py @@ -2,9 +2,9 @@ import base64 import hashlib from collections.abc import Generator -from baidubce.auth.bce_credentials import BceCredentials # type: ignore -from baidubce.bce_client_configuration import BceClientConfiguration # type: ignore -from baidubce.services.bos.bos_client import BosClient # type: ignore +from baidubce.auth.bce_credentials import BceCredentials +from baidubce.bce_client_configuration import BceClientConfiguration +from baidubce.services.bos.bos_client import BosClient from configs import dify_config from extensions.storage.base_storage import BaseStorage diff --git a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py index 06c528ca41..1cabc57e74 100644 --- a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py +++ b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py @@ -11,7 +11,7 @@ from collections.abc import Generator from io import BytesIO from pathlib import Path -import clickzetta # type: ignore[import] +import clickzetta from pydantic import BaseModel, model_validator from extensions.storage.base_storage import BaseStorage diff --git a/api/extensions/storage/clickzetta_volume/volume_permissions.py b/api/extensions/storage/clickzetta_volume/volume_permissions.py index 6dcf800abb..9d4ca689d8 100644 --- a/api/extensions/storage/clickzetta_volume/volume_permissions.py +++ b/api/extensions/storage/clickzetta_volume/volume_permissions.py @@ -34,7 +34,7 @@ class VolumePermissionManager: # Support two initialization methods: connection object or configuration dictionary if isinstance(connection_or_config, dict): # Create connection from configuration dictionary - import clickzetta # type: ignore[import-untyped] + import clickzetta config = connection_or_config self._connection = clickzetta.connect( diff --git a/api/extensions/storage/google_cloud_storage.py b/api/extensions/storage/google_cloud_storage.py index 7f59252f2f..d352996518 100644 --- a/api/extensions/storage/google_cloud_storage.py +++ b/api/extensions/storage/google_cloud_storage.py @@ -3,7 +3,7 @@ import io import json from collections.abc import Generator -from google.cloud import storage as google_cloud_storage # type: ignore +from google.cloud import storage as google_cloud_storage from configs import dify_config from extensions.storage.base_storage import BaseStorage diff --git a/api/extensions/storage/huawei_obs_storage.py b/api/extensions/storage/huawei_obs_storage.py index 3e75ecb7a9..74fed26f65 100644 --- a/api/extensions/storage/huawei_obs_storage.py +++ b/api/extensions/storage/huawei_obs_storage.py @@ -1,6 +1,6 @@ from collections.abc import Generator -from obs import ObsClient # type: ignore +from obs import ObsClient from configs import dify_config from extensions.storage.base_storage import BaseStorage diff --git a/api/extensions/storage/oracle_oci_storage.py b/api/extensions/storage/oracle_oci_storage.py index acc00cbd6b..c032803045 100644 --- a/api/extensions/storage/oracle_oci_storage.py +++ b/api/extensions/storage/oracle_oci_storage.py @@ -1,7 +1,7 @@ from collections.abc import Generator -import boto3 # type: ignore -from botocore.exceptions import ClientError # type: ignore +import boto3 +from botocore.exceptions import ClientError from configs import dify_config from extensions.storage.base_storage import BaseStorage diff --git a/api/extensions/storage/tencent_cos_storage.py b/api/extensions/storage/tencent_cos_storage.py index 9cdd3e67f7..ea5d982efc 100644 --- a/api/extensions/storage/tencent_cos_storage.py +++ b/api/extensions/storage/tencent_cos_storage.py @@ -1,6 +1,6 @@ from collections.abc import Generator -from qcloud_cos import CosConfig, CosS3Client # type: ignore +from qcloud_cos import CosConfig, CosS3Client from configs import dify_config from extensions.storage.base_storage import BaseStorage diff --git a/api/extensions/storage/volcengine_tos_storage.py b/api/extensions/storage/volcengine_tos_storage.py index 8ed8e4c170..a44959221f 100644 --- a/api/extensions/storage/volcengine_tos_storage.py +++ b/api/extensions/storage/volcengine_tos_storage.py @@ -1,6 +1,6 @@ from collections.abc import Generator -import tos # type: ignore +import tos from configs import dify_config from extensions.storage.base_storage import BaseStorage diff --git a/api/libs/external_api.py b/api/libs/external_api.py index f3ebcc4306..1a4fde960c 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -146,6 +146,6 @@ class ExternalApi(Api): kwargs["doc"] = dify_config.SWAGGER_UI_PATH if dify_config.SWAGGER_UI_ENABLED else False # manual separate call on construction and init_app to ensure configs in kwargs effective - super().__init__(app=None, *args, **kwargs) # type: ignore + super().__init__(app=None, *args, **kwargs) self.init_app(app, **kwargs) register_external_error_handlers(self) diff --git a/api/libs/gmpy2_pkcs10aep_cipher.py b/api/libs/gmpy2_pkcs10aep_cipher.py index fc38d51005..23eb8dca05 100644 --- a/api/libs/gmpy2_pkcs10aep_cipher.py +++ b/api/libs/gmpy2_pkcs10aep_cipher.py @@ -23,7 +23,7 @@ from hashlib import sha1 import Crypto.Hash.SHA1 import Crypto.Util.number -import gmpy2 # type: ignore +import gmpy2 from Crypto import Random from Crypto.Signature.pss import MGF1 from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes @@ -136,7 +136,7 @@ class PKCS1OAepCipher: # Step 3a (OS2IP) em_int = bytes_to_long(em) # Step 3b (RSAEP) - m_int = gmpy2.powmod(em_int, self._key.e, self._key.n) # ty: ignore [unresolved-attribute] + m_int = gmpy2.powmod(em_int, self._key.e, self._key.n) # Step 3c (I2OSP) c = long_to_bytes(m_int, k) return c @@ -169,7 +169,7 @@ class PKCS1OAepCipher: ct_int = bytes_to_long(ciphertext) # Step 2b (RSADP) # m_int = self._key._decrypt(ct_int) - m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n) # ty: ignore [unresolved-attribute] + m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n) # Complete step 2c (I2OSP) em = long_to_bytes(m_int, k) # Step 3a @@ -191,12 +191,12 @@ class PKCS1OAepCipher: # Step 3g one_pos = hLen + db[hLen:].find(b"\x01") lHash1 = db[:hLen] - invalid = bord(y) | int(one_pos < hLen) # type: ignore + invalid = bord(y) | int(one_pos < hLen) # type: ignore[arg-type] hash_compare = strxor(lHash1, lHash) for x in hash_compare: - invalid |= bord(x) # type: ignore + invalid |= bord(x) # type: ignore[arg-type] for x in db[hLen:one_pos]: - invalid |= bord(x) # type: ignore + invalid |= bord(x) # type: ignore[arg-type] if invalid != 0: raise ValueError("Incorrect decryption.") # Step 4 diff --git a/api/libs/helper.py b/api/libs/helper.py index b878141d8e..60484dd40b 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -81,6 +81,8 @@ class AvatarUrlField(fields.Raw): from models import Account if isinstance(obj, Account) and obj.avatar is not None: + if obj.avatar.startswith(("http://", "https://")): + return obj.avatar return file_helpers.get_signed_file_url(obj.avatar) return None diff --git a/api/libs/login.py b/api/libs/login.py index 5ed4bfae8f..4b8ee2d1f8 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -3,7 +3,7 @@ from functools import wraps from typing import Any from flask import current_app, g, has_request_context, request -from flask_login.config import EXEMPT_METHODS # type: ignore +from flask_login.config import EXEMPT_METHODS from werkzeug.local import LocalProxy from configs import dify_config @@ -87,7 +87,7 @@ def _get_user() -> EndUser | Account | None: if "_login_user" not in g: current_app.login_manager._load_user() # type: ignore - return g._login_user # type: ignore + return g._login_user return None diff --git a/api/libs/sendgrid.py b/api/libs/sendgrid.py index a270fa70fa..c047c54d06 100644 --- a/api/libs/sendgrid.py +++ b/api/libs/sendgrid.py @@ -1,8 +1,8 @@ import logging -import sendgrid # type: ignore +import sendgrid from python_http_client.exceptions import ForbiddenError, UnauthorizedError -from sendgrid.helpers.mail import Content, Email, Mail, To # type: ignore +from sendgrid.helpers.mail import Content, Email, Mail, To logger = logging.getLogger(__name__) diff --git a/api/libs/token.py b/api/libs/token.py index 4be25696e7..0b40f18143 100644 --- a/api/libs/token.py +++ b/api/libs/token.py @@ -12,6 +12,7 @@ from constants import ( COOKIE_NAME_CSRF_TOKEN, COOKIE_NAME_PASSPORT, COOKIE_NAME_REFRESH_TOKEN, + COOKIE_NAME_WEBAPP_ACCESS_TOKEN, HEADER_NAME_CSRF_TOKEN, HEADER_NAME_PASSPORT, ) @@ -81,6 +82,14 @@ def extract_access_token(request: Request) -> str | None: return _try_extract_from_cookie(request) or _try_extract_from_header(request) +def extract_webapp_access_token(request: Request) -> str | None: + """ + Try to extract webapp access token from cookie, then header. + """ + + return request.cookies.get(_real_cookie_name(COOKIE_NAME_WEBAPP_ACCESS_TOKEN)) or _try_extract_from_header(request) + + def extract_webapp_passport(app_code: str, request: Request) -> str | None: """ Try to extract app token from header or params. @@ -155,6 +164,10 @@ def clear_access_token_from_cookie(response: Response, samesite: str = "Lax"): _clear_cookie(response, COOKIE_NAME_ACCESS_TOKEN, samesite) +def clear_webapp_access_token_from_cookie(response: Response, samesite: str = "Lax"): + _clear_cookie(response, COOKIE_NAME_WEBAPP_ACCESS_TOKEN, samesite) + + def clear_refresh_token_from_cookie(response: Response): _clear_cookie(response, COOKIE_NAME_REFRESH_TOKEN) diff --git a/api/migrations/versions/2025_10_21_1430-ae662b25d9bc_remove_builtin_template_user.py b/api/migrations/versions/2025_10_21_1430-ae662b25d9bc_remove_builtin_template_user.py new file mode 100644 index 0000000000..086a02e7c3 --- /dev/null +++ b/api/migrations/versions/2025_10_21_1430-ae662b25d9bc_remove_builtin_template_user.py @@ -0,0 +1,36 @@ +"""remove-builtin-template-user + +Revision ID: ae662b25d9bc +Revises: d98acf217d43 +Create Date: 2025-10-21 14:30:28.566192 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'ae662b25d9bc' +down_revision = 'd98acf217d43' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op: + batch_op.drop_column('updated_by') + batch_op.drop_column('created_by') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op: + batch_op.add_column(sa.Column('created_by', sa.UUID(), autoincrement=False, nullable=False)) + batch_op.add_column(sa.Column('updated_by', sa.UUID(), autoincrement=False, nullable=True)) + + # ### end Alembic commands ### diff --git a/api/models/account.py b/api/models/account.py index 86cd9e41b5..400a2c6362 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -5,7 +5,7 @@ from datetime import datetime from typing import Any, Optional import sqlalchemy as sa -from flask_login import UserMixin # type: ignore[import-untyped] +from flask_login import UserMixin from sqlalchemy import DateTime, String, func, select from sqlalchemy.orm import Mapped, Session, mapped_column from typing_extensions import deprecated diff --git a/api/models/dataset.py b/api/models/dataset.py index 5653445f2b..4a9e2688b8 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -1239,15 +1239,6 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined] language = mapped_column(db.String(255), nullable=False) created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) - created_by = mapped_column(StringUUID, nullable=False) - updated_by = mapped_column(StringUUID, nullable=True) - - @property - def created_user_name(self): - account = db.session.query(Account).where(Account.id == self.created_by).first() - if account: - return account.name - return "" class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] diff --git a/api/models/tools.py b/api/models/tools.py index d45db3d350..12acc149b1 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -219,7 +219,7 @@ class WorkflowToolProvider(TypeBase): sa.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) # name of the workflow provider name: Mapped[str] = mapped_column(String(255), nullable=False) # label of the workflow provider diff --git a/api/pyproject.toml b/api/pyproject.toml index a14c120f55..7275693e99 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dify-api" -version = "1.9.1" +version = "1.9.2" requires-python = ">=3.11,<3.13" dependencies = [ diff --git a/api/pyrightconfig.json b/api/pyrightconfig.json index bf4ec2314e..6a689b96df 100644 --- a/api/pyrightconfig.json +++ b/api/pyrightconfig.json @@ -16,7 +16,25 @@ "opentelemetry.instrumentation.requests", "opentelemetry.instrumentation.sqlalchemy", "opentelemetry.instrumentation.redis", - "opentelemetry.instrumentation.httpx" + "langfuse", + "cloudscraper", + "readabilipy", + "pypandoc", + "pypdfium2", + "webvtt", + "flask_compress", + "oss2", + "baidubce.auth.bce_credentials", + "baidubce.bce_client_configuration", + "baidubce.services.bos.bos_client", + "clickzetta", + "google.cloud", + "obs", + "qcloud_cos", + "tos", + "gmpy2", + "sendgrid", + "sendgrid.helpers.mail" ], "reportUnknownMemberType": "hint", "reportUnknownParameterType": "hint", @@ -28,7 +46,7 @@ "reportUnnecessaryComparison": "hint", "reportUnnecessaryIsInstance": "hint", "reportUntypedFunctionDecorator": "hint", - + "reportUnnecessaryTypeIgnoreComment": "hint", "reportAttributeAccessIssue": "hint", "pythonVersion": "3.11", "pythonPlatform": "All" diff --git a/api/repositories/factory.py b/api/repositories/factory.py index 0be9c8908c..96f9f886a4 100644 --- a/api/repositories/factory.py +++ b/api/repositories/factory.py @@ -48,7 +48,7 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory): try: repository_class = import_string(class_path) - return repository_class(session_maker=session_maker) # type: ignore[no-any-return] + return repository_class(session_maker=session_maker) except (ImportError, Exception) as e: raise RepositoryImportError( f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}" @@ -77,6 +77,6 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory): try: repository_class = import_string(class_path) - return repository_class(session_maker=session_maker) # type: ignore[no-any-return] + return repository_class(session_maker=session_maker) except (ImportError, Exception) as e: raise RepositoryImportError(f"Failed to create APIWorkflowRunRepository from '{class_path}': {e}") from e diff --git a/api/services/account_service.py b/api/services/account_service.py index a1d2d73085..4976a2121c 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -13,7 +13,7 @@ from sqlalchemy.orm import Session from werkzeug.exceptions import Unauthorized from configs import dify_config -from constants.languages import language_timezone_mapping, languages +from constants.languages import get_valid_language, language_timezone_mapping from events.tenant_event import tenant_was_created from extensions.ext_database import db from extensions.ext_redis import redis_client, redis_fallback @@ -1264,7 +1264,7 @@ class RegisterService: return f"member_invite:token:{token}" @classmethod - def setup(cls, email: str, name: str, password: str, ip_address: str): + def setup(cls, email: str, name: str, password: str, ip_address: str, language: str): """ Setup dify @@ -1274,11 +1274,10 @@ class RegisterService: :param ip_address: ip address """ try: - # Register account = AccountService.create_account( email=email, name=name, - interface_language=languages[0], + interface_language=get_valid_language(language), password=password, is_setup=True, ) @@ -1320,7 +1319,7 @@ class RegisterService: account = AccountService.create_account( email=email, name=name, - interface_language=language or languages[0], + interface_language=get_valid_language(language), password=password, is_setup=is_setup, ) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index e2915ebfbb..edb18a845a 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -7,7 +7,7 @@ from enum import StrEnum from urllib.parse import urlparse from uuid import uuid4 -import yaml # type: ignore +import yaml from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad from packaging import version @@ -563,7 +563,7 @@ class AppDslService: else: cls._append_model_config_export_data(export_data, app_model) - return yaml.dump(export_data, allow_unicode=True) # type: ignore + return yaml.dump(export_data, allow_unicode=True) @classmethod def _append_workflow_export_data( diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index f4047da6b8..c97d419545 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -241,9 +241,9 @@ class DatasetService: dataset.created_by = account.id dataset.updated_by = account.id dataset.tenant_id = tenant_id - dataset.embedding_model_provider = embedding_model.provider if embedding_model else None # type: ignore - dataset.embedding_model = embedding_model.model if embedding_model else None # type: ignore - dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None # type: ignore + dataset.embedding_model_provider = embedding_model.provider if embedding_model else None + dataset.embedding_model = embedding_model.model if embedding_model else None + dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None dataset.permission = permission or DatasetPermissionEnum.ONLY_ME dataset.provider = provider db.session.add(dataset) @@ -1416,6 +1416,8 @@ class DocumentService: # check document limit assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None + assert knowledge_config.data_source + assert knowledge_config.data_source.info_list.file_info_list features = FeatureService.get_features(current_user.current_tenant_id) @@ -1424,15 +1426,16 @@ class DocumentService: count = 0 if knowledge_config.data_source: if knowledge_config.data_source.info_list.data_source_type == "upload_file": - upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore + upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids count = len(upload_file_list) elif knowledge_config.data_source.info_list.data_source_type == "notion_import": - notion_info_list = knowledge_config.data_source.info_list.notion_info_list - for notion_info in notion_info_list: # type: ignore + notion_info_list = knowledge_config.data_source.info_list.notion_info_list or [] + for notion_info in notion_info_list: count = count + len(notion_info.pages) elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": website_info = knowledge_config.data_source.info_list.website_info_list - count = len(website_info.urls) # type: ignore + assert website_info + count = len(website_info.urls) batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) if features.billing.subscription.plan == "sandbox" and count > 1: @@ -1444,7 +1447,7 @@ class DocumentService: # if dataset is empty, update dataset data_source_type if not dataset.data_source_type: - dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore + dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type if not dataset.indexing_technique: if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: @@ -1481,7 +1484,7 @@ class DocumentService: knowledge_config.retrieval_model.model_dump() if knowledge_config.retrieval_model else default_retrieval_model - ) # type: ignore + ) documents = [] if knowledge_config.original_document_id: @@ -1523,11 +1526,12 @@ class DocumentService: db.session.flush() lock_name = f"add_document_lock_dataset_id_{dataset.id}" with redis_client.lock(lock_name, timeout=600): + assert dataset_process_rule position = DocumentService.get_documents_position(dataset.id) document_ids = [] duplicate_document_ids = [] - if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore - upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore + if knowledge_config.data_source.info_list.data_source_type == "upload_file": + upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids for file_id in upload_file_list: file = ( db.session.query(UploadFile) @@ -1540,7 +1544,7 @@ class DocumentService: raise FileNotExistsError() file_name = file.name - data_source_info = { + data_source_info: dict[str, str | bool] = { "upload_file_id": file_id, } # check duplicate @@ -1557,7 +1561,7 @@ class DocumentService: .first() ) if document: - document.dataset_process_rule_id = dataset_process_rule.id # type: ignore + document.dataset_process_rule_id = dataset_process_rule.id document.updated_at = naive_utc_now() document.created_from = created_from document.doc_form = knowledge_config.doc_form @@ -1571,8 +1575,8 @@ class DocumentService: continue document = DocumentService.build_document( dataset, - dataset_process_rule.id, # type: ignore - knowledge_config.data_source.info_list.data_source_type, # type: ignore + dataset_process_rule.id, + knowledge_config.data_source.info_list.data_source_type, knowledge_config.doc_form, knowledge_config.doc_language, data_source_info, @@ -1587,7 +1591,7 @@ class DocumentService: document_ids.append(document.id) documents.append(document) position += 1 - elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore + elif knowledge_config.data_source.info_list.data_source_type == "notion_import": notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore if not notion_info_list: raise ValueError("No notion info list found.") @@ -1616,15 +1620,15 @@ class DocumentService: "credential_id": notion_info.credential_id, "notion_workspace_id": workspace_id, "notion_page_id": page.page_id, - "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, + "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore "type": page.type, } # Truncate page name to 255 characters to prevent DB field length errors truncated_page_name = page.page_name[:255] if page.page_name else "nopagename" document = DocumentService.build_document( dataset, - dataset_process_rule.id, # type: ignore - knowledge_config.data_source.info_list.data_source_type, # type: ignore + dataset_process_rule.id, + knowledge_config.data_source.info_list.data_source_type, knowledge_config.doc_form, knowledge_config.doc_language, data_source_info, @@ -1644,8 +1648,8 @@ class DocumentService: # delete not selected documents if len(exist_document) > 0: clean_notion_document_task.delay(list(exist_document.values()), dataset.id) - elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore - website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore + elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": + website_info = knowledge_config.data_source.info_list.website_info_list if not website_info: raise ValueError("No website info list found.") urls = website_info.urls @@ -1663,8 +1667,8 @@ class DocumentService: document_name = url document = DocumentService.build_document( dataset, - dataset_process_rule.id, # type: ignore - knowledge_config.data_source.info_list.data_source_type, # type: ignore + dataset_process_rule.id, + knowledge_config.data_source.info_list.data_source_type, knowledge_config.doc_form, knowledge_config.doc_language, data_source_info, @@ -2071,7 +2075,7 @@ class DocumentService: # update document data source if document_data.data_source: file_name = "" - data_source_info = {} + data_source_info: dict[str, str | bool] = {} if document_data.data_source.info_list.data_source_type == "upload_file": if not document_data.data_source.info_list.file_info_list: raise ValueError("No file info list found.") @@ -2128,7 +2132,7 @@ class DocumentService: "url": url, "provider": website_info.provider, "job_id": website_info.job_id, - "only_main_content": website_info.only_main_content, # type: ignore + "only_main_content": website_info.only_main_content, "mode": "crawl", } document.data_source_type = document_data.data_source.info_list.data_source_type @@ -2154,7 +2158,7 @@ class DocumentService: db.session.query(DocumentSegment).filter_by(document_id=document.id).update( {DocumentSegment.status: "re_segment"} - ) # type: ignore + ) db.session.commit() # trigger async task document_indexing_update_task.delay(document.dataset_id, document.id) @@ -2164,25 +2168,26 @@ class DocumentService: def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account): assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None + assert knowledge_config.data_source features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: count = 0 - if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore + if knowledge_config.data_source.info_list.data_source_type == "upload_file": upload_file_list = ( - knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore - if knowledge_config.data_source.info_list.file_info_list # type: ignore + knowledge_config.data_source.info_list.file_info_list.file_ids + if knowledge_config.data_source.info_list.file_info_list else [] ) count = len(upload_file_list) - elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore - notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore + elif knowledge_config.data_source.info_list.data_source_type == "notion_import": + notion_info_list = knowledge_config.data_source.info_list.notion_info_list if notion_info_list: for notion_info in notion_info_list: count = count + len(notion_info.pages) - elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore - website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore + elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": + website_info = knowledge_config.data_source.info_list.website_info_list if website_info: count = len(website_info.urls) if features.billing.subscription.plan == "sandbox" and count > 1: @@ -2196,9 +2201,11 @@ class DocumentService: dataset_collection_binding_id = None retrieval_model = None if knowledge_config.indexing_technique == "high_quality": + assert knowledge_config.embedding_model_provider + assert knowledge_config.embedding_model dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - knowledge_config.embedding_model_provider, # type: ignore - knowledge_config.embedding_model, # type: ignore + knowledge_config.embedding_model_provider, + knowledge_config.embedding_model, ) dataset_collection_binding_id = dataset_collection_binding.id if knowledge_config.retrieval_model: @@ -2215,7 +2222,7 @@ class DocumentService: dataset = Dataset( tenant_id=tenant_id, name="", - data_source_type=knowledge_config.data_source.info_list.data_source_type, # type: ignore + data_source_type=knowledge_config.data_source.info_list.data_source_type, indexing_technique=knowledge_config.indexing_technique, created_by=account.id, embedding_model=knowledge_config.embedding_model, @@ -2224,7 +2231,7 @@ class DocumentService: retrieval_model=retrieval_model.model_dump() if retrieval_model else None, ) - db.session.add(dataset) # type: ignore + db.session.add(dataset) db.session.flush() documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 1b5805b220..0fb8cb0bb7 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -178,6 +178,7 @@ class FeatureService: if dify_config.ENTERPRISE_ENABLED: features.webapp_copyright_enabled = True + features.knowledge_pipeline.publish_enabled = True cls._fulfill_params_from_workspace_info(features, tenant_id) return features diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 7fa82c6d22..337181728c 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -88,7 +88,7 @@ class HitTestingService: db.session.add(dataset_query) db.session.commit() - return cls.compact_retrieve_response(query, all_documents) # type: ignore + return cls.compact_retrieve_response(query, all_documents) @classmethod def external_retrieve( diff --git a/api/services/knowledge_service.py b/api/services/knowledge_service.py index 8df1a6ba14..02fe1d19bc 100644 --- a/api/services/knowledge_service.py +++ b/api/services/knowledge_service.py @@ -1,4 +1,4 @@ -import boto3 # type: ignore +import boto3 from configs import dify_config diff --git a/api/services/message_service.py b/api/services/message_service.py index 9fdff18622..7ed56d80f2 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -288,9 +288,10 @@ class MessageService: ) with measure_time() as timer: - questions: list[str] = LLMGenerator.generate_suggested_questions_after_answer( + questions_sequence = LLMGenerator.generate_suggested_questions_after_answer( tenant_id=app_model.tenant_id, histories=histories ) + questions: list[str] = list(questions_sequence) # get tracing instance trace_manager = TraceQueueManager(app_id=app_model.id) diff --git a/api/services/metadata_service.py b/api/services/metadata_service.py index 5f280c9e57..b369994d2d 100644 --- a/api/services/metadata_service.py +++ b/api/services/metadata_service.py @@ -89,7 +89,7 @@ class MetadataService: document.doc_metadata = doc_metadata db.session.add(document) db.session.commit() - return metadata # type: ignore + return metadata except Exception: logger.exception("Update metadata name failed") finally: diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 2901a0d273..50ddbbf681 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -137,7 +137,7 @@ class ModelProviderService: :return: """ provider_configuration = self._get_provider_configuration(tenant_id, provider) - return provider_configuration.get_provider_credential(credential_id=credential_id) # type: ignore + return provider_configuration.get_provider_credential(credential_id=credential_id) def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict): """ @@ -225,7 +225,7 @@ class ModelProviderService: :return: """ provider_configuration = self._get_provider_configuration(tenant_id, provider) - return provider_configuration.get_custom_model_credential( # type: ignore + return provider_configuration.get_custom_model_credential( model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id ) diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index dec92a6faa..df5fa3e233 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -146,7 +146,7 @@ class PluginMigration: futures.append( thread_pool.submit( process_tenant, - current_app._get_current_object(), # type: ignore[attr-defined] + current_app._get_current_object(), # type: ignore tenant_id, ) ) diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index ec91f79606..908f9a2684 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -74,5 +74,4 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): "chunk_structure": pipeline_template.chunk_structure, "export_data": pipeline_template.yaml_content, "graph": graph_data, - "created_by": pipeline_template.created_user_name, } diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index b5dcec17d0..0628c8f22e 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -544,8 +544,8 @@ class BuiltinToolManageService: try: # handle include, exclude if is_filtered( - include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore - exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore + include_set=dify_config.POSITION_TOOL_INCLUDES_SET, + exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, data=provider_controller, name_func=lambda x: x.entity.identity.name, ): diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index 18f4c9250e..b24483b9c6 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -1,10 +1,12 @@ import hashlib import json import logging -from collections.abc import Callable from datetime import datetime +from enum import StrEnum from typing import Any +from urllib.parse import urlparse +from pydantic import BaseModel, Field from sqlalchemy import or_, select from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session @@ -12,6 +14,7 @@ from sqlalchemy.orm import Session from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity from core.helper import encrypter from core.helper.provider_cache import NoOpProviderCredentialCache +from core.mcp.auth.auth_flow import auth from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.error import MCPAuthError, MCPError from core.tools.entities.api_entities import ToolProviderApiEntity @@ -28,6 +31,38 @@ EMPTY_TOOLS_JSON = "[]" EMPTY_CREDENTIALS_JSON = "{}" +class OAuthDataType(StrEnum): + """Types of OAuth data that can be saved.""" + + TOKENS = "tokens" + CLIENT_INFO = "client_info" + CODE_VERIFIER = "code_verifier" + MIXED = "mixed" + + +class ReconnectResult(BaseModel): + """Result of reconnecting to an MCP provider""" + + authed: bool = Field(description="Whether the provider is authenticated") + tools: str = Field(description="JSON string of tool list") + encrypted_credentials: str = Field(description="JSON string of encrypted credentials") + + +class ServerUrlValidationResult(BaseModel): + """Result of server URL validation check""" + + needs_validation: bool + validation_passed: bool = False + reconnect_result: ReconnectResult | None = None + encrypted_server_url: str | None = None + server_url_hash: str | None = None + + @property + def should_update_server_url(self) -> bool: + """Check if server URL should be updated based on validation result""" + return self.needs_validation and self.validation_passed and self.reconnect_result is not None + + class MCPToolManageService: """Service class for managing MCP tools and providers.""" @@ -91,6 +126,10 @@ class MCPToolManageService: headers: dict[str, str] | None = None, ) -> ToolProviderApiEntity: """Create a new MCP provider.""" + # Validate URL format + if not self._is_valid_url(server_url): + raise ValueError("Server URL is not valid.") + server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() # Check for existing provider @@ -99,13 +138,12 @@ class MCPToolManageService: # Encrypt sensitive data encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url) encrypted_headers = self._prepare_encrypted_dict(headers, tenant_id) if headers else None - if authentication is not None and authentication.client_id and authentication.client_secret: - # Build the full credentials structure with encrypted client_id and client_secret + encrypted_credentials = None + if authentication is not None and authentication.client_id: encrypted_credentials = self._build_and_encrypt_credentials( authentication.client_id, authentication.client_secret, tenant_id ) - else: - encrypted_credentials = None + # Create provider mcp_tool = MCPToolProvider( tenant_id=tenant_id, @@ -142,24 +180,39 @@ class MCPToolManageService: headers: dict[str, str] | None = None, configuration: MCPConfiguration, authentication: MCPAuthentication | None = None, + validation_result: ServerUrlValidationResult | None = None, ) -> None: - """Update an MCP provider.""" + """ + Update an MCP provider. + + Args: + validation_result: Pre-validation result from validate_server_url_change. + If provided and contains reconnect_result, it will be used + instead of performing network operations. + """ mcp_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) - reconnect_result = None + # Check for duplicate name (excluding current provider) + if name != mcp_provider.name: + stmt = select(MCPToolProvider).where( + MCPToolProvider.tenant_id == tenant_id, + MCPToolProvider.name == name, + MCPToolProvider.id != provider_id, + ) + existing_provider = self._session.scalar(stmt) + if existing_provider: + raise ValueError(f"MCP tool {name} already exists") + + # Get URL update data from validation result encrypted_server_url = None server_url_hash = None + reconnect_result = None - # Handle server URL update - if UNCHANGED_SERVER_URL_PLACEHOLDER not in server_url: - encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url) - server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() - - if server_url_hash != mcp_provider.server_url_hash: - reconnect_result = self._reconnect_provider( - server_url=server_url, - provider=mcp_provider, - ) + if validation_result and validation_result.encrypted_server_url: + # Use all data from validation result + encrypted_server_url = validation_result.encrypted_server_url + server_url_hash = validation_result.server_url_hash + reconnect_result = validation_result.reconnect_result try: # Update basic fields @@ -169,63 +222,35 @@ class MCPToolManageService: mcp_provider.server_identifier = server_identifier # Update server URL if changed - if encrypted_server_url is not None and server_url_hash is not None: + if encrypted_server_url and server_url_hash: mcp_provider.server_url = encrypted_server_url mcp_provider.server_url_hash = server_url_hash if reconnect_result: - mcp_provider.authed = reconnect_result["authed"] - mcp_provider.tools = reconnect_result["tools"] - mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"] + mcp_provider.authed = reconnect_result.authed + mcp_provider.tools = reconnect_result.tools + mcp_provider.encrypted_credentials = reconnect_result.encrypted_credentials - # Update optional fields - if configuration.timeout is not None: - mcp_provider.timeout = configuration.timeout - if configuration.sse_read_timeout is not None: - mcp_provider.sse_read_timeout = configuration.sse_read_timeout + # Update optional configuration fields + self._update_optional_fields(mcp_provider, configuration) + + # Update headers if provided if headers is not None: - if headers: - # Build headers preserving unchanged masked values - final_headers = self._merge_headers_with_masked(incoming_headers=headers, mcp_provider=mcp_provider) - encrypted_headers_dict = self._prepare_encrypted_dict(final_headers, tenant_id) - mcp_provider.encrypted_headers = encrypted_headers_dict - else: - # Clear headers if empty dict passed - mcp_provider.encrypted_headers = None + mcp_provider.encrypted_headers = self._process_headers(headers, mcp_provider, tenant_id) # Update credentials if provided - if authentication is not None and authentication.client_id and authentication.client_secret: - # Merge with existing credentials to handle masked values - ( - final_client_id, - final_client_secret, - ) = self._merge_credentials_with_masked( - authentication.client_id, authentication.client_secret, mcp_provider - ) + if authentication and authentication.client_id: + mcp_provider.encrypted_credentials = self._process_credentials(authentication, mcp_provider, tenant_id) - # Build and encrypt new credentials - encrypted_credentials = self._build_and_encrypt_credentials( - final_client_id, final_client_secret, tenant_id - ) - mcp_provider.encrypted_credentials = encrypted_credentials - - self._session.commit() + # Flush changes to database + self._session.flush() except IntegrityError as e: - self._session.rollback() self._handle_integrity_error(e, name, server_url, server_identifier) - except (ValueError, AttributeError, TypeError) as e: - # Catch specific exceptions that might occur during update - # ValueError: invalid data provided - # AttributeError: missing required attributes - # TypeError: type conversion errors - self._session.rollback() - raise def delete_provider(self, *, tenant_id: str, provider_id: str) -> None: """Delete an MCP provider.""" mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) self._session.delete(mcp_tool) - self._session.commit() def list_providers(self, *, tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]: """List all MCP providers for a tenant.""" @@ -241,8 +266,6 @@ class MCPToolManageService: def list_provider_tools(self, *, tenant_id: str, provider_id: str) -> ToolProviderApiEntity: """List tools from remote MCP server.""" - from core.mcp.auth.auth_flow import auth - # Load provider and convert to entity db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) provider_entity = db_provider.to_entity() @@ -257,9 +280,7 @@ class MCPToolManageService: # Retrieve tools from remote server server_url = provider_entity.decrypt_server_url() try: - tools = self._retrieve_remote_mcp_tools( - server_url, headers, provider_entity, lambda p, s, c: auth(p, self, c) - ) + tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity) except MCPError as e: raise ValueError(f"Failed to connect to MCP server: {e}") @@ -305,9 +326,12 @@ class MCPToolManageService: if not authed: provider.tools = EMPTY_TOOLS_JSON - self._session.commit() + # Flush changes to database + self._session.flush() - def save_oauth_data(self, provider_id: str, tenant_id: str, data: dict[str, Any], data_type: str = "mixed") -> None: + def save_oauth_data( + self, provider_id: str, tenant_id: str, data: dict[str, Any], data_type: OAuthDataType = OAuthDataType.MIXED + ) -> None: """ Save OAuth-related data (tokens, client info, code verifier). @@ -315,12 +339,14 @@ class MCPToolManageService: provider_id: Provider ID tenant_id: Tenant ID data: Data to save (tokens, client info, or code verifier) - data_type: Type of data ('tokens', 'client_info', 'code_verifier', 'mixed') + data_type: Type of OAuth data to save """ db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) # Determine if this makes the provider authenticated - authed = data_type == "tokens" or (data_type == "mixed" and "access_token" in data) or None + authed = ( + data_type == OAuthDataType.TOKENS or (data_type == OAuthDataType.MIXED and "access_token" in data) or None + ) self.update_provider_credentials(provider=db_provider, credentials=data, authed=authed) @@ -330,7 +356,6 @@ class MCPToolManageService: provider.encrypted_credentials = EMPTY_CREDENTIALS_JSON provider.updated_at = datetime.now() provider.authed = False - self._session.commit() # ========== Private Helper Methods ========== @@ -406,41 +431,123 @@ class MCPToolManageService: server_url: str, headers: dict[str, str], provider_entity: MCPProviderEntity, - auth_callback: Callable[[MCPProviderEntity, "MCPToolManageService", str | None], dict[str, str]], ): """Retrieve tools from remote MCP server.""" with MCPClientWithAuthRetry( - server_url, + server_url=server_url, headers=headers, timeout=provider_entity.timeout, sse_read_timeout=provider_entity.sse_read_timeout, provider_entity=provider_entity, - auth_callback=auth_callback, - mcp_service=self, ) as mcp_client: return mcp_client.list_tools() - def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> dict[str, Any]: - """Attempt to reconnect to MCP provider with new server URL.""" - from core.mcp.auth.auth_flow import auth + def execute_auth_actions(self, auth_result: Any) -> dict[str, str]: + """ + Execute the actions returned by the auth function. + This method processes the AuthResult and performs the necessary database operations. + + Args: + auth_result: The result from the auth function + + Returns: + The response from the auth result + """ + from core.mcp.entities import AuthAction, AuthActionType + + action: AuthAction + for action in auth_result.actions: + if action.provider_id is None or action.tenant_id is None: + continue + + if action.action_type == AuthActionType.SAVE_CLIENT_INFO: + self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.CLIENT_INFO) + elif action.action_type == AuthActionType.SAVE_TOKENS: + self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.TOKENS) + elif action.action_type == AuthActionType.SAVE_CODE_VERIFIER: + self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.CODE_VERIFIER) + + return auth_result.response + + def auth_with_actions( + self, provider_entity: MCPProviderEntity, authorization_code: str | None = None + ) -> dict[str, str]: + """ + Perform authentication and execute all resulting actions. + + This method is used by MCPClientWithAuthRetry for automatic re-authentication. + + Args: + provider_entity: The MCP provider entity + authorization_code: Optional authorization code + + Returns: + Response dictionary from auth result + """ + auth_result = auth(provider_entity, authorization_code) + return self.execute_auth_actions(auth_result) + + def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> ReconnectResult: + """Attempt to reconnect to MCP provider with new server URL.""" provider_entity = provider.to_entity() headers = provider_entity.headers try: - tools = self._retrieve_remote_mcp_tools( - server_url, headers, provider_entity, lambda p, s, c: auth(p, self, c) + tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity) + return ReconnectResult( + authed=True, + tools=json.dumps([tool.model_dump() for tool in tools]), + encrypted_credentials=EMPTY_CREDENTIALS_JSON, ) - return { - "authed": True, - "tools": json.dumps([tool.model_dump() for tool in tools]), - "encrypted_credentials": EMPTY_CREDENTIALS_JSON, - } except MCPAuthError: - return {"authed": False, "tools": EMPTY_TOOLS_JSON, "encrypted_credentials": EMPTY_CREDENTIALS_JSON} + return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON) except MCPError as e: raise ValueError(f"Failed to re-connect MCP server: {e}") from e + def validate_server_url_change( + self, *, tenant_id: str, provider_id: str, new_server_url: str + ) -> ServerUrlValidationResult: + """ + Validate server URL change by attempting to connect to the new server. + This method should be called BEFORE update_provider to perform network operations + outside of the database transaction. + + Returns: + ServerUrlValidationResult: Validation result with connection status and tools if successful + """ + # Handle hidden/unchanged URL + if UNCHANGED_SERVER_URL_PLACEHOLDER in new_server_url: + return ServerUrlValidationResult(needs_validation=False) + + # Validate URL format + if not self._is_valid_url(new_server_url): + raise ValueError("Server URL is not valid.") + + # Always encrypt and hash the URL + encrypted_server_url = encrypter.encrypt_token(tenant_id, new_server_url) + new_server_url_hash = hashlib.sha256(new_server_url.encode()).hexdigest() + + # Get current provider + provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) + + # Check if URL is actually different + if new_server_url_hash == provider.server_url_hash: + # URL hasn't changed, but still return the encrypted data + return ServerUrlValidationResult( + needs_validation=False, encrypted_server_url=encrypted_server_url, server_url_hash=new_server_url_hash + ) + + # Perform validation by attempting to connect + reconnect_result = self._reconnect_provider(server_url=new_server_url, provider=provider) + return ServerUrlValidationResult( + needs_validation=True, + validation_passed=True, + reconnect_result=reconnect_result, + encrypted_server_url=encrypted_server_url, + server_url_hash=new_server_url_hash, + ) + def _build_tool_provider_response( self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list ) -> ToolProviderApiEntity: @@ -466,6 +573,45 @@ class MCPToolManageService: raise ValueError(f"MCP tool {server_identifier} already exists") raise + def _is_valid_url(self, url: str) -> bool: + """Validate URL format.""" + if not url: + return False + try: + parsed = urlparse(url) + return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"] + except (ValueError, TypeError): + return False + + def _update_optional_fields(self, mcp_provider: MCPToolProvider, configuration: MCPConfiguration) -> None: + """Update optional configuration fields using setattr for cleaner code.""" + field_mapping = {"timeout": configuration.timeout, "sse_read_timeout": configuration.sse_read_timeout} + + for field, value in field_mapping.items(): + if value is not None: + setattr(mcp_provider, field, value) + + def _process_headers(self, headers: dict[str, str], mcp_provider: MCPToolProvider, tenant_id: str) -> str | None: + """Process headers update, handling empty dict to clear headers.""" + if not headers: + return None + + # Merge with existing headers to preserve masked values + final_headers = self._merge_headers_with_masked(incoming_headers=headers, mcp_provider=mcp_provider) + return self._prepare_encrypted_dict(final_headers, tenant_id) + + def _process_credentials( + self, authentication: MCPAuthentication, mcp_provider: MCPToolProvider, tenant_id: str + ) -> str: + """Process credentials update, handling masked values.""" + # Merge with existing credentials + final_client_id, final_client_secret = self._merge_credentials_with_masked( + authentication.client_id, authentication.client_secret, mcp_provider + ) + + # Build and encrypt + return self._build_and_encrypt_credentials(final_client_id, final_client_secret, tenant_id) + def _merge_headers_with_masked( self, incoming_headers: dict[str, str], mcp_provider: MCPToolProvider ) -> dict[str, str]: @@ -530,12 +676,12 @@ class MCPToolManageService: # Create a flat structure with all credential data credentials_data = { "client_id": client_id, - "encrypted_client_secret": client_secret, "client_name": CLIENT_NAME, "is_dynamic_registration": False, } - - # Only client_id and client_secret need encryption - secret_fields = ["encrypted_client_secret"] if client_secret else [] + secret_fields = [] + if client_secret is not None: + credentials_data["encrypted_client_secret"] = encrypter.encrypt_token(tenant_id, client_secret) + secret_fields = ["encrypted_client_secret"] client_info = self._encrypt_dict_fields(credentials_data, secret_fields, tenant_id) return json.dumps({"client_information": client_info}) diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 2449536d5c..b1cc963681 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -4,6 +4,7 @@ from datetime import datetime from typing import Any from sqlalchemy import or_, select +from sqlalchemy.orm import Session from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool_provider import ToolProviderController @@ -13,6 +14,7 @@ from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurati from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db +from libs.uuid_utils import uuidv7 from models.model import App from models.tools import WorkflowToolProvider from models.workflow import Workflow @@ -63,27 +65,27 @@ class WorkflowToolManageService: if workflow is None: raise ValueError(f"Workflow not found for app {workflow_app_id}") - workflow_tool_provider = WorkflowToolProvider( - tenant_id=tenant_id, - user_id=user_id, - app_id=workflow_app_id, - name=name, - label=label, - icon=json.dumps(icon), - description=description, - parameter_configuration=json.dumps(parameters), - privacy_policy=privacy_policy, - version=workflow.version, - ) + with Session(db.engine, expire_on_commit=False) as session, session.begin(): + workflow_tool_provider = WorkflowToolProvider( + id=str(uuidv7()), + tenant_id=tenant_id, + user_id=user_id, + app_id=workflow_app_id, + name=name, + label=label, + icon=json.dumps(icon), + description=description, + parameter_configuration=json.dumps(parameters), + privacy_policy=privacy_policy, + version=workflow.version, + ) + session.add(workflow_tool_provider) try: WorkflowToolProviderController.from_db(workflow_tool_provider) except Exception as e: raise ValueError(str(e)) - db.session.add(workflow_tool_provider) - db.session.commit() - if labels is not None: ToolLabelManager.update_tool_labels( ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels @@ -168,7 +170,6 @@ class WorkflowToolManageService: except Exception as e: raise ValueError(str(e)) - db.session.add(workflow_tool_provider) db.session.commit() if labels is not None: diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index d02508e4f3..4e13d2d964 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -17,6 +17,7 @@ from core.variables.segments import ( StringSegment, ) from core.variables.utils import dumps_with_segments +from core.workflow.nodes.variable_assigner.common.helpers import UpdatedVariable _MAX_DEPTH = 100 @@ -56,7 +57,7 @@ class UnknownTypeError(Exception): pass -JSONTypes: TypeAlias = int | float | str | list | dict | None | bool +JSONTypes: TypeAlias = int | float | str | list[object] | dict[str, object] | None | bool @dataclasses.dataclass(frozen=True) @@ -79,7 +80,7 @@ class VariableTruncator: self, string_length_limit=5000, array_element_limit: int = 20, - max_size_bytes: int = 1024_000, # 100KB + max_size_bytes: int = 1024_000, # 1000 KiB ): if string_length_limit <= 3: raise ValueError("string_length_limit should be greater than 3.") @@ -202,6 +203,9 @@ class VariableTruncator: """Recursively calculate JSON size without serialization.""" if isinstance(value, Segment): return VariableTruncator.calculate_json_size(value.value) + if isinstance(value, UpdatedVariable): + # TODO(Workflow): migrate UpdatedVariable serialization upstream and drop this fallback. + return VariableTruncator.calculate_json_size(value.model_dump(), depth=depth + 1) if depth > _MAX_DEPTH: raise MaxDepthExceededError() if isinstance(value, str): @@ -248,14 +252,14 @@ class VariableTruncator: truncated_value = value[:truncated_size] + "..." return _PartResult(truncated_value, self.calculate_json_size(truncated_value), True) - def _truncate_array(self, value: list, target_size: int) -> _PartResult[list]: + def _truncate_array(self, value: list[object], target_size: int) -> _PartResult[list[object]]: """ Truncate array with correct strategy: 1. First limit to 20 items 2. If still too large, truncate individual items """ - truncated_value: list[Any] = [] + truncated_value: list[object] = [] truncated = False used_size = self.calculate_json_size([]) @@ -278,7 +282,11 @@ class VariableTruncator: if used_size > target_size: break - part_result = self._truncate_json_primitives(item, target_size - used_size) + remaining_budget = target_size - used_size + if item is None or isinstance(item, (str, list, dict, bool, int, float)): + part_result = self._truncate_json_primitives(item, remaining_budget) + else: + raise UnknownTypeError(f"got unknown type {type(item)} in array truncation") truncated_value.append(part_result.value) used_size += part_result.value_size truncated = part_result.truncated @@ -369,10 +377,10 @@ class VariableTruncator: def _truncate_json_primitives(self, val: str, target_size: int) -> _PartResult[str]: ... @overload - def _truncate_json_primitives(self, val: list, target_size: int) -> _PartResult[list]: ... + def _truncate_json_primitives(self, val: list[object], target_size: int) -> _PartResult[list[object]]: ... @overload - def _truncate_json_primitives(self, val: dict, target_size: int) -> _PartResult[dict]: ... + def _truncate_json_primitives(self, val: dict[str, object], target_size: int) -> _PartResult[dict[str, object]]: ... @overload def _truncate_json_primitives(self, val: bool, target_size: int) -> _PartResult[bool]: ... # type: ignore @@ -387,10 +395,15 @@ class VariableTruncator: def _truncate_json_primitives(self, val: None, target_size: int) -> _PartResult[None]: ... def _truncate_json_primitives( - self, val: str | list | dict | bool | int | float | None, target_size: int + self, + val: UpdatedVariable | str | list[object] | dict[str, object] | bool | int | float | None, + target_size: int, ) -> _PartResult[Any]: """Truncate a value within an object to fit within budget.""" - if isinstance(val, str): + if isinstance(val, UpdatedVariable): + # TODO(Workflow): push UpdatedVariable normalization closer to its producer. + return self._truncate_object(val.model_dump(), target_size) + elif isinstance(val, str): return self._truncate_string(val, target_size) elif isinstance(val, list): return self._truncate_array(val, target_size) diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index b528728364..bd95af2614 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -102,7 +102,7 @@ def batch_create_segment_to_index_task( for segment, tokens in zip(content, tokens_list): content = segment["content"] doc_id = str(uuid.uuid4()) - segment_hash = helper.generate_text_hash(content) # type: ignore + segment_hash = helper.generate_text_hash(content) max_position = ( db.session.query(func.max(DocumentSegment.position)) .where(DocumentSegment.document_id == dataset_document.id) diff --git a/api/tests/integration_tests/conftest.py b/api/tests/integration_tests/conftest.py index 9dc7b76e04..4395a9815a 100644 --- a/api/tests/integration_tests/conftest.py +++ b/api/tests/integration_tests/conftest.py @@ -58,6 +58,7 @@ def setup_account(request) -> Generator[Account, None, None]: name=name, password=secrets.token_hex(16), ip_address="localhost", + language="en-US", ) with _CACHED_APP.test_request_context(): diff --git a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py index 8a43d03a43..3984078ee9 100644 --- a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py @@ -5,11 +5,11 @@ from unittest.mock import MagicMock import pytest from _pytest.monkeypatch import MonkeyPatch -from pymochow import MochowClient # type: ignore -from pymochow.model.database import Database # type: ignore -from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState # type: ignore -from pymochow.model.schema import HNSWParams, VectorIndex # type: ignore -from pymochow.model.table import Table # type: ignore +from pymochow import MochowClient +from pymochow.model.database import Database +from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState +from pymochow.model.schema import HNSWParams, VectorIndex +from pymochow.model.table import Table class AttrDict(UserDict): diff --git a/api/tests/integration_tests/vdb/__mock/tcvectordb.py b/api/tests/integration_tests/vdb/__mock/tcvectordb.py index 5130fcfe17..8f87d6a073 100644 --- a/api/tests/integration_tests/vdb/__mock/tcvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/tcvectordb.py @@ -3,15 +3,15 @@ from typing import Any, Union import pytest from _pytest.monkeypatch import MonkeyPatch -from tcvectordb import RPCVectorDBClient # type: ignore +from tcvectordb import RPCVectorDBClient from tcvectordb.model import enum from tcvectordb.model.collection import FilterIndexConfig -from tcvectordb.model.document import AnnSearch, Document, Filter, KeywordSearch, Rerank # type: ignore -from tcvectordb.model.enum import ReadConsistency # type: ignore -from tcvectordb.model.index import FilterIndex, HNSWParams, Index, IndexField, VectorIndex # type: ignore +from tcvectordb.model.document import AnnSearch, Document, Filter, KeywordSearch, Rerank +from tcvectordb.model.enum import ReadConsistency +from tcvectordb.model.index import FilterIndex, HNSWParams, Index, IndexField, VectorIndex from tcvectordb.rpc.model.collection import RPCCollection from tcvectordb.rpc.model.database import RPCDatabase -from xinference_client.types import Embedding # type: ignore +from xinference_client.types import Embedding class MockTcvectordbClass: diff --git a/api/tests/integration_tests/vdb/__mock/vikingdb.py b/api/tests/integration_tests/vdb/__mock/vikingdb.py index f351df8d5b..289c515b85 100644 --- a/api/tests/integration_tests/vdb/__mock/vikingdb.py +++ b/api/tests/integration_tests/vdb/__mock/vikingdb.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock import pytest from _pytest.monkeypatch import MonkeyPatch -from volcengine.viking_db import ( # type: ignore +from volcengine.viking_db import ( Collection, Data, DistanceType, diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index c59fc50f08..4d4e77a802 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -2299,6 +2299,7 @@ class TestRegisterService: name=admin_name, password=admin_password, ip_address=ip_address, + language="en-US", ) # Verify account was created @@ -2348,6 +2349,7 @@ class TestRegisterService: name=admin_name, password=admin_password, ip_address=ip_address, + language="en-US", ) # Verify no entities were created (rollback worked) diff --git a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py index 6565179f7a..3c77d0c0da 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py @@ -1108,75 +1108,6 @@ class TestMCPToolManageService: assert icon_data["content"] == "🚀" assert icon_data["background"] == "#4ECDC4" - def test_update_mcp_provider_with_server_url_change( - self, db_session_with_containers, mock_external_service_dependencies - ): - """ - Test successful update of MCP provider with server URL change. - - This test verifies: - - Proper handling of server URL changes - - Correct reconnection logic - - Database state updates - - External service integration - """ - # Arrange: Create test data - fake = Faker() - account, tenant = self._create_test_account_and_tenant( - db_session_with_containers, mock_external_service_dependencies - ) - - # Create MCP provider - mcp_provider = self._create_test_mcp_provider( - db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id - ) - - from extensions.ext_database import db - - db.session.commit() - - # Mock the reconnection method - with patch.object(MCPToolManageService, "_reconnect_provider") as mock_reconnect: - mock_reconnect.return_value = { - "authed": True, - "tools": '[{"name": "test_tool"}]', - "encrypted_credentials": "{}", - } - - # Act: Execute the method under test - from core.entities.mcp_provider import MCPConfiguration - from extensions.ext_database import db - - service = MCPToolManageService(db.session()) - service.update_provider( - tenant_id=tenant.id, - provider_id=mcp_provider.id, - name="Updated MCP Provider", - server_url="https://new-example.com/mcp", - icon="🚀", - icon_type="emoji", - icon_background="#4ECDC4", - server_identifier="updated_identifier_123", - configuration=MCPConfiguration( - timeout=45.0, - sse_read_timeout=400.0, - ), - ) - - # Assert: Verify the expected outcomes - db.session.refresh(mcp_provider) - assert mcp_provider.name == "Updated MCP Provider" - assert mcp_provider.server_identifier == "updated_identifier_123" - assert mcp_provider.timeout == 45.0 - assert mcp_provider.sse_read_timeout == 400.0 - assert mcp_provider.updated_at is not None - - # Verify reconnection was called - mock_reconnect.assert_called_once_with( - server_url="https://new-example.com/mcp", - provider=mcp_provider, - ) - def test_update_mcp_provider_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies): """ Test error handling when updating MCP provider with duplicate name. @@ -1387,14 +1318,14 @@ class TestMCPToolManageService: # Assert: Verify the expected outcomes assert result is not None - assert result["authed"] is True - assert result["tools"] is not None - assert result["encrypted_credentials"] == "{}" + assert result.authed is True + assert result.tools is not None + assert result.encrypted_credentials == "{}" # Verify tools were properly serialized import json - tools_data = json.loads(result["tools"]) + tools_data = json.loads(result.tools) assert len(tools_data) == 2 assert tools_data[0]["name"] == "test_tool_1" assert tools_data[1]["name"] == "test_tool_2" @@ -1441,9 +1372,9 @@ class TestMCPToolManageService: # Assert: Verify the expected outcomes assert result is not None - assert result["authed"] is False - assert result["tools"] == "[]" - assert result["encrypted_credentials"] == "{}" + assert result.authed is False + assert result.tools == "[]" + assert result.encrypted_credentials == "{}" def test_re_connect_mcp_provider_connection_error( self, db_session_with_containers, mock_external_service_dependencies diff --git a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py index ae0c7b7a6b..e2c616420f 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py @@ -6,6 +6,7 @@ from faker import Faker from core.tools.entities.api_entities import ToolProviderApiEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType +from libs.uuid_utils import uuidv7 from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider from services.tools.tools_transform_service import ToolTransformService @@ -66,6 +67,7 @@ class TestToolTransformService: ) elif provider_type == "workflow": provider = WorkflowToolProvider( + id=str(uuidv7()), name=fake.company(), description=fake.text(max_nb_chars=100), icon='{"background": "#FF6B6B", "content": "🔧"}', @@ -758,6 +760,7 @@ class TestToolTransformService: # Create workflow tool provider provider = WorkflowToolProvider( + id=str(uuidv7()), name=fake.company(), description=fake.text(max_nb_chars=100), icon='{"background": "#FF6B6B", "content": "🔧"}', diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py index 5895f63f94..8423f1ab02 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py @@ -43,7 +43,7 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue: """Test with None input""" # The method signature expects Union[dict, list, Segment], but implementation handles None # We'll test the actual behavior by passing an empty dict instead - result = WorkflowResponseConverter._fetch_files_from_variable_value(None) # type: ignore + result = WorkflowResponseConverter._fetch_files_from_variable_value(None) assert result == [] def test_fetch_files_from_variable_value_with_empty_dict(self): diff --git a/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py b/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py index 26b5d1f7ce..12a9f11205 100644 --- a/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py +++ b/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py @@ -21,6 +21,7 @@ from core.mcp.auth.auth_flow import ( register_client, start_authorization, ) +from core.mcp.entities import AuthActionType, AuthResult from core.mcp.types import ( OAuthClientInformation, OAuthClientInformationFull, @@ -527,9 +528,10 @@ class TestCallbackHandling: # Setup service mock_service = Mock() - result = handle_callback("state-key", "auth-code", mock_service) + state_result, tokens_result = handle_callback("state-key", "auth-code") - assert result == state_data + assert state_result == state_data + assert tokens_result == tokens # Verify calls mock_retrieve_state.assert_called_once_with("state-key") @@ -541,9 +543,8 @@ class TestCallbackHandling: "test-verifier", "https://redirect.example.com", ) - mock_service.save_oauth_data.assert_called_once_with( - "test-provider", "test-tenant", tokens.model_dump(), "tokens" - ) + # Note: handle_callback no longer saves tokens directly, it just returns them + # The caller (e.g., controller) is responsible for saving via execute_auth_actions class TestAuthOrchestration: @@ -589,21 +590,28 @@ class TestAuthOrchestration: ) mock_start_auth.return_value = ("https://auth.example.com/authorize?...", "code-verifier") - result = auth(mock_provider, mock_service) + result = auth(mock_provider) - assert result == {"authorization_url": "https://auth.example.com/authorize?..."} + # auth() now returns AuthResult + assert isinstance(result, AuthResult) + assert result.response == {"authorization_url": "https://auth.example.com/authorize?..."} + + # Verify that the result contains the correct actions + assert len(result.actions) == 2 + # Check for SAVE_CLIENT_INFO action + client_info_action = next(a for a in result.actions if a.action_type == AuthActionType.SAVE_CLIENT_INFO) + assert client_info_action.data == {"client_information": mock_register.return_value.model_dump()} + assert client_info_action.provider_id == "provider-id" + assert client_info_action.tenant_id == "tenant-id" + + # Check for SAVE_CODE_VERIFIER action + verifier_action = next(a for a in result.actions if a.action_type == AuthActionType.SAVE_CODE_VERIFIER) + assert verifier_action.data == {"code_verifier": "code-verifier"} + assert verifier_action.provider_id == "provider-id" + assert verifier_action.tenant_id == "tenant-id" # Verify calls mock_register.assert_called_once() - mock_service.save_oauth_data.assert_any_call( - "provider-id", - "tenant-id", - {"client_information": mock_register.return_value.model_dump()}, - "client_info", - ) - mock_service.save_oauth_data.assert_any_call( - "provider-id", "tenant-id", {"code_verifier": "code-verifier"}, "code_verifier" - ) @patch("core.mcp.auth.auth_flow.discover_oauth_metadata") @patch("core.mcp.auth.auth_flow._retrieve_redis_state") @@ -637,12 +645,18 @@ class TestAuthOrchestration: tokens = OAuthTokens(access_token="new-token", token_type="Bearer", expires_in=3600) mock_exchange.return_value = tokens - result = auth(mock_provider, mock_service, authorization_code="auth-code", state_param="state-key") + result = auth(mock_provider, authorization_code="auth-code", state_param="state-key") - assert result == {"result": "success"} + # auth() now returns AuthResult, not a dict + assert isinstance(result, AuthResult) + assert result.response == {"result": "success"} - # Verify token save - mock_service.save_oauth_data.assert_called_with("provider-id", "tenant-id", tokens.model_dump(), "tokens") + # Verify that the result contains the correct action + assert len(result.actions) == 1 + assert result.actions[0].action_type == AuthActionType.SAVE_TOKENS + assert result.actions[0].data == tokens.model_dump() + assert result.actions[0].provider_id == "provider-id" + assert result.actions[0].tenant_id == "tenant-id" @patch("core.mcp.auth.auth_flow.discover_oauth_metadata") def test_auth_exchange_code_without_state(self, mock_discover, mock_provider, mock_service): @@ -658,7 +672,7 @@ class TestAuthOrchestration: mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client") with pytest.raises(ValueError) as exc_info: - auth(mock_provider, mock_service, authorization_code="auth-code") + auth(mock_provider, authorization_code="auth-code") assert "State parameter is required" in str(exc_info.value) @@ -691,15 +705,21 @@ class TestAuthOrchestration: grant_types_supported=["authorization_code"], ) - result = auth(mock_provider, mock_service) + result = auth(mock_provider) - assert result == {"result": "success"} + # auth() now returns AuthResult + assert isinstance(result, AuthResult) + assert result.response == {"result": "success"} + + # Verify that the result contains the correct action + assert len(result.actions) == 1 + assert result.actions[0].action_type == AuthActionType.SAVE_TOKENS + assert result.actions[0].data == new_tokens.model_dump() + assert result.actions[0].provider_id == "provider-id" + assert result.actions[0].tenant_id == "tenant-id" # Verify refresh was called mock_refresh.assert_called_once() - mock_service.save_oauth_data.assert_called_with( - "provider-id", "tenant-id", new_tokens.model_dump(), "tokens" - ) @patch("core.mcp.auth.auth_flow.discover_oauth_metadata") def test_auth_registration_fails_with_code(self, mock_discover, mock_provider, mock_service): @@ -715,6 +735,6 @@ class TestAuthOrchestration: mock_provider.retrieve_client_information.return_value = None with pytest.raises(ValueError) as exc_info: - auth(mock_provider, mock_service, authorization_code="auth-code") + auth(mock_provider, authorization_code="auth-code") assert "Existing OAuth client information is required" in str(exc_info.value) diff --git a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py index 895ebdd751..fe9f0935d5 100644 --- a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py +++ b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py @@ -235,7 +235,7 @@ class TestIndividualHandlers: # Type assertion needed due to union type text_content = result.content[0] assert hasattr(text_content, "text") - assert text_content.text == "test answer" # type: ignore[attr-defined] + assert text_content.text == "test answer" def test_handle_call_tool_no_end_user(self): """Test call tool handler without end user""" diff --git a/api/tests/unit_tests/core/mcp/test_auth_client.py b/api/tests/unit_tests/core/mcp/test_auth_client.py deleted file mode 100644 index 7b06c9df4d..0000000000 --- a/api/tests/unit_tests/core/mcp/test_auth_client.py +++ /dev/null @@ -1,420 +0,0 @@ -"""Unit tests for MCP auth client with retry logic.""" - -from types import TracebackType -from unittest.mock import Mock, patch - -import pytest - -from core.entities.mcp_provider import MCPProviderEntity -from core.mcp.auth_client import MCPClientWithAuthRetry -from core.mcp.error import MCPAuthError -from core.mcp.mcp_client import MCPClient -from core.mcp.types import CallToolResult, TextContent, Tool, ToolAnnotations - - -class TestMCPClientWithAuthRetry: - """Test suite for MCPClientWithAuthRetry.""" - - @pytest.fixture - def mock_provider_entity(self): - """Create a mock provider entity.""" - provider = Mock(spec=MCPProviderEntity) - provider.id = "test-provider-id" - provider.tenant_id = "test-tenant-id" - provider.retrieve_tokens.return_value = Mock( - access_token="test-token", token_type="Bearer", expires_in=3600, refresh_token=None - ) - return provider - - @pytest.fixture - def mock_mcp_service(self): - """Create a mock MCP service.""" - service = Mock() - service.get_provider_entity.return_value = Mock( - retrieve_tokens=lambda: Mock( - access_token="new-test-token", token_type="Bearer", expires_in=3600, refresh_token=None - ) - ) - return service - - @pytest.fixture - def auth_callback(self): - """Create a mock auth callback.""" - return Mock() - - def test_init(self, mock_provider_entity, mock_mcp_service, auth_callback): - """Test client initialization.""" - client = MCPClientWithAuthRetry( - server_url="http://test.example.com", - headers={"Authorization": "Bearer test"}, - timeout=30.0, - sse_read_timeout=60.0, - provider_entity=mock_provider_entity, - auth_callback=auth_callback, - authorization_code="test-auth-code", - by_server_id=True, - mcp_service=mock_mcp_service, - ) - - assert client.server_url == "http://test.example.com" - assert client.headers == {"Authorization": "Bearer test"} - assert client.timeout == 30.0 - assert client.sse_read_timeout == 60.0 - assert client.provider_entity == mock_provider_entity - assert client.auth_callback == auth_callback - assert client.authorization_code == "test-auth-code" - assert client.by_server_id is True - assert client.mcp_service == mock_mcp_service - assert client._has_retried is False - # In inheritance design, we don't have _client attribute - assert hasattr(client, "_session") # Inherited from MCPClient - - def test_inheritance_structure(self): - """Test that MCPClientWithAuthRetry properly inherits from MCPClient.""" - client = MCPClientWithAuthRetry( - server_url="http://test.example.com", - headers={"Authorization": "Bearer test"}, - ) - - # Verify inheritance - assert isinstance(client, MCPClient) - - # Verify inherited attributes are accessible - assert hasattr(client, "server_url") - assert hasattr(client, "headers") - assert hasattr(client, "_session") - assert hasattr(client, "_exit_stack") - assert hasattr(client, "_initialized") - - def test_handle_auth_error_no_retry_components(self): - """Test auth error handling when retry components are missing.""" - client = MCPClientWithAuthRetry(server_url="http://test.example.com") - error = MCPAuthError("Auth failed") - - with pytest.raises(MCPAuthError) as exc_info: - client._handle_auth_error(error) - - assert exc_info.value == error - - def test_handle_auth_error_already_retried(self, mock_provider_entity, mock_mcp_service, auth_callback): - """Test auth error handling when already retried.""" - client = MCPClientWithAuthRetry( - server_url="http://test.example.com", - provider_entity=mock_provider_entity, - auth_callback=auth_callback, - mcp_service=mock_mcp_service, - ) - client._has_retried = True - error = MCPAuthError("Auth failed") - - with pytest.raises(MCPAuthError) as exc_info: - client._handle_auth_error(error) - - assert exc_info.value == error - auth_callback.assert_not_called() - - def test_handle_auth_error_successful_refresh(self, mock_provider_entity, mock_mcp_service, auth_callback): - """Test successful auth refresh on error.""" - client = MCPClientWithAuthRetry( - server_url="http://test.example.com", - provider_entity=mock_provider_entity, - auth_callback=auth_callback, - authorization_code="test-code", - by_server_id=True, - mcp_service=mock_mcp_service, - ) - - # Configure mocks - new_provider = Mock(spec=MCPProviderEntity) - new_provider.id = "test-provider-id" - new_provider.tenant_id = "test-tenant-id" - new_provider.retrieve_tokens.return_value = Mock( - access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None - ) - mock_mcp_service.get_provider_entity.return_value = new_provider - - error = MCPAuthError("Auth failed") - client._handle_auth_error(error) - - # Verify auth flow - auth_callback.assert_called_once_with(mock_provider_entity, mock_mcp_service, "test-code") - mock_mcp_service.get_provider_entity.assert_called_once_with( - "test-provider-id", "test-tenant-id", by_server_id=True - ) - assert client.headers["Authorization"] == "Bearer new-token" - assert client.authorization_code is None # Should be cleared after use - assert client._has_retried is True - - def test_handle_auth_error_refresh_fails(self, mock_provider_entity, mock_mcp_service, auth_callback): - """Test auth refresh failure.""" - client = MCPClientWithAuthRetry( - server_url="http://test.example.com", - provider_entity=mock_provider_entity, - auth_callback=auth_callback, - mcp_service=mock_mcp_service, - ) - - auth_callback.side_effect = Exception("Auth callback failed") - - error = MCPAuthError("Original auth failed") - with pytest.raises(MCPAuthError) as exc_info: - client._handle_auth_error(error) - - assert "Authentication retry failed" in str(exc_info.value) - - def test_handle_auth_error_no_token_received(self, mock_provider_entity, mock_mcp_service, auth_callback): - """Test auth refresh when no token is received.""" - client = MCPClientWithAuthRetry( - server_url="http://test.example.com", - provider_entity=mock_provider_entity, - auth_callback=auth_callback, - mcp_service=mock_mcp_service, - ) - - # Configure mock to return no token - new_provider = Mock(spec=MCPProviderEntity) - new_provider.retrieve_tokens.return_value = None - mock_mcp_service.get_provider_entity.return_value = new_provider - - error = MCPAuthError("Auth failed") - with pytest.raises(MCPAuthError) as exc_info: - client._handle_auth_error(error) - - assert "no token received" in str(exc_info.value) - - def test_execute_with_retry_success(self): - """Test successful execution without retry.""" - client = MCPClientWithAuthRetry(server_url="http://test.example.com") - - mock_func = Mock(return_value="success") - result = client._execute_with_retry(mock_func, "arg1", kwarg1="value1") - - assert result == "success" - mock_func.assert_called_once_with("arg1", kwarg1="value1") - assert client._has_retried is False - - def test_execute_with_retry_auth_error_then_success(self, mock_provider_entity, mock_mcp_service, auth_callback): - """Test execution with auth error followed by successful retry.""" - client = MCPClientWithAuthRetry( - server_url="http://test.example.com", - provider_entity=mock_provider_entity, - auth_callback=auth_callback, - mcp_service=mock_mcp_service, - ) - - # Configure new provider with token - new_provider = Mock(spec=MCPProviderEntity) - new_provider.retrieve_tokens.return_value = Mock( - access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None - ) - mock_mcp_service.get_provider_entity.return_value = new_provider - - # Mock function that fails first, then succeeds - mock_func = Mock(side_effect=[MCPAuthError("Auth failed"), "success"]) - - # Mock the exit stack and session cleanup - with ( - patch.object(client, "_exit_stack") as mock_exit_stack, - patch.object(client, "_session") as mock_session, - patch.object(client, "_initialize") as mock_initialize, - ): - client._initialized = True - result = client._execute_with_retry(mock_func, "arg1", kwarg1="value1") - - assert result == "success" - assert mock_func.call_count == 2 - mock_func.assert_called_with("arg1", kwarg1="value1") - auth_callback.assert_called_once() - mock_exit_stack.close.assert_called_once() - mock_initialize.assert_called_once() - assert client._has_retried is False # Reset after completion - - def test_execute_with_retry_non_auth_error(self): - """Test execution with non-auth error (no retry).""" - client = MCPClientWithAuthRetry(server_url="http://test.example.com") - - mock_func = Mock(side_effect=ValueError("Some other error")) - - with pytest.raises(ValueError) as exc_info: - client._execute_with_retry(mock_func) - - assert str(exc_info.value) == "Some other error" - mock_func.assert_called_once() - - def test_context_manager_enter(self): - """Test context manager enter.""" - client = MCPClientWithAuthRetry(server_url="http://test.example.com") - - with patch.object(client, "_initialize") as mock_initialize: - result = client.__enter__() - - assert result == client - assert client._initialized is True - mock_initialize.assert_called_once() - - def test_context_manager_enter_with_auth_error(self, mock_provider_entity, mock_mcp_service, auth_callback): - """Test context manager enter with auth error and retry.""" - # Configure new provider with token - new_provider = Mock(spec=MCPProviderEntity) - new_provider.retrieve_tokens.return_value = Mock( - access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None - ) - mock_mcp_service.get_provider_entity.return_value = new_provider - - client = MCPClientWithAuthRetry( - server_url="http://test.example.com", - provider_entity=mock_provider_entity, - auth_callback=auth_callback, - mcp_service=mock_mcp_service, - ) - - # Mock parent class __enter__ to raise auth error first, then succeed - with patch.object(MCPClient, "__enter__") as mock_parent_enter: - mock_parent_enter.side_effect = [MCPAuthError("Auth failed"), client] - - result = client.__enter__() - - assert result == client - assert mock_parent_enter.call_count == 2 - auth_callback.assert_called_once() - - def test_context_manager_exit(self): - """Test context manager exit.""" - client = MCPClientWithAuthRetry(server_url="http://test.example.com") - - with patch.object(client, "cleanup") as mock_cleanup: - exc_type: type[BaseException] | None = None - exc_val: BaseException | None = None - exc_tb: TracebackType | None = None - client.__exit__(exc_type, exc_val, exc_tb) - - mock_cleanup.assert_called_once() - - def test_list_tools_not_initialized(self): - """Test list_tools when client not initialized.""" - client = MCPClientWithAuthRetry(server_url="http://test.example.com") - - with pytest.raises(ValueError) as exc_info: - client.list_tools() - - assert "Session not initialized" in str(exc_info.value) - - def test_list_tools_success(self): - """Test successful list_tools call.""" - client = MCPClientWithAuthRetry(server_url="http://test.example.com") - - expected_tools = [ - Tool( - name="test-tool", - description="A test tool", - inputSchema={"type": "object", "properties": {}}, - annotations=ToolAnnotations(title="Test Tool"), - ) - ] - - # Mock the parent class list_tools method - with patch.object(MCPClient, "list_tools", return_value=expected_tools): - result = client.list_tools() - assert result == expected_tools - - def test_list_tools_with_auth_retry(self, mock_provider_entity, mock_mcp_service, auth_callback): - """Test list_tools with auth retry.""" - client = MCPClientWithAuthRetry( - server_url="http://test.example.com", - provider_entity=mock_provider_entity, - auth_callback=auth_callback, - mcp_service=mock_mcp_service, - ) - - # Configure new provider with token - new_provider = Mock(spec=MCPProviderEntity) - new_provider.retrieve_tokens.return_value = Mock( - access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None - ) - mock_mcp_service.get_provider_entity.return_value = new_provider - - expected_tools = [Tool(name="test-tool", description="A test tool", inputSchema={})] - - # Mock parent class list_tools to raise auth error first, then succeed - with patch.object(MCPClient, "list_tools") as mock_list_tools: - mock_list_tools.side_effect = [MCPAuthError("Auth failed"), expected_tools] - - result = client.list_tools() - - assert result == expected_tools - assert mock_list_tools.call_count == 2 - auth_callback.assert_called_once() - - def test_invoke_tool_not_initialized(self): - """Test invoke_tool when client not initialized.""" - client = MCPClientWithAuthRetry(server_url="http://test.example.com") - - with pytest.raises(ValueError) as exc_info: - client.invoke_tool("test-tool", {"arg": "value"}) - - assert "Session not initialized" in str(exc_info.value) - - def test_invoke_tool_success(self): - """Test successful invoke_tool call.""" - client = MCPClientWithAuthRetry(server_url="http://test.example.com") - - expected_result = CallToolResult( - content=[TextContent(type="text", text="Tool executed successfully")], isError=False - ) - - # Mock the parent class invoke_tool method - with patch.object(MCPClient, "invoke_tool", return_value=expected_result) as mock_invoke: - result = client.invoke_tool("test-tool", {"arg": "value"}) - - assert result == expected_result - mock_invoke.assert_called_once_with("test-tool", {"arg": "value"}) - - def test_invoke_tool_with_auth_retry(self, mock_provider_entity, mock_mcp_service, auth_callback): - """Test invoke_tool with auth retry.""" - client = MCPClientWithAuthRetry( - server_url="http://test.example.com", - provider_entity=mock_provider_entity, - auth_callback=auth_callback, - mcp_service=mock_mcp_service, - ) - - # Configure new provider with token - new_provider = Mock(spec=MCPProviderEntity) - new_provider.retrieve_tokens.return_value = Mock( - access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None - ) - mock_mcp_service.get_provider_entity.return_value = new_provider - - expected_result = CallToolResult(content=[TextContent(type="text", text="Success")], isError=False) - - # Mock parent class invoke_tool to raise auth error first, then succeed - with patch.object(MCPClient, "invoke_tool") as mock_invoke_tool: - mock_invoke_tool.side_effect = [MCPAuthError("Auth failed"), expected_result] - - result = client.invoke_tool("test-tool", {"arg": "value"}) - - assert result == expected_result - assert mock_invoke_tool.call_count == 2 - mock_invoke_tool.assert_called_with("test-tool", {"arg": "value"}) - auth_callback.assert_called_once() - - def test_cleanup(self): - """Test cleanup method.""" - client = MCPClientWithAuthRetry(server_url="http://test.example.com") - - # Mock the parent class cleanup method - with patch.object(MCPClient, "cleanup") as mock_cleanup: - client.cleanup() - mock_cleanup.assert_called_once() - - def test_cleanup_no_client(self): - """Test cleanup when no client exists.""" - client = MCPClientWithAuthRetry(server_url="http://test.example.com") - - # Should not raise - client.cleanup() - - # Since MCPClientWithAuthRetry inherits from MCPClient, - # it doesn't have a _client attribute. The test should just - # verify that cleanup can be called without error. - assert not hasattr(client, "_client") diff --git a/api/tests/unit_tests/core/tools/utils/test_parser.py b/api/tests/unit_tests/core/tools/utils/test_parser.py index e1eab21ca4..f39158aa59 100644 --- a/api/tests/unit_tests/core/tools/utils/test_parser.py +++ b/api/tests/unit_tests/core/tools/utils/test_parser.py @@ -109,3 +109,83 @@ def test_parse_openapi_to_tool_bundle_properties_all_of(app): assert tool_bundles[0].parameters[0].llm_description == "desc prop1" # TODO: support enum in OpenAPI # assert set(tool_bundles[0].parameters[0].options) == {"option1", "option2", "option3"} + + +def test_parse_openapi_to_tool_bundle_default_value_type_casting(app): + """ + Test that default values are properly cast to match parameter types. + This addresses the issue where array default values like [] cause validation errors + when parameter type is inferred as string/number/boolean. + """ + openapi = { + "openapi": "3.0.0", + "info": {"title": "Test API", "version": "1.0.0"}, + "servers": [{"url": "https://example.com"}], + "paths": { + "/product/create": { + "post": { + "operationId": "createProduct", + "summary": "Create a product", + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "categories": { + "description": "List of category identifiers", + "default": [], + "type": "array", + "items": {"type": "string"}, + }, + "name": { + "description": "Product name", + "default": "Default Product", + "type": "string", + }, + "price": {"description": "Product price", "default": 0.0, "type": "number"}, + "available": { + "description": "Product availability", + "default": True, + "type": "boolean", + }, + }, + } + } + } + }, + "responses": {"200": {"description": "Default Response"}}, + } + } + }, + } + + with app.test_request_context(): + tool_bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi) + + assert len(tool_bundles) == 1 + bundle = tool_bundles[0] + assert len(bundle.parameters) == 4 + + # Find parameters by name + params_by_name = {param.name: param for param in bundle.parameters} + + # Check categories parameter (array type with [] default) + categories_param = params_by_name["categories"] + assert categories_param.type == "array" # Will be detected by _get_tool_parameter_type + assert categories_param.default is None # Array default [] is converted to None + + # Check name parameter (string type with string default) + name_param = params_by_name["name"] + assert name_param.type == "string" + assert name_param.default == "Default Product" + + # Check price parameter (number type with number default) + price_param = params_by_name["price"] + assert price_param.type == "number" + assert price_param.default == 0.0 + + # Check available parameter (boolean type with boolean default) + available_param = params_by_name["available"] + assert available_param.type == "boolean" + assert available_param.default is True diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py new file mode 100644 index 0000000000..b55d4998c4 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py @@ -0,0 +1,181 @@ +from __future__ import annotations + +import time +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool +from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType +from core.workflow.graph import Graph +from core.workflow.graph.validation import GraphValidationError +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.base.node import Node +from core.workflow.system_variable import SystemVariable +from models.enums import UserFrom + + +class _TestNode(Node): + node_type = NodeType.ANSWER + execution_type = NodeExecutionType.EXECUTABLE + + @classmethod + def version(cls) -> str: + return "test" + + def __init__( + self, + *, + id: str, + config: Mapping[str, object], + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + data = config.get("data", {}) + if isinstance(data, Mapping): + execution_type = data.get("execution_type") + if isinstance(execution_type, str): + self.execution_type = NodeExecutionType(execution_type) + self._base_node_data = BaseNodeData(title=str(data.get("title", self.id))) + self.data: dict[str, object] = {} + + def init_node_data(self, data: Mapping[str, object]) -> None: + title = str(data.get("title", self.id)) + desc = data.get("description") + error_strategy_value = data.get("error_strategy") + error_strategy: ErrorStrategy | None = None + if isinstance(error_strategy_value, ErrorStrategy): + error_strategy = error_strategy_value + elif isinstance(error_strategy_value, str): + error_strategy = ErrorStrategy(error_strategy_value) + self._base_node_data = BaseNodeData( + title=title, + desc=str(desc) if desc is not None else None, + error_strategy=error_strategy, + ) + self.data = dict(data) + + def _run(self): + raise NotImplementedError + + def _get_error_strategy(self) -> ErrorStrategy | None: + return self._base_node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._base_node_data.retry_config + + def _get_title(self) -> str: + return self._base_node_data.title + + def _get_description(self) -> str | None: + return self._base_node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._base_node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._base_node_data + + +@dataclass(slots=True) +class _SimpleNodeFactory: + graph_init_params: GraphInitParams + graph_runtime_state: GraphRuntimeState + + def create_node(self, node_config: Mapping[str, object]) -> _TestNode: + node_id = str(node_config["id"]) + node = _TestNode( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + ) + node.init_node_data(node_config.get("data", {})) + return node + + +@pytest.fixture +def graph_init_dependencies() -> tuple[_SimpleNodeFactory, dict[str, object]]: + graph_config: dict[str, object] = {"edges": [], "nodes": []} + init_params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config=graph_config, + user_id="user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ) + variable_pool = VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}) + runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + factory = _SimpleNodeFactory(graph_init_params=init_params, graph_runtime_state=runtime_state) + return factory, graph_config + + +def test_graph_initialization_runs_default_validators( + graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], +): + node_factory, graph_config = graph_init_dependencies + graph_config["nodes"] = [ + {"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}}, + {"id": "answer", "data": {"type": NodeType.ANSWER, "title": "Answer"}}, + ] + graph_config["edges"] = [ + {"source": "start", "target": "answer", "sourceHandle": "success"}, + ] + + graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + + assert graph.root_node.id == "start" + assert "answer" in graph.nodes + + +def test_graph_validation_fails_for_unknown_edge_targets( + graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], +) -> None: + node_factory, graph_config = graph_init_dependencies + graph_config["nodes"] = [ + {"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}}, + ] + graph_config["edges"] = [ + {"source": "start", "target": "missing", "sourceHandle": "success"}, + ] + + with pytest.raises(GraphValidationError) as exc: + Graph.init(graph_config=graph_config, node_factory=node_factory) + + assert any(issue.code == "MISSING_NODE" for issue in exc.value.issues) + + +def test_graph_promotes_fail_branch_nodes_to_branch_execution_type( + graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], +) -> None: + node_factory, graph_config = graph_init_dependencies + graph_config["nodes"] = [ + {"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}}, + { + "id": "branch", + "data": { + "type": NodeType.IF_ELSE, + "title": "Branch", + "error_strategy": ErrorStrategy.FAIL_BRANCH, + }, + }, + ] + graph_config["edges"] = [ + {"source": "start", "target": "branch", "sourceHandle": "success"}, + ] + + graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + + assert graph.nodes["branch"].execution_type == NodeExecutionType.BRANCH diff --git a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py index b9947d4693..b359284d00 100644 --- a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py @@ -212,7 +212,7 @@ class TestValidateResult: parameters=[ ParameterConfig( name="status", - type="select", # type: ignore + type="select", description="Status", required=True, options=["active", "inactive"], @@ -400,7 +400,7 @@ class TestTransformResult: parameters=[ ParameterConfig( name="status", - type="select", # type: ignore + type="select", description="Status", required=True, options=["active", "inactive"], @@ -414,7 +414,7 @@ class TestTransformResult: parameters=[ ParameterConfig( name="status", - type="select", # type: ignore + type="select", description="Status", required=True, options=["active", "inactive"], diff --git a/api/tests/unit_tests/core/workflow/test_system_variable.py b/api/tests/unit_tests/core/workflow/test_system_variable.py index 3ae5edb383..f76e81ae55 100644 --- a/api/tests/unit_tests/core/workflow/test_system_variable.py +++ b/api/tests/unit_tests/core/workflow/test_system_variable.py @@ -248,4 +248,4 @@ def test_constructor_with_extra_key(): # Test that SystemVariable should forbid extra keys with pytest.raises(ValidationError): # This should fail because there is an unexpected key. - SystemVariable(invalid_key=1) # type: ignore + SystemVariable(invalid_key=1) diff --git a/api/tests/unit_tests/libs/test_external_api.py b/api/tests/unit_tests/libs/test_external_api.py index c4c376a070..9aa157a651 100644 --- a/api/tests/unit_tests/libs/test_external_api.py +++ b/api/tests/unit_tests/libs/test_external_api.py @@ -14,36 +14,36 @@ def _create_api_app(): api = ExternalApi(bp) @api.route("/bad-request") - class Bad(Resource): # type: ignore - def get(self): # type: ignore + class Bad(Resource): + def get(self): raise BadRequest("invalid input") @api.route("/unauth") - class Unauth(Resource): # type: ignore - def get(self): # type: ignore + class Unauth(Resource): + def get(self): raise Unauthorized("auth required") @api.route("/value-error") - class ValErr(Resource): # type: ignore - def get(self): # type: ignore + class ValErr(Resource): + def get(self): raise ValueError("boom") @api.route("/quota") - class Quota(Resource): # type: ignore - def get(self): # type: ignore + class Quota(Resource): + def get(self): raise AppInvokeQuotaExceededError("quota exceeded") @api.route("/general") - class Gen(Resource): # type: ignore - def get(self): # type: ignore + class Gen(Resource): + def get(self): raise RuntimeError("oops") # Note: We avoid altering default_mediatype to keep normal error paths # Special 400 message rewrite @api.route("/json-empty") - class JsonEmpty(Resource): # type: ignore - def get(self): # type: ignore + class JsonEmpty(Resource): + def get(self): e = BadRequest() # Force the specific message the handler rewrites e.description = "Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)" @@ -51,11 +51,11 @@ def _create_api_app(): # 400 mapping payload path @api.route("/param-errors") - class ParamErrors(Resource): # type: ignore - def get(self): # type: ignore + class ParamErrors(Resource): + def get(self): e = BadRequest() # Coerce a mapping description to trigger param error shaping - e.description = {"field": "is required"} # type: ignore[assignment] + e.description = {"field": "is required"} raise e app.register_blueprint(bp, url_prefix="/api") @@ -105,7 +105,7 @@ def test_external_api_param_mapping_and_quota_and_exc_info_none(): orig_exc_info = ext.sys.exc_info try: - ext.sys.exc_info = lambda: (None, None, None) # type: ignore[assignment] + ext.sys.exc_info = lambda: (None, None, None) app = _create_api_app() client = app.test_client() diff --git a/api/tests/unit_tests/libs/test_flask_utils.py b/api/tests/unit_tests/libs/test_flask_utils.py index e30433bfce..9cab0db24c 100644 --- a/api/tests/unit_tests/libs/test_flask_utils.py +++ b/api/tests/unit_tests/libs/test_flask_utils.py @@ -67,7 +67,7 @@ def test_current_user_not_accessible_across_threads(login_app: Flask, test_user: # without preserve_flask_contexts result["user_accessible"] = current_user.is_authenticated except Exception as e: - result["error"] = str(e) # type: ignore + result["error"] = str(e) # Run the function in a separate thread thread = threading.Thread(target=check_user_in_thread) @@ -110,7 +110,7 @@ def test_current_user_accessible_with_preserve_flask_contexts(login_app: Flask, else: result["user_accessible"] = False except Exception as e: - result["error"] = str(e) # type: ignore + result["error"] = str(e) # Run the function in a separate thread thread = threading.Thread(target=check_user_in_thread_with_manager) diff --git a/api/tests/unit_tests/libs/test_oauth_base.py b/api/tests/unit_tests/libs/test_oauth_base.py index 3e0c235fff..7b7f086dac 100644 --- a/api/tests/unit_tests/libs/test_oauth_base.py +++ b/api/tests/unit_tests/libs/test_oauth_base.py @@ -16,4 +16,4 @@ def test_oauth_base_methods_raise_not_implemented(): oauth.get_raw_user_info("token") with pytest.raises(NotImplementedError): - oauth._transform_user_info({}) # type: ignore[name-defined] + oauth._transform_user_info({}) diff --git a/api/tests/unit_tests/libs/test_token.py b/api/tests/unit_tests/libs/test_token.py index 22790fa4a6..a611d3eb0e 100644 --- a/api/tests/unit_tests/libs/test_token.py +++ b/api/tests/unit_tests/libs/test_token.py @@ -1,5 +1,5 @@ -from constants import COOKIE_NAME_ACCESS_TOKEN -from libs.token import extract_access_token +from constants import COOKIE_NAME_ACCESS_TOKEN, COOKIE_NAME_WEBAPP_ACCESS_TOKEN +from libs.token import extract_access_token, extract_webapp_access_token class MockRequest: @@ -14,10 +14,12 @@ def test_extract_access_token(): return MockRequest(headers, cookies, args) test_cases = [ - (_mock_request({"Authorization": "Bearer 123"}, {}, {}), "123"), - (_mock_request({}, {COOKIE_NAME_ACCESS_TOKEN: "123"}, {}), "123"), - (_mock_request({}, {}, {}), None), - (_mock_request({"Authorization": "Bearer_aaa 123"}, {}, {}), None), + (_mock_request({"Authorization": "Bearer 123"}, {}, {}), "123", "123"), + (_mock_request({}, {COOKIE_NAME_ACCESS_TOKEN: "123"}, {}), "123", None), + (_mock_request({}, {}, {}), None, None), + (_mock_request({"Authorization": "Bearer_aaa 123"}, {}, {}), None, None), + (_mock_request({}, {COOKIE_NAME_WEBAPP_ACCESS_TOKEN: "123"}, {}), None, "123"), ] - for request, expected in test_cases: - assert extract_access_token(request) == expected # pyright: ignore[reportArgumentType] + for request, expected_console, expected_webapp in test_cases: + assert extract_access_token(request) == expected_console # pyright: ignore[reportArgumentType] + assert extract_webapp_access_token(request) == expected_webapp # pyright: ignore[reportArgumentType] diff --git a/api/tests/unit_tests/oss/__mock/tencent_cos.py b/api/tests/unit_tests/oss/__mock/tencent_cos.py index c77c5b08f3..5189b68e87 100644 --- a/api/tests/unit_tests/oss/__mock/tencent_cos.py +++ b/api/tests/unit_tests/oss/__mock/tencent_cos.py @@ -3,8 +3,8 @@ from unittest.mock import MagicMock import pytest from _pytest.monkeypatch import MonkeyPatch -from qcloud_cos import CosS3Client # type: ignore -from qcloud_cos.streambody import StreamBody # type: ignore +from qcloud_cos import CosS3Client +from qcloud_cos.streambody import StreamBody from tests.unit_tests.oss.__mock.base import ( get_example_bucket, diff --git a/api/tests/unit_tests/oss/__mock/volcengine_tos.py b/api/tests/unit_tests/oss/__mock/volcengine_tos.py index 88df59f91c..649d93a202 100644 --- a/api/tests/unit_tests/oss/__mock/volcengine_tos.py +++ b/api/tests/unit_tests/oss/__mock/volcengine_tos.py @@ -4,8 +4,8 @@ from unittest.mock import MagicMock import pytest from _pytest.monkeypatch import MonkeyPatch -from tos import TosClientV2 # type: ignore -from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput # type: ignore +from tos import TosClientV2 +from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput from tests.unit_tests.oss.__mock.base import ( get_example_bucket, diff --git a/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py b/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py index d289751800..303f0493bd 100644 --- a/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py +++ b/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py @@ -1,7 +1,7 @@ from unittest.mock import patch import pytest -from qcloud_cos import CosConfig # type: ignore +from qcloud_cos import CosConfig from extensions.storage.tencent_cos_storage import TencentCosStorage from tests.unit_tests.oss.__mock.base import ( diff --git a/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py b/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py index 1659205ec0..a06623a69e 100644 --- a/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py +++ b/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py @@ -1,7 +1,7 @@ from unittest.mock import patch import pytest -from tos import TosClientV2 # type: ignore +from tos import TosClientV2 from extensions.storage.volcengine_tos_storage import VolcengineTosStorage from tests.unit_tests.oss.__mock.base import ( diff --git a/api/tests/unit_tests/services/auth/test_api_key_auth_service.py b/api/tests/unit_tests/services/auth/test_api_key_auth_service.py index d23298f096..c6c3f677fb 100644 --- a/api/tests/unit_tests/services/auth/test_api_key_auth_service.py +++ b/api/tests/unit_tests/services/auth/test_api_key_auth_service.py @@ -125,13 +125,13 @@ class TestApiKeyAuthService: mock_session.commit = Mock() args_copy = self.mock_args.copy() - original_key = args_copy["credentials"]["config"]["api_key"] # type: ignore + original_key = args_copy["credentials"]["config"]["api_key"] ApiKeyAuthService.create_provider_auth(self.tenant_id, args_copy) # Verify original key is replaced with encrypted key - assert args_copy["credentials"]["config"]["api_key"] == encrypted_key # type: ignore - assert args_copy["credentials"]["config"]["api_key"] != original_key # type: ignore + assert args_copy["credentials"]["config"]["api_key"] == encrypted_key + assert args_copy["credentials"]["config"]["api_key"] != original_key # Verify encryption function is called correctly mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, original_key) @@ -268,7 +268,7 @@ class TestApiKeyAuthService: def test_validate_api_key_auth_args_empty_credentials(self): """Test API key auth args validation - empty credentials""" args = self.mock_args.copy() - args["credentials"] = None # type: ignore + args["credentials"] = None with pytest.raises(ValueError, match="credentials is required"): ApiKeyAuthService.validate_api_key_auth_args(args) @@ -284,7 +284,7 @@ class TestApiKeyAuthService: def test_validate_api_key_auth_args_missing_auth_type(self): """Test API key auth args validation - missing auth_type""" args = self.mock_args.copy() - del args["credentials"]["auth_type"] # type: ignore + del args["credentials"]["auth_type"] with pytest.raises(ValueError, match="auth_type is required"): ApiKeyAuthService.validate_api_key_auth_args(args) @@ -292,7 +292,7 @@ class TestApiKeyAuthService: def test_validate_api_key_auth_args_empty_auth_type(self): """Test API key auth args validation - empty auth_type""" args = self.mock_args.copy() - args["credentials"]["auth_type"] = "" # type: ignore + args["credentials"]["auth_type"] = "" with pytest.raises(ValueError, match="auth_type is required"): ApiKeyAuthService.validate_api_key_auth_args(args) @@ -380,7 +380,7 @@ class TestApiKeyAuthService: def test_validate_api_key_auth_args_dict_credentials_with_list_auth_type(self): """Test API key auth args validation - dict credentials with list auth_type""" args = self.mock_args.copy() - args["credentials"]["auth_type"] = ["api_key"] # type: ignore # list instead of string + args["credentials"]["auth_type"] = ["api_key"] # Current implementation checks if auth_type exists and is truthy, list ["api_key"] is truthy # So this should not raise exception, this test should pass diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index aec8efd880..e35ba74c56 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -898,7 +898,7 @@ class TestRegisterService: mock_dify_setup.return_value = mock_dify_setup_instance # Execute test - RegisterService.setup("admin@example.com", "Admin User", "password123", "192.168.1.1") + RegisterService.setup("admin@example.com", "Admin User", "password123", "192.168.1.1", "en-US") # Verify results mock_create_account.assert_called_once_with( @@ -930,6 +930,7 @@ class TestRegisterService: "Admin User", "password123", "192.168.1.1", + "en-US", ) # Verify rollback operations were called diff --git a/api/tests/unit_tests/services/test_dataset_service_delete_dataset.py b/api/tests/unit_tests/services/test_dataset_service_delete_dataset.py new file mode 100644 index 0000000000..cc718c9997 --- /dev/null +++ b/api/tests/unit_tests/services/test_dataset_service_delete_dataset.py @@ -0,0 +1,216 @@ +from unittest.mock import Mock, patch + +import pytest + +from models.account import Account, TenantAccountRole +from models.dataset import Dataset +from services.dataset_service import DatasetService + + +class DatasetDeleteTestDataFactory: + """Factory class for creating test data and mock objects for dataset delete tests.""" + + @staticmethod + def create_dataset_mock( + dataset_id: str = "dataset-123", + tenant_id: str = "test-tenant-123", + created_by: str = "creator-456", + doc_form: str | None = None, + indexing_technique: str | None = "high_quality", + **kwargs, + ) -> Mock: + """Create a mock dataset with specified attributes.""" + dataset = Mock(spec=Dataset) + dataset.id = dataset_id + dataset.tenant_id = tenant_id + dataset.created_by = created_by + dataset.doc_form = doc_form + dataset.indexing_technique = indexing_technique + for key, value in kwargs.items(): + setattr(dataset, key, value) + return dataset + + @staticmethod + def create_user_mock( + user_id: str = "user-789", + tenant_id: str = "test-tenant-123", + role: TenantAccountRole = TenantAccountRole.ADMIN, + **kwargs, + ) -> Mock: + """Create a mock user with specified attributes.""" + user = Mock(spec=Account) + user.id = user_id + user.current_tenant_id = tenant_id + user.current_role = role + for key, value in kwargs.items(): + setattr(user, key, value) + return user + + +class TestDatasetServiceDeleteDataset: + """ + Comprehensive unit tests for DatasetService.delete_dataset method. + + This test suite covers all deletion scenarios including: + - Normal dataset deletion with documents + - Empty dataset deletion (no documents, doc_form is None) + - Dataset deletion with missing indexing_technique + - Permission checks + - Event handling + + This test suite provides regression protection for issue #27073. + """ + + @pytest.fixture + def mock_dataset_service_dependencies(self): + """Common mock setup for dataset service dependencies.""" + with ( + patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset, + patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm, + patch("extensions.ext_database.db.session") as mock_db, + patch("services.dataset_service.dataset_was_deleted") as mock_dataset_was_deleted, + ): + yield { + "get_dataset": mock_get_dataset, + "check_permission": mock_check_perm, + "db_session": mock_db, + "dataset_was_deleted": mock_dataset_was_deleted, + } + + def test_delete_dataset_with_documents_success(self, mock_dataset_service_dependencies): + """ + Test successful deletion of a dataset with documents. + + This test verifies: + - Dataset is retrieved correctly + - Permission check is performed + - dataset_was_deleted event is sent + - Dataset is deleted from database + - Method returns True + """ + # Arrange + dataset = DatasetDeleteTestDataFactory.create_dataset_mock( + doc_form="text_model", indexing_technique="high_quality" + ) + user = DatasetDeleteTestDataFactory.create_user_mock() + + mock_dataset_service_dependencies["get_dataset"].return_value = dataset + + # Act + result = DatasetService.delete_dataset(dataset.id, user) + + # Assert + assert result is True + mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id) + mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user) + mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset) + mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset) + mock_dataset_service_dependencies["db_session"].commit.assert_called_once() + + def test_delete_empty_dataset_success(self, mock_dataset_service_dependencies): + """ + Test successful deletion of an empty dataset (no documents, doc_form is None). + + This test verifies that: + - Empty datasets can be deleted without errors + - dataset_was_deleted event is sent (event handler will skip cleanup if doc_form is None) + - Dataset is deleted from database + - Method returns True + + This is the primary test for issue #27073 where deleting an empty dataset + caused internal server error due to assertion failure in event handlers. + """ + # Arrange + dataset = DatasetDeleteTestDataFactory.create_dataset_mock(doc_form=None, indexing_technique=None) + user = DatasetDeleteTestDataFactory.create_user_mock() + + mock_dataset_service_dependencies["get_dataset"].return_value = dataset + + # Act + result = DatasetService.delete_dataset(dataset.id, user) + + # Assert - Verify complete deletion flow + assert result is True + mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id) + mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user) + mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset) + mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset) + mock_dataset_service_dependencies["db_session"].commit.assert_called_once() + + def test_delete_dataset_with_partial_none_values(self, mock_dataset_service_dependencies): + """ + Test deletion of dataset with partial None values. + + This test verifies that datasets with partial None values (e.g., doc_form exists + but indexing_technique is None) can be deleted successfully. The event handler + will skip cleanup if any required field is None. + + Improvement based on Gemini Code Assist suggestion: Added comprehensive assertions + to verify all core deletion operations are performed, not just event sending. + """ + # Arrange + dataset = DatasetDeleteTestDataFactory.create_dataset_mock(doc_form="text_model", indexing_technique=None) + user = DatasetDeleteTestDataFactory.create_user_mock() + + mock_dataset_service_dependencies["get_dataset"].return_value = dataset + + # Act + result = DatasetService.delete_dataset(dataset.id, user) + + # Assert - Verify complete deletion flow (Gemini suggestion implemented) + assert result is True + mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id) + mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user) + mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset) + mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset) + mock_dataset_service_dependencies["db_session"].commit.assert_called_once() + + def test_delete_dataset_with_doc_form_none_indexing_technique_exists(self, mock_dataset_service_dependencies): + """ + Test deletion of dataset where doc_form is None but indexing_technique exists. + + This edge case can occur in certain dataset configurations and should be handled + gracefully by the event handler's conditional check. + """ + # Arrange + dataset = DatasetDeleteTestDataFactory.create_dataset_mock(doc_form=None, indexing_technique="high_quality") + user = DatasetDeleteTestDataFactory.create_user_mock() + + mock_dataset_service_dependencies["get_dataset"].return_value = dataset + + # Act + result = DatasetService.delete_dataset(dataset.id, user) + + # Assert - Verify complete deletion flow + assert result is True + mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id) + mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user) + mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset) + mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset) + mock_dataset_service_dependencies["db_session"].commit.assert_called_once() + + def test_delete_dataset_not_found(self, mock_dataset_service_dependencies): + """ + Test deletion attempt when dataset doesn't exist. + + This test verifies that: + - Method returns False when dataset is not found + - No deletion operations are performed + - No events are sent + """ + # Arrange + dataset_id = "non-existent-dataset" + user = DatasetDeleteTestDataFactory.create_user_mock() + + mock_dataset_service_dependencies["get_dataset"].return_value = None + + # Act + result = DatasetService.delete_dataset(dataset_id, user) + + # Assert + assert result is False + mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset_id) + mock_dataset_service_dependencies["check_permission"].assert_not_called() + mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_not_called() + mock_dataset_service_dependencies["db_session"].delete.assert_not_called() + mock_dataset_service_dependencies["db_session"].commit.assert_not_called() diff --git a/api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py b/api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py index 30990f8d50..e2607f0fb1 100644 --- a/api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py +++ b/api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py @@ -116,10 +116,10 @@ class TestSystemOAuthEncrypter: encrypter = SystemOAuthEncrypter("test_secret") with pytest.raises(Exception): # noqa: B017 - encrypter.encrypt_oauth_params(None) # type: ignore + encrypter.encrypt_oauth_params(None) with pytest.raises(Exception): # noqa: B017 - encrypter.encrypt_oauth_params("not_a_dict") # type: ignore + encrypter.encrypt_oauth_params("not_a_dict") def test_decrypt_oauth_params_basic(self): """Test basic OAuth parameters decryption""" @@ -207,12 +207,12 @@ class TestSystemOAuthEncrypter: encrypter = SystemOAuthEncrypter("test_secret") with pytest.raises(ValueError) as exc_info: - encrypter.decrypt_oauth_params(123) # type: ignore + encrypter.decrypt_oauth_params(123) assert "encrypted_data must be a string" in str(exc_info.value) with pytest.raises(ValueError) as exc_info: - encrypter.decrypt_oauth_params(None) # type: ignore + encrypter.decrypt_oauth_params(None) assert "encrypted_data must be a string" in str(exc_info.value) @@ -461,14 +461,14 @@ class TestConvenienceFunctions: """Test convenience functions with error conditions""" # Test encryption with invalid input with pytest.raises(Exception): # noqa: B017 - encrypt_system_oauth_params(None) # type: ignore + encrypt_system_oauth_params(None) # Test decryption with invalid input with pytest.raises(ValueError): decrypt_system_oauth_params("") with pytest.raises(ValueError): - decrypt_system_oauth_params(None) # type: ignore + decrypt_system_oauth_params(None) class TestErrorHandling: @@ -501,7 +501,7 @@ class TestErrorHandling: # Test non-string error with pytest.raises(ValueError) as exc_info: - encrypter.decrypt_oauth_params(123) # type: ignore + encrypter.decrypt_oauth_params(123) assert "encrypted_data must be a string" in str(exc_info.value) # Test invalid format error diff --git a/api/uv.lock b/api/uv.lock index 3558f81c59..c94f71bf62 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1301,7 +1301,7 @@ wheels = [ [[package]] name = "dify-api" -version = "1.9.1" +version = "1.9.2" source = { virtual = "." } dependencies = [ { name = "arize-phoenix-otel" }, diff --git a/dev/basedpyright-check b/dev/basedpyright-check index ef58ed1f57..1c87b27d6f 100755 --- a/dev/basedpyright-check +++ b/dev/basedpyright-check @@ -10,7 +10,7 @@ PATH_TO_CHECK="$1" # run basedpyright checks if [ -n "$PATH_TO_CHECK" ]; then - uv run --directory api --dev basedpyright "$PATH_TO_CHECK" + uv run --directory api --dev -- basedpyright --threads $(nproc) "$PATH_TO_CHECK" else - uv run --directory api --dev basedpyright + uv run --directory api --dev -- basedpyright --threads $(nproc) fi diff --git a/docker/.env.example b/docker/.env.example index a150730de8..31b9d54345 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -264,6 +264,18 @@ POSTGRES_MAINTENANCE_WORK_MEM=64MB # Reference: https://www.postgresql.org/docs/current/runtime-config-query.html#GUC-EFFECTIVE-CACHE-SIZE POSTGRES_EFFECTIVE_CACHE_SIZE=4096MB +# Sets the maximum allowed duration of any statement before termination. +# Default is 60000 milliseconds. +# +# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-STATEMENT-TIMEOUT +POSTGRES_STATEMENT_TIMEOUT=60000 + +# Sets the maximum allowed duration of any idle in-transaction session before termination. +# Default is 60000 milliseconds. +# +# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-IDLE-IN-TRANSACTION-SESSION-TIMEOUT +POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT=60000 + # ------------------------------ # Redis Configuration # This Redis configuration is used for caching and for pub/sub during conversation. diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 5a67c080cc..9650be90db 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -2,7 +2,7 @@ x-shared-env: &shared-api-worker-env services: # API service api: - image: langgenius/dify-api:1.9.1 + image: langgenius/dify-api:1.9.2 restart: always environment: # Use the shared environment variables. @@ -24,13 +24,6 @@ services: volumes: # Mount the storage directory to the container, for storing user files. - ./volumes/app/storage:/app/api/storage - # TODO: Remove this entrypoint override when weaviate-client 4.17.0 is included in the next Dify release - entrypoint: - - /bin/bash - - -c - - | - uv pip install --system weaviate-client==4.17.0 - exec /bin/bash /app/api/docker/entrypoint.sh networks: - ssrf_proxy_network - default @@ -38,7 +31,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:1.9.1 + image: langgenius/dify-api:1.9.2 restart: always environment: # Use the shared environment variables. @@ -58,13 +51,6 @@ services: volumes: # Mount the storage directory to the container, for storing user files. - ./volumes/app/storage:/app/api/storage - # TODO: Remove this entrypoint override when weaviate-client 4.17.0 is included in the next Dify release - entrypoint: - - /bin/bash - - -c - - | - uv pip install --system weaviate-client==4.17.0 - exec /bin/bash /app/api/docker/entrypoint.sh networks: - ssrf_proxy_network - default @@ -72,7 +58,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.9.1 + image: langgenius/dify-api:1.9.2 restart: always environment: # Use the shared environment variables. @@ -90,7 +76,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.9.1 + image: langgenius/dify-web:1.9.2 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -129,6 +115,8 @@ services: -c 'work_mem=${POSTGRES_WORK_MEM:-4MB}' -c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}' -c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}' + -c 'statement_timeout=${POSTGRES_STATEMENT_TIMEOUT:-60000}' + -c 'idle_in_transaction_session_timeout=${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-60000}' volumes: - ./volumes/db/data:/var/lib/postgresql/data healthcheck: @@ -191,7 +179,7 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.3.0-local + image: langgenius/dify-plugin-daemon:0.3.3-local restart: always environment: # Use the shared environment variables. diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index ebc619a50f..9a1b9b53ba 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -15,6 +15,8 @@ services: -c 'work_mem=${POSTGRES_WORK_MEM:-4MB}' -c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}' -c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}' + -c 'statement_timeout=${POSTGRES_STATEMENT_TIMEOUT:-60000}' + -c 'idle_in_transaction_session_timeout=${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-60000}' volumes: - ${PGDATA_HOST_VOLUME:-./volumes/db/data}:/var/lib/postgresql/data ports: @@ -85,7 +87,7 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.3.0-local + image: langgenius/dify-plugin-daemon:0.3.3-local restart: always env_file: - ./middleware.env diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 421b733e2b..d2ca6b859e 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -68,6 +68,8 @@ x-shared-env: &shared-api-worker-env POSTGRES_WORK_MEM: ${POSTGRES_WORK_MEM:-4MB} POSTGRES_MAINTENANCE_WORK_MEM: ${POSTGRES_MAINTENANCE_WORK_MEM:-64MB} POSTGRES_EFFECTIVE_CACHE_SIZE: ${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB} + POSTGRES_STATEMENT_TIMEOUT: ${POSTGRES_STATEMENT_TIMEOUT:-60000} + POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT: ${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-60000} REDIS_HOST: ${REDIS_HOST:-redis} REDIS_PORT: ${REDIS_PORT:-6379} REDIS_USERNAME: ${REDIS_USERNAME:-} @@ -609,7 +611,7 @@ x-shared-env: &shared-api-worker-env services: # API service api: - image: langgenius/dify-api:1.9.1 + image: langgenius/dify-api:1.9.2 restart: always environment: # Use the shared environment variables. @@ -631,13 +633,6 @@ services: volumes: # Mount the storage directory to the container, for storing user files. - ./volumes/app/storage:/app/api/storage - # TODO: Remove this entrypoint override when weaviate-client 4.17.0 is included in the next Dify release - entrypoint: - - /bin/bash - - -c - - | - uv pip install --system weaviate-client==4.17.0 - exec /bin/bash /app/api/docker/entrypoint.sh networks: - ssrf_proxy_network - default @@ -645,7 +640,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:1.9.1 + image: langgenius/dify-api:1.9.2 restart: always environment: # Use the shared environment variables. @@ -665,13 +660,6 @@ services: volumes: # Mount the storage directory to the container, for storing user files. - ./volumes/app/storage:/app/api/storage - # TODO: Remove this entrypoint override when weaviate-client 4.17.0 is included in the next Dify release - entrypoint: - - /bin/bash - - -c - - | - uv pip install --system weaviate-client==4.17.0 - exec /bin/bash /app/api/docker/entrypoint.sh networks: - ssrf_proxy_network - default @@ -679,7 +667,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.9.1 + image: langgenius/dify-api:1.9.2 restart: always environment: # Use the shared environment variables. @@ -697,7 +685,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.9.1 + image: langgenius/dify-web:1.9.2 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -736,6 +724,8 @@ services: -c 'work_mem=${POSTGRES_WORK_MEM:-4MB}' -c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}' -c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}' + -c 'statement_timeout=${POSTGRES_STATEMENT_TIMEOUT:-60000}' + -c 'idle_in_transaction_session_timeout=${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-60000}' volumes: - ./volumes/db/data:/var/lib/postgresql/data healthcheck: @@ -798,7 +788,7 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.3.0-local + image: langgenius/dify-plugin-daemon:0.3.3-local restart: always environment: # Use the shared environment variables. diff --git a/docker/middleware.env.example b/docker/middleware.env.example index 2eba62f594..c9bb8c0528 100644 --- a/docker/middleware.env.example +++ b/docker/middleware.env.example @@ -40,6 +40,18 @@ POSTGRES_MAINTENANCE_WORK_MEM=64MB # Reference: https://www.postgresql.org/docs/current/runtime-config-query.html#GUC-EFFECTIVE-CACHE-SIZE POSTGRES_EFFECTIVE_CACHE_SIZE=4096MB +# Sets the maximum allowed duration of any statement before termination. +# Default is 60000 milliseconds. +# +# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-STATEMENT-TIMEOUT +POSTGRES_STATEMENT_TIMEOUT=60000 + +# Sets the maximum allowed duration of any idle in-transaction session before termination. +# Default is 60000 milliseconds. +# +# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-IDLE-IN-TRANSACTION-SESSION-TIMEOUT +POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT=60000 + # ----------------------------- # Environment Variables for redis Service # ----------------------------- diff --git a/web/.storybook/__mocks__/context-block.tsx b/web/.storybook/__mocks__/context-block.tsx new file mode 100644 index 0000000000..8a9d8625cc --- /dev/null +++ b/web/.storybook/__mocks__/context-block.tsx @@ -0,0 +1,4 @@ +// Mock for context-block plugin to avoid circular dependency in Storybook +export const ContextBlockNode = null +export const ContextBlockReplacementBlock = null +export default null diff --git a/web/.storybook/__mocks__/history-block.tsx b/web/.storybook/__mocks__/history-block.tsx new file mode 100644 index 0000000000..e3c3965d13 --- /dev/null +++ b/web/.storybook/__mocks__/history-block.tsx @@ -0,0 +1,4 @@ +// Mock for history-block plugin to avoid circular dependency in Storybook +export const HistoryBlockNode = null +export const HistoryBlockReplacementBlock = null +export default null diff --git a/web/.storybook/__mocks__/query-block.tsx b/web/.storybook/__mocks__/query-block.tsx new file mode 100644 index 0000000000..d82f51363a --- /dev/null +++ b/web/.storybook/__mocks__/query-block.tsx @@ -0,0 +1,4 @@ +// Mock for query-block plugin to avoid circular dependency in Storybook +export const QueryBlockNode = null +export const QueryBlockReplacementBlock = null +export default null diff --git a/web/.storybook/main.ts b/web/.storybook/main.ts index 0605c71346..ca56261431 100644 --- a/web/.storybook/main.ts +++ b/web/.storybook/main.ts @@ -1,4 +1,8 @@ import type { StorybookConfig } from '@storybook/nextjs' +import path from 'node:path' +import { fileURLToPath } from 'node:url' + +const storybookDir = path.dirname(fileURLToPath(import.meta.url)) const config: StorybookConfig = { stories: ['../app/components/**/*.stories.@(js|jsx|mjs|ts|tsx)'], @@ -25,5 +29,17 @@ const config: StorybookConfig = { docs: { defaultName: 'Documentation', }, + webpackFinal: async (config) => { + // Add alias to mock problematic modules with circular dependencies + config.resolve = config.resolve || {} + config.resolve.alias = { + ...config.resolve.alias, + // Mock the plugin index files to avoid circular dependencies + [path.resolve(storybookDir, '../app/components/base/prompt-editor/plugins/context-block/index.tsx')]: path.resolve(storybookDir, '__mocks__/context-block.tsx'), + [path.resolve(storybookDir, '../app/components/base/prompt-editor/plugins/history-block/index.tsx')]: path.resolve(storybookDir, '__mocks__/history-block.tsx'), + [path.resolve(storybookDir, '../app/components/base/prompt-editor/plugins/query-block/index.tsx')]: path.resolve(storybookDir, '__mocks__/query-block.tsx'), + } + return config + }, } export default config diff --git a/web/.storybook/utils/audio-player-manager.mock.ts b/web/.storybook/utils/audio-player-manager.mock.ts new file mode 100644 index 0000000000..aca8b56b76 --- /dev/null +++ b/web/.storybook/utils/audio-player-manager.mock.ts @@ -0,0 +1,64 @@ +import { AudioPlayerManager } from '@/app/components/base/audio-btn/audio.player.manager' + +type PlayerCallback = ((event: string) => void) | null + +class MockAudioPlayer { + private callback: PlayerCallback = null + private finishTimer?: ReturnType + + public setCallback(callback: PlayerCallback) { + this.callback = callback + } + + public playAudio() { + this.clearTimer() + this.callback?.('play') + this.finishTimer = setTimeout(() => { + this.callback?.('ended') + }, 2000) + } + + public pauseAudio() { + this.clearTimer() + this.callback?.('paused') + } + + private clearTimer() { + if (this.finishTimer) + clearTimeout(this.finishTimer) + } +} + +class MockAudioPlayerManager { + private readonly player = new MockAudioPlayer() + + public getAudioPlayer( + _url: string, + _isPublic: boolean, + _id: string | undefined, + _msgContent: string | null | undefined, + _voice: string | undefined, + callback: PlayerCallback, + ) { + this.player.setCallback(callback) + return this.player + } + + public resetMsgId() { + // No-op for the mock + } +} + +export const ensureMockAudioManager = () => { + const managerAny = AudioPlayerManager as unknown as { + getInstance: () => AudioPlayerManager + __isStorybookMockInstalled?: boolean + } + + if (managerAny.__isStorybookMockInstalled) + return + + const mock = new MockAudioPlayerManager() + managerAny.getInstance = () => mock as unknown as AudioPlayerManager + managerAny.__isStorybookMockInstalled = true +} diff --git a/web/__tests__/navigation-utils.test.ts b/web/__tests__/navigation-utils.test.ts index 9a388505d6..fa4986e63d 100644 --- a/web/__tests__/navigation-utils.test.ts +++ b/web/__tests__/navigation-utils.test.ts @@ -160,8 +160,7 @@ describe('Navigation Utilities', () => { page: 1, limit: '', keyword: 'test', - empty: null, - undefined, + filter: '', }) expect(path).toBe('/datasets/123/documents?page=1&keyword=test') diff --git a/web/__tests__/real-browser-flicker.test.tsx b/web/__tests__/real-browser-flicker.test.tsx index f71e8de515..0a0ea0c062 100644 --- a/web/__tests__/real-browser-flicker.test.tsx +++ b/web/__tests__/real-browser-flicker.test.tsx @@ -39,28 +39,38 @@ const setupMockEnvironment = (storedTheme: string | null, systemPrefersDark = fa const isDarkQuery = DARK_MODE_MEDIA_QUERY.test(query) const matches = isDarkQuery ? systemPrefersDark : false + const handleAddListener = (listener: (event: MediaQueryListEvent) => void) => { + listeners.add(listener) + } + + const handleRemoveListener = (listener: (event: MediaQueryListEvent) => void) => { + listeners.delete(listener) + } + + const handleAddEventListener = (_event: string, listener: EventListener) => { + if (typeof listener === 'function') + listeners.add(listener as (event: MediaQueryListEvent) => void) + } + + const handleRemoveEventListener = (_event: string, listener: EventListener) => { + if (typeof listener === 'function') + listeners.delete(listener as (event: MediaQueryListEvent) => void) + } + + const handleDispatchEvent = (event: Event) => { + listeners.forEach(listener => listener(event as MediaQueryListEvent)) + return true + } + const mediaQueryList: MediaQueryList = { matches, media: query, onchange: null, - addListener: (listener: MediaQueryListListener) => { - listeners.add(listener) - }, - removeListener: (listener: MediaQueryListListener) => { - listeners.delete(listener) - }, - addEventListener: (_event, listener: EventListener) => { - if (typeof listener === 'function') - listeners.add(listener as MediaQueryListListener) - }, - removeEventListener: (_event, listener: EventListener) => { - if (typeof listener === 'function') - listeners.delete(listener as MediaQueryListListener) - }, - dispatchEvent: (event: Event) => { - listeners.forEach(listener => listener(event as MediaQueryListEvent)) - return true - }, + addListener: handleAddListener, + removeListener: handleRemoveListener, + addEventListener: handleAddEventListener, + removeEventListener: handleRemoveEventListener, + dispatchEvent: handleDispatchEvent, } return mediaQueryList @@ -69,6 +79,121 @@ const setupMockEnvironment = (storedTheme: string | null, systemPrefersDark = fa jest.spyOn(window, 'matchMedia').mockImplementation(mockMatchMedia) } +// Helper function to create timing page component +const createTimingPageComponent = ( + timingData: Array<{ phase: string; timestamp: number; styles: { backgroundColor: string; color: string } }>, +) => { + const recordTiming = (phase: string, styles: { backgroundColor: string; color: string }) => { + timingData.push({ + phase, + timestamp: performance.now(), + styles, + }) + } + + const TimingPageComponent = () => { + const [mounted, setMounted] = useState(false) + const { theme } = useTheme() + const isDark = mounted ? theme === 'dark' : false + + const currentStyles = { + backgroundColor: isDark ? '#1f2937' : '#ffffff', + color: isDark ? '#ffffff' : '#000000', + } + + recordTiming(mounted ? 'CSR' : 'Initial', currentStyles) + + useEffect(() => { + setMounted(true) + }, []) + + return ( +
+
+ Phase: {mounted ? 'CSR' : 'Initial'} | Theme: {theme} | Visual: {isDark ? 'dark' : 'light'} +
+
+ ) + } + + return TimingPageComponent +} + +// Helper function to create CSS test component +const createCSSTestComponent = ( + cssStates: Array<{ className: string; timestamp: number }>, +) => { + const recordCSSState = (className: string) => { + cssStates.push({ + className, + timestamp: performance.now(), + }) + } + + const CSSTestComponent = () => { + const [mounted, setMounted] = useState(false) + const { theme } = useTheme() + const isDark = mounted ? theme === 'dark' : false + + const className = `min-h-screen ${isDark ? 'bg-gray-900 text-white' : 'bg-white text-black'}` + + recordCSSState(className) + + useEffect(() => { + setMounted(true) + }, []) + + return ( +
+
Classes: {className}
+
+ ) + } + + return CSSTestComponent +} + +// Helper function to create performance test component +const createPerformanceTestComponent = ( + performanceMarks: Array<{ event: string; timestamp: number }>, +) => { + const recordPerformanceMark = (event: string) => { + performanceMarks.push({ event, timestamp: performance.now() }) + } + + const PerformanceTestComponent = () => { + const [mounted, setMounted] = useState(false) + const { theme } = useTheme() + + recordPerformanceMark('component-render') + + useEffect(() => { + recordPerformanceMark('mount-start') + setMounted(true) + recordPerformanceMark('mount-complete') + }, []) + + useEffect(() => { + if (theme) + recordPerformanceMark('theme-available') + }, [theme]) + + return ( +
+ Mounted: {mounted.toString()} | Theme: {theme || 'loading'} +
+ ) + } + + return PerformanceTestComponent +} + // Simulate real page component based on Dify's actual theme usage const PageComponent = () => { const [mounted, setMounted] = useState(false) @@ -227,39 +352,7 @@ describe('Real Browser Environment Dark Mode Flicker Test', () => { setupMockEnvironment('dark') const timingData: Array<{ phase: string; timestamp: number; styles: any }> = [] - - const TimingPageComponent = () => { - const [mounted, setMounted] = useState(false) - const { theme } = useTheme() - const isDark = mounted ? theme === 'dark' : false - - // Record timing and styles for each render phase - const currentStyles = { - backgroundColor: isDark ? '#1f2937' : '#ffffff', - color: isDark ? '#ffffff' : '#000000', - } - - timingData.push({ - phase: mounted ? 'CSR' : 'Initial', - timestamp: performance.now(), - styles: currentStyles, - }) - - useEffect(() => { - setMounted(true) - }, []) - - return ( -
-
- Phase: {mounted ? 'CSR' : 'Initial'} | Theme: {theme} | Visual: {isDark ? 'dark' : 'light'} -
-
- ) - } + const TimingPageComponent = createTimingPageComponent(timingData) render( @@ -295,33 +388,7 @@ describe('Real Browser Environment Dark Mode Flicker Test', () => { setupMockEnvironment('dark') const cssStates: Array<{ className: string; timestamp: number }> = [] - - const CSSTestComponent = () => { - const [mounted, setMounted] = useState(false) - const { theme } = useTheme() - const isDark = mounted ? theme === 'dark' : false - - // Simulate Tailwind CSS class application - const className = `min-h-screen ${isDark ? 'bg-gray-900 text-white' : 'bg-white text-black'}` - - cssStates.push({ - className, - timestamp: performance.now(), - }) - - useEffect(() => { - setMounted(true) - }, []) - - return ( -
-
Classes: {className}
-
- ) - } + const CSSTestComponent = createCSSTestComponent(cssStates) render( @@ -413,34 +480,12 @@ describe('Real Browser Environment Dark Mode Flicker Test', () => { test('verifies ThemeProvider position fix reduces initialization delay', async () => { const performanceMarks: Array<{ event: string; timestamp: number }> = [] - const PerformanceTestComponent = () => { - const [mounted, setMounted] = useState(false) - const { theme } = useTheme() - - performanceMarks.push({ event: 'component-render', timestamp: performance.now() }) - - useEffect(() => { - performanceMarks.push({ event: 'mount-start', timestamp: performance.now() }) - setMounted(true) - performanceMarks.push({ event: 'mount-complete', timestamp: performance.now() }) - }, []) - - useEffect(() => { - if (theme) - performanceMarks.push({ event: 'theme-available', timestamp: performance.now() }) - }, [theme]) - - return ( -
- Mounted: {mounted.toString()} | Theme: {theme || 'loading'} -
- ) - } - setupMockEnvironment('dark') expect(window.localStorage.getItem('theme')).toBe('dark') + const PerformanceTestComponent = createPerformanceTestComponent(performanceMarks) + render( diff --git a/web/__tests__/unified-tags-logic.test.ts b/web/__tests__/unified-tags-logic.test.ts index c920e28e0a..ec73a6a268 100644 --- a/web/__tests__/unified-tags-logic.test.ts +++ b/web/__tests__/unified-tags-logic.test.ts @@ -70,14 +70,18 @@ describe('Unified Tags Editing - Pure Logic Tests', () => { }) describe('Fallback Logic (from layout-main.tsx)', () => { + type Tag = { id: string; name: string } + type AppDetail = { tags: Tag[] } + type FallbackResult = { tags?: Tag[] } | null + // no-op it('should trigger fallback when tags are missing or empty', () => { - const appDetailWithoutTags = { tags: [] } - const appDetailWithTags = { tags: [{ id: 'tag1' }] } - const appDetailWithUndefinedTags = { tags: undefined as any } + const appDetailWithoutTags: AppDetail = { tags: [] } + const appDetailWithTags: AppDetail = { tags: [{ id: 'tag1', name: 't' }] } + const appDetailWithUndefinedTags: { tags: Tag[] | undefined } = { tags: undefined } // This simulates the condition in layout-main.tsx - const shouldFallback1 = !appDetailWithoutTags.tags || appDetailWithoutTags.tags.length === 0 - const shouldFallback2 = !appDetailWithTags.tags || appDetailWithTags.tags.length === 0 + const shouldFallback1 = appDetailWithoutTags.tags.length === 0 + const shouldFallback2 = appDetailWithTags.tags.length === 0 const shouldFallback3 = !appDetailWithUndefinedTags.tags || appDetailWithUndefinedTags.tags.length === 0 expect(shouldFallback1).toBe(true) // Empty array should trigger fallback @@ -86,24 +90,26 @@ describe('Unified Tags Editing - Pure Logic Tests', () => { }) it('should preserve tags when fallback succeeds', () => { - const originalAppDetail = { tags: [] as any[] } - const fallbackResult = { tags: [{ id: 'tag1', name: 'fallback-tag' }] } + const originalAppDetail: AppDetail = { tags: [] } + const fallbackResult: { tags?: Tag[] } = { tags: [{ id: 'tag1', name: 'fallback-tag' }] } // This simulates the successful fallback in layout-main.tsx - if (fallbackResult?.tags) - originalAppDetail.tags = fallbackResult.tags + const tags = fallbackResult.tags + if (tags) + originalAppDetail.tags = tags expect(originalAppDetail.tags).toEqual(fallbackResult.tags) expect(originalAppDetail.tags.length).toBe(1) }) it('should continue with empty tags when fallback fails', () => { - const originalAppDetail: { tags: any[] } = { tags: [] } - const fallbackResult: { tags?: any[] } | null = null + const originalAppDetail: AppDetail = { tags: [] } + const fallbackResult = null as FallbackResult // This simulates fallback failure in layout-main.tsx - if (fallbackResult?.tags) - originalAppDetail.tags = fallbackResult.tags + const tags: Tag[] | undefined = fallbackResult && 'tags' in fallbackResult ? fallbackResult.tags : undefined + if (tags) + originalAppDetail.tags = tags expect(originalAppDetail.tags).toEqual([]) }) diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx index e4c3f60c12..0ad02ad7f3 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx @@ -73,7 +73,7 @@ const ConfigPopup: FC = ({ } }, [onChooseProvider]) - const handleConfigUpdated = useCallback((payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig) => { + const handleConfigUpdated = useCallback((payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | TencentConfig) => { onConfigUpdated(currentProvider!, payload) hideConfigModal() }, [currentProvider, hideConfigModal, onConfigUpdated]) diff --git a/web/app/(commonLayout)/layout.tsx b/web/app/(commonLayout)/layout.tsx index ed1c995e25..be9c4fe49a 100644 --- a/web/app/(commonLayout)/layout.tsx +++ b/web/app/(commonLayout)/layout.tsx @@ -9,6 +9,7 @@ import { EventEmitterContextProvider } from '@/context/event-emitter' import { ProviderContextProvider } from '@/context/provider-context' import { ModalContextProvider } from '@/context/modal-context' import GotoAnything from '@/app/components/goto-anything' +import Zendesk from '@/app/components/base/zendesk' const Layout = ({ children }: { children: ReactNode }) => { return ( @@ -28,6 +29,7 @@ const Layout = ({ children }: { children: ReactNode }) => { + ) diff --git a/web/app/(shareLayout)/components/splash.tsx b/web/app/(shareLayout)/components/splash.tsx index c26ea7e045..16d291d4b4 100644 --- a/web/app/(shareLayout)/components/splash.tsx +++ b/web/app/(shareLayout)/components/splash.tsx @@ -6,7 +6,6 @@ import { useWebAppStore } from '@/context/web-app-context' import { useRouter, useSearchParams } from 'next/navigation' import AppUnavailable from '@/app/components/base/app-unavailable' import { useTranslation } from 'react-i18next' -import { AccessMode } from '@/models/access-control' import { webAppLoginStatus, webAppLogout } from '@/service/webapp-auth' import { fetchAccessToken } from '@/service/share' import Loading from '@/app/components/base/loading' @@ -35,7 +34,6 @@ const Splash: FC = ({ children }) => { router.replace(url) }, [getSigninUrl, router, webAppLogout, shareCode]) - const needCheckIsLogin = webAppAccessMode !== AccessMode.PUBLIC const [isLoading, setIsLoading] = useState(true) useEffect(() => { if (message) { @@ -58,8 +56,8 @@ const Splash: FC = ({ children }) => { } (async () => { - const { userLoggedIn, appLoggedIn } = await webAppLoginStatus(needCheckIsLogin, shareCode!) - + // if access mode is public, user login is always true, but the app login(passport) may be expired + const { userLoggedIn, appLoggedIn } = await webAppLoginStatus(shareCode!) if (userLoggedIn && appLoggedIn) { redirectOrFinish() } @@ -87,7 +85,6 @@ const Splash: FC = ({ children }) => { router, message, webAppAccessMode, - needCheckIsLogin, tokenFromUrl]) if (message) { diff --git a/web/app/components/app/annotation/index.tsx b/web/app/components/app/annotation/index.tsx index 264b1ac727..bc63b85f6d 100644 --- a/web/app/components/app/annotation/index.tsx +++ b/web/app/components/app/annotation/index.tsx @@ -53,7 +53,6 @@ const Annotation: FC = (props) => { const [isShowViewModal, setIsShowViewModal] = useState(false) const [selectedIds, setSelectedIds] = useState([]) const debouncedQueryParams = useDebounce(queryParams, { wait: 500 }) - const [isBatchDeleting, setIsBatchDeleting] = useState(false) const fetchAnnotationConfig = async () => { const res = await doFetchAnnotationConfig(appDetail.id) @@ -108,9 +107,6 @@ const Annotation: FC = (props) => { } const handleBatchDelete = async () => { - if (isBatchDeleting) - return - setIsBatchDeleting(true) try { await delAnnotations(appDetail.id, selectedIds) Toast.notify({ message: t('common.api.actionSuccess'), type: 'success' }) @@ -121,9 +117,6 @@ const Annotation: FC = (props) => { catch (e: any) { Toast.notify({ type: 'error', message: e.message || t('common.api.actionFailed') }) } - finally { - setIsBatchDeleting(false) - } } const handleView = (item: AnnotationItem) => { @@ -213,7 +206,6 @@ const Annotation: FC = (props) => { onSelectedIdsChange={setSelectedIds} onBatchDelete={handleBatchDelete} onCancel={() => setSelectedIds([])} - isBatchDeleting={isBatchDeleting} /> :
} diff --git a/web/app/components/app/annotation/list.tsx b/web/app/components/app/annotation/list.tsx index 6705ac5768..70ecedb869 100644 --- a/web/app/components/app/annotation/list.tsx +++ b/web/app/components/app/annotation/list.tsx @@ -19,7 +19,6 @@ type Props = { onSelectedIdsChange: (selectedIds: string[]) => void onBatchDelete: () => Promise onCancel: () => void - isBatchDeleting?: boolean } const List: FC = ({ @@ -30,7 +29,6 @@ const List: FC = ({ onSelectedIdsChange, onBatchDelete, onCancel, - isBatchDeleting, }) => { const { t } = useTranslation() const { formatTime } = useTimestamp() @@ -142,7 +140,6 @@ const List: FC = ({ selectedIds={selectedIds} onBatchDelete={onBatchDelete} onCancel={onCancel} - isBatchDeleting={isBatchDeleting} /> )} diff --git a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx index 70e0334e98..aa8d0f65ca 100644 --- a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx @@ -78,7 +78,9 @@ const AdvancedPromptInput: FC = ({ const handleOpenExternalDataToolModal = () => { setShowExternalDataToolModal({ payload: {}, - onSaveCallback: (newExternalDataTool: ExternalDataTool) => { + onSaveCallback: (newExternalDataTool?: ExternalDataTool) => { + if (!newExternalDataTool) + return eventEmitter?.emit({ type: ADD_EXTERNAL_DATA_TOOL, payload: newExternalDataTool, diff --git a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx index 169e8a14a2..8634232b2b 100644 --- a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx @@ -76,7 +76,9 @@ const Prompt: FC = ({ const handleOpenExternalDataToolModal = () => { setShowExternalDataToolModal({ payload: {}, - onSaveCallback: (newExternalDataTool: ExternalDataTool) => { + onSaveCallback: (newExternalDataTool?: ExternalDataTool) => { + if (!newExternalDataTool) + return eventEmitter?.emit({ type: ADD_EXTERNAL_DATA_TOOL, payload: newExternalDataTool, diff --git a/web/app/components/app/configuration/config-var/config-modal/index.tsx b/web/app/components/app/configuration/config-var/config-modal/index.tsx index de7d2c9eac..3f32c9b0c7 100644 --- a/web/app/components/app/configuration/config-var/config-modal/index.tsx +++ b/web/app/components/app/configuration/config-var/config-modal/index.tsx @@ -320,7 +320,7 @@ const ConfigModal: FC = ({ {type === InputVarType.paragraph && (