test: migrate test_workflow_draft_variable_service to SQLAlchemy 2.0 select() API (#34986)

This commit is contained in:
dataCenter430 2026-04-12 17:57:21 -07:00 committed by GitHub
parent 88c38ddeb3
commit bc2b9eec58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -7,7 +7,7 @@ from graphon.nodes import BuiltinNodeTypes
from graphon.variables.segments import StringSegment from graphon.variables.segments import StringSegment
from graphon.variables.types import SegmentType from graphon.variables.types import SegmentType
from graphon.variables.variables import StringVariable from graphon.variables.variables import StringVariable
from sqlalchemy import delete from sqlalchemy import delete, func, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
@ -38,21 +38,25 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
def setUp(self): def setUp(self):
self._test_app_id = str(uuid.uuid4()) self._test_app_id = str(uuid.uuid4())
self._test_user_id = str(uuid.uuid4())
self._session: Session = db.session() self._session: Session = db.session()
sys_var = WorkflowDraftVariable.new_sys_variable( sys_var = WorkflowDraftVariable.new_sys_variable(
app_id=self._test_app_id, app_id=self._test_app_id,
user_id=self._test_user_id,
name="sys_var", name="sys_var",
value=build_segment("sys_value"), value=build_segment("sys_value"),
node_execution_id=self._node_exec_id, node_execution_id=self._node_exec_id,
) )
conv_var = WorkflowDraftVariable.new_conversation_variable( conv_var = WorkflowDraftVariable.new_conversation_variable(
app_id=self._test_app_id, app_id=self._test_app_id,
user_id=self._test_user_id,
name="conv_var", name="conv_var",
value=build_segment("conv_value"), value=build_segment("conv_value"),
) )
node2_vars = [ node2_vars = [
WorkflowDraftVariable.new_node_variable( WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id, app_id=self._test_app_id,
user_id=self._test_user_id,
node_id=self._node2_id, node_id=self._node2_id,
name="int_var", name="int_var",
value=build_segment(1), value=build_segment(1),
@ -61,6 +65,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
), ),
WorkflowDraftVariable.new_node_variable( WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id, app_id=self._test_app_id,
user_id=self._test_user_id,
node_id=self._node2_id, node_id=self._node2_id,
name="str_var", name="str_var",
value=build_segment("str_value"), value=build_segment("str_value"),
@ -70,6 +75,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
] ]
node1_var = WorkflowDraftVariable.new_node_variable( node1_var = WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id, app_id=self._test_app_id,
user_id=self._test_user_id,
node_id=self._node1_id, node_id=self._node1_id,
name="str_var", name="str_var",
value=build_segment("str_value"), value=build_segment("str_value"),
@ -141,24 +147,27 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
def test_delete_node_variables(self): def test_delete_node_variables(self):
srv = self._get_test_srv() srv = self._get_test_srv()
srv.delete_node_variables(self._test_app_id, self._node2_id, user_id=self._test_user_id) srv.delete_node_variables(self._test_app_id, self._node2_id, user_id=self._test_user_id)
node2_var_count = ( node2_var_count = self._session.scalar(
self._session.query(WorkflowDraftVariable) select(func.count())
.select_from(WorkflowDraftVariable)
.where( .where(
WorkflowDraftVariable.app_id == self._test_app_id, WorkflowDraftVariable.app_id == self._test_app_id,
WorkflowDraftVariable.node_id == self._node2_id, WorkflowDraftVariable.node_id == self._node2_id,
WorkflowDraftVariable.user_id == self._test_user_id,
) )
.count()
) )
assert node2_var_count == 0 assert node2_var_count == 0
def test_delete_variable(self): def test_delete_variable(self):
srv = self._get_test_srv() srv = self._get_test_srv()
node_1_var = ( node_1_var = self._session.scalars(
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).one() select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id)
) ).one()
srv.delete_variable(node_1_var) srv.delete_variable(node_1_var)
exists = bool( exists = bool(
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).first() self._session.scalars(
select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id)
).first()
) )
assert exists is False assert exists is False
@ -248,9 +257,7 @@ class TestDraftVariableLoader(unittest.TestCase):
def tearDown(self): def tearDown(self):
with Session(bind=db.engine, expire_on_commit=False) as session: with Session(bind=db.engine, expire_on_commit=False) as session:
session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == self._test_app_id).delete( session.execute(delete(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == self._test_app_id))
synchronize_session=False
)
session.commit() session.commit()
def test_variable_loader_with_empty_selector(self): def test_variable_loader_with_empty_selector(self):
@ -431,9 +438,11 @@ class TestDraftVariableLoader(unittest.TestCase):
# Clean up # Clean up
with Session(bind=db.engine) as session: with Session(bind=db.engine) as session:
# Query and delete by ID to ensure they're tracked in this session # Query and delete by ID to ensure they're tracked in this session
session.query(WorkflowDraftVariable).filter_by(id=offloaded_var.id).delete() session.execute(delete(WorkflowDraftVariable).where(WorkflowDraftVariable.id == offloaded_var.id))
session.query(WorkflowDraftVariableFile).filter_by(id=variable_file.id).delete() session.execute(
session.query(UploadFile).filter_by(id=upload_file.id).delete() delete(WorkflowDraftVariableFile).where(WorkflowDraftVariableFile.id == variable_file.id)
)
session.execute(delete(UploadFile).where(UploadFile.id == upload_file.id))
session.commit() session.commit()
# Clean up storage # Clean up storage
try: try:
@ -534,9 +543,11 @@ class TestDraftVariableLoader(unittest.TestCase):
# Clean up # Clean up
with Session(bind=db.engine) as session: with Session(bind=db.engine) as session:
# Query and delete by ID to ensure they're tracked in this session # Query and delete by ID to ensure they're tracked in this session
session.query(WorkflowDraftVariable).filter_by(id=offloaded_var.id).delete() session.execute(delete(WorkflowDraftVariable).where(WorkflowDraftVariable.id == offloaded_var.id))
session.query(WorkflowDraftVariableFile).filter_by(id=variable_file.id).delete() session.execute(
session.query(UploadFile).filter_by(id=upload_file.id).delete() delete(WorkflowDraftVariableFile).where(WorkflowDraftVariableFile.id == variable_file.id)
)
session.execute(delete(UploadFile).where(UploadFile.id == upload_file.id))
session.commit() session.commit()
# Clean up storage # Clean up storage
try: try: