import json import logging import uuid from collections.abc import Mapping from datetime import UTC, datetime from enum import StrEnum from urllib.parse import urlparse import yaml from packaging import version from pydantic import BaseModel, Field from sqlalchemy import select from sqlalchemy.orm import Session from core.helper import ssrf_proxy from core.plugin.entities.plugin import PluginDependency from extensions.ext_redis import redis_client from graphon.enums import BuiltinNodeTypes from graphon.model_runtime.utils.encoders import jsonable_encoder from models import Account from models.snippet import CustomizedSnippet, SnippetType from models.workflow import Workflow from services.plugin.dependencies_analysis import DependenciesAnalysisService from services.snippet_service import SNIPPET_FORBIDDEN_NODE_TYPES, SnippetService logger = logging.getLogger(__name__) IMPORT_INFO_REDIS_KEY_PREFIX = "snippet_import_info:" CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "snippet_check_dependencies:" IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB CURRENT_DSL_VERSION = "0.1.0" class ImportMode(StrEnum): YAML_CONTENT = "yaml-content" YAML_URL = "yaml-url" class ImportStatus(StrEnum): COMPLETED = "completed" COMPLETED_WITH_WARNINGS = "completed-with-warnings" PENDING = "pending" FAILED = "failed" class SnippetImportInfo(BaseModel): id: str status: ImportStatus snippet_id: str | None = None current_dsl_version: str = CURRENT_DSL_VERSION imported_dsl_version: str = "" error: str = "" class CheckDependenciesResult(BaseModel): leaked_dependencies: list[PluginDependency] = Field(default_factory=list) def _check_version_compatibility(imported_version: str) -> ImportStatus: """Determine import status based on version comparison""" try: current_ver = version.parse(CURRENT_DSL_VERSION) imported_ver = version.parse(imported_version) except version.InvalidVersion: return ImportStatus.FAILED # If imported version is newer than current, always return PENDING if imported_ver > current_ver: return ImportStatus.PENDING # If imported version is older than current's major, return PENDING if imported_ver.major < current_ver.major: return ImportStatus.PENDING # If imported version is older than current's minor, return COMPLETED_WITH_WARNINGS if imported_ver.minor < current_ver.minor: return ImportStatus.COMPLETED_WITH_WARNINGS # If imported version equals or is older than current's micro, return COMPLETED return ImportStatus.COMPLETED class SnippetPendingData(BaseModel): import_mode: str yaml_content: str name: str | None = None description: str | None = None snippet_id: str | None class CheckDependenciesPendingData(BaseModel): dependencies: list[PluginDependency] snippet_id: str | None class SnippetDslService: def __init__(self, session: Session): self._session = session def _snippet_service(self) -> SnippetService: return SnippetService(session=self._session) def import_snippet( self, *, account: Account, import_mode: str, yaml_content: str | None = None, yaml_url: str | None = None, snippet_id: str | None = None, name: str | None = None, description: str | None = None, ) -> SnippetImportInfo: """Import a snippet from YAML content or URL.""" import_id = str(uuid.uuid4()) # Validate import mode try: mode = ImportMode(import_mode) except ValueError: raise ValueError(f"Invalid import_mode: {import_mode}") # Get YAML content content: str = "" if mode == ImportMode.YAML_URL: if not yaml_url: return SnippetImportInfo( id=import_id, status=ImportStatus.FAILED, error="yaml_url is required when import_mode is yaml-url", ) try: parsed_url = urlparse(yaml_url) if parsed_url.scheme not in ["http", "https"]: return SnippetImportInfo( id=import_id, status=ImportStatus.FAILED, error="Invalid URL scheme, only http and https are allowed", ) response = ssrf_proxy.get(yaml_url, timeout=(10, 30)) if response.status_code != 200: return SnippetImportInfo( id=import_id, status=ImportStatus.FAILED, error=f"Failed to fetch YAML from URL: {response.status_code}", ) content = response.text if len(content) > DSL_MAX_SIZE: return SnippetImportInfo( id=import_id, status=ImportStatus.FAILED, error=f"YAML content size exceeds maximum limit of {DSL_MAX_SIZE} bytes", ) except Exception as e: logger.exception("Failed to fetch YAML from URL") return SnippetImportInfo( id=import_id, status=ImportStatus.FAILED, error=f"Failed to fetch YAML from URL: {str(e)}", ) elif mode == ImportMode.YAML_CONTENT: if not yaml_content: return SnippetImportInfo( id=import_id, status=ImportStatus.FAILED, error="yaml_content is required when import_mode is yaml-content", ) content = yaml_content if len(content) > DSL_MAX_SIZE: return SnippetImportInfo( id=import_id, status=ImportStatus.FAILED, error=f"YAML content size exceeds maximum limit of {DSL_MAX_SIZE} bytes", ) try: # Parse YAML data = yaml.safe_load(content) if not isinstance(data, dict): return SnippetImportInfo( id=import_id, status=ImportStatus.FAILED, error="Invalid YAML format: expected a dictionary", ) # Validate and fix DSL version if not data.get("version"): data["version"] = "0.1.0" # Strictly validate kind field kind = data.get("kind") if not kind: return SnippetImportInfo( id=import_id, status=ImportStatus.FAILED, error="Missing 'kind' field in DSL. Expected 'kind: snippet'.", ) if kind != "snippet": return SnippetImportInfo( id=import_id, status=ImportStatus.FAILED, error=f"Invalid DSL kind: expected 'snippet', got '{kind}'. This DSL is for {kind}, not snippet.", ) imported_version = data.get("version", "0.1.0") if not isinstance(imported_version, str): raise ValueError(f"Invalid version type, expected str, got {type(imported_version)}") status = _check_version_compatibility(imported_version) # Extract snippet data snippet_data = data.get("snippet") if not snippet_data: return SnippetImportInfo( id=import_id, status=ImportStatus.FAILED, error="Missing snippet data in YAML content", ) # Validate workflow nodes - check for forbidden node types workflow_data = data.get("workflow", {}) if workflow_data: graph = workflow_data.get("graph", {}) nodes = graph.get("nodes", []) forbidden_nodes_found = [] for node in nodes: node_data = node.get("data", {}) if not node_data: continue node_type = node_data.get("type", "") if node_type in SNIPPET_FORBIDDEN_NODE_TYPES: forbidden_nodes_found.append(node_type) if forbidden_nodes_found: forbidden_types_str = ", ".join(set(forbidden_nodes_found)) return SnippetImportInfo( id=import_id, status=ImportStatus.FAILED, error=f"Snippet cannot contain the following node types: {forbidden_types_str}", ) # If snippet_id is provided, check if it exists snippet = None if snippet_id: stmt = select(CustomizedSnippet).where( CustomizedSnippet.id == snippet_id, CustomizedSnippet.tenant_id == account.current_tenant_id, ) snippet = self._session.scalar(stmt) if not snippet: return SnippetImportInfo( id=import_id, status=ImportStatus.FAILED, error="Snippet not found", ) # If major version mismatch, store import info in Redis if status == ImportStatus.PENDING: pending_data = SnippetPendingData( import_mode=import_mode, yaml_content=content, name=name, description=description, snippet_id=snippet_id, ) redis_client.setex( f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}", IMPORT_INFO_REDIS_EXPIRY, pending_data.model_dump_json(), ) return SnippetImportInfo( id=import_id, status=status, snippet_id=snippet_id, imported_dsl_version=imported_version, ) # Extract dependencies dependencies = data.get("dependencies", []) check_dependencies_pending_data = None if dependencies: check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies] # Create or update snippet snippet = self._create_or_update_snippet( snippet=snippet, data=data, account=account, name=name, description=description, dependencies=check_dependencies_pending_data, ) return SnippetImportInfo( id=import_id, status=status, snippet_id=snippet.id, imported_dsl_version=imported_version, ) except yaml.YAMLError as e: return SnippetImportInfo( id=import_id, status=ImportStatus.FAILED, error=f"Invalid YAML format: {str(e)}", ) except Exception as e: logger.exception("Failed to import snippet") return SnippetImportInfo( id=import_id, status=ImportStatus.FAILED, error=str(e), ) def confirm_import(self, *, import_id: str, account: Account) -> SnippetImportInfo: """ Confirm an import that requires confirmation """ redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}" pending_data = redis_client.get(redis_key) if not pending_data: return SnippetImportInfo( id=import_id, status=ImportStatus.FAILED, error="Import information expired or does not exist", ) try: if not isinstance(pending_data, str | bytes): return SnippetImportInfo( id=import_id, status=ImportStatus.FAILED, error="Invalid import information", ) pending_data_str = pending_data.decode("utf-8") if isinstance(pending_data, bytes) else pending_data pending = SnippetPendingData.model_validate_json(pending_data_str) data = yaml.safe_load(pending.yaml_content) if not isinstance(data, dict): return SnippetImportInfo( id=import_id, status=ImportStatus.FAILED, error="Invalid YAML format: expected a dictionary", ) snippet = None if pending.snippet_id: stmt = select(CustomizedSnippet).where( CustomizedSnippet.id == pending.snippet_id, CustomizedSnippet.tenant_id == account.current_tenant_id, ) snippet = self._session.scalar(stmt) snippet = self._create_or_update_snippet( snippet=snippet, data=data, account=account, name=pending.name, description=pending.description, ) redis_client.delete(redis_key) return SnippetImportInfo( id=import_id, status=ImportStatus.COMPLETED, snippet_id=snippet.id, imported_dsl_version=data.get("version", "0.1.0"), ) except Exception as e: logger.exception("Failed to confirm import") return SnippetImportInfo( id=import_id, status=ImportStatus.FAILED, error=str(e), ) def check_dependencies(self, snippet: CustomizedSnippet) -> CheckDependenciesResult: """ Check dependencies for a snippet """ snippet_service = self._snippet_service() workflow = snippet_service.get_draft_workflow(snippet=snippet) if not workflow: return CheckDependenciesResult(leaked_dependencies=[]) dependencies = self._extract_dependencies_from_workflow(workflow) leaked_dependencies = DependenciesAnalysisService.generate_dependencies( tenant_id=snippet.tenant_id, dependencies=dependencies ) return CheckDependenciesResult(leaked_dependencies=leaked_dependencies) def _create_or_update_snippet( self, *, snippet: CustomizedSnippet | None, data: dict, account: Account, name: str | None = None, description: str | None = None, dependencies: list[PluginDependency] | None = None, ) -> CustomizedSnippet: """ Create or update snippet from DSL data """ snippet_data = data.get("snippet", {}) workflow_data = data.get("workflow", {}) # Extract snippet info snippet_name = name or snippet_data.get("name") or "Untitled Snippet" snippet_description = description or snippet_data.get("description") or "" snippet_type_str = snippet_data.get("type", "node") try: snippet_type = SnippetType(snippet_type_str) except ValueError: snippet_type = SnippetType.NODE icon_info = snippet_data.get("icon_info", {}) input_fields = snippet_data.get("input_fields", []) # Create or update snippet if snippet: # Update existing snippet snippet.name = snippet_name snippet.description = snippet_description snippet.type = snippet_type.value snippet.icon_info = icon_info or None snippet.input_fields = json.dumps(input_fields) if input_fields else None snippet.updated_by = account.id snippet.updated_at = datetime.now(UTC).replace(tzinfo=None) else: # Create new snippet snippet = CustomizedSnippet( tenant_id=account.current_tenant_id, name=snippet_name, description=snippet_description, type=snippet_type.value, icon_info=icon_info or None, input_fields=json.dumps(input_fields) if input_fields else None, created_by=account.id, ) self._session.add(snippet) self._session.flush() # Create or update draft workflow if workflow_data: graph = workflow_data.get("graph", {}) snippet_service = self._snippet_service() # Get existing workflow hash if exists existing_workflow = snippet_service.get_draft_workflow(snippet=snippet) unique_hash = existing_workflow.unique_hash if existing_workflow else None snippet_service.sync_draft_workflow( snippet=snippet, graph=graph, unique_hash=unique_hash, account=account, input_fields=input_fields, ) self._session.commit() return snippet def export_snippet_dsl(self, snippet: CustomizedSnippet, include_secret: bool = False) -> str: """ Export snippet as DSL :param snippet: CustomizedSnippet instance :param include_secret: Whether include secret variable :return: YAML string """ snippet_service = self._snippet_service() workflow = snippet_service.get_draft_workflow(snippet=snippet) if not workflow: raise ValueError("Missing draft workflow configuration, please check.") icon_info = snippet.icon_info or {} export_data = { "version": CURRENT_DSL_VERSION, "kind": "snippet", "snippet": { "name": snippet.name, "description": snippet.description or "", "type": snippet.type, "icon_info": icon_info, "input_fields": snippet.input_fields_list, }, } self._append_workflow_export_data( export_data=export_data, snippet=snippet, workflow=workflow, include_secret=include_secret ) return yaml.dump(export_data, allow_unicode=True) def _append_workflow_export_data( self, *, export_data: dict, snippet: CustomizedSnippet, workflow: Workflow, include_secret: bool ) -> None: """ Append workflow export data """ workflow_dict = workflow.to_dict(include_secret=include_secret) # Filter workspace related data from nodes workflow_dict["environment_variables"] = [] workflow_dict["conversation_variables"] = [] for node in workflow_dict.get("graph", {}).get("nodes", []): node_data = node.get("data", {}) if not node_data: continue data_type = node_data.get("type", "") if data_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: dataset_ids = node_data.get("dataset_ids", []) node["data"]["dataset_ids"] = [ self._encrypt_dataset_id(dataset_id=dataset_id, tenant_id=snippet.tenant_id) for dataset_id in dataset_ids ] # filter credential id from tool node if not include_secret and data_type == BuiltinNodeTypes.TOOL: node_data.pop("credential_id", None) # filter credential id from agent node if not include_secret and data_type == BuiltinNodeTypes.AGENT: for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []): tool.pop("credential_id", None) export_data["workflow"] = workflow_dict dependencies = self._extract_dependencies_from_workflow(workflow) export_data["dependencies"] = [ jsonable_encoder(d.model_dump()) for d in DependenciesAnalysisService.generate_dependencies( tenant_id=snippet.tenant_id, dependencies=dependencies ) ] def _encrypt_dataset_id(self, *, dataset_id: str, tenant_id: str) -> str: """ Encrypt dataset ID for export """ # For now, just return the dataset_id as-is # In the future, we might want to encrypt it return dataset_id def _extract_dependencies_from_workflow(self, workflow: Workflow) -> list[str]: """ Extract dependencies from workflow :param workflow: Workflow instance :return: dependencies list format like ["langgenius/google"] """ graph = workflow.graph_dict dependencies = self._extract_dependencies_from_workflow_graph(graph) return dependencies def _extract_dependencies_from_workflow_graph(self, graph: Mapping) -> list[str]: """ Extract dependencies from workflow graph :param graph: Workflow graph :return: dependencies list format like ["langgenius/google"] """ dependencies = [] for node in graph.get("nodes", []): node_data = node.get("data", {}) if not node_data: continue data_type = node_data.get("type", "") if data_type == BuiltinNodeTypes.TOOL: tool_config = node_data.get("tool_configurations", {}) provider_type = tool_config.get("provider_type") provider_name = tool_config.get("provider") if provider_type and provider_name: dependencies.append(f"{provider_name}/{provider_name}") elif data_type == BuiltinNodeTypes.AGENT: agent_parameters = node_data.get("agent_parameters", {}) tools = agent_parameters.get("tools", {}).get("value", []) for tool in tools: provider_type = tool.get("provider_type") provider_name = tool.get("provider") if provider_type and provider_name: dependencies.append(f"{provider_name}/{provider_name}") return dependencies