mirror of https://github.com/langgenius/dify.git
feat: support var filer in conversation service (#29245)
This commit is contained in:
parent
accc91e89d
commit
65e8fdc0e4
|
|
@ -4,7 +4,7 @@ from uuid import UUID
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from flask_restx._http import HTTPStatus
|
from flask_restx._http import HTTPStatus
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import BadRequest, NotFound
|
from werkzeug.exceptions import BadRequest, NotFound
|
||||||
|
|
||||||
|
|
@ -51,6 +51,32 @@ class ConversationRenamePayload(BaseModel):
|
||||||
class ConversationVariablesQuery(BaseModel):
|
class ConversationVariablesQuery(BaseModel):
|
||||||
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")
|
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")
|
||||||
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
|
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
|
||||||
|
variable_name: str | None = Field(
|
||||||
|
default=None, description="Filter variables by name", min_length=1, max_length=255
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("variable_name", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_variable_name(cls, v: str | None) -> str | None:
|
||||||
|
"""
|
||||||
|
Validate variable_name to prevent injection attacks.
|
||||||
|
"""
|
||||||
|
if v is None:
|
||||||
|
return v
|
||||||
|
|
||||||
|
# Only allow safe characters: alphanumeric, underscore, hyphen, period
|
||||||
|
if not v.replace("-", "").replace("_", "").replace(".", "").isalnum():
|
||||||
|
raise ValueError(
|
||||||
|
"Variable name can only contain letters, numbers, hyphens (-), underscores (_), and periods (.)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prevent SQL injection patterns
|
||||||
|
dangerous_patterns = ["'", '"', ";", "--", "/*", "*/", "xp_", "sp_"]
|
||||||
|
for pattern in dangerous_patterns:
|
||||||
|
if pattern in v.lower():
|
||||||
|
raise ValueError(f"Variable name contains invalid characters: {pattern}")
|
||||||
|
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
class ConversationVariableUpdatePayload(BaseModel):
|
class ConversationVariableUpdatePayload(BaseModel):
|
||||||
|
|
@ -199,7 +225,7 @@ class ConversationVariablesApi(Resource):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return ConversationService.get_conversational_variable(
|
return ConversationService.get_conversational_variable(
|
||||||
app_model, conversation_id, end_user, query_args.limit, last_id
|
app_model, conversation_id, end_user, query_args.limit, last_id, query_args.variable_name
|
||||||
)
|
)
|
||||||
except services.errors.conversation.ConversationNotExistsError:
|
except services.errors.conversation.ConversationNotExistsError:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,9 @@ from typing import Any, Union
|
||||||
from sqlalchemy import asc, desc, func, or_, select
|
from sqlalchemy import asc, desc, func, or_, select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.db.session_factory import session_factory
|
||||||
from core.llm_generator.llm_generator import LLMGenerator
|
from core.llm_generator.llm_generator import LLMGenerator
|
||||||
from core.variables.types import SegmentType
|
from core.variables.types import SegmentType
|
||||||
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
|
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
|
||||||
|
|
@ -202,6 +204,7 @@ class ConversationService:
|
||||||
user: Union[Account, EndUser] | None,
|
user: Union[Account, EndUser] | None,
|
||||||
limit: int,
|
limit: int,
|
||||||
last_id: str | None,
|
last_id: str | None,
|
||||||
|
variable_name: str | None = None,
|
||||||
) -> InfiniteScrollPagination:
|
) -> InfiniteScrollPagination:
|
||||||
conversation = cls.get_conversation(app_model, conversation_id, user)
|
conversation = cls.get_conversation(app_model, conversation_id, user)
|
||||||
|
|
||||||
|
|
@ -212,7 +215,25 @@ class ConversationService:
|
||||||
.order_by(ConversationVariable.created_at)
|
.order_by(ConversationVariable.created_at)
|
||||||
)
|
)
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
# Apply variable_name filter if provided
|
||||||
|
if variable_name:
|
||||||
|
# Filter using JSON extraction to match variable names case-insensitively
|
||||||
|
escaped_variable_name = variable_name.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||||
|
# Filter using JSON extraction to match variable names case-insensitively
|
||||||
|
if dify_config.DB_TYPE in ["mysql", "oceanbase", "seekdb"]:
|
||||||
|
stmt = stmt.where(
|
||||||
|
func.json_extract(ConversationVariable.data, "$.name").ilike(
|
||||||
|
f"%{escaped_variable_name}%", escape="\\"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif dify_config.DB_TYPE == "postgresql":
|
||||||
|
stmt = stmt.where(
|
||||||
|
func.json_extract_path_text(ConversationVariable.data, "name").ilike(
|
||||||
|
f"%{escaped_variable_name}%", escape="\\"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with session_factory.create_session() as session:
|
||||||
if last_id:
|
if last_id:
|
||||||
last_variable = session.scalar(stmt.where(ConversationVariable.id == last_id))
|
last_variable = session.scalar(stmt.where(ConversationVariable.id == last_id))
|
||||||
if not last_variable:
|
if not last_variable:
|
||||||
|
|
@ -279,7 +300,7 @@ class ConversationService:
|
||||||
.where(ConversationVariable.id == variable_id)
|
.where(ConversationVariable.id == variable_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with session_factory.create_session() as session:
|
||||||
existing_variable = session.scalar(stmt)
|
existing_variable = session.scalar(stmt)
|
||||||
if not existing_variable:
|
if not existing_variable:
|
||||||
raise ConversationVariableNotExistsError()
|
raise ConversationVariableNotExistsError()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue