add new table of end user oauth

This commit is contained in:
Charles Yao 2025-11-11 20:05:11 -06:00
parent 850c5fec32
commit c493e08df1
2 changed files with 54 additions and 0 deletions

View File

@ -80,6 +80,7 @@ from .task import CeleryTask, CeleryTaskSet
from .tools import (
ApiToolProvider,
BuiltinToolProvider,
EndUserAuthenticationProvider,
ToolConversationVariables,
ToolFile,
ToolLabelBinding,
@ -148,6 +149,7 @@ __all__ = [
"DocumentSegment",
"Embedding",
"EndUser",
"EndUserAuthenticationProvider",
"ExternalKnowledgeApis",
"ExternalKnowledgeBindings",
"IconType",

View File

@ -114,6 +114,58 @@ class BuiltinToolProvider(TypeBase):
return cast(dict[str, Any], json.loads(self.encrypted_credentials))
class EndUserAuthenticationProvider(TypeBase):
"""
This table stores the authentication credentials for end users in tools.
Mimics the BuiltinToolProvider structure but for end users instead of tenants.
"""
__tablename__ = "tool_enduser_authentication_providers"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="tool_enduser_authentication_provider_pkey"),
sa.UniqueConstraint("tenant_id", "provider", "end_user_id", "name", name="unique_enduser_authentication_provider"),
sa.Index("tool_enduser_authentication_provider_tenant_id_idx", "tenant_id"),
sa.Index("tool_enduser_authentication_provider_end_user_id_idx", "end_user_id"),
)
# id of the authentication provider
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
name: Mapped[str] = mapped_column(
String(256),
nullable=False,
server_default=sa.text("'API KEY 1'::character varying"),
)
# id of the tenant
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# id of the end user
end_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# name of the tool provider
provider: Mapped[str] = mapped_column(String(256), nullable=False)
# encrypted credentials for the end user
encrypted_credentials: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"), init=False
)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
server_default=sa.text("CURRENT_TIMESTAMP(0)"),
onupdate=func.current_timestamp(),
init=False,
)
# credential type, e.g., "api-key", "oauth2"
credential_type: Mapped[str] = mapped_column(
String(32), nullable=False, server_default=sa.text("'api-key'::character varying"), default="api-key"
)
expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1"), default=-1)
@property
def credentials(self) -> dict[str, Any]:
if not self.encrypted_credentials:
return {}
return cast(dict[str, Any], json.loads(self.encrypted_credentials))
class ApiToolProvider(TypeBase):
"""
The table stores the api providers.