mirror of https://github.com/langgenius/dify.git
add cancel provider credential
This commit is contained in:
parent
874406d934
commit
32b2d19622
|
|
@ -175,6 +175,22 @@ class ModelProviderCredentialSwitchApi(Resource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
class ModelProviderCredentialCancelApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
service = ModelProviderService()
|
||||
service.cancel_provider_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=provider,
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class ModelProviderValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -289,6 +305,9 @@ api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-provider
|
|||
api.add_resource(
|
||||
ModelProviderCredentialSwitchApi, "/workspaces/current/model-providers/<path:provider>/credentials/switch"
|
||||
)
|
||||
api.add_resource(
|
||||
ModelProviderCredentialCancelApi, "/workspaces/current/model-providers/<path:provider>/credentials/cancel"
|
||||
)
|
||||
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate")
|
||||
|
||||
api.add_resource(
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ from core.plugin.entities.plugin import ModelProviderID
|
|||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.provider import (
|
||||
CredentialStatus,
|
||||
LoadBalancingModelConfig,
|
||||
Provider,
|
||||
ProviderCredential,
|
||||
|
|
@ -43,6 +44,7 @@ from models.provider import (
|
|||
TenantPreferredModelProvider,
|
||||
)
|
||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||
from services.entities.model_provider_entities import CustomConfigurationStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -188,6 +190,18 @@ class ProviderConfiguration(BaseModel):
|
|||
if current_quota_configuration.is_valid
|
||||
else SystemConfigurationStatus.QUOTA_EXCEEDED
|
||||
)
|
||||
|
||||
def get_custom_configuration_status(self) -> Optional[CustomConfigurationStatus]:
|
||||
"""
|
||||
Get custom configuration status.
|
||||
:return:
|
||||
"""
|
||||
if not self.is_custom_configuration_available():
|
||||
return CustomConfigurationStatus.NO_CONFIGURE
|
||||
elif self.custom_configuration.provider.current_credential_status:
|
||||
return self.custom_configuration.provider.current_credential_status
|
||||
|
||||
return CustomConfigurationStatus.ACTIVE
|
||||
|
||||
def is_custom_configuration_available(self) -> bool:
|
||||
"""
|
||||
|
|
@ -643,6 +657,7 @@ class ProviderConfiguration(BaseModel):
|
|||
self.switch_preferred_provider_type(provider_type=ProviderType.SYSTEM, session=session)
|
||||
elif provider_record and provider_record.credential_id == credential_id:
|
||||
provider_record.credential_id = None
|
||||
provider_record.credential_status = CredentialStatus.REMOVED.value
|
||||
provider_record.updated_at = naive_utc_now()
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
|
|
@ -681,6 +696,34 @@ class ProviderConfiguration(BaseModel):
|
|||
|
||||
try:
|
||||
provider_record.credential_id = credential_record.id
|
||||
provider_record.credential_status = CredentialStatus.ACTIVE.value
|
||||
provider_record.updated_at = naive_utc_now()
|
||||
session.commit()
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=provider_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
provider_model_credentials_cache.delete()
|
||||
self.switch_preferred_provider_type(ProviderType.CUSTOM, session=session)
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
|
||||
def cancel_provider_credential(self):
|
||||
"""
|
||||
Cancel select the active provider credential.
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
provider_record = self._get_provider_record(session)
|
||||
if not provider_record:
|
||||
raise ValueError("Provider record not found.")
|
||||
|
||||
try:
|
||||
provider_record.credential_id = None
|
||||
provider_record.credential_status = CredentialStatus.CANCEL.value
|
||||
provider_record.updated_at = naive_utc_now()
|
||||
session.commit()
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from core.entities.parameter_entities import (
|
|||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from models.provider import CredentialStatus
|
||||
|
||||
|
||||
class ProviderQuotaType(Enum):
|
||||
|
|
@ -97,6 +98,7 @@ class CustomProviderConfiguration(BaseModel):
|
|||
credentials: dict
|
||||
current_credential_id: Optional[str] = None
|
||||
current_credential_name: Optional[str] = None
|
||||
current_credential_status: Optional[CredentialStatus] = None
|
||||
available_credentials: list[CredentialConfiguration] = []
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -711,6 +711,7 @@ class ProviderManager:
|
|||
credentials=provider_credentials,
|
||||
current_credential_name=custom_provider_record.credential_name,
|
||||
current_credential_id=custom_provider_record.credential_id,
|
||||
current_credential_status=custom_provider_record.credential_status,
|
||||
available_credentials=self.get_provider_available_credentials(
|
||||
tenant_id, custom_provider_record.provider_name
|
||||
),
|
||||
|
|
|
|||
|
|
@ -0,0 +1,33 @@
|
|||
"""Add credential status for provider table
|
||||
|
||||
Revision ID: cf7c38a32b2d
|
||||
Revises: c20211f18133
|
||||
Create Date: 2025-09-11 15:37:17.771298
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'cf7c38a32b2d'
|
||||
down_revision = 'c20211f18133'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('credential_status', sa.String(length=20), server_default=sa.text("'active'::character varying"), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('providers', schema=None) as batch_op:
|
||||
batch_op.drop_column('credential_status')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import Enum, StrEnum
|
||||
from functools import cached_property
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -40,6 +40,11 @@ class ProviderQuotaType(Enum):
|
|||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
class CredentialStatus(StrEnum):
|
||||
ACTIVE = "active"
|
||||
CANCELED = "canceled"
|
||||
REMOVED = "removed"
|
||||
|
||||
|
||||
class Provider(Base):
|
||||
|
|
@ -65,6 +70,9 @@ class Provider(Base):
|
|||
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
|
||||
last_used: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True)
|
||||
credential_status: Mapped[Optional[str]] = mapped_column(
|
||||
String(20), nullable=True, server_default=text("'active'::character varying")
|
||||
)
|
||||
|
||||
quota_type: Mapped[Optional[str]] = mapped_column(
|
||||
String(40), nullable=True, server_default=text("''::character varying")
|
||||
|
|
|
|||
|
|
@ -34,6 +34,8 @@ class CustomConfigurationStatus(Enum):
|
|||
|
||||
ACTIVE = "active"
|
||||
NO_CONFIGURE = "no-configure"
|
||||
CANCELED = "canceled"
|
||||
REMOVED = "removed"
|
||||
|
||||
|
||||
class CustomConfigurationResponse(BaseModel):
|
||||
|
|
|
|||
|
|
@ -89,9 +89,7 @@ class ModelProviderService:
|
|||
model_credential_schema=provider_configuration.provider.model_credential_schema,
|
||||
preferred_provider_type=provider_configuration.preferred_provider_type,
|
||||
custom_configuration=CustomConfigurationResponse(
|
||||
status=CustomConfigurationStatus.ACTIVE
|
||||
if provider_configuration.is_custom_configuration_available()
|
||||
else CustomConfigurationStatus.NO_CONFIGURE,
|
||||
status=provider_configuration.get_custom_configuration_status(),
|
||||
current_credential_id=getattr(provider_config, "current_credential_id", None),
|
||||
current_credential_name=getattr(provider_config, "current_credential_name", None),
|
||||
available_credentials=getattr(provider_config, "available_credentials", []),
|
||||
|
|
@ -214,6 +212,16 @@ class ModelProviderService:
|
|||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.switch_active_provider_credential(credential_id=credential_id)
|
||||
|
||||
def cancel_provider_credential(self, tenant_id: str, provider: str):
|
||||
"""
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param credential_id: credential id
|
||||
:return:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.cancel_provider_credential()
|
||||
|
||||
def get_model_credential(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | None
|
||||
) -> Optional[dict]:
|
||||
|
|
|
|||
Loading…
Reference in New Issue