fix retriever resource

This commit is contained in:
takatost 2024-03-18 16:38:39 +08:00
parent 5ed181dd42
commit 61b41ca04b
5 changed files with 77 additions and 4 deletions

View File

@ -2,6 +2,7 @@ from typing import Optional
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
@ -128,3 +129,12 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
}
), PublishFrom.APPLICATION_MANAGER
)
def on_event(self, event: AppQueueEvent) -> None:
"""
Publish event
"""
self._queue_manager.publish(
event,
PublishFrom.APPLICATION_MANAGER
)

View File

@ -2,6 +2,7 @@ from typing import Optional
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
@ -119,3 +120,9 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
Publish text chunk
"""
pass
def on_event(self, event: AppQueueEvent) -> None:
"""
Publish event
"""
pass

View File

@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from typing import Optional
from core.app.entities.queue_entities import AppQueueEvent
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
@ -70,3 +71,10 @@ class BaseWorkflowCallback(ABC):
Publish text chunk
"""
raise NotImplementedError
@abstractmethod
def on_event(self, event: AppQueueEvent) -> None:
"""
Publish event
"""
raise NotImplementedError

View File

@ -133,11 +133,13 @@ class KnowledgeRetrievalNode(BaseNode):
Document.enabled == True,
Document.archived == False,
).first()
resource_number = 1
if dataset and document:
source = {
'metadata': {
'_source': 'knowledge',
'position': resource_number,
'dataset_id': dataset.id,
'dataset_name': dataset.name,
'document_id': document.id,
@ -148,14 +150,17 @@ class KnowledgeRetrievalNode(BaseNode):
'score': document_score_list.get(segment.index_node_id, None),
'segment_hit_count': segment.hit_count,
'segment_word_count': segment.word_count,
'segment_position': segment.position
}
'segment_position': segment.position,
'segment_index_node_hash': segment.index_node_hash,
},
'title': document.name
}
if segment.answer:
source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
else:
source['content'] = segment.content
context_list.append(source)
resource_number += 1
return context_list

View File

@ -2,6 +2,7 @@ from collections.abc import Generator
from typing import Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
@ -220,8 +221,8 @@ class LLMNode(BaseNode):
:param variable_pool: variable pool
:return:
"""
if not node_data.context.enabled:
return None
# if not node_data.context.enabled:
# return None
context_value = variable_pool.get_variable_value(node_data.context.variable_selector)
if context_value:
@ -229,16 +230,58 @@ class LLMNode(BaseNode):
return context_value
elif isinstance(context_value, list):
context_str = ''
original_retriever_resource = []
for item in context_value:
if 'content' not in item:
raise ValueError(f'Invalid context structure: {item}')
context_str += item['content'] + '\n'
retriever_resource = self._convert_to_original_retriever_resource(item)
if retriever_resource:
original_retriever_resource.append(retriever_resource)
if self.callbacks:
for callback in self.callbacks:
callback.on_event(
event=QueueRetrieverResourcesEvent(
retriever_resources=original_retriever_resource
)
)
return context_str.strip()
return None
def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
"""
Convert to original retriever resource, temp.
:param context_dict: context dict
:return:
"""
if '_source' in context_dict and context_dict['_source'] == 'knowledge':
metadata = context_dict.get('metadata', {})
source = {
'position': metadata.get('position'),
'dataset_id': metadata.get('dataset_id'),
'dataset_name': metadata.get('dataset_name'),
'document_id': metadata.get('document_id'),
'document_name': metadata.get('document_name'),
'data_source_type': metadata.get('document_data_source_type'),
'segment_id': metadata.get('segment_id'),
'retriever_from': metadata.get('retriever_from'),
'score': metadata.get('score'),
'hit_count': metadata.get('segment_hit_count'),
'word_count': metadata.get('segment_word_count'),
'segment_position': metadata.get('segment_position'),
'index_node_hash': metadata.get('segment_index_node_hash'),
'content': context_dict.get('content'),
}
return source
return None
def _fetch_model_config(self, node_data: LLMNodeData) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config