mirror of
https://github.com/langgenius/dify.git
synced 2026-06-13 04:01:12 +08:00
fix(agent-v2): filter workflow invite options (#37368)
Co-authored-by: Yansong Zhang <916125788@qq.com>
This commit is contained in:
parent
514fddb60c
commit
5e8c182970
@ -0,0 +1,40 @@
|
||||
"""add agent active config has model
|
||||
|
||||
Revision ID: 9f4b7c2d1a80
|
||||
Revises: 0b2f2c8a9d1e
|
||||
Create Date: 2026-06-12 16:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9f4b7c2d1a80"
|
||||
down_revision = "0b2f2c8a9d1e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
with op.batch_alter_table("agents", schema=None) as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column(
|
||||
"active_config_has_model",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("false"),
|
||||
nullable=False,
|
||||
)
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"agent_tenant_invitable_idx",
|
||||
"agents",
|
||||
["tenant_id", "scope", "status", "active_config_has_model", "updated_at"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_index("agent_tenant_invitable_idx", table_name="agents")
|
||||
with op.batch_alter_table("agents", schema=None) as batch_op:
|
||||
batch_op.drop_column("active_config_has_model")
|
||||
@ -131,6 +131,14 @@ class Agent(DefaultFieldsMixin, Base):
|
||||
Index("agent_tenant_workflow_id_idx", "tenant_id", "workflow_id"),
|
||||
Index("agent_tenant_app_id_idx", "tenant_id", "app_id"),
|
||||
Index("agent_active_config_snapshot_id_idx", "active_config_snapshot_id"),
|
||||
Index(
|
||||
"agent_tenant_invitable_idx",
|
||||
"tenant_id",
|
||||
"scope",
|
||||
"status",
|
||||
"active_config_has_model",
|
||||
"updated_at",
|
||||
),
|
||||
)
|
||||
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
@ -153,6 +161,9 @@ class Agent(DefaultFieldsMixin, Base):
|
||||
workflow_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
workflow_node_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
active_config_snapshot_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
active_config_has_model: Mapped[bool] = mapped_column(
|
||||
sa.Boolean, nullable=False, default=False, server_default=sa.text("false")
|
||||
)
|
||||
status: Mapped[AgentStatus] = mapped_column(
|
||||
EnumText(AgentStatus, length=32), nullable=False, default=AgentStatus.ACTIVE
|
||||
)
|
||||
|
||||
6
api/services/agent/agent_soul_state.py
Normal file
6
api/services/agent/agent_soul_state.py
Normal file
@ -0,0 +1,6 @@
|
||||
from models.agent_config_entities import AgentSoulConfig
|
||||
|
||||
|
||||
def agent_soul_has_model(agent_soul: AgentSoulConfig) -> bool:
|
||||
"""Return whether the Agent Soul has the minimum model config required for runtime."""
|
||||
return agent_soul.model is not None
|
||||
@ -26,6 +26,7 @@ from models.agent_config_entities import (
|
||||
effective_declared_outputs as _effective_declared_outputs,
|
||||
)
|
||||
from models.workflow import Workflow
|
||||
from services.agent.agent_soul_state import agent_soul_has_model
|
||||
from services.agent.composer_validator import ComposerConfigValidator
|
||||
from services.agent.errors import AgentNameConflictError, AgentNotFoundError, AgentVersionNotFoundError
|
||||
from services.entities.agent_entities import (
|
||||
@ -229,6 +230,7 @@ class AgentComposerService:
|
||||
version_note=payload.version_note,
|
||||
)
|
||||
agent.active_config_snapshot_id = version.id
|
||||
agent.active_config_has_model = agent_soul_has_model(payload.agent_soul)
|
||||
else:
|
||||
current_snapshot = cls._require_version(
|
||||
tenant_id=tenant_id, agent_id=agent.id, version_id=agent.active_config_snapshot_id
|
||||
@ -241,6 +243,7 @@ class AgentComposerService:
|
||||
version_note=payload.version_note,
|
||||
)
|
||||
agent.active_config_snapshot_id = version.id
|
||||
agent.active_config_has_model = agent_soul_has_model(payload.agent_soul)
|
||||
agent.updated_by = account_id
|
||||
|
||||
db.session.commit()
|
||||
@ -605,6 +608,7 @@ class AgentComposerService:
|
||||
)
|
||||
agent = cls._require_agent(tenant_id=tenant_id, agent_id=binding.agent_id)
|
||||
agent.active_config_snapshot_id = version.id
|
||||
agent.active_config_has_model = agent_soul_has_model(payload.agent_soul)
|
||||
agent.updated_by = account_id
|
||||
binding.current_snapshot_id = version.id
|
||||
if payload.node_job is not None:
|
||||
@ -634,6 +638,7 @@ class AgentComposerService:
|
||||
)
|
||||
agent = cls._require_agent(tenant_id=tenant_id, agent_id=binding.agent_id)
|
||||
agent.active_config_snapshot_id = version.id
|
||||
agent.active_config_has_model = agent_soul_has_model(payload.agent_soul)
|
||||
agent.updated_by = account_id
|
||||
binding.current_snapshot_id = version.id
|
||||
binding.updated_by = account_id
|
||||
@ -753,6 +758,7 @@ class AgentComposerService:
|
||||
version_note=None,
|
||||
)
|
||||
agent.active_config_snapshot_id = version.id
|
||||
agent.active_config_has_model = agent_soul_has_model(agent_soul)
|
||||
return agent
|
||||
|
||||
@classmethod
|
||||
@ -792,6 +798,7 @@ class AgentComposerService:
|
||||
version_note=version_note,
|
||||
)
|
||||
agent.active_config_snapshot_id = version.id
|
||||
agent.active_config_has_model = agent_soul_has_model(agent_soul)
|
||||
return agent
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -21,6 +21,7 @@ from models.agent_config_entities import AgentSoulConfig
|
||||
from models.enums import AppStatus
|
||||
from models.model import App
|
||||
from models.workflow import Workflow
|
||||
from services.agent.agent_soul_state import agent_soul_has_model
|
||||
from services.agent.composer_validator import ComposerConfigValidator
|
||||
from services.agent.errors import (
|
||||
AgentArchivedError,
|
||||
@ -95,9 +96,8 @@ class AgentRosterService:
|
||||
"created_at": to_timestamp(version.created_at),
|
||||
}
|
||||
|
||||
def list_roster_agents(
|
||||
self, *, tenant_id: str, page: int = 1, limit: int = 20, keyword: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
@staticmethod
|
||||
def _build_roster_agents_stmt(*, tenant_id: str, keyword: str | None = None):
|
||||
stmt = select(Agent).where(
|
||||
Agent.tenant_id == tenant_id,
|
||||
Agent.scope == AgentScope.ROSTER,
|
||||
@ -108,7 +108,12 @@ class AgentRosterService:
|
||||
|
||||
escaped_keyword = escape_like_pattern(keyword)
|
||||
stmt = stmt.where(Agent.name.ilike(f"%{escaped_keyword}%", escape="\\"))
|
||||
stmt = stmt.order_by(Agent.updated_at.desc())
|
||||
return stmt.order_by(Agent.updated_at.desc())
|
||||
|
||||
def list_roster_agents(
|
||||
self, *, tenant_id: str, page: int = 1, limit: int = 20, keyword: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
stmt = self._build_roster_agents_stmt(tenant_id=tenant_id, keyword=keyword)
|
||||
|
||||
total = self._session.scalar(select(func.count()).select_from(stmt.subquery())) or 0
|
||||
agents = list(self._session.scalars(stmt.offset((page - 1) * limit).limit(limit)).all())
|
||||
@ -144,7 +149,26 @@ class AgentRosterService:
|
||||
def list_invite_options(
|
||||
self, *, tenant_id: str, page: int = 1, limit: int = 20, keyword: str | None = None, app_id: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
result = self.list_roster_agents(tenant_id=tenant_id, page=page, limit=limit, keyword=keyword)
|
||||
stmt = self._build_roster_agents_stmt(tenant_id=tenant_id, keyword=keyword).where(
|
||||
Agent.active_config_has_model.is_(True)
|
||||
)
|
||||
total = self._session.scalar(select(func.count()).select_from(stmt.subquery())) or 0
|
||||
agents = list(self._session.scalars(stmt.offset((page - 1) * limit).limit(limit)).all())
|
||||
versions_by_id = self._load_versions_by_id(
|
||||
[agent.active_config_snapshot_id for agent in agents if agent.active_config_snapshot_id]
|
||||
)
|
||||
published_references_by_agent_id = self._load_published_references_by_agent_id(
|
||||
tenant_id=tenant_id,
|
||||
agent_ids=[agent.id for agent in agents],
|
||||
)
|
||||
data = [
|
||||
self.serialize_agent(
|
||||
agent,
|
||||
versions_by_id.get(agent.active_config_snapshot_id) if agent.active_config_snapshot_id else None,
|
||||
published_references_by_agent_id.get(agent.id, []),
|
||||
)
|
||||
for agent in agents
|
||||
]
|
||||
usage_by_agent_id: dict[str, list[str]] = {}
|
||||
if app_id:
|
||||
draft_workflow = self._session.scalar(
|
||||
@ -157,7 +181,7 @@ class AgentRosterService:
|
||||
.limit(1)
|
||||
)
|
||||
if draft_workflow:
|
||||
agent_ids = [item["id"] for item in result["data"]]
|
||||
agent_ids = [item["id"] for item in data]
|
||||
if agent_ids:
|
||||
bindings = self._session.scalars(
|
||||
select(WorkflowAgentNodeBinding).where(
|
||||
@ -170,12 +194,18 @@ class AgentRosterService:
|
||||
if binding.agent_id:
|
||||
usage_by_agent_id.setdefault(binding.agent_id, []).append(binding.node_id)
|
||||
|
||||
for item in result["data"]:
|
||||
for item in data:
|
||||
existing_node_ids = usage_by_agent_id.get(item["id"], [])
|
||||
item["is_in_current_workflow"] = bool(existing_node_ids)
|
||||
item["in_current_workflow_count"] = len(existing_node_ids)
|
||||
item["existing_node_ids"] = existing_node_ids
|
||||
return result
|
||||
return {
|
||||
"data": data,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"has_more": page * limit < total,
|
||||
}
|
||||
|
||||
def create_roster_agent(
|
||||
self,
|
||||
@ -231,6 +261,7 @@ class AgentRosterService:
|
||||
)
|
||||
self._session.add(revision)
|
||||
agent.active_config_snapshot_id = version.id
|
||||
agent.active_config_has_model = agent_soul_has_model(payload.agent_soul)
|
||||
|
||||
try:
|
||||
self._session.commit()
|
||||
@ -302,6 +333,7 @@ class AgentRosterService:
|
||||
)
|
||||
self._session.add(revision)
|
||||
agent.active_config_snapshot_id = version.id
|
||||
agent.active_config_has_model = agent_soul_has_model(AgentSoulConfig())
|
||||
self._session.flush()
|
||||
return agent
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@ from models.agent import (
|
||||
from models.agent_config_entities import WorkflowNodeJobConfig
|
||||
from models.workflow import Workflow
|
||||
from services.agent import composer_service, roster_service
|
||||
from services.agent.agent_soul_state import agent_soul_has_model
|
||||
from services.agent.composer_service import AgentComposerService
|
||||
from services.agent.composer_validator import ComposerConfigValidator
|
||||
from services.agent.errors import InvalidComposerConfigError
|
||||
@ -72,6 +73,23 @@ class FakeSession:
|
||||
self.rollbacks += 1
|
||||
|
||||
|
||||
def _agent_soul_with_model() -> AgentSoulConfig:
|
||||
return AgentSoulConfig.model_validate(
|
||||
{
|
||||
"model": {
|
||||
"plugin_id": "langgenius/openai/openai",
|
||||
"model_provider": "openai",
|
||||
"model": "gpt-4o",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_agent_soul_has_model():
|
||||
assert agent_soul_has_model(_agent_soul_with_model()) is True
|
||||
assert agent_soul_has_model(AgentSoulConfig()) is False
|
||||
|
||||
|
||||
def test_load_workflow_composer_returns_empty_state(monkeypatch):
|
||||
monkeypatch.setattr(AgentComposerService, "_get_draft_workflow", lambda **kwargs: SimpleNamespace(id="workflow-1"))
|
||||
monkeypatch.setattr(AgentComposerService, "_get_workflow_binding", lambda **kwargs: None)
|
||||
@ -217,13 +235,13 @@ def test_save_agent_app_composer_creates_agent_when_missing(monkeypatch):
|
||||
assert result == {"loaded": True}
|
||||
assert fake_session.added[0].name == "Analyst"
|
||||
assert fake_session.added[0].active_config_snapshot_id == "version-1"
|
||||
assert fake_session.added[0].active_config_has_model is False
|
||||
assert fake_session.commits == 1
|
||||
|
||||
|
||||
def test_save_agent_app_composer_updates_current_version(monkeypatch):
|
||||
fake_session = FakeSession(
|
||||
scalar=[SimpleNamespace(id="agent-1", active_config_snapshot_id="version-1", updated_by=None)]
|
||||
)
|
||||
agent = SimpleNamespace(id="agent-1", active_config_snapshot_id="version-1", updated_by=None)
|
||||
fake_session = FakeSession(scalar=[agent])
|
||||
updated = {}
|
||||
|
||||
monkeypatch.setattr(composer_service.db, "session", fake_session)
|
||||
@ -239,7 +257,7 @@ def test_save_agent_app_composer_updates_current_version(monkeypatch):
|
||||
{
|
||||
"variant": ComposerVariant.AGENT_APP.value,
|
||||
"save_strategy": ComposerSaveStrategy.SAVE_TO_CURRENT_VERSION.value,
|
||||
"agent_soul": {"prompt": {"system_prompt": "updated"}},
|
||||
"agent_soul": _agent_soul_with_model().model_dump(mode="json"),
|
||||
}
|
||||
)
|
||||
|
||||
@ -250,6 +268,7 @@ def test_save_agent_app_composer_updates_current_version(monkeypatch):
|
||||
assert result.pop("validation") == {"warnings": [], "knowledge_retrieval_placeholder": []}
|
||||
assert result == {"loaded": True}
|
||||
assert updated["operation"].value == "save_current_version"
|
||||
assert agent.active_config_has_model is True
|
||||
assert fake_session._scalar == []
|
||||
assert fake_session.commits == 1
|
||||
|
||||
@ -431,6 +450,38 @@ def test_composer_save_helpers_create_and_rebind_agents(monkeypatch):
|
||||
assert new_version_binding.current_snapshot_id == "new-version-1"
|
||||
|
||||
|
||||
def test_composer_create_agents_syncs_active_config_has_model(monkeypatch):
|
||||
fake_session = FakeSession()
|
||||
monkeypatch.setattr(composer_service.db, "session", fake_session)
|
||||
monkeypatch.setattr(
|
||||
AgentComposerService,
|
||||
"_create_config_version",
|
||||
lambda **kwargs: SimpleNamespace(id="version-with-model"),
|
||||
)
|
||||
|
||||
workflow_agent = AgentComposerService._create_workflow_only_agent(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
workflow_id="workflow-1",
|
||||
node_id="node-1",
|
||||
account_id="account-1",
|
||||
agent_soul=_agent_soul_with_model(),
|
||||
)
|
||||
roster_agent = AgentComposerService._create_roster_agent_for_composer(
|
||||
tenant_id="tenant-1",
|
||||
account_id="account-1",
|
||||
name="Ready Agent",
|
||||
agent_soul=_agent_soul_with_model(),
|
||||
operation=AgentConfigRevisionOperation.CREATE_VERSION,
|
||||
version_note=None,
|
||||
)
|
||||
|
||||
assert workflow_agent.active_config_snapshot_id == "version-with-model"
|
||||
assert workflow_agent.active_config_has_model is True
|
||||
assert roster_agent.active_config_snapshot_id == "version-with-model"
|
||||
assert roster_agent.active_config_has_model is True
|
||||
|
||||
|
||||
def test_composer_version_helpers_and_lookup_errors(monkeypatch):
|
||||
fake_session = FakeSession(
|
||||
scalar=[
|
||||
@ -554,20 +605,50 @@ def test_roster_list_and_invite_options(monkeypatch):
|
||||
)
|
||||
agent.created_at = created_at
|
||||
agent.updated_at = updated_at
|
||||
version = AgentConfigSnapshot(id="version-1", agent_id="agent-1", version=1)
|
||||
version = AgentConfigSnapshot(
|
||||
id="version-1", agent_id="agent-1", version=1, config_snapshot=_agent_soul_with_model()
|
||||
)
|
||||
version.created_at = version_created_at
|
||||
agent.active_config_snapshot_id = "version-1"
|
||||
agent.active_config_has_model = True
|
||||
unconfigured_agent = Agent(
|
||||
id="agent-2",
|
||||
tenant_id="tenant-1",
|
||||
name="Draft Agent",
|
||||
description="",
|
||||
role="draft",
|
||||
agent_kind=AgentKind.DIFY_AGENT,
|
||||
scope=AgentScope.ROSTER,
|
||||
source=AgentSource.AGENT_APP,
|
||||
status=AgentStatus.ACTIVE,
|
||||
)
|
||||
unconfigured_agent.active_config_snapshot_id = "version-2"
|
||||
unconfigured_agent.active_config_has_model = False
|
||||
unconfigured_version = AgentConfigSnapshot(
|
||||
id="version-2", agent_id="agent-2", version=1, config_snapshot=AgentSoulConfig()
|
||||
)
|
||||
fake_session = FakeSession(
|
||||
scalar=[1, 1, SimpleNamespace(id="workflow-1")],
|
||||
scalars=[[agent], [agent], [SimpleNamespace(agent_id="agent-1", node_id="node-1")]],
|
||||
scalar=[2, 1, SimpleNamespace(id="workflow-1")],
|
||||
scalars=[
|
||||
[agent, unconfigured_agent],
|
||||
[agent],
|
||||
[SimpleNamespace(agent_id="agent-1", node_id="node-1")],
|
||||
],
|
||||
)
|
||||
service = AgentRosterService(fake_session)
|
||||
monkeypatch.setattr(service, "_load_versions_by_id", lambda version_ids: {"version-1": version})
|
||||
monkeypatch.setattr(
|
||||
service,
|
||||
"_load_versions_by_id",
|
||||
lambda version_ids: {"version-1": version, "version-2": unconfigured_version},
|
||||
)
|
||||
monkeypatch.setattr(service, "_load_published_references_by_agent_id", lambda **kwargs: {})
|
||||
|
||||
listed = service.list_roster_agents(tenant_id="tenant-1", page=1, limit=20)
|
||||
invited = service.list_invite_options(tenant_id="tenant-1", page=1, limit=20, app_id="app-1")
|
||||
|
||||
assert [item["id"] for item in listed["data"]] == ["agent-1", "agent-2"]
|
||||
assert [item["id"] for item in invited["data"]] == ["agent-1"]
|
||||
assert invited["total"] == 1
|
||||
assert listed["data"][0]["active_config_snapshot"]["id"] == "version-1"
|
||||
assert listed["data"][0]["role"] == "researcher"
|
||||
assert listed["data"][0]["created_at"] == int(created_at.timestamp())
|
||||
@ -577,6 +658,39 @@ def test_roster_list_and_invite_options(monkeypatch):
|
||||
assert invited["data"][0]["existing_node_ids"] == ["node-1"]
|
||||
|
||||
|
||||
def test_invite_options_uses_db_filtered_pagination(monkeypatch):
|
||||
configured_agent = Agent(
|
||||
id="agent-2",
|
||||
tenant_id="tenant-1",
|
||||
name="Ready Agent",
|
||||
description="",
|
||||
agent_kind=AgentKind.DIFY_AGENT,
|
||||
scope=AgentScope.ROSTER,
|
||||
source=AgentSource.AGENT_APP,
|
||||
status=AgentStatus.ACTIVE,
|
||||
active_config_snapshot_id="version-2",
|
||||
active_config_has_model=True,
|
||||
)
|
||||
fake_session = FakeSession(scalar=[1], scalars=[[configured_agent]])
|
||||
service = AgentRosterService(fake_session)
|
||||
monkeypatch.setattr(
|
||||
service,
|
||||
"_load_versions_by_id",
|
||||
lambda version_ids: {
|
||||
"version-2": AgentConfigSnapshot(
|
||||
id="version-2", agent_id="agent-2", version=1, config_snapshot=_agent_soul_with_model()
|
||||
)
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(service, "_load_published_references_by_agent_id", lambda **kwargs: {})
|
||||
|
||||
result = service.list_invite_options(tenant_id="tenant-1", page=1, limit=1)
|
||||
|
||||
assert result["total"] == 1
|
||||
assert result["has_more"] is False
|
||||
assert [item["id"] for item in result["data"]] == ["agent-2"]
|
||||
|
||||
|
||||
def test_roster_update_archive_versions_and_detail(monkeypatch):
|
||||
listed_version = AgentConfigSnapshot(id="version-2", agent_id="agent-1", version=2)
|
||||
listed_version_created_at = datetime(2026, 1, 5, 3, 4, 5, tzinfo=UTC)
|
||||
@ -657,6 +771,12 @@ def test_roster_create_detail_and_lookup_helpers(monkeypatch):
|
||||
)
|
||||
|
||||
created = service.create_roster_agent(tenant_id="tenant-1", account_id="account-1", payload=payload)
|
||||
backing_agent = service.create_backing_agent_for_app(
|
||||
tenant_id="tenant-1",
|
||||
account_id="account-1",
|
||||
app_id="app-1",
|
||||
name="Backing Agent",
|
||||
)
|
||||
found_agent = service._get_agent(tenant_id="tenant-1", agent_id="agent-1")
|
||||
with pytest.raises(roster_service.AgentNotFoundError):
|
||||
service._get_agent(tenant_id="tenant-1", agent_id="missing")
|
||||
@ -668,6 +788,9 @@ def test_roster_create_detail_and_lookup_helpers(monkeypatch):
|
||||
|
||||
assert created.name == "Analyst"
|
||||
assert created.active_config_snapshot_id is not None
|
||||
assert created.active_config_has_model is False
|
||||
assert backing_agent.active_config_snapshot_id is not None
|
||||
assert backing_agent.active_config_has_model is False
|
||||
assert found_agent.id == "agent-1"
|
||||
assert found_version.id == "version-1"
|
||||
assert loaded_versions["version-1"].agent_id == "agent-1"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user