feat: support var filer in conversation service (#29245)

This commit is contained in:
wangxiaolei 2025-12-22 21:48:11 +08:00 committed by GitHub
parent accc91e89d
commit 65e8fdc0e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 51 additions and 4 deletions

View File

@ -4,7 +4,7 @@ from uuid import UUID
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
from flask_restx._http import HTTPStatus from flask_restx._http import HTTPStatus
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound from werkzeug.exceptions import BadRequest, NotFound
@ -51,6 +51,32 @@ class ConversationRenamePayload(BaseModel):
class ConversationVariablesQuery(BaseModel): class ConversationVariablesQuery(BaseModel):
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination") last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return") limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
variable_name: str | None = Field(
default=None, description="Filter variables by name", min_length=1, max_length=255
)
@field_validator("variable_name", mode="before")
@classmethod
def validate_variable_name(cls, v: str | None) -> str | None:
"""
Validate variable_name to prevent injection attacks.
"""
if v is None:
return v
# Only allow safe characters: alphanumeric, underscore, hyphen, period
if not v.replace("-", "").replace("_", "").replace(".", "").isalnum():
raise ValueError(
"Variable name can only contain letters, numbers, hyphens (-), underscores (_), and periods (.)"
)
# Prevent SQL injection patterns
dangerous_patterns = ["'", '"', ";", "--", "/*", "*/", "xp_", "sp_"]
for pattern in dangerous_patterns:
if pattern in v.lower():
raise ValueError(f"Variable name contains invalid characters: {pattern}")
return v
class ConversationVariableUpdatePayload(BaseModel): class ConversationVariableUpdatePayload(BaseModel):
@ -199,7 +225,7 @@ class ConversationVariablesApi(Resource):
try: try:
return ConversationService.get_conversational_variable( return ConversationService.get_conversational_variable(
app_model, conversation_id, end_user, query_args.limit, last_id app_model, conversation_id, end_user, query_args.limit, last_id, query_args.variable_name
) )
except services.errors.conversation.ConversationNotExistsError: except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")

View File

@ -6,7 +6,9 @@ from typing import Any, Union
from sqlalchemy import asc, desc, func, or_, select from sqlalchemy import asc, desc, func, or_, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.db.session_factory import session_factory
from core.llm_generator.llm_generator import LLMGenerator from core.llm_generator.llm_generator import LLMGenerator
from core.variables.types import SegmentType from core.variables.types import SegmentType
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
@ -202,6 +204,7 @@ class ConversationService:
user: Union[Account, EndUser] | None, user: Union[Account, EndUser] | None,
limit: int, limit: int,
last_id: str | None, last_id: str | None,
variable_name: str | None = None,
) -> InfiniteScrollPagination: ) -> InfiniteScrollPagination:
conversation = cls.get_conversation(app_model, conversation_id, user) conversation = cls.get_conversation(app_model, conversation_id, user)
@ -212,7 +215,25 @@ class ConversationService:
.order_by(ConversationVariable.created_at) .order_by(ConversationVariable.created_at)
) )
with Session(db.engine) as session: # Apply variable_name filter if provided
if variable_name:
# Filter using JSON extraction to match variable names case-insensitively
escaped_variable_name = variable_name.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
# Filter using JSON extraction to match variable names case-insensitively
if dify_config.DB_TYPE in ["mysql", "oceanbase", "seekdb"]:
stmt = stmt.where(
func.json_extract(ConversationVariable.data, "$.name").ilike(
f"%{escaped_variable_name}%", escape="\\"
)
)
elif dify_config.DB_TYPE == "postgresql":
stmt = stmt.where(
func.json_extract_path_text(ConversationVariable.data, "name").ilike(
f"%{escaped_variable_name}%", escape="\\"
)
)
with session_factory.create_session() as session:
if last_id: if last_id:
last_variable = session.scalar(stmt.where(ConversationVariable.id == last_id)) last_variable = session.scalar(stmt.where(ConversationVariable.id == last_id))
if not last_variable: if not last_variable:
@ -279,7 +300,7 @@ class ConversationService:
.where(ConversationVariable.id == variable_id) .where(ConversationVariable.id == variable_id)
) )
with Session(db.engine) as session: with session_factory.create_session() as session:
existing_variable = session.scalar(stmt) existing_variable = session.scalar(stmt)
if not existing_variable: if not existing_variable:
raise ConversationVariableNotExistsError() raise ConversationVariableNotExistsError()