diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 6d290a46ee..34a10ddd04 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -63,6 +63,7 @@ from .app import ( model_config, ops_trace, site, + skills, statistic, workflow, workflow_app_log, @@ -206,6 +207,7 @@ __all__ = [ "saved_message", "setup", "site", + "skills", "spec", "statistic", "tags", diff --git a/api/controllers/console/app/skills.py b/api/controllers/console/app/skills.py new file mode 100644 index 0000000000..0da818a8a0 --- /dev/null +++ b/api/controllers/console/app/skills.py @@ -0,0 +1,72 @@ +from flask_restx import Resource + +from controllers.console import console_ns +from controllers.console.app.error import DraftWorkflowNotExist +from controllers.console.app.wraps import get_app_model +from controllers.console.wraps import account_initialization_required, setup_required +from libs.login import login_required +from models import App +from models.model import AppMode +from services.skill_service import SkillService +from services.workflow_service import WorkflowService + + +@console_ns.route("/apps//workflows/draft/nodes//skills") +class NodeSkillsApi(Resource): + """API for retrieving skill references for a specific workflow node.""" + + @console_ns.doc("get_node_skills") + @console_ns.doc(description="Get skill references for a specific node in the draft workflow") + @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) + @console_ns.response(200, "Node skills retrieved successfully") + @console_ns.response(404, "Workflow or node not found") + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def get(self, app_model: App, node_id: str): + """ + Get skill information for a specific node in the draft workflow. + + Returns information about skill references in the node, including: + - skill_references: List of prompt messages marked as skills + - tool_references: Aggregated tool references from all skill prompts + - file_references: Aggregated file references from all skill prompts + """ + workflow_service = WorkflowService() + workflow = workflow_service.get_draft_workflow(app_model=app_model) + + if not workflow: + raise DraftWorkflowNotExist() + + skill_info = SkillService.get_node_skill_info(workflow=workflow, node_id=node_id) + return skill_info.model_dump() + + +@console_ns.route("/apps//workflows/draft/skills") +class WorkflowSkillsApi(Resource): + """API for retrieving all skill references in a workflow.""" + + @console_ns.doc("get_workflow_skills") + @console_ns.doc(description="Get all skill references in the draft workflow") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.response(200, "Workflow skills retrieved successfully") + @console_ns.response(404, "Workflow not found") + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def get(self, app_model: App): + """ + Get skill information for all nodes in the draft workflow that have skill references. + + Returns a list of nodes with their skill information. + """ + workflow_service = WorkflowService() + workflow = workflow_service.get_draft_workflow(app_model=app_model) + + if not workflow: + raise DraftWorkflowNotExist() + + skills_info = SkillService.get_workflow_skills(workflow=workflow) + return {"nodes": [info.model_dump() for info in skills_info]} diff --git a/api/core/skill/entities/api_entities.py b/api/core/skill/entities/api_entities.py new file mode 100644 index 0000000000..6c0e37011d --- /dev/null +++ b/api/core/skill/entities/api_entities.py @@ -0,0 +1,12 @@ +from pydantic import BaseModel, Field + +from core.skill.entities.tool_dependencies import ToolDependency + + +class NodeSkillInfo(BaseModel): + """Information about skills referenced by a workflow node.""" + + node_id: str = Field(description="The node ID") + tool_dependencies: list[ToolDependency] = Field( + default_factory=list, description="Tool dependencies extracted from skill prompts" + ) diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index e6b25be026..f92f108e9e 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -356,6 +356,15 @@ PromptTemplateItem: TypeAlias = Annotated[ ] +class ToolSetting(BaseModel): + model_config = ConfigDict(extra="forbid") + + type: ToolProviderType + provider: str + tool_name: str + enabled: bool = Field(default=True, description="Whether the tool is enabled") + + class LLMNodeData(BaseNodeData): model: ModelConfig prompt_template: Sequence[PromptTemplateItem] | LLMNodeCompletionModelPromptTemplate @@ -386,6 +395,7 @@ class LLMNodeData(BaseNodeData): # Tool support tools: Sequence[ToolMetadata] = Field(default_factory=list) + tool_settings: Sequence[ToolSetting] = Field(default_factory=list) max_iterations: int | None = Field(default=100, description="Maximum number of iterations for the LLM node") @field_validator("prompt_config", mode="before") diff --git a/api/services/skill_service.py b/api/services/skill_service.py new file mode 100644 index 0000000000..bf4a45c061 --- /dev/null +++ b/api/services/skill_service.py @@ -0,0 +1,108 @@ +import logging + +from core.skill.entities.api_entities import NodeSkillInfo +from core.skill.entities.skill_metadata import ToolReference +from core.skill.entities.tool_dependencies import ToolDependency +from core.workflow.enums import NodeType +from models.workflow import Workflow + +logger = logging.getLogger(__name__) + + +class SkillService: + """ + Service for managing and retrieving skill information from workflows. + """ + + @staticmethod + def get_node_skill_info(workflow: Workflow, node_id: str) -> NodeSkillInfo: + """ + Get skill information for a specific node in a workflow. + + Args: + workflow: The workflow containing the node + node_id: The ID of the node to get skill info for + + Returns: + NodeSkillInfo containing tool dependencies for the node + """ + node_config = workflow.get_node_config_by_id(node_id) + node_data = node_config.get("data", {}) + node_type = node_data.get("type", "") + + # Only LLM nodes support skills currently + if node_type != NodeType.LLM.value: + return NodeSkillInfo(node_id=node_id) + + tool_dependencies = SkillService._extract_tool_dependencies(node_data) + + return NodeSkillInfo( + node_id=node_id, + tool_dependencies=tool_dependencies, + ) + + @staticmethod + def get_workflow_skills(workflow: Workflow) -> list[NodeSkillInfo]: + """ + Get skill information for all nodes in a workflow that have skill references. + + Args: + workflow: The workflow to scan for skills + + Returns: + List of NodeSkillInfo for nodes that have skill references + """ + result: list[NodeSkillInfo] = [] + + # Only scan LLM nodes since they're the only ones that support skills + for node_id, node_data in workflow.walk_nodes(specific_node_type=NodeType.LLM): + has_skill = SkillService._has_skill(node_data) + + if has_skill: + tool_dependencies = SkillService._extract_tool_dependencies(node_data) + result.append( + NodeSkillInfo( + node_id=node_id, + tool_dependencies=tool_dependencies, + ) + ) + + return result + + @staticmethod + def _has_skill(node_data: dict) -> bool: + """Check if node has any skill prompts.""" + prompt_template = node_data.get("prompt_template", []) + if isinstance(prompt_template, list): + for prompt in prompt_template: + if isinstance(prompt, dict) and prompt.get("skill", False): + return True + return False + + @staticmethod + def _extract_tool_dependencies(node_data: dict) -> list[ToolDependency]: + """Extract deduplicated tool dependencies from node data.""" + dependencies: dict[str, ToolDependency] = {} + + prompt_template = node_data.get("prompt_template", []) + if isinstance(prompt_template, list): + for prompt in prompt_template: + if isinstance(prompt, dict) and prompt.get("skill", False): + metadata_dict = prompt.get("metadata") or {} + tools_dict = metadata_dict.get("tools", {}) + + for uuid, tool_data in tools_dict.items(): + if isinstance(tool_data, dict): + try: + ref = ToolReference.model_validate({"uuid": uuid, **tool_data}) + key = f"{ref.provider}.{ref.tool_name}" + if key not in dependencies: + dependencies[key] = ToolDependency( + type=ref.type, + provider=ref.provider, + tool_name=ref.tool_name, + ) + except Exception: + logger.debug("Skipping invalid tool reference: uuid=%s", uuid) + + return list(dependencies.values())