diff --git a/api/services/agent/workflow_publish_service.py b/api/services/agent/workflow_publish_service.py index 13927026c4..9dd7531167 100644 --- a/api/services/agent/workflow_publish_service.py +++ b/api/services/agent/workflow_publish_service.py @@ -3,12 +3,13 @@ from __future__ import annotations from collections.abc import Mapping from typing import Any +from pydantic import ValidationError from sqlalchemy import select from sqlalchemy.orm import Session from core.workflow.nodes.agent_v2.validators import WorkflowAgentNodeValidator from models.agent import Agent, AgentScope, AgentStatus, WorkflowAgentBindingType, WorkflowAgentNodeBinding -from models.agent_config_entities import WorkflowNodeJobConfig +from models.agent_config_entities import DeclaredOutputConfig, WorkflowNodeJobConfig from models.workflow import Workflow @@ -17,6 +18,8 @@ class WorkflowAgentPublishService: _DRAFT_WORKFLOW_VERSION = Workflow.VERSION_DRAFT _AGENT_BINDING_KEY = "agent_binding" + _AGENT_TASK_KEY = "agent_task" + _AGENT_DECLARED_OUTPUTS_KEY = "agent_declared_outputs" @classmethod def validate_agent_nodes_for_publish(cls, *, session: Session, draft_workflow: Workflow) -> None: @@ -61,6 +64,7 @@ class WorkflowAgentPublishService: session=session, draft_workflow=draft_workflow, node_id=node_id, + node_data=node_data, node_binding=binding_payload, existing_binding=existing_by_node_id.get(node_id), account_id=account_id, @@ -74,6 +78,7 @@ class WorkflowAgentPublishService: session: Session, draft_workflow: Workflow, node_id: str, + node_data: Mapping[str, Any], node_binding: Mapping[str, Any], existing_binding: WorkflowAgentNodeBinding | None, account_id: str, @@ -101,6 +106,10 @@ class WorkflowAgentPublishService: raise ValueError(f"Workflow Agent node {node_id} roster agent has no active config snapshot.") binding = existing_binding + node_job_config = cls._node_job_config_from_node_data( + existing_binding=existing_binding, + node_data=node_data, + ) if binding is None: binding = WorkflowAgentNodeBinding( tenant_id=draft_workflow.tenant_id, @@ -108,18 +117,47 @@ class WorkflowAgentPublishService: workflow_id=draft_workflow.id, workflow_version=cls._DRAFT_WORKFLOW_VERSION, node_id=node_id, - node_job_config=WorkflowNodeJobConfig(), + node_job_config=node_job_config, created_by=account_id, ) session.add(binding) - elif not binding.node_job_config: - binding.node_job_config = WorkflowNodeJobConfig() + else: + binding.node_job_config = node_job_config binding.binding_type = WorkflowAgentBindingType.ROSTER_AGENT binding.agent_id = agent.id binding.current_snapshot_id = agent.active_config_snapshot_id binding.updated_by = account_id + @classmethod + def _node_job_config_from_node_data( + cls, + *, + existing_binding: WorkflowAgentNodeBinding | None, + node_data: Mapping[str, Any], + ) -> WorkflowNodeJobConfig: + if existing_binding and existing_binding.node_job_config: + node_job = WorkflowNodeJobConfig.model_validate(existing_binding.node_job_config_dict) + else: + node_job = WorkflowNodeJobConfig() + + agent_task = node_data.get(cls._AGENT_TASK_KEY) + if isinstance(agent_task, str): + node_job.workflow_prompt = agent_task + + declared_outputs_payload = node_data.get(cls._AGENT_DECLARED_OUTPUTS_KEY) + if declared_outputs_payload is not None: + if not isinstance(declared_outputs_payload, list): + raise ValueError("Workflow Agent node agent_declared_outputs must be a list.") + try: + node_job.declared_outputs = [ + DeclaredOutputConfig.model_validate(output) for output in declared_outputs_payload + ] + except ValidationError as exc: + raise ValueError("Workflow Agent node has invalid agent_declared_outputs.") from exc + + return node_job + @classmethod def copy_agent_node_bindings_to_published( cls, diff --git a/api/tests/unit_tests/services/agent/test_agent_services.py b/api/tests/unit_tests/services/agent/test_agent_services.py index b64fc56170..5e85ea4970 100644 --- a/api/tests/unit_tests/services/agent/test_agent_services.py +++ b/api/tests/unit_tests/services/agent/test_agent_services.py @@ -1,3 +1,4 @@ +import json from datetime import UTC, datetime from types import SimpleNamespace @@ -14,7 +15,12 @@ from models.agent import ( WorkflowAgentBindingType, WorkflowAgentNodeBinding, ) -from models.agent_config_entities import AgentFileRefConfig, WorkflowNodeJobConfig +from models.agent_config_entities import ( + AgentFileRefConfig, + DeclaredOutputConfig, + DeclaredOutputType, + WorkflowNodeJobConfig, +) from models.workflow import Workflow from services.agent import composer_service, roster_service from services.agent.agent_soul_state import agent_soul_has_model @@ -1115,7 +1121,31 @@ class TestWorkflowAgentDraftBindingSync: tenant_id="tenant-1", app_id="app-1", version=Workflow.VERSION_DRAFT, - graph='{"nodes":[{"id":"agent-node","data":{"type":"agent","version":"2","agent_binding":{"binding_type":"roster_agent","agent_id":"agent-1"}}}]}', + graph=json.dumps( + { + "nodes": [ + { + "id": "agent-node", + "data": { + "type": "agent", + "version": "2", + "agent_task": "Summarize the upstream result.", + "agent_declared_outputs": [ + { + "name": "summary", + "type": "string", + "description": "Short summary", + } + ], + "agent_binding": { + "binding_type": "roster_agent", + "agent_id": "agent-1", + }, + }, + } + ] + } + ), ) agent = Agent( id="agent-1", @@ -1139,7 +1169,151 @@ class TestWorkflowAgentDraftBindingSync: assert binding.binding_type == WorkflowAgentBindingType.ROSTER_AGENT assert binding.agent_id == "agent-1" assert binding.current_snapshot_id == "snapshot-2" - assert binding.node_job_config_dict == WorkflowNodeJobConfig().model_dump(mode="json") + assert binding.node_job_config_dict == WorkflowNodeJobConfig( + workflow_prompt="Summarize the upstream result.", + declared_outputs=[ + DeclaredOutputConfig( + name="summary", + type=DeclaredOutputType.STRING, + description="Short summary", + ) + ], + ).model_dump(mode="json") + + def test_updates_existing_roster_binding_prompt_from_agent_node_graph(self): + workflow = Workflow( + id="workflow-1", + tenant_id="tenant-1", + app_id="app-1", + version=Workflow.VERSION_DRAFT, + graph=json.dumps( + { + "nodes": [ + { + "id": "agent-node", + "data": { + "type": "agent", + "version": "2", + "agent_task": "Use the latest tender context.", + "agent_binding": { + "binding_type": "roster_agent", + "agent_id": "agent-1", + }, + }, + } + ] + } + ), + ) + agent = Agent( + id="agent-1", + tenant_id="tenant-1", + name="Agent", + agent_kind=AgentKind.DIFY_AGENT, + scope=AgentScope.ROSTER, + source=AgentSource.AGENT_APP, + status=AgentStatus.ACTIVE, + active_config_snapshot_id="snapshot-2", + ) + existing_binding = WorkflowAgentNodeBinding( + id="binding-1", + tenant_id="tenant-1", + app_id="app-1", + workflow_id="workflow-1", + workflow_version=Workflow.VERSION_DRAFT, + node_id="agent-node", + binding_type=WorkflowAgentBindingType.ROSTER_AGENT, + agent_id="agent-1", + current_snapshot_id="snapshot-1", + node_job_config=WorkflowNodeJobConfig( + workflow_prompt="Old prompt", + declared_outputs=[ + DeclaredOutputConfig(name="summary", type=DeclaredOutputType.STRING, description="Short summary") + ], + ), + ) + session = FakeSession(scalar=[agent], scalars=[[existing_binding]]) + + WorkflowAgentPublishService.sync_roster_agent_bindings_for_draft( + session=session, + draft_workflow=workflow, + account_id="account-1", + ) + + node_job = WorkflowNodeJobConfig.model_validate(existing_binding.node_job_config_dict) + assert node_job.workflow_prompt == "Use the latest tender context." + assert [output.name for output in node_job.declared_outputs] == ["summary"] + assert existing_binding.current_snapshot_id == "snapshot-2" + + def test_updates_existing_roster_binding_declared_outputs_from_agent_node_graph(self): + workflow = Workflow( + id="workflow-1", + tenant_id="tenant-1", + app_id="app-1", + version=Workflow.VERSION_DRAFT, + graph=json.dumps( + { + "nodes": [ + { + "id": "agent-node", + "data": { + "type": "agent", + "version": "2", + "agent_task": "Keep the prompt.", + "agent_declared_outputs": [], + "agent_binding": { + "binding_type": "roster_agent", + "agent_id": "agent-1", + }, + }, + } + ] + } + ), + ) + agent = Agent( + id="agent-1", + tenant_id="tenant-1", + name="Agent", + agent_kind=AgentKind.DIFY_AGENT, + scope=AgentScope.ROSTER, + source=AgentSource.AGENT_APP, + status=AgentStatus.ACTIVE, + active_config_snapshot_id="snapshot-2", + ) + existing_binding = WorkflowAgentNodeBinding( + id="binding-1", + tenant_id="tenant-1", + app_id="app-1", + workflow_id="workflow-1", + workflow_version=Workflow.VERSION_DRAFT, + node_id="agent-node", + binding_type=WorkflowAgentBindingType.ROSTER_AGENT, + agent_id="agent-1", + current_snapshot_id="snapshot-1", + node_job_config=WorkflowNodeJobConfig( + workflow_prompt="Old prompt", + declared_outputs=[ + DeclaredOutputConfig( + name="summary", + type=DeclaredOutputType.STRING, + description="Short summary", + ) + ], + ), + ) + session = FakeSession(scalar=[agent], scalars=[[existing_binding]]) + + WorkflowAgentPublishService.sync_roster_agent_bindings_for_draft( + session=session, + draft_workflow=workflow, + account_id="account-1", + ) + + node_job = WorkflowNodeJobConfig.model_validate(existing_binding.node_job_config_dict) + assert node_job.workflow_prompt == "Keep the prompt." + assert node_job.declared_outputs == [] + assert existing_binding.current_snapshot_id == "snapshot-2" def test_deletes_draft_binding_when_agent_node_removed(self): workflow = Workflow(