mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 05:56:31 +08:00
feat(skills): implement API endpoints for retrieving skill references in workflows and add related data models
This commit is contained in:
parent
a4a85f7168
commit
951af125af
@ -63,6 +63,7 @@ from .app import (
|
||||
model_config,
|
||||
ops_trace,
|
||||
site,
|
||||
skills,
|
||||
statistic,
|
||||
workflow,
|
||||
workflow_app_log,
|
||||
@ -206,6 +207,7 @@ __all__ = [
|
||||
"saved_message",
|
||||
"setup",
|
||||
"site",
|
||||
"skills",
|
||||
"spec",
|
||||
"statistic",
|
||||
"tags",
|
||||
|
||||
72
api/controllers/console/app/skills.py
Normal file
72
api/controllers/console/app/skills.py
Normal file
@ -0,0 +1,72 @@
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import DraftWorkflowNotExist
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from services.skill_service import SkillService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/skills")
|
||||
class NodeSkillsApi(Resource):
|
||||
"""API for retrieving skill references for a specific workflow node."""
|
||||
|
||||
@console_ns.doc("get_node_skills")
|
||||
@console_ns.doc(description="Get skill references for a specific node in the draft workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.response(200, "Node skills retrieved successfully")
|
||||
@console_ns.response(404, "Workflow or node not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Get skill information for a specific node in the draft workflow.
|
||||
|
||||
Returns information about skill references in the node, including:
|
||||
- skill_references: List of prompt messages marked as skills
|
||||
- tool_references: Aggregated tool references from all skill prompts
|
||||
- file_references: Aggregated file references from all skill prompts
|
||||
"""
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.get_draft_workflow(app_model=app_model)
|
||||
|
||||
if not workflow:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
skill_info = SkillService.get_node_skill_info(workflow=workflow, node_id=node_id)
|
||||
return skill_info.model_dump()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/skills")
|
||||
class WorkflowSkillsApi(Resource):
|
||||
"""API for retrieving all skill references in a workflow."""
|
||||
|
||||
@console_ns.doc("get_workflow_skills")
|
||||
@console_ns.doc(description="Get all skill references in the draft workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Workflow skills retrieved successfully")
|
||||
@console_ns.response(404, "Workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get skill information for all nodes in the draft workflow that have skill references.
|
||||
|
||||
Returns a list of nodes with their skill information.
|
||||
"""
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.get_draft_workflow(app_model=app_model)
|
||||
|
||||
if not workflow:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
skills_info = SkillService.get_workflow_skills(workflow=workflow)
|
||||
return {"nodes": [info.model_dump() for info in skills_info]}
|
||||
12
api/core/skill/entities/api_entities.py
Normal file
12
api/core/skill/entities/api_entities.py
Normal file
@ -0,0 +1,12 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.skill.entities.tool_dependencies import ToolDependency
|
||||
|
||||
|
||||
class NodeSkillInfo(BaseModel):
|
||||
"""Information about skills referenced by a workflow node."""
|
||||
|
||||
node_id: str = Field(description="The node ID")
|
||||
tool_dependencies: list[ToolDependency] = Field(
|
||||
default_factory=list, description="Tool dependencies extracted from skill prompts"
|
||||
)
|
||||
@ -356,6 +356,15 @@ PromptTemplateItem: TypeAlias = Annotated[
|
||||
]
|
||||
|
||||
|
||||
class ToolSetting(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
type: ToolProviderType
|
||||
provider: str
|
||||
tool_name: str
|
||||
enabled: bool = Field(default=True, description="Whether the tool is enabled")
|
||||
|
||||
|
||||
class LLMNodeData(BaseNodeData):
|
||||
model: ModelConfig
|
||||
prompt_template: Sequence[PromptTemplateItem] | LLMNodeCompletionModelPromptTemplate
|
||||
@ -386,6 +395,7 @@ class LLMNodeData(BaseNodeData):
|
||||
|
||||
# Tool support
|
||||
tools: Sequence[ToolMetadata] = Field(default_factory=list)
|
||||
tool_settings: Sequence[ToolSetting] = Field(default_factory=list)
|
||||
max_iterations: int | None = Field(default=100, description="Maximum number of iterations for the LLM node")
|
||||
|
||||
@field_validator("prompt_config", mode="before")
|
||||
|
||||
108
api/services/skill_service.py
Normal file
108
api/services/skill_service.py
Normal file
@ -0,0 +1,108 @@
|
||||
import logging
|
||||
|
||||
from core.skill.entities.api_entities import NodeSkillInfo
|
||||
from core.skill.entities.skill_metadata import ToolReference
|
||||
from core.skill.entities.tool_dependencies import ToolDependency
|
||||
from core.workflow.enums import NodeType
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SkillService:
|
||||
"""
|
||||
Service for managing and retrieving skill information from workflows.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_node_skill_info(workflow: Workflow, node_id: str) -> NodeSkillInfo:
|
||||
"""
|
||||
Get skill information for a specific node in a workflow.
|
||||
|
||||
Args:
|
||||
workflow: The workflow containing the node
|
||||
node_id: The ID of the node to get skill info for
|
||||
|
||||
Returns:
|
||||
NodeSkillInfo containing tool dependencies for the node
|
||||
"""
|
||||
node_config = workflow.get_node_config_by_id(node_id)
|
||||
node_data = node_config.get("data", {})
|
||||
node_type = node_data.get("type", "")
|
||||
|
||||
# Only LLM nodes support skills currently
|
||||
if node_type != NodeType.LLM.value:
|
||||
return NodeSkillInfo(node_id=node_id)
|
||||
|
||||
tool_dependencies = SkillService._extract_tool_dependencies(node_data)
|
||||
|
||||
return NodeSkillInfo(
|
||||
node_id=node_id,
|
||||
tool_dependencies=tool_dependencies,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_workflow_skills(workflow: Workflow) -> list[NodeSkillInfo]:
|
||||
"""
|
||||
Get skill information for all nodes in a workflow that have skill references.
|
||||
|
||||
Args:
|
||||
workflow: The workflow to scan for skills
|
||||
|
||||
Returns:
|
||||
List of NodeSkillInfo for nodes that have skill references
|
||||
"""
|
||||
result: list[NodeSkillInfo] = []
|
||||
|
||||
# Only scan LLM nodes since they're the only ones that support skills
|
||||
for node_id, node_data in workflow.walk_nodes(specific_node_type=NodeType.LLM):
|
||||
has_skill = SkillService._has_skill(node_data)
|
||||
|
||||
if has_skill:
|
||||
tool_dependencies = SkillService._extract_tool_dependencies(node_data)
|
||||
result.append(
|
||||
NodeSkillInfo(
|
||||
node_id=node_id,
|
||||
tool_dependencies=tool_dependencies,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _has_skill(node_data: dict) -> bool:
|
||||
"""Check if node has any skill prompts."""
|
||||
prompt_template = node_data.get("prompt_template", [])
|
||||
if isinstance(prompt_template, list):
|
||||
for prompt in prompt_template:
|
||||
if isinstance(prompt, dict) and prompt.get("skill", False):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_dependencies(node_data: dict) -> list[ToolDependency]:
|
||||
"""Extract deduplicated tool dependencies from node data."""
|
||||
dependencies: dict[str, ToolDependency] = {}
|
||||
|
||||
prompt_template = node_data.get("prompt_template", [])
|
||||
if isinstance(prompt_template, list):
|
||||
for prompt in prompt_template:
|
||||
if isinstance(prompt, dict) and prompt.get("skill", False):
|
||||
metadata_dict = prompt.get("metadata") or {}
|
||||
tools_dict = metadata_dict.get("tools", {})
|
||||
|
||||
for uuid, tool_data in tools_dict.items():
|
||||
if isinstance(tool_data, dict):
|
||||
try:
|
||||
ref = ToolReference.model_validate({"uuid": uuid, **tool_data})
|
||||
key = f"{ref.provider}.{ref.tool_name}"
|
||||
if key not in dependencies:
|
||||
dependencies[key] = ToolDependency(
|
||||
type=ref.type,
|
||||
provider=ref.provider,
|
||||
tool_name=ref.tool_name,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Skipping invalid tool reference: uuid=%s", uuid)
|
||||
|
||||
return list(dependencies.values())
|
||||
Loading…
Reference in New Issue
Block a user