From c493e08df15aed46ed77d8e784209b70f8349912 Mon Sep 17 00:00:00 2001 From: Charles Yao Date: Tue, 11 Nov 2025 20:05:11 -0600 Subject: [PATCH] add new table of end user oauth --- api/models/__init__.py | 2 ++ api/models/tools.py | 52 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/api/models/__init__.py b/api/models/__init__.py index d5e017e036..2a228b6c06 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -80,6 +80,7 @@ from .task import CeleryTask, CeleryTaskSet from .tools import ( ApiToolProvider, BuiltinToolProvider, + EndUserAuthenticationProvider, ToolConversationVariables, ToolFile, ToolLabelBinding, @@ -148,6 +149,7 @@ __all__ = [ "DocumentSegment", "Embedding", "EndUser", + "EndUserAuthenticationProvider", "ExternalKnowledgeApis", "ExternalKnowledgeBindings", "IconType", diff --git a/api/models/tools.py b/api/models/tools.py index cc80ddcf51..3d42cfa682 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -114,6 +114,58 @@ class BuiltinToolProvider(TypeBase): return cast(dict[str, Any], json.loads(self.encrypted_credentials)) +class EndUserAuthenticationProvider(TypeBase): + """ + This table stores the authentication credentials for end users in tools. + Mimics the BuiltinToolProvider structure but for end users instead of tenants. + """ + + __tablename__ = "tool_enduser_authentication_providers" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="tool_enduser_authentication_provider_pkey"), + sa.UniqueConstraint("tenant_id", "provider", "end_user_id", "name", name="unique_enduser_authentication_provider"), + sa.Index("tool_enduser_authentication_provider_tenant_id_idx", "tenant_id"), + sa.Index("tool_enduser_authentication_provider_end_user_id_idx", "end_user_id"), + ) + + # id of the authentication provider + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False) + name: Mapped[str] = mapped_column( + String(256), + nullable=False, + server_default=sa.text("'API KEY 1'::character varying"), + ) + # id of the tenant + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + # id of the end user + end_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + # name of the tool provider + provider: Mapped[str] = mapped_column(String(256), nullable=False) + # encrypted credentials for the end user + encrypted_credentials: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"), init=False + ) + updated_at: Mapped[datetime] = mapped_column( + sa.DateTime, + nullable=False, + server_default=sa.text("CURRENT_TIMESTAMP(0)"), + onupdate=func.current_timestamp(), + init=False, + ) + # credential type, e.g., "api-key", "oauth2" + credential_type: Mapped[str] = mapped_column( + String(32), nullable=False, server_default=sa.text("'api-key'::character varying"), default="api-key" + ) + expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1"), default=-1) + + @property + def credentials(self) -> dict[str, Any]: + if not self.encrypted_credentials: + return {} + return cast(dict[str, Any], json.loads(self.encrypted_credentials)) + + class ApiToolProvider(TypeBase): """ The table stores the api providers.