mirror of
https://github.com/langgenius/dify.git
synced 2026-04-23 16:37:44 +08:00
refactor: use EnumText for Conversation/Message invoke_from and from_source (#33901)
This commit is contained in:
parent
6ecf89e262
commit
2b6f761dfe
@ -50,7 +50,7 @@ class BuiltinTool(Tool):
|
|||||||
return ModelInvocationUtils.invoke(
|
return ModelInvocationUtils.invoke(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
tenant_id=self.runtime.tenant_id or "",
|
tenant_id=self.runtime.tenant_id or "",
|
||||||
tool_type="builtin",
|
tool_type=ToolProviderType.BUILT_IN,
|
||||||
tool_name=self.entity.identity.name,
|
tool_name=self.entity.identity.name,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -38,7 +38,7 @@ class ToolLabelManager:
|
|||||||
db.session.add(
|
db.session.add(
|
||||||
ToolLabelBinding(
|
ToolLabelBinding(
|
||||||
tool_id=provider_id,
|
tool_id=provider_id,
|
||||||
tool_type=controller.provider_type.value,
|
tool_type=controller.provider_type,
|
||||||
label_name=label,
|
label_name=label,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -58,7 +58,7 @@ class ToolLabelManager:
|
|||||||
raise ValueError("Unsupported tool type")
|
raise ValueError("Unsupported tool type")
|
||||||
stmt = select(ToolLabelBinding.label_name).where(
|
stmt = select(ToolLabelBinding.label_name).where(
|
||||||
ToolLabelBinding.tool_id == provider_id,
|
ToolLabelBinding.tool_id == provider_id,
|
||||||
ToolLabelBinding.tool_type == controller.provider_type.value,
|
ToolLabelBinding.tool_type == controller.provider_type,
|
||||||
)
|
)
|
||||||
labels = db.session.scalars(stmt).all()
|
labels = db.session.scalars(stmt).all()
|
||||||
|
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from decimal import Decimal
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from core.model_manager import ModelManager
|
from core.model_manager import ModelManager
|
||||||
|
from core.tools.entities.tool_entities import ToolProviderType
|
||||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult
|
from dify_graph.model_runtime.entities.llm_entities import LLMResult
|
||||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage
|
from dify_graph.model_runtime.entities.message_entities import PromptMessage
|
||||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||||
@ -78,7 +79,7 @@ class ModelInvocationUtils:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def invoke(
|
def invoke(
|
||||||
user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage]
|
user_id: str, tenant_id: str, tool_type: ToolProviderType, tool_name: str, prompt_messages: list[PromptMessage]
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""
|
"""
|
||||||
invoke model with parameters in user's own context
|
invoke model with parameters in user's own context
|
||||||
|
|||||||
@ -43,6 +43,7 @@ from .enums import (
|
|||||||
MessageChainType,
|
MessageChainType,
|
||||||
MessageFileBelongsTo,
|
MessageFileBelongsTo,
|
||||||
MessageStatus,
|
MessageStatus,
|
||||||
|
TagType,
|
||||||
)
|
)
|
||||||
from .provider_ids import GenericProviderID
|
from .provider_ids import GenericProviderID
|
||||||
from .types import EnumText, LongText, StringUUID
|
from .types import EnumText, LongText, StringUUID
|
||||||
@ -2404,7 +2405,7 @@ class Tag(TypeBase):
|
|||||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||||
)
|
)
|
||||||
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||||
type: Mapped[str] = mapped_column(String(16), nullable=False)
|
type: Mapped[TagType] = mapped_column(EnumText(TagType, length=16), nullable=False)
|
||||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
|||||||
@ -13,12 +13,16 @@ from sqlalchemy.orm import Mapped, mapped_column
|
|||||||
|
|
||||||
from core.tools.entities.common_entities import I18nObject
|
from core.tools.entities.common_entities import I18nObject
|
||||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
|
from core.tools.entities.tool_entities import (
|
||||||
|
ApiProviderSchemaType,
|
||||||
|
ToolProviderType,
|
||||||
|
WorkflowToolParameterConfiguration,
|
||||||
|
)
|
||||||
|
|
||||||
from .base import TypeBase
|
from .base import TypeBase
|
||||||
from .engine import db
|
from .engine import db
|
||||||
from .model import Account, App, Tenant
|
from .model import Account, App, Tenant
|
||||||
from .types import LongText, StringUUID
|
from .types import EnumText, LongText, StringUUID
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from core.entities.mcp_provider import MCPProviderEntity
|
from core.entities.mcp_provider import MCPProviderEntity
|
||||||
@ -208,7 +212,7 @@ class ToolLabelBinding(TypeBase):
|
|||||||
# tool id
|
# tool id
|
||||||
tool_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
tool_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
# tool type
|
# tool type
|
||||||
tool_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False)
|
||||||
# label name
|
# label name
|
||||||
label_name: Mapped[str] = mapped_column(String(40), nullable=False)
|
label_name: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||||
|
|
||||||
@ -386,7 +390,7 @@ class ToolModelInvoke(TypeBase):
|
|||||||
# provider
|
# provider
|
||||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
# type
|
# type
|
||||||
tool_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False)
|
||||||
# tool name
|
# tool name
|
||||||
tool_name: Mapped[str] = mapped_column(String(128), nullable=False)
|
tool_name: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||||
# invoke parameters
|
# invoke parameters
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound
|
|||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
|
from models.enums import TagType
|
||||||
from models.model import App, Tag, TagBinding
|
from models.model import App, Tag, TagBinding
|
||||||
|
|
||||||
|
|
||||||
@ -83,7 +84,7 @@ class TagService:
|
|||||||
raise ValueError("Tag name already exists")
|
raise ValueError("Tag name already exists")
|
||||||
tag = Tag(
|
tag = Tag(
|
||||||
name=args["name"],
|
name=args["name"],
|
||||||
type=args["type"],
|
type=TagType(args["type"]),
|
||||||
created_by=current_user.id,
|
created_by=current_user.id,
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from werkzeug.exceptions import NotFound
|
|||||||
|
|
||||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
from models.enums import DataSourceType
|
from models.enums import DataSourceType, TagType
|
||||||
from models.model import App, Tag, TagBinding
|
from models.model import App, Tag, TagBinding
|
||||||
from services.tag_service import TagService
|
from services.tag_service import TagService
|
||||||
|
|
||||||
@ -547,7 +547,7 @@ class TestTagService:
|
|||||||
assert result is not None
|
assert result is not None
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
assert result[0].name == "python_tag"
|
assert result[0].name == "python_tag"
|
||||||
assert result[0].type == "app"
|
assert result[0].type == TagType.APP
|
||||||
assert result[0].tenant_id == tenant.id
|
assert result[0].tenant_id == tenant.id
|
||||||
|
|
||||||
def test_get_tag_by_tag_name_no_matches(
|
def test_get_tag_by_tag_name_no_matches(
|
||||||
@ -638,7 +638,7 @@ class TestTagService:
|
|||||||
|
|
||||||
# Verify all tags are returned
|
# Verify all tags are returned
|
||||||
for tag in result:
|
for tag in result:
|
||||||
assert tag.type == "app"
|
assert tag.type == TagType.APP
|
||||||
assert tag.tenant_id == tenant.id
|
assert tag.tenant_id == tenant.id
|
||||||
assert tag.id in [t.id for t in tags]
|
assert tag.id in [t.id for t in tags]
|
||||||
|
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from controllers.console.tag.tags import (
|
|||||||
TagListApi,
|
TagListApi,
|
||||||
TagUpdateDeleteApi,
|
TagUpdateDeleteApi,
|
||||||
)
|
)
|
||||||
|
from models.enums import TagType
|
||||||
|
|
||||||
|
|
||||||
def unwrap(func):
|
def unwrap(func):
|
||||||
@ -52,7 +53,7 @@ def tag():
|
|||||||
tag = MagicMock()
|
tag = MagicMock()
|
||||||
tag.id = "tag-1"
|
tag.id = "tag-1"
|
||||||
tag.name = "test-tag"
|
tag.name = "test-tag"
|
||||||
tag.type = "knowledge"
|
tag.type = TagType.KNOWLEDGE
|
||||||
return tag
|
return tag
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -35,6 +35,7 @@ from controllers.service_api.dataset.dataset import (
|
|||||||
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
|
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.dataset import DatasetPermissionEnum
|
from models.dataset import DatasetPermissionEnum
|
||||||
|
from models.enums import TagType
|
||||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||||
from services.tag_service import TagService
|
from services.tag_service import TagService
|
||||||
|
|
||||||
@ -277,7 +278,7 @@ class TestDatasetTagsApi:
|
|||||||
mock_tag = Mock()
|
mock_tag = Mock()
|
||||||
mock_tag.id = "tag_1"
|
mock_tag.id = "tag_1"
|
||||||
mock_tag.name = "Test Tag"
|
mock_tag.name = "Test Tag"
|
||||||
mock_tag.type = "knowledge"
|
mock_tag.type = TagType.KNOWLEDGE
|
||||||
mock_tag.binding_count = "0" # Required for Pydantic validation - must be string
|
mock_tag.binding_count = "0" # Required for Pydantic validation - must be string
|
||||||
mock_tag_service.get_tags.return_value = [mock_tag]
|
mock_tag_service.get_tags.return_value = [mock_tag]
|
||||||
|
|
||||||
@ -316,7 +317,7 @@ class TestDatasetTagsApi:
|
|||||||
mock_tag = Mock()
|
mock_tag = Mock()
|
||||||
mock_tag.id = "new_tag_1"
|
mock_tag.id = "new_tag_1"
|
||||||
mock_tag.name = "New Tag"
|
mock_tag.name = "New Tag"
|
||||||
mock_tag.type = "knowledge"
|
mock_tag.type = TagType.KNOWLEDGE
|
||||||
mock_tag_service.save_tags.return_value = mock_tag
|
mock_tag_service.save_tags.return_value = mock_tag
|
||||||
mock_service_api_ns.payload = {"name": "New Tag"}
|
mock_service_api_ns.payload = {"name": "New Tag"}
|
||||||
|
|
||||||
@ -378,7 +379,7 @@ class TestDatasetTagsApi:
|
|||||||
mock_tag = Mock()
|
mock_tag = Mock()
|
||||||
mock_tag.id = "tag_1"
|
mock_tag.id = "tag_1"
|
||||||
mock_tag.name = "Updated Tag"
|
mock_tag.name = "Updated Tag"
|
||||||
mock_tag.type = "knowledge"
|
mock_tag.type = TagType.KNOWLEDGE
|
||||||
mock_tag.binding_count = "5"
|
mock_tag.binding_count = "5"
|
||||||
mock_tag_service.update_tags.return_value = mock_tag
|
mock_tag_service.update_tags.return_value = mock_tag
|
||||||
mock_tag_service.get_tag_binding_count.return_value = 5
|
mock_tag_service.get_tag_binding_count.return_value = 5
|
||||||
@ -866,7 +867,7 @@ class TestTagService:
|
|||||||
mock_tag = Mock()
|
mock_tag = Mock()
|
||||||
mock_tag.id = str(uuid.uuid4())
|
mock_tag.id = str(uuid.uuid4())
|
||||||
mock_tag.name = "New Tag"
|
mock_tag.name = "New Tag"
|
||||||
mock_tag.type = "knowledge"
|
mock_tag.type = TagType.KNOWLEDGE
|
||||||
mock_save.return_value = mock_tag
|
mock_save.return_value = mock_tag
|
||||||
|
|
||||||
result = TagService.save_tags({"name": "New Tag", "type": "knowledge"})
|
result = TagService.save_tags({"name": "New Tag", "type": "knowledge"})
|
||||||
|
|||||||
@ -12,7 +12,7 @@ This test suite covers:
|
|||||||
import json
|
import json
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from core.tools.entities.tool_entities import ApiProviderSchemaType
|
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolProviderType
|
||||||
from models.tools import (
|
from models.tools import (
|
||||||
ApiToolProvider,
|
ApiToolProvider,
|
||||||
BuiltinToolProvider,
|
BuiltinToolProvider,
|
||||||
@ -631,7 +631,7 @@ class TestToolLabelBinding:
|
|||||||
"""Test creating a tool label binding."""
|
"""Test creating a tool label binding."""
|
||||||
# Arrange
|
# Arrange
|
||||||
tool_id = "google.search"
|
tool_id = "google.search"
|
||||||
tool_type = "builtin"
|
tool_type = ToolProviderType.BUILT_IN
|
||||||
label_name = "search"
|
label_name = "search"
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
@ -655,7 +655,7 @@ class TestToolLabelBinding:
|
|||||||
# Act
|
# Act
|
||||||
label_binding = ToolLabelBinding(
|
label_binding = ToolLabelBinding(
|
||||||
tool_id=tool_id,
|
tool_id=tool_id,
|
||||||
tool_type="builtin",
|
tool_type=ToolProviderType.BUILT_IN,
|
||||||
label_name=label_name,
|
label_name=label_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -667,7 +667,7 @@ class TestToolLabelBinding:
|
|||||||
"""Test multiple labels can be bound to the same tool."""
|
"""Test multiple labels can be bound to the same tool."""
|
||||||
# Arrange
|
# Arrange
|
||||||
tool_id = "google.search"
|
tool_id = "google.search"
|
||||||
tool_type = "builtin"
|
tool_type = ToolProviderType.BUILT_IN
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
binding1 = ToolLabelBinding(
|
binding1 = ToolLabelBinding(
|
||||||
@ -688,7 +688,7 @@ class TestToolLabelBinding:
|
|||||||
def test_tool_label_binding_different_tool_types(self):
|
def test_tool_label_binding_different_tool_types(self):
|
||||||
"""Test label bindings for different tool types."""
|
"""Test label bindings for different tool types."""
|
||||||
# Arrange
|
# Arrange
|
||||||
tool_types = ["builtin", "api", "workflow"]
|
tool_types = [ToolProviderType.BUILT_IN, ToolProviderType.API, ToolProviderType.WORKFLOW]
|
||||||
|
|
||||||
# Act & Assert
|
# Act & Assert
|
||||||
for tool_type in tool_types:
|
for tool_type in tool_types:
|
||||||
@ -951,12 +951,12 @@ class TestToolProviderRelationships:
|
|||||||
# Act
|
# Act
|
||||||
binding1 = ToolLabelBinding(
|
binding1 = ToolLabelBinding(
|
||||||
tool_id=tool_id,
|
tool_id=tool_id,
|
||||||
tool_type="builtin",
|
tool_type=ToolProviderType.BUILT_IN,
|
||||||
label_name="search",
|
label_name="search",
|
||||||
)
|
)
|
||||||
binding2 = ToolLabelBinding(
|
binding2 = ToolLabelBinding(
|
||||||
tool_id=tool_id,
|
tool_id=tool_id,
|
||||||
tool_type="builtin",
|
tool_type=ToolProviderType.BUILT_IN,
|
||||||
label_name="web",
|
label_name="web",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -75,6 +75,7 @@ import pytest
|
|||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
|
from models.enums import TagType
|
||||||
from models.model import App, Tag, TagBinding
|
from models.model import App, Tag, TagBinding
|
||||||
from services.tag_service import TagService
|
from services.tag_service import TagService
|
||||||
|
|
||||||
@ -102,7 +103,7 @@ class TagServiceTestDataFactory:
|
|||||||
def create_tag_mock(
|
def create_tag_mock(
|
||||||
tag_id: str = "tag-123",
|
tag_id: str = "tag-123",
|
||||||
name: str = "Test Tag",
|
name: str = "Test Tag",
|
||||||
tag_type: str = "app",
|
tag_type: TagType = TagType.APP,
|
||||||
tenant_id: str = "tenant-123",
|
tenant_id: str = "tenant-123",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Mock:
|
) -> Mock:
|
||||||
@ -705,7 +706,7 @@ class TestTagServiceCRUD:
|
|||||||
# Verify tag attributes
|
# Verify tag attributes
|
||||||
added_tag = mock_db_session.add.call_args[0][0]
|
added_tag = mock_db_session.add.call_args[0][0]
|
||||||
assert added_tag.name == "New Tag", "Tag name should match"
|
assert added_tag.name == "New Tag", "Tag name should match"
|
||||||
assert added_tag.type == "app", "Tag type should match"
|
assert added_tag.type == TagType.APP, "Tag type should match"
|
||||||
assert added_tag.created_by == "user-123", "Created by should match current user"
|
assert added_tag.created_by == "user-123", "Created by should match current user"
|
||||||
assert added_tag.tenant_id == "tenant-123", "Tenant ID should match current tenant"
|
assert added_tag.tenant_id == "tenant-123", "Tenant ID should match current tenant"
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user