mirror of
https://github.com/langgenius/dify.git
synced 2026-06-07 16:32:01 +08:00
feat(plugin): cache plugin model providers by tenant (#36449)
Co-authored-by: WH-2099 <wh2099@pm.me>
This commit is contained in:
parent
2a0c098857
commit
8d99326fb3
@ -657,6 +657,7 @@ PLUGIN_REMOTE_INSTALL_PORT=5003
|
||||
PLUGIN_REMOTE_INSTALL_HOST=localhost
|
||||
PLUGIN_MAX_PACKAGE_SIZE=15728640
|
||||
PLUGIN_MODEL_SCHEMA_CACHE_TTL=3600
|
||||
PLUGIN_MODEL_PROVIDERS_CACHE_TTL=86400
|
||||
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
|
||||
|
||||
# Marketplace configuration
|
||||
|
||||
@ -11,6 +11,7 @@ from configs import dify_config
|
||||
from core.helper import encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.plugin.plugin_service import PluginService
|
||||
from core.tools.utils.system_encryption import encrypt_system_params
|
||||
from extensions.ext_database import db
|
||||
from models import Tenant
|
||||
@ -20,7 +21,6 @@ from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
|
||||
from models.tools import ToolOAuthSystemClient
|
||||
from services.plugin.data_migration import PluginDataMigration
|
||||
from services.plugin.plugin_migration import PluginMigration
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -265,6 +265,11 @@ class PluginConfig(BaseSettings):
|
||||
default=60 * 60,
|
||||
)
|
||||
|
||||
PLUGIN_MODEL_PROVIDERS_CACHE_TTL: PositiveInt = Field(
|
||||
description="TTL in seconds for caching tenant plugin model providers in Redis",
|
||||
default=60 * 60 * 24,
|
||||
)
|
||||
|
||||
PLUGIN_MAX_FILE_SIZE: PositiveInt = Field(
|
||||
description="Maximum allowed size (bytes) for plugin-generated files",
|
||||
default=50 * 1024 * 1024,
|
||||
|
||||
@ -15,6 +15,7 @@ from controllers.console import console_ns
|
||||
from controllers.console.workspace import plugin_permission_required
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.plugin.plugin_service import PluginService
|
||||
from fields.base import ResponseModel
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
@ -22,7 +23,6 @@ from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermissi
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
from services.plugin.plugin_parameter_service import PluginParameterService
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
|
||||
class ParserList(BaseModel):
|
||||
|
||||
@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
import hashlib
|
||||
import logging
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from threading import Lock
|
||||
from typing import IO, Any, Literal, cast, overload, override
|
||||
|
||||
from pydantic import ValidationError
|
||||
@ -13,9 +12,9 @@ from configs import dify_config
|
||||
from core.llm_generator.output_parser.structured_output import (
|
||||
invoke_llm_with_structured_output as invoke_llm_with_structured_output_helper,
|
||||
)
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.plugin.impl.asset import PluginAssetManager
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from core.plugin.plugin_service import PluginService
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
@ -101,35 +100,36 @@ class _PluginStructuredOutputModelInstance:
|
||||
|
||||
|
||||
class PluginModelRuntime(ModelRuntime):
|
||||
"""Plugin-backed runtime adapter bound to tenant context and optional caller scope."""
|
||||
"""Plugin-backed runtime adapter bound to tenant context and optional caller scope.
|
||||
|
||||
Provider discovery goes through ``PluginService`` so the plugin lifecycle
|
||||
methods and provider reads share one tenant-scoped cache owner.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
user_id: str | None
|
||||
client: PluginModelClient
|
||||
_provider_entities: tuple[ProviderEntity, ...] | None
|
||||
_provider_entities_lock: Lock
|
||||
_plugin_service: type[PluginService]
|
||||
|
||||
def __init__(self, tenant_id: str, user_id: str | None, client: PluginModelClient) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str | None,
|
||||
client: PluginModelClient,
|
||||
plugin_service: type[PluginService],
|
||||
) -> None:
|
||||
if client is None:
|
||||
raise ValueError("client is required.")
|
||||
if plugin_service is None:
|
||||
raise ValueError("plugin_service is required.")
|
||||
self.tenant_id = tenant_id
|
||||
self.user_id = user_id
|
||||
self.client = client
|
||||
self._provider_entities = None
|
||||
self._provider_entities_lock = Lock()
|
||||
self._plugin_service = plugin_service
|
||||
|
||||
@override
|
||||
def fetch_model_providers(self) -> Sequence[ProviderEntity]:
|
||||
if self._provider_entities is not None:
|
||||
return self._provider_entities
|
||||
|
||||
with self._provider_entities_lock:
|
||||
if self._provider_entities is None:
|
||||
self._provider_entities = tuple(
|
||||
self._to_provider_entity(provider) for provider in self.client.fetch_model_providers(self.tenant_id)
|
||||
)
|
||||
|
||||
return self._provider_entities
|
||||
return self._plugin_service.fetch_plugin_model_providers(tenant_id=self.tenant_id, client=self.client)
|
||||
|
||||
@override
|
||||
def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
|
||||
@ -628,34 +628,6 @@ class PluginModelRuntime(ModelRuntime):
|
||||
text=text,
|
||||
)
|
||||
|
||||
def _get_provider_short_name_alias(self, provider: PluginModelProviderEntity) -> str:
|
||||
"""
|
||||
Expose a bare provider alias only for the canonical provider mapping.
|
||||
|
||||
Multiple plugins can publish the same short provider slug. If every
|
||||
provider entity keeps that slug in ``provider_name``, callers that still
|
||||
resolve by short name become order-dependent. Restrict the alias to the
|
||||
provider selected by ``ModelProviderID`` so legacy short-name lookups
|
||||
remain deterministic while the runtime surface stays canonical.
|
||||
"""
|
||||
try:
|
||||
canonical_provider_id = ModelProviderID(provider.provider)
|
||||
except ValueError:
|
||||
return ""
|
||||
|
||||
if canonical_provider_id.plugin_id != provider.plugin_id:
|
||||
return ""
|
||||
if canonical_provider_id.provider_name != provider.provider:
|
||||
return ""
|
||||
|
||||
return provider.provider
|
||||
|
||||
def _to_provider_entity(self, provider: PluginModelProviderEntity) -> ProviderEntity:
|
||||
declaration = provider.declaration.model_copy(deep=True)
|
||||
declaration.provider = f"{provider.plugin_id}/{provider.provider}"
|
||||
declaration.provider_name = self._get_provider_short_name_alias(provider)
|
||||
return declaration
|
||||
|
||||
def _get_provider_schema(self, provider: str) -> ProviderEntity:
|
||||
providers = self.fetch_model_providers()
|
||||
provider_entity = next((item for item in providers if item.provider == provider), None)
|
||||
|
||||
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from core.plugin.plugin_service import PluginService
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from graphon.model_runtime.model_providers.base.ai_model import AIModel
|
||||
@ -117,6 +118,7 @@ def create_plugin_model_runtime(*, tenant_id: str, user_id: str | None = None) -
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
client=PluginModelClient(),
|
||||
plugin_service=PluginService,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1,8 +1,17 @@
|
||||
"""Core plugin service and tenant-scoped plugin metadata cache ownership.
|
||||
|
||||
This module owns plugin daemon management calls that are shared by API services
|
||||
and core runtimes. Plugin model provider discovery is cached here, alongside
|
||||
plugin install, uninstall, and upgrade invalidation, so all cache mutations for
|
||||
plugin-owned provider metadata stay tenant-scoped and in one place.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from mimetypes import guess_type
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, TypeAdapter, ValidationError
|
||||
from redis import RedisError
|
||||
from sqlalchemy import delete, select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from yarl import URL
|
||||
@ -22,16 +31,20 @@ from core.plugin.entities.plugin import (
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
PluginDecodeResponse,
|
||||
PluginInstallTask,
|
||||
PluginInstallTaskStatus,
|
||||
PluginListResponse,
|
||||
PluginModelProviderEntity,
|
||||
PluginVerification,
|
||||
)
|
||||
from core.plugin.impl.asset import PluginAssetManager
|
||||
from core.plugin.impl.debugging import PluginDebuggingClient
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider
|
||||
from models.provider_ids import GenericProviderID
|
||||
from models.provider_ids import GenericProviderID, ModelProviderID
|
||||
from services.enterprise.plugin_manager_service import (
|
||||
PluginManagerService,
|
||||
PreUninstallPluginRequest,
|
||||
@ -40,6 +53,7 @@ from services.errors.plugin import PluginInstallationForbiddenError
|
||||
from services.feature_service import FeatureService, PluginInstallationScope
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_provider_entities_adapter: TypeAdapter[list[ProviderEntity]] = TypeAdapter(list[ProviderEntity])
|
||||
|
||||
|
||||
class PluginService:
|
||||
@ -53,6 +67,102 @@ class PluginService:
|
||||
|
||||
REDIS_KEY_PREFIX = "plugin_service:latest_plugin:"
|
||||
REDIS_TTL = 60 * 5 # 5 minutes
|
||||
PLUGIN_MODEL_PROVIDERS_REDIS_KEY_PREFIX = "plugin_model_providers:tenant_id:"
|
||||
PLUGIN_INSTALL_TASK_TERMINAL_STATUSES = (PluginInstallTaskStatus.Success, PluginInstallTaskStatus.Failed)
|
||||
|
||||
@classmethod
|
||||
def _get_plugin_model_providers_cache_key(cls, tenant_id: str) -> str:
|
||||
return f"{cls.PLUGIN_MODEL_PROVIDERS_REDIS_KEY_PREFIX}{tenant_id}"
|
||||
|
||||
@staticmethod
|
||||
def _get_provider_short_name_alias(provider: PluginModelProviderEntity) -> str:
|
||||
"""
|
||||
Expose a bare provider alias only for the canonical provider mapping.
|
||||
|
||||
Multiple plugins can publish the same short provider slug. If every
|
||||
provider entity keeps that slug in ``provider_name``, callers that still
|
||||
resolve by short name become order-dependent. Restrict the alias to the
|
||||
provider selected by ``ModelProviderID`` so legacy short-name lookups
|
||||
remain deterministic while the runtime surface stays canonical.
|
||||
"""
|
||||
try:
|
||||
canonical_provider_id = ModelProviderID(provider.provider)
|
||||
except ValueError:
|
||||
return ""
|
||||
|
||||
if canonical_provider_id.plugin_id != provider.plugin_id:
|
||||
return ""
|
||||
if canonical_provider_id.provider_name != provider.provider:
|
||||
return ""
|
||||
|
||||
return provider.provider
|
||||
|
||||
@classmethod
|
||||
def _to_provider_entity(cls, provider: PluginModelProviderEntity) -> ProviderEntity:
|
||||
declaration = provider.declaration.model_copy(deep=True)
|
||||
declaration.provider = f"{provider.plugin_id}/{provider.provider}"
|
||||
declaration.provider_name = cls._get_provider_short_name_alias(provider)
|
||||
return declaration
|
||||
|
||||
@classmethod
|
||||
def _load_cached_plugin_model_providers(cls, tenant_id: str) -> tuple[ProviderEntity, ...] | None:
|
||||
cache_key = cls._get_plugin_model_providers_cache_key(tenant_id)
|
||||
try:
|
||||
cached_providers = redis_client.get(cache_key)
|
||||
except (RedisError, RuntimeError):
|
||||
logger.warning("Failed to read cached plugin model providers for tenant %s.", tenant_id, exc_info=True)
|
||||
return None
|
||||
|
||||
if not cached_providers:
|
||||
return None
|
||||
|
||||
try:
|
||||
return tuple(_provider_entities_adapter.validate_json(cached_providers))
|
||||
except (TypeError, ValueError, ValidationError):
|
||||
logger.warning(
|
||||
"Invalid cached plugin model providers for tenant %s; deleting cache.", tenant_id, exc_info=True
|
||||
)
|
||||
cls.invalidate_plugin_model_providers_cache(tenant_id)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _store_cached_plugin_model_providers(cls, tenant_id: str, providers: Sequence[ProviderEntity]) -> None:
|
||||
cache_key = cls._get_plugin_model_providers_cache_key(tenant_id)
|
||||
try:
|
||||
payload = _provider_entities_adapter.dump_json(list(providers)).decode("utf-8")
|
||||
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_PROVIDERS_CACHE_TTL, payload)
|
||||
except (RedisError, RuntimeError):
|
||||
logger.warning("Failed to cache plugin model providers for tenant %s.", tenant_id, exc_info=True)
|
||||
|
||||
@classmethod
|
||||
def invalidate_plugin_model_providers_cache(cls, tenant_id: str) -> None:
|
||||
"""Delete the tenant-scoped plugin model provider list cache."""
|
||||
try:
|
||||
redis_client.delete(cls._get_plugin_model_providers_cache_key(tenant_id))
|
||||
except (RedisError, RuntimeError):
|
||||
logger.warning("Failed to invalidate plugin model providers cache for tenant %s.", tenant_id, exc_info=True)
|
||||
|
||||
@classmethod
|
||||
def fetch_plugin_model_providers(
|
||||
cls, *, tenant_id: str, client: PluginModelClient | None = None
|
||||
) -> Sequence[ProviderEntity]:
|
||||
"""
|
||||
Fetch plugin model providers through the tenant-scoped plugin cache.
|
||||
|
||||
Plugin daemon provider discovery and plugin lifecycle cache invalidation
|
||||
are intentionally owned by this service so tenant isolation and cache
|
||||
expiry are handled in one place.
|
||||
"""
|
||||
cached_providers = cls._load_cached_plugin_model_providers(tenant_id)
|
||||
if cached_providers is not None:
|
||||
return cached_providers
|
||||
|
||||
model_client = client or PluginModelClient()
|
||||
providers = tuple(
|
||||
cls._to_provider_entity(provider) for provider in model_client.fetch_model_providers(tenant_id)
|
||||
)
|
||||
cls._store_cached_plugin_model_providers(tenant_id, providers)
|
||||
return providers
|
||||
|
||||
@staticmethod
|
||||
def fetch_latest_plugin_version(plugin_ids: Sequence[str]) -> Mapping[str, LatestPluginCache | None]:
|
||||
@ -248,12 +358,18 @@ class PluginService:
|
||||
Fetch plugin installation tasks
|
||||
"""
|
||||
manager = PluginInstaller()
|
||||
return manager.fetch_plugin_installation_tasks(tenant_id, page, page_size)
|
||||
tasks = manager.fetch_plugin_installation_tasks(tenant_id, page, page_size)
|
||||
if any(task.status in PluginService.PLUGIN_INSTALL_TASK_TERMINAL_STATUSES for task in tasks):
|
||||
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
|
||||
return tasks
|
||||
|
||||
@staticmethod
|
||||
def fetch_install_task(tenant_id: str, task_id: str) -> PluginInstallTask:
|
||||
manager = PluginInstaller()
|
||||
return manager.fetch_plugin_installation_task(tenant_id, task_id)
|
||||
task = manager.fetch_plugin_installation_task(tenant_id, task_id)
|
||||
if task.status in PluginService.PLUGIN_INSTALL_TASK_TERMINAL_STATUSES:
|
||||
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def delete_install_task(tenant_id: str, task_id: str) -> bool:
|
||||
@ -315,7 +431,7 @@ class PluginService:
|
||||
# check if the plugin is available to install
|
||||
PluginService._check_plugin_installation_scope(response.verification)
|
||||
|
||||
return manager.upgrade_plugin(
|
||||
result = manager.upgrade_plugin(
|
||||
tenant_id,
|
||||
original_plugin_unique_identifier,
|
||||
new_plugin_unique_identifier,
|
||||
@ -324,6 +440,8 @@ class PluginService:
|
||||
"plugin_unique_identifier": new_plugin_unique_identifier,
|
||||
},
|
||||
)
|
||||
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def upgrade_plugin_with_github(
|
||||
@ -339,7 +457,7 @@ class PluginService:
|
||||
"""
|
||||
PluginService._check_marketplace_only_permission()
|
||||
manager = PluginInstaller()
|
||||
return manager.upgrade_plugin(
|
||||
result = manager.upgrade_plugin(
|
||||
tenant_id,
|
||||
original_plugin_unique_identifier,
|
||||
new_plugin_unique_identifier,
|
||||
@ -350,6 +468,8 @@ class PluginService:
|
||||
"package": package,
|
||||
},
|
||||
)
|
||||
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def upload_pkg(tenant_id: str, pkg: bytes, verify_signature: bool = False) -> PluginDecodeResponse:
|
||||
@ -415,12 +535,14 @@ class PluginService:
|
||||
resp = manager.decode_plugin_from_identifier(tenant_id, plugin_unique_identifier)
|
||||
PluginService._check_plugin_installation_scope(resp.verification)
|
||||
|
||||
return manager.install_from_identifiers(
|
||||
result = manager.install_from_identifiers(
|
||||
tenant_id,
|
||||
plugin_unique_identifiers,
|
||||
PluginInstallationSource.Package,
|
||||
[{}],
|
||||
)
|
||||
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def install_from_github(tenant_id: str, plugin_unique_identifier: str, repo: str, version: str, package: str):
|
||||
@ -434,7 +556,7 @@ class PluginService:
|
||||
plugin_decode_response = manager.decode_plugin_from_identifier(tenant_id, plugin_unique_identifier)
|
||||
PluginService._check_plugin_installation_scope(plugin_decode_response.verification)
|
||||
|
||||
return manager.install_from_identifiers(
|
||||
result = manager.install_from_identifiers(
|
||||
tenant_id,
|
||||
[plugin_unique_identifier],
|
||||
PluginInstallationSource.Github,
|
||||
@ -446,6 +568,8 @@ class PluginService:
|
||||
}
|
||||
],
|
||||
)
|
||||
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def fetch_marketplace_pkg(tenant_id: str, plugin_unique_identifier: str) -> PluginDeclaration:
|
||||
@ -513,12 +637,14 @@ class PluginService:
|
||||
actual_plugin_unique_identifiers.append(response.unique_identifier)
|
||||
metas.append({"plugin_unique_identifier": response.unique_identifier})
|
||||
|
||||
return manager.install_from_identifiers(
|
||||
result = manager.install_from_identifiers(
|
||||
tenant_id,
|
||||
actual_plugin_unique_identifiers,
|
||||
PluginInstallationSource.Marketplace,
|
||||
metas,
|
||||
)
|
||||
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def uninstall(tenant_id: str, plugin_installation_id: str) -> bool:
|
||||
@ -529,7 +655,10 @@ class PluginService:
|
||||
plugin = next((p for p in plugins if p.installation_id == plugin_installation_id), None)
|
||||
|
||||
if not plugin:
|
||||
return manager.uninstall(tenant_id, plugin_installation_id)
|
||||
result = manager.uninstall(tenant_id, plugin_installation_id)
|
||||
if result:
|
||||
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
|
||||
return result
|
||||
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
PluginManagerService.try_pre_uninstall_plugin(
|
||||
@ -559,37 +688,39 @@ class PluginService:
|
||||
|
||||
if not credential_ids:
|
||||
logger.info("No credentials found for plugin: %s", plugin_id)
|
||||
return manager.uninstall(tenant_id, plugin_installation_id)
|
||||
else:
|
||||
provider_ids = session.scalars(
|
||||
select(Provider.id).where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name.like(f"{plugin_id}/%"),
|
||||
Provider.credential_id.in_(credential_ids),
|
||||
)
|
||||
).all()
|
||||
|
||||
provider_ids = session.scalars(
|
||||
select(Provider.id).where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name.like(f"{plugin_id}/%"),
|
||||
Provider.credential_id.in_(credential_ids),
|
||||
session.execute(update(Provider).where(Provider.id.in_(provider_ids)).values(credential_id=None))
|
||||
|
||||
for provider_id in provider_ids:
|
||||
ProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
identity_id=provider_id,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER,
|
||||
).delete()
|
||||
|
||||
session.execute(
|
||||
delete(ProviderCredential).where(
|
||||
ProviderCredential.id.in_(credential_ids),
|
||||
)
|
||||
)
|
||||
).all()
|
||||
|
||||
session.execute(update(Provider).where(Provider.id.in_(provider_ids)).values(credential_id=None))
|
||||
|
||||
for provider_id in provider_ids:
|
||||
ProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
identity_id=provider_id,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER,
|
||||
).delete()
|
||||
|
||||
session.execute(
|
||||
delete(ProviderCredential).where(
|
||||
ProviderCredential.id.in_(credential_ids),
|
||||
logger.info(
|
||||
"Completed deleting credentials and cleaning provider associations for plugin: %s",
|
||||
plugin_id,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Completed deleting credentials and cleaning provider associations for plugin: %s",
|
||||
plugin_id,
|
||||
)
|
||||
|
||||
return manager.uninstall(tenant_id, plugin_installation_id)
|
||||
result = manager.uninstall(tenant_id, plugin_installation_id)
|
||||
if result:
|
||||
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def check_tools_existence(tenant_id: str, provider_ids: Sequence[GenericProviderID]) -> Sequence[bool]:
|
||||
@ -16,6 +16,7 @@ from core.plugin.entities.request import (
|
||||
TriggerSubscriptionResponse,
|
||||
)
|
||||
from core.plugin.impl.trigger import PluginTriggerClient
|
||||
from core.plugin.plugin_service import PluginService
|
||||
from core.trigger.entities.api_entities import EventApiEntity, TriggerProviderApiEntity
|
||||
from core.trigger.entities.entities import (
|
||||
EventEntity,
|
||||
@ -30,7 +31,6 @@ from core.trigger.entities.entities import (
|
||||
)
|
||||
from core.trigger.errors import TriggerProviderCredentialValidationError
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -492,8 +492,8 @@ class App(Base):
|
||||
|
||||
@property
|
||||
def deleted_tools(self) -> list[DeletedToolInfo]:
|
||||
from core.plugin.plugin_service import PluginService
|
||||
from core.tools.tool_manager import ToolManager, ToolProviderType
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
# get agent mode tools
|
||||
app_model_config = self.app_model_config
|
||||
|
||||
@ -14,13 +14,13 @@ from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.plugin.plugin_service import PluginService
|
||||
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.model_runtime.entities.provider_entities import FormType
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider
|
||||
from models.provider_ids import DatasourceProviderID
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -22,6 +22,7 @@ from core.helper import marketplace
|
||||
from core.plugin.entities.plugin import PluginInstallationSource
|
||||
from core.plugin.entities.plugin_daemon import PluginInstallTaskStatus
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.plugin.plugin_service import PluginService
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from extensions.ext_database import db
|
||||
from models.account import Tenant
|
||||
@ -29,7 +30,6 @@ from models.model import App, AppMode, AppModelConfig
|
||||
from models.provider_ids import ModelProviderID, ToolProviderID
|
||||
from models.tools import BuiltinToolProvider
|
||||
from models.workflow import Workflow
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -389,17 +389,19 @@ class PluginMigration:
|
||||
for plugin_id in batch_plugin_ids
|
||||
if plugin_id not in installed_plugins_ids and plugin_id in plugins["plugins"]
|
||||
]
|
||||
manager.install_from_identifiers(
|
||||
tenant_id,
|
||||
batch_plugin_identifiers,
|
||||
PluginInstallationSource.Marketplace,
|
||||
metas=[
|
||||
{
|
||||
"plugin_unique_identifier": identifier,
|
||||
}
|
||||
for identifier in batch_plugin_identifiers
|
||||
],
|
||||
)
|
||||
if batch_plugin_identifiers:
|
||||
manager.install_from_identifiers(
|
||||
tenant_id,
|
||||
batch_plugin_identifiers,
|
||||
PluginInstallationSource.Marketplace,
|
||||
metas=[
|
||||
{
|
||||
"plugin_unique_identifier": identifier,
|
||||
}
|
||||
for identifier in batch_plugin_identifiers
|
||||
],
|
||||
)
|
||||
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
|
||||
|
||||
with open(extracted_plugins) as f:
|
||||
"""
|
||||
@ -595,6 +597,7 @@ class PluginMigration:
|
||||
for identifier in batch_plugin_identifiers
|
||||
],
|
||||
)
|
||||
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
|
||||
except Exception:
|
||||
# add to failed
|
||||
failed.extend(batch_plugin_identifiers)
|
||||
@ -609,6 +612,7 @@ class PluginMigration:
|
||||
while not done:
|
||||
status = manager.fetch_plugin_installation_task(tenant_id, task_id)
|
||||
if status.status in [PluginInstallTaskStatus.Failed, PluginInstallTaskStatus.Success]:
|
||||
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
|
||||
for plugin in status.plugins:
|
||||
if plugin.status == PluginInstallTaskStatus.Success:
|
||||
success.append(reverse_map[plugin.plugin_unique_identifier])
|
||||
|
||||
@ -12,6 +12,7 @@ from sqlalchemy import select
|
||||
from configs import dify_config
|
||||
from constants import DOCUMENT_EXTENSIONS
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.plugin.plugin_service import PluginService
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
@ -22,7 +23,6 @@ from models.model import UploadFile
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration, RetrievalSetting
|
||||
from services.plugin.plugin_migration import PluginMigration
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ from core.helper.name_generator import generate_incremental_name
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.plugin_service import PluginService
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||
from core.tools.entities.api_entities import (
|
||||
@ -31,7 +32,6 @@ from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider_ids import ToolProviderID
|
||||
from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient
|
||||
from services.plugin.plugin_service import PluginService
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -9,6 +9,7 @@ from configs import dify_config
|
||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||
from core.mcp.types import Tool as MCPTool
|
||||
from core.plugin.entities.plugin_daemon import CredentialType, PluginDatasourceProviderEntity
|
||||
from core.plugin.plugin_service import PluginService
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
@ -27,7 +28,6 @@ from core.tools.utils.encryption import create_provider_encrypter, create_tool_p
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.helper.provider_encryption import ProviderConfigEncrypter, create_provider_encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.plugin.plugin_service import PluginService
|
||||
from core.tools.utils.system_encryption import decrypt_system_params
|
||||
from core.trigger.entities.api_entities import (
|
||||
TriggerProviderApiEntity,
|
||||
@ -37,7 +38,6 @@ from models.trigger import (
|
||||
TriggerSubscription,
|
||||
WorkflowPluginTrigger,
|
||||
)
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -6,11 +6,11 @@ from typing import Any, TypedDict
|
||||
from sqlalchemy import and_, func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.plugin.plugin_service import PluginService
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun
|
||||
from models.enums import AppTriggerType, CreatorUserRole
|
||||
from models.trigger import WorkflowTriggerLog
|
||||
from services.plugin.plugin_service import PluginService
|
||||
from services.workflow.entities import TriggerMetadata
|
||||
|
||||
|
||||
|
||||
@ -9,9 +9,9 @@ from celery import shared_task
|
||||
from core.plugin.entities.marketplace import MarketplacePluginSnapshot
|
||||
from core.plugin.entities.plugin import PluginInstallationSource
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.plugin.plugin_service import PluginService
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import TenantPluginAutoUpgradeStrategy
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
"""Tests for services.plugin.plugin_service.PluginService.
|
||||
"""Tests for core.plugin.plugin_service.PluginService.
|
||||
|
||||
Covers: version caching with Redis, install permission/scope gates,
|
||||
icon URL construction, asset retrieval with MIME guessing, plugin
|
||||
@ -17,11 +17,11 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from core.plugin.entities.plugin import PluginInstallationSource
|
||||
from core.plugin.entities.plugin_daemon import PluginVerification
|
||||
from core.plugin.plugin_service import PluginService
|
||||
from models import ProviderType
|
||||
from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider
|
||||
from services.errors.plugin import PluginInstallationForbiddenError
|
||||
from services.feature_service import PluginInstallationScope
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
|
||||
def _make_features(
|
||||
@ -35,8 +35,8 @@ def _make_features(
|
||||
|
||||
|
||||
class TestFetchLatestPluginVersion:
|
||||
@patch("services.plugin.plugin_service.marketplace")
|
||||
@patch("services.plugin.plugin_service.redis_client")
|
||||
@patch("core.plugin.plugin_service.marketplace")
|
||||
@patch("core.plugin.plugin_service.redis_client")
|
||||
def test_returns_cached_version(self, mock_redis, mock_marketplace):
|
||||
cached_json = PluginService.LatestPluginCache(
|
||||
plugin_id="p1",
|
||||
@ -53,8 +53,8 @@ class TestFetchLatestPluginVersion:
|
||||
assert result["p1"].version == "1.0.0"
|
||||
mock_marketplace.batch_fetch_plugin_manifests.assert_not_called()
|
||||
|
||||
@patch("services.plugin.plugin_service.marketplace")
|
||||
@patch("services.plugin.plugin_service.redis_client")
|
||||
@patch("core.plugin.plugin_service.marketplace")
|
||||
@patch("core.plugin.plugin_service.redis_client")
|
||||
def test_fetches_from_marketplace_on_cache_miss(self, mock_redis, mock_marketplace):
|
||||
mock_redis.get.return_value = None
|
||||
manifest = MagicMock()
|
||||
@ -71,8 +71,8 @@ class TestFetchLatestPluginVersion:
|
||||
assert result["p1"].version == "2.0.0"
|
||||
mock_redis.setex.assert_called_once()
|
||||
|
||||
@patch("services.plugin.plugin_service.marketplace")
|
||||
@patch("services.plugin.plugin_service.redis_client")
|
||||
@patch("core.plugin.plugin_service.marketplace")
|
||||
@patch("core.plugin.plugin_service.redis_client")
|
||||
def test_returns_none_for_unknown_plugin(self, mock_redis, mock_marketplace):
|
||||
mock_redis.get.return_value = None
|
||||
mock_marketplace.batch_fetch_plugin_manifests.return_value = []
|
||||
@ -81,8 +81,8 @@ class TestFetchLatestPluginVersion:
|
||||
|
||||
assert result["unknown"] is None
|
||||
|
||||
@patch("services.plugin.plugin_service.marketplace")
|
||||
@patch("services.plugin.plugin_service.redis_client")
|
||||
@patch("core.plugin.plugin_service.marketplace")
|
||||
@patch("core.plugin.plugin_service.redis_client")
|
||||
def test_handles_marketplace_exception_gracefully(self, mock_redis, mock_marketplace):
|
||||
mock_redis.get.return_value = None
|
||||
mock_marketplace.batch_fetch_plugin_manifests.side_effect = RuntimeError("network error")
|
||||
@ -93,14 +93,14 @@ class TestFetchLatestPluginVersion:
|
||||
|
||||
|
||||
class TestCheckMarketplaceOnlyPermission:
|
||||
@patch("services.plugin.plugin_service.FeatureService")
|
||||
@patch("core.plugin.plugin_service.FeatureService")
|
||||
def test_raises_when_restricted(self, mock_fs):
|
||||
mock_fs.get_system_features.return_value = _make_features(restrict_to_marketplace=True)
|
||||
|
||||
with pytest.raises(PluginInstallationForbiddenError):
|
||||
PluginService._check_marketplace_only_permission()
|
||||
|
||||
@patch("services.plugin.plugin_service.FeatureService")
|
||||
@patch("core.plugin.plugin_service.FeatureService")
|
||||
def test_passes_when_not_restricted(self, mock_fs):
|
||||
mock_fs.get_system_features.return_value = _make_features(restrict_to_marketplace=False)
|
||||
|
||||
@ -108,7 +108,7 @@ class TestCheckMarketplaceOnlyPermission:
|
||||
|
||||
|
||||
class TestCheckPluginInstallationScope:
|
||||
@patch("services.plugin.plugin_service.FeatureService")
|
||||
@patch("core.plugin.plugin_service.FeatureService")
|
||||
def test_official_only_allows_langgenius(self, mock_fs):
|
||||
mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.OFFICIAL_ONLY)
|
||||
verification = MagicMock()
|
||||
@ -116,14 +116,14 @@ class TestCheckPluginInstallationScope:
|
||||
|
||||
PluginService._check_plugin_installation_scope(verification) # should not raise
|
||||
|
||||
@patch("services.plugin.plugin_service.FeatureService")
|
||||
@patch("core.plugin.plugin_service.FeatureService")
|
||||
def test_official_only_rejects_third_party(self, mock_fs):
|
||||
mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.OFFICIAL_ONLY)
|
||||
|
||||
with pytest.raises(PluginInstallationForbiddenError):
|
||||
PluginService._check_plugin_installation_scope(None)
|
||||
|
||||
@patch("services.plugin.plugin_service.FeatureService")
|
||||
@patch("core.plugin.plugin_service.FeatureService")
|
||||
def test_official_and_partners_allows_partner(self, mock_fs):
|
||||
mock_fs.get_system_features.return_value = _make_features(
|
||||
scope=PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS
|
||||
@ -133,7 +133,7 @@ class TestCheckPluginInstallationScope:
|
||||
|
||||
PluginService._check_plugin_installation_scope(verification) # should not raise
|
||||
|
||||
@patch("services.plugin.plugin_service.FeatureService")
|
||||
@patch("core.plugin.plugin_service.FeatureService")
|
||||
def test_official_and_partners_rejects_none(self, mock_fs):
|
||||
mock_fs.get_system_features.return_value = _make_features(
|
||||
scope=PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS
|
||||
@ -142,7 +142,7 @@ class TestCheckPluginInstallationScope:
|
||||
with pytest.raises(PluginInstallationForbiddenError):
|
||||
PluginService._check_plugin_installation_scope(None)
|
||||
|
||||
@patch("services.plugin.plugin_service.FeatureService")
|
||||
@patch("core.plugin.plugin_service.FeatureService")
|
||||
def test_none_scope_always_raises(self, mock_fs):
|
||||
mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.NONE)
|
||||
verification = MagicMock()
|
||||
@ -151,7 +151,7 @@ class TestCheckPluginInstallationScope:
|
||||
with pytest.raises(PluginInstallationForbiddenError):
|
||||
PluginService._check_plugin_installation_scope(verification)
|
||||
|
||||
@patch("services.plugin.plugin_service.FeatureService")
|
||||
@patch("core.plugin.plugin_service.FeatureService")
|
||||
def test_all_scope_passes_any(self, mock_fs):
|
||||
mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.ALL)
|
||||
|
||||
@ -159,7 +159,7 @@ class TestCheckPluginInstallationScope:
|
||||
|
||||
|
||||
class TestGetPluginIconUrl:
|
||||
@patch("services.plugin.plugin_service.dify_config")
|
||||
@patch("core.plugin.plugin_service.dify_config")
|
||||
def test_constructs_url_with_params(self, mock_config):
|
||||
mock_config.CONSOLE_API_URL = "https://console.example.com"
|
||||
|
||||
@ -171,7 +171,7 @@ class TestGetPluginIconUrl:
|
||||
|
||||
|
||||
class TestGetAsset:
|
||||
@patch("services.plugin.plugin_service.PluginAssetManager")
|
||||
@patch("core.plugin.plugin_service.PluginAssetManager")
|
||||
def test_returns_bytes_and_guessed_mime(self, mock_asset_cls):
|
||||
mock_asset_cls.return_value.fetch_asset.return_value = b"<svg/>"
|
||||
|
||||
@ -180,7 +180,7 @@ class TestGetAsset:
|
||||
assert data == b"<svg/>"
|
||||
assert "svg" in mime
|
||||
|
||||
@patch("services.plugin.plugin_service.PluginAssetManager")
|
||||
@patch("core.plugin.plugin_service.PluginAssetManager")
|
||||
def test_fallback_to_octet_stream_for_unknown(self, mock_asset_cls):
|
||||
mock_asset_cls.return_value.fetch_asset.return_value = b"\x00"
|
||||
|
||||
@ -190,13 +190,13 @@ class TestGetAsset:
|
||||
|
||||
|
||||
class TestIsPluginVerified:
|
||||
@patch("services.plugin.plugin_service.PluginInstaller")
|
||||
@patch("core.plugin.plugin_service.PluginInstaller")
|
||||
def test_returns_true_when_verified(self, mock_installer_cls):
|
||||
mock_installer_cls.return_value.fetch_plugin_manifest.return_value.verified = True
|
||||
|
||||
assert PluginService.is_plugin_verified("t1", "uid-1") is True
|
||||
|
||||
@patch("services.plugin.plugin_service.PluginInstaller")
|
||||
@patch("core.plugin.plugin_service.PluginInstaller")
|
||||
def test_returns_false_on_exception(self, mock_installer_cls):
|
||||
mock_installer_cls.return_value.fetch_plugin_manifest.side_effect = RuntimeError("not found")
|
||||
|
||||
@ -204,24 +204,24 @@ class TestIsPluginVerified:
|
||||
|
||||
|
||||
class TestUpgradePluginWithMarketplace:
|
||||
@patch("services.plugin.plugin_service.dify_config")
|
||||
@patch("core.plugin.plugin_service.dify_config")
|
||||
def test_raises_when_marketplace_disabled(self, mock_config):
|
||||
mock_config.MARKETPLACE_ENABLED = False
|
||||
|
||||
with pytest.raises(ValueError, match="marketplace is not enabled"):
|
||||
PluginService.upgrade_plugin_with_marketplace("t1", "old-uid", "new-uid")
|
||||
|
||||
@patch("services.plugin.plugin_service.dify_config")
|
||||
@patch("core.plugin.plugin_service.dify_config")
|
||||
def test_raises_when_same_identifier(self, mock_config):
|
||||
mock_config.MARKETPLACE_ENABLED = True
|
||||
|
||||
with pytest.raises(ValueError, match="same plugin"):
|
||||
PluginService.upgrade_plugin_with_marketplace("t1", "same-uid", "same-uid")
|
||||
|
||||
@patch("services.plugin.plugin_service.marketplace")
|
||||
@patch("services.plugin.plugin_service.FeatureService")
|
||||
@patch("services.plugin.plugin_service.PluginInstaller")
|
||||
@patch("services.plugin.plugin_service.dify_config")
|
||||
@patch("core.plugin.plugin_service.marketplace")
|
||||
@patch("core.plugin.plugin_service.FeatureService")
|
||||
@patch("core.plugin.plugin_service.PluginInstaller")
|
||||
@patch("core.plugin.plugin_service.dify_config")
|
||||
def test_skips_download_when_already_installed(self, mock_config, mock_installer_cls, mock_fs, mock_marketplace):
|
||||
mock_config.MARKETPLACE_ENABLED = True
|
||||
mock_fs.get_system_features.return_value = _make_features()
|
||||
@ -234,10 +234,10 @@ class TestUpgradePluginWithMarketplace:
|
||||
mock_marketplace.record_install_plugin_event.assert_called_once_with("new-uid")
|
||||
installer.upgrade_plugin.assert_called_once()
|
||||
|
||||
@patch("services.plugin.plugin_service.download_plugin_pkg")
|
||||
@patch("services.plugin.plugin_service.FeatureService")
|
||||
@patch("services.plugin.plugin_service.PluginInstaller")
|
||||
@patch("services.plugin.plugin_service.dify_config")
|
||||
@patch("core.plugin.plugin_service.download_plugin_pkg")
|
||||
@patch("core.plugin.plugin_service.FeatureService")
|
||||
@patch("core.plugin.plugin_service.PluginInstaller")
|
||||
@patch("core.plugin.plugin_service.dify_config")
|
||||
def test_downloads_when_not_installed(self, mock_config, mock_installer_cls, mock_fs, mock_download):
|
||||
mock_config.MARKETPLACE_ENABLED = True
|
||||
mock_fs.get_system_features.return_value = _make_features()
|
||||
@ -256,8 +256,8 @@ class TestUpgradePluginWithMarketplace:
|
||||
|
||||
|
||||
class TestUpgradePluginWithGithub:
|
||||
@patch("services.plugin.plugin_service.FeatureService")
|
||||
@patch("services.plugin.plugin_service.PluginInstaller")
|
||||
@patch("core.plugin.plugin_service.FeatureService")
|
||||
@patch("core.plugin.plugin_service.PluginInstaller")
|
||||
def test_checks_marketplace_permission_and_delegates(self, mock_installer_cls, mock_fs):
|
||||
mock_fs.get_system_features.return_value = _make_features()
|
||||
installer = mock_installer_cls.return_value
|
||||
@ -271,8 +271,8 @@ class TestUpgradePluginWithGithub:
|
||||
|
||||
|
||||
class TestUploadPkg:
|
||||
@patch("services.plugin.plugin_service.FeatureService")
|
||||
@patch("services.plugin.plugin_service.PluginInstaller")
|
||||
@patch("core.plugin.plugin_service.FeatureService")
|
||||
@patch("core.plugin.plugin_service.PluginInstaller")
|
||||
def test_runs_permission_and_scope_checks(self, mock_installer_cls, mock_fs):
|
||||
mock_fs.get_system_features.return_value = _make_features()
|
||||
upload_resp = MagicMock()
|
||||
@ -285,17 +285,17 @@ class TestUploadPkg:
|
||||
|
||||
|
||||
class TestInstallFromMarketplacePkg:
|
||||
@patch("services.plugin.plugin_service.dify_config")
|
||||
@patch("core.plugin.plugin_service.dify_config")
|
||||
def test_raises_when_marketplace_disabled(self, mock_config):
|
||||
mock_config.MARKETPLACE_ENABLED = False
|
||||
|
||||
with pytest.raises(ValueError, match="marketplace is not enabled"):
|
||||
PluginService.install_from_marketplace_pkg("t1", ["uid-1"])
|
||||
|
||||
@patch("services.plugin.plugin_service.download_plugin_pkg")
|
||||
@patch("services.plugin.plugin_service.FeatureService")
|
||||
@patch("services.plugin.plugin_service.PluginInstaller")
|
||||
@patch("services.plugin.plugin_service.dify_config")
|
||||
@patch("core.plugin.plugin_service.download_plugin_pkg")
|
||||
@patch("core.plugin.plugin_service.FeatureService")
|
||||
@patch("core.plugin.plugin_service.PluginInstaller")
|
||||
@patch("core.plugin.plugin_service.dify_config")
|
||||
def test_downloads_when_not_cached(self, mock_config, mock_installer_cls, mock_fs, mock_download):
|
||||
mock_config.MARKETPLACE_ENABLED = True
|
||||
mock_fs.get_system_features.return_value = _make_features()
|
||||
@ -315,9 +315,9 @@ class TestInstallFromMarketplacePkg:
|
||||
call_args = installer.install_from_identifiers.call_args[0]
|
||||
assert call_args[1] == ["resolved-uid"]
|
||||
|
||||
@patch("services.plugin.plugin_service.FeatureService")
|
||||
@patch("services.plugin.plugin_service.PluginInstaller")
|
||||
@patch("services.plugin.plugin_service.dify_config")
|
||||
@patch("core.plugin.plugin_service.FeatureService")
|
||||
@patch("core.plugin.plugin_service.PluginInstaller")
|
||||
@patch("core.plugin.plugin_service.dify_config")
|
||||
def test_uses_cached_when_already_downloaded(self, mock_config, mock_installer_cls, mock_fs):
|
||||
mock_config.MARKETPLACE_ENABLED = True
|
||||
mock_fs.get_system_features.return_value = _make_features()
|
||||
@ -336,7 +336,7 @@ class TestInstallFromMarketplacePkg:
|
||||
|
||||
|
||||
class TestUninstall:
|
||||
@patch("services.plugin.plugin_service.PluginInstaller")
|
||||
@patch("core.plugin.plugin_service.PluginInstaller")
|
||||
def test_direct_uninstall_when_plugin_not_found(self, mock_installer_cls):
|
||||
installer = mock_installer_cls.return_value
|
||||
installer.list_plugins.return_value = []
|
||||
@ -347,7 +347,7 @@ class TestUninstall:
|
||||
assert result is True
|
||||
installer.uninstall.assert_called_once_with("t1", "install-1")
|
||||
|
||||
@patch("services.plugin.plugin_service.PluginInstaller")
|
||||
@patch("core.plugin.plugin_service.PluginInstaller")
|
||||
def test_cleans_credentials_when_plugin_found(
|
||||
self, mock_installer_cls, flask_app_with_containers: Flask, db_session_with_containers: Session
|
||||
):
|
||||
@ -389,7 +389,7 @@ class TestUninstall:
|
||||
installer.list_plugins.return_value = [plugin]
|
||||
installer.uninstall.return_value = True
|
||||
|
||||
with patch("services.plugin.plugin_service.dify_config") as mock_config:
|
||||
with patch("core.plugin.plugin_service.dify_config") as mock_config:
|
||||
mock_config.ENTERPRISE_ENABLED = False
|
||||
result = PluginService.uninstall(tenant_id, "install-1")
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.plugin.plugin_service import PluginService
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
@ -20,7 +21,6 @@ from core.tools.entities.tool_entities import (
|
||||
ToolProviderType,
|
||||
)
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||
from services.plugin.plugin_service import PluginService
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
|
||||
@ -31,7 +31,7 @@ class TestToolTransformService:
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with patch("services.tools.tools_transform_service.dify_config") as mock_dify_config:
|
||||
with patch("services.plugin.plugin_service.dify_config", new=mock_dify_config):
|
||||
with patch("core.plugin.plugin_service.dify_config", new=mock_dify_config):
|
||||
# Setup default mock returns
|
||||
mock_dify_config.CONSOLE_API_URL = "https://console.example.com"
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly
|
||||
from core.plugin.plugin_service import PluginService
|
||||
|
||||
|
||||
def test_plugin_model_assembly_reuses_single_runtime_across_views():
|
||||
@ -34,3 +35,11 @@ def test_plugin_model_assembly_reuses_single_runtime_across_views():
|
||||
mock_provider_factory_cls.assert_called_once_with(runtime=runtime)
|
||||
mock_provider_manager_cls.assert_called_once_with(model_runtime=runtime)
|
||||
mock_model_manager_cls.assert_called_once_with(provider_manager=provider_manager)
|
||||
|
||||
|
||||
def test_create_plugin_model_runtime_injects_plugin_service():
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime
|
||||
|
||||
runtime = create_plugin_model_runtime(tenant_id="tenant-1", user_id="user-1")
|
||||
|
||||
assert runtime._plugin_service is PluginService
|
||||
|
||||
@ -12,6 +12,7 @@ from core.plugin.impl import model_runtime as model_runtime_module
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from core.plugin.impl.model_runtime import TENANT_SCOPE_SCHEMA_CACHE_USER_ID, PluginModelRuntime
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime
|
||||
from core.plugin.plugin_service import PluginService
|
||||
from graphon.model_runtime.entities.common_entities import I18nObject
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from graphon.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
@ -19,6 +20,22 @@ from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFr
|
||||
from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||
|
||||
|
||||
class _FakeRedis:
|
||||
def __init__(self) -> None:
|
||||
self._values: dict[str, str] = {}
|
||||
self.setex_calls: list[tuple[str, int, str]] = []
|
||||
|
||||
def get(self, key: str) -> str | None:
|
||||
return self._values.get(key)
|
||||
|
||||
def setex(self, key: str, ttl: int, value: str) -> None:
|
||||
self._values[key] = value
|
||||
self.setex_calls.append((key, ttl, value))
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
self._values.pop(key, None)
|
||||
|
||||
|
||||
def _build_model_schema() -> AIModelEntity:
|
||||
return AIModelEntity(
|
||||
model="gpt-4o-mini",
|
||||
@ -29,6 +46,24 @@ def _build_model_schema() -> AIModelEntity:
|
||||
)
|
||||
|
||||
|
||||
def _build_plugin_model_provider(*, tenant_id: str, provider: str = "openai") -> PluginModelProviderEntity:
|
||||
return PluginModelProviderEntity(
|
||||
id=uuid.uuid4().hex,
|
||||
created_at=datetime.datetime.now(),
|
||||
updated_at=datetime.datetime.now(),
|
||||
provider=provider,
|
||||
tenant_id=tenant_id,
|
||||
plugin_unique_identifier=f"langgenius/{provider}/{provider}",
|
||||
plugin_id=f"langgenius/{provider}",
|
||||
declaration=ProviderEntity(
|
||||
provider=provider,
|
||||
label=I18nObject(en_US=provider.title()),
|
||||
supported_model_types=[],
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestPluginModelRuntime:
|
||||
"""Validate the adapter keeps plugin-specific routing out of the runtime port."""
|
||||
|
||||
@ -51,7 +86,7 @@ class TestPluginModelRuntime:
|
||||
),
|
||||
)
|
||||
]
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService)
|
||||
|
||||
providers = runtime.fetch_model_providers()
|
||||
|
||||
@ -95,7 +130,7 @@ class TestPluginModelRuntime:
|
||||
),
|
||||
),
|
||||
]
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService)
|
||||
|
||||
providers = runtime.fetch_model_providers()
|
||||
|
||||
@ -122,7 +157,7 @@ class TestPluginModelRuntime:
|
||||
),
|
||||
)
|
||||
]
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService)
|
||||
|
||||
providers = runtime.fetch_model_providers()
|
||||
|
||||
@ -131,7 +166,7 @@ class TestPluginModelRuntime:
|
||||
|
||||
def test_validate_provider_credentials_resolves_plugin_fields(self) -> None:
|
||||
client = Mock(spec=PluginModelClient)
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService)
|
||||
|
||||
runtime.validate_provider_credentials(
|
||||
provider="langgenius/openai/openai",
|
||||
@ -173,7 +208,7 @@ class TestPluginModelRuntime:
|
||||
),
|
||||
]
|
||||
)
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService)
|
||||
|
||||
result = runtime.invoke_llm(
|
||||
provider="langgenius/openai/openai",
|
||||
@ -209,7 +244,7 @@ class TestPluginModelRuntime:
|
||||
client = Mock(spec=PluginModelClient)
|
||||
stream_result = iter([])
|
||||
client.invoke_llm.return_value = stream_result
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService)
|
||||
|
||||
result = runtime.invoke_llm(
|
||||
provider="langgenius/openai/openai",
|
||||
@ -240,7 +275,9 @@ class TestPluginModelRuntime:
|
||||
def test_invoke_llm_rejects_per_call_user_override(self) -> None:
|
||||
client = Mock(spec=PluginModelClient)
|
||||
client.invoke_llm.return_value = sentinel.result
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="bound-user", client=client)
|
||||
runtime = PluginModelRuntime(
|
||||
tenant_id="tenant", user_id="bound-user", client=client, plugin_service=PluginService
|
||||
)
|
||||
|
||||
with pytest.raises(TypeError, match="unexpected keyword argument 'user_id'"):
|
||||
runtime.invoke_llm( # type: ignore[call-arg]
|
||||
@ -260,7 +297,7 @@ class TestPluginModelRuntime:
|
||||
def test_invoke_tts_uses_bound_runtime_user_when_runtime_is_unbound(self) -> None:
|
||||
client = Mock(spec=PluginModelClient)
|
||||
client.invoke_tts.return_value = iter([b"chunk"])
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=client)
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=client, plugin_service=PluginService)
|
||||
|
||||
result = runtime.invoke_tts(
|
||||
provider="langgenius/openai/openai",
|
||||
@ -282,15 +319,107 @@ class TestPluginModelRuntime:
|
||||
voice="alloy",
|
||||
)
|
||||
|
||||
def test_fetch_model_providers_uses_bound_runtime_cache(self) -> None:
|
||||
def test_fetch_model_providers_does_not_keep_bound_runtime_cache(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
client = Mock(spec=PluginModelClient)
|
||||
client.fetch_model_providers.return_value = []
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
|
||||
from core.plugin import plugin_service as plugin_service_module
|
||||
|
||||
monkeypatch.setattr(
|
||||
plugin_service_module,
|
||||
"redis_client",
|
||||
SimpleNamespace(
|
||||
get=Mock(return_value=None),
|
||||
delete=Mock(),
|
||||
setex=Mock(),
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(plugin_service_module.dify_config, "PLUGIN_MODEL_PROVIDERS_CACHE_TTL", 300)
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService)
|
||||
|
||||
runtime.fetch_model_providers()
|
||||
runtime.fetch_model_providers()
|
||||
|
||||
client.fetch_model_providers.assert_called_once_with("tenant")
|
||||
assert client.fetch_model_providers.call_count == 2
|
||||
|
||||
def test_fetch_model_providers_uses_tenant_ttl_cache_across_runtime_instances(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
redis = _FakeRedis()
|
||||
from core.plugin import plugin_service as plugin_service_module
|
||||
|
||||
monkeypatch.setattr(plugin_service_module, "redis_client", redis)
|
||||
monkeypatch.setattr(plugin_service_module.dify_config, "PLUGIN_MODEL_PROVIDERS_CACHE_TTL", 300)
|
||||
first_client = Mock(spec=PluginModelClient)
|
||||
first_client.fetch_model_providers.return_value = [_build_plugin_model_provider(tenant_id="tenant")]
|
||||
second_client = Mock(spec=PluginModelClient)
|
||||
first_runtime = PluginModelRuntime(
|
||||
tenant_id="tenant", user_id="user-a", client=first_client, plugin_service=PluginService
|
||||
)
|
||||
second_runtime = PluginModelRuntime(
|
||||
tenant_id="tenant", user_id="user-b", client=second_client, plugin_service=PluginService
|
||||
)
|
||||
|
||||
first_providers = first_runtime.fetch_model_providers()
|
||||
second_providers = second_runtime.fetch_model_providers()
|
||||
|
||||
assert [provider.provider for provider in first_providers] == ["langgenius/openai/openai"]
|
||||
assert [provider.provider for provider in second_providers] == ["langgenius/openai/openai"]
|
||||
first_client.fetch_model_providers.assert_called_once_with("tenant")
|
||||
second_client.fetch_model_providers.assert_not_called()
|
||||
assert redis.setex_calls[0][1] == 300
|
||||
|
||||
def test_fetch_model_providers_cache_is_tenant_isolated(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
redis = _FakeRedis()
|
||||
from core.plugin import plugin_service as plugin_service_module
|
||||
|
||||
monkeypatch.setattr(plugin_service_module, "redis_client", redis)
|
||||
monkeypatch.setattr(plugin_service_module.dify_config, "PLUGIN_MODEL_PROVIDERS_CACHE_TTL", 300)
|
||||
first_client = Mock(spec=PluginModelClient)
|
||||
first_client.fetch_model_providers.return_value = [_build_plugin_model_provider(tenant_id="tenant-a")]
|
||||
second_client = Mock(spec=PluginModelClient)
|
||||
second_client.fetch_model_providers.return_value = [_build_plugin_model_provider(tenant_id="tenant-b")]
|
||||
first_runtime = PluginModelRuntime(
|
||||
tenant_id="tenant-a", user_id="user", client=first_client, plugin_service=PluginService
|
||||
)
|
||||
second_runtime = PluginModelRuntime(
|
||||
tenant_id="tenant-b", user_id="user", client=second_client, plugin_service=PluginService
|
||||
)
|
||||
|
||||
first_providers = first_runtime.fetch_model_providers()
|
||||
second_providers = second_runtime.fetch_model_providers()
|
||||
|
||||
assert [provider.provider for provider in first_providers] == ["langgenius/openai/openai"]
|
||||
assert [provider.provider for provider in second_providers] == ["langgenius/openai/openai"]
|
||||
first_client.fetch_model_providers.assert_called_once_with("tenant-a")
|
||||
second_client.fetch_model_providers.assert_called_once_with("tenant-b")
|
||||
assert len(redis.setex_calls) == 2
|
||||
|
||||
def test_fetch_model_providers_delegates_cache_to_injected_plugin_service(self) -> None:
|
||||
client = Mock(spec=PluginModelClient)
|
||||
service_result = [
|
||||
ProviderEntity(
|
||||
provider="langgenius/openai/openai",
|
||||
label=I18nObject(en_US="OpenAI"),
|
||||
supported_model_types=[],
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
)
|
||||
]
|
||||
fetch_plugin_model_providers = Mock(return_value=service_result)
|
||||
|
||||
class TestPluginService(PluginService):
|
||||
pass
|
||||
|
||||
TestPluginService.fetch_plugin_model_providers = fetch_plugin_model_providers
|
||||
|
||||
runtime = PluginModelRuntime(
|
||||
tenant_id="tenant", user_id="user", client=client, plugin_service=TestPluginService
|
||||
)
|
||||
|
||||
result = runtime.fetch_model_providers()
|
||||
|
||||
assert result is service_result
|
||||
fetch_plugin_model_providers.assert_called_once_with(tenant_id="tenant", client=client)
|
||||
client.fetch_model_providers.assert_not_called()
|
||||
|
||||
|
||||
def test_create_plugin_model_runtime_without_user_context() -> None:
|
||||
@ -301,7 +430,17 @@ def test_create_plugin_model_runtime_without_user_context() -> None:
|
||||
|
||||
def test_plugin_model_runtime_requires_client() -> None:
|
||||
with pytest.raises(ValueError, match="client is required"):
|
||||
PluginModelRuntime(tenant_id="tenant", user_id="user", client=None) # type: ignore[arg-type]
|
||||
PluginModelRuntime(tenant_id="tenant", user_id="user", client=None, plugin_service=PluginService) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_plugin_model_runtime_requires_plugin_service() -> None:
|
||||
with pytest.raises(ValueError, match="plugin_service is required"):
|
||||
PluginModelRuntime(
|
||||
tenant_id="tenant",
|
||||
user_id="user",
|
||||
client=Mock(spec=PluginModelClient),
|
||||
plugin_service=None, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
def test_get_model_schema_uses_cached_schema_without_hitting_client(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@ -317,7 +456,7 @@ def test_get_model_schema_uses_cached_schema_without_hitting_client(monkeypatch:
|
||||
),
|
||||
)
|
||||
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService)
|
||||
result = runtime.get_model_schema(
|
||||
provider="langgenius/openai/openai",
|
||||
model_type=ModelType.LLM,
|
||||
@ -395,7 +534,7 @@ def test_structured_output_adapter_invokes_bound_runtime_non_streaming() -> None
|
||||
|
||||
def test_invoke_llm_with_structured_output_delegates_with_bound_adapter() -> None:
|
||||
client = Mock(spec=PluginModelClient)
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService)
|
||||
schema = _build_model_schema()
|
||||
runtime.get_model_schema = Mock(return_value=schema) # type: ignore[method-assign]
|
||||
|
||||
@ -436,7 +575,7 @@ def test_invoke_llm_with_structured_output_delegates_with_bound_adapter() -> Non
|
||||
|
||||
def test_invoke_llm_with_structured_output_raises_when_model_schema_is_missing() -> None:
|
||||
client = Mock(spec=PluginModelClient)
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService)
|
||||
runtime.get_model_schema = Mock(return_value=None) # type: ignore[method-assign]
|
||||
|
||||
with pytest.raises(ValueError, match="Model schema not found for gpt-4o-mini"):
|
||||
@ -468,7 +607,7 @@ def test_get_model_schema_deletes_invalid_cache_and_refetches(monkeypatch: pytes
|
||||
)
|
||||
monkeypatch.setattr(model_runtime_module.dify_config, "PLUGIN_MODEL_SCHEMA_CACHE_TTL", 300)
|
||||
client.get_model_schema.return_value = schema
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService)
|
||||
|
||||
result = runtime.get_model_schema(
|
||||
provider="langgenius/openai/openai",
|
||||
@ -494,7 +633,7 @@ def test_get_model_schema_deletes_invalid_cache_and_refetches(monkeypatch: pytes
|
||||
def test_get_llm_num_tokens_returns_zero_when_plugin_counting_is_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
client = Mock(spec=PluginModelClient)
|
||||
monkeypatch.setattr(model_runtime_module.dify_config, "PLUGIN_BASED_TOKEN_COUNTING_ENABLED", False)
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService)
|
||||
|
||||
assert (
|
||||
runtime.get_llm_num_tokens(
|
||||
@ -533,7 +672,7 @@ def test_get_provider_icon_reads_requested_variant_and_detects_svg_mime(monkeypa
|
||||
]
|
||||
fetch_asset = Mock(return_value=b"<svg></svg>")
|
||||
monkeypatch.setattr(model_runtime_module.PluginAssetManager, "fetch_asset", fetch_asset)
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService)
|
||||
|
||||
icon_bytes, mime_type = runtime.get_provider_icon(
|
||||
provider="langgenius/openai/openai",
|
||||
@ -565,7 +704,7 @@ def test_get_provider_icon_rejects_unsupported_types_and_missing_variants() -> N
|
||||
),
|
||||
)
|
||||
]
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService)
|
||||
|
||||
with pytest.raises(ValueError, match="does not have small dark icon"):
|
||||
runtime.get_provider_icon(
|
||||
@ -583,7 +722,9 @@ def test_get_provider_icon_rejects_unsupported_types_and_missing_variants() -> N
|
||||
|
||||
|
||||
def test_get_schema_cache_key_is_stable_across_credential_order() -> None:
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=Mock(spec=PluginModelClient))
|
||||
runtime = PluginModelRuntime(
|
||||
tenant_id="tenant", user_id="user", client=Mock(spec=PluginModelClient), plugin_service=PluginService
|
||||
)
|
||||
|
||||
first = runtime._get_schema_cache_key(
|
||||
provider="langgenius/openai/openai",
|
||||
@ -602,8 +743,12 @@ def test_get_schema_cache_key_is_stable_across_credential_order() -> None:
|
||||
|
||||
|
||||
def test_get_schema_cache_key_separates_distinct_user_scopes() -> None:
|
||||
first_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-a", client=Mock(spec=PluginModelClient))
|
||||
second_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-b", client=Mock(spec=PluginModelClient))
|
||||
first_runtime = PluginModelRuntime(
|
||||
tenant_id="tenant", user_id="user-a", client=Mock(spec=PluginModelClient), plugin_service=PluginService
|
||||
)
|
||||
second_runtime = PluginModelRuntime(
|
||||
tenant_id="tenant", user_id="user-b", client=Mock(spec=PluginModelClient), plugin_service=PluginService
|
||||
)
|
||||
|
||||
first = first_runtime._get_schema_cache_key(
|
||||
provider="langgenius/openai/openai",
|
||||
@ -622,8 +767,12 @@ def test_get_schema_cache_key_separates_distinct_user_scopes() -> None:
|
||||
|
||||
|
||||
def test_get_schema_cache_key_separates_tenant_scope_from_user_scope() -> None:
|
||||
tenant_runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=Mock(spec=PluginModelClient))
|
||||
user_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-a", client=Mock(spec=PluginModelClient))
|
||||
tenant_runtime = PluginModelRuntime(
|
||||
tenant_id="tenant", user_id=None, client=Mock(spec=PluginModelClient), plugin_service=PluginService
|
||||
)
|
||||
user_runtime = PluginModelRuntime(
|
||||
tenant_id="tenant", user_id="user-a", client=Mock(spec=PluginModelClient), plugin_service=PluginService
|
||||
)
|
||||
|
||||
tenant_key = tenant_runtime._get_schema_cache_key(
|
||||
provider="langgenius/openai/openai",
|
||||
@ -643,8 +792,12 @@ def test_get_schema_cache_key_separates_tenant_scope_from_user_scope() -> None:
|
||||
|
||||
|
||||
def test_get_schema_cache_key_separates_tenant_scope_from_empty_string_user_scope() -> None:
|
||||
tenant_runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=Mock(spec=PluginModelClient))
|
||||
empty_user_runtime = PluginModelRuntime(tenant_id="tenant", user_id="", client=Mock(spec=PluginModelClient))
|
||||
tenant_runtime = PluginModelRuntime(
|
||||
tenant_id="tenant", user_id=None, client=Mock(spec=PluginModelClient), plugin_service=PluginService
|
||||
)
|
||||
empty_user_runtime = PluginModelRuntime(
|
||||
tenant_id="tenant", user_id="", client=Mock(spec=PluginModelClient), plugin_service=PluginService
|
||||
)
|
||||
|
||||
tenant_key = tenant_runtime._get_schema_cache_key(
|
||||
provider="langgenius/openai/openai",
|
||||
@ -683,7 +836,7 @@ def test_get_provider_schema_supports_short_alias_and_rejects_invalid_provider()
|
||||
),
|
||||
)
|
||||
]
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService)
|
||||
|
||||
assert runtime._get_provider_schema("openai").provider == "langgenius/openai/openai"
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ This test suite covers:
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from decimal import Decimal
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
@ -197,6 +198,55 @@ class TestAppModelValidation:
|
||||
# Assert
|
||||
assert result == AppMode.CHAT
|
||||
|
||||
def test_deleted_tools_checks_plugin_builtin_providers_through_core_plugin_service(self):
|
||||
"""Plugin-backed built-in tools are checked through core PluginService."""
|
||||
# Arrange
|
||||
app = App(
|
||||
tenant_id="tenant-1",
|
||||
name="Test App",
|
||||
mode=AppMode.CHAT,
|
||||
enable_site=True,
|
||||
enable_api=False,
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
app_model_config = AppModelConfig(
|
||||
app_id=str(uuid4()),
|
||||
agent_mode=json.dumps(
|
||||
{
|
||||
"enabled": True,
|
||||
"strategy": "function_call",
|
||||
"tools": [
|
||||
{
|
||||
"provider_type": "builtin",
|
||||
"provider_id": "langgenius/openai/openai",
|
||||
"tool_name": "chat",
|
||||
"tool_parameters": {},
|
||||
}
|
||||
],
|
||||
"prompt": None,
|
||||
}
|
||||
),
|
||||
)
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = MagicMock()
|
||||
session_factory = SimpleNamespace(begin=MagicMock(return_value=session_context))
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch.object(App, "app_model_config", new_callable=lambda: property(lambda self: app_model_config)),
|
||||
patch("models.model.db", SimpleNamespace(engine=object())),
|
||||
patch("models.model.sessionmaker", return_value=session_factory),
|
||||
patch("core.tools.tool_manager.ToolManager.get_hardcoded_provider", side_effect=Exception),
|
||||
patch("core.plugin.plugin_service.PluginService.check_tools_existence", return_value=[False]) as exists,
|
||||
):
|
||||
result = app.deleted_tools
|
||||
|
||||
# Assert
|
||||
assert result == [{"type": "builtin", "tool_name": "chat", "provider_id": "langgenius/openai/openai"}]
|
||||
exists.assert_called_once()
|
||||
assert exists.call_args.args[0] == "tenant-1"
|
||||
assert [str(provider_id) for provider_id in exists.call_args.args[1]] == ["langgenius/openai/openai"]
|
||||
|
||||
|
||||
class TestAppModelConfig:
|
||||
"""Test suite for AppModelConfig model."""
|
||||
|
||||
@ -24,7 +24,7 @@ def make_features(
|
||||
def mock_installer(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Patch PluginInstaller at the service import site."""
|
||||
mock = MagicMock()
|
||||
monkeypatch.setattr("services.plugin.plugin_service.PluginInstaller", lambda: mock)
|
||||
monkeypatch.setattr("core.plugin.plugin_service.PluginInstaller", lambda: mock)
|
||||
return mock
|
||||
|
||||
|
||||
@ -34,6 +34,6 @@ def mock_features():
|
||||
from unittest.mock import patch
|
||||
|
||||
features = make_features()
|
||||
with patch("services.plugin.plugin_service.FeatureService") as mock_fs:
|
||||
with patch("core.plugin.plugin_service.FeatureService") as mock_fs:
|
||||
mock_fs.get_system_features.return_value = features
|
||||
yield features
|
||||
|
||||
@ -61,6 +61,7 @@ class TestHandlePluginInstanceInstall:
|
||||
patch(f"{MIGRATION_MODULE}.dify_config") as mock_cfg,
|
||||
patch(f"{MIGRATION_MODULE}.marketplace") as mock_marketplace,
|
||||
patch(f"{MIGRATION_MODULE}.PluginInstaller") as mock_installer_cls,
|
||||
patch(f"{MIGRATION_MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache,
|
||||
):
|
||||
mock_cfg.MARKETPLACE_ENABLED = True
|
||||
mock_marketplace.download_plugin_pkg.return_value = b"pkg_data"
|
||||
@ -73,4 +74,31 @@ class TestHandlePluginInstanceInstall:
|
||||
)
|
||||
|
||||
mock_marketplace.download_plugin_pkg.assert_called_once()
|
||||
invalidate_cache.assert_called_once_with("tenant1")
|
||||
assert "success" in result or "failed" in result
|
||||
|
||||
def test_install_plugins_invalidates_cache_after_direct_tenant_install(self, tmp_path) -> None:
|
||||
extracted_plugins = tmp_path / "plugins.jsonl"
|
||||
output_file = tmp_path / "output.json"
|
||||
extracted_plugins.write_text('{"tenant_id":"tenant1","plugins":["langgenius/openai"]}\n')
|
||||
|
||||
with (
|
||||
patch(
|
||||
f"{MIGRATION_MODULE}.PluginMigration.extract_unique_plugins",
|
||||
return_value={
|
||||
"plugins": {"langgenius/openai": "langgenius/openai:1.0.0@abc"},
|
||||
"plugin_not_exist": [],
|
||||
},
|
||||
),
|
||||
patch(f"{MIGRATION_MODULE}.PluginMigration.handle_plugin_instance_install", return_value={}),
|
||||
patch(f"{MIGRATION_MODULE}.PluginInstaller") as mock_installer_cls,
|
||||
patch(f"{MIGRATION_MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache,
|
||||
):
|
||||
mock_installer = MagicMock()
|
||||
mock_installer.list_plugins.return_value = []
|
||||
mock_installer_cls.return_value = mock_installer
|
||||
|
||||
PluginMigration.install_plugins(str(extracted_plugins), str(output_file), workers=1)
|
||||
|
||||
mock_installer.install_from_identifiers.assert_called_once()
|
||||
invalidate_cache.assert_called_once_with("tenant1")
|
||||
|
||||
@ -1,6 +1,71 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
import datetime
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
MODULE = "services.plugin.plugin_service"
|
||||
from pydantic import TypeAdapter
|
||||
from redis import RedisError
|
||||
|
||||
from core.plugin.entities.plugin_daemon import PluginInstallTask, PluginInstallTaskStatus, PluginModelProviderEntity
|
||||
from graphon.model_runtime.entities.common_entities import I18nObject
|
||||
from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||
|
||||
MODULE = "core.plugin.plugin_service"
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self) -> None:
|
||||
self.execute = Mock()
|
||||
self.scalars = Mock(return_value=SimpleNamespace(all=Mock(return_value=[])))
|
||||
|
||||
def __enter__(self) -> "_FakeSession":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, traceback) -> None:
|
||||
return None
|
||||
|
||||
def begin(self) -> "_FakeSession":
|
||||
return self
|
||||
|
||||
|
||||
def _build_provider_entity(provider: str = "openai") -> ProviderEntity:
|
||||
return ProviderEntity(
|
||||
provider=f"langgenius/{provider}/{provider}",
|
||||
label=I18nObject(en_US=provider.title()),
|
||||
supported_model_types=[],
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
)
|
||||
|
||||
|
||||
def _build_plugin_model_provider(*, tenant_id: str = "tenant-1", provider: str = "openai") -> PluginModelProviderEntity:
|
||||
return PluginModelProviderEntity(
|
||||
id=uuid.uuid4().hex,
|
||||
created_at=datetime.datetime.now(),
|
||||
updated_at=datetime.datetime.now(),
|
||||
provider=provider,
|
||||
tenant_id=tenant_id,
|
||||
plugin_unique_identifier=f"langgenius/{provider}/{provider}",
|
||||
plugin_id=f"langgenius/{provider}",
|
||||
declaration=ProviderEntity(
|
||||
provider=provider,
|
||||
label=I18nObject(en_US=provider.title()),
|
||||
supported_model_types=[],
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _build_install_task(*, task_id: str = "task-1", status: PluginInstallTaskStatus) -> PluginInstallTask:
|
||||
now = datetime.datetime.now()
|
||||
return PluginInstallTask(
|
||||
id=task_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
status=status,
|
||||
total_plugins=1,
|
||||
completed_plugins=1 if status != PluginInstallTaskStatus.Pending else 0,
|
||||
plugins=[],
|
||||
)
|
||||
|
||||
|
||||
class TestFetchLatestPluginVersion:
|
||||
@ -14,7 +79,7 @@ class TestFetchLatestPluginVersion:
|
||||
mock_cfg.MARKETPLACE_ENABLED = False
|
||||
mock_redis.get.return_value = None # all cache misses
|
||||
|
||||
from services.plugin.plugin_service import PluginService
|
||||
from core.plugin.plugin_service import PluginService
|
||||
|
||||
result = PluginService.fetch_latest_plugin_version(["langgenius/openai", "langgenius/anthropic"])
|
||||
|
||||
@ -40,7 +105,7 @@ class TestFetchLatestPluginVersion:
|
||||
mock_redis.get.return_value = None
|
||||
mock_marketplace.batch_fetch_plugin_manifests.return_value = [manifest]
|
||||
|
||||
from services.plugin.plugin_service import PluginService
|
||||
from core.plugin.plugin_service import PluginService
|
||||
|
||||
result = PluginService.fetch_latest_plugin_version(["langgenius/openai"])
|
||||
|
||||
@ -48,3 +113,322 @@ class TestFetchLatestPluginVersion:
|
||||
mock_marketplace.batch_fetch_plugin_manifests.assert_called_once()
|
||||
assert result["langgenius/openai"] is not None
|
||||
assert result["langgenius/openai"].version == "1.0.0"
|
||||
|
||||
|
||||
class TestPluginModelProviderCache:
|
||||
def test_fetch_plugin_model_providers_returns_cached_provider_without_calling_daemon(self) -> None:
|
||||
"""A valid tenant cache entry is reused across runtime calls without plugin daemon access."""
|
||||
cached_provider = _build_provider_entity()
|
||||
cached_payload = TypeAdapter(list[ProviderEntity]).dump_json([cached_provider]).decode("utf-8")
|
||||
|
||||
with patch(f"{MODULE}.redis_client") as redis_client:
|
||||
redis_client.get.return_value = cached_payload
|
||||
|
||||
from core.plugin.plugin_service import PluginService
|
||||
|
||||
client = Mock()
|
||||
result = PluginService.fetch_plugin_model_providers(tenant_id="tenant-1", client=client)
|
||||
|
||||
assert [provider.provider for provider in result] == ["langgenius/openai/openai"]
|
||||
client.fetch_model_providers.assert_not_called()
|
||||
redis_client.setex.assert_not_called()
|
||||
|
||||
def test_fetch_plugin_model_providers_deletes_invalid_cache_and_refetches(self) -> None:
|
||||
"""Invalid cache payloads are tenant-scoped invalidated before falling back to the daemon."""
|
||||
with (
|
||||
patch(f"{MODULE}.redis_client") as redis_client,
|
||||
patch(f"{MODULE}.dify_config") as mock_config,
|
||||
):
|
||||
redis_client.get.return_value = "not-json"
|
||||
mock_config.PLUGIN_MODEL_PROVIDERS_CACHE_TTL = 86400
|
||||
client = Mock()
|
||||
client.fetch_model_providers.return_value = [_build_plugin_model_provider()]
|
||||
|
||||
from core.plugin.plugin_service import PluginService
|
||||
|
||||
result = PluginService.fetch_plugin_model_providers(tenant_id="tenant-1", client=client)
|
||||
|
||||
cache_key = "plugin_model_providers:tenant_id:tenant-1"
|
||||
redis_client.delete.assert_called_once_with(cache_key)
|
||||
redis_client.setex.assert_called_once()
|
||||
assert redis_client.setex.call_args.args[0] == cache_key
|
||||
assert redis_client.setex.call_args.args[1] == 86400
|
||||
assert [provider.provider for provider in result] == ["langgenius/openai/openai"]
|
||||
|
||||
def test_fetch_plugin_model_providers_refetches_when_cache_read_fails(self) -> None:
|
||||
"""Redis read failures do not block provider discovery for the tenant."""
|
||||
with patch(f"{MODULE}.redis_client") as redis_client:
|
||||
redis_client.get.side_effect = RedisError("redis unavailable")
|
||||
client = Mock()
|
||||
client.fetch_model_providers.return_value = [_build_plugin_model_provider()]
|
||||
|
||||
from core.plugin.plugin_service import PluginService
|
||||
|
||||
result = PluginService.fetch_plugin_model_providers(tenant_id="tenant-1", client=client)
|
||||
|
||||
client.fetch_model_providers.assert_called_once_with("tenant-1")
|
||||
assert [provider.provider for provider in result] == ["langgenius/openai/openai"]
|
||||
|
||||
def test_fetch_plugin_model_providers_returns_fresh_result_when_cache_write_fails(self) -> None:
|
||||
"""Redis write failures are non-fatal after fresh provider data has been fetched."""
|
||||
with patch(f"{MODULE}.redis_client") as redis_client:
|
||||
redis_client.get.return_value = None
|
||||
redis_client.setex.side_effect = RedisError("redis unavailable")
|
||||
client = Mock()
|
||||
client.fetch_model_providers.return_value = [_build_plugin_model_provider()]
|
||||
|
||||
from core.plugin.plugin_service import PluginService
|
||||
|
||||
result = PluginService.fetch_plugin_model_providers(tenant_id="tenant-1", client=client)
|
||||
|
||||
client.fetch_model_providers.assert_called_once_with("tenant-1")
|
||||
assert [provider.provider for provider in result] == ["langgenius/openai/openai"]
|
||||
|
||||
def test_fetch_plugin_model_providers_creates_default_client_on_cache_miss(self) -> None:
|
||||
"""The service owns plugin daemon access when no runtime-provided client is injected."""
|
||||
with (
|
||||
patch(f"{MODULE}.redis_client") as redis_client,
|
||||
patch(f"{MODULE}.PluginModelClient") as client_cls,
|
||||
):
|
||||
redis_client.get.return_value = None
|
||||
client = client_cls.return_value
|
||||
client.fetch_model_providers.return_value = [_build_plugin_model_provider()]
|
||||
|
||||
from core.plugin.plugin_service import PluginService
|
||||
|
||||
result = PluginService.fetch_plugin_model_providers(tenant_id="tenant-1")
|
||||
|
||||
client_cls.assert_called_once_with()
|
||||
client.fetch_model_providers.assert_called_once_with("tenant-1")
|
||||
assert [provider.provider for provider in result] == ["langgenius/openai/openai"]
|
||||
|
||||
def test_invalidate_plugin_model_providers_cache_uses_tenant_cache_key(self) -> None:
|
||||
with patch(f"{MODULE}.redis_client") as redis_client:
|
||||
from core.plugin.plugin_service import PluginService
|
||||
|
||||
PluginService.invalidate_plugin_model_providers_cache("tenant-1")
|
||||
|
||||
redis_client.delete.assert_called_once_with("plugin_model_providers:tenant_id:tenant-1")
|
||||
|
||||
def test_invalidate_plugin_model_providers_cache_ignores_redis_delete_failure(self) -> None:
|
||||
with patch(f"{MODULE}.redis_client") as redis_client:
|
||||
redis_client.delete.side_effect = RedisError("redis unavailable")
|
||||
|
||||
from core.plugin.plugin_service import PluginService
|
||||
|
||||
PluginService.invalidate_plugin_model_providers_cache("tenant-1")
|
||||
|
||||
redis_client.delete.assert_called_once_with("plugin_model_providers:tenant_id:tenant-1")
|
||||
|
||||
|
||||
class TestPluginModelProviderCacheInvalidation:
|
||||
def test_fetch_install_task_invalidates_model_provider_cache_when_finished(self) -> None:
|
||||
"""Finished plugin install tasks invalidate tenant provider cache."""
|
||||
task = _build_install_task(status=PluginInstallTaskStatus.Success)
|
||||
|
||||
with (
|
||||
patch(f"{MODULE}.PluginInstaller") as installer_cls,
|
||||
patch(f"{MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache,
|
||||
):
|
||||
installer_cls.return_value.fetch_plugin_installation_task.return_value = task
|
||||
|
||||
from core.plugin.plugin_service import PluginService
|
||||
|
||||
result = PluginService.fetch_install_task("tenant-1", "task-1")
|
||||
|
||||
assert result is task
|
||||
invalidate_cache.assert_called_once_with("tenant-1")
|
||||
|
||||
def test_fetch_install_tasks_invalidates_model_provider_cache_for_finished_tasks(self) -> None:
|
||||
"""Finished tasks from task list polling also invalidate tenant provider cache."""
|
||||
task = _build_install_task(status=PluginInstallTaskStatus.Success)
|
||||
|
||||
with (
|
||||
patch(f"{MODULE}.PluginInstaller") as installer_cls,
|
||||
patch(f"{MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache,
|
||||
):
|
||||
installer_cls.return_value.fetch_plugin_installation_tasks.return_value = [task]
|
||||
|
||||
from core.plugin.plugin_service import PluginService
|
||||
|
||||
result = PluginService.fetch_install_tasks("tenant-1", 1, 256)
|
||||
|
||||
assert result == [task]
|
||||
invalidate_cache.assert_called_once_with("tenant-1")
|
||||
|
||||
def test_fetch_install_tasks_ignores_running_model_provider_cache_tasks(self) -> None:
|
||||
"""Running plugin install tasks do not invalidate provider cache until they reach a terminal state."""
|
||||
task = _build_install_task(status=PluginInstallTaskStatus.Running)
|
||||
|
||||
with (
|
||||
patch(f"{MODULE}.PluginInstaller") as installer_cls,
|
||||
patch(f"{MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache,
|
||||
):
|
||||
installer_cls.return_value.fetch_plugin_installation_tasks.return_value = [task]
|
||||
|
||||
from core.plugin.plugin_service import PluginService
|
||||
|
||||
result = PluginService.fetch_install_tasks("tenant-1", 1, 256)
|
||||
|
||||
assert result == [task]
|
||||
invalidate_cache.assert_not_called()
|
||||
|
||||
def test_upgrade_plugin_with_marketplace_invalidates_model_provider_cache_for_tenant(self) -> None:
|
||||
"""Marketplace upgrades invalidate only the mutated tenant provider cache."""
|
||||
with (
|
||||
patch(f"{MODULE}.dify_config") as mock_config,
|
||||
patch(f"{MODULE}.FeatureService") as feature_service,
|
||||
patch(f"{MODULE}.PluginInstaller") as installer_cls,
|
||||
patch(f"{MODULE}.marketplace") as marketplace,
|
||||
patch(f"{MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache,
|
||||
):
|
||||
mock_config.MARKETPLACE_ENABLED = True
|
||||
feature_service.get_system_features.return_value = SimpleNamespace(
|
||||
plugin_installation_permission=SimpleNamespace(restrict_to_marketplace_only=False)
|
||||
)
|
||||
installer = installer_cls.return_value
|
||||
installer.fetch_plugin_manifest.return_value = MagicMock()
|
||||
installer.upgrade_plugin.return_value = "task-id"
|
||||
|
||||
from core.plugin.plugin_service import PluginService
|
||||
|
||||
result = PluginService.upgrade_plugin_with_marketplace("tenant-1", "old-uid", "new-uid")
|
||||
|
||||
assert result == "task-id"
|
||||
marketplace.record_install_plugin_event.assert_called_once_with("new-uid")
|
||||
invalidate_cache.assert_called_once_with("tenant-1")
|
||||
|
||||
def test_install_from_local_pkg_invalidates_model_provider_cache_for_tenant(self) -> None:
|
||||
"""Starting a plugin install invalidates only the mutated tenant provider cache."""
|
||||
with (
|
||||
patch(f"{MODULE}.PluginService._check_marketplace_only_permission"),
|
||||
patch(f"{MODULE}.PluginService._check_plugin_installation_scope"),
|
||||
patch(f"{MODULE}.PluginInstaller") as installer_cls,
|
||||
patch(f"{MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache,
|
||||
):
|
||||
installer = installer_cls.return_value
|
||||
decode_response = MagicMock()
|
||||
decode_response.verification = None
|
||||
installer.decode_plugin_from_identifier.return_value = decode_response
|
||||
installer.install_from_identifiers.return_value = "task-id"
|
||||
|
||||
from core.plugin.plugin_service import PluginService
|
||||
|
||||
result = PluginService.install_from_local_pkg("tenant-1", ["langgenius/openai:1.0.0"])
|
||||
|
||||
assert result == "task-id"
|
||||
invalidate_cache.assert_called_once_with("tenant-1")
|
||||
|
||||
def test_upgrade_plugin_with_github_invalidates_model_provider_cache_for_tenant(self) -> None:
|
||||
"""Starting a plugin upgrade invalidates only the mutated tenant provider cache."""
|
||||
with (
|
||||
patch(f"{MODULE}.PluginService._check_marketplace_only_permission"),
|
||||
patch(f"{MODULE}.PluginInstaller") as installer_cls,
|
||||
patch(f"{MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache,
|
||||
):
|
||||
installer = installer_cls.return_value
|
||||
installer.upgrade_plugin.return_value = "task-id"
|
||||
|
||||
from core.plugin.plugin_service import PluginService
|
||||
|
||||
result = PluginService.upgrade_plugin_with_github(
|
||||
"tenant-1", "old-uid", "new-uid", "langgenius/openai", "1.0.0", "openai.difypkg"
|
||||
)
|
||||
|
||||
assert result == "task-id"
|
||||
invalidate_cache.assert_called_once_with("tenant-1")
|
||||
|
||||
def test_install_from_github_invalidates_model_provider_cache_for_tenant(self) -> None:
|
||||
"""GitHub installs invalidate only the mutated tenant provider cache."""
|
||||
with (
|
||||
patch(f"{MODULE}.PluginService._check_marketplace_only_permission"),
|
||||
patch(f"{MODULE}.PluginService._check_plugin_installation_scope"),
|
||||
patch(f"{MODULE}.PluginInstaller") as installer_cls,
|
||||
patch(f"{MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache,
|
||||
):
|
||||
installer = installer_cls.return_value
|
||||
decode_response = MagicMock()
|
||||
decode_response.verification = None
|
||||
installer.decode_plugin_from_identifier.return_value = decode_response
|
||||
installer.install_from_identifiers.return_value = "task-id"
|
||||
|
||||
from core.plugin.plugin_service import PluginService
|
||||
|
||||
result = PluginService.install_from_github(
|
||||
"tenant-1", "langgenius/openai:1.0.0", "langgenius/openai", "1.0.0", "openai.difypkg"
|
||||
)
|
||||
|
||||
assert result == "task-id"
|
||||
invalidate_cache.assert_called_once_with("tenant-1")
|
||||
|
||||
def test_install_from_marketplace_pkg_invalidates_model_provider_cache_for_tenant(self) -> None:
|
||||
"""Marketplace package installs invalidate only the mutated tenant provider cache."""
|
||||
with (
|
||||
patch(f"{MODULE}.dify_config") as mock_config,
|
||||
patch(f"{MODULE}.FeatureService") as feature_service,
|
||||
patch(f"{MODULE}.PluginService._check_plugin_installation_scope"),
|
||||
patch(f"{MODULE}.PluginInstaller") as installer_cls,
|
||||
patch(f"{MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache,
|
||||
):
|
||||
mock_config.MARKETPLACE_ENABLED = True
|
||||
feature_service.get_system_features.return_value = SimpleNamespace(
|
||||
plugin_installation_permission=SimpleNamespace(restrict_to_marketplace_only=False)
|
||||
)
|
||||
installer = installer_cls.return_value
|
||||
installer.fetch_plugin_manifest.return_value = MagicMock()
|
||||
decode_response = MagicMock()
|
||||
decode_response.verification = None
|
||||
installer.decode_plugin_from_identifier.return_value = decode_response
|
||||
installer.install_from_identifiers.return_value = "task-id"
|
||||
|
||||
from core.plugin.plugin_service import PluginService
|
||||
|
||||
result = PluginService.install_from_marketplace_pkg("tenant-1", ["langgenius/openai:1.0.0"])
|
||||
|
||||
assert result == "task-id"
|
||||
invalidate_cache.assert_called_once_with("tenant-1")
|
||||
|
||||
def test_uninstall_invalidates_model_provider_cache_for_tenant(self) -> None:
|
||||
"""Successful uninstall invalidates only the mutated tenant provider cache."""
|
||||
with (
|
||||
patch(f"{MODULE}.PluginInstaller") as installer_cls,
|
||||
patch(f"{MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache,
|
||||
):
|
||||
installer = installer_cls.return_value
|
||||
installer.list_plugins.return_value = []
|
||||
installer.uninstall.return_value = True
|
||||
|
||||
from core.plugin.plugin_service import PluginService
|
||||
|
||||
result = PluginService.uninstall("tenant-1", "installation-1")
|
||||
|
||||
assert result is True
|
||||
invalidate_cache.assert_called_once_with("tenant-1")
|
||||
|
||||
def test_uninstall_existing_plugin_invalidates_cache_after_credential_cleanup(self) -> None:
|
||||
"""Successful uninstall with plugin metadata also invalidates the mutated tenant provider cache."""
|
||||
plugin = SimpleNamespace(
|
||||
installation_id="installation-1",
|
||||
plugin_id="langgenius/openai",
|
||||
plugin_unique_identifier="langgenius/openai:1.0.0",
|
||||
)
|
||||
session = _FakeSession()
|
||||
with (
|
||||
patch(f"{MODULE}.db", SimpleNamespace(engine=object())),
|
||||
patch(f"{MODULE}.dify_config") as mock_config,
|
||||
patch(f"{MODULE}.PluginInstaller") as installer_cls,
|
||||
patch(f"{MODULE}.Session", return_value=session),
|
||||
patch(f"{MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = False
|
||||
installer = installer_cls.return_value
|
||||
installer.list_plugins.return_value = [plugin]
|
||||
installer.uninstall.return_value = True
|
||||
|
||||
from core.plugin.plugin_service import PluginService
|
||||
|
||||
result = PluginService.uninstall("tenant-1", "installation-1")
|
||||
|
||||
assert result is True
|
||||
installer.uninstall.assert_called_once_with("tenant-1", "installation-1")
|
||||
invalidate_cache.assert_called_once_with("tenant-1")
|
||||
|
||||
@ -60,6 +60,7 @@ SSRF_PROXY_HTTPS_URL=http://ssrf_proxy:3128
|
||||
PGDATA=/var/lib/postgresql/data/pgdata
|
||||
PLUGIN_MAX_PACKAGE_SIZE=52428800
|
||||
PLUGIN_MODEL_SCHEMA_CACHE_TTL=3600
|
||||
PLUGIN_MODEL_PROVIDERS_CACHE_TTL=86400
|
||||
ENDPOINT_URL_TEMPLATE=http://localhost/e/{hook_id}
|
||||
LOG_LEVEL=INFO
|
||||
LOG_OUTPUT_FORMAT=text
|
||||
|
||||
Loading…
Reference in New Issue
Block a user