diff --git a/api/models/comment.py b/api/models/comment.py index 059f974dc7..5bc79aae67 100644 --- a/api/models/comment.py +++ b/api/models/comment.py @@ -70,11 +70,15 @@ class WorkflowComment(Base): @property def created_by_account(self): """Get creator account.""" + if hasattr(self, "_created_by_account_cache"): + return self._created_by_account_cache return db.session.get(Account, self.created_by) @property def resolved_by_account(self): """Get resolver account.""" + if hasattr(self, "_resolved_by_account_cache"): + return self._resolved_by_account_cache if self.resolved_by: return db.session.get(Account, self.resolved_by) return None @@ -147,6 +151,8 @@ class WorkflowCommentReply(Base): @property def created_by_account(self): """Get creator account.""" + if hasattr(self, "_created_by_account_cache"): + return self._created_by_account_cache return db.session.get(Account, self.created_by) @@ -186,4 +192,6 @@ class WorkflowCommentMention(Base): @property def mentioned_user_account(self): """Get mentioned account.""" + if hasattr(self, "_mentioned_user_account_cache"): + return self._mentioned_user_account_cache return db.session.get(Account, self.mentioned_user_id) diff --git a/api/services/workflow_comment_service.py b/api/services/workflow_comment_service.py index 4b5fbf7a05..c25bfa65e7 100644 --- a/api/services/workflow_comment_service.py +++ b/api/services/workflow_comment_service.py @@ -1,4 +1,5 @@ import logging +from collections.abc import Sequence from typing import Optional from sqlalchemy import desc, select @@ -9,6 +10,7 @@ from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from libs.helper import uuid_value from models import WorkflowComment, WorkflowCommentMention, WorkflowCommentReply +from models.account import Account logger = logging.getLogger(__name__) @@ -25,7 +27,7 @@ class WorkflowCommentService: raise ValueError("Comment content cannot exceed 1000 characters") @staticmethod - def get_comments(tenant_id: str, app_id: str) -> list[WorkflowComment]: + def get_comments(tenant_id: str, app_id: str) -> Sequence[WorkflowComment]: """Get all comments for a workflow.""" with Session(db.engine) as session: # Get all comments with eager loading @@ -37,10 +39,42 @@ class WorkflowCommentService: ) comments = session.scalars(stmt).all() + + # Batch preload all Account objects to avoid N+1 queries + WorkflowCommentService._preload_accounts(session, comments) + return comments @staticmethod - def get_comment(tenant_id: str, app_id: str, comment_id: str, session: Session = None) -> WorkflowComment: + def _preload_accounts(session: Session, comments: Sequence[WorkflowComment]) -> None: + """Batch preload Account objects for comments, replies, and mentions.""" + # Collect all user IDs + user_ids: set[str] = set() + for comment in comments: + user_ids.add(comment.created_by) + if comment.resolved_by: + user_ids.add(comment.resolved_by) + user_ids.update(reply.created_by for reply in comment.replies) + user_ids.update(mention.mentioned_user_id for mention in comment.mentions) + + if not user_ids: + return + + # Batch query all accounts + accounts = session.scalars(select(Account).where(Account.id.in_(user_ids))).all() + account_map = {str(account.id): account for account in accounts} + + # Cache accounts on objects + for comment in comments: + comment._created_by_account_cache = account_map.get(comment.created_by) + comment._resolved_by_account_cache = account_map.get(comment.resolved_by) if comment.resolved_by else None + for reply in comment.replies: + reply._created_by_account_cache = account_map.get(reply.created_by) + for mention in comment.mentions: + mention._mentioned_user_account_cache = account_map.get(mention.mentioned_user_id) + + @staticmethod + def get_comment(tenant_id: str, app_id: str, comment_id: str, session: Session | None = None) -> WorkflowComment: """Get a specific comment.""" def _get_comment(session: Session) -> WorkflowComment: @@ -58,6 +92,9 @@ class WorkflowCommentService: if not comment: raise NotFound("Comment not found") + # Preload accounts to avoid N+1 queries + WorkflowCommentService._preload_accounts(session, [comment]) + return comment if session is not None: @@ -75,7 +112,7 @@ class WorkflowCommentService: position_x: float, position_y: float, mentioned_user_ids: Optional[list[str]] = None, - ) -> WorkflowComment: + ) -> dict: """Create a new workflow comment.""" WorkflowCommentService._validate_content(content) @@ -247,7 +284,7 @@ class WorkflowCommentService: @staticmethod def update_reply( reply_id: str, user_id: str, content: str, mentioned_user_ids: Optional[list[str]] = None - ) -> WorkflowCommentReply: + ) -> dict: """Update a comment reply.""" WorkflowCommentService._validate_content(content)