diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 5bc453420d..a6ac0faf44 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -9,6 +9,7 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAp from core.entities.provider_configuration import ProviderModelBundle from core.file import File, FileUploadConfig from core.model_runtime.entities.model_entities import AIModelEntity +from models.enums import CreatorUserRole if TYPE_CHECKING: from core.ops.ops_trace_manager import TraceQueueManager @@ -80,6 +81,11 @@ class InvokeFrom(StrEnum): return "dev" + def to_creator_user_role(self) -> CreatorUserRole: + if self in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}: + return CreatorUserRole.ACCOUNT + return CreatorUserRole.END_USER + class ModelConfigWithCredentialsEntity(BaseModel): """ diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index f8f85d141a..99ec6cb5f3 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -63,6 +63,7 @@ from libs.json_in_md_parser import parse_and_check_json_markdown from models import UploadFile from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument +from models.enums import CreatorUserRole from services.external_knowledge_service import ExternalDatasetService default_retrieval_model: dict[str, Any] = { @@ -176,13 +177,13 @@ class DatasetRetrieval: ) all_documents = [] - user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user" + creator_user_role = invoke_from.to_creator_user_role() if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: all_documents = self.single_retrieve( app_id, tenant_id, user_id, - user_from, + creator_user_role, query, available_datasets, model_instance, @@ -197,7 +198,7 @@ class DatasetRetrieval: app_id, tenant_id, user_id, - user_from, + creator_user_role, available_datasets, query, retrieve_config.top_k or 0, @@ -334,7 +335,7 @@ class DatasetRetrieval: app_id: str, tenant_id: str, user_id: str, - user_from: str, + creator_user_role: CreatorUserRole, query: str, available_datasets: list, model_instance: ModelInstance, @@ -444,7 +445,7 @@ class DatasetRetrieval: weights=retrieval_model_config.get("weights", None), document_ids_filter=document_ids_filter, ) - self._on_query(query, None, [dataset_id], app_id, user_from, user_id) + self._on_query(query, None, [dataset_id], app_id, creator_user_role, user_id) if results: thread = threading.Thread( @@ -466,7 +467,7 @@ class DatasetRetrieval: app_id: str, tenant_id: str, user_id: str, - user_from: str, + creator_user_role: CreatorUserRole, available_datasets: list, query: str | None, top_k: int, @@ -584,7 +585,7 @@ class DatasetRetrieval: if thread_exceptions: raise thread_exceptions[0] - self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id) + self._on_query(query, attachment_ids, dataset_ids, app_id, creator_user_role, user_id) if all_documents: # add thread to call _on_retrieval_end @@ -733,7 +734,7 @@ class DatasetRetrieval: attachment_ids: list[str] | None, dataset_ids: list[str], app_id: str, - user_from: str, + creator_user_role: CreatorUserRole, user_id: str, ): """ @@ -755,7 +756,7 @@ class DatasetRetrieval: content=json.dumps(contents), source="app", source_app_id=app_id, - created_by_role=user_from, + created_by_role=creator_user_role, created_by=user_id, ) dataset_queries.append(dataset_query) diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 8670a71aa3..dbd340374f 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -55,6 +55,7 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.json_in_md_parser import parse_and_check_json_markdown from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog +from models.enums import CreatorUserRole, UserFrom from services.feature_service import FeatureService from .entities import KnowledgeRetrievalNodeData @@ -268,6 +269,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD usage = self._merge_usage(usage, metadata_usage) all_documents = [] dataset_retrieval = DatasetRetrieval() + creator_user_role = CreatorUserRole.ACCOUNT if self.user_from == UserFrom.ACCOUNT else CreatorUserRole.END_USER if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query: # fetch model config if node_data.single_retrieval_config is None: @@ -292,7 +294,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD tenant_id=self.tenant_id, user_id=self.user_id, app_id=self.app_id, - user_from=self.user_from.value, + creator_user_role=creator_user_role, query=query, model_config=model_config, model_instance=model_instance, @@ -334,7 +336,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD app_id=self.app_id, tenant_id=self.tenant_id, user_id=self.user_id, - user_from=self.user_from.value, + creator_user_role=creator_user_role, available_datasets=available_datasets, query=query, top_k=node_data.multiple_retrieval_config.top_k,