diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index be6d837032..40e4bde389 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -4,7 +4,7 @@ from uuid import UUID from flask import request from flask_restx import Resource from flask_restx._http import HTTPStatus -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, NotFound @@ -51,6 +51,32 @@ class ConversationRenamePayload(BaseModel): class ConversationVariablesQuery(BaseModel): last_id: UUID | None = Field(default=None, description="Last variable ID for pagination") limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return") + variable_name: str | None = Field( + default=None, description="Filter variables by name", min_length=1, max_length=255 + ) + + @field_validator("variable_name", mode="before") + @classmethod + def validate_variable_name(cls, v: str | None) -> str | None: + """ + Validate variable_name to prevent injection attacks. + """ + if v is None: + return v + + # Only allow safe characters: alphanumeric, underscore, hyphen, period + if not v.replace("-", "").replace("_", "").replace(".", "").isalnum(): + raise ValueError( + "Variable name can only contain letters, numbers, hyphens (-), underscores (_), and periods (.)" + ) + + # Prevent SQL injection patterns + dangerous_patterns = ["'", '"', ";", "--", "/*", "*/", "xp_", "sp_"] + for pattern in dangerous_patterns: + if pattern in v.lower(): + raise ValueError(f"Variable name contains invalid characters: {pattern}") + + return v class ConversationVariableUpdatePayload(BaseModel): @@ -199,7 +225,7 @@ class ConversationVariablesApi(Resource): try: return ConversationService.get_conversational_variable( - app_model, conversation_id, end_user, query_args.limit, last_id + app_model, conversation_id, end_user, query_args.limit, last_id, query_args.variable_name ) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 5253199552..659e7406fb 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -6,7 +6,9 @@ from typing import Any, Union from sqlalchemy import asc, desc, func, or_, select from sqlalchemy.orm import Session +from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom +from core.db.session_factory import session_factory from core.llm_generator.llm_generator import LLMGenerator from core.variables.types import SegmentType from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory @@ -202,6 +204,7 @@ class ConversationService: user: Union[Account, EndUser] | None, limit: int, last_id: str | None, + variable_name: str | None = None, ) -> InfiniteScrollPagination: conversation = cls.get_conversation(app_model, conversation_id, user) @@ -212,7 +215,25 @@ class ConversationService: .order_by(ConversationVariable.created_at) ) - with Session(db.engine) as session: + # Apply variable_name filter if provided + if variable_name: + # Filter using JSON extraction to match variable names case-insensitively + escaped_variable_name = variable_name.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + # Filter using JSON extraction to match variable names case-insensitively + if dify_config.DB_TYPE in ["mysql", "oceanbase", "seekdb"]: + stmt = stmt.where( + func.json_extract(ConversationVariable.data, "$.name").ilike( + f"%{escaped_variable_name}%", escape="\\" + ) + ) + elif dify_config.DB_TYPE == "postgresql": + stmt = stmt.where( + func.json_extract_path_text(ConversationVariable.data, "name").ilike( + f"%{escaped_variable_name}%", escape="\\" + ) + ) + + with session_factory.create_session() as session: if last_id: last_variable = session.scalar(stmt.where(ConversationVariable.id == last_id)) if not last_variable: @@ -279,7 +300,7 @@ class ConversationService: .where(ConversationVariable.id == variable_id) ) - with Session(db.engine) as session: + with session_factory.create_session() as session: existing_variable = session.scalar(stmt) if not existing_variable: raise ConversationVariableNotExistsError()