diff --git a/api/.env.example b/api/.env.example index 1d8190ce5f..4df6adf348 100644 --- a/api/.env.example +++ b/api/.env.example @@ -434,6 +434,9 @@ CODE_EXECUTION_SSL_VERIFY=True CODE_EXECUTION_POOL_MAX_CONNECTIONS=100 CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS=20 CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY=5.0 +CODE_EXECUTION_CONNECT_TIMEOUT=10 +CODE_EXECUTION_READ_TIMEOUT=60 +CODE_EXECUTION_WRITE_TIMEOUT=10 CODE_MAX_NUMBER=9223372036854775807 CODE_MIN_NUMBER=-9223372036854775808 CODE_MAX_STRING_LENGTH=400000 diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 6ce72e80df..a02f8a4d49 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -548,7 +548,7 @@ class UpdateConfig(BaseSettings): class WorkflowVariableTruncationConfig(BaseSettings): WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE: PositiveInt = Field( - # 100KB + # 1000 KiB 1024_000, description="Maximum size for variable to trigger final truncation.", ) diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index d872e8201b..816d0e442f 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -145,7 +145,7 @@ class DatabaseConfig(BaseSettings): default="postgresql", ) - @computed_field # type: ignore[misc] + @computed_field # type: ignore[prop-decorator] @property def SQLALCHEMY_DATABASE_URI(self) -> str: db_extras = ( @@ -198,7 +198,7 @@ class DatabaseConfig(BaseSettings): default=os.cpu_count() or 1, ) - @computed_field # type: ignore[misc] + @computed_field # type: ignore[prop-decorator] @property def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]: # Parse DB_EXTRAS for 'options' diff --git a/api/constants/__init__.py b/api/constants/__init__.py index 248cdfc09f..e441395afc 100644 --- a/api/constants/__init__.py +++ b/api/constants/__init__.py @@ -56,11 +56,15 @@ else: } DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions) +# console COOKIE_NAME_ACCESS_TOKEN = "access_token" COOKIE_NAME_REFRESH_TOKEN = "refresh_token" -COOKIE_NAME_PASSPORT = "passport" COOKIE_NAME_CSRF_TOKEN = "csrf_token" +# webapp +COOKIE_NAME_WEBAPP_ACCESS_TOKEN = "webapp_access_token" +COOKIE_NAME_PASSPORT = "passport" + HEADER_NAME_CSRF_TOKEN = "X-CSRF-Token" HEADER_NAME_APP_CODE = "X-App-Code" HEADER_NAME_PASSPORT = "X-App-Passport" diff --git a/api/constants/languages.py b/api/constants/languages.py index a509ddcf5d..0312a558c9 100644 --- a/api/constants/languages.py +++ b/api/constants/languages.py @@ -31,3 +31,9 @@ def supported_language(lang): error = f"{lang} is not a valid language." raise ValueError(error) + + +def get_valid_language(lang: str | None) -> str: + if lang and lang in languages: + return lang + return languages[0] diff --git a/api/controllers/common/helpers.py b/api/controllers/common/helpers.py index 6a5197635e..ef89e66980 100644 --- a/api/controllers/common/helpers.py +++ b/api/controllers/common/helpers.py @@ -24,7 +24,7 @@ except ImportError: ) else: warnings.warn("To use python-magic guess MIMETYPE, you need to install `libmagic`", stacklevel=2) - magic = None # type: ignore + magic = None # type: ignore[assignment] from pydantic import BaseModel diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 277f9a60a8..c0a565b5da 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -4,7 +4,7 @@ from flask_restx import Resource, reqparse import services from configs import dify_config -from constants.languages import languages +from constants.languages import get_valid_language from controllers.console import console_ns from controllers.console.auth.error import ( AuthenticationFailedError, @@ -29,8 +29,6 @@ from libs.token import ( clear_access_token_from_cookie, clear_csrf_token_from_cookie, clear_refresh_token_from_cookie, - extract_access_token, - extract_csrf_token, set_access_token_to_cookie, set_csrf_token_to_cookie, set_refresh_token_to_cookie, @@ -206,10 +204,12 @@ class EmailCodeLoginApi(Resource): .add_argument("email", type=str, required=True, location="json") .add_argument("code", type=str, required=True, location="json") .add_argument("token", type=str, required=True, location="json") + .add_argument("language", type=str, required=False, location="json") ) args = parser.parse_args() user_email = args["email"] + language = args["language"] token_data = AccountService.get_email_code_login_data(args["token"]) if token_data is None: @@ -243,7 +243,9 @@ class EmailCodeLoginApi(Resource): if account is None: try: account = AccountService.create_account_and_tenant( - email=user_email, name=user_email, interface_language=languages[0] + email=user_email, + name=user_email, + interface_language=get_valid_language(language), ) except WorkSpaceNotAllowedCreateError: raise NotAllowedCreateWorkspace() @@ -286,13 +288,3 @@ class RefreshTokenApi(Resource): return response except Exception as e: return {"result": "fail", "message": str(e)}, 401 - - -# this api helps frontend to check whether user is authenticated -# TODO: remove in the future. frontend should redirect to login page by catching 401 status -@console_ns.route("/login/status") -class LoginStatus(Resource): - def get(self): - token = extract_access_token(request) - csrf_token = extract_csrf_token(request) - return {"logged_in": bool(token) and bool(csrf_token)} diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index 3022d937b9..125f603a5a 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -22,7 +22,7 @@ from core.errors.error import ( from core.model_runtime.errors.invoke import InvokeError from core.workflow.graph_engine.manager import GraphEngineManager from libs import helper -from libs.login import current_user as current_user_ +from libs.login import current_account_with_tenant from models.model import AppMode, InstalledApp from services.app_generate_service import AppGenerateService from services.errors.llm import InvokeRateLimitError @@ -31,8 +31,6 @@ from .. import console_ns logger = logging.getLogger(__name__) -current_user = current_user_._get_current_object() # type: ignore - @console_ns.route("/installed-apps//workflows/run") class InstalledAppWorkflowRunApi(InstalledAppResource): @@ -40,6 +38,7 @@ class InstalledAppWorkflowRunApi(InstalledAppResource): """ Run workflow """ + current_user, _ = current_account_with_tenant() app_model = installed_app.app if not app_model: raise NotWorkflowAppError() @@ -53,7 +52,6 @@ class InstalledAppWorkflowRunApi(InstalledAppResource): .add_argument("files", type=list, required=False, location="json") ) args = parser.parse_args() - assert current_user is not None try: response = AppGenerateService.generate( app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True @@ -89,7 +87,6 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource): app_mode = AppMode.value_of(app_model.mode) if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() - assert current_user is not None # Stop using both mechanisms for backward compatibility # Legacy stop flag mechanism (without user check) diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index 6d2b22bde3..1200349e2d 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -74,12 +74,17 @@ class SetupApi(Resource): .add_argument("email", type=email, required=True, location="json") .add_argument("name", type=StrLen(30), required=True, location="json") .add_argument("password", type=valid_password, required=True, location="json") + .add_argument("language", type=str, required=False, location="json") ) args = parser.parse_args() # setup RegisterService.setup( - email=args["email"], name=args["name"], password=args["password"], ip_address=extract_remote_ip(request) + email=args["email"], + name=args["name"], + password=args["password"], + ip_address=extract_remote_ip(request), + language=args["language"], ) return {"result": "success"}, 201 diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index 85b7df229f..8d8fe6b3a8 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -193,15 +193,16 @@ class MCPAppApi(Resource): except ValidationError as e: raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}") - def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str, session: Session) -> EndUser | None: - """Get end user from existing session - optimized query""" - return ( - session.query(EndUser) - .where(EndUser.tenant_id == tenant_id) - .where(EndUser.session_id == mcp_server_id) - .where(EndUser.type == "mcp") - .first() - ) + def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str) -> EndUser | None: + """Get end user - manages its own database session""" + with Session(db.engine, expire_on_commit=False) as session, session.begin(): + return ( + session.query(EndUser) + .where(EndUser.tenant_id == tenant_id) + .where(EndUser.session_id == mcp_server_id) + .where(EndUser.type == "mcp") + .first() + ) def _create_end_user( self, client_name: str, tenant_id: str, app_id: str, mcp_server_id: str, session: Session @@ -229,7 +230,7 @@ class MCPAppApi(Resource): request_id: Union[int, str], ) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError | None: """Handle MCP request and return response""" - end_user = self._retrieve_end_user(mcp_server.tenant_id, mcp_server.id, session) + end_user = self._retrieve_end_user(mcp_server.tenant_id, mcp_server.id) if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest): client_info = mcp_request.root.params.clientInfo diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py index f213fd8c90..244ef47982 100644 --- a/api/controllers/web/login.py +++ b/api/controllers/web/login.py @@ -17,8 +17,8 @@ from libs.helper import email from libs.passport import PassportService from libs.password import valid_password from libs.token import ( - clear_access_token_from_cookie, - extract_access_token, + clear_webapp_access_token_from_cookie, + extract_webapp_access_token, ) from services.account_service import AccountService from services.app_service import AppService @@ -81,7 +81,7 @@ class LoginStatusApi(Resource): ) def get(self): app_code = request.args.get("app_code") - token = extract_access_token(request) + token = extract_webapp_access_token(request) if not app_code: return { "logged_in": bool(token), @@ -128,7 +128,7 @@ class LogoutApi(Resource): response = make_response({"result": "success"}) # enterprise SSO sets same site to None in https deployment # so we need to logout by calling api - clear_access_token_from_cookie(response, samesite="None") + clear_webapp_access_token_from_cookie(response, samesite="None") return response diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index 776b743e92..6a2e0b65fb 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -12,10 +12,8 @@ from controllers.web import web_ns from controllers.web.error import WebAppAuthRequiredError from extensions.ext_database import db from libs.passport import PassportService -from libs.token import extract_access_token +from libs.token import extract_webapp_access_token from models.model import App, EndUser, Site -from services.app_service import AppService -from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService from services.webapp_auth_service import WebAppAuthService, WebAppAuthType @@ -37,23 +35,18 @@ class PassportResource(Resource): system_features = FeatureService.get_system_features() app_code = request.headers.get(HEADER_NAME_APP_CODE) user_id = request.args.get("user_id") - access_token = extract_access_token(request) - + access_token = extract_webapp_access_token(request) if app_code is None: raise Unauthorized("X-App-Code header is missing.") - app_id = AppService.get_app_id_by_code(app_code) - # exchange token for enterprise logined web user - enterprise_user_decoded = decode_enterprise_webapp_user_id(access_token) - if enterprise_user_decoded: - # a web user has already logged in, exchange a token for this app without redirecting to the login page - return exchange_token_for_existing_web_user( - app_code=app_code, enterprise_user_decoded=enterprise_user_decoded - ) - if system_features.webapp_auth.enabled: - app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id) - if not app_settings or not app_settings.access_mode == "public": - raise WebAppAuthRequiredError() + enterprise_user_decoded = decode_enterprise_webapp_user_id(access_token) + app_auth_type = WebAppAuthService.get_app_auth_type(app_code=app_code) + if app_auth_type != WebAppAuthType.PUBLIC: + if not enterprise_user_decoded: + raise WebAppAuthRequiredError() + return exchange_token_for_existing_web_user( + app_code=app_code, enterprise_user_decoded=enterprise_user_decoded, auth_type=app_auth_type + ) # get site from db and check if it is normal site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal")) @@ -124,7 +117,7 @@ def decode_enterprise_webapp_user_id(jwt_token: str | None): return decoded -def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict): +def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict, auth_type: WebAppAuthType): """ Exchange a token for an existing web user session. """ @@ -145,13 +138,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: if not app_model or app_model.status != "normal" or not app_model.enable_site: raise NotFound() - app_auth_type = WebAppAuthService.get_app_auth_type(app_code=app_code) - - if app_auth_type == WebAppAuthType.PUBLIC: + if auth_type == WebAppAuthType.PUBLIC: return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded) - elif app_auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external": + elif auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external": raise WebAppAuthRequiredError("Please login as external user.") - elif app_auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal": + elif auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal": raise WebAppAuthRequiredError("Please login as internal user.") end_user = None diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index c6d98374c1..7bd3b8a56e 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -211,8 +211,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): user=user, stream=streaming, ) - # FIXME: Type hinting issue here, ignore it for now, will fix it later - return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) # type: ignore + return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def _generate_worker( self, diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 1fb076b685..f8bfbce37a 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -255,7 +255,7 @@ class PipelineGenerator(BaseAppGenerator): json_text = json.dumps(text) upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id) features = FeatureService.get_features(dataset.tenant_id) - if features.billing.subscription.plan == "sandbox": + if features.billing.enabled and features.billing.subscription.plan == "sandbox": tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}" tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}" diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py index 01ecf0298f..c64f44a603 100644 --- a/api/core/app/apps/workflow/generate_response_converter.py +++ b/api/core/app/apps/workflow/generate_response_converter.py @@ -89,7 +89,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): - response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute] + response_chunk.update(sub_stream_response.to_ignore_detail_dict()) else: response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index ffa10cd43c..565905be0d 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -98,7 +98,7 @@ class RateLimit: else: return RateLimitGenerator( rate_limit=self, - generator=generator, # ty: ignore [invalid-argument-type] + generator=generator, request_id=request_id, ) diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 45e3c0006b..26c7e60a4c 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -49,7 +49,7 @@ class BasedGenerateTaskPipeline: if isinstance(e, InvokeAuthorizationError): err = InvokeAuthorizationError("Incorrect API key provided") elif isinstance(e, InvokeError | ValueError): - err = e # ty: ignore [invalid-assignment] + err = e else: description = getattr(e, "description", None) err = Exception(description if description is not None else str(e)) diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index c4be429219..b10838f8c9 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -1868,7 +1868,7 @@ class ProviderConfigurations(BaseModel): if "/" not in key: key = str(ModelProviderID(key)) - return self.configurations.get(key, default) # type: ignore + return self.configurations.get(key, default) class ProviderModelBundle(BaseModel): diff --git a/api/core/helper/module_import_helper.py b/api/core/helper/module_import_helper.py index 6a2f27b8ba..2bada85582 100644 --- a/api/core/helper/module_import_helper.py +++ b/api/core/helper/module_import_helper.py @@ -20,7 +20,7 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz else: # Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly # FIXME: mypy does not support the type of spec.loader - spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore + spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore[assignment] if not spec or not spec.loader: raise Exception(f"Failed to load module {module_name} from {py_file_path!r}") if use_lazy_loader: diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 7822ed4268..36b38b7b45 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -49,62 +49,80 @@ class IndexingRunner: self.storage = storage self.model_manager = ModelManager() + def _handle_indexing_error(self, document_id: str, error: Exception) -> None: + """Handle indexing errors by updating document status.""" + logger.exception("consume document failed") + document = db.session.get(DatasetDocument, document_id) + if document: + document.indexing_status = "error" + error_message = getattr(error, "description", str(error)) + document.error = str(error_message) + document.stopped_at = naive_utc_now() + db.session.commit() + def run(self, dataset_documents: list[DatasetDocument]): """Run the indexing process.""" for dataset_document in dataset_documents: + document_id = dataset_document.id try: + # Re-query the document to ensure it's bound to the current session + requeried_document = db.session.get(DatasetDocument, document_id) + if not requeried_document: + logger.warning("Document not found, skipping document id: %s", document_id) + continue + # get dataset - dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first() + dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first() if not dataset: raise ValueError("no dataset found") # get the process rule stmt = select(DatasetProcessRule).where( - DatasetProcessRule.id == dataset_document.dataset_process_rule_id + DatasetProcessRule.id == requeried_document.dataset_process_rule_id ) processing_rule = db.session.scalar(stmt) if not processing_rule: raise ValueError("no process rule found") - index_type = dataset_document.doc_form + index_type = requeried_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() # extract - text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict()) + text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict()) # transform documents = self._transform( - index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict() + index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict() ) # save segment - self._load_segments(dataset, dataset_document, documents) + self._load_segments(dataset, requeried_document, documents) # load self._load( index_processor=index_processor, dataset=dataset, - dataset_document=dataset_document, + dataset_document=requeried_document, documents=documents, ) except DocumentIsPausedError: - raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}") + raise DocumentIsPausedError(f"Document paused, document id: {document_id}") except ProviderTokenNotInitError as e: - dataset_document.indexing_status = "error" - dataset_document.error = str(e.description) - dataset_document.stopped_at = naive_utc_now() - db.session.commit() + self._handle_indexing_error(document_id, e) except ObjectDeletedError: - logger.warning("Document deleted, document id: %s", dataset_document.id) + logger.warning("Document deleted, document id: %s", document_id) except Exception as e: - logger.exception("consume document failed") - dataset_document.indexing_status = "error" - dataset_document.error = str(e) - dataset_document.stopped_at = naive_utc_now() - db.session.commit() + self._handle_indexing_error(document_id, e) def run_in_splitting_status(self, dataset_document: DatasetDocument): """Run the indexing process when the index_status is splitting.""" + document_id = dataset_document.id try: + # Re-query the document to ensure it's bound to the current session + requeried_document = db.session.get(DatasetDocument, document_id) + if not requeried_document: + logger.warning("Document not found: %s", document_id) + return + # get dataset - dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first() + dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first() if not dataset: raise ValueError("no dataset found") @@ -112,57 +130,60 @@ class IndexingRunner: # get exist document_segment list and delete document_segments = ( db.session.query(DocumentSegment) - .filter_by(dataset_id=dataset.id, document_id=dataset_document.id) + .filter_by(dataset_id=dataset.id, document_id=requeried_document.id) .all() ) for document_segment in document_segments: db.session.delete(document_segment) - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX: # delete child chunks db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete() db.session.commit() # get the process rule - stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) + stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == requeried_document.dataset_process_rule_id) processing_rule = db.session.scalar(stmt) if not processing_rule: raise ValueError("no process rule found") - index_type = dataset_document.doc_form + index_type = requeried_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() # extract - text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict()) + text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict()) # transform documents = self._transform( - index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict() + index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict() ) # save segment - self._load_segments(dataset, dataset_document, documents) + self._load_segments(dataset, requeried_document, documents) # load self._load( - index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents + index_processor=index_processor, + dataset=dataset, + dataset_document=requeried_document, + documents=documents, ) except DocumentIsPausedError: - raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}") + raise DocumentIsPausedError(f"Document paused, document id: {document_id}") except ProviderTokenNotInitError as e: - dataset_document.indexing_status = "error" - dataset_document.error = str(e.description) - dataset_document.stopped_at = naive_utc_now() - db.session.commit() + self._handle_indexing_error(document_id, e) except Exception as e: - logger.exception("consume document failed") - dataset_document.indexing_status = "error" - dataset_document.error = str(e) - dataset_document.stopped_at = naive_utc_now() - db.session.commit() + self._handle_indexing_error(document_id, e) def run_in_indexing_status(self, dataset_document: DatasetDocument): """Run the indexing process when the index_status is indexing.""" + document_id = dataset_document.id try: + # Re-query the document to ensure it's bound to the current session + requeried_document = db.session.get(DatasetDocument, document_id) + if not requeried_document: + logger.warning("Document not found: %s", document_id) + return + # get dataset - dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first() + dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first() if not dataset: raise ValueError("no dataset found") @@ -170,7 +191,7 @@ class IndexingRunner: # get exist document_segment list and delete document_segments = ( db.session.query(DocumentSegment) - .filter_by(dataset_id=dataset.id, document_id=dataset_document.id) + .filter_by(dataset_id=dataset.id, document_id=requeried_document.id) .all() ) @@ -188,7 +209,7 @@ class IndexingRunner: "dataset_id": document_segment.dataset_id, }, ) - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX: child_chunks = document_segment.get_child_chunks() if child_chunks: child_documents = [] @@ -206,24 +227,20 @@ class IndexingRunner: document.children = child_documents documents.append(document) # build index - index_type = dataset_document.doc_form + index_type = requeried_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() self._load( - index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents + index_processor=index_processor, + dataset=dataset, + dataset_document=requeried_document, + documents=documents, ) except DocumentIsPausedError: - raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}") + raise DocumentIsPausedError(f"Document paused, document id: {document_id}") except ProviderTokenNotInitError as e: - dataset_document.indexing_status = "error" - dataset_document.error = str(e.description) - dataset_document.stopped_at = naive_utc_now() - db.session.commit() + self._handle_indexing_error(document_id, e) except Exception as e: - logger.exception("consume document failed") - dataset_document.indexing_status = "error" - dataset_document.error = str(e) - dataset_document.stopped_at = naive_utc_now() - db.session.commit() + self._handle_indexing_error(document_id, e) def indexing_estimate( self, @@ -398,7 +415,6 @@ class IndexingRunner: document_id=dataset_document.id, after_indexing_status="splitting", extra_update_params={ - DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs), DatasetDocument.parsing_completed_at: naive_utc_now(), }, ) @@ -738,6 +754,7 @@ class IndexingRunner: extra_update_params={ DatasetDocument.cleaning_completed_at: cur_time, DatasetDocument.splitting_completed_at: cur_time, + DatasetDocument.word_count: sum(len(doc.page_content) for doc in documents), }, ) diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index e64ac25ab1..bd893b17f1 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -100,7 +100,7 @@ class LLMGenerator: return name @classmethod - def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str): + def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str) -> Sequence[str]: output_parser = SuggestedQuestionsAfterAnswerOutputParser() format_instructions = output_parser.get_format_instructions() @@ -119,6 +119,8 @@ class LLMGenerator: prompt_messages = [UserPromptMessage(content=prompt)] + questions: Sequence[str] = [] + try: response: LLMResult = model_instance.invoke_llm( prompt_messages=list(prompt_messages), diff --git a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py index e78859cc1a..eec771181f 100644 --- a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py +++ b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py @@ -1,17 +1,26 @@ import json +import logging import re +from collections.abc import Sequence from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT +logger = logging.getLogger(__name__) + class SuggestedQuestionsAfterAnswerOutputParser: def get_format_instructions(self) -> str: return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT - def parse(self, text: str): + def parse(self, text: str) -> Sequence[str]: action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL) + questions: list[str] = [] if action_match is not None: - json_obj = json.loads(action_match.group(0).strip()) - else: - json_obj = [] - return json_obj + try: + json_obj = json.loads(action_match.group(0).strip()) + except json.JSONDecodeError as exc: + logger.warning("Failed to decode suggested questions payload: %s", exc) + else: + if isinstance(json_obj, list): + questions = [question for question in json_obj if isinstance(question, str)] + return questions diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 92e6b8ea60..4de4f403ce 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -2,7 +2,7 @@ import logging import os from datetime import datetime, timedelta -from langfuse import Langfuse # type: ignore +from langfuse import Langfuse from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance diff --git a/api/core/plugin/entities/parameters.py b/api/core/plugin/entities/parameters.py index 68b5c1084a..1e7f8e4c86 100644 --- a/api/core/plugin/entities/parameters.py +++ b/api/core/plugin/entities/parameters.py @@ -76,7 +76,7 @@ class PluginParameter(BaseModel): auto_generate: PluginParameterAutoGenerate | None = None template: PluginParameterTemplate | None = None required: bool = False - default: Union[float, int, str] | None = None + default: Union[float, int, str, bool] | None = None min: Union[float, int] | None = None max: Union[float, int] | None = None precision: int | None = None diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index 5095b46432..e9dc58eec8 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -180,7 +180,7 @@ class BasePluginClient: Make a request to the plugin daemon inner API and return the response as a model. """ response = self._request(method, path, headers, data, params, files) - return type_(**response.json()) # type: ignore + return type_(**response.json()) # type: ignore[return-value] def _request_with_plugin_daemon_response( self, diff --git a/api/core/plugin/impl/exc.py b/api/core/plugin/impl/exc.py index 23a69bd92f..e28a324217 100644 --- a/api/core/plugin/impl/exc.py +++ b/api/core/plugin/impl/exc.py @@ -40,7 +40,7 @@ class PluginDaemonBadRequestError(PluginDaemonClientSideError): description: str = "Bad Request" -class PluginInvokeError(PluginDaemonClientSideError): +class PluginInvokeError(PluginDaemonClientSideError, ValueError): description: str = "Invoke Error" def _get_error_object(self) -> Mapping: diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 99bbe615fb..45b19f25a0 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -72,6 +72,19 @@ default_retrieval_model: dict[str, Any] = { class DatasetRetrieval: def __init__(self, application_generate_entity=None): self.application_generate_entity = application_generate_entity + self._llm_usage = LLMUsage.empty_usage() + + @property + def llm_usage(self) -> LLMUsage: + return self._llm_usage.model_copy() + + def _record_usage(self, usage: LLMUsage | None) -> None: + if usage is None or usage.total_tokens <= 0: + return + if self._llm_usage.total_tokens == 0: + self._llm_usage = usage + else: + self._llm_usage = self._llm_usage.plus(usage) def retrieve( self, @@ -312,15 +325,18 @@ class DatasetRetrieval: ) tools.append(message_tool) dataset_id = None + router_usage = LLMUsage.empty_usage() if planning_strategy == PlanningStrategy.REACT_ROUTER: react_multi_dataset_router = ReactMultiDatasetRouter() - dataset_id = react_multi_dataset_router.invoke( + dataset_id, router_usage = react_multi_dataset_router.invoke( query, tools, model_config, model_instance, user_id, tenant_id ) elif planning_strategy == PlanningStrategy.ROUTER: function_call_router = FunctionCallMultiDatasetRouter() - dataset_id = function_call_router.invoke(query, tools, model_config, model_instance) + dataset_id, router_usage = function_call_router.invoke(query, tools, model_config, model_instance) + + self._record_usage(router_usage) if dataset_id: # get retrieval model config @@ -983,7 +999,8 @@ class DatasetRetrieval: ) # handle invoke result - result_text, _ = self._handle_invoke_result(invoke_result=invoke_result) + result_text, usage = self._handle_invoke_result(invoke_result=invoke_result) + self._record_usage(usage) result_text_json = parse_and_check_json_markdown(result_text, []) automatic_metadata_filters = [] diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py index de59c6380e..5f3e1a8cae 100644 --- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py +++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py @@ -2,7 +2,7 @@ from typing import Union from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage @@ -13,15 +13,15 @@ class FunctionCallMultiDatasetRouter: dataset_tools: list[PromptMessageTool], model_config: ModelConfigWithCredentialsEntity, model_instance: ModelInstance, - ) -> Union[str, None]: + ) -> tuple[Union[str, None], LLMUsage]: """Given input, decided what to do. Returns: Action specifying what tool to use. """ if len(dataset_tools) == 0: - return None + return None, LLMUsage.empty_usage() elif len(dataset_tools) == 1: - return dataset_tools[0].name + return dataset_tools[0].name, LLMUsage.empty_usage() try: prompt_messages = [ @@ -34,9 +34,10 @@ class FunctionCallMultiDatasetRouter: stream=False, model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500}, ) + usage = result.usage or LLMUsage.empty_usage() if result.message.tool_calls: # get retrieval model config - return result.message.tool_calls[0].function.name - return None + return result.message.tool_calls[0].function.name, usage + return None, usage except Exception: - return None + return None, LLMUsage.empty_usage() diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index 59d36229b3..8f3bec2704 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -58,15 +58,15 @@ class ReactMultiDatasetRouter: model_instance: ModelInstance, user_id: str, tenant_id: str, - ) -> Union[str, None]: + ) -> tuple[Union[str, None], LLMUsage]: """Given input, decided what to do. Returns: Action specifying what tool to use. """ if len(dataset_tools) == 0: - return None + return None, LLMUsage.empty_usage() elif len(dataset_tools) == 1: - return dataset_tools[0].name + return dataset_tools[0].name, LLMUsage.empty_usage() try: return self._react_invoke( @@ -78,7 +78,7 @@ class ReactMultiDatasetRouter: tenant_id=tenant_id, ) except Exception: - return None + return None, LLMUsage.empty_usage() def _react_invoke( self, @@ -91,7 +91,7 @@ class ReactMultiDatasetRouter: prefix: str = PREFIX, suffix: str = SUFFIX, format_instructions: str = FORMAT_INSTRUCTIONS, - ) -> Union[str, None]: + ) -> tuple[Union[str, None], LLMUsage]: prompt: Union[list[ChatModelMessage], CompletionModelPromptTemplate] if model_config.mode == "chat": prompt = self.create_chat_prompt( @@ -120,7 +120,7 @@ class ReactMultiDatasetRouter: memory=None, model_config=model_config, ) - result_text, _ = self._invoke_llm( + result_text, usage = self._invoke_llm( completion_param=model_config.parameters, model_instance=model_instance, prompt_messages=prompt_messages, @@ -131,8 +131,8 @@ class ReactMultiDatasetRouter: output_parser = StructuredChatOutputParser() react_decision = output_parser.parse(result_text) if isinstance(react_decision, ReactAction): - return react_decision.tool - return None + return react_decision.tool, usage + return None, usage def _invoke_llm( self, diff --git a/api/core/repositories/celery_workflow_execution_repository.py b/api/core/repositories/celery_workflow_execution_repository.py index 460bb75722..c7f5942f5f 100644 --- a/api/core/repositories/celery_workflow_execution_repository.py +++ b/api/core/repositories/celery_workflow_execution_repository.py @@ -74,7 +74,7 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository): tenant_id = extract_tenant_id(user) if not tenant_id: raise ValueError("User must have a tenant_id or current_tenant_id") - self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None + self._tenant_id = tenant_id # Store app context self._app_id = app_id diff --git a/api/core/repositories/celery_workflow_node_execution_repository.py b/api/core/repositories/celery_workflow_node_execution_repository.py index 21a0b7eefe..9b8e45b1eb 100644 --- a/api/core/repositories/celery_workflow_node_execution_repository.py +++ b/api/core/repositories/celery_workflow_node_execution_repository.py @@ -81,7 +81,7 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): tenant_id = extract_tenant_id(user) if not tenant_id: raise ValueError("User must have a tenant_id or current_tenant_id") - self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None + self._tenant_id = tenant_id # Store app context self._app_id = app_id diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py index 854c122331..02fcabab5d 100644 --- a/api/core/repositories/factory.py +++ b/api/core/repositories/factory.py @@ -60,7 +60,7 @@ class DifyCoreRepositoryFactory: try: repository_class = import_string(class_path) - return repository_class( # type: ignore[no-any-return] + return repository_class( session_factory=session_factory, user=user, app_id=app_id, @@ -96,7 +96,7 @@ class DifyCoreRepositoryFactory: try: repository_class = import_string(class_path) - return repository_class( # type: ignore[no-any-return] + return repository_class( session_factory=session_factory, user=user, app_id=app_id, diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index 2e94907f30..a391136a5c 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -157,7 +157,7 @@ class BuiltinToolProviderController(ToolProviderController): """ returns the tool that the provider can provide """ - return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None) # type: ignore + return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None) @property def need_credentials(self) -> bool: diff --git a/api/core/tools/builtin_tool/providers/audio/tools/tts.py b/api/core/tools/builtin_tool/providers/audio/tools/tts.py index 8bc159bb85..5009f7ac21 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/tts.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/tts.py @@ -43,7 +43,7 @@ class TTSTool(BuiltinTool): content_text=tool_parameters.get("text"), # type: ignore user=user_id, tenant_id=self.runtime.tenant_id, - voice=voice, # type: ignore + voice=voice, ) buffer = io.BytesIO() for chunk in tts: diff --git a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py index 197b062e44..d0a41b940f 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py +++ b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py @@ -34,6 +34,7 @@ class LocaltimeToTimestampTool(BuiltinTool): yield self.create_text_message(f"{timestamp}") + # TODO: this method's type is messy @staticmethod def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None: try: diff --git a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py index babfa9bcd9..e23ae3b001 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py +++ b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py @@ -48,6 +48,6 @@ class TimezoneConversionTool(BuiltinTool): datetime_with_tz = input_timezone.localize(local_time) # timezone convert converted_datetime = datetime_with_tz.astimezone(output_timezone) - return converted_datetime.strftime(format=time_format) # type: ignore + return converted_datetime.strftime(time_format) except Exception as e: raise ToolInvokeError(str(e)) diff --git a/api/core/tools/mcp_tool/provider.py b/api/core/tools/mcp_tool/provider.py index 3404f5c3b4..557211c8c8 100644 --- a/api/core/tools/mcp_tool/provider.py +++ b/api/core/tools/mcp_tool/provider.py @@ -113,7 +113,7 @@ class MCPToolProviderController(ToolProviderController): """ pass - def get_tool(self, tool_name: str) -> MCPTool: # type: ignore + def get_tool(self, tool_name: str) -> MCPTool: """ return tool with given name """ @@ -136,7 +136,7 @@ class MCPToolProviderController(ToolProviderController): sse_read_timeout=self.sse_read_timeout, ) - def get_tools(self) -> list[MCPTool]: # type: ignore + def get_tools(self) -> list[MCPTool]: """ get all tools """ diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index 39646b7fc8..90d5a647e9 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -26,7 +26,7 @@ class ToolLabelManager: labels = cls.filter_tool_labels(labels) if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): - provider_id = controller.provider_id # ty: ignore [unresolved-attribute] + provider_id = controller.provider_id else: raise ValueError("Unsupported tool type") @@ -51,7 +51,7 @@ class ToolLabelManager: Get tool labels """ if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): - provider_id = controller.provider_id # ty: ignore [unresolved-attribute] + provider_id = controller.provider_id elif isinstance(controller, BuiltinToolProviderController): return controller.tool_labels else: @@ -85,7 +85,7 @@ class ToolLabelManager: provider_ids = [] for controller in tool_providers: assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController) - provider_ids.append(controller.provider_id) # ty: ignore [unresolved-attribute] + provider_ids.append(controller.provider_id) labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all() diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 0641fa01fe..ff7dcc0e55 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -331,7 +331,8 @@ class ToolManager: workflow_provider_stmt = select(WorkflowToolProvider).where( WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id ) - workflow_provider = db.session.scalar(workflow_provider_stmt) + with Session(db.engine, expire_on_commit=False) as session, session.begin(): + workflow_provider = session.scalar(workflow_provider_stmt) if workflow_provider is None: raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index 915a22dd0f..f96510fb45 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -193,18 +193,18 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): DatasetDocument.enabled == True, DatasetDocument.archived == False, ) - document = db.session.scalar(dataset_document_stmt) # type: ignore + document = db.session.scalar(dataset_document_stmt) if dataset and document: source = RetrievalSourceMetadata( dataset_id=dataset.id, dataset_name=dataset.name, - document_id=document.id, # type: ignore - document_name=document.name, # type: ignore - data_source_type=document.data_source_type, # type: ignore + document_id=document.id, + document_name=document.name, + data_source_type=document.data_source_type, segment_id=segment.id, retriever_from=self.retriever_from, score=record.score or 0.0, - doc_metadata=document.doc_metadata, # type: ignore + doc_metadata=document.doc_metadata, ) if self.retriever_from == "dev": diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index c7ac3387e5..6eabde3991 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -62,6 +62,11 @@ class ApiBasedToolSchemaParser: root = root[ref] interface["operation"]["parameters"][i] = root for parameter in interface["operation"]["parameters"]: + # Handle complex type defaults that are not supported by PluginParameter + default_value = None + if "schema" in parameter and "default" in parameter["schema"]: + default_value = ApiBasedToolSchemaParser._sanitize_default_value(parameter["schema"]["default"]) + tool_parameter = ToolParameter( name=parameter["name"], label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]), @@ -72,9 +77,7 @@ class ApiBasedToolSchemaParser: required=parameter.get("required", False), form=ToolParameter.ToolParameterForm.LLM, llm_description=parameter.get("description"), - default=parameter["schema"]["default"] - if "schema" in parameter and "default" in parameter["schema"] - else None, + default=default_value, placeholder=I18nObject( en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "") ), @@ -134,6 +137,11 @@ class ApiBasedToolSchemaParser: required = body_schema.get("required", []) properties = body_schema.get("properties", {}) for name, property in properties.items(): + # Handle complex type defaults that are not supported by PluginParameter + default_value = ApiBasedToolSchemaParser._sanitize_default_value( + property.get("default", None) + ) + tool = ToolParameter( name=name, label=I18nObject(en_US=name, zh_Hans=name), @@ -144,12 +152,11 @@ class ApiBasedToolSchemaParser: required=name in required, form=ToolParameter.ToolParameterForm.LLM, llm_description=property.get("description", ""), - default=property.get("default", None), + default=default_value, placeholder=I18nObject( en_US=property.get("description", ""), zh_Hans=property.get("description", "") ), ) - # check if there is a type typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property) if typ: @@ -197,6 +204,22 @@ class ApiBasedToolSchemaParser: return bundles + @staticmethod + def _sanitize_default_value(value): + """ + Sanitize default values for PluginParameter compatibility. + Complex types (list, dict) are converted to None to avoid validation errors. + + Args: + value: The default value from OpenAPI schema + + Returns: + None for complex types (list, dict), otherwise the original value + """ + if isinstance(value, (list, dict)): + return None + return value + @staticmethod def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None: parameter = parameter or {} @@ -217,7 +240,11 @@ class ApiBasedToolSchemaParser: return ToolParameter.ToolParameterType.STRING elif typ == "array": items = parameter.get("items") or parameter.get("schema", {}).get("items") - return ToolParameter.ToolParameterType.FILES if items and items.get("format") == "binary" else None + if items and items.get("format") == "binary": + return ToolParameter.ToolParameterType.FILES + else: + # For regular arrays, return ARRAY type instead of None + return ToolParameter.ToolParameterType.ARRAY else: return None diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index 52c16c34a0..ef6913d0bd 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -6,8 +6,8 @@ from typing import Any, cast from urllib.parse import unquote import chardet -import cloudscraper # type: ignore -from readabilipy import simple_json_from_html_string # type: ignore +import cloudscraper +from readabilipy import simple_json_from_html_string from core.helper import ssrf_proxy from core.rag.extractor import extract_processor @@ -63,8 +63,8 @@ def get_url(url: str, user_agent: str | None = None) -> str: response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) elif response.status_code == 403: scraper = cloudscraper.create_scraper() - scraper.perform_request = ssrf_proxy.make_request # type: ignore - response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) # type: ignore + scraper.perform_request = ssrf_proxy.make_request + response = scraper.get(url, headers=headers, timeout=(120, 300)) if response.status_code != 200: return f"URL returned status code {response.status_code}." diff --git a/api/core/tools/utils/yaml_utils.py b/api/core/tools/utils/yaml_utils.py index e9b5dab7d3..071154ee71 100644 --- a/api/core/tools/utils/yaml_utils.py +++ b/api/core/tools/utils/yaml_utils.py @@ -3,7 +3,7 @@ from functools import lru_cache from pathlib import Path from typing import Any -import yaml # type: ignore +import yaml from yaml import YAMLError logger = logging.getLogger(__name__) diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index 4d9c8895fc..d7afbc7389 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -1,6 +1,7 @@ from collections.abc import Mapping from pydantic import Field +from sqlalchemy.orm import Session from core.app.app_config.entities import VariableEntity, VariableEntityType from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager @@ -20,6 +21,7 @@ from core.tools.entities.tool_entities import ( from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db +from models.account import Account from models.model import App, AppMode from models.tools import WorkflowToolProvider from models.workflow import Workflow @@ -44,29 +46,34 @@ class WorkflowToolProviderController(ToolProviderController): @classmethod def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController": - app = db_provider.app + with Session(db.engine, expire_on_commit=False) as session, session.begin(): + provider = session.get(WorkflowToolProvider, db_provider.id) if db_provider.id else None + if not provider: + raise ValueError("workflow provider not found") + app = session.get(App, provider.app_id) + if not app: + raise ValueError("app not found") - if not app: - raise ValueError("app not found") + user = session.get(Account, provider.user_id) if provider.user_id else None - controller = WorkflowToolProviderController( - entity=ToolProviderEntity( - identity=ToolProviderIdentity( - author=db_provider.user.name if db_provider.user_id and db_provider.user else "", - name=db_provider.label, - label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label), - description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description), - icon=db_provider.icon, + controller = WorkflowToolProviderController( + entity=ToolProviderEntity( + identity=ToolProviderIdentity( + author=user.name if user else "", + name=provider.label, + label=I18nObject(en_US=provider.label, zh_Hans=provider.label), + description=I18nObject(en_US=provider.description, zh_Hans=provider.description), + icon=provider.icon, + ), + credentials_schema=[], + plugin_id=None, ), - credentials_schema=[], - plugin_id=None, - ), - provider_id=db_provider.id or "", - ) + provider_id=provider.id or "", + ) - # init tools - - controller.tools = [controller._get_db_provider_tool(db_provider, app)] + controller.tools = [ + controller._get_db_provider_tool(provider, app, session=session, user=user), + ] return controller @@ -74,7 +81,14 @@ class WorkflowToolProviderController(ToolProviderController): def provider_type(self) -> ToolProviderType: return ToolProviderType.WORKFLOW - def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool: + def _get_db_provider_tool( + self, + db_provider: WorkflowToolProvider, + app: App, + *, + session: Session, + user: Account | None = None, + ) -> WorkflowTool: """ get db provider tool :param db_provider: the db provider @@ -82,7 +96,7 @@ class WorkflowToolProviderController(ToolProviderController): :return: the tool """ workflow: Workflow | None = ( - db.session.query(Workflow) + session.query(Workflow) .where(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version) .first() ) @@ -99,9 +113,7 @@ class WorkflowToolProviderController(ToolProviderController): variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph) def fetch_workflow_variable(variable_name: str) -> VariableEntity | None: - return next(filter(lambda x: x.variable == variable_name, variables), None) # type: ignore - - user = db_provider.user + return next(filter(lambda x: x.variable == variable_name, variables), None) workflow_tool_parameters = [] for parameter in parameters: @@ -187,22 +199,25 @@ class WorkflowToolProviderController(ToolProviderController): if self.tools is not None: return self.tools - db_providers: WorkflowToolProvider | None = ( - db.session.query(WorkflowToolProvider) - .where( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.app_id == self.provider_id, + with Session(db.engine, expire_on_commit=False) as session, session.begin(): + db_provider: WorkflowToolProvider | None = ( + session.query(WorkflowToolProvider) + .where( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.app_id == self.provider_id, + ) + .first() ) - .first() - ) - if not db_providers: - return [] - if not db_providers.app: - raise ValueError("app not found") + if not db_provider: + return [] - app = db_providers.app - self.tools = [self._get_db_provider_tool(db_providers, app)] + app = session.get(App, db_provider.app_id) + if not app: + raise ValueError("app not found") + + user = session.get(Account, db_provider.user_id) if db_provider.user_id else None + self.tools = [self._get_db_provider_tool(db_provider, app, session=session, user=user)] return self.tools diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 50c2327004..2cd46647a0 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -1,12 +1,14 @@ import json import logging -from collections.abc import Generator -from typing import Any +from collections.abc import Generator, Mapping, Sequence +from typing import Any, cast from flask import has_request_context from sqlalchemy import select +from sqlalchemy.orm import Session from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod +from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ( @@ -48,6 +50,7 @@ class WorkflowTool(Tool): self.workflow_entities = workflow_entities self.workflow_call_depth = workflow_call_depth self.label = label + self._latest_usage = LLMUsage.empty_usage() super().__init__(entity=entity, runtime=runtime) @@ -83,10 +86,11 @@ class WorkflowTool(Tool): assert self.runtime.invoke_from is not None user = self._resolve_user(user_id=user_id) - if user is None: raise ToolInvokeError("User not found") + self._latest_usage = LLMUsage.empty_usage() + result = generator.generate( app_model=app, workflow=workflow, @@ -110,9 +114,68 @@ class WorkflowTool(Tool): for file in files: yield self.create_file_message(file) # type: ignore + self._latest_usage = self._derive_usage_from_result(data) + yield self.create_text_message(json.dumps(outputs, ensure_ascii=False)) yield self.create_json_message(outputs) + @property + def latest_usage(self) -> LLMUsage: + return self._latest_usage + + @classmethod + def _derive_usage_from_result(cls, data: Mapping[str, Any]) -> LLMUsage: + usage_dict = cls._extract_usage_dict(data) + if usage_dict is not None: + return LLMUsage.from_metadata(cast(LLMUsageMetadata, dict(usage_dict))) + + total_tokens = data.get("total_tokens") + total_price = data.get("total_price") + if total_tokens is None and total_price is None: + return LLMUsage.empty_usage() + + usage_metadata: dict[str, Any] = {} + if total_tokens is not None: + try: + usage_metadata["total_tokens"] = int(str(total_tokens)) + except (TypeError, ValueError): + pass + if total_price is not None: + usage_metadata["total_price"] = str(total_price) + currency = data.get("currency") + if currency is not None: + usage_metadata["currency"] = currency + + if not usage_metadata: + return LLMUsage.empty_usage() + + return LLMUsage.from_metadata(cast(LLMUsageMetadata, usage_metadata)) + + @classmethod + def _extract_usage_dict(cls, payload: Mapping[str, Any]) -> Mapping[str, Any] | None: + usage_candidate = payload.get("usage") + if isinstance(usage_candidate, Mapping): + return usage_candidate + + metadata_candidate = payload.get("metadata") + if isinstance(metadata_candidate, Mapping): + usage_candidate = metadata_candidate.get("usage") + if isinstance(usage_candidate, Mapping): + return usage_candidate + + for value in payload.values(): + if isinstance(value, Mapping): + found = cls._extract_usage_dict(value) + if found is not None: + return found + elif isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): + for item in value: + if isinstance(item, Mapping): + found = cls._extract_usage_dict(item) + if found is not None: + return found + return None + def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool": """ fork a new tool with metadata @@ -179,16 +242,17 @@ class WorkflowTool(Tool): """ get the workflow by app id and version """ - if not version: - workflow = ( - db.session.query(Workflow) - .where(Workflow.app_id == app_id, Workflow.version != Workflow.VERSION_DRAFT) - .order_by(Workflow.created_at.desc()) - .first() - ) - else: - stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version) - workflow = db.session.scalar(stmt) + with Session(db.engine, expire_on_commit=False) as session, session.begin(): + if not version: + stmt = ( + select(Workflow) + .where(Workflow.app_id == app_id, Workflow.version != Workflow.VERSION_DRAFT) + .order_by(Workflow.created_at.desc()) + ) + workflow = session.scalars(stmt).first() + else: + stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version) + workflow = session.scalar(stmt) if not workflow: raise ValueError("workflow not found or not published") @@ -200,7 +264,8 @@ class WorkflowTool(Tool): get the app by app id """ stmt = select(App).where(App.id == app_id) - app = db.session.scalar(stmt) + with Session(db.engine, expire_on_commit=False) as session, session.begin(): + app = session.scalar(stmt) if not app: raise ValueError("app not found") diff --git a/api/core/variables/segment_group.py b/api/core/variables/segment_group.py index 0a41b64228..b363255b2c 100644 --- a/api/core/variables/segment_group.py +++ b/api/core/variables/segment_group.py @@ -4,7 +4,7 @@ from .types import SegmentType class SegmentGroup(Segment): value_type: SegmentType = SegmentType.GROUP - value: list[Segment] = None # type: ignore + value: list[Segment] @property def text(self): diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index 6c9e6d726e..406b4e6f93 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -19,7 +19,7 @@ class Segment(BaseModel): model_config = ConfigDict(frozen=True) value_type: SegmentType - value: Any = None + value: Any @field_validator("value_type") @classmethod @@ -74,12 +74,12 @@ class NoneSegment(Segment): class StringSegment(Segment): value_type: SegmentType = SegmentType.STRING - value: str = None # type: ignore + value: str class FloatSegment(Segment): value_type: SegmentType = SegmentType.FLOAT - value: float = None # type: ignore + value: float # NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems. # The following tests cannot pass. # @@ -98,12 +98,12 @@ class FloatSegment(Segment): class IntegerSegment(Segment): value_type: SegmentType = SegmentType.INTEGER - value: int = None # type: ignore + value: int class ObjectSegment(Segment): value_type: SegmentType = SegmentType.OBJECT - value: Mapping[str, Any] = None # type: ignore + value: Mapping[str, Any] @property def text(self) -> str: @@ -136,7 +136,7 @@ class ArraySegment(Segment): class FileSegment(Segment): value_type: SegmentType = SegmentType.FILE - value: File = None # type: ignore + value: File @property def markdown(self) -> str: @@ -153,17 +153,17 @@ class FileSegment(Segment): class BooleanSegment(Segment): value_type: SegmentType = SegmentType.BOOLEAN - value: bool = None # type: ignore + value: bool class ArrayAnySegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_ANY - value: Sequence[Any] = None # type: ignore + value: Sequence[Any] class ArrayStringSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_STRING - value: Sequence[str] = None # type: ignore + value: Sequence[str] @property def text(self) -> str: @@ -175,17 +175,17 @@ class ArrayStringSegment(ArraySegment): class ArrayNumberSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_NUMBER - value: Sequence[float | int] = None # type: ignore + value: Sequence[float | int] class ArrayObjectSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_OBJECT - value: Sequence[Mapping[str, Any]] = None # type: ignore + value: Sequence[Mapping[str, Any]] class ArrayFileSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_FILE - value: Sequence[File] = None # type: ignore + value: Sequence[File] @property def markdown(self) -> str: @@ -205,7 +205,7 @@ class ArrayFileSegment(ArraySegment): class ArrayBooleanSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_BOOLEAN - value: Sequence[bool] = None # type: ignore + value: Sequence[bool] def get_segment_discriminator(v: Any) -> SegmentType | None: diff --git a/api/core/workflow/entities/__init__.py b/api/core/workflow/entities/__init__.py index be70e467a0..185f0ad620 100644 --- a/api/core/workflow/entities/__init__.py +++ b/api/core/workflow/entities/__init__.py @@ -1,3 +1,5 @@ +from ..runtime.graph_runtime_state import GraphRuntimeState +from ..runtime.variable_pool import VariablePool from .agent import AgentNodeStrategyInit from .graph_init_params import GraphInitParams from .workflow_execution import WorkflowExecution @@ -6,6 +8,8 @@ from .workflow_node_execution import WorkflowNodeExecution __all__ = [ "AgentNodeStrategyInit", "GraphInitParams", + "GraphRuntimeState", + "VariablePool", "WorkflowExecution", "WorkflowNodeExecution", ] diff --git a/api/core/workflow/graph/graph.py b/api/core/workflow/graph/graph.py index 20b5193875..d04724425c 100644 --- a/api/core/workflow/graph/graph.py +++ b/api/core/workflow/graph/graph.py @@ -3,11 +3,12 @@ from collections import defaultdict from collections.abc import Mapping, Sequence from typing import Protocol, cast, final -from core.workflow.enums import NodeExecutionType, NodeState, NodeType +from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType from core.workflow.nodes.base.node import Node from libs.typing import is_str, is_str_dict from .edge import Edge +from .validation import get_graph_validator logger = logging.getLogger(__name__) @@ -201,6 +202,17 @@ class Graph: return GraphBuilder(graph_cls=cls) + @classmethod + def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None: + """ + Promote nodes configured with FAIL_BRANCH error strategy to branch execution type. + + :param nodes: mapping of node ID to node instance + """ + for node in nodes.values(): + if node.error_strategy == ErrorStrategy.FAIL_BRANCH: + node.execution_type = NodeExecutionType.BRANCH + @classmethod def _mark_inactive_root_branches( cls, @@ -307,6 +319,9 @@ class Graph: # Create node instances nodes = cls._create_node_instances(node_configs_map, node_factory) + # Promote fail-branch nodes to branch execution type at graph level + cls._promote_fail_branch_nodes(nodes) + # Get root node instance root_node = nodes[root_node_id] @@ -314,7 +329,7 @@ class Graph: cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id) # Create and return the graph - return cls( + graph = cls( nodes=nodes, edges=edges, in_edges=in_edges, @@ -322,6 +337,11 @@ class Graph: root_node=root_node, ) + # Validate the graph structure using built-in validators + get_graph_validator().validate(graph) + + return graph + @property def node_ids(self) -> list[str]: """ diff --git a/api/core/workflow/graph/validation.py b/api/core/workflow/graph/validation.py new file mode 100644 index 0000000000..87aa7db2e4 --- /dev/null +++ b/api/core/workflow/graph/validation.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Protocol + +from core.workflow.enums import NodeExecutionType, NodeType + +if TYPE_CHECKING: + from .graph import Graph + + +@dataclass(frozen=True, slots=True) +class GraphValidationIssue: + """Immutable value object describing a single validation issue.""" + + code: str + message: str + node_id: str | None = None + + +class GraphValidationError(ValueError): + """Raised when graph validation fails.""" + + def __init__(self, issues: Sequence[GraphValidationIssue]) -> None: + if not issues: + raise ValueError("GraphValidationError requires at least one issue.") + self.issues: tuple[GraphValidationIssue, ...] = tuple(issues) + message = "; ".join(f"[{issue.code}] {issue.message}" for issue in self.issues) + super().__init__(message) + + +class GraphValidationRule(Protocol): + """Protocol that individual validation rules must satisfy.""" + + def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: + """Validate the provided graph and return any discovered issues.""" + ... + + +@dataclass(frozen=True, slots=True) +class _EdgeEndpointValidator: + """Ensures all edges reference existing nodes.""" + + missing_node_code: str = "MISSING_NODE" + + def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: + issues: list[GraphValidationIssue] = [] + for edge in graph.edges.values(): + if edge.tail not in graph.nodes: + issues.append( + GraphValidationIssue( + code=self.missing_node_code, + message=f"Edge {edge.id} references unknown source node '{edge.tail}'.", + node_id=edge.tail, + ) + ) + if edge.head not in graph.nodes: + issues.append( + GraphValidationIssue( + code=self.missing_node_code, + message=f"Edge {edge.id} references unknown target node '{edge.head}'.", + node_id=edge.head, + ) + ) + return issues + + +@dataclass(frozen=True, slots=True) +class _RootNodeValidator: + """Validates root node invariants.""" + + invalid_root_code: str = "INVALID_ROOT" + container_entry_types: tuple[NodeType, ...] = (NodeType.ITERATION_START, NodeType.LOOP_START) + + def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: + root_node = graph.root_node + issues: list[GraphValidationIssue] = [] + if root_node.id not in graph.nodes: + issues.append( + GraphValidationIssue( + code=self.invalid_root_code, + message=f"Root node '{root_node.id}' is missing from the node registry.", + node_id=root_node.id, + ) + ) + return issues + + node_type = getattr(root_node, "node_type", None) + if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types: + issues.append( + GraphValidationIssue( + code=self.invalid_root_code, + message=f"Root node '{root_node.id}' must declare execution type 'root'.", + node_id=root_node.id, + ) + ) + return issues + + +@dataclass(frozen=True, slots=True) +class GraphValidator: + """Coordinates execution of graph validation rules.""" + + rules: tuple[GraphValidationRule, ...] + + def validate(self, graph: Graph) -> None: + """Validate the graph against all configured rules.""" + issues: list[GraphValidationIssue] = [] + for rule in self.rules: + issues.extend(rule.validate(graph)) + + if issues: + raise GraphValidationError(issues) + + +_DEFAULT_RULES: tuple[GraphValidationRule, ...] = ( + _EdgeEndpointValidator(), + _RootNodeValidator(), +) + + +def get_graph_validator() -> GraphValidator: + """Construct the validator composed of default rules.""" + return GraphValidator(_DEFAULT_RULES) diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index ce6eb33ecc..985ee5eef2 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -26,8 +26,8 @@ class AgentNodeData(BaseNodeData): class ParamsAutoGenerated(IntEnum): - CLOSE = auto() - OPEN = auto() + CLOSE = 0 + OPEN = 1 class AgentOldVersionModelFeatures(StrEnum): diff --git a/api/core/workflow/nodes/base/__init__.py b/api/core/workflow/nodes/base/__init__.py index 8cf31dc342..f83df0e323 100644 --- a/api/core/workflow/nodes/base/__init__.py +++ b/api/core/workflow/nodes/base/__init__.py @@ -1,4 +1,5 @@ from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData +from .usage_tracking_mixin import LLMUsageTrackingMixin __all__ = [ "BaseIterationNodeData", @@ -6,4 +7,5 @@ __all__ = [ "BaseLoopNodeData", "BaseLoopState", "BaseNodeData", + "LLMUsageTrackingMixin", ] diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index 5aef9d79cf..94b0d1d8bc 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -1,5 +1,6 @@ import json from abc import ABC +from builtins import type as type_ from collections.abc import Sequence from enum import StrEnum from typing import Any, Union @@ -58,10 +59,9 @@ class DefaultValue(BaseModel): raise DefaultValueTypeError(f"Invalid JSON format for value: {value}") @staticmethod - def _validate_array(value: Any, element_type: DefaultValueType) -> bool: + def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool: """Unified array type validation""" - # FIXME, type ignore here for do not find the reason mypy complain, if find the root cause, please fix it - return isinstance(value, list) and all(isinstance(x, element_type) for x in value) # type: ignore + return isinstance(value, list) and all(isinstance(x, element_type) for x in value) @staticmethod def _convert_number(value: str) -> float: diff --git a/api/core/workflow/nodes/base/usage_tracking_mixin.py b/api/core/workflow/nodes/base/usage_tracking_mixin.py new file mode 100644 index 0000000000..d9a0ef8972 --- /dev/null +++ b/api/core/workflow/nodes/base/usage_tracking_mixin.py @@ -0,0 +1,28 @@ +from core.model_runtime.entities.llm_entities import LLMUsage +from core.workflow.runtime import GraphRuntimeState + + +class LLMUsageTrackingMixin: + """Provides shared helpers for merging and recording LLM usage within workflow nodes.""" + + graph_runtime_state: GraphRuntimeState + + @staticmethod + def _merge_usage(current: LLMUsage, new_usage: LLMUsage | None) -> LLMUsage: + """Return a combined usage snapshot, preserving zero-value inputs.""" + if new_usage is None or new_usage.total_tokens <= 0: + return current + if current.total_tokens == 0: + return new_usage + return current.plus(new_usage) + + def _accumulate_usage(self, usage: LLMUsage) -> None: + """Push usage into the graph runtime accumulator for downstream reporting.""" + if usage.total_tokens <= 0: + return + + current_usage = self.graph_runtime_state.llm_usage + if current_usage.total_tokens == 0: + self.graph_runtime_state.llm_usage = usage.model_copy() + else: + self.graph_runtime_state.llm_usage = current_usage.plus(usage) diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index ae1061d72c..cd5f50aaab 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -10,10 +10,10 @@ from typing import Any import chardet import docx import pandas as pd -import pypandoc # type: ignore -import pypdfium2 # type: ignore -import webvtt # type: ignore -import yaml # type: ignore +import pypandoc +import pypdfium2 +import webvtt +import yaml from docx.document import Document from docx.oxml.table import CT_Tbl from docx.oxml.text.paragraph import CT_P diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 41060bd569..3a3a2290be 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, NewType, cast from flask import Flask, current_app from typing_extensions import TypeIs +from core.model_runtime.entities.llm_entities import LLMUsage from core.variables import IntegerVariable, NoneSegment from core.variables.segments import ArrayAnySegment, ArraySegment from core.variables.variables import VariableUnion @@ -34,6 +35,7 @@ from core.workflow.node_events import ( NodeRunResult, StreamCompletedEvent, ) +from core.workflow.nodes.base import LLMUsageTrackingMixin from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData @@ -58,7 +60,7 @@ logger = logging.getLogger(__name__) EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment) -class IterationNode(Node): +class IterationNode(LLMUsageTrackingMixin, Node): """ Iteration Node. """ @@ -118,6 +120,7 @@ class IterationNode(Node): started_at = naive_utc_now() iter_run_map: dict[str, float] = {} outputs: list[object] = [] + usage_accumulator = [LLMUsage.empty_usage()] yield IterationStartedEvent( start_at=started_at, @@ -130,22 +133,27 @@ class IterationNode(Node): iterator_list_value=iterator_list_value, outputs=outputs, iter_run_map=iter_run_map, + usage_accumulator=usage_accumulator, ) + self._accumulate_usage(usage_accumulator[0]) yield from self._handle_iteration_success( started_at=started_at, inputs=inputs, outputs=outputs, iterator_list_value=iterator_list_value, iter_run_map=iter_run_map, + usage=usage_accumulator[0], ) except IterationNodeError as e: + self._accumulate_usage(usage_accumulator[0]) yield from self._handle_iteration_failure( started_at=started_at, inputs=inputs, outputs=outputs, iterator_list_value=iterator_list_value, iter_run_map=iter_run_map, + usage=usage_accumulator[0], error=e, ) @@ -196,6 +204,7 @@ class IterationNode(Node): iterator_list_value: Sequence[object], outputs: list[object], iter_run_map: dict[str, float], + usage_accumulator: list[LLMUsage], ) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: if self._node_data.is_parallel: # Parallel mode execution @@ -203,6 +212,7 @@ class IterationNode(Node): iterator_list_value=iterator_list_value, outputs=outputs, iter_run_map=iter_run_map, + usage_accumulator=usage_accumulator, ) else: # Sequential mode execution @@ -228,6 +238,9 @@ class IterationNode(Node): # Update the total tokens from this iteration self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens + usage_accumulator[0] = self._merge_usage( + usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage + ) iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() def _execute_parallel_iterations( @@ -235,6 +248,7 @@ class IterationNode(Node): iterator_list_value: Sequence[object], outputs: list[object], iter_run_map: dict[str, float], + usage_accumulator: list[LLMUsage], ) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: # Initialize outputs list with None values to maintain order outputs.extend([None] * len(iterator_list_value)) @@ -245,7 +259,16 @@ class IterationNode(Node): with ThreadPoolExecutor(max_workers=max_workers) as executor: # Submit all iteration tasks future_to_index: dict[ - Future[tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]], + Future[ + tuple[ + datetime, + list[GraphNodeEventBase], + object | None, + int, + dict[str, VariableUnion], + LLMUsage, + ] + ], int, ] = {} for index, item in enumerate(iterator_list_value): @@ -264,7 +287,14 @@ class IterationNode(Node): index = future_to_index[future] try: result = future.result() - iter_start_at, events, output_value, tokens_used, conversation_snapshot = result + ( + iter_start_at, + events, + output_value, + tokens_used, + conversation_snapshot, + iteration_usage, + ) = result # Update outputs at the correct index outputs[index] = output_value @@ -276,6 +306,8 @@ class IterationNode(Node): self.graph_runtime_state.total_tokens += tokens_used iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() + usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage) + # Sync conversation variables after iteration completion self._sync_conversation_variables_from_snapshot(conversation_snapshot) @@ -303,7 +335,7 @@ class IterationNode(Node): item: object, flask_app: Flask, context_vars: contextvars.Context, - ) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]: + ) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion], LLMUsage]: """Execute a single iteration in parallel mode and return results.""" with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars): iter_start_at = datetime.now(UTC).replace(tzinfo=None) @@ -332,6 +364,7 @@ class IterationNode(Node): output_value, graph_engine.graph_runtime_state.total_tokens, conversation_snapshot, + graph_engine.graph_runtime_state.llm_usage, ) def _handle_iteration_success( @@ -341,6 +374,8 @@ class IterationNode(Node): outputs: list[object], iterator_list_value: Sequence[object], iter_run_map: dict[str, float], + *, + usage: LLMUsage, ) -> Generator[NodeEventBase, None, None]: # Flatten the list of lists if all outputs are lists flattened_outputs = self._flatten_outputs_if_needed(outputs) @@ -351,7 +386,9 @@ class IterationNode(Node): outputs={"output": flattened_outputs}, steps=len(iterator_list_value), metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, }, ) @@ -362,8 +399,11 @@ class IterationNode(Node): status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"output": flattened_outputs}, metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, }, + llm_usage=usage, ) ) @@ -400,6 +440,8 @@ class IterationNode(Node): outputs: list[object], iterator_list_value: Sequence[object], iter_run_map: dict[str, float], + *, + usage: LLMUsage, error: IterationNodeError, ) -> Generator[NodeEventBase, None, None]: # Flatten the list of lists if all outputs are lists (even in failure case) @@ -411,7 +453,9 @@ class IterationNode(Node): outputs={"output": flattened_outputs}, steps=len(iterator_list_value), metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, }, error=str(error), @@ -420,6 +464,12 @@ class IterationNode(Node): node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, error=str(error), + metadata={ + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, + }, + llm_usage=usage, ) ) diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 2dc3cb9320..4a63900527 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -15,14 +15,11 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti from core.entities.agent_entities import PlanningStrategy from core.entities.model_entities import ModelStatus from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.message_entities import ( - PromptMessageRole, -) -from core.model_runtime.entities.model_entities import ( - ModelFeature, - ModelType, -) +from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.message_entities import PromptMessageRole +from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.simple_prompt_transform import ModelMode from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.metadata_entities import Condition, MetadataCondition @@ -33,8 +30,14 @@ from core.variables import ( ) from core.variables.segments import ArrayObjectSegment from core.workflow.entities import GraphInitParams -from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus +from core.workflow.enums import ( + ErrorStrategy, + NodeType, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult +from core.workflow.nodes.base import LLMUsageTrackingMixin from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.knowledge_retrieval.template_prompts import ( @@ -80,7 +83,7 @@ default_retrieval_model = { } -class KnowledgeRetrievalNode(Node): +class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node): node_type = NodeType.KNOWLEDGE_RETRIEVAL _node_data: KnowledgeRetrievalNodeData @@ -141,7 +144,7 @@ class KnowledgeRetrievalNode(Node): def version(cls): return "1" - def _run(self) -> NodeRunResult: # type: ignore + def _run(self) -> NodeRunResult: # extract variables variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector) if not isinstance(variable, StringSegment): @@ -182,14 +185,21 @@ class KnowledgeRetrievalNode(Node): ) # retrieve knowledge + usage = LLMUsage.empty_usage() try: - results = self._fetch_dataset_retriever(node_data=self._node_data, query=query) + results, usage = self._fetch_dataset_retriever(node_data=self._node_data, query=query) outputs = {"result": ArrayObjectSegment(value=results)} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, - process_data={}, + process_data={"usage": jsonable_encoder(usage)}, outputs=outputs, # type: ignore + metadata={ + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, + }, + llm_usage=usage, ) except KnowledgeRetrievalNodeError as e: @@ -199,6 +209,7 @@ class KnowledgeRetrievalNode(Node): inputs=variables, error=str(e), error_type=type(e).__name__, + llm_usage=usage, ) # Temporary handle all exceptions from DatasetRetrieval class here. except Exception as e: @@ -207,11 +218,15 @@ class KnowledgeRetrievalNode(Node): inputs=variables, error=str(e), error_type=type(e).__name__, + llm_usage=usage, ) finally: db.session.close() - def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]: + def _fetch_dataset_retriever( + self, node_data: KnowledgeRetrievalNodeData, query: str + ) -> tuple[list[dict[str, Any]], LLMUsage]: + usage = LLMUsage.empty_usage() available_datasets = [] dataset_ids = node_data.dataset_ids @@ -245,9 +260,10 @@ class KnowledgeRetrievalNode(Node): if not dataset: continue available_datasets.append(dataset) - metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition( + metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition( [dataset.id for dataset in available_datasets], query, node_data ) + usage = self._merge_usage(usage, metadata_usage) all_documents = [] dataset_retrieval = DatasetRetrieval() if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: @@ -330,6 +346,8 @@ class KnowledgeRetrievalNode(Node): metadata_filter_document_ids=metadata_filter_document_ids, metadata_condition=metadata_condition, ) + usage = self._merge_usage(usage, dataset_retrieval.llm_usage) + dify_documents = [item for item in all_documents if item.provider == "dify"] external_documents = [item for item in all_documents if item.provider == "external"] retrieval_resource_list = [] @@ -406,11 +424,12 @@ class KnowledgeRetrievalNode(Node): ) for position, item in enumerate(retrieval_resource_list, start=1): item["metadata"]["position"] = position - return retrieval_resource_list + return retrieval_resource_list, usage def _get_metadata_filter_condition( self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData - ) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]: + ) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]: + usage = LLMUsage.empty_usage() document_query = db.session.query(Document).where( Document.dataset_id.in_(dataset_ids), Document.indexing_status == "completed", @@ -420,9 +439,12 @@ class KnowledgeRetrievalNode(Node): filters: list[Any] = [] metadata_condition = None if node_data.metadata_filtering_mode == "disabled": - return None, None + return None, None, usage elif node_data.metadata_filtering_mode == "automatic": - automatic_metadata_filters = self._automatic_metadata_filter_func(dataset_ids, query, node_data) + automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func( + dataset_ids, query, node_data + ) + usage = self._merge_usage(usage, automatic_usage) if automatic_metadata_filters: conditions = [] for sequence, filter in enumerate(automatic_metadata_filters): @@ -443,7 +465,7 @@ class KnowledgeRetrievalNode(Node): metadata_condition = MetadataCondition( logical_operator=node_data.metadata_filtering_conditions.logical_operator if node_data.metadata_filtering_conditions - else "or", # type: ignore + else "or", conditions=conditions, ) elif node_data.metadata_filtering_mode == "manual": @@ -457,10 +479,10 @@ class KnowledgeRetrievalNode(Node): expected_value = self.graph_runtime_state.variable_pool.convert_template( expected_value ).value[0] - if expected_value.value_type in {"number", "integer", "float"}: # type: ignore - expected_value = expected_value.value # type: ignore - elif expected_value.value_type == "string": # type: ignore - expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore + if expected_value.value_type in {"number", "integer", "float"}: + expected_value = expected_value.value + elif expected_value.value_type == "string": + expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() else: raise ValueError("Invalid expected metadata value type") conditions.append( @@ -487,7 +509,7 @@ class KnowledgeRetrievalNode(Node): if ( node_data.metadata_filtering_conditions and node_data.metadata_filtering_conditions.logical_operator == "and" - ): # type: ignore + ): document_query = document_query.where(and_(*filters)) else: document_query = document_query.where(or_(*filters)) @@ -496,11 +518,12 @@ class KnowledgeRetrievalNode(Node): metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore for document in documents: metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore - return metadata_filter_document_ids, metadata_condition + return metadata_filter_document_ids, metadata_condition, usage def _automatic_metadata_filter_func( self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData - ) -> list[dict[str, Any]]: + ) -> tuple[list[dict[str, Any]], LLMUsage]: + usage = LLMUsage.empty_usage() # get all metadata field stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)) metadata_fields = db.session.scalars(stmt).all() @@ -548,6 +571,7 @@ class KnowledgeRetrievalNode(Node): for event in generator: if isinstance(event, ModelInvokeCompletedEvent): result_text = event.text + usage = self._merge_usage(usage, event.usage) break result_text_json = parse_and_check_json_markdown(result_text, []) @@ -564,8 +588,8 @@ class KnowledgeRetrievalNode(Node): } ) except Exception: - return [] - return automatic_metadata_filters + return [], usage + return automatic_metadata_filters, usage def _process_metadata_filter_func( self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any] diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index e4637e6e95..1644f683bf 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -441,10 +441,14 @@ class LLMNode(Node): usage = LLMUsage.empty_usage() finish_reason = None full_text_buffer = io.StringIO() + collected_structured_output = None # Collect structured_output from streaming chunks # Consume the invoke result and handle generator exception try: for result in invoke_result: if isinstance(result, LLMResultChunkWithStructuredOutput): + # Collect structured_output from the chunk + if result.structured_output is not None: + collected_structured_output = dict(result.structured_output) yield result if isinstance(result, LLMResultChunk): contents = result.delta.message.content @@ -492,6 +496,8 @@ class LLMNode(Node): finish_reason=finish_reason, # Reasoning content for workflow variables and downstream nodes reasoning_content=reasoning_content, + # Pass structured output if collected from streaming chunks + structured_output=collected_structured_output, ) @staticmethod diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index b51790c0a2..ca39e5aa23 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -5,6 +5,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence from datetime import datetime from typing import TYPE_CHECKING, Any, Literal, cast +from core.model_runtime.entities.llm_entities import LLMUsage from core.variables import Segment, SegmentType from core.workflow.enums import ( ErrorStrategy, @@ -27,6 +28,7 @@ from core.workflow.node_events import ( NodeRunResult, StreamCompletedEvent, ) +from core.workflow.nodes.base import LLMUsageTrackingMixin from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData @@ -40,7 +42,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class LoopNode(Node): +class LoopNode(LLMUsageTrackingMixin, Node): """ Loop Node. """ @@ -108,7 +110,7 @@ class LoopNode(Node): raise ValueError(f"Invalid value for loop variable {loop_variable.label}") variable_selector = [self._node_id, loop_variable.label] variable = segment_to_variable(segment=processed_segment, selector=variable_selector) - self.graph_runtime_state.variable_pool.add(variable_selector, variable) + self.graph_runtime_state.variable_pool.add(variable_selector, variable.value) loop_variable_selectors[loop_variable.label] = variable_selector inputs[loop_variable.label] = processed_segment.value @@ -117,6 +119,7 @@ class LoopNode(Node): loop_duration_map: dict[str, float] = {} single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output + loop_usage = LLMUsage.empty_usage() # Start Loop event yield LoopStartedEvent( @@ -163,6 +166,9 @@ class LoopNode(Node): # Update the total tokens from this iteration cost_tokens += graph_engine.graph_runtime_state.total_tokens + # Accumulate usage from the sub-graph execution + loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage) + # Collect loop variable values after iteration single_loop_variable = {} for key, selector in loop_variable_selectors.items(): @@ -189,6 +195,7 @@ class LoopNode(Node): ) self.graph_runtime_state.total_tokens += cost_tokens + self._accumulate_usage(loop_usage) # Loop completed successfully yield LoopSucceededEvent( start_at=start_at, @@ -196,7 +203,9 @@ class LoopNode(Node): outputs=self._node_data.outputs, steps=loop_count, metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: cost_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, "completed_reason": "loop_break" if reach_break_condition else "loop_completed", WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, @@ -207,22 +216,28 @@ class LoopNode(Node): node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, }, outputs=self._node_data.outputs, inputs=inputs, + llm_usage=loop_usage, ) ) except Exception as e: + self._accumulate_usage(loop_usage) yield LoopFailedEvent( start_at=start_at, inputs=inputs, steps=loop_count, metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, "completed_reason": "error", WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, @@ -235,10 +250,13 @@ class LoopNode(Node): status=WorkflowNodeExecutionStatus.FAILED, error=str(e), metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, }, + llm_usage=loop_usage, ) ) diff --git a/api/core/workflow/nodes/node_factory.py b/api/core/workflow/nodes/node_factory.py index 87d1b8c435..84f63d57eb 100644 --- a/api/core/workflow/nodes/node_factory.py +++ b/api/core/workflow/nodes/node_factory.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, final from typing_extensions import override -from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType +from core.workflow.enums import NodeType from core.workflow.graph import NodeFactory from core.workflow.nodes.base.node import Node from libs.typing import is_str, is_str_dict @@ -82,8 +82,4 @@ class DifyNodeFactory(NodeFactory): raise ValueError(f"Node {node_id} missing data information") node_instance.init_node_data(node_data) - # If node has fail branch, change execution type to branch - if node_instance.error_strategy == ErrorStrategy.FAIL_BRANCH: - node_instance.execution_type = NodeExecutionType.BRANCH - return node_instance diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 2b65cc30b6..e250650fef 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -747,7 +747,7 @@ class ParameterExtractorNode(Node): if model_mode == ModelMode.CHAT: system_prompt_messages = ChatModelMessage( role=PromptMessageRole.SYSTEM, - text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str).replace("{{instructions}}", instruction), + text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str, instructions=instruction), ) user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) return [system_prompt_messages, user_prompt_message] diff --git a/api/core/workflow/nodes/parameter_extractor/prompts.py b/api/core/workflow/nodes/parameter_extractor/prompts.py index b74be8f206..1b29be4418 100644 --- a/api/core/workflow/nodes/parameter_extractor/prompts.py +++ b/api/core/workflow/nodes/parameter_extractor/prompts.py @@ -135,7 +135,7 @@ Here are the chat histories between human and assistant, inside -{{instructions}} +{instructions} """ diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 2e2c32ac93..69ab6f0718 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -6,10 +6,13 @@ from sqlalchemy.orm import Session from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.file import File, FileTransferMethod +from core.model_runtime.entities.llm_entities import LLMUsage +from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.errors import ToolInvokeError from core.tools.tool_engine import ToolEngine from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.tools.workflow_as_tool.tool import WorkflowTool from core.variables.segments import ArrayAnySegment, ArrayFileSegment from core.variables.variables import ArrayAnyVariable from core.workflow.enums import ( @@ -136,13 +139,14 @@ class ToolNode(Node): try: # convert tool messages - yield from self._transform_message( + _ = yield from self._transform_message( messages=message_stream, tool_info=tool_info, parameters_for_log=parameters_for_log, user_id=self.user_id, tenant_id=self.tenant_id, node_id=self._node_id, + tool_runtime=tool_runtime, ) except ToolInvokeError as e: yield StreamCompletedEvent( @@ -236,7 +240,8 @@ class ToolNode(Node): user_id: str, tenant_id: str, node_id: str, - ) -> Generator: + tool_runtime: Tool, + ) -> Generator[NodeEventBase, None, LLMUsage]: """ Convert ToolInvokeMessages into tuple[plain_text, files] """ @@ -424,17 +429,34 @@ class ToolNode(Node): is_final=True, ) + usage = self._extract_tool_usage(tool_runtime) + + metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = { + WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, + } + if usage.total_tokens > 0: + metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens + metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price + metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency + yield StreamCompletedEvent( node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables}, - metadata={ - WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, - }, + metadata=metadata, inputs=parameters_for_log, + llm_usage=usage, ) ) + return usage + + @staticmethod + def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage: + if isinstance(tool_runtime, WorkflowTool): + return tool_runtime.latest_usage + return LLMUsage.empty_usage() + @classmethod def _extract_variable_selector_to_variable_mapping( cls, diff --git a/api/core/workflow/runtime/variable_pool.py b/api/core/workflow/runtime/variable_pool.py index 5fd6e894f1..d41a20dfd7 100644 --- a/api/core/workflow/runtime/variable_pool.py +++ b/api/core/workflow/runtime/variable_pool.py @@ -260,7 +260,7 @@ class VariablePool(BaseModel): # This ensures that we can keep the id of the system variables intact. if self._has(selector): continue - self.add(selector, value) # type: ignore + self.add(selector, value) @classmethod def empty(cls) -> "VariablePool": diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 08c0a1f35e..421d72a3a9 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -32,7 +32,8 @@ if [[ "${MODE}" == "worker" ]]; then exec celery -A celery_entrypoint.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \ --max-tasks-per-child ${MAX_TASKS_PER_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \ - -Q ${CELERY_QUEUES:-dataset,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation} + -Q ${CELERY_QUEUES:-dataset,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation} \ + --prefetch-multiplier=1 elif [[ "${MODE}" == "beat" ]]; then exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO} diff --git a/api/events/event_handlers/clean_when_dataset_deleted.py b/api/events/event_handlers/clean_when_dataset_deleted.py index 0f6aa0e778..1666e2e29f 100644 --- a/api/events/event_handlers/clean_when_dataset_deleted.py +++ b/api/events/event_handlers/clean_when_dataset_deleted.py @@ -6,8 +6,8 @@ from tasks.clean_dataset_task import clean_dataset_task @dataset_was_deleted.connect def handle(sender: Dataset, **kwargs): dataset = sender - assert dataset.doc_form - assert dataset.indexing_technique + if not dataset.doc_form or not dataset.indexing_technique: + return clean_dataset_task.delay( dataset.id, dataset.tenant_id, diff --git a/api/events/event_handlers/clean_when_document_deleted.py b/api/events/event_handlers/clean_when_document_deleted.py index bbc913b7cf..0add109b06 100644 --- a/api/events/event_handlers/clean_when_document_deleted.py +++ b/api/events/event_handlers/clean_when_document_deleted.py @@ -8,6 +8,6 @@ def handle(sender, **kwargs): dataset_id = kwargs.get("dataset_id") doc_form = kwargs.get("doc_form") file_id = kwargs.get("file_id") - assert dataset_id is not None - assert doc_form is not None + if not dataset_id or not doc_form: + return clean_document_task.delay(document_id, dataset_id, doc_form, file_id) diff --git a/api/extensions/ext_blueprints.py b/api/extensions/ext_blueprints.py index 52fef4929f..82f0542b35 100644 --- a/api/extensions/ext_blueprints.py +++ b/api/extensions/ext_blueprints.py @@ -1,7 +1,12 @@ from configs import dify_config -from constants import HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN +from constants import HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN, HEADER_NAME_PASSPORT from dify_app import DifyApp +BASE_CORS_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE, HEADER_NAME_PASSPORT) +SERVICE_API_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, "Authorization") +AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN) +FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN) + def init_app(app: DifyApp): # register blueprint routers @@ -17,7 +22,7 @@ def init_app(app: DifyApp): CORS( service_api_bp, - allow_headers=["Content-Type", "Authorization", HEADER_NAME_APP_CODE], + allow_headers=list(SERVICE_API_HEADERS), methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], ) app.register_blueprint(service_api_bp) @@ -26,7 +31,7 @@ def init_app(app: DifyApp): web_bp, resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}}, supports_credentials=True, - allow_headers=["Content-Type", "Authorization", HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN], + allow_headers=list(AUTHENTICATED_HEADERS), methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], expose_headers=["X-Version", "X-Env"], ) @@ -36,7 +41,7 @@ def init_app(app: DifyApp): console_app_bp, resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}}, supports_credentials=True, - allow_headers=["Content-Type", "Authorization", HEADER_NAME_CSRF_TOKEN], + allow_headers=list(AUTHENTICATED_HEADERS), methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], expose_headers=["X-Version", "X-Env"], ) @@ -44,7 +49,7 @@ def init_app(app: DifyApp): CORS( files_bp, - allow_headers=["Content-Type", HEADER_NAME_CSRF_TOKEN], + allow_headers=list(FILES_HEADERS), methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], ) app.register_blueprint(files_bp) diff --git a/api/extensions/ext_compress.py b/api/extensions/ext_compress.py index 26ff6427be..9c3a663af4 100644 --- a/api/extensions/ext_compress.py +++ b/api/extensions/ext_compress.py @@ -7,7 +7,7 @@ def is_enabled() -> bool: def init_app(app: DifyApp): - from flask_compress import Compress # type: ignore + from flask_compress import Compress compress = Compress() compress.init_app(app) diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index e7816a2e88..ed4fe332c1 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -1,6 +1,6 @@ import json -import flask_login # type: ignore +import flask_login from flask import Response, request from flask_login import user_loaded_from_request, user_logged_in from werkzeug.exceptions import NotFound, Unauthorized diff --git a/api/extensions/ext_migrate.py b/api/extensions/ext_migrate.py index 5f862181fa..6d8f35c30d 100644 --- a/api/extensions/ext_migrate.py +++ b/api/extensions/ext_migrate.py @@ -2,7 +2,7 @@ from dify_app import DifyApp def init_app(app: DifyApp): - import flask_migrate # type: ignore + import flask_migrate from extensions.ext_database import db diff --git a/api/extensions/ext_otel.py b/api/extensions/ext_otel.py index cb6e4849a9..20ac2503a2 100644 --- a/api/extensions/ext_otel.py +++ b/api/extensions/ext_otel.py @@ -103,7 +103,7 @@ def init_app(app: DifyApp): def shutdown_tracer(): provider = trace.get_tracer_provider() if hasattr(provider, "force_flush"): - provider.force_flush() # ty: ignore [call-non-callable] + provider.force_flush() class ExceptionLoggingHandler(logging.Handler): """Custom logging handler that creates spans for logging.exception() calls""" diff --git a/api/extensions/ext_proxy_fix.py b/api/extensions/ext_proxy_fix.py index c085aed986..fe6685f633 100644 --- a/api/extensions/ext_proxy_fix.py +++ b/api/extensions/ext_proxy_fix.py @@ -6,4 +6,4 @@ def init_app(app: DifyApp): if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED: from werkzeug.middleware.proxy_fix import ProxyFix - app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore + app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore[method-assign] diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index 5ed7840211..c3aa8edf80 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -5,7 +5,7 @@ from dify_app import DifyApp def init_app(app: DifyApp): if dify_config.SENTRY_DSN: import sentry_sdk - from langfuse import parse_error # type: ignore + from langfuse import parse_error from sentry_sdk.integrations.celery import CeleryIntegration from sentry_sdk.integrations.flask import FlaskIntegration from werkzeug.exceptions import HTTPException diff --git a/api/extensions/storage/aliyun_oss_storage.py b/api/extensions/storage/aliyun_oss_storage.py index 5da4737138..2283581f62 100644 --- a/api/extensions/storage/aliyun_oss_storage.py +++ b/api/extensions/storage/aliyun_oss_storage.py @@ -1,7 +1,7 @@ import posixpath from collections.abc import Generator -import oss2 as aliyun_s3 # type: ignore +import oss2 as aliyun_s3 from configs import dify_config from extensions.storage.base_storage import BaseStorage diff --git a/api/extensions/storage/baidu_obs_storage.py b/api/extensions/storage/baidu_obs_storage.py index b94efa08be..0bb4648c0a 100644 --- a/api/extensions/storage/baidu_obs_storage.py +++ b/api/extensions/storage/baidu_obs_storage.py @@ -2,9 +2,9 @@ import base64 import hashlib from collections.abc import Generator -from baidubce.auth.bce_credentials import BceCredentials # type: ignore -from baidubce.bce_client_configuration import BceClientConfiguration # type: ignore -from baidubce.services.bos.bos_client import BosClient # type: ignore +from baidubce.auth.bce_credentials import BceCredentials +from baidubce.bce_client_configuration import BceClientConfiguration +from baidubce.services.bos.bos_client import BosClient from configs import dify_config from extensions.storage.base_storage import BaseStorage diff --git a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py index 06c528ca41..1cabc57e74 100644 --- a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py +++ b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py @@ -11,7 +11,7 @@ from collections.abc import Generator from io import BytesIO from pathlib import Path -import clickzetta # type: ignore[import] +import clickzetta from pydantic import BaseModel, model_validator from extensions.storage.base_storage import BaseStorage diff --git a/api/extensions/storage/clickzetta_volume/volume_permissions.py b/api/extensions/storage/clickzetta_volume/volume_permissions.py index 6dcf800abb..9d4ca689d8 100644 --- a/api/extensions/storage/clickzetta_volume/volume_permissions.py +++ b/api/extensions/storage/clickzetta_volume/volume_permissions.py @@ -34,7 +34,7 @@ class VolumePermissionManager: # Support two initialization methods: connection object or configuration dictionary if isinstance(connection_or_config, dict): # Create connection from configuration dictionary - import clickzetta # type: ignore[import-untyped] + import clickzetta config = connection_or_config self._connection = clickzetta.connect( diff --git a/api/extensions/storage/google_cloud_storage.py b/api/extensions/storage/google_cloud_storage.py index 7f59252f2f..d352996518 100644 --- a/api/extensions/storage/google_cloud_storage.py +++ b/api/extensions/storage/google_cloud_storage.py @@ -3,7 +3,7 @@ import io import json from collections.abc import Generator -from google.cloud import storage as google_cloud_storage # type: ignore +from google.cloud import storage as google_cloud_storage from configs import dify_config from extensions.storage.base_storage import BaseStorage diff --git a/api/extensions/storage/huawei_obs_storage.py b/api/extensions/storage/huawei_obs_storage.py index 3e75ecb7a9..74fed26f65 100644 --- a/api/extensions/storage/huawei_obs_storage.py +++ b/api/extensions/storage/huawei_obs_storage.py @@ -1,6 +1,6 @@ from collections.abc import Generator -from obs import ObsClient # type: ignore +from obs import ObsClient from configs import dify_config from extensions.storage.base_storage import BaseStorage diff --git a/api/extensions/storage/oracle_oci_storage.py b/api/extensions/storage/oracle_oci_storage.py index acc00cbd6b..c032803045 100644 --- a/api/extensions/storage/oracle_oci_storage.py +++ b/api/extensions/storage/oracle_oci_storage.py @@ -1,7 +1,7 @@ from collections.abc import Generator -import boto3 # type: ignore -from botocore.exceptions import ClientError # type: ignore +import boto3 +from botocore.exceptions import ClientError from configs import dify_config from extensions.storage.base_storage import BaseStorage diff --git a/api/extensions/storage/tencent_cos_storage.py b/api/extensions/storage/tencent_cos_storage.py index 9cdd3e67f7..ea5d982efc 100644 --- a/api/extensions/storage/tencent_cos_storage.py +++ b/api/extensions/storage/tencent_cos_storage.py @@ -1,6 +1,6 @@ from collections.abc import Generator -from qcloud_cos import CosConfig, CosS3Client # type: ignore +from qcloud_cos import CosConfig, CosS3Client from configs import dify_config from extensions.storage.base_storage import BaseStorage diff --git a/api/extensions/storage/volcengine_tos_storage.py b/api/extensions/storage/volcengine_tos_storage.py index 8ed8e4c170..a44959221f 100644 --- a/api/extensions/storage/volcengine_tos_storage.py +++ b/api/extensions/storage/volcengine_tos_storage.py @@ -1,6 +1,6 @@ from collections.abc import Generator -import tos # type: ignore +import tos from configs import dify_config from extensions.storage.base_storage import BaseStorage diff --git a/api/libs/external_api.py b/api/libs/external_api.py index f3ebcc4306..1a4fde960c 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -146,6 +146,6 @@ class ExternalApi(Api): kwargs["doc"] = dify_config.SWAGGER_UI_PATH if dify_config.SWAGGER_UI_ENABLED else False # manual separate call on construction and init_app to ensure configs in kwargs effective - super().__init__(app=None, *args, **kwargs) # type: ignore + super().__init__(app=None, *args, **kwargs) self.init_app(app, **kwargs) register_external_error_handlers(self) diff --git a/api/libs/gmpy2_pkcs10aep_cipher.py b/api/libs/gmpy2_pkcs10aep_cipher.py index fc38d51005..23eb8dca05 100644 --- a/api/libs/gmpy2_pkcs10aep_cipher.py +++ b/api/libs/gmpy2_pkcs10aep_cipher.py @@ -23,7 +23,7 @@ from hashlib import sha1 import Crypto.Hash.SHA1 import Crypto.Util.number -import gmpy2 # type: ignore +import gmpy2 from Crypto import Random from Crypto.Signature.pss import MGF1 from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes @@ -136,7 +136,7 @@ class PKCS1OAepCipher: # Step 3a (OS2IP) em_int = bytes_to_long(em) # Step 3b (RSAEP) - m_int = gmpy2.powmod(em_int, self._key.e, self._key.n) # ty: ignore [unresolved-attribute] + m_int = gmpy2.powmod(em_int, self._key.e, self._key.n) # Step 3c (I2OSP) c = long_to_bytes(m_int, k) return c @@ -169,7 +169,7 @@ class PKCS1OAepCipher: ct_int = bytes_to_long(ciphertext) # Step 2b (RSADP) # m_int = self._key._decrypt(ct_int) - m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n) # ty: ignore [unresolved-attribute] + m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n) # Complete step 2c (I2OSP) em = long_to_bytes(m_int, k) # Step 3a @@ -191,12 +191,12 @@ class PKCS1OAepCipher: # Step 3g one_pos = hLen + db[hLen:].find(b"\x01") lHash1 = db[:hLen] - invalid = bord(y) | int(one_pos < hLen) # type: ignore + invalid = bord(y) | int(one_pos < hLen) # type: ignore[arg-type] hash_compare = strxor(lHash1, lHash) for x in hash_compare: - invalid |= bord(x) # type: ignore + invalid |= bord(x) # type: ignore[arg-type] for x in db[hLen:one_pos]: - invalid |= bord(x) # type: ignore + invalid |= bord(x) # type: ignore[arg-type] if invalid != 0: raise ValueError("Incorrect decryption.") # Step 4 diff --git a/api/libs/helper.py b/api/libs/helper.py index b878141d8e..60484dd40b 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -81,6 +81,8 @@ class AvatarUrlField(fields.Raw): from models import Account if isinstance(obj, Account) and obj.avatar is not None: + if obj.avatar.startswith(("http://", "https://")): + return obj.avatar return file_helpers.get_signed_file_url(obj.avatar) return None diff --git a/api/libs/login.py b/api/libs/login.py index 5ed4bfae8f..4b8ee2d1f8 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -3,7 +3,7 @@ from functools import wraps from typing import Any from flask import current_app, g, has_request_context, request -from flask_login.config import EXEMPT_METHODS # type: ignore +from flask_login.config import EXEMPT_METHODS from werkzeug.local import LocalProxy from configs import dify_config @@ -87,7 +87,7 @@ def _get_user() -> EndUser | Account | None: if "_login_user" not in g: current_app.login_manager._load_user() # type: ignore - return g._login_user # type: ignore + return g._login_user return None diff --git a/api/libs/sendgrid.py b/api/libs/sendgrid.py index a270fa70fa..c047c54d06 100644 --- a/api/libs/sendgrid.py +++ b/api/libs/sendgrid.py @@ -1,8 +1,8 @@ import logging -import sendgrid # type: ignore +import sendgrid from python_http_client.exceptions import ForbiddenError, UnauthorizedError -from sendgrid.helpers.mail import Content, Email, Mail, To # type: ignore +from sendgrid.helpers.mail import Content, Email, Mail, To logger = logging.getLogger(__name__) diff --git a/api/libs/token.py b/api/libs/token.py index 4be25696e7..0b40f18143 100644 --- a/api/libs/token.py +++ b/api/libs/token.py @@ -12,6 +12,7 @@ from constants import ( COOKIE_NAME_CSRF_TOKEN, COOKIE_NAME_PASSPORT, COOKIE_NAME_REFRESH_TOKEN, + COOKIE_NAME_WEBAPP_ACCESS_TOKEN, HEADER_NAME_CSRF_TOKEN, HEADER_NAME_PASSPORT, ) @@ -81,6 +82,14 @@ def extract_access_token(request: Request) -> str | None: return _try_extract_from_cookie(request) or _try_extract_from_header(request) +def extract_webapp_access_token(request: Request) -> str | None: + """ + Try to extract webapp access token from cookie, then header. + """ + + return request.cookies.get(_real_cookie_name(COOKIE_NAME_WEBAPP_ACCESS_TOKEN)) or _try_extract_from_header(request) + + def extract_webapp_passport(app_code: str, request: Request) -> str | None: """ Try to extract app token from header or params. @@ -155,6 +164,10 @@ def clear_access_token_from_cookie(response: Response, samesite: str = "Lax"): _clear_cookie(response, COOKIE_NAME_ACCESS_TOKEN, samesite) +def clear_webapp_access_token_from_cookie(response: Response, samesite: str = "Lax"): + _clear_cookie(response, COOKIE_NAME_WEBAPP_ACCESS_TOKEN, samesite) + + def clear_refresh_token_from_cookie(response: Response): _clear_cookie(response, COOKIE_NAME_REFRESH_TOKEN) diff --git a/api/migrations/versions/2025_10_14_1618-d98acf217d43_add_app_mode_for_messsage.py b/api/migrations/versions/2025_10_14_1618-d98acf217d43_add_app_mode_for_messsage.py index 7d6797fca0..910cf75838 100644 --- a/api/migrations/versions/2025_10_14_1618-d98acf217d43_add_app_mode_for_messsage.py +++ b/api/migrations/versions/2025_10_14_1618-d98acf217d43_add_app_mode_for_messsage.py @@ -22,55 +22,6 @@ def upgrade(): batch_op.add_column(sa.Column('app_mode', sa.String(length=255), nullable=True)) batch_op.create_index('message_app_mode_idx', ['app_mode'], unique=False) - conn = op.get_bind() - - # Strategy: Update in batches to minimize lock time - # For large tables (millions of rows), this prevents long-running transactions - batch_size = 10000 - - print("Starting backfill of app_mode from conversations...") - - # Use a more efficient UPDATE with JOIN - # This query updates messages.app_mode from conversations.mode - # Using string formatting for LIMIT since it's a constant - update_query = f""" - UPDATE messages m - SET app_mode = c.mode - FROM conversations c - WHERE m.conversation_id = c.id - AND m.app_mode IS NULL - AND m.id IN ( - SELECT id FROM messages - WHERE app_mode IS NULL - LIMIT {batch_size} - ) - """ - - # Execute batched updates - total_updated = 0 - iteration = 0 - while True: - iteration += 1 - result = conn.execute(sa.text(update_query)) - - # Check if result is None or has no rowcount - if result is None: - print("Warning: Query returned None, stopping backfill") - break - - rows_updated = result.rowcount if hasattr(result, 'rowcount') else 0 - total_updated += rows_updated - - if rows_updated == 0: - break - - print(f"Iteration {iteration}: Updated {rows_updated} messages (total: {total_updated})") - - # For very large tables, add a small delay to reduce load - # Uncomment if needed: import time; time.sleep(0.1) - - print(f"Backfill completed. Total messages updated: {total_updated}") - # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_10_21_1430-ae662b25d9bc_remove_builtin_template_user.py b/api/migrations/versions/2025_10_21_1430-ae662b25d9bc_remove_builtin_template_user.py new file mode 100644 index 0000000000..086a02e7c3 --- /dev/null +++ b/api/migrations/versions/2025_10_21_1430-ae662b25d9bc_remove_builtin_template_user.py @@ -0,0 +1,36 @@ +"""remove-builtin-template-user + +Revision ID: ae662b25d9bc +Revises: d98acf217d43 +Create Date: 2025-10-21 14:30:28.566192 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'ae662b25d9bc' +down_revision = 'd98acf217d43' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op: + batch_op.drop_column('updated_by') + batch_op.drop_column('created_by') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op: + batch_op.add_column(sa.Column('created_by', sa.UUID(), autoincrement=False, nullable=False)) + batch_op.add_column(sa.Column('updated_by', sa.UUID(), autoincrement=False, nullable=True)) + + # ### end Alembic commands ### diff --git a/api/models/account.py b/api/models/account.py index 86cd9e41b5..400a2c6362 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -5,7 +5,7 @@ from datetime import datetime from typing import Any, Optional import sqlalchemy as sa -from flask_login import UserMixin # type: ignore[import-untyped] +from flask_login import UserMixin from sqlalchemy import DateTime, String, func, select from sqlalchemy.orm import Mapped, Session, mapped_column from typing_extensions import deprecated diff --git a/api/models/dataset.py b/api/models/dataset.py index 5653445f2b..4a9e2688b8 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -1239,15 +1239,6 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined] language = mapped_column(db.String(255), nullable=False) created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) - created_by = mapped_column(StringUUID, nullable=False) - updated_by = mapped_column(StringUUID, nullable=True) - - @property - def created_user_name(self): - account = db.session.query(Account).where(Account.id == self.created_by).first() - if account: - return account.name - return "" class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] diff --git a/api/models/model.py b/api/models/model.py index af22ab9538..8a8574e2fe 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast import sqlalchemy as sa from flask import request -from flask_login import UserMixin # type: ignore[import-untyped] +from flask_login import UserMixin from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text from sqlalchemy.orm import Mapped, Session, mapped_column diff --git a/api/models/tools.py b/api/models/tools.py index d45db3d350..12acc149b1 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -219,7 +219,7 @@ class WorkflowToolProvider(TypeBase): sa.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) # name of the workflow provider name: Mapped[str] = mapped_column(String(255), nullable=False) # label of the workflow provider diff --git a/api/pyproject.toml b/api/pyproject.toml index 040d9658b3..5a9becaaef 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dify-api" -version = "1.9.1" +version = "1.9.2" requires-python = ">=3.11,<3.13" dependencies = [ diff --git a/api/pyrightconfig.json b/api/pyrightconfig.json index bf4ec2314e..6a689b96df 100644 --- a/api/pyrightconfig.json +++ b/api/pyrightconfig.json @@ -16,7 +16,25 @@ "opentelemetry.instrumentation.requests", "opentelemetry.instrumentation.sqlalchemy", "opentelemetry.instrumentation.redis", - "opentelemetry.instrumentation.httpx" + "langfuse", + "cloudscraper", + "readabilipy", + "pypandoc", + "pypdfium2", + "webvtt", + "flask_compress", + "oss2", + "baidubce.auth.bce_credentials", + "baidubce.bce_client_configuration", + "baidubce.services.bos.bos_client", + "clickzetta", + "google.cloud", + "obs", + "qcloud_cos", + "tos", + "gmpy2", + "sendgrid", + "sendgrid.helpers.mail" ], "reportUnknownMemberType": "hint", "reportUnknownParameterType": "hint", @@ -28,7 +46,7 @@ "reportUnnecessaryComparison": "hint", "reportUnnecessaryIsInstance": "hint", "reportUntypedFunctionDecorator": "hint", - + "reportUnnecessaryTypeIgnoreComment": "hint", "reportAttributeAccessIssue": "hint", "pythonVersion": "3.11", "pythonPlatform": "All" diff --git a/api/repositories/factory.py b/api/repositories/factory.py index 0be9c8908c..96f9f886a4 100644 --- a/api/repositories/factory.py +++ b/api/repositories/factory.py @@ -48,7 +48,7 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory): try: repository_class = import_string(class_path) - return repository_class(session_maker=session_maker) # type: ignore[no-any-return] + return repository_class(session_maker=session_maker) except (ImportError, Exception) as e: raise RepositoryImportError( f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}" @@ -77,6 +77,6 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory): try: repository_class = import_string(class_path) - return repository_class(session_maker=session_maker) # type: ignore[no-any-return] + return repository_class(session_maker=session_maker) except (ImportError, Exception) as e: raise RepositoryImportError(f"Failed to create APIWorkflowRunRepository from '{class_path}': {e}") from e diff --git a/api/services/account_service.py b/api/services/account_service.py index cb0eb7a9dd..13c3993fb5 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -13,7 +13,7 @@ from sqlalchemy.orm import Session from werkzeug.exceptions import Unauthorized from configs import dify_config -from constants.languages import language_timezone_mapping, languages +from constants.languages import get_valid_language, language_timezone_mapping from events.tenant_event import tenant_was_created from extensions.ext_database import db from extensions.ext_redis import redis_client, redis_fallback @@ -1259,7 +1259,7 @@ class RegisterService: return f"member_invite:token:{token}" @classmethod - def setup(cls, email: str, name: str, password: str, ip_address: str): + def setup(cls, email: str, name: str, password: str, ip_address: str, language: str): """ Setup dify @@ -1269,11 +1269,10 @@ class RegisterService: :param ip_address: ip address """ try: - # Register account = AccountService.create_account( email=email, name=name, - interface_language=languages[0], + interface_language=get_valid_language(language), password=password, is_setup=True, ) @@ -1315,7 +1314,7 @@ class RegisterService: account = AccountService.create_account( email=email, name=name, - interface_language=language or languages[0], + interface_language=get_valid_language(language), password=password, is_setup=is_setup, ) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index e2915ebfbb..edb18a845a 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -7,7 +7,7 @@ from enum import StrEnum from urllib.parse import urlparse from uuid import uuid4 -import yaml # type: ignore +import yaml from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad from packaging import version @@ -563,7 +563,7 @@ class AppDslService: else: cls._append_model_config_export_data(export_data, app_model) - return yaml.dump(export_data, allow_unicode=True) # type: ignore + return yaml.dump(export_data, allow_unicode=True) @classmethod def _append_workflow_export_data( diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index f4047da6b8..c97d419545 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -241,9 +241,9 @@ class DatasetService: dataset.created_by = account.id dataset.updated_by = account.id dataset.tenant_id = tenant_id - dataset.embedding_model_provider = embedding_model.provider if embedding_model else None # type: ignore - dataset.embedding_model = embedding_model.model if embedding_model else None # type: ignore - dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None # type: ignore + dataset.embedding_model_provider = embedding_model.provider if embedding_model else None + dataset.embedding_model = embedding_model.model if embedding_model else None + dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None dataset.permission = permission or DatasetPermissionEnum.ONLY_ME dataset.provider = provider db.session.add(dataset) @@ -1416,6 +1416,8 @@ class DocumentService: # check document limit assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None + assert knowledge_config.data_source + assert knowledge_config.data_source.info_list.file_info_list features = FeatureService.get_features(current_user.current_tenant_id) @@ -1424,15 +1426,16 @@ class DocumentService: count = 0 if knowledge_config.data_source: if knowledge_config.data_source.info_list.data_source_type == "upload_file": - upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore + upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids count = len(upload_file_list) elif knowledge_config.data_source.info_list.data_source_type == "notion_import": - notion_info_list = knowledge_config.data_source.info_list.notion_info_list - for notion_info in notion_info_list: # type: ignore + notion_info_list = knowledge_config.data_source.info_list.notion_info_list or [] + for notion_info in notion_info_list: count = count + len(notion_info.pages) elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": website_info = knowledge_config.data_source.info_list.website_info_list - count = len(website_info.urls) # type: ignore + assert website_info + count = len(website_info.urls) batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) if features.billing.subscription.plan == "sandbox" and count > 1: @@ -1444,7 +1447,7 @@ class DocumentService: # if dataset is empty, update dataset data_source_type if not dataset.data_source_type: - dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore + dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type if not dataset.indexing_technique: if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: @@ -1481,7 +1484,7 @@ class DocumentService: knowledge_config.retrieval_model.model_dump() if knowledge_config.retrieval_model else default_retrieval_model - ) # type: ignore + ) documents = [] if knowledge_config.original_document_id: @@ -1523,11 +1526,12 @@ class DocumentService: db.session.flush() lock_name = f"add_document_lock_dataset_id_{dataset.id}" with redis_client.lock(lock_name, timeout=600): + assert dataset_process_rule position = DocumentService.get_documents_position(dataset.id) document_ids = [] duplicate_document_ids = [] - if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore - upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore + if knowledge_config.data_source.info_list.data_source_type == "upload_file": + upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids for file_id in upload_file_list: file = ( db.session.query(UploadFile) @@ -1540,7 +1544,7 @@ class DocumentService: raise FileNotExistsError() file_name = file.name - data_source_info = { + data_source_info: dict[str, str | bool] = { "upload_file_id": file_id, } # check duplicate @@ -1557,7 +1561,7 @@ class DocumentService: .first() ) if document: - document.dataset_process_rule_id = dataset_process_rule.id # type: ignore + document.dataset_process_rule_id = dataset_process_rule.id document.updated_at = naive_utc_now() document.created_from = created_from document.doc_form = knowledge_config.doc_form @@ -1571,8 +1575,8 @@ class DocumentService: continue document = DocumentService.build_document( dataset, - dataset_process_rule.id, # type: ignore - knowledge_config.data_source.info_list.data_source_type, # type: ignore + dataset_process_rule.id, + knowledge_config.data_source.info_list.data_source_type, knowledge_config.doc_form, knowledge_config.doc_language, data_source_info, @@ -1587,7 +1591,7 @@ class DocumentService: document_ids.append(document.id) documents.append(document) position += 1 - elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore + elif knowledge_config.data_source.info_list.data_source_type == "notion_import": notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore if not notion_info_list: raise ValueError("No notion info list found.") @@ -1616,15 +1620,15 @@ class DocumentService: "credential_id": notion_info.credential_id, "notion_workspace_id": workspace_id, "notion_page_id": page.page_id, - "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, + "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore "type": page.type, } # Truncate page name to 255 characters to prevent DB field length errors truncated_page_name = page.page_name[:255] if page.page_name else "nopagename" document = DocumentService.build_document( dataset, - dataset_process_rule.id, # type: ignore - knowledge_config.data_source.info_list.data_source_type, # type: ignore + dataset_process_rule.id, + knowledge_config.data_source.info_list.data_source_type, knowledge_config.doc_form, knowledge_config.doc_language, data_source_info, @@ -1644,8 +1648,8 @@ class DocumentService: # delete not selected documents if len(exist_document) > 0: clean_notion_document_task.delay(list(exist_document.values()), dataset.id) - elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore - website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore + elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": + website_info = knowledge_config.data_source.info_list.website_info_list if not website_info: raise ValueError("No website info list found.") urls = website_info.urls @@ -1663,8 +1667,8 @@ class DocumentService: document_name = url document = DocumentService.build_document( dataset, - dataset_process_rule.id, # type: ignore - knowledge_config.data_source.info_list.data_source_type, # type: ignore + dataset_process_rule.id, + knowledge_config.data_source.info_list.data_source_type, knowledge_config.doc_form, knowledge_config.doc_language, data_source_info, @@ -2071,7 +2075,7 @@ class DocumentService: # update document data source if document_data.data_source: file_name = "" - data_source_info = {} + data_source_info: dict[str, str | bool] = {} if document_data.data_source.info_list.data_source_type == "upload_file": if not document_data.data_source.info_list.file_info_list: raise ValueError("No file info list found.") @@ -2128,7 +2132,7 @@ class DocumentService: "url": url, "provider": website_info.provider, "job_id": website_info.job_id, - "only_main_content": website_info.only_main_content, # type: ignore + "only_main_content": website_info.only_main_content, "mode": "crawl", } document.data_source_type = document_data.data_source.info_list.data_source_type @@ -2154,7 +2158,7 @@ class DocumentService: db.session.query(DocumentSegment).filter_by(document_id=document.id).update( {DocumentSegment.status: "re_segment"} - ) # type: ignore + ) db.session.commit() # trigger async task document_indexing_update_task.delay(document.dataset_id, document.id) @@ -2164,25 +2168,26 @@ class DocumentService: def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account): assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None + assert knowledge_config.data_source features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: count = 0 - if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore + if knowledge_config.data_source.info_list.data_source_type == "upload_file": upload_file_list = ( - knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore - if knowledge_config.data_source.info_list.file_info_list # type: ignore + knowledge_config.data_source.info_list.file_info_list.file_ids + if knowledge_config.data_source.info_list.file_info_list else [] ) count = len(upload_file_list) - elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore - notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore + elif knowledge_config.data_source.info_list.data_source_type == "notion_import": + notion_info_list = knowledge_config.data_source.info_list.notion_info_list if notion_info_list: for notion_info in notion_info_list: count = count + len(notion_info.pages) - elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore - website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore + elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": + website_info = knowledge_config.data_source.info_list.website_info_list if website_info: count = len(website_info.urls) if features.billing.subscription.plan == "sandbox" and count > 1: @@ -2196,9 +2201,11 @@ class DocumentService: dataset_collection_binding_id = None retrieval_model = None if knowledge_config.indexing_technique == "high_quality": + assert knowledge_config.embedding_model_provider + assert knowledge_config.embedding_model dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - knowledge_config.embedding_model_provider, # type: ignore - knowledge_config.embedding_model, # type: ignore + knowledge_config.embedding_model_provider, + knowledge_config.embedding_model, ) dataset_collection_binding_id = dataset_collection_binding.id if knowledge_config.retrieval_model: @@ -2215,7 +2222,7 @@ class DocumentService: dataset = Dataset( tenant_id=tenant_id, name="", - data_source_type=knowledge_config.data_source.info_list.data_source_type, # type: ignore + data_source_type=knowledge_config.data_source.info_list.data_source_type, indexing_technique=knowledge_config.indexing_technique, created_by=account.id, embedding_model=knowledge_config.embedding_model, @@ -2224,7 +2231,7 @@ class DocumentService: retrieval_model=retrieval_model.model_dump() if retrieval_model else None, ) - db.session.add(dataset) # type: ignore + db.session.add(dataset) db.session.flush() documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 19d96cb972..148442f76e 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -174,6 +174,7 @@ class FeatureService: if dify_config.ENTERPRISE_ENABLED: features.webapp_copyright_enabled = True + features.knowledge_pipeline.publish_enabled = True cls._fulfill_params_from_workspace_info(features, tenant_id) return features diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 7fa82c6d22..337181728c 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -88,7 +88,7 @@ class HitTestingService: db.session.add(dataset_query) db.session.commit() - return cls.compact_retrieve_response(query, all_documents) # type: ignore + return cls.compact_retrieve_response(query, all_documents) @classmethod def external_retrieve( diff --git a/api/services/knowledge_service.py b/api/services/knowledge_service.py index 8df1a6ba14..02fe1d19bc 100644 --- a/api/services/knowledge_service.py +++ b/api/services/knowledge_service.py @@ -1,4 +1,4 @@ -import boto3 # type: ignore +import boto3 from configs import dify_config diff --git a/api/services/message_service.py b/api/services/message_service.py index 9fdff18622..7ed56d80f2 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -288,9 +288,10 @@ class MessageService: ) with measure_time() as timer: - questions: list[str] = LLMGenerator.generate_suggested_questions_after_answer( + questions_sequence = LLMGenerator.generate_suggested_questions_after_answer( tenant_id=app_model.tenant_id, histories=histories ) + questions: list[str] = list(questions_sequence) # get tracing instance trace_manager = TraceQueueManager(app_id=app_model.id) diff --git a/api/services/metadata_service.py b/api/services/metadata_service.py index 5f280c9e57..b369994d2d 100644 --- a/api/services/metadata_service.py +++ b/api/services/metadata_service.py @@ -89,7 +89,7 @@ class MetadataService: document.doc_metadata = doc_metadata db.session.add(document) db.session.commit() - return metadata # type: ignore + return metadata except Exception: logger.exception("Update metadata name failed") finally: diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 2901a0d273..50ddbbf681 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -137,7 +137,7 @@ class ModelProviderService: :return: """ provider_configuration = self._get_provider_configuration(tenant_id, provider) - return provider_configuration.get_provider_credential(credential_id=credential_id) # type: ignore + return provider_configuration.get_provider_credential(credential_id=credential_id) def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict): """ @@ -225,7 +225,7 @@ class ModelProviderService: :return: """ provider_configuration = self._get_provider_configuration(tenant_id, provider) - return provider_configuration.get_custom_model_credential( # type: ignore + return provider_configuration.get_custom_model_credential( model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id ) diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index dec92a6faa..df5fa3e233 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -146,7 +146,7 @@ class PluginMigration: futures.append( thread_pool.submit( process_tenant, - current_app._get_current_object(), # type: ignore[attr-defined] + current_app._get_current_object(), # type: ignore tenant_id, ) ) diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index ec91f79606..908f9a2684 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -74,5 +74,4 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): "chunk_structure": pipeline_template.chunk_structure, "export_data": pipeline_template.yaml_content, "graph": graph_data, - "created_by": pipeline_template.created_user_name, } diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index b5dcec17d0..0628c8f22e 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -544,8 +544,8 @@ class BuiltinToolManageService: try: # handle include, exclude if is_filtered( - include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore - exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore + include_set=dify_config.POSITION_TOOL_INCLUDES_SET, + exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, data=provider_controller, name_func=lambda x: x.entity.identity.name, ): diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 2449536d5c..b1cc963681 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -4,6 +4,7 @@ from datetime import datetime from typing import Any from sqlalchemy import or_, select +from sqlalchemy.orm import Session from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool_provider import ToolProviderController @@ -13,6 +14,7 @@ from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurati from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db +from libs.uuid_utils import uuidv7 from models.model import App from models.tools import WorkflowToolProvider from models.workflow import Workflow @@ -63,27 +65,27 @@ class WorkflowToolManageService: if workflow is None: raise ValueError(f"Workflow not found for app {workflow_app_id}") - workflow_tool_provider = WorkflowToolProvider( - tenant_id=tenant_id, - user_id=user_id, - app_id=workflow_app_id, - name=name, - label=label, - icon=json.dumps(icon), - description=description, - parameter_configuration=json.dumps(parameters), - privacy_policy=privacy_policy, - version=workflow.version, - ) + with Session(db.engine, expire_on_commit=False) as session, session.begin(): + workflow_tool_provider = WorkflowToolProvider( + id=str(uuidv7()), + tenant_id=tenant_id, + user_id=user_id, + app_id=workflow_app_id, + name=name, + label=label, + icon=json.dumps(icon), + description=description, + parameter_configuration=json.dumps(parameters), + privacy_policy=privacy_policy, + version=workflow.version, + ) + session.add(workflow_tool_provider) try: WorkflowToolProviderController.from_db(workflow_tool_provider) except Exception as e: raise ValueError(str(e)) - db.session.add(workflow_tool_provider) - db.session.commit() - if labels is not None: ToolLabelManager.update_tool_labels( ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels @@ -168,7 +170,6 @@ class WorkflowToolManageService: except Exception as e: raise ValueError(str(e)) - db.session.add(workflow_tool_provider) db.session.commit() if labels is not None: diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index d02508e4f3..4e13d2d964 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -17,6 +17,7 @@ from core.variables.segments import ( StringSegment, ) from core.variables.utils import dumps_with_segments +from core.workflow.nodes.variable_assigner.common.helpers import UpdatedVariable _MAX_DEPTH = 100 @@ -56,7 +57,7 @@ class UnknownTypeError(Exception): pass -JSONTypes: TypeAlias = int | float | str | list | dict | None | bool +JSONTypes: TypeAlias = int | float | str | list[object] | dict[str, object] | None | bool @dataclasses.dataclass(frozen=True) @@ -79,7 +80,7 @@ class VariableTruncator: self, string_length_limit=5000, array_element_limit: int = 20, - max_size_bytes: int = 1024_000, # 100KB + max_size_bytes: int = 1024_000, # 1000 KiB ): if string_length_limit <= 3: raise ValueError("string_length_limit should be greater than 3.") @@ -202,6 +203,9 @@ class VariableTruncator: """Recursively calculate JSON size without serialization.""" if isinstance(value, Segment): return VariableTruncator.calculate_json_size(value.value) + if isinstance(value, UpdatedVariable): + # TODO(Workflow): migrate UpdatedVariable serialization upstream and drop this fallback. + return VariableTruncator.calculate_json_size(value.model_dump(), depth=depth + 1) if depth > _MAX_DEPTH: raise MaxDepthExceededError() if isinstance(value, str): @@ -248,14 +252,14 @@ class VariableTruncator: truncated_value = value[:truncated_size] + "..." return _PartResult(truncated_value, self.calculate_json_size(truncated_value), True) - def _truncate_array(self, value: list, target_size: int) -> _PartResult[list]: + def _truncate_array(self, value: list[object], target_size: int) -> _PartResult[list[object]]: """ Truncate array with correct strategy: 1. First limit to 20 items 2. If still too large, truncate individual items """ - truncated_value: list[Any] = [] + truncated_value: list[object] = [] truncated = False used_size = self.calculate_json_size([]) @@ -278,7 +282,11 @@ class VariableTruncator: if used_size > target_size: break - part_result = self._truncate_json_primitives(item, target_size - used_size) + remaining_budget = target_size - used_size + if item is None or isinstance(item, (str, list, dict, bool, int, float)): + part_result = self._truncate_json_primitives(item, remaining_budget) + else: + raise UnknownTypeError(f"got unknown type {type(item)} in array truncation") truncated_value.append(part_result.value) used_size += part_result.value_size truncated = part_result.truncated @@ -369,10 +377,10 @@ class VariableTruncator: def _truncate_json_primitives(self, val: str, target_size: int) -> _PartResult[str]: ... @overload - def _truncate_json_primitives(self, val: list, target_size: int) -> _PartResult[list]: ... + def _truncate_json_primitives(self, val: list[object], target_size: int) -> _PartResult[list[object]]: ... @overload - def _truncate_json_primitives(self, val: dict, target_size: int) -> _PartResult[dict]: ... + def _truncate_json_primitives(self, val: dict[str, object], target_size: int) -> _PartResult[dict[str, object]]: ... @overload def _truncate_json_primitives(self, val: bool, target_size: int) -> _PartResult[bool]: ... # type: ignore @@ -387,10 +395,15 @@ class VariableTruncator: def _truncate_json_primitives(self, val: None, target_size: int) -> _PartResult[None]: ... def _truncate_json_primitives( - self, val: str | list | dict | bool | int | float | None, target_size: int + self, + val: UpdatedVariable | str | list[object] | dict[str, object] | bool | int | float | None, + target_size: int, ) -> _PartResult[Any]: """Truncate a value within an object to fit within budget.""" - if isinstance(val, str): + if isinstance(val, UpdatedVariable): + # TODO(Workflow): push UpdatedVariable normalization closer to its producer. + return self._truncate_object(val.model_dump(), target_size) + elif isinstance(val, str): return self._truncate_string(val, target_size) elif isinstance(val, list): return self._truncate_array(val, target_size) diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index b528728364..bd95af2614 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -102,7 +102,7 @@ def batch_create_segment_to_index_task( for segment, tokens in zip(content, tokens_list): content = segment["content"] doc_id = str(uuid.uuid4()) - segment_hash = helper.generate_text_hash(content) # type: ignore + segment_hash = helper.generate_text_hash(content) max_position = ( db.session.query(func.max(DocumentSegment.position)) .where(DocumentSegment.document_id == dataset_document.id) diff --git a/api/tests/integration_tests/conftest.py b/api/tests/integration_tests/conftest.py index 9dc7b76e04..4395a9815a 100644 --- a/api/tests/integration_tests/conftest.py +++ b/api/tests/integration_tests/conftest.py @@ -58,6 +58,7 @@ def setup_account(request) -> Generator[Account, None, None]: name=name, password=secrets.token_hex(16), ip_address="localhost", + language="en-US", ) with _CACHED_APP.test_request_context(): diff --git a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py index 8a43d03a43..3984078ee9 100644 --- a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py @@ -5,11 +5,11 @@ from unittest.mock import MagicMock import pytest from _pytest.monkeypatch import MonkeyPatch -from pymochow import MochowClient # type: ignore -from pymochow.model.database import Database # type: ignore -from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState # type: ignore -from pymochow.model.schema import HNSWParams, VectorIndex # type: ignore -from pymochow.model.table import Table # type: ignore +from pymochow import MochowClient +from pymochow.model.database import Database +from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState +from pymochow.model.schema import HNSWParams, VectorIndex +from pymochow.model.table import Table class AttrDict(UserDict): diff --git a/api/tests/integration_tests/vdb/__mock/tcvectordb.py b/api/tests/integration_tests/vdb/__mock/tcvectordb.py index 5130fcfe17..8f87d6a073 100644 --- a/api/tests/integration_tests/vdb/__mock/tcvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/tcvectordb.py @@ -3,15 +3,15 @@ from typing import Any, Union import pytest from _pytest.monkeypatch import MonkeyPatch -from tcvectordb import RPCVectorDBClient # type: ignore +from tcvectordb import RPCVectorDBClient from tcvectordb.model import enum from tcvectordb.model.collection import FilterIndexConfig -from tcvectordb.model.document import AnnSearch, Document, Filter, KeywordSearch, Rerank # type: ignore -from tcvectordb.model.enum import ReadConsistency # type: ignore -from tcvectordb.model.index import FilterIndex, HNSWParams, Index, IndexField, VectorIndex # type: ignore +from tcvectordb.model.document import AnnSearch, Document, Filter, KeywordSearch, Rerank +from tcvectordb.model.enum import ReadConsistency +from tcvectordb.model.index import FilterIndex, HNSWParams, Index, IndexField, VectorIndex from tcvectordb.rpc.model.collection import RPCCollection from tcvectordb.rpc.model.database import RPCDatabase -from xinference_client.types import Embedding # type: ignore +from xinference_client.types import Embedding class MockTcvectordbClass: diff --git a/api/tests/integration_tests/vdb/__mock/vikingdb.py b/api/tests/integration_tests/vdb/__mock/vikingdb.py index f351df8d5b..289c515b85 100644 --- a/api/tests/integration_tests/vdb/__mock/vikingdb.py +++ b/api/tests/integration_tests/vdb/__mock/vikingdb.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock import pytest from _pytest.monkeypatch import MonkeyPatch -from volcengine.viking_db import ( # type: ignore +from volcengine.viking_db import ( Collection, Data, DistanceType, diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index c59fc50f08..4d4e77a802 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -2299,6 +2299,7 @@ class TestRegisterService: name=admin_name, password=admin_password, ip_address=ip_address, + language="en-US", ) # Verify account was created @@ -2348,6 +2349,7 @@ class TestRegisterService: name=admin_name, password=admin_password, ip_address=ip_address, + language="en-US", ) # Verify no entities were created (rollback worked) diff --git a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py index ae0c7b7a6b..e2c616420f 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py @@ -6,6 +6,7 @@ from faker import Faker from core.tools.entities.api_entities import ToolProviderApiEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType +from libs.uuid_utils import uuidv7 from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider from services.tools.tools_transform_service import ToolTransformService @@ -66,6 +67,7 @@ class TestToolTransformService: ) elif provider_type == "workflow": provider = WorkflowToolProvider( + id=str(uuidv7()), name=fake.company(), description=fake.text(max_nb_chars=100), icon='{"background": "#FF6B6B", "content": "🔧"}', @@ -758,6 +760,7 @@ class TestToolTransformService: # Create workflow tool provider provider = WorkflowToolProvider( + id=str(uuidv7()), name=fake.company(), description=fake.text(max_nb_chars=100), icon='{"background": "#FF6B6B", "content": "🔧"}', diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py index 5895f63f94..8423f1ab02 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py @@ -43,7 +43,7 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue: """Test with None input""" # The method signature expects Union[dict, list, Segment], but implementation handles None # We'll test the actual behavior by passing an empty dict instead - result = WorkflowResponseConverter._fetch_files_from_variable_value(None) # type: ignore + result = WorkflowResponseConverter._fetch_files_from_variable_value(None) assert result == [] def test_fetch_files_from_variable_value_with_empty_dict(self): diff --git a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py index 895ebdd751..fe9f0935d5 100644 --- a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py +++ b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py @@ -235,7 +235,7 @@ class TestIndividualHandlers: # Type assertion needed due to union type text_content = result.content[0] assert hasattr(text_content, "text") - assert text_content.text == "test answer" # type: ignore[attr-defined] + assert text_content.text == "test answer" def test_handle_call_tool_no_end_user(self): """Test call tool handler without end user""" diff --git a/api/tests/unit_tests/core/tools/utils/test_parser.py b/api/tests/unit_tests/core/tools/utils/test_parser.py index e1eab21ca4..f39158aa59 100644 --- a/api/tests/unit_tests/core/tools/utils/test_parser.py +++ b/api/tests/unit_tests/core/tools/utils/test_parser.py @@ -109,3 +109,83 @@ def test_parse_openapi_to_tool_bundle_properties_all_of(app): assert tool_bundles[0].parameters[0].llm_description == "desc prop1" # TODO: support enum in OpenAPI # assert set(tool_bundles[0].parameters[0].options) == {"option1", "option2", "option3"} + + +def test_parse_openapi_to_tool_bundle_default_value_type_casting(app): + """ + Test that default values are properly cast to match parameter types. + This addresses the issue where array default values like [] cause validation errors + when parameter type is inferred as string/number/boolean. + """ + openapi = { + "openapi": "3.0.0", + "info": {"title": "Test API", "version": "1.0.0"}, + "servers": [{"url": "https://example.com"}], + "paths": { + "/product/create": { + "post": { + "operationId": "createProduct", + "summary": "Create a product", + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "categories": { + "description": "List of category identifiers", + "default": [], + "type": "array", + "items": {"type": "string"}, + }, + "name": { + "description": "Product name", + "default": "Default Product", + "type": "string", + }, + "price": {"description": "Product price", "default": 0.0, "type": "number"}, + "available": { + "description": "Product availability", + "default": True, + "type": "boolean", + }, + }, + } + } + } + }, + "responses": {"200": {"description": "Default Response"}}, + } + } + }, + } + + with app.test_request_context(): + tool_bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi) + + assert len(tool_bundles) == 1 + bundle = tool_bundles[0] + assert len(bundle.parameters) == 4 + + # Find parameters by name + params_by_name = {param.name: param for param in bundle.parameters} + + # Check categories parameter (array type with [] default) + categories_param = params_by_name["categories"] + assert categories_param.type == "array" # Will be detected by _get_tool_parameter_type + assert categories_param.default is None # Array default [] is converted to None + + # Check name parameter (string type with string default) + name_param = params_by_name["name"] + assert name_param.type == "string" + assert name_param.default == "Default Product" + + # Check price parameter (number type with number default) + price_param = params_by_name["price"] + assert price_param.type == "number" + assert price_param.default == 0.0 + + # Check available parameter (boolean type with boolean default) + available_param = params_by_name["available"] + assert available_param.type == "boolean" + assert available_param.default is True diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py new file mode 100644 index 0000000000..b55d4998c4 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py @@ -0,0 +1,181 @@ +from __future__ import annotations + +import time +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool +from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType +from core.workflow.graph import Graph +from core.workflow.graph.validation import GraphValidationError +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.base.node import Node +from core.workflow.system_variable import SystemVariable +from models.enums import UserFrom + + +class _TestNode(Node): + node_type = NodeType.ANSWER + execution_type = NodeExecutionType.EXECUTABLE + + @classmethod + def version(cls) -> str: + return "test" + + def __init__( + self, + *, + id: str, + config: Mapping[str, object], + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + data = config.get("data", {}) + if isinstance(data, Mapping): + execution_type = data.get("execution_type") + if isinstance(execution_type, str): + self.execution_type = NodeExecutionType(execution_type) + self._base_node_data = BaseNodeData(title=str(data.get("title", self.id))) + self.data: dict[str, object] = {} + + def init_node_data(self, data: Mapping[str, object]) -> None: + title = str(data.get("title", self.id)) + desc = data.get("description") + error_strategy_value = data.get("error_strategy") + error_strategy: ErrorStrategy | None = None + if isinstance(error_strategy_value, ErrorStrategy): + error_strategy = error_strategy_value + elif isinstance(error_strategy_value, str): + error_strategy = ErrorStrategy(error_strategy_value) + self._base_node_data = BaseNodeData( + title=title, + desc=str(desc) if desc is not None else None, + error_strategy=error_strategy, + ) + self.data = dict(data) + + def _run(self): + raise NotImplementedError + + def _get_error_strategy(self) -> ErrorStrategy | None: + return self._base_node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._base_node_data.retry_config + + def _get_title(self) -> str: + return self._base_node_data.title + + def _get_description(self) -> str | None: + return self._base_node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._base_node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._base_node_data + + +@dataclass(slots=True) +class _SimpleNodeFactory: + graph_init_params: GraphInitParams + graph_runtime_state: GraphRuntimeState + + def create_node(self, node_config: Mapping[str, object]) -> _TestNode: + node_id = str(node_config["id"]) + node = _TestNode( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + ) + node.init_node_data(node_config.get("data", {})) + return node + + +@pytest.fixture +def graph_init_dependencies() -> tuple[_SimpleNodeFactory, dict[str, object]]: + graph_config: dict[str, object] = {"edges": [], "nodes": []} + init_params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config=graph_config, + user_id="user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ) + variable_pool = VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}) + runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + factory = _SimpleNodeFactory(graph_init_params=init_params, graph_runtime_state=runtime_state) + return factory, graph_config + + +def test_graph_initialization_runs_default_validators( + graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], +): + node_factory, graph_config = graph_init_dependencies + graph_config["nodes"] = [ + {"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}}, + {"id": "answer", "data": {"type": NodeType.ANSWER, "title": "Answer"}}, + ] + graph_config["edges"] = [ + {"source": "start", "target": "answer", "sourceHandle": "success"}, + ] + + graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + + assert graph.root_node.id == "start" + assert "answer" in graph.nodes + + +def test_graph_validation_fails_for_unknown_edge_targets( + graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], +) -> None: + node_factory, graph_config = graph_init_dependencies + graph_config["nodes"] = [ + {"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}}, + ] + graph_config["edges"] = [ + {"source": "start", "target": "missing", "sourceHandle": "success"}, + ] + + with pytest.raises(GraphValidationError) as exc: + Graph.init(graph_config=graph_config, node_factory=node_factory) + + assert any(issue.code == "MISSING_NODE" for issue in exc.value.issues) + + +def test_graph_promotes_fail_branch_nodes_to_branch_execution_type( + graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], +) -> None: + node_factory, graph_config = graph_init_dependencies + graph_config["nodes"] = [ + {"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}}, + { + "id": "branch", + "data": { + "type": NodeType.IF_ELSE, + "title": "Branch", + "error_strategy": ErrorStrategy.FAIL_BRANCH, + }, + }, + ] + graph_config["edges"] = [ + {"source": "start", "target": "branch", "sourceHandle": "success"}, + ] + + graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + + assert graph.nodes["branch"].execution_type == NodeExecutionType.BRANCH diff --git a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py index b9947d4693..b359284d00 100644 --- a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py @@ -212,7 +212,7 @@ class TestValidateResult: parameters=[ ParameterConfig( name="status", - type="select", # type: ignore + type="select", description="Status", required=True, options=["active", "inactive"], @@ -400,7 +400,7 @@ class TestTransformResult: parameters=[ ParameterConfig( name="status", - type="select", # type: ignore + type="select", description="Status", required=True, options=["active", "inactive"], @@ -414,7 +414,7 @@ class TestTransformResult: parameters=[ ParameterConfig( name="status", - type="select", # type: ignore + type="select", description="Status", required=True, options=["active", "inactive"], diff --git a/api/tests/unit_tests/core/workflow/test_system_variable.py b/api/tests/unit_tests/core/workflow/test_system_variable.py index 3ae5edb383..f76e81ae55 100644 --- a/api/tests/unit_tests/core/workflow/test_system_variable.py +++ b/api/tests/unit_tests/core/workflow/test_system_variable.py @@ -248,4 +248,4 @@ def test_constructor_with_extra_key(): # Test that SystemVariable should forbid extra keys with pytest.raises(ValidationError): # This should fail because there is an unexpected key. - SystemVariable(invalid_key=1) # type: ignore + SystemVariable(invalid_key=1) diff --git a/api/tests/unit_tests/libs/test_external_api.py b/api/tests/unit_tests/libs/test_external_api.py index c4c376a070..9aa157a651 100644 --- a/api/tests/unit_tests/libs/test_external_api.py +++ b/api/tests/unit_tests/libs/test_external_api.py @@ -14,36 +14,36 @@ def _create_api_app(): api = ExternalApi(bp) @api.route("/bad-request") - class Bad(Resource): # type: ignore - def get(self): # type: ignore + class Bad(Resource): + def get(self): raise BadRequest("invalid input") @api.route("/unauth") - class Unauth(Resource): # type: ignore - def get(self): # type: ignore + class Unauth(Resource): + def get(self): raise Unauthorized("auth required") @api.route("/value-error") - class ValErr(Resource): # type: ignore - def get(self): # type: ignore + class ValErr(Resource): + def get(self): raise ValueError("boom") @api.route("/quota") - class Quota(Resource): # type: ignore - def get(self): # type: ignore + class Quota(Resource): + def get(self): raise AppInvokeQuotaExceededError("quota exceeded") @api.route("/general") - class Gen(Resource): # type: ignore - def get(self): # type: ignore + class Gen(Resource): + def get(self): raise RuntimeError("oops") # Note: We avoid altering default_mediatype to keep normal error paths # Special 400 message rewrite @api.route("/json-empty") - class JsonEmpty(Resource): # type: ignore - def get(self): # type: ignore + class JsonEmpty(Resource): + def get(self): e = BadRequest() # Force the specific message the handler rewrites e.description = "Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)" @@ -51,11 +51,11 @@ def _create_api_app(): # 400 mapping payload path @api.route("/param-errors") - class ParamErrors(Resource): # type: ignore - def get(self): # type: ignore + class ParamErrors(Resource): + def get(self): e = BadRequest() # Coerce a mapping description to trigger param error shaping - e.description = {"field": "is required"} # type: ignore[assignment] + e.description = {"field": "is required"} raise e app.register_blueprint(bp, url_prefix="/api") @@ -105,7 +105,7 @@ def test_external_api_param_mapping_and_quota_and_exc_info_none(): orig_exc_info = ext.sys.exc_info try: - ext.sys.exc_info = lambda: (None, None, None) # type: ignore[assignment] + ext.sys.exc_info = lambda: (None, None, None) app = _create_api_app() client = app.test_client() diff --git a/api/tests/unit_tests/libs/test_flask_utils.py b/api/tests/unit_tests/libs/test_flask_utils.py index e30433bfce..9cab0db24c 100644 --- a/api/tests/unit_tests/libs/test_flask_utils.py +++ b/api/tests/unit_tests/libs/test_flask_utils.py @@ -67,7 +67,7 @@ def test_current_user_not_accessible_across_threads(login_app: Flask, test_user: # without preserve_flask_contexts result["user_accessible"] = current_user.is_authenticated except Exception as e: - result["error"] = str(e) # type: ignore + result["error"] = str(e) # Run the function in a separate thread thread = threading.Thread(target=check_user_in_thread) @@ -110,7 +110,7 @@ def test_current_user_accessible_with_preserve_flask_contexts(login_app: Flask, else: result["user_accessible"] = False except Exception as e: - result["error"] = str(e) # type: ignore + result["error"] = str(e) # Run the function in a separate thread thread = threading.Thread(target=check_user_in_thread_with_manager) diff --git a/api/tests/unit_tests/libs/test_oauth_base.py b/api/tests/unit_tests/libs/test_oauth_base.py index 3e0c235fff..7b7f086dac 100644 --- a/api/tests/unit_tests/libs/test_oauth_base.py +++ b/api/tests/unit_tests/libs/test_oauth_base.py @@ -16,4 +16,4 @@ def test_oauth_base_methods_raise_not_implemented(): oauth.get_raw_user_info("token") with pytest.raises(NotImplementedError): - oauth._transform_user_info({}) # type: ignore[name-defined] + oauth._transform_user_info({}) diff --git a/api/tests/unit_tests/libs/test_token.py b/api/tests/unit_tests/libs/test_token.py index 22790fa4a6..a611d3eb0e 100644 --- a/api/tests/unit_tests/libs/test_token.py +++ b/api/tests/unit_tests/libs/test_token.py @@ -1,5 +1,5 @@ -from constants import COOKIE_NAME_ACCESS_TOKEN -from libs.token import extract_access_token +from constants import COOKIE_NAME_ACCESS_TOKEN, COOKIE_NAME_WEBAPP_ACCESS_TOKEN +from libs.token import extract_access_token, extract_webapp_access_token class MockRequest: @@ -14,10 +14,12 @@ def test_extract_access_token(): return MockRequest(headers, cookies, args) test_cases = [ - (_mock_request({"Authorization": "Bearer 123"}, {}, {}), "123"), - (_mock_request({}, {COOKIE_NAME_ACCESS_TOKEN: "123"}, {}), "123"), - (_mock_request({}, {}, {}), None), - (_mock_request({"Authorization": "Bearer_aaa 123"}, {}, {}), None), + (_mock_request({"Authorization": "Bearer 123"}, {}, {}), "123", "123"), + (_mock_request({}, {COOKIE_NAME_ACCESS_TOKEN: "123"}, {}), "123", None), + (_mock_request({}, {}, {}), None, None), + (_mock_request({"Authorization": "Bearer_aaa 123"}, {}, {}), None, None), + (_mock_request({}, {COOKIE_NAME_WEBAPP_ACCESS_TOKEN: "123"}, {}), None, "123"), ] - for request, expected in test_cases: - assert extract_access_token(request) == expected # pyright: ignore[reportArgumentType] + for request, expected_console, expected_webapp in test_cases: + assert extract_access_token(request) == expected_console # pyright: ignore[reportArgumentType] + assert extract_webapp_access_token(request) == expected_webapp # pyright: ignore[reportArgumentType] diff --git a/api/tests/unit_tests/oss/__mock/tencent_cos.py b/api/tests/unit_tests/oss/__mock/tencent_cos.py index c77c5b08f3..5189b68e87 100644 --- a/api/tests/unit_tests/oss/__mock/tencent_cos.py +++ b/api/tests/unit_tests/oss/__mock/tencent_cos.py @@ -3,8 +3,8 @@ from unittest.mock import MagicMock import pytest from _pytest.monkeypatch import MonkeyPatch -from qcloud_cos import CosS3Client # type: ignore -from qcloud_cos.streambody import StreamBody # type: ignore +from qcloud_cos import CosS3Client +from qcloud_cos.streambody import StreamBody from tests.unit_tests.oss.__mock.base import ( get_example_bucket, diff --git a/api/tests/unit_tests/oss/__mock/volcengine_tos.py b/api/tests/unit_tests/oss/__mock/volcengine_tos.py index 88df59f91c..649d93a202 100644 --- a/api/tests/unit_tests/oss/__mock/volcengine_tos.py +++ b/api/tests/unit_tests/oss/__mock/volcengine_tos.py @@ -4,8 +4,8 @@ from unittest.mock import MagicMock import pytest from _pytest.monkeypatch import MonkeyPatch -from tos import TosClientV2 # type: ignore -from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput # type: ignore +from tos import TosClientV2 +from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput from tests.unit_tests.oss.__mock.base import ( get_example_bucket, diff --git a/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py b/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py index d289751800..303f0493bd 100644 --- a/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py +++ b/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py @@ -1,7 +1,7 @@ from unittest.mock import patch import pytest -from qcloud_cos import CosConfig # type: ignore +from qcloud_cos import CosConfig from extensions.storage.tencent_cos_storage import TencentCosStorage from tests.unit_tests.oss.__mock.base import ( diff --git a/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py b/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py index 1659205ec0..a06623a69e 100644 --- a/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py +++ b/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py @@ -1,7 +1,7 @@ from unittest.mock import patch import pytest -from tos import TosClientV2 # type: ignore +from tos import TosClientV2 from extensions.storage.volcengine_tos_storage import VolcengineTosStorage from tests.unit_tests.oss.__mock.base import ( diff --git a/api/tests/unit_tests/services/auth/test_api_key_auth_service.py b/api/tests/unit_tests/services/auth/test_api_key_auth_service.py index d23298f096..c6c3f677fb 100644 --- a/api/tests/unit_tests/services/auth/test_api_key_auth_service.py +++ b/api/tests/unit_tests/services/auth/test_api_key_auth_service.py @@ -125,13 +125,13 @@ class TestApiKeyAuthService: mock_session.commit = Mock() args_copy = self.mock_args.copy() - original_key = args_copy["credentials"]["config"]["api_key"] # type: ignore + original_key = args_copy["credentials"]["config"]["api_key"] ApiKeyAuthService.create_provider_auth(self.tenant_id, args_copy) # Verify original key is replaced with encrypted key - assert args_copy["credentials"]["config"]["api_key"] == encrypted_key # type: ignore - assert args_copy["credentials"]["config"]["api_key"] != original_key # type: ignore + assert args_copy["credentials"]["config"]["api_key"] == encrypted_key + assert args_copy["credentials"]["config"]["api_key"] != original_key # Verify encryption function is called correctly mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, original_key) @@ -268,7 +268,7 @@ class TestApiKeyAuthService: def test_validate_api_key_auth_args_empty_credentials(self): """Test API key auth args validation - empty credentials""" args = self.mock_args.copy() - args["credentials"] = None # type: ignore + args["credentials"] = None with pytest.raises(ValueError, match="credentials is required"): ApiKeyAuthService.validate_api_key_auth_args(args) @@ -284,7 +284,7 @@ class TestApiKeyAuthService: def test_validate_api_key_auth_args_missing_auth_type(self): """Test API key auth args validation - missing auth_type""" args = self.mock_args.copy() - del args["credentials"]["auth_type"] # type: ignore + del args["credentials"]["auth_type"] with pytest.raises(ValueError, match="auth_type is required"): ApiKeyAuthService.validate_api_key_auth_args(args) @@ -292,7 +292,7 @@ class TestApiKeyAuthService: def test_validate_api_key_auth_args_empty_auth_type(self): """Test API key auth args validation - empty auth_type""" args = self.mock_args.copy() - args["credentials"]["auth_type"] = "" # type: ignore + args["credentials"]["auth_type"] = "" with pytest.raises(ValueError, match="auth_type is required"): ApiKeyAuthService.validate_api_key_auth_args(args) @@ -380,7 +380,7 @@ class TestApiKeyAuthService: def test_validate_api_key_auth_args_dict_credentials_with_list_auth_type(self): """Test API key auth args validation - dict credentials with list auth_type""" args = self.mock_args.copy() - args["credentials"]["auth_type"] = ["api_key"] # type: ignore # list instead of string + args["credentials"]["auth_type"] = ["api_key"] # Current implementation checks if auth_type exists and is truthy, list ["api_key"] is truthy # So this should not raise exception, this test should pass diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index 737202f8de..627a04bcd0 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -893,7 +893,7 @@ class TestRegisterService: mock_dify_setup.return_value = mock_dify_setup_instance # Execute test - RegisterService.setup("admin@example.com", "Admin User", "password123", "192.168.1.1") + RegisterService.setup("admin@example.com", "Admin User", "password123", "192.168.1.1", "en-US") # Verify results mock_create_account.assert_called_once_with( @@ -925,6 +925,7 @@ class TestRegisterService: "Admin User", "password123", "192.168.1.1", + "en-US", ) # Verify rollback operations were called diff --git a/api/tests/unit_tests/services/test_dataset_service_delete_dataset.py b/api/tests/unit_tests/services/test_dataset_service_delete_dataset.py new file mode 100644 index 0000000000..cc718c9997 --- /dev/null +++ b/api/tests/unit_tests/services/test_dataset_service_delete_dataset.py @@ -0,0 +1,216 @@ +from unittest.mock import Mock, patch + +import pytest + +from models.account import Account, TenantAccountRole +from models.dataset import Dataset +from services.dataset_service import DatasetService + + +class DatasetDeleteTestDataFactory: + """Factory class for creating test data and mock objects for dataset delete tests.""" + + @staticmethod + def create_dataset_mock( + dataset_id: str = "dataset-123", + tenant_id: str = "test-tenant-123", + created_by: str = "creator-456", + doc_form: str | None = None, + indexing_technique: str | None = "high_quality", + **kwargs, + ) -> Mock: + """Create a mock dataset with specified attributes.""" + dataset = Mock(spec=Dataset) + dataset.id = dataset_id + dataset.tenant_id = tenant_id + dataset.created_by = created_by + dataset.doc_form = doc_form + dataset.indexing_technique = indexing_technique + for key, value in kwargs.items(): + setattr(dataset, key, value) + return dataset + + @staticmethod + def create_user_mock( + user_id: str = "user-789", + tenant_id: str = "test-tenant-123", + role: TenantAccountRole = TenantAccountRole.ADMIN, + **kwargs, + ) -> Mock: + """Create a mock user with specified attributes.""" + user = Mock(spec=Account) + user.id = user_id + user.current_tenant_id = tenant_id + user.current_role = role + for key, value in kwargs.items(): + setattr(user, key, value) + return user + + +class TestDatasetServiceDeleteDataset: + """ + Comprehensive unit tests for DatasetService.delete_dataset method. + + This test suite covers all deletion scenarios including: + - Normal dataset deletion with documents + - Empty dataset deletion (no documents, doc_form is None) + - Dataset deletion with missing indexing_technique + - Permission checks + - Event handling + + This test suite provides regression protection for issue #27073. + """ + + @pytest.fixture + def mock_dataset_service_dependencies(self): + """Common mock setup for dataset service dependencies.""" + with ( + patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset, + patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm, + patch("extensions.ext_database.db.session") as mock_db, + patch("services.dataset_service.dataset_was_deleted") as mock_dataset_was_deleted, + ): + yield { + "get_dataset": mock_get_dataset, + "check_permission": mock_check_perm, + "db_session": mock_db, + "dataset_was_deleted": mock_dataset_was_deleted, + } + + def test_delete_dataset_with_documents_success(self, mock_dataset_service_dependencies): + """ + Test successful deletion of a dataset with documents. + + This test verifies: + - Dataset is retrieved correctly + - Permission check is performed + - dataset_was_deleted event is sent + - Dataset is deleted from database + - Method returns True + """ + # Arrange + dataset = DatasetDeleteTestDataFactory.create_dataset_mock( + doc_form="text_model", indexing_technique="high_quality" + ) + user = DatasetDeleteTestDataFactory.create_user_mock() + + mock_dataset_service_dependencies["get_dataset"].return_value = dataset + + # Act + result = DatasetService.delete_dataset(dataset.id, user) + + # Assert + assert result is True + mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id) + mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user) + mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset) + mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset) + mock_dataset_service_dependencies["db_session"].commit.assert_called_once() + + def test_delete_empty_dataset_success(self, mock_dataset_service_dependencies): + """ + Test successful deletion of an empty dataset (no documents, doc_form is None). + + This test verifies that: + - Empty datasets can be deleted without errors + - dataset_was_deleted event is sent (event handler will skip cleanup if doc_form is None) + - Dataset is deleted from database + - Method returns True + + This is the primary test for issue #27073 where deleting an empty dataset + caused internal server error due to assertion failure in event handlers. + """ + # Arrange + dataset = DatasetDeleteTestDataFactory.create_dataset_mock(doc_form=None, indexing_technique=None) + user = DatasetDeleteTestDataFactory.create_user_mock() + + mock_dataset_service_dependencies["get_dataset"].return_value = dataset + + # Act + result = DatasetService.delete_dataset(dataset.id, user) + + # Assert - Verify complete deletion flow + assert result is True + mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id) + mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user) + mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset) + mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset) + mock_dataset_service_dependencies["db_session"].commit.assert_called_once() + + def test_delete_dataset_with_partial_none_values(self, mock_dataset_service_dependencies): + """ + Test deletion of dataset with partial None values. + + This test verifies that datasets with partial None values (e.g., doc_form exists + but indexing_technique is None) can be deleted successfully. The event handler + will skip cleanup if any required field is None. + + Improvement based on Gemini Code Assist suggestion: Added comprehensive assertions + to verify all core deletion operations are performed, not just event sending. + """ + # Arrange + dataset = DatasetDeleteTestDataFactory.create_dataset_mock(doc_form="text_model", indexing_technique=None) + user = DatasetDeleteTestDataFactory.create_user_mock() + + mock_dataset_service_dependencies["get_dataset"].return_value = dataset + + # Act + result = DatasetService.delete_dataset(dataset.id, user) + + # Assert - Verify complete deletion flow (Gemini suggestion implemented) + assert result is True + mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id) + mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user) + mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset) + mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset) + mock_dataset_service_dependencies["db_session"].commit.assert_called_once() + + def test_delete_dataset_with_doc_form_none_indexing_technique_exists(self, mock_dataset_service_dependencies): + """ + Test deletion of dataset where doc_form is None but indexing_technique exists. + + This edge case can occur in certain dataset configurations and should be handled + gracefully by the event handler's conditional check. + """ + # Arrange + dataset = DatasetDeleteTestDataFactory.create_dataset_mock(doc_form=None, indexing_technique="high_quality") + user = DatasetDeleteTestDataFactory.create_user_mock() + + mock_dataset_service_dependencies["get_dataset"].return_value = dataset + + # Act + result = DatasetService.delete_dataset(dataset.id, user) + + # Assert - Verify complete deletion flow + assert result is True + mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id) + mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user) + mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset) + mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset) + mock_dataset_service_dependencies["db_session"].commit.assert_called_once() + + def test_delete_dataset_not_found(self, mock_dataset_service_dependencies): + """ + Test deletion attempt when dataset doesn't exist. + + This test verifies that: + - Method returns False when dataset is not found + - No deletion operations are performed + - No events are sent + """ + # Arrange + dataset_id = "non-existent-dataset" + user = DatasetDeleteTestDataFactory.create_user_mock() + + mock_dataset_service_dependencies["get_dataset"].return_value = None + + # Act + result = DatasetService.delete_dataset(dataset_id, user) + + # Assert + assert result is False + mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset_id) + mock_dataset_service_dependencies["check_permission"].assert_not_called() + mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_not_called() + mock_dataset_service_dependencies["db_session"].delete.assert_not_called() + mock_dataset_service_dependencies["db_session"].commit.assert_not_called() diff --git a/api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py b/api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py index 30990f8d50..e2607f0fb1 100644 --- a/api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py +++ b/api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py @@ -116,10 +116,10 @@ class TestSystemOAuthEncrypter: encrypter = SystemOAuthEncrypter("test_secret") with pytest.raises(Exception): # noqa: B017 - encrypter.encrypt_oauth_params(None) # type: ignore + encrypter.encrypt_oauth_params(None) with pytest.raises(Exception): # noqa: B017 - encrypter.encrypt_oauth_params("not_a_dict") # type: ignore + encrypter.encrypt_oauth_params("not_a_dict") def test_decrypt_oauth_params_basic(self): """Test basic OAuth parameters decryption""" @@ -207,12 +207,12 @@ class TestSystemOAuthEncrypter: encrypter = SystemOAuthEncrypter("test_secret") with pytest.raises(ValueError) as exc_info: - encrypter.decrypt_oauth_params(123) # type: ignore + encrypter.decrypt_oauth_params(123) assert "encrypted_data must be a string" in str(exc_info.value) with pytest.raises(ValueError) as exc_info: - encrypter.decrypt_oauth_params(None) # type: ignore + encrypter.decrypt_oauth_params(None) assert "encrypted_data must be a string" in str(exc_info.value) @@ -461,14 +461,14 @@ class TestConvenienceFunctions: """Test convenience functions with error conditions""" # Test encryption with invalid input with pytest.raises(Exception): # noqa: B017 - encrypt_system_oauth_params(None) # type: ignore + encrypt_system_oauth_params(None) # Test decryption with invalid input with pytest.raises(ValueError): decrypt_system_oauth_params("") with pytest.raises(ValueError): - decrypt_system_oauth_params(None) # type: ignore + decrypt_system_oauth_params(None) class TestErrorHandling: @@ -501,7 +501,7 @@ class TestErrorHandling: # Test non-string error with pytest.raises(ValueError) as exc_info: - encrypter.decrypt_oauth_params(123) # type: ignore + encrypter.decrypt_oauth_params(123) assert "encrypted_data must be a string" in str(exc_info.value) # Test invalid format error diff --git a/api/uv.lock b/api/uv.lock index e7e51acedf..066f9a58a4 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1292,7 +1292,7 @@ wheels = [ [[package]] name = "dify-api" -version = "1.9.1" +version = "1.9.2" source = { virtual = "." } dependencies = [ { name = "arize-phoenix-otel" }, diff --git a/dev/basedpyright-check b/dev/basedpyright-check index ef58ed1f57..1c87b27d6f 100755 --- a/dev/basedpyright-check +++ b/dev/basedpyright-check @@ -10,7 +10,7 @@ PATH_TO_CHECK="$1" # run basedpyright checks if [ -n "$PATH_TO_CHECK" ]; then - uv run --directory api --dev basedpyright "$PATH_TO_CHECK" + uv run --directory api --dev -- basedpyright --threads $(nproc) "$PATH_TO_CHECK" else - uv run --directory api --dev basedpyright + uv run --directory api --dev -- basedpyright --threads $(nproc) fi diff --git a/docker/.env.example b/docker/.env.example index b0e8d020ba..ca580dcb79 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -259,6 +259,18 @@ POSTGRES_MAINTENANCE_WORK_MEM=64MB # Reference: https://www.postgresql.org/docs/current/runtime-config-query.html#GUC-EFFECTIVE-CACHE-SIZE POSTGRES_EFFECTIVE_CACHE_SIZE=4096MB +# Sets the maximum allowed duration of any statement before termination. +# Default is 60000 milliseconds. +# +# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-STATEMENT-TIMEOUT +POSTGRES_STATEMENT_TIMEOUT=60000 + +# Sets the maximum allowed duration of any idle in-transaction session before termination. +# Default is 60000 milliseconds. +# +# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-IDLE-IN-TRANSACTION-SESSION-TIMEOUT +POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT=60000 + # ------------------------------ # Redis Configuration # This Redis configuration is used for caching and for pub/sub during conversation. diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 5a67c080cc..9650be90db 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -2,7 +2,7 @@ x-shared-env: &shared-api-worker-env services: # API service api: - image: langgenius/dify-api:1.9.1 + image: langgenius/dify-api:1.9.2 restart: always environment: # Use the shared environment variables. @@ -24,13 +24,6 @@ services: volumes: # Mount the storage directory to the container, for storing user files. - ./volumes/app/storage:/app/api/storage - # TODO: Remove this entrypoint override when weaviate-client 4.17.0 is included in the next Dify release - entrypoint: - - /bin/bash - - -c - - | - uv pip install --system weaviate-client==4.17.0 - exec /bin/bash /app/api/docker/entrypoint.sh networks: - ssrf_proxy_network - default @@ -38,7 +31,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:1.9.1 + image: langgenius/dify-api:1.9.2 restart: always environment: # Use the shared environment variables. @@ -58,13 +51,6 @@ services: volumes: # Mount the storage directory to the container, for storing user files. - ./volumes/app/storage:/app/api/storage - # TODO: Remove this entrypoint override when weaviate-client 4.17.0 is included in the next Dify release - entrypoint: - - /bin/bash - - -c - - | - uv pip install --system weaviate-client==4.17.0 - exec /bin/bash /app/api/docker/entrypoint.sh networks: - ssrf_proxy_network - default @@ -72,7 +58,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.9.1 + image: langgenius/dify-api:1.9.2 restart: always environment: # Use the shared environment variables. @@ -90,7 +76,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.9.1 + image: langgenius/dify-web:1.9.2 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -129,6 +115,8 @@ services: -c 'work_mem=${POSTGRES_WORK_MEM:-4MB}' -c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}' -c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}' + -c 'statement_timeout=${POSTGRES_STATEMENT_TIMEOUT:-60000}' + -c 'idle_in_transaction_session_timeout=${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-60000}' volumes: - ./volumes/db/data:/var/lib/postgresql/data healthcheck: @@ -191,7 +179,7 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.3.0-local + image: langgenius/dify-plugin-daemon:0.3.3-local restart: always environment: # Use the shared environment variables. diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index ebc619a50f..9a1b9b53ba 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -15,6 +15,8 @@ services: -c 'work_mem=${POSTGRES_WORK_MEM:-4MB}' -c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}' -c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}' + -c 'statement_timeout=${POSTGRES_STATEMENT_TIMEOUT:-60000}' + -c 'idle_in_transaction_session_timeout=${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-60000}' volumes: - ${PGDATA_HOST_VOLUME:-./volumes/db/data}:/var/lib/postgresql/data ports: @@ -85,7 +87,7 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.3.0-local + image: langgenius/dify-plugin-daemon:0.3.3-local restart: always env_file: - ./middleware.env diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 421b733e2b..d2ca6b859e 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -68,6 +68,8 @@ x-shared-env: &shared-api-worker-env POSTGRES_WORK_MEM: ${POSTGRES_WORK_MEM:-4MB} POSTGRES_MAINTENANCE_WORK_MEM: ${POSTGRES_MAINTENANCE_WORK_MEM:-64MB} POSTGRES_EFFECTIVE_CACHE_SIZE: ${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB} + POSTGRES_STATEMENT_TIMEOUT: ${POSTGRES_STATEMENT_TIMEOUT:-60000} + POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT: ${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-60000} REDIS_HOST: ${REDIS_HOST:-redis} REDIS_PORT: ${REDIS_PORT:-6379} REDIS_USERNAME: ${REDIS_USERNAME:-} @@ -609,7 +611,7 @@ x-shared-env: &shared-api-worker-env services: # API service api: - image: langgenius/dify-api:1.9.1 + image: langgenius/dify-api:1.9.2 restart: always environment: # Use the shared environment variables. @@ -631,13 +633,6 @@ services: volumes: # Mount the storage directory to the container, for storing user files. - ./volumes/app/storage:/app/api/storage - # TODO: Remove this entrypoint override when weaviate-client 4.17.0 is included in the next Dify release - entrypoint: - - /bin/bash - - -c - - | - uv pip install --system weaviate-client==4.17.0 - exec /bin/bash /app/api/docker/entrypoint.sh networks: - ssrf_proxy_network - default @@ -645,7 +640,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:1.9.1 + image: langgenius/dify-api:1.9.2 restart: always environment: # Use the shared environment variables. @@ -665,13 +660,6 @@ services: volumes: # Mount the storage directory to the container, for storing user files. - ./volumes/app/storage:/app/api/storage - # TODO: Remove this entrypoint override when weaviate-client 4.17.0 is included in the next Dify release - entrypoint: - - /bin/bash - - -c - - | - uv pip install --system weaviate-client==4.17.0 - exec /bin/bash /app/api/docker/entrypoint.sh networks: - ssrf_proxy_network - default @@ -679,7 +667,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.9.1 + image: langgenius/dify-api:1.9.2 restart: always environment: # Use the shared environment variables. @@ -697,7 +685,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.9.1 + image: langgenius/dify-web:1.9.2 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -736,6 +724,8 @@ services: -c 'work_mem=${POSTGRES_WORK_MEM:-4MB}' -c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}' -c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}' + -c 'statement_timeout=${POSTGRES_STATEMENT_TIMEOUT:-60000}' + -c 'idle_in_transaction_session_timeout=${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-60000}' volumes: - ./volumes/db/data:/var/lib/postgresql/data healthcheck: @@ -798,7 +788,7 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.3.0-local + image: langgenius/dify-plugin-daemon:0.3.3-local restart: always environment: # Use the shared environment variables. diff --git a/docker/middleware.env.example b/docker/middleware.env.example index 2eba62f594..c9bb8c0528 100644 --- a/docker/middleware.env.example +++ b/docker/middleware.env.example @@ -40,6 +40,18 @@ POSTGRES_MAINTENANCE_WORK_MEM=64MB # Reference: https://www.postgresql.org/docs/current/runtime-config-query.html#GUC-EFFECTIVE-CACHE-SIZE POSTGRES_EFFECTIVE_CACHE_SIZE=4096MB +# Sets the maximum allowed duration of any statement before termination. +# Default is 60000 milliseconds. +# +# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-STATEMENT-TIMEOUT +POSTGRES_STATEMENT_TIMEOUT=60000 + +# Sets the maximum allowed duration of any idle in-transaction session before termination. +# Default is 60000 milliseconds. +# +# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-IDLE-IN-TRANSACTION-SESSION-TIMEOUT +POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT=60000 + # ----------------------------- # Environment Variables for redis Service # ----------------------------- diff --git a/web/.storybook/__mocks__/context-block.tsx b/web/.storybook/__mocks__/context-block.tsx new file mode 100644 index 0000000000..8a9d8625cc --- /dev/null +++ b/web/.storybook/__mocks__/context-block.tsx @@ -0,0 +1,4 @@ +// Mock for context-block plugin to avoid circular dependency in Storybook +export const ContextBlockNode = null +export const ContextBlockReplacementBlock = null +export default null diff --git a/web/.storybook/__mocks__/history-block.tsx b/web/.storybook/__mocks__/history-block.tsx new file mode 100644 index 0000000000..e3c3965d13 --- /dev/null +++ b/web/.storybook/__mocks__/history-block.tsx @@ -0,0 +1,4 @@ +// Mock for history-block plugin to avoid circular dependency in Storybook +export const HistoryBlockNode = null +export const HistoryBlockReplacementBlock = null +export default null diff --git a/web/.storybook/__mocks__/query-block.tsx b/web/.storybook/__mocks__/query-block.tsx new file mode 100644 index 0000000000..d82f51363a --- /dev/null +++ b/web/.storybook/__mocks__/query-block.tsx @@ -0,0 +1,4 @@ +// Mock for query-block plugin to avoid circular dependency in Storybook +export const QueryBlockNode = null +export const QueryBlockReplacementBlock = null +export default null diff --git a/web/.storybook/main.ts b/web/.storybook/main.ts index 0605c71346..ca56261431 100644 --- a/web/.storybook/main.ts +++ b/web/.storybook/main.ts @@ -1,4 +1,8 @@ import type { StorybookConfig } from '@storybook/nextjs' +import path from 'node:path' +import { fileURLToPath } from 'node:url' + +const storybookDir = path.dirname(fileURLToPath(import.meta.url)) const config: StorybookConfig = { stories: ['../app/components/**/*.stories.@(js|jsx|mjs|ts|tsx)'], @@ -25,5 +29,17 @@ const config: StorybookConfig = { docs: { defaultName: 'Documentation', }, + webpackFinal: async (config) => { + // Add alias to mock problematic modules with circular dependencies + config.resolve = config.resolve || {} + config.resolve.alias = { + ...config.resolve.alias, + // Mock the plugin index files to avoid circular dependencies + [path.resolve(storybookDir, '../app/components/base/prompt-editor/plugins/context-block/index.tsx')]: path.resolve(storybookDir, '__mocks__/context-block.tsx'), + [path.resolve(storybookDir, '../app/components/base/prompt-editor/plugins/history-block/index.tsx')]: path.resolve(storybookDir, '__mocks__/history-block.tsx'), + [path.resolve(storybookDir, '../app/components/base/prompt-editor/plugins/query-block/index.tsx')]: path.resolve(storybookDir, '__mocks__/query-block.tsx'), + } + return config + }, } export default config diff --git a/web/.storybook/utils/audio-player-manager.mock.ts b/web/.storybook/utils/audio-player-manager.mock.ts new file mode 100644 index 0000000000..aca8b56b76 --- /dev/null +++ b/web/.storybook/utils/audio-player-manager.mock.ts @@ -0,0 +1,64 @@ +import { AudioPlayerManager } from '@/app/components/base/audio-btn/audio.player.manager' + +type PlayerCallback = ((event: string) => void) | null + +class MockAudioPlayer { + private callback: PlayerCallback = null + private finishTimer?: ReturnType + + public setCallback(callback: PlayerCallback) { + this.callback = callback + } + + public playAudio() { + this.clearTimer() + this.callback?.('play') + this.finishTimer = setTimeout(() => { + this.callback?.('ended') + }, 2000) + } + + public pauseAudio() { + this.clearTimer() + this.callback?.('paused') + } + + private clearTimer() { + if (this.finishTimer) + clearTimeout(this.finishTimer) + } +} + +class MockAudioPlayerManager { + private readonly player = new MockAudioPlayer() + + public getAudioPlayer( + _url: string, + _isPublic: boolean, + _id: string | undefined, + _msgContent: string | null | undefined, + _voice: string | undefined, + callback: PlayerCallback, + ) { + this.player.setCallback(callback) + return this.player + } + + public resetMsgId() { + // No-op for the mock + } +} + +export const ensureMockAudioManager = () => { + const managerAny = AudioPlayerManager as unknown as { + getInstance: () => AudioPlayerManager + __isStorybookMockInstalled?: boolean + } + + if (managerAny.__isStorybookMockInstalled) + return + + const mock = new MockAudioPlayerManager() + managerAny.getInstance = () => mock as unknown as AudioPlayerManager + managerAny.__isStorybookMockInstalled = true +} diff --git a/web/__tests__/navigation-utils.test.ts b/web/__tests__/navigation-utils.test.ts index 9a388505d6..fa4986e63d 100644 --- a/web/__tests__/navigation-utils.test.ts +++ b/web/__tests__/navigation-utils.test.ts @@ -160,8 +160,7 @@ describe('Navigation Utilities', () => { page: 1, limit: '', keyword: 'test', - empty: null, - undefined, + filter: '', }) expect(path).toBe('/datasets/123/documents?page=1&keyword=test') diff --git a/web/__tests__/real-browser-flicker.test.tsx b/web/__tests__/real-browser-flicker.test.tsx index f71e8de515..0a0ea0c062 100644 --- a/web/__tests__/real-browser-flicker.test.tsx +++ b/web/__tests__/real-browser-flicker.test.tsx @@ -39,28 +39,38 @@ const setupMockEnvironment = (storedTheme: string | null, systemPrefersDark = fa const isDarkQuery = DARK_MODE_MEDIA_QUERY.test(query) const matches = isDarkQuery ? systemPrefersDark : false + const handleAddListener = (listener: (event: MediaQueryListEvent) => void) => { + listeners.add(listener) + } + + const handleRemoveListener = (listener: (event: MediaQueryListEvent) => void) => { + listeners.delete(listener) + } + + const handleAddEventListener = (_event: string, listener: EventListener) => { + if (typeof listener === 'function') + listeners.add(listener as (event: MediaQueryListEvent) => void) + } + + const handleRemoveEventListener = (_event: string, listener: EventListener) => { + if (typeof listener === 'function') + listeners.delete(listener as (event: MediaQueryListEvent) => void) + } + + const handleDispatchEvent = (event: Event) => { + listeners.forEach(listener => listener(event as MediaQueryListEvent)) + return true + } + const mediaQueryList: MediaQueryList = { matches, media: query, onchange: null, - addListener: (listener: MediaQueryListListener) => { - listeners.add(listener) - }, - removeListener: (listener: MediaQueryListListener) => { - listeners.delete(listener) - }, - addEventListener: (_event, listener: EventListener) => { - if (typeof listener === 'function') - listeners.add(listener as MediaQueryListListener) - }, - removeEventListener: (_event, listener: EventListener) => { - if (typeof listener === 'function') - listeners.delete(listener as MediaQueryListListener) - }, - dispatchEvent: (event: Event) => { - listeners.forEach(listener => listener(event as MediaQueryListEvent)) - return true - }, + addListener: handleAddListener, + removeListener: handleRemoveListener, + addEventListener: handleAddEventListener, + removeEventListener: handleRemoveEventListener, + dispatchEvent: handleDispatchEvent, } return mediaQueryList @@ -69,6 +79,121 @@ const setupMockEnvironment = (storedTheme: string | null, systemPrefersDark = fa jest.spyOn(window, 'matchMedia').mockImplementation(mockMatchMedia) } +// Helper function to create timing page component +const createTimingPageComponent = ( + timingData: Array<{ phase: string; timestamp: number; styles: { backgroundColor: string; color: string } }>, +) => { + const recordTiming = (phase: string, styles: { backgroundColor: string; color: string }) => { + timingData.push({ + phase, + timestamp: performance.now(), + styles, + }) + } + + const TimingPageComponent = () => { + const [mounted, setMounted] = useState(false) + const { theme } = useTheme() + const isDark = mounted ? theme === 'dark' : false + + const currentStyles = { + backgroundColor: isDark ? '#1f2937' : '#ffffff', + color: isDark ? '#ffffff' : '#000000', + } + + recordTiming(mounted ? 'CSR' : 'Initial', currentStyles) + + useEffect(() => { + setMounted(true) + }, []) + + return ( +
+
+ Phase: {mounted ? 'CSR' : 'Initial'} | Theme: {theme} | Visual: {isDark ? 'dark' : 'light'} +
+
+ ) + } + + return TimingPageComponent +} + +// Helper function to create CSS test component +const createCSSTestComponent = ( + cssStates: Array<{ className: string; timestamp: number }>, +) => { + const recordCSSState = (className: string) => { + cssStates.push({ + className, + timestamp: performance.now(), + }) + } + + const CSSTestComponent = () => { + const [mounted, setMounted] = useState(false) + const { theme } = useTheme() + const isDark = mounted ? theme === 'dark' : false + + const className = `min-h-screen ${isDark ? 'bg-gray-900 text-white' : 'bg-white text-black'}` + + recordCSSState(className) + + useEffect(() => { + setMounted(true) + }, []) + + return ( +
+
Classes: {className}
+
+ ) + } + + return CSSTestComponent +} + +// Helper function to create performance test component +const createPerformanceTestComponent = ( + performanceMarks: Array<{ event: string; timestamp: number }>, +) => { + const recordPerformanceMark = (event: string) => { + performanceMarks.push({ event, timestamp: performance.now() }) + } + + const PerformanceTestComponent = () => { + const [mounted, setMounted] = useState(false) + const { theme } = useTheme() + + recordPerformanceMark('component-render') + + useEffect(() => { + recordPerformanceMark('mount-start') + setMounted(true) + recordPerformanceMark('mount-complete') + }, []) + + useEffect(() => { + if (theme) + recordPerformanceMark('theme-available') + }, [theme]) + + return ( +
+ Mounted: {mounted.toString()} | Theme: {theme || 'loading'} +
+ ) + } + + return PerformanceTestComponent +} + // Simulate real page component based on Dify's actual theme usage const PageComponent = () => { const [mounted, setMounted] = useState(false) @@ -227,39 +352,7 @@ describe('Real Browser Environment Dark Mode Flicker Test', () => { setupMockEnvironment('dark') const timingData: Array<{ phase: string; timestamp: number; styles: any }> = [] - - const TimingPageComponent = () => { - const [mounted, setMounted] = useState(false) - const { theme } = useTheme() - const isDark = mounted ? theme === 'dark' : false - - // Record timing and styles for each render phase - const currentStyles = { - backgroundColor: isDark ? '#1f2937' : '#ffffff', - color: isDark ? '#ffffff' : '#000000', - } - - timingData.push({ - phase: mounted ? 'CSR' : 'Initial', - timestamp: performance.now(), - styles: currentStyles, - }) - - useEffect(() => { - setMounted(true) - }, []) - - return ( -
-
- Phase: {mounted ? 'CSR' : 'Initial'} | Theme: {theme} | Visual: {isDark ? 'dark' : 'light'} -
-
- ) - } + const TimingPageComponent = createTimingPageComponent(timingData) render( @@ -295,33 +388,7 @@ describe('Real Browser Environment Dark Mode Flicker Test', () => { setupMockEnvironment('dark') const cssStates: Array<{ className: string; timestamp: number }> = [] - - const CSSTestComponent = () => { - const [mounted, setMounted] = useState(false) - const { theme } = useTheme() - const isDark = mounted ? theme === 'dark' : false - - // Simulate Tailwind CSS class application - const className = `min-h-screen ${isDark ? 'bg-gray-900 text-white' : 'bg-white text-black'}` - - cssStates.push({ - className, - timestamp: performance.now(), - }) - - useEffect(() => { - setMounted(true) - }, []) - - return ( -
-
Classes: {className}
-
- ) - } + const CSSTestComponent = createCSSTestComponent(cssStates) render( @@ -413,34 +480,12 @@ describe('Real Browser Environment Dark Mode Flicker Test', () => { test('verifies ThemeProvider position fix reduces initialization delay', async () => { const performanceMarks: Array<{ event: string; timestamp: number }> = [] - const PerformanceTestComponent = () => { - const [mounted, setMounted] = useState(false) - const { theme } = useTheme() - - performanceMarks.push({ event: 'component-render', timestamp: performance.now() }) - - useEffect(() => { - performanceMarks.push({ event: 'mount-start', timestamp: performance.now() }) - setMounted(true) - performanceMarks.push({ event: 'mount-complete', timestamp: performance.now() }) - }, []) - - useEffect(() => { - if (theme) - performanceMarks.push({ event: 'theme-available', timestamp: performance.now() }) - }, [theme]) - - return ( -
- Mounted: {mounted.toString()} | Theme: {theme || 'loading'} -
- ) - } - setupMockEnvironment('dark') expect(window.localStorage.getItem('theme')).toBe('dark') + const PerformanceTestComponent = createPerformanceTestComponent(performanceMarks) + render( diff --git a/web/__tests__/unified-tags-logic.test.ts b/web/__tests__/unified-tags-logic.test.ts index c920e28e0a..ec73a6a268 100644 --- a/web/__tests__/unified-tags-logic.test.ts +++ b/web/__tests__/unified-tags-logic.test.ts @@ -70,14 +70,18 @@ describe('Unified Tags Editing - Pure Logic Tests', () => { }) describe('Fallback Logic (from layout-main.tsx)', () => { + type Tag = { id: string; name: string } + type AppDetail = { tags: Tag[] } + type FallbackResult = { tags?: Tag[] } | null + // no-op it('should trigger fallback when tags are missing or empty', () => { - const appDetailWithoutTags = { tags: [] } - const appDetailWithTags = { tags: [{ id: 'tag1' }] } - const appDetailWithUndefinedTags = { tags: undefined as any } + const appDetailWithoutTags: AppDetail = { tags: [] } + const appDetailWithTags: AppDetail = { tags: [{ id: 'tag1', name: 't' }] } + const appDetailWithUndefinedTags: { tags: Tag[] | undefined } = { tags: undefined } // This simulates the condition in layout-main.tsx - const shouldFallback1 = !appDetailWithoutTags.tags || appDetailWithoutTags.tags.length === 0 - const shouldFallback2 = !appDetailWithTags.tags || appDetailWithTags.tags.length === 0 + const shouldFallback1 = appDetailWithoutTags.tags.length === 0 + const shouldFallback2 = appDetailWithTags.tags.length === 0 const shouldFallback3 = !appDetailWithUndefinedTags.tags || appDetailWithUndefinedTags.tags.length === 0 expect(shouldFallback1).toBe(true) // Empty array should trigger fallback @@ -86,24 +90,26 @@ describe('Unified Tags Editing - Pure Logic Tests', () => { }) it('should preserve tags when fallback succeeds', () => { - const originalAppDetail = { tags: [] as any[] } - const fallbackResult = { tags: [{ id: 'tag1', name: 'fallback-tag' }] } + const originalAppDetail: AppDetail = { tags: [] } + const fallbackResult: { tags?: Tag[] } = { tags: [{ id: 'tag1', name: 'fallback-tag' }] } // This simulates the successful fallback in layout-main.tsx - if (fallbackResult?.tags) - originalAppDetail.tags = fallbackResult.tags + const tags = fallbackResult.tags + if (tags) + originalAppDetail.tags = tags expect(originalAppDetail.tags).toEqual(fallbackResult.tags) expect(originalAppDetail.tags.length).toBe(1) }) it('should continue with empty tags when fallback fails', () => { - const originalAppDetail: { tags: any[] } = { tags: [] } - const fallbackResult: { tags?: any[] } | null = null + const originalAppDetail: AppDetail = { tags: [] } + const fallbackResult = null as FallbackResult // This simulates fallback failure in layout-main.tsx - if (fallbackResult?.tags) - originalAppDetail.tags = fallbackResult.tags + const tags: Tag[] | undefined = fallbackResult && 'tags' in fallbackResult ? fallbackResult.tags : undefined + if (tags) + originalAppDetail.tags = tags expect(originalAppDetail.tags).toEqual([]) }) diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx index e4c3f60c12..0ad02ad7f3 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx @@ -73,7 +73,7 @@ const ConfigPopup: FC = ({ } }, [onChooseProvider]) - const handleConfigUpdated = useCallback((payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig) => { + const handleConfigUpdated = useCallback((payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | TencentConfig) => { onConfigUpdated(currentProvider!, payload) hideConfigModal() }, [currentProvider, hideConfigModal, onConfigUpdated]) diff --git a/web/app/(commonLayout)/layout.tsx b/web/app/(commonLayout)/layout.tsx index ed1c995e25..be9c4fe49a 100644 --- a/web/app/(commonLayout)/layout.tsx +++ b/web/app/(commonLayout)/layout.tsx @@ -9,6 +9,7 @@ import { EventEmitterContextProvider } from '@/context/event-emitter' import { ProviderContextProvider } from '@/context/provider-context' import { ModalContextProvider } from '@/context/modal-context' import GotoAnything from '@/app/components/goto-anything' +import Zendesk from '@/app/components/base/zendesk' const Layout = ({ children }: { children: ReactNode }) => { return ( @@ -28,6 +29,7 @@ const Layout = ({ children }: { children: ReactNode }) => { + ) diff --git a/web/app/(shareLayout)/components/splash.tsx b/web/app/(shareLayout)/components/splash.tsx index c26ea7e045..16d291d4b4 100644 --- a/web/app/(shareLayout)/components/splash.tsx +++ b/web/app/(shareLayout)/components/splash.tsx @@ -6,7 +6,6 @@ import { useWebAppStore } from '@/context/web-app-context' import { useRouter, useSearchParams } from 'next/navigation' import AppUnavailable from '@/app/components/base/app-unavailable' import { useTranslation } from 'react-i18next' -import { AccessMode } from '@/models/access-control' import { webAppLoginStatus, webAppLogout } from '@/service/webapp-auth' import { fetchAccessToken } from '@/service/share' import Loading from '@/app/components/base/loading' @@ -35,7 +34,6 @@ const Splash: FC = ({ children }) => { router.replace(url) }, [getSigninUrl, router, webAppLogout, shareCode]) - const needCheckIsLogin = webAppAccessMode !== AccessMode.PUBLIC const [isLoading, setIsLoading] = useState(true) useEffect(() => { if (message) { @@ -58,8 +56,8 @@ const Splash: FC = ({ children }) => { } (async () => { - const { userLoggedIn, appLoggedIn } = await webAppLoginStatus(needCheckIsLogin, shareCode!) - + // if access mode is public, user login is always true, but the app login(passport) may be expired + const { userLoggedIn, appLoggedIn } = await webAppLoginStatus(shareCode!) if (userLoggedIn && appLoggedIn) { redirectOrFinish() } @@ -87,7 +85,6 @@ const Splash: FC = ({ children }) => { router, message, webAppAccessMode, - needCheckIsLogin, tokenFromUrl]) if (message) { diff --git a/web/app/components/app/annotation/index.tsx b/web/app/components/app/annotation/index.tsx index 264b1ac727..bc63b85f6d 100644 --- a/web/app/components/app/annotation/index.tsx +++ b/web/app/components/app/annotation/index.tsx @@ -53,7 +53,6 @@ const Annotation: FC = (props) => { const [isShowViewModal, setIsShowViewModal] = useState(false) const [selectedIds, setSelectedIds] = useState([]) const debouncedQueryParams = useDebounce(queryParams, { wait: 500 }) - const [isBatchDeleting, setIsBatchDeleting] = useState(false) const fetchAnnotationConfig = async () => { const res = await doFetchAnnotationConfig(appDetail.id) @@ -108,9 +107,6 @@ const Annotation: FC = (props) => { } const handleBatchDelete = async () => { - if (isBatchDeleting) - return - setIsBatchDeleting(true) try { await delAnnotations(appDetail.id, selectedIds) Toast.notify({ message: t('common.api.actionSuccess'), type: 'success' }) @@ -121,9 +117,6 @@ const Annotation: FC = (props) => { catch (e: any) { Toast.notify({ type: 'error', message: e.message || t('common.api.actionFailed') }) } - finally { - setIsBatchDeleting(false) - } } const handleView = (item: AnnotationItem) => { @@ -213,7 +206,6 @@ const Annotation: FC = (props) => { onSelectedIdsChange={setSelectedIds} onBatchDelete={handleBatchDelete} onCancel={() => setSelectedIds([])} - isBatchDeleting={isBatchDeleting} /> :
} diff --git a/web/app/components/app/annotation/list.tsx b/web/app/components/app/annotation/list.tsx index 6705ac5768..70ecedb869 100644 --- a/web/app/components/app/annotation/list.tsx +++ b/web/app/components/app/annotation/list.tsx @@ -19,7 +19,6 @@ type Props = { onSelectedIdsChange: (selectedIds: string[]) => void onBatchDelete: () => Promise onCancel: () => void - isBatchDeleting?: boolean } const List: FC = ({ @@ -30,7 +29,6 @@ const List: FC = ({ onSelectedIdsChange, onBatchDelete, onCancel, - isBatchDeleting, }) => { const { t } = useTranslation() const { formatTime } = useTimestamp() @@ -142,7 +140,6 @@ const List: FC = ({ selectedIds={selectedIds} onBatchDelete={onBatchDelete} onCancel={onCancel} - isBatchDeleting={isBatchDeleting} /> )} diff --git a/web/app/components/app/app-publisher/features-wrapper.tsx b/web/app/components/app/app-publisher/features-wrapper.tsx index dadd112135..409c390f4b 100644 --- a/web/app/components/app/app-publisher/features-wrapper.tsx +++ b/web/app/components/app/app-publisher/features-wrapper.tsx @@ -1,6 +1,6 @@ import React, { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' -import produce from 'immer' +import { produce } from 'immer' import type { AppPublisherProps } from '@/app/components/app/app-publisher' import Confirm from '@/app/components/base/confirm' import AppPublisher from '@/app/components/app/app-publisher' diff --git a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx index e2d37bb9de..aa8d0f65ca 100644 --- a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx @@ -5,7 +5,7 @@ import copy from 'copy-to-clipboard' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { useBoolean } from 'ahooks' -import produce from 'immer' +import { produce } from 'immer' import { RiDeleteBinLine, RiErrorWarningFill, @@ -78,7 +78,9 @@ const AdvancedPromptInput: FC = ({ const handleOpenExternalDataToolModal = () => { setShowExternalDataToolModal({ payload: {}, - onSaveCallback: (newExternalDataTool: ExternalDataTool) => { + onSaveCallback: (newExternalDataTool?: ExternalDataTool) => { + if (!newExternalDataTool) + return eventEmitter?.emit({ type: ADD_EXTERNAL_DATA_TOOL, payload: newExternalDataTool, diff --git a/web/app/components/app/configuration/config-prompt/index.tsx b/web/app/components/app/configuration/config-prompt/index.tsx index 1caca47bcb..ec34588e41 100644 --- a/web/app/components/app/configuration/config-prompt/index.tsx +++ b/web/app/components/app/configuration/config-prompt/index.tsx @@ -2,7 +2,7 @@ import type { FC } from 'react' import React from 'react' import { useContext } from 'use-context-selector' -import produce from 'immer' +import { produce } from 'immer' import { RiAddLine, } from '@remixicon/react' diff --git a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx index a7bdc550d1..8634232b2b 100644 --- a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx @@ -3,7 +3,7 @@ import type { FC } from 'react' import React, { useState } from 'react' import { useTranslation } from 'react-i18next' import { useBoolean } from 'ahooks' -import produce from 'immer' +import { produce } from 'immer' import { useContext } from 'use-context-selector' import ConfirmAddVar from './confirm-add-var' import PromptEditorHeightResizeWrap from './prompt-editor-height-resize-wrap' @@ -76,7 +76,9 @@ const Prompt: FC = ({ const handleOpenExternalDataToolModal = () => { setShowExternalDataToolModal({ payload: {}, - onSaveCallback: (newExternalDataTool: ExternalDataTool) => { + onSaveCallback: (newExternalDataTool?: ExternalDataTool) => { + if (!newExternalDataTool) + return eventEmitter?.emit({ type: ADD_EXTERNAL_DATA_TOOL, payload: newExternalDataTool, diff --git a/web/app/components/app/configuration/config-var/config-modal/index.tsx b/web/app/components/app/configuration/config-var/config-modal/index.tsx index 8a02ca8caa..3f32c9b0c7 100644 --- a/web/app/components/app/configuration/config-var/config-modal/index.tsx +++ b/web/app/components/app/configuration/config-var/config-modal/index.tsx @@ -3,7 +3,7 @@ import type { ChangeEvent, FC } from 'react' import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' -import produce from 'immer' +import { produce } from 'immer' import ModalFoot from '../modal-foot' import ConfigSelect from '../config-select' import ConfigString from '../config-string' @@ -320,7 +320,7 @@ const ConfigModal: FC = ({ {type === InputVarType.paragraph && (