mirror of https://github.com/langgenius/dify.git
fix retriever resource
This commit is contained in:
parent
5ed181dd42
commit
61b41ca04b
|
|
@ -2,6 +2,7 @@ from typing import Optional
|
|||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
|
|
@ -128,3 +129,12 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
|||
}
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_event(self, event: AppQueueEvent) -> None:
|
||||
"""
|
||||
Publish event
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
event,
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from typing import Optional
|
|||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
|
|
@ -119,3 +120,9 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
|||
Publish text chunk
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_event(self, event: AppQueueEvent) -> None:
|
||||
"""
|
||||
Publish event
|
||||
"""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from core.app.entities.queue_entities import AppQueueEvent
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
|
||||
|
|
@ -70,3 +71,10 @@ class BaseWorkflowCallback(ABC):
|
|||
Publish text chunk
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_event(self, event: AppQueueEvent) -> None:
|
||||
"""
|
||||
Publish event
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -133,11 +133,13 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
).first()
|
||||
resource_number = 1
|
||||
if dataset and document:
|
||||
|
||||
source = {
|
||||
'metadata': {
|
||||
'_source': 'knowledge',
|
||||
'position': resource_number,
|
||||
'dataset_id': dataset.id,
|
||||
'dataset_name': dataset.name,
|
||||
'document_id': document.id,
|
||||
|
|
@ -148,14 +150,17 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||
'score': document_score_list.get(segment.index_node_id, None),
|
||||
'segment_hit_count': segment.hit_count,
|
||||
'segment_word_count': segment.word_count,
|
||||
'segment_position': segment.position
|
||||
}
|
||||
'segment_position': segment.position,
|
||||
'segment_index_node_hash': segment.index_node_hash,
|
||||
},
|
||||
'title': document.name
|
||||
}
|
||||
if segment.answer:
|
||||
source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
|
||||
else:
|
||||
source['content'] = segment.content
|
||||
context_list.append(source)
|
||||
resource_number += 1
|
||||
|
||||
return context_list
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from collections.abc import Generator
|
|||
from typing import Optional, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.entities.provider_entities import QuotaUnit
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
|
|
@ -220,8 +221,8 @@ class LLMNode(BaseNode):
|
|||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
if not node_data.context.enabled:
|
||||
return None
|
||||
# if not node_data.context.enabled:
|
||||
# return None
|
||||
|
||||
context_value = variable_pool.get_variable_value(node_data.context.variable_selector)
|
||||
if context_value:
|
||||
|
|
@ -229,16 +230,58 @@ class LLMNode(BaseNode):
|
|||
return context_value
|
||||
elif isinstance(context_value, list):
|
||||
context_str = ''
|
||||
original_retriever_resource = []
|
||||
for item in context_value:
|
||||
if 'content' not in item:
|
||||
raise ValueError(f'Invalid context structure: {item}')
|
||||
|
||||
context_str += item['content'] + '\n'
|
||||
|
||||
retriever_resource = self._convert_to_original_retriever_resource(item)
|
||||
if retriever_resource:
|
||||
original_retriever_resource.append(retriever_resource)
|
||||
|
||||
if self.callbacks:
|
||||
for callback in self.callbacks:
|
||||
callback.on_event(
|
||||
event=QueueRetrieverResourcesEvent(
|
||||
retriever_resources=original_retriever_resource
|
||||
)
|
||||
)
|
||||
|
||||
return context_str.strip()
|
||||
|
||||
return None
|
||||
|
||||
def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
|
||||
"""
|
||||
Convert to original retriever resource, temp.
|
||||
:param context_dict: context dict
|
||||
:return:
|
||||
"""
|
||||
if '_source' in context_dict and context_dict['_source'] == 'knowledge':
|
||||
metadata = context_dict.get('metadata', {})
|
||||
source = {
|
||||
'position': metadata.get('position'),
|
||||
'dataset_id': metadata.get('dataset_id'),
|
||||
'dataset_name': metadata.get('dataset_name'),
|
||||
'document_id': metadata.get('document_id'),
|
||||
'document_name': metadata.get('document_name'),
|
||||
'data_source_type': metadata.get('document_data_source_type'),
|
||||
'segment_id': metadata.get('segment_id'),
|
||||
'retriever_from': metadata.get('retriever_from'),
|
||||
'score': metadata.get('score'),
|
||||
'hit_count': metadata.get('segment_hit_count'),
|
||||
'word_count': metadata.get('segment_word_count'),
|
||||
'segment_position': metadata.get('segment_position'),
|
||||
'index_node_hash': metadata.get('segment_index_node_hash'),
|
||||
'content': context_dict.get('content'),
|
||||
}
|
||||
|
||||
return source
|
||||
|
||||
return None
|
||||
|
||||
def _fetch_model_config(self, node_data: LLMNodeData) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
"""
|
||||
Fetch model config
|
||||
|
|
|
|||
Loading…
Reference in New Issue