mirror of https://github.com/langgenius/dify.git
more typed orm (#28331)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
5f61ca5e6f
commit
3c30d0f41b
|
|
@ -6,7 +6,7 @@ import sqlalchemy as sa
|
|||
from sqlalchemy import DateTime, String, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .base import Base
|
||||
from .base import TypeBase
|
||||
from .types import LongText, StringUUID
|
||||
|
||||
|
||||
|
|
@ -17,16 +17,18 @@ class APIBasedExtensionPoint(enum.StrEnum):
|
|||
APP_MODERATION_OUTPUT = "app.moderation.output"
|
||||
|
||||
|
||||
class APIBasedExtension(Base):
|
||||
class APIBasedExtension(TypeBase):
|
||||
__tablename__ = "api_based_extensions"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="api_based_extension_pkey"),
|
||||
sa.Index("api_based_extension_tenant_idx", "tenant_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
api_endpoint: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
api_key = mapped_column(LongText, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
api_key: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,62 +6,74 @@ from sqlalchemy.orm import Mapped, mapped_column
|
|||
|
||||
from libs.uuid_utils import uuidv7
|
||||
|
||||
from .base import Base
|
||||
from .base import TypeBase
|
||||
from .types import AdjustedJSON, LongText, StringUUID
|
||||
|
||||
|
||||
class DatasourceOauthParamConfig(Base): # type: ignore[name-defined]
|
||||
class DatasourceOauthParamConfig(TypeBase):
|
||||
__tablename__ = "datasource_oauth_params"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="datasource_oauth_config_pkey"),
|
||||
sa.UniqueConstraint("plugin_id", "provider", name="datasource_oauth_config_datasource_id_provider_idx"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
|
||||
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
system_credentials: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False)
|
||||
|
||||
|
||||
class DatasourceProvider(Base):
|
||||
class DatasourceProvider(TypeBase):
|
||||
__tablename__ = "datasource_providers"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="datasource_provider_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", "name", name="datasource_provider_unique_name"),
|
||||
sa.Index("datasource_provider_auth_type_provider_idx", "tenant_id", "plugin_id", "provider"),
|
||||
)
|
||||
id = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(sa.String(128), nullable=False)
|
||||
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
auth_type: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
encrypted_credentials: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False)
|
||||
avatar_url: Mapped[str] = mapped_column(LongText, nullable=True, default="default")
|
||||
is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
expires_at: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default="-1")
|
||||
is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False)
|
||||
expires_at: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default="-1", default=-1)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
|
||||
class DatasourceOauthTenantParamConfig(Base):
|
||||
class DatasourceOauthTenantParamConfig(TypeBase):
|
||||
__tablename__ = "datasource_oauth_tenant_params"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="datasource_oauth_tenant_config_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="datasource_oauth_tenant_config_unique"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
client_params: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False, default={})
|
||||
client_params: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False, default_factory=dict)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -16,14 +16,15 @@ from core.trigger.entities.entities import Subscription
|
|||
from core.trigger.utils.endpoint import generate_plugin_trigger_endpoint_url, generate_webhook_trigger_endpoint
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models.base import Base, TypeBase
|
||||
from models.engine import db
|
||||
from models.enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTriggerStatus
|
||||
from models.model import Account
|
||||
from models.types import EnumText, LongText, StringUUID
|
||||
|
||||
from .base import Base, TypeBase
|
||||
from .engine import db
|
||||
from .enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTriggerStatus
|
||||
from .model import Account
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
|
||||
|
||||
class TriggerSubscription(Base):
|
||||
class TriggerSubscription(TypeBase):
|
||||
"""
|
||||
Trigger provider model for managing credentials
|
||||
Supports multiple credential instances per provider
|
||||
|
|
@ -40,7 +41,7 @@ class TriggerSubscription(Base):
|
|||
UniqueConstraint("tenant_id", "provider_id", "name", name="unique_trigger_provider"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False, comment="Subscription instance name")
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
|
@ -62,12 +63,15 @@ class TriggerSubscription(Base):
|
|||
Integer, default=-1, comment="Subscription instance expiration timestamp, -1 for never"
|
||||
)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
server_onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
def is_credential_expired(self) -> bool:
|
||||
|
|
@ -100,24 +104,27 @@ class TriggerSubscription(Base):
|
|||
|
||||
|
||||
# system level trigger oauth client params
|
||||
class TriggerOAuthSystemClient(Base):
|
||||
class TriggerOAuthSystemClient(TypeBase):
|
||||
__tablename__ = "trigger_oauth_system_clients"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="trigger_oauth_system_client_pkey"),
|
||||
sa.UniqueConstraint("plugin_id", "provider", name="trigger_oauth_system_client_plugin_id_provider_idx"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
plugin_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# oauth params of the trigger provider
|
||||
encrypted_oauth_params: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
server_onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -134,7 +141,7 @@ class TriggerOAuthTenantClient(Base):
|
|||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
plugin_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"), default=True)
|
||||
# oauth params of the trigger provider
|
||||
encrypted_oauth_params: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
|
|
|||
|
|
@ -181,19 +181,21 @@ class TriggerProviderService:
|
|||
|
||||
# Create provider record
|
||||
subscription = TriggerSubscription(
|
||||
id=subscription_id or str(uuid.uuid4()),
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
endpoint_id=endpoint_id,
|
||||
provider_id=str(provider_id),
|
||||
parameters=parameters,
|
||||
properties=properties_encrypter.encrypt(dict(properties)),
|
||||
credentials=credential_encrypter.encrypt(dict(credentials)) if credential_encrypter else {},
|
||||
parameters=dict(parameters),
|
||||
properties=dict(properties_encrypter.encrypt(dict(properties))),
|
||||
credentials=dict(credential_encrypter.encrypt(dict(credentials)))
|
||||
if credential_encrypter
|
||||
else {},
|
||||
credential_type=credential_type.value,
|
||||
credential_expires_at=credential_expires_at,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
subscription.id = subscription_id or str(uuid.uuid4())
|
||||
|
||||
session.add(subscription)
|
||||
session.commit()
|
||||
|
|
|
|||
|
|
@ -69,13 +69,14 @@ class TestAPIBasedExtensionService:
|
|||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
assert tenant is not None
|
||||
# Setup extension data
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
# Save extension
|
||||
saved_extension = APIBasedExtensionService.save(extension_data)
|
||||
|
|
@ -105,13 +106,14 @@ class TestAPIBasedExtensionService:
|
|||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
assert tenant is not None
|
||||
# Test empty name
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = ""
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name="",
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="name must not be empty"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
|
@ -141,12 +143,14 @@ class TestAPIBasedExtensionService:
|
|||
|
||||
# Create multiple extensions
|
||||
extensions = []
|
||||
assert tenant is not None
|
||||
for i in range(3):
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = f"Extension {i}: {fake.company()}"
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name=f"Extension {i}: {fake.company()}",
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
saved_extension = APIBasedExtensionService.save(extension_data)
|
||||
extensions.append(saved_extension)
|
||||
|
|
@ -173,13 +177,14 @@ class TestAPIBasedExtensionService:
|
|||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
assert tenant is not None
|
||||
# Create an extension
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
created_extension = APIBasedExtensionService.save(extension_data)
|
||||
|
||||
|
|
@ -217,13 +222,14 @@ class TestAPIBasedExtensionService:
|
|||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
assert tenant is not None
|
||||
# Create an extension first
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
created_extension = APIBasedExtensionService.save(extension_data)
|
||||
extension_id = created_extension.id
|
||||
|
|
@ -245,22 +251,23 @@ class TestAPIBasedExtensionService:
|
|||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
assert tenant is not None
|
||||
# Create first extension
|
||||
extension_data1 = APIBasedExtension()
|
||||
extension_data1.tenant_id = tenant.id
|
||||
extension_data1.name = "Test Extension"
|
||||
extension_data1.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data1.api_key = fake.password(length=20)
|
||||
extension_data1 = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name="Test Extension",
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
APIBasedExtensionService.save(extension_data1)
|
||||
|
||||
# Try to create second extension with same name
|
||||
extension_data2 = APIBasedExtension()
|
||||
extension_data2.tenant_id = tenant.id
|
||||
extension_data2.name = "Test Extension" # Same name
|
||||
extension_data2.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data2.api_key = fake.password(length=20)
|
||||
extension_data2 = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name="Test Extension", # Same name
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="name must be unique, it is already existed"):
|
||||
APIBasedExtensionService.save(extension_data2)
|
||||
|
|
@ -273,13 +280,14 @@ class TestAPIBasedExtensionService:
|
|||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
assert tenant is not None
|
||||
# Create initial extension
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
created_extension = APIBasedExtensionService.save(extension_data)
|
||||
|
||||
|
|
@ -330,13 +338,14 @@ class TestAPIBasedExtensionService:
|
|||
mock_external_service_dependencies["requestor_instance"].request.side_effect = ValueError(
|
||||
"connection error: request timeout"
|
||||
)
|
||||
|
||||
assert tenant is not None
|
||||
# Setup extension data
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = "https://invalid-endpoint.com/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
api_endpoint="https://invalid-endpoint.com/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
# Try to save extension with connection error
|
||||
with pytest.raises(ValueError, match="connection error: request timeout"):
|
||||
|
|
@ -352,13 +361,14 @@ class TestAPIBasedExtensionService:
|
|||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
assert tenant is not None
|
||||
# Setup extension data with short API key
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = "1234" # Less than 5 characters
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key="1234", # Less than 5 characters
|
||||
)
|
||||
|
||||
# Try to save extension with short API key
|
||||
with pytest.raises(ValueError, match="api_key must be at least 5 characters"):
|
||||
|
|
@ -372,13 +382,14 @@ class TestAPIBasedExtensionService:
|
|||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
assert tenant is not None
|
||||
# Test with None values
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = None
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name=None, # type: ignore # why str become None here???
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="name must not be empty"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
|
@ -424,13 +435,14 @@ class TestAPIBasedExtensionService:
|
|||
|
||||
# Mock invalid ping response
|
||||
mock_external_service_dependencies["requestor_instance"].request.return_value = {"result": "invalid"}
|
||||
|
||||
assert tenant is not None
|
||||
# Setup extension data
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
# Try to save extension with invalid ping response
|
||||
with pytest.raises(ValueError, match="{'result': 'invalid'}"):
|
||||
|
|
@ -447,13 +459,14 @@ class TestAPIBasedExtensionService:
|
|||
|
||||
# Mock ping response without result field
|
||||
mock_external_service_dependencies["requestor_instance"].request.return_value = {"status": "ok"}
|
||||
|
||||
assert tenant is not None
|
||||
# Setup extension data
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
# Try to save extension with missing ping result
|
||||
with pytest.raises(ValueError, match="{'status': 'ok'}"):
|
||||
|
|
@ -472,13 +485,14 @@ class TestAPIBasedExtensionService:
|
|||
account2, tenant2 = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
assert tenant1 is not None
|
||||
# Create extension in first tenant
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant1.id
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant1.id,
|
||||
name=fake.company(),
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
created_extension = APIBasedExtensionService.save(extension_data)
|
||||
|
||||
|
|
|
|||
|
|
@ -70,12 +70,13 @@ def test__convert_to_http_request_node_for_chatbot(default_variables):
|
|||
|
||||
api_based_extension_id = "api_based_extension_id"
|
||||
mock_api_based_extension = APIBasedExtension(
|
||||
id=api_based_extension_id,
|
||||
tenant_id="tenant_id",
|
||||
name="api-1",
|
||||
api_key="encrypted_api_key",
|
||||
api_endpoint="https://dify.ai",
|
||||
)
|
||||
|
||||
mock_api_based_extension.id = api_based_extension_id
|
||||
workflow_converter = WorkflowConverter()
|
||||
workflow_converter._get_api_based_extension = MagicMock(return_value=mock_api_based_extension)
|
||||
|
||||
|
|
@ -131,11 +132,12 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables):
|
|||
|
||||
api_based_extension_id = "api_based_extension_id"
|
||||
mock_api_based_extension = APIBasedExtension(
|
||||
id=api_based_extension_id,
|
||||
tenant_id="tenant_id",
|
||||
name="api-1",
|
||||
api_key="encrypted_api_key",
|
||||
api_endpoint="https://dify.ai",
|
||||
)
|
||||
mock_api_based_extension.id = api_based_extension_id
|
||||
|
||||
workflow_converter = WorkflowConverter()
|
||||
workflow_converter._get_api_based_extension = MagicMock(return_value=mock_api_based_extension)
|
||||
|
|
@ -281,6 +283,7 @@ def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables):
|
|||
assert llm_node["data"]["model"]["name"] == model
|
||||
assert llm_node["data"]["model"]["mode"] == model_mode.value
|
||||
template = prompt_template.simple_prompt_template
|
||||
assert template is not None
|
||||
for v in default_variables:
|
||||
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
|
||||
assert llm_node["data"]["prompt_template"][0]["text"] == template + "\n"
|
||||
|
|
@ -323,6 +326,7 @@ def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variab
|
|||
assert llm_node["data"]["model"]["name"] == model
|
||||
assert llm_node["data"]["model"]["mode"] == model_mode.value
|
||||
template = prompt_template.simple_prompt_template
|
||||
assert template is not None
|
||||
for v in default_variables:
|
||||
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
|
||||
assert llm_node["data"]["prompt_template"]["text"] == template + "\n"
|
||||
|
|
@ -374,6 +378,7 @@ def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables)
|
|||
assert llm_node["data"]["model"]["name"] == model
|
||||
assert llm_node["data"]["model"]["mode"] == model_mode.value
|
||||
assert isinstance(llm_node["data"]["prompt_template"], list)
|
||||
assert prompt_template.advanced_chat_prompt_template is not None
|
||||
assert len(llm_node["data"]["prompt_template"]) == len(prompt_template.advanced_chat_prompt_template.messages)
|
||||
template = prompt_template.advanced_chat_prompt_template.messages[0].text
|
||||
for v in default_variables:
|
||||
|
|
@ -420,6 +425,7 @@ def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_var
|
|||
assert llm_node["data"]["model"]["name"] == model
|
||||
assert llm_node["data"]["model"]["mode"] == model_mode.value
|
||||
assert isinstance(llm_node["data"]["prompt_template"], dict)
|
||||
assert prompt_template.advanced_completion_prompt_template is not None
|
||||
template = prompt_template.advanced_completion_prompt_template.prompt
|
||||
for v in default_variables:
|
||||
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
|
||||
|
|
|
|||
Loading…
Reference in New Issue