mirror of https://github.com/langgenius/dify.git
refactor(api): update subscription handling in trigger provider
- Replaced SubscriptionSchema with SubscriptionConstructor in various parts of the trigger provider implementation to streamline subscription management. - Enhanced the PluginTriggerProviderController to utilize the new subscription constructor for retrieving default properties and credential schemas. - Removed the deprecated get_provider_subscription_schema method from TriggerManager. - Updated TriggerSubscriptionBuilderService to reflect changes in subscription handling, ensuring compatibility with the new structure. These changes improve the clarity and maintainability of the subscription handling within the trigger provider architecture.
This commit is contained in:
parent
a06d2892f8
commit
5e3e6b0bd8
|
|
@ -7,7 +7,7 @@ from core.entities.provider_entities import ProviderConfig
|
|||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.trigger.entities.entities import (
|
||||
SubscriptionSchema,
|
||||
SubscriptionConstructor,
|
||||
TriggerCreationMethod,
|
||||
TriggerDescription,
|
||||
TriggerIdentity,
|
||||
|
|
@ -52,12 +52,13 @@ class TriggerProviderApiEntity(BaseModel):
|
|||
description="Supported creation methods for the trigger provider. like 'OAUTH', 'APIKEY', 'MANUAL'.",
|
||||
)
|
||||
|
||||
credentials_schema: list[ProviderConfig] = Field(description="The credentials schema of the trigger provider")
|
||||
oauth_client_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list, description="The schema of the OAuth client"
|
||||
subscription_constructor: Optional[SubscriptionConstructor] = Field(
|
||||
default=None, description="The subscription constructor of the trigger provider"
|
||||
)
|
||||
subscription_schema: Optional[SubscriptionSchema] = Field(
|
||||
description="The subscription schema of the trigger provider"
|
||||
|
||||
subscription_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list,
|
||||
description="The subscription schema of the trigger provider",
|
||||
)
|
||||
triggers: list[TriggerApiEntity] = Field(description="The triggers of the trigger provider")
|
||||
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ class TriggerParameterType(StrEnum):
|
|||
ARRAY = "array"
|
||||
DYNAMIC_SELECT = "dynamic-select"
|
||||
CHECKBOX = "checkbox"
|
||||
|
||||
|
||||
def as_normal_type(self):
|
||||
return as_normal_type(self)
|
||||
|
||||
|
|
@ -119,32 +119,30 @@ class OAuthSchema(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class SubscriptionSchema(BaseModel):
|
||||
class SubscriptionConstructor(BaseModel):
|
||||
"""
|
||||
The subscription schema of the trigger provider
|
||||
The subscription constructor of the trigger provider
|
||||
"""
|
||||
|
||||
parameters_schema: list[TriggerParameter] | None = Field(
|
||||
default_factory=list,
|
||||
description="The parameters schema required to create a subscription",
|
||||
parameters: list[TriggerParameter] = Field(
|
||||
default_factory=list, description="The parameters schema of the subscription constructor"
|
||||
)
|
||||
|
||||
properties_schema: list[ProviderConfig] | None = Field(
|
||||
credentials_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list,
|
||||
description="The configuration schema stored in the subscription entity",
|
||||
description="The credentials schema of the subscription constructor",
|
||||
)
|
||||
|
||||
oauth_schema: Optional[OAuthSchema] = Field(
|
||||
default=None,
|
||||
description="The OAuth schema of the subscription constructor if OAuth is supported",
|
||||
)
|
||||
|
||||
def get_default_parameters(self) -> Mapping[str, Any]:
|
||||
"""Get the default parameters from the parameters schema"""
|
||||
if not self.parameters_schema:
|
||||
if not self.parameters:
|
||||
return {}
|
||||
return {param.name: param.default for param in self.parameters_schema if param.default}
|
||||
|
||||
def get_default_properties(self) -> Mapping[str, Any]:
|
||||
"""Get the default properties from the properties schema"""
|
||||
if not self.properties_schema:
|
||||
return {}
|
||||
return {prop.name: prop.default for prop in self.properties_schema if prop.default}
|
||||
return {param.name: param.default for param in self.parameters if param.default}
|
||||
|
||||
|
||||
class TriggerProviderEntity(BaseModel):
|
||||
|
|
@ -153,16 +151,12 @@ class TriggerProviderEntity(BaseModel):
|
|||
"""
|
||||
|
||||
identity: TriggerProviderIdentity = Field(..., description="The identity of the trigger provider")
|
||||
credentials_schema: list[ProviderConfig] = Field(
|
||||
subscription_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list,
|
||||
description="The credentials schema of the trigger provider",
|
||||
description="The configuration schema stored in the subscription entity",
|
||||
)
|
||||
oauth_schema: Optional[OAuthSchema] = Field(
|
||||
default=None,
|
||||
description="The OAuth schema of the trigger provider if OAuth is supported",
|
||||
)
|
||||
subscription_schema: SubscriptionSchema = Field(
|
||||
description="The subscription schema for trigger(webhook, polling, etc.) subscription parameters",
|
||||
subscription_constructor: SubscriptionConstructor = Field(
|
||||
description="The subscription constructor of the trigger provider",
|
||||
)
|
||||
triggers: list[TriggerEntity] = Field(default=[], description="The triggers of the trigger provider")
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from core.trigger.entities.api_entities import TriggerApiEntity, TriggerProvider
|
|||
from core.trigger.entities.entities import (
|
||||
ProviderConfig,
|
||||
Subscription,
|
||||
SubscriptionSchema,
|
||||
SubscriptionConstructor,
|
||||
TriggerCreationMethod,
|
||||
TriggerEntity,
|
||||
TriggerProviderEntity,
|
||||
|
|
@ -81,13 +81,12 @@ class PluginTriggerProviderController:
|
|||
if self.entity.identity.icon_dark
|
||||
else None
|
||||
)
|
||||
supported_creation_methods = []
|
||||
if self.entity.oauth_schema:
|
||||
subscription_constructor = self.entity.subscription_constructor
|
||||
supported_creation_methods = [TriggerCreationMethod.MANUAL]
|
||||
if subscription_constructor and subscription_constructor.oauth_schema:
|
||||
supported_creation_methods.append(TriggerCreationMethod.OAUTH)
|
||||
if self.entity.credentials_schema:
|
||||
if subscription_constructor and subscription_constructor.credentials_schema:
|
||||
supported_creation_methods.append(TriggerCreationMethod.APIKEY)
|
||||
if self.entity.subscription_schema:
|
||||
supported_creation_methods.append(TriggerCreationMethod.MANUAL)
|
||||
return TriggerProviderApiEntity(
|
||||
author=self.entity.identity.author,
|
||||
name=self.entity.identity.name,
|
||||
|
|
@ -98,8 +97,7 @@ class PluginTriggerProviderController:
|
|||
tags=self.entity.identity.tags,
|
||||
plugin_id=self.plugin_id,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
credentials_schema=self.entity.credentials_schema,
|
||||
oauth_client_schema=self.entity.oauth_schema.client_schema if self.entity.oauth_schema else [],
|
||||
subscription_constructor=subscription_constructor,
|
||||
subscription_schema=self.entity.subscription_schema,
|
||||
supported_creation_methods=supported_creation_methods,
|
||||
triggers=[
|
||||
|
|
@ -139,13 +137,21 @@ class PluginTriggerProviderController:
|
|||
return trigger
|
||||
return None
|
||||
|
||||
def get_subscription_schema(self) -> SubscriptionSchema:
|
||||
def get_subscription_default_properties(self) -> Mapping[str, Any]:
|
||||
"""
|
||||
Get subscription schema for this provider
|
||||
Get default properties for this provider
|
||||
|
||||
:return: List of subscription config schemas
|
||||
:return: Default properties
|
||||
"""
|
||||
return self.entity.subscription_schema
|
||||
return {prop.name: prop.default for prop in self.entity.subscription_schema if prop.default}
|
||||
|
||||
def get_subscription_constructor(self) -> SubscriptionConstructor:
|
||||
"""
|
||||
Get subscription constructor for this provider
|
||||
|
||||
:return: Subscription constructor
|
||||
"""
|
||||
return self.entity.subscription_constructor
|
||||
|
||||
def validate_credentials(self, user_id: str, credentials: Mapping[str, str]) -> None:
|
||||
"""
|
||||
|
|
@ -155,7 +161,7 @@ class PluginTriggerProviderController:
|
|||
:return: Validation response
|
||||
"""
|
||||
# First validate against schema
|
||||
for config in self.entity.credentials_schema:
|
||||
for config in self.entity.subscription_constructor.credentials_schema:
|
||||
if config.required and config.name not in credentials:
|
||||
raise TriggerProviderCredentialValidationError(f"Missing required credential field: {config.name}")
|
||||
|
||||
|
|
@ -180,9 +186,10 @@ class PluginTriggerProviderController:
|
|||
:return: List of supported credential types
|
||||
"""
|
||||
types = []
|
||||
if self.entity.oauth_schema:
|
||||
subscription_constructor = self.entity.subscription_constructor
|
||||
if subscription_constructor and subscription_constructor.oauth_schema:
|
||||
types.append(CredentialType.OAUTH2)
|
||||
if self.entity.credentials_schema:
|
||||
if subscription_constructor and subscription_constructor.credentials_schema:
|
||||
types.append(CredentialType.API_KEY)
|
||||
return types
|
||||
|
||||
|
|
@ -193,11 +200,20 @@ class PluginTriggerProviderController:
|
|||
:param credential_type: The type of credential (oauth or api_key)
|
||||
:return: List of provider config schemas
|
||||
"""
|
||||
subscription_constructor = self.entity.subscription_constructor
|
||||
credential_type = CredentialType.of(credential_type) if isinstance(credential_type, str) else credential_type
|
||||
if credential_type == CredentialType.OAUTH2:
|
||||
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
|
||||
return (
|
||||
subscription_constructor.oauth_schema.credentials_schema.copy()
|
||||
if subscription_constructor and subscription_constructor.oauth_schema
|
||||
else []
|
||||
)
|
||||
if credential_type == CredentialType.API_KEY:
|
||||
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
|
||||
return (
|
||||
subscription_constructor.credentials_schema.copy()
|
||||
if subscription_constructor and subscription_constructor.credentials_schema
|
||||
else []
|
||||
)
|
||||
if credential_type == CredentialType.UNAUTHORIZED:
|
||||
return []
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
|
@ -214,7 +230,12 @@ class PluginTriggerProviderController:
|
|||
|
||||
:return: List of OAuth client config schemas
|
||||
"""
|
||||
return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else []
|
||||
subscription_constructor = self.entity.subscription_constructor
|
||||
return (
|
||||
subscription_constructor.oauth_schema.client_schema.copy()
|
||||
if subscription_constructor and subscription_constructor.oauth_schema
|
||||
else []
|
||||
)
|
||||
|
||||
def get_properties_schema(self) -> list[BasicProviderConfig]:
|
||||
"""
|
||||
|
|
@ -223,8 +244,8 @@ class PluginTriggerProviderController:
|
|||
:return: List of properties config schemas
|
||||
"""
|
||||
return (
|
||||
[x.to_basic_provider_config() for x in self.entity.subscription_schema.properties_schema.copy()]
|
||||
if self.entity.subscription_schema.properties_schema
|
||||
[x.to_basic_provider_config() for x in self.entity.subscription_schema.copy()]
|
||||
if self.entity.subscription_schema
|
||||
else []
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ from core.plugin.impl.exc import PluginInvokeError
|
|||
from core.plugin.impl.trigger import PluginTriggerManager
|
||||
from core.trigger.entities.entities import (
|
||||
Subscription,
|
||||
SubscriptionSchema,
|
||||
TriggerEntity,
|
||||
Unsubscription,
|
||||
)
|
||||
|
|
@ -226,17 +225,6 @@ class TriggerManager:
|
|||
provider = cls.get_trigger_provider(tenant_id, provider_id)
|
||||
return provider.unsubscribe_trigger(user_id=user_id, subscription=subscription, credentials=credentials)
|
||||
|
||||
@classmethod
|
||||
def get_provider_subscription_schema(cls, tenant_id: str, provider_id: TriggerProviderID) -> SubscriptionSchema:
|
||||
"""
|
||||
Get provider subscription schema
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider ID
|
||||
:return: List of subscription config schemas
|
||||
"""
|
||||
return cls.get_trigger_provider(tenant_id, provider_id).get_subscription_schema()
|
||||
|
||||
@classmethod
|
||||
def refresh_trigger(
|
||||
cls,
|
||||
|
|
|
|||
|
|
@ -152,7 +152,7 @@ class TriggerSubscriptionBuilderService:
|
|||
if not provider_controller:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
subscription_schema = provider_controller.get_subscription_schema()
|
||||
subscription_constructor = provider_controller.get_subscription_constructor()
|
||||
subscription_id = str(uuid.uuid4())
|
||||
subscription_builder = SubscriptionBuilder(
|
||||
id=subscription_id,
|
||||
|
|
@ -161,8 +161,8 @@ class TriggerSubscriptionBuilderService:
|
|||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=str(provider_id),
|
||||
parameters=subscription_schema.get_default_parameters(),
|
||||
properties=subscription_schema.get_default_properties(),
|
||||
parameters=subscription_constructor.get_default_parameters() if subscription_constructor else {},
|
||||
properties=provider_controller.get_subscription_default_properties(),
|
||||
credentials={},
|
||||
credential_type=credential_type,
|
||||
credential_expires_at=-1,
|
||||
|
|
|
|||
Loading…
Reference in New Issue