more typed orm (#28331)

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato 2025-11-21 14:23:32 +09:00 committed by GitHub
parent 5f61ca5e6f
commit 3c30d0f41b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 167 additions and 124 deletions

View File

@ -6,7 +6,7 @@ import sqlalchemy as sa
from sqlalchemy import DateTime, String, func
from sqlalchemy.orm import Mapped, mapped_column
from .base import Base
from .base import TypeBase
from .types import LongText, StringUUID
@ -17,16 +17,18 @@ class APIBasedExtensionPoint(enum.StrEnum):
APP_MODERATION_OUTPUT = "app.moderation.output"
class APIBasedExtension(Base):
class APIBasedExtension(TypeBase):
__tablename__ = "api_based_extensions"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="api_based_extension_pkey"),
sa.Index("api_based_extension_tenant_idx", "tenant_id"),
)
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
tenant_id = mapped_column(StringUUID, nullable=False)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
api_endpoint: Mapped[str] = mapped_column(String(255), nullable=False)
api_key = mapped_column(LongText, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
api_key: Mapped[str] = mapped_column(LongText, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)

View File

@ -6,62 +6,74 @@ from sqlalchemy.orm import Mapped, mapped_column
from libs.uuid_utils import uuidv7
from .base import Base
from .base import TypeBase
from .types import AdjustedJSON, LongText, StringUUID
class DatasourceOauthParamConfig(Base): # type: ignore[name-defined]
class DatasourceOauthParamConfig(TypeBase):
__tablename__ = "datasource_oauth_params"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="datasource_oauth_config_pkey"),
sa.UniqueConstraint("plugin_id", "provider", name="datasource_oauth_config_datasource_id_provider_idx"),
)
id = mapped_column(StringUUID, default=lambda: str(uuidv7()))
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
system_credentials: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False)
class DatasourceProvider(Base):
class DatasourceProvider(TypeBase):
__tablename__ = "datasource_providers"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="datasource_provider_pkey"),
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", "name", name="datasource_provider_unique_name"),
sa.Index("datasource_provider_auth_type_provider_idx", "tenant_id", "plugin_id", "provider"),
)
id = mapped_column(StringUUID, default=lambda: str(uuidv7()))
tenant_id = mapped_column(StringUUID, nullable=False)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
provider: Mapped[str] = mapped_column(sa.String(128), nullable=False)
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
auth_type: Mapped[str] = mapped_column(sa.String(255), nullable=False)
encrypted_credentials: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False)
avatar_url: Mapped[str] = mapped_column(LongText, nullable=True, default="default")
is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
expires_at: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default="-1")
is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False)
expires_at: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default="-1", default=-1)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
sa.DateTime,
nullable=False,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
init=False,
)
class DatasourceOauthTenantParamConfig(Base):
class DatasourceOauthTenantParamConfig(TypeBase):
__tablename__ = "datasource_oauth_tenant_params"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="datasource_oauth_tenant_config_pkey"),
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="datasource_oauth_tenant_config_unique"),
)
id = mapped_column(StringUUID, default=lambda: str(uuidv7()))
tenant_id = mapped_column(StringUUID, nullable=False)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
client_params: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False, default={})
client_params: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False, default_factory=dict)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
init=False,
)

View File

@ -16,14 +16,15 @@ from core.trigger.entities.entities import Subscription
from core.trigger.utils.endpoint import generate_plugin_trigger_endpoint_url, generate_webhook_trigger_endpoint
from libs.datetime_utils import naive_utc_now
from libs.uuid_utils import uuidv7
from models.base import Base, TypeBase
from models.engine import db
from models.enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTriggerStatus
from models.model import Account
from models.types import EnumText, LongText, StringUUID
from .base import Base, TypeBase
from .engine import db
from .enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTriggerStatus
from .model import Account
from .types import EnumText, LongText, StringUUID
class TriggerSubscription(Base):
class TriggerSubscription(TypeBase):
"""
Trigger provider model for managing credentials
Supports multiple credential instances per provider
@ -40,7 +41,7 @@ class TriggerSubscription(Base):
UniqueConstraint("tenant_id", "provider_id", "name", name="unique_trigger_provider"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
name: Mapped[str] = mapped_column(String(255), nullable=False, comment="Subscription instance name")
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@ -62,12 +63,15 @@ class TriggerSubscription(Base):
Integer, default=-1, comment="Subscription instance expiration timestamp, -1 for never"
)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
server_default=func.current_timestamp(),
server_onupdate=func.current_timestamp(),
init=False,
)
def is_credential_expired(self) -> bool:
@ -100,24 +104,27 @@ class TriggerSubscription(Base):
# system level trigger oauth client params
class TriggerOAuthSystemClient(Base):
class TriggerOAuthSystemClient(TypeBase):
__tablename__ = "trigger_oauth_system_clients"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="trigger_oauth_system_client_pkey"),
sa.UniqueConstraint("plugin_id", "provider", name="trigger_oauth_system_client_plugin_id_provider_idx"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
plugin_id: Mapped[str] = mapped_column(String(255), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
# oauth params of the trigger provider
encrypted_oauth_params: Mapped[str] = mapped_column(LongText, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
server_default=func.current_timestamp(),
server_onupdate=func.current_timestamp(),
init=False,
)
@ -134,7 +141,7 @@ class TriggerOAuthTenantClient(Base):
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
plugin_id: Mapped[str] = mapped_column(String(255), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"), default=True)
# oauth params of the trigger provider
encrypted_oauth_params: Mapped[str] = mapped_column(LongText, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())

View File

@ -181,19 +181,21 @@ class TriggerProviderService:
# Create provider record
subscription = TriggerSubscription(
id=subscription_id or str(uuid.uuid4()),
tenant_id=tenant_id,
user_id=user_id,
name=name,
endpoint_id=endpoint_id,
provider_id=str(provider_id),
parameters=parameters,
properties=properties_encrypter.encrypt(dict(properties)),
credentials=credential_encrypter.encrypt(dict(credentials)) if credential_encrypter else {},
parameters=dict(parameters),
properties=dict(properties_encrypter.encrypt(dict(properties))),
credentials=dict(credential_encrypter.encrypt(dict(credentials)))
if credential_encrypter
else {},
credential_type=credential_type.value,
credential_expires_at=credential_expires_at,
expires_at=expires_at,
)
subscription.id = subscription_id or str(uuid.uuid4())
session.add(subscription)
session.commit()

View File

@ -69,13 +69,14 @@ class TestAPIBasedExtensionService:
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant is not None
# Setup extension data
extension_data = APIBasedExtension()
extension_data.tenant_id = tenant.id
extension_data.name = fake.company()
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data.api_key = fake.password(length=20)
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
# Save extension
saved_extension = APIBasedExtensionService.save(extension_data)
@ -105,13 +106,14 @@ class TestAPIBasedExtensionService:
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant is not None
# Test empty name
extension_data = APIBasedExtension()
extension_data.tenant_id = tenant.id
extension_data.name = ""
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data.api_key = fake.password(length=20)
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name="",
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
with pytest.raises(ValueError, match="name must not be empty"):
APIBasedExtensionService.save(extension_data)
@ -141,12 +143,14 @@ class TestAPIBasedExtensionService:
# Create multiple extensions
extensions = []
assert tenant is not None
for i in range(3):
extension_data = APIBasedExtension()
extension_data.tenant_id = tenant.id
extension_data.name = f"Extension {i}: {fake.company()}"
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data.api_key = fake.password(length=20)
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=f"Extension {i}: {fake.company()}",
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
saved_extension = APIBasedExtensionService.save(extension_data)
extensions.append(saved_extension)
@ -173,13 +177,14 @@ class TestAPIBasedExtensionService:
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant is not None
# Create an extension
extension_data = APIBasedExtension()
extension_data.tenant_id = tenant.id
extension_data.name = fake.company()
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data.api_key = fake.password(length=20)
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
created_extension = APIBasedExtensionService.save(extension_data)
@ -217,13 +222,14 @@ class TestAPIBasedExtensionService:
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant is not None
# Create an extension first
extension_data = APIBasedExtension()
extension_data.tenant_id = tenant.id
extension_data.name = fake.company()
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data.api_key = fake.password(length=20)
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
created_extension = APIBasedExtensionService.save(extension_data)
extension_id = created_extension.id
@ -245,22 +251,23 @@ class TestAPIBasedExtensionService:
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant is not None
# Create first extension
extension_data1 = APIBasedExtension()
extension_data1.tenant_id = tenant.id
extension_data1.name = "Test Extension"
extension_data1.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data1.api_key = fake.password(length=20)
extension_data1 = APIBasedExtension(
tenant_id=tenant.id,
name="Test Extension",
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
APIBasedExtensionService.save(extension_data1)
# Try to create second extension with same name
extension_data2 = APIBasedExtension()
extension_data2.tenant_id = tenant.id
extension_data2.name = "Test Extension" # Same name
extension_data2.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data2.api_key = fake.password(length=20)
extension_data2 = APIBasedExtension(
tenant_id=tenant.id,
name="Test Extension", # Same name
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
with pytest.raises(ValueError, match="name must be unique, it is already existed"):
APIBasedExtensionService.save(extension_data2)
@ -273,13 +280,14 @@ class TestAPIBasedExtensionService:
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant is not None
# Create initial extension
extension_data = APIBasedExtension()
extension_data.tenant_id = tenant.id
extension_data.name = fake.company()
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data.api_key = fake.password(length=20)
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
created_extension = APIBasedExtensionService.save(extension_data)
@ -330,13 +338,14 @@ class TestAPIBasedExtensionService:
mock_external_service_dependencies["requestor_instance"].request.side_effect = ValueError(
"connection error: request timeout"
)
assert tenant is not None
# Setup extension data
extension_data = APIBasedExtension()
extension_data.tenant_id = tenant.id
extension_data.name = fake.company()
extension_data.api_endpoint = "https://invalid-endpoint.com/api"
extension_data.api_key = fake.password(length=20)
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=fake.company(),
api_endpoint="https://invalid-endpoint.com/api",
api_key=fake.password(length=20),
)
# Try to save extension with connection error
with pytest.raises(ValueError, match="connection error: request timeout"):
@ -352,13 +361,14 @@ class TestAPIBasedExtensionService:
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant is not None
# Setup extension data with short API key
extension_data = APIBasedExtension()
extension_data.tenant_id = tenant.id
extension_data.name = fake.company()
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data.api_key = "1234" # Less than 5 characters
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key="1234", # Less than 5 characters
)
# Try to save extension with short API key
with pytest.raises(ValueError, match="api_key must be at least 5 characters"):
@ -372,13 +382,14 @@ class TestAPIBasedExtensionService:
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant is not None
# Test with None values
extension_data = APIBasedExtension()
extension_data.tenant_id = tenant.id
extension_data.name = None
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data.api_key = fake.password(length=20)
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=None, # type: ignore # why str become None here???
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
with pytest.raises(ValueError, match="name must not be empty"):
APIBasedExtensionService.save(extension_data)
@ -424,13 +435,14 @@ class TestAPIBasedExtensionService:
# Mock invalid ping response
mock_external_service_dependencies["requestor_instance"].request.return_value = {"result": "invalid"}
assert tenant is not None
# Setup extension data
extension_data = APIBasedExtension()
extension_data.tenant_id = tenant.id
extension_data.name = fake.company()
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data.api_key = fake.password(length=20)
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
# Try to save extension with invalid ping response
with pytest.raises(ValueError, match="{'result': 'invalid'}"):
@ -447,13 +459,14 @@ class TestAPIBasedExtensionService:
# Mock ping response without result field
mock_external_service_dependencies["requestor_instance"].request.return_value = {"status": "ok"}
assert tenant is not None
# Setup extension data
extension_data = APIBasedExtension()
extension_data.tenant_id = tenant.id
extension_data.name = fake.company()
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data.api_key = fake.password(length=20)
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
# Try to save extension with missing ping result
with pytest.raises(ValueError, match="{'status': 'ok'}"):
@ -472,13 +485,14 @@ class TestAPIBasedExtensionService:
account2, tenant2 = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant1 is not None
# Create extension in first tenant
extension_data = APIBasedExtension()
extension_data.tenant_id = tenant1.id
extension_data.name = fake.company()
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data.api_key = fake.password(length=20)
extension_data = APIBasedExtension(
tenant_id=tenant1.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
created_extension = APIBasedExtensionService.save(extension_data)

View File

@ -70,12 +70,13 @@ def test__convert_to_http_request_node_for_chatbot(default_variables):
api_based_extension_id = "api_based_extension_id"
mock_api_based_extension = APIBasedExtension(
id=api_based_extension_id,
tenant_id="tenant_id",
name="api-1",
api_key="encrypted_api_key",
api_endpoint="https://dify.ai",
)
mock_api_based_extension.id = api_based_extension_id
workflow_converter = WorkflowConverter()
workflow_converter._get_api_based_extension = MagicMock(return_value=mock_api_based_extension)
@ -131,11 +132,12 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables):
api_based_extension_id = "api_based_extension_id"
mock_api_based_extension = APIBasedExtension(
id=api_based_extension_id,
tenant_id="tenant_id",
name="api-1",
api_key="encrypted_api_key",
api_endpoint="https://dify.ai",
)
mock_api_based_extension.id = api_based_extension_id
workflow_converter = WorkflowConverter()
workflow_converter._get_api_based_extension = MagicMock(return_value=mock_api_based_extension)
@ -281,6 +283,7 @@ def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables):
assert llm_node["data"]["model"]["name"] == model
assert llm_node["data"]["model"]["mode"] == model_mode.value
template = prompt_template.simple_prompt_template
assert template is not None
for v in default_variables:
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
assert llm_node["data"]["prompt_template"][0]["text"] == template + "\n"
@ -323,6 +326,7 @@ def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variab
assert llm_node["data"]["model"]["name"] == model
assert llm_node["data"]["model"]["mode"] == model_mode.value
template = prompt_template.simple_prompt_template
assert template is not None
for v in default_variables:
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
assert llm_node["data"]["prompt_template"]["text"] == template + "\n"
@ -374,6 +378,7 @@ def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables)
assert llm_node["data"]["model"]["name"] == model
assert llm_node["data"]["model"]["mode"] == model_mode.value
assert isinstance(llm_node["data"]["prompt_template"], list)
assert prompt_template.advanced_chat_prompt_template is not None
assert len(llm_node["data"]["prompt_template"]) == len(prompt_template.advanced_chat_prompt_template.messages)
template = prompt_template.advanced_chat_prompt_template.messages[0].text
for v in default_variables:
@ -420,6 +425,7 @@ def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_var
assert llm_node["data"]["model"]["name"] == model
assert llm_node["data"]["model"]["mode"] == model_mode.value
assert isinstance(llm_node["data"]["prompt_template"], dict)
assert prompt_template.advanced_completion_prompt_template is not None
template = prompt_template.advanced_completion_prompt_template.prompt
for v in default_variables:
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")