refactor: use EnumText for Conversation/Message invoke_from and from_source (#33901)

This commit is contained in:
tmimmanuel 2026-03-23 08:03:35 +01:00 committed by GitHub
parent 6ecf89e262
commit 2b6f761dfe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 37 additions and 27 deletions

View File

@ -50,7 +50,7 @@ class BuiltinTool(Tool):
return ModelInvocationUtils.invoke( return ModelInvocationUtils.invoke(
user_id=user_id, user_id=user_id,
tenant_id=self.runtime.tenant_id or "", tenant_id=self.runtime.tenant_id or "",
tool_type="builtin", tool_type=ToolProviderType.BUILT_IN,
tool_name=self.entity.identity.name, tool_name=self.entity.identity.name,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
) )

View File

@ -38,7 +38,7 @@ class ToolLabelManager:
db.session.add( db.session.add(
ToolLabelBinding( ToolLabelBinding(
tool_id=provider_id, tool_id=provider_id,
tool_type=controller.provider_type.value, tool_type=controller.provider_type,
label_name=label, label_name=label,
) )
) )
@ -58,7 +58,7 @@ class ToolLabelManager:
raise ValueError("Unsupported tool type") raise ValueError("Unsupported tool type")
stmt = select(ToolLabelBinding.label_name).where( stmt = select(ToolLabelBinding.label_name).where(
ToolLabelBinding.tool_id == provider_id, ToolLabelBinding.tool_id == provider_id,
ToolLabelBinding.tool_type == controller.provider_type.value, ToolLabelBinding.tool_type == controller.provider_type,
) )
labels = db.session.scalars(stmt).all() labels = db.session.scalars(stmt).all()

View File

@ -9,6 +9,7 @@ from decimal import Decimal
from typing import cast from typing import cast
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.tools.entities.tool_entities import ToolProviderType
from dify_graph.model_runtime.entities.llm_entities import LLMResult from dify_graph.model_runtime.entities.llm_entities import LLMResult
from dify_graph.model_runtime.entities.message_entities import PromptMessage from dify_graph.model_runtime.entities.message_entities import PromptMessage
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
@ -78,7 +79,7 @@ class ModelInvocationUtils:
@staticmethod @staticmethod
def invoke( def invoke(
user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage] user_id: str, tenant_id: str, tool_type: ToolProviderType, tool_name: str, prompt_messages: list[PromptMessage]
) -> LLMResult: ) -> LLMResult:
""" """
invoke model with parameters in user's own context invoke model with parameters in user's own context

View File

@ -43,6 +43,7 @@ from .enums import (
MessageChainType, MessageChainType,
MessageFileBelongsTo, MessageFileBelongsTo,
MessageStatus, MessageStatus,
TagType,
) )
from .provider_ids import GenericProviderID from .provider_ids import GenericProviderID
from .types import EnumText, LongText, StringUUID from .types import EnumText, LongText, StringUUID
@ -2404,7 +2405,7 @@ class Tag(TypeBase):
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
) )
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
type: Mapped[str] = mapped_column(String(16), nullable=False) type: Mapped[TagType] = mapped_column(EnumText(TagType, length=16), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(

View File

@ -13,12 +13,16 @@ from sqlalchemy.orm import Mapped, mapped_column
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration from core.tools.entities.tool_entities import (
ApiProviderSchemaType,
ToolProviderType,
WorkflowToolParameterConfiguration,
)
from .base import TypeBase from .base import TypeBase
from .engine import db from .engine import db
from .model import Account, App, Tenant from .model import Account, App, Tenant
from .types import LongText, StringUUID from .types import EnumText, LongText, StringUUID
if TYPE_CHECKING: if TYPE_CHECKING:
from core.entities.mcp_provider import MCPProviderEntity from core.entities.mcp_provider import MCPProviderEntity
@ -208,7 +212,7 @@ class ToolLabelBinding(TypeBase):
# tool id # tool id
tool_id: Mapped[str] = mapped_column(String(64), nullable=False) tool_id: Mapped[str] = mapped_column(String(64), nullable=False)
# tool type # tool type
tool_type: Mapped[str] = mapped_column(String(40), nullable=False) tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False)
# label name # label name
label_name: Mapped[str] = mapped_column(String(40), nullable=False) label_name: Mapped[str] = mapped_column(String(40), nullable=False)
@ -386,7 +390,7 @@ class ToolModelInvoke(TypeBase):
# provider # provider
provider: Mapped[str] = mapped_column(String(255), nullable=False) provider: Mapped[str] = mapped_column(String(255), nullable=False)
# type # type
tool_type: Mapped[str] = mapped_column(String(40), nullable=False) tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False)
# tool name # tool name
tool_name: Mapped[str] = mapped_column(String(128), nullable=False) tool_name: Mapped[str] = mapped_column(String(128), nullable=False)
# invoke parameters # invoke parameters

View File

@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset from models.dataset import Dataset
from models.enums import TagType
from models.model import App, Tag, TagBinding from models.model import App, Tag, TagBinding
@ -83,7 +84,7 @@ class TagService:
raise ValueError("Tag name already exists") raise ValueError("Tag name already exists")
tag = Tag( tag = Tag(
name=args["name"], name=args["name"],
type=args["type"], type=TagType(args["type"]),
created_by=current_user.id, created_by=current_user.id,
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
) )

View File

@ -9,7 +9,7 @@ from werkzeug.exceptions import NotFound
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset from models.dataset import Dataset
from models.enums import DataSourceType from models.enums import DataSourceType, TagType
from models.model import App, Tag, TagBinding from models.model import App, Tag, TagBinding
from services.tag_service import TagService from services.tag_service import TagService
@ -547,7 +547,7 @@ class TestTagService:
assert result is not None assert result is not None
assert len(result) == 1 assert len(result) == 1
assert result[0].name == "python_tag" assert result[0].name == "python_tag"
assert result[0].type == "app" assert result[0].type == TagType.APP
assert result[0].tenant_id == tenant.id assert result[0].tenant_id == tenant.id
def test_get_tag_by_tag_name_no_matches( def test_get_tag_by_tag_name_no_matches(
@ -638,7 +638,7 @@ class TestTagService:
# Verify all tags are returned # Verify all tags are returned
for tag in result: for tag in result:
assert tag.type == "app" assert tag.type == TagType.APP
assert tag.tenant_id == tenant.id assert tag.tenant_id == tenant.id
assert tag.id in [t.id for t in tags] assert tag.id in [t.id for t in tags]

View File

@ -11,6 +11,7 @@ from controllers.console.tag.tags import (
TagListApi, TagListApi,
TagUpdateDeleteApi, TagUpdateDeleteApi,
) )
from models.enums import TagType
def unwrap(func): def unwrap(func):
@ -52,7 +53,7 @@ def tag():
tag = MagicMock() tag = MagicMock()
tag.id = "tag-1" tag.id = "tag-1"
tag.name = "test-tag" tag.name = "test-tag"
tag.type = "knowledge" tag.type = TagType.KNOWLEDGE
return tag return tag

View File

@ -35,6 +35,7 @@ from controllers.service_api.dataset.dataset import (
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
from models.account import Account from models.account import Account
from models.dataset import DatasetPermissionEnum from models.dataset import DatasetPermissionEnum
from models.enums import TagType
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
from services.tag_service import TagService from services.tag_service import TagService
@ -277,7 +278,7 @@ class TestDatasetTagsApi:
mock_tag = Mock() mock_tag = Mock()
mock_tag.id = "tag_1" mock_tag.id = "tag_1"
mock_tag.name = "Test Tag" mock_tag.name = "Test Tag"
mock_tag.type = "knowledge" mock_tag.type = TagType.KNOWLEDGE
mock_tag.binding_count = "0" # Required for Pydantic validation - must be string mock_tag.binding_count = "0" # Required for Pydantic validation - must be string
mock_tag_service.get_tags.return_value = [mock_tag] mock_tag_service.get_tags.return_value = [mock_tag]
@ -316,7 +317,7 @@ class TestDatasetTagsApi:
mock_tag = Mock() mock_tag = Mock()
mock_tag.id = "new_tag_1" mock_tag.id = "new_tag_1"
mock_tag.name = "New Tag" mock_tag.name = "New Tag"
mock_tag.type = "knowledge" mock_tag.type = TagType.KNOWLEDGE
mock_tag_service.save_tags.return_value = mock_tag mock_tag_service.save_tags.return_value = mock_tag
mock_service_api_ns.payload = {"name": "New Tag"} mock_service_api_ns.payload = {"name": "New Tag"}
@ -378,7 +379,7 @@ class TestDatasetTagsApi:
mock_tag = Mock() mock_tag = Mock()
mock_tag.id = "tag_1" mock_tag.id = "tag_1"
mock_tag.name = "Updated Tag" mock_tag.name = "Updated Tag"
mock_tag.type = "knowledge" mock_tag.type = TagType.KNOWLEDGE
mock_tag.binding_count = "5" mock_tag.binding_count = "5"
mock_tag_service.update_tags.return_value = mock_tag mock_tag_service.update_tags.return_value = mock_tag
mock_tag_service.get_tag_binding_count.return_value = 5 mock_tag_service.get_tag_binding_count.return_value = 5
@ -866,7 +867,7 @@ class TestTagService:
mock_tag = Mock() mock_tag = Mock()
mock_tag.id = str(uuid.uuid4()) mock_tag.id = str(uuid.uuid4())
mock_tag.name = "New Tag" mock_tag.name = "New Tag"
mock_tag.type = "knowledge" mock_tag.type = TagType.KNOWLEDGE
mock_save.return_value = mock_tag mock_save.return_value = mock_tag
result = TagService.save_tags({"name": "New Tag", "type": "knowledge"}) result = TagService.save_tags({"name": "New Tag", "type": "knowledge"})

View File

@ -12,7 +12,7 @@ This test suite covers:
import json import json
from uuid import uuid4 from uuid import uuid4
from core.tools.entities.tool_entities import ApiProviderSchemaType from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolProviderType
from models.tools import ( from models.tools import (
ApiToolProvider, ApiToolProvider,
BuiltinToolProvider, BuiltinToolProvider,
@ -631,7 +631,7 @@ class TestToolLabelBinding:
"""Test creating a tool label binding.""" """Test creating a tool label binding."""
# Arrange # Arrange
tool_id = "google.search" tool_id = "google.search"
tool_type = "builtin" tool_type = ToolProviderType.BUILT_IN
label_name = "search" label_name = "search"
# Act # Act
@ -655,7 +655,7 @@ class TestToolLabelBinding:
# Act # Act
label_binding = ToolLabelBinding( label_binding = ToolLabelBinding(
tool_id=tool_id, tool_id=tool_id,
tool_type="builtin", tool_type=ToolProviderType.BUILT_IN,
label_name=label_name, label_name=label_name,
) )
@ -667,7 +667,7 @@ class TestToolLabelBinding:
"""Test multiple labels can be bound to the same tool.""" """Test multiple labels can be bound to the same tool."""
# Arrange # Arrange
tool_id = "google.search" tool_id = "google.search"
tool_type = "builtin" tool_type = ToolProviderType.BUILT_IN
# Act # Act
binding1 = ToolLabelBinding( binding1 = ToolLabelBinding(
@ -688,7 +688,7 @@ class TestToolLabelBinding:
def test_tool_label_binding_different_tool_types(self): def test_tool_label_binding_different_tool_types(self):
"""Test label bindings for different tool types.""" """Test label bindings for different tool types."""
# Arrange # Arrange
tool_types = ["builtin", "api", "workflow"] tool_types = [ToolProviderType.BUILT_IN, ToolProviderType.API, ToolProviderType.WORKFLOW]
# Act & Assert # Act & Assert
for tool_type in tool_types: for tool_type in tool_types:
@ -951,12 +951,12 @@ class TestToolProviderRelationships:
# Act # Act
binding1 = ToolLabelBinding( binding1 = ToolLabelBinding(
tool_id=tool_id, tool_id=tool_id,
tool_type="builtin", tool_type=ToolProviderType.BUILT_IN,
label_name="search", label_name="search",
) )
binding2 = ToolLabelBinding( binding2 = ToolLabelBinding(
tool_id=tool_id, tool_id=tool_id,
tool_type="builtin", tool_type=ToolProviderType.BUILT_IN,
label_name="web", label_name="web",
) )

View File

@ -75,6 +75,7 @@ import pytest
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from models.dataset import Dataset from models.dataset import Dataset
from models.enums import TagType
from models.model import App, Tag, TagBinding from models.model import App, Tag, TagBinding
from services.tag_service import TagService from services.tag_service import TagService
@ -102,7 +103,7 @@ class TagServiceTestDataFactory:
def create_tag_mock( def create_tag_mock(
tag_id: str = "tag-123", tag_id: str = "tag-123",
name: str = "Test Tag", name: str = "Test Tag",
tag_type: str = "app", tag_type: TagType = TagType.APP,
tenant_id: str = "tenant-123", tenant_id: str = "tenant-123",
**kwargs, **kwargs,
) -> Mock: ) -> Mock:
@ -705,7 +706,7 @@ class TestTagServiceCRUD:
# Verify tag attributes # Verify tag attributes
added_tag = mock_db_session.add.call_args[0][0] added_tag = mock_db_session.add.call_args[0][0]
assert added_tag.name == "New Tag", "Tag name should match" assert added_tag.name == "New Tag", "Tag name should match"
assert added_tag.type == "app", "Tag type should match" assert added_tag.type == TagType.APP, "Tag type should match"
assert added_tag.created_by == "user-123", "Created by should match current user" assert added_tag.created_by == "user-123", "Created by should match current user"
assert added_tag.tenant_id == "tenant-123", "Tenant ID should match current tenant" assert added_tag.tenant_id == "tenant-123", "Tenant ID should match current tenant"