diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index b4f2ef0ba8..f3534b7e9a 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -46,6 +46,8 @@ from models.workflow import Workflow from services.app_generate_service import AppGenerateService from services.errors.app import WorkflowHashNotEqualError from services.errors.llm import InvokeRateLimitError +from services.workflow.entities import MentionGraphRequest, MentionParameterSchema +from services.workflow.mention_graph_service import MentionGraphService from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService logger = logging.getLogger(__name__) @@ -188,6 +190,15 @@ class DraftWorkflowTriggerRunAllPayload(BaseModel): node_ids: list[str] +class MentionGraphPayload(BaseModel): + """Request payload for generating mention graph.""" + + parent_node_id: str = Field(description="ID of the parent node that uses the extracted value") + parameter_key: str = Field(description="Key of the parameter being extracted") + context_source: list[str] = Field(description="Variable selector for the context source") + parameter_schema: dict[str, Any] = Field(description="Schema of the parameter to extract") + + def reg(cls: type[BaseModel]): console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) @@ -205,6 +216,7 @@ reg(WorkflowListQuery) reg(WorkflowUpdatePayload) reg(DraftWorkflowTriggerRunPayload) reg(DraftWorkflowTriggerRunAllPayload) +reg(MentionGraphPayload) # TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing @@ -1166,3 +1178,54 @@ class DraftWorkflowTriggerRunAllApi(Resource): "status": "error", } ), 400 + + +@console_ns.route("/apps//workflows/draft/mention-graph") +class MentionGraphApi(Resource): + """ + API for generating Mention LLM node graph structures. + + This endpoint creates a complete graph structure containing an LLM node + configured to extract values from list[PromptMessage] variables. + """ + + @console_ns.doc("generate_mention_graph") + @console_ns.doc(description="Generate a Mention LLM node graph structure") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.expect(console_ns.models[MentionGraphPayload.__name__]) + @console_ns.response(200, "Mention graph generated successfully") + @console_ns.response(400, "Invalid request parameters") + @console_ns.response(403, "Permission denied") + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @edit_permission_required + def post(self, app_model: App): + """ + Generate a Mention LLM node graph structure. + + Returns a complete graph structure containing a single LLM node + configured for extracting values from list[PromptMessage] context. + """ + + payload = MentionGraphPayload.model_validate(console_ns.payload or {}) + + parameter_schema = MentionParameterSchema( + name=payload.parameter_schema.get("name", payload.parameter_key), + type=payload.parameter_schema.get("type", "string"), + description=payload.parameter_schema.get("description", ""), + ) + + request = MentionGraphRequest( + parent_node_id=payload.parent_node_id, + parameter_key=payload.parameter_key, + context_source=payload.context_source, + parameter_schema=parameter_schema, + ) + + with Session(db.engine) as session: + service = MentionGraphService(session) + response = service.generate_mention_graph(tenant_id=app_model.tenant_id, request=request) + + return response.model_dump() diff --git a/api/services/workflow/entities.py b/api/services/workflow/entities.py index 70ec8d6e2a..cf5519527d 100644 --- a/api/services/workflow/entities.py +++ b/api/services/workflow/entities.py @@ -163,3 +163,29 @@ class WorkflowScheduleCFSPlanEntity(BaseModel): schedule_strategy: Strategy granularity: int = Field(default=-1) # -1 means infinite + + +# ========== Mention Graph Entities ========== + + +class MentionParameterSchema(BaseModel): + """Schema for the parameter to be extracted from mention context.""" + + name: str = Field(description="Parameter name (e.g., 'query')") + type: str = Field(default="string", description="Parameter type (e.g., 'string', 'number')") + description: str = Field(default="", description="Parameter description for LLM") + + +class MentionGraphRequest(BaseModel): + """Request payload for generating mention graph.""" + + parent_node_id: str = Field(description="ID of the parent node that uses the extracted value") + parameter_key: str = Field(description="Key of the parameter being extracted") + context_source: list[str] = Field(description="Variable selector for the context source") + parameter_schema: MentionParameterSchema = Field(description="Schema of the parameter to extract") + + +class MentionGraphResponse(BaseModel): + """Response containing the generated mention graph.""" + + graph: Mapping[str, Any] = Field(description="Complete graph structure with nodes, edges, viewport") diff --git a/api/services/workflow/mention_graph_service.py b/api/services/workflow/mention_graph_service.py new file mode 100644 index 0000000000..05b5c303cc --- /dev/null +++ b/api/services/workflow/mention_graph_service.py @@ -0,0 +1,140 @@ +""" +Service for generating Mention LLM node graph structures. + +This service creates graph structures containing LLM nodes configured for +extracting values from list[PromptMessage] variables. +""" + +from typing import Any + +from sqlalchemy.orm import Session + +from core.model_runtime.entities import LLMMode +from core.workflow.enums import NodeType +from services.model_provider_service import ModelProviderService +from services.workflow.entities import MentionGraphRequest, MentionGraphResponse, MentionParameterSchema + + +class MentionGraphService: + """Service for generating Mention LLM node graph structures.""" + + def __init__(self, session: Session): + self._session = session + + def generate_mention_node_id(self, node_id: str, parameter_name: str) -> str: + """Generate mention node ID following the naming convention. + + Format: {node_id}_ext_{parameter_name} + """ + return f"{node_id}_ext_{parameter_name}" + + def generate_mention_graph(self, tenant_id: str, request: MentionGraphRequest) -> MentionGraphResponse: + """Generate a complete graph structure containing a Mention LLM node. + + Args: + tenant_id: The tenant ID for fetching default model config + request: The mention graph generation request + + Returns: + Complete graph structure with nodes, edges, and viewport + """ + node_id = self.generate_mention_node_id(request.parent_node_id, request.parameter_key) + model_config = self._get_default_model_config(tenant_id) + node = self._build_mention_llm_node( + node_id=node_id, + parent_node_id=request.parent_node_id, + context_source=request.context_source, + parameter_schema=request.parameter_schema, + model_config=model_config, + ) + + graph = { + "nodes": [node], + "edges": [], + "viewport": {}, + } + + return MentionGraphResponse(graph=graph) + + def _get_default_model_config(self, tenant_id: str) -> dict[str, Any]: + """Get the default LLM model configuration for the tenant.""" + model_provider_service = ModelProviderService() + default_model = model_provider_service.get_default_model_of_model_type( + tenant_id=tenant_id, + model_type="llm", + ) + + if default_model: + return { + "provider": default_model.provider.provider, + "name": default_model.model, + "mode": LLMMode.CHAT.value, + "completion_params": {}, + } + + # Fallback to empty config if no default model is configured + return { + "provider": "", + "name": "", + "mode": LLMMode.CHAT.value, + "completion_params": {}, + } + + def _build_mention_llm_node( + self, + *, + node_id: str, + parent_node_id: str, + context_source: list[str], + parameter_schema: MentionParameterSchema, + model_config: dict[str, Any], + ) -> dict[str, Any]: + """Build the Mention LLM node structure. + + The node uses: + - $context in prompt_template to reference the PromptMessage list + - structured_output for extracting the specific parameter + - parent_node_id to associate with the parent node + """ + prompt_template = [ + { + "role": "system", + "text": "Extract the required parameter value from the conversation context above.", + }, + {"$context": context_source}, + {"role": "user", "text": ""}, + ] + + structured_output = { + "type": "object", + "properties": { + parameter_schema.name: { + "type": parameter_schema.type, + "description": parameter_schema.description, + } + }, + "required": [parameter_schema.name], + } + + return { + "id": node_id, + "position": {"x": 0, "y": 0}, + "data": { + "type": NodeType.LLM.value, + "title": f"Mention: {parameter_schema.name}", + "desc": f"Extract {parameter_schema.name} from conversation context", + "parent_node_id": parent_node_id, + "model": model_config, + "prompt_template": prompt_template, + "context": { + "enabled": False, + "variable_selector": None, + }, + "vision": { + "enabled": False, + }, + "memory": None, + "structured_output_enabled": True, + "structured_output": structured_output, + }, + }