diff --git a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py index 45d0e94bfb..fef719a086 100644 --- a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py +++ b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py @@ -2,6 +2,7 @@ from typing import Optional from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.queue_entities import ( + AppQueueEvent, QueueNodeFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, @@ -128,3 +129,12 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): } ), PublishFrom.APPLICATION_MANAGER ) + + def on_event(self, event: AppQueueEvent) -> None: + """ + Publish event + """ + self._queue_manager.publish( + event, + PublishFrom.APPLICATION_MANAGER + ) diff --git a/api/core/app/apps/workflow/workflow_event_trigger_callback.py b/api/core/app/apps/workflow/workflow_event_trigger_callback.py index e15ebd5548..eea456e151 100644 --- a/api/core/app/apps/workflow/workflow_event_trigger_callback.py +++ b/api/core/app/apps/workflow/workflow_event_trigger_callback.py @@ -2,6 +2,7 @@ from typing import Optional from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.queue_entities import ( + AppQueueEvent, QueueNodeFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, @@ -119,3 +120,9 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): Publish text chunk """ pass + + def on_event(self, event: AppQueueEvent) -> None: + """ + Publish event + """ + pass diff --git a/api/core/workflow/callbacks/base_workflow_callback.py b/api/core/workflow/callbacks/base_workflow_callback.py index c2546050c5..dd5a30f611 100644 --- a/api/core/workflow/callbacks/base_workflow_callback.py +++ b/api/core/workflow/callbacks/base_workflow_callback.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import Optional +from core.app.entities.queue_entities import AppQueueEvent from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeType @@ -70,3 +71,10 @@ class BaseWorkflowCallback(ABC): Publish text chunk """ raise NotImplementedError + + @abstractmethod + def on_event(self, event: AppQueueEvent) -> None: + """ + Publish event + """ + raise NotImplementedError diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 4d5970aaef..87ba4239f8 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -133,11 +133,13 @@ class KnowledgeRetrievalNode(BaseNode): Document.enabled == True, Document.archived == False, ).first() + resource_number = 1 if dataset and document: source = { 'metadata': { '_source': 'knowledge', + 'position': resource_number, 'dataset_id': dataset.id, 'dataset_name': dataset.name, 'document_id': document.id, @@ -148,14 +150,17 @@ class KnowledgeRetrievalNode(BaseNode): 'score': document_score_list.get(segment.index_node_id, None), 'segment_hit_count': segment.hit_count, 'segment_word_count': segment.word_count, - 'segment_position': segment.position - } + 'segment_position': segment.position, + 'segment_index_node_hash': segment.index_node_hash, + }, + 'title': document.name } if segment.answer: source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' else: source['content'] = segment.content context_list.append(source) + resource_number += 1 return context_list diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 0d860f5dd6..596e439a7a 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -2,6 +2,7 @@ from collections.abc import Generator from typing import Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.app.entities.queue_entities import QueueRetrieverResourcesEvent from core.entities.model_entities import ModelStatus from core.entities.provider_entities import QuotaUnit from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError @@ -220,8 +221,8 @@ class LLMNode(BaseNode): :param variable_pool: variable pool :return: """ - if not node_data.context.enabled: - return None + # if not node_data.context.enabled: + # return None context_value = variable_pool.get_variable_value(node_data.context.variable_selector) if context_value: @@ -229,16 +230,58 @@ class LLMNode(BaseNode): return context_value elif isinstance(context_value, list): context_str = '' + original_retriever_resource = [] for item in context_value: if 'content' not in item: raise ValueError(f'Invalid context structure: {item}') context_str += item['content'] + '\n' + retriever_resource = self._convert_to_original_retriever_resource(item) + if retriever_resource: + original_retriever_resource.append(retriever_resource) + + if self.callbacks: + for callback in self.callbacks: + callback.on_event( + event=QueueRetrieverResourcesEvent( + retriever_resources=original_retriever_resource + ) + ) + return context_str.strip() return None + def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]: + """ + Convert to original retriever resource, temp. + :param context_dict: context dict + :return: + """ + if '_source' in context_dict and context_dict['_source'] == 'knowledge': + metadata = context_dict.get('metadata', {}) + source = { + 'position': metadata.get('position'), + 'dataset_id': metadata.get('dataset_id'), + 'dataset_name': metadata.get('dataset_name'), + 'document_id': metadata.get('document_id'), + 'document_name': metadata.get('document_name'), + 'data_source_type': metadata.get('document_data_source_type'), + 'segment_id': metadata.get('segment_id'), + 'retriever_from': metadata.get('retriever_from'), + 'score': metadata.get('score'), + 'hit_count': metadata.get('segment_hit_count'), + 'word_count': metadata.get('segment_word_count'), + 'segment_position': metadata.get('segment_position'), + 'index_node_hash': metadata.get('segment_index_node_hash'), + 'content': context_dict.get('content'), + } + + return source + + return None + def _fetch_model_config(self, node_data: LLMNodeData) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: """ Fetch model config