mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into feat/r2
# Conflicts: # api/core/repositories/sqlalchemy_workflow_node_execution_repository.py # api/core/workflow/entities/node_entities.py # api/core/workflow/enums.py
This commit is contained in:
commit
309fffd1e4
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
npm add -g pnpm@10.8.0
|
||||
npm add -g pnpm@10.11.1
|
||||
cd web && pnpm install
|
||||
pipx install uv
|
||||
|
||||
|
|
|
|||
|
|
@ -846,6 +846,9 @@ def clear_orphaned_file_records(force: bool):
|
|||
{"type": "text", "table": "workflow_node_executions", "column": "outputs"},
|
||||
{"type": "text", "table": "conversations", "column": "introduction"},
|
||||
{"type": "text", "table": "conversations", "column": "system_instruction"},
|
||||
{"type": "text", "table": "accounts", "column": "avatar"},
|
||||
{"type": "text", "table": "apps", "column": "icon"},
|
||||
{"type": "text", "table": "sites", "column": "icon"},
|
||||
{"type": "json", "table": "messages", "column": "inputs"},
|
||||
{"type": "json", "table": "messages", "column": "message"},
|
||||
]
|
||||
|
|
|
|||
|
|
@ -60,8 +60,7 @@ class NacosHttpClient:
|
|||
sign_str = tenant + "+"
|
||||
if group:
|
||||
sign_str = sign_str + group + "+"
|
||||
if sign_str:
|
||||
sign_str += ts
|
||||
sign_str += ts # Directly concatenate ts without conditional checks, because the nacos auth header forced it.
|
||||
return sign_str
|
||||
|
||||
def get_access_token(self, force_refresh=False):
|
||||
|
|
|
|||
|
|
@ -6,12 +6,12 @@ from sqlalchemy.orm import Session
|
|||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRunStatus
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
|
||||
|
|
@ -38,7 +38,7 @@ class WorkflowAppLogApi(Resource):
|
|||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
args.status = WorkflowRunStatus(args.status) if args.status else None
|
||||
args.status = WorkflowExecutionStatus(args.status) if args.status else None
|
||||
if args.created_at__before:
|
||||
args.created_at__before = isoparse(args.created_at__before)
|
||||
|
||||
|
|
|
|||
|
|
@ -24,12 +24,13 @@ from core.errors.error import (
|
|||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models.workflow import WorkflowRun, WorkflowRunStatus
|
||||
from models.workflow import WorkflowRun
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
|
@ -138,7 +139,7 @@ class WorkflowAppLogApi(Resource):
|
|||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
args.status = WorkflowRunStatus(args.status) if args.status else None
|
||||
args.status = WorkflowExecutionStatus(args.status) if args.status else None
|
||||
if args.created_at__before:
|
||||
args.created_at__before = isoparse(args.created_at__before)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,19 +1,21 @@
|
|||
from flask import request
|
||||
from flask_restful import marshal, reqparse
|
||||
from flask_restful import marshal, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services.dataset_service
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from controllers.service_api.wraps import DatasetApiResource, validate_dataset_token
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from fields.tag_fields import tag_fields
|
||||
from libs.login import current_user
|
||||
from models.dataset import Dataset, DatasetPermissionEnum
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||
from services.tag_service import TagService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
|
|
@ -320,5 +322,134 @@ class DatasetApi(DatasetApiResource):
|
|||
raise DatasetInUseError()
|
||||
|
||||
|
||||
class DatasetTagsApi(DatasetApiResource):
|
||||
@validate_dataset_token
|
||||
@marshal_with(tag_fields)
|
||||
def get(self, _, dataset_id):
|
||||
"""Get all knowledge type tags."""
|
||||
tags = TagService.get_tags("knowledge", current_user.current_tenant_id)
|
||||
|
||||
return tags, 200
|
||||
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
"""Add a knowledge type tag."""
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 50 characters.",
|
||||
type=DatasetTagsApi._validate_tag_name,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
tag = TagService.save_tags(args)
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
|
||||
return response, 200
|
||||
|
||||
@validate_dataset_token
|
||||
def patch(self, _, dataset_id):
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 50 characters.",
|
||||
type=DatasetTagsApi._validate_tag_name,
|
||||
)
|
||||
parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
|
||||
args = parser.parse_args()
|
||||
tag = TagService.update_tags(args, args.get("tag_id"))
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(args.get("tag_id"))
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
|
||||
return response, 200
|
||||
|
||||
@validate_dataset_token
|
||||
def delete(self, _, dataset_id):
|
||||
"""Delete a knowledge type tag."""
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
|
||||
args = parser.parse_args()
|
||||
TagService.delete_tag(args.get("tag_id"))
|
||||
|
||||
return 204
|
||||
|
||||
@staticmethod
|
||||
def _validate_tag_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 50:
|
||||
raise ValueError("Name must be between 1 to 50 characters.")
|
||||
return name
|
||||
|
||||
|
||||
class DatasetTagBindingApi(DatasetApiResource):
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required."
|
||||
)
|
||||
parser.add_argument(
|
||||
"target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
TagService.save_tag_binding(args)
|
||||
|
||||
return 204
|
||||
|
||||
|
||||
class DatasetTagUnbindingApi(DatasetApiResource):
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
|
||||
parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
|
||||
|
||||
args = parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
TagService.delete_tag_binding(args)
|
||||
|
||||
return 204
|
||||
|
||||
|
||||
class DatasetTagsBindingStatusApi(DatasetApiResource):
|
||||
@validate_dataset_token
|
||||
def get(self, _, *args, **kwargs):
|
||||
"""Get all knowledge type tags."""
|
||||
dataset_id = kwargs.get("dataset_id")
|
||||
tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id))
|
||||
tags_list = [{"id": tag.id, "name": tag.name} for tag in tags]
|
||||
response = {"data": tags_list, "total": len(tags)}
|
||||
return response, 200
|
||||
|
||||
|
||||
api.add_resource(DatasetListApi, "/datasets")
|
||||
api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")
|
||||
api.add_resource(DatasetTagsApi, "/datasets/tags")
|
||||
api.add_resource(DatasetTagBindingApi, "/datasets/tags/binding")
|
||||
api.add_resource(DatasetTagUnbindingApi, "/datasets/tags/unbinding")
|
||||
api.add_resource(DatasetTagsBindingStatusApi, "/datasets/<uuid:dataset_id>/tags")
|
||||
|
|
|
|||
|
|
@ -208,6 +208,28 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||
)
|
||||
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
def get(self, tenant_id, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
|
||||
class ChildChunkApi(DatasetApiResource):
|
||||
"""Resource for child chunks."""
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ class ModelConfigConverter:
|
|||
if not model_mode:
|
||||
model_mode = LLMMode.CHAT.value
|
||||
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
|
||||
model_mode = LLMMode.value_of(model_schema.model_properties[ModelPropertyKey.MODE]).value
|
||||
model_mode = LLMMode(model_schema.model_properties[ModelPropertyKey.MODE]).value
|
||||
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
|
|
|||
|
|
@ -27,8 +27,8 @@ from core.ops.ops_trace_manager import TraceQueueManager
|
|||
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
|
|
|
|||
|
|
@ -140,7 +140,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
SystemVariableKey.DIALOGUE_COUNT: self._dialogue_count,
|
||||
SystemVariableKey.APP_ID: app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id,
|
||||
SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_run_id,
|
||||
}
|
||||
|
||||
# init variable pool
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator, Mapping
|
||||
|
|
@ -57,26 +56,23 @@ from core.app.entities.task_entities import (
|
|||
WorkflowTaskState,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models import Conversation, EndUser, Message, MessageFile
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -126,8 +122,14 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
|
||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||
SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_run_id,
|
||||
},
|
||||
workflow_info=CycleManagerWorkflowInfo(
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType(workflow.type),
|
||||
version=workflow.version,
|
||||
graph_data=workflow.graph_dict,
|
||||
),
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
)
|
||||
|
|
@ -137,7 +139,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
)
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._message_cycle_manager = MessageCycleManage(
|
||||
self._message_cycle_manager = MessageCycleManager(
|
||||
application_generate_entity=application_generate_entity, task_state=self._task_state
|
||||
)
|
||||
|
||||
|
|
@ -158,7 +160,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
:return:
|
||||
"""
|
||||
# start generate conversation name thread
|
||||
self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name(
|
||||
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
|
||||
conversation_id=self._conversation_id, query=self._application_generate_entity.query
|
||||
)
|
||||
|
||||
|
|
@ -302,15 +304,12 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# init workflow run
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start(
|
||||
session=session,
|
||||
workflow_id=self._workflow_id,
|
||||
)
|
||||
self._workflow_run_id = workflow_execution.id
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
|
||||
self._workflow_run_id = workflow_execution.id_
|
||||
message = self._get_message(session=session)
|
||||
if not message:
|
||||
raise ValueError(f"Message not found: {self._message_id}")
|
||||
message.workflow_run_id = workflow_execution.id
|
||||
message.workflow_run_id = workflow_execution.id_
|
||||
workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
|
|
@ -550,7 +549,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
status=WorkflowRunStatus.FAILED,
|
||||
status=WorkflowExecutionStatus.FAILED,
|
||||
error_message=event.error,
|
||||
conversation_id=self._conversation_id,
|
||||
trace_manager=trace_manager,
|
||||
|
|
@ -576,7 +575,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
status=WorkflowRunStatus.STOPPED,
|
||||
status=WorkflowExecutionStatus.STOPPED,
|
||||
error_message=event.get_stop_reason(),
|
||||
conversation_id=self._conversation_id,
|
||||
trace_manager=trace_manager,
|
||||
|
|
@ -604,22 +603,18 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
yield self._message_end_to_stream_response()
|
||||
break
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
self._message_cycle_manager._handle_retriever_resources(event)
|
||||
self._message_cycle_manager.handle_retriever_resources(event)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message = self._get_message(session=session)
|
||||
message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
session.commit()
|
||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||
self._message_cycle_manager._handle_annotation_reply(event)
|
||||
self._message_cycle_manager.handle_annotation_reply(event)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message = self._get_message(session=session)
|
||||
message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
session.commit()
|
||||
elif isinstance(event, QueueTextChunkEvent):
|
||||
delta_text = event.text
|
||||
|
|
@ -636,12 +631,12 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
tts_publisher.publish(queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._message_cycle_manager._message_to_stream_response(
|
||||
yield self._message_cycle_manager.message_to_stream_response(
|
||||
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
# published by moderation
|
||||
yield self._message_cycle_manager._message_replace_to_stream_response(
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(
|
||||
answer=event.text, reason=event.reason
|
||||
)
|
||||
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
||||
|
|
@ -653,7 +648,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
)
|
||||
if output_moderation_answer:
|
||||
self._task_state.answer = output_moderation_answer
|
||||
yield self._message_cycle_manager._message_replace_to_stream_response(
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(
|
||||
answer=output_moderation_answer,
|
||||
reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
|
||||
)
|
||||
|
|
@ -682,9 +677,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
message = self._get_message(session=session)
|
||||
message.answer = self._task_state.answer
|
||||
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
|
||||
message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
message_files = [
|
||||
MessageFile(
|
||||
message_id=message.id,
|
||||
|
|
@ -712,9 +705,9 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
message.answer_price_unit = usage.completion_price_unit
|
||||
message.total_price = usage.total_price
|
||||
message.currency = usage.currency
|
||||
self._task_state.metadata["usage"] = jsonable_encoder(usage)
|
||||
self._task_state.metadata.usage = usage
|
||||
else:
|
||||
self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage())
|
||||
self._task_state.metadata.usage = LLMUsage.empty_usage()
|
||||
message_was_created.send(
|
||||
message,
|
||||
application_generate_entity=self._application_generate_entity,
|
||||
|
|
@ -725,18 +718,16 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
Message end to stream response.
|
||||
:return:
|
||||
"""
|
||||
extras = {}
|
||||
if self._task_state.metadata:
|
||||
extras["metadata"] = self._task_state.metadata.copy()
|
||||
extras = self._task_state.metadata.model_dump()
|
||||
|
||||
if "annotation_reply" in extras["metadata"]:
|
||||
del extras["metadata"]["annotation_reply"]
|
||||
if self._task_state.metadata.annotation_reply:
|
||||
del extras["annotation_reply"]
|
||||
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=self._message_id,
|
||||
files=self._recorded_files,
|
||||
metadata=extras.get("metadata", {}),
|
||||
metadata=extras,
|
||||
)
|
||||
|
||||
def _handle_output_moderation_chunk(self, text: str) -> bool:
|
||||
|
|
|
|||
|
|
@ -44,15 +44,14 @@ from core.app.entities.task_entities import (
|
|||
)
|
||||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.workflow.entities.node_execution_entities import NodeExecution
|
||||
from core.workflow.entities.workflow_execution_entities import WorkflowExecution
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from models import (
|
||||
Account,
|
||||
CreatorUserRole,
|
||||
EndUser,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowRun,
|
||||
)
|
||||
|
||||
|
|
@ -73,11 +72,10 @@ class WorkflowResponseConverter:
|
|||
) -> WorkflowStartStreamResponse:
|
||||
return WorkflowStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_execution.id,
|
||||
workflow_run_id=workflow_execution.id_,
|
||||
data=WorkflowStartStreamResponse.Data(
|
||||
id=workflow_execution.id,
|
||||
id=workflow_execution.id_,
|
||||
workflow_id=workflow_execution.workflow_id,
|
||||
sequence_number=workflow_execution.sequence_number,
|
||||
inputs=workflow_execution.inputs,
|
||||
created_at=int(workflow_execution.started_at.timestamp()),
|
||||
),
|
||||
|
|
@ -91,7 +89,7 @@ class WorkflowResponseConverter:
|
|||
workflow_execution: WorkflowExecution,
|
||||
) -> WorkflowFinishStreamResponse:
|
||||
created_by = None
|
||||
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id))
|
||||
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id_))
|
||||
assert workflow_run is not None
|
||||
if workflow_run.created_by_role == CreatorUserRole.ACCOUNT:
|
||||
stmt = select(Account).where(Account.id == workflow_run.created_by)
|
||||
|
|
@ -122,11 +120,10 @@ class WorkflowResponseConverter:
|
|||
|
||||
return WorkflowFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_execution.id,
|
||||
workflow_run_id=workflow_execution.id_,
|
||||
data=WorkflowFinishStreamResponse.Data(
|
||||
id=workflow_execution.id,
|
||||
id=workflow_execution.id_,
|
||||
workflow_id=workflow_execution.workflow_id,
|
||||
sequence_number=workflow_execution.sequence_number,
|
||||
status=workflow_execution.status,
|
||||
outputs=workflow_execution.outputs,
|
||||
error=workflow_execution.error_message,
|
||||
|
|
@ -146,16 +143,16 @@ class WorkflowResponseConverter:
|
|||
*,
|
||||
event: QueueNodeStartedEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: NodeExecution,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeStartStreamResponse]:
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
if not workflow_node_execution.workflow_run_id:
|
||||
if not workflow_node_execution.workflow_execution_id:
|
||||
return None
|
||||
|
||||
response = NodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_execution_id,
|
||||
data=NodeStartStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
|
|
@ -196,18 +193,18 @@ class WorkflowResponseConverter:
|
|||
| QueueNodeInLoopFailedEvent
|
||||
| QueueNodeExceptionEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: NodeExecution,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
if not workflow_node_execution.workflow_run_id:
|
||||
if not workflow_node_execution.workflow_execution_id:
|
||||
return None
|
||||
if not workflow_node_execution.finished_at:
|
||||
return None
|
||||
|
||||
return NodeFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_execution_id,
|
||||
data=NodeFinishStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
|
|
@ -239,18 +236,18 @@ class WorkflowResponseConverter:
|
|||
*,
|
||||
event: QueueNodeRetryEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: NodeExecution,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
if not workflow_node_execution.workflow_run_id:
|
||||
if not workflow_node_execution.workflow_execution_id:
|
||||
return None
|
||||
if not workflow_node_execution.finished_at:
|
||||
return None
|
||||
|
||||
return NodeRetryStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_execution_id,
|
||||
data=NodeRetryStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
|
|
|
|||
|
|
@ -25,8 +25,8 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
|||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
|
|
@ -132,7 +132,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
invoke_from=invoke_from,
|
||||
call_depth=call_depth,
|
||||
trace_manager=trace_manager,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_execution_id=workflow_run_id,
|
||||
)
|
||||
|
||||
contexts.plugin_tool_providers.set({})
|
||||
|
|
@ -279,7 +279,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id, inputs=args["inputs"]
|
||||
),
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
workflow_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
|
@ -355,7 +355,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
workflow_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
|
|
|||
|
|
@ -95,7 +95,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
SystemVariableKey.USER_ID: user_id,
|
||||
SystemVariableKey.APP_ID: app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id,
|
||||
SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_execution_id,
|
||||
}
|
||||
|
||||
variable_pool = VariablePool(
|
||||
|
|
|
|||
|
|
@ -50,16 +50,15 @@ from core.app.entities.task_entities import (
|
|||
WorkflowAppStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.workflow_execution_entities import WorkflowExecution
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole
|
||||
|
|
@ -69,7 +68,6 @@ from models.workflow import (
|
|||
WorkflowAppLog,
|
||||
WorkflowAppLogCreatedFrom,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -114,8 +112,14 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
SystemVariableKey.USER_ID: user_session_id,
|
||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||
SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_execution_id,
|
||||
},
|
||||
workflow_info=CycleManagerWorkflowInfo(
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType(workflow.type),
|
||||
version=workflow.version,
|
||||
graph_data=workflow.graph_dict,
|
||||
),
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
)
|
||||
|
|
@ -125,9 +129,7 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
)
|
||||
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_id = workflow.id
|
||||
self._workflow_features_dict = workflow.features_dict
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._workflow_run_id = ""
|
||||
|
||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
|
|
@ -266,17 +268,13 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
# override graph runtime state
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# init workflow run
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start(
|
||||
session=session,
|
||||
workflow_id=self._workflow_id,
|
||||
)
|
||||
self._workflow_run_id = workflow_execution.id
|
||||
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
# init workflow run
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
|
||||
self._workflow_run_id = workflow_execution.id_
|
||||
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
|
||||
yield start_resp
|
||||
elif isinstance(
|
||||
|
|
@ -511,9 +509,9 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
status=WorkflowRunStatus.FAILED
|
||||
status=WorkflowExecutionStatus.FAILED
|
||||
if isinstance(event, QueueWorkflowFailedEvent)
|
||||
else WorkflowRunStatus.STOPPED,
|
||||
else WorkflowExecutionStatus.STOPPED,
|
||||
error_message=event.error
|
||||
if isinstance(event, QueueWorkflowFailedEvent)
|
||||
else event.get_stop_reason(),
|
||||
|
|
@ -542,7 +540,6 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
if tts_publisher:
|
||||
tts_publisher.publish(queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._text_chunk_to_stream_response(
|
||||
delta_text, from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
|
|
@ -557,7 +554,7 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
tts_publisher.publish(None)
|
||||
|
||||
def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None:
|
||||
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id))
|
||||
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id_))
|
||||
assert workflow_run is not None
|
||||
invoke_from = self._application_generate_entity.invoke_from
|
||||
if invoke_from == InvokeFrom.SERVICE_API:
|
||||
|
|
|
|||
|
|
@ -29,8 +29,8 @@ from core.app.entities.queue_entities import (
|
|||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
AgentLogEvent,
|
||||
GraphEngineEvent,
|
||||
|
|
@ -295,7 +295,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
inputs: Mapping[str, Any] | None = {}
|
||||
process_data: Mapping[str, Any] | None = {}
|
||||
outputs: Mapping[str, Any] | None = {}
|
||||
execution_metadata: Mapping[NodeRunMetadataKey, Any] | None = {}
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = {}
|
||||
if node_run_result:
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
|
|
|
|||
|
|
@ -77,6 +77,8 @@ class AppGenerateEntity(BaseModel):
|
|||
App Generate Entity.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
task_id: str
|
||||
|
||||
# app config
|
||||
|
|
@ -100,9 +102,6 @@ class AppGenerateEntity(BaseModel):
|
|||
# tracing instance
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
|
|
@ -206,7 +205,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
|||
|
||||
# app config
|
||||
app_config: WorkflowUIBasedAppConfig
|
||||
workflow_run_id: str
|
||||
workflow_execution_id: str
|
||||
|
||||
class SingleIterationRunEntity(BaseModel):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Optional
|
||||
|
|
@ -6,7 +6,9 @@ from typing import Any, Optional
|
|||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
|
@ -282,7 +284,7 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
|
|||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES
|
||||
retriever_resources: list[dict]
|
||||
retriever_resources: Sequence[RetrievalSourceMetadata]
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
|
|
@ -412,7 +414,7 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
|||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
|
||||
error: Optional[str] = None
|
||||
"""single iteration duration map"""
|
||||
|
|
@ -446,7 +448,7 @@ class QueueNodeRetryEvent(QueueNodeStartedEvent):
|
|||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
retry_index: int # retry index
|
||||
|
|
@ -480,7 +482,7 @@ class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
|||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
|
@ -513,7 +515,7 @@ class QueueNodeInLoopFailedEvent(AppQueueEvent):
|
|||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
|
@ -546,7 +548,7 @@ class QueueNodeExceptionEvent(AppQueueEvent):
|
|||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
|
@ -579,7 +581,7 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
|||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
|
|
|||
|
|
@ -2,12 +2,29 @@ from collections.abc import Mapping, Sequence
|
|||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class AnnotationReplyAccount(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class AnnotationReply(BaseModel):
|
||||
id: str
|
||||
account: AnnotationReplyAccount
|
||||
|
||||
|
||||
class TaskStateMetadata(BaseModel):
|
||||
annotation_reply: AnnotationReply | None = None
|
||||
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(default_factory=list)
|
||||
usage: LLMUsage | None = None
|
||||
|
||||
|
||||
class TaskState(BaseModel):
|
||||
|
|
@ -15,7 +32,7 @@ class TaskState(BaseModel):
|
|||
TaskState entity
|
||||
"""
|
||||
|
||||
metadata: dict = {}
|
||||
metadata: TaskStateMetadata = Field(default_factory=TaskStateMetadata)
|
||||
|
||||
|
||||
class EasyUITaskState(TaskState):
|
||||
|
|
@ -189,7 +206,6 @@ class WorkflowStartStreamResponse(StreamResponse):
|
|||
|
||||
id: str
|
||||
workflow_id: str
|
||||
sequence_number: int
|
||||
inputs: Mapping[str, Any]
|
||||
created_at: int
|
||||
|
||||
|
|
@ -210,7 +226,6 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
|||
|
||||
id: str
|
||||
workflow_id: str
|
||||
sequence_number: int
|
||||
status: str
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
|
|
@ -307,7 +322,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
|||
status: str
|
||||
error: Optional[str] = None
|
||||
elapsed_time: float
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||
|
|
@ -376,7 +391,7 @@ class NodeRetryStreamResponse(StreamResponse):
|
|||
status: str
|
||||
error: Optional[str] = None
|
||||
elapsed_time: float
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
|
|
@ -43,7 +42,7 @@ from core.app.entities.task_entities import (
|
|||
StreamResponse,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
|
|
@ -51,7 +50,6 @@ from core.model_runtime.entities.message_entities import (
|
|||
AssistantPromptMessage,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
|
|
@ -63,7 +61,7 @@ from models.model import AppMode, Conversation, Message, MessageAgentThought
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleManage):
|
||||
class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
"""
|
||||
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
|
@ -104,6 +102,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
)
|
||||
)
|
||||
|
||||
self._message_cycle_manager = MessageCycleManager(
|
||||
application_generate_entity=application_generate_entity,
|
||||
task_state=self._task_state,
|
||||
)
|
||||
|
||||
self._conversation_name_generate_thread: Optional[Thread] = None
|
||||
|
||||
def process(
|
||||
|
|
@ -115,7 +118,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
]:
|
||||
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
|
||||
# start generate conversation name thread
|
||||
self._conversation_name_generate_thread = self._generate_conversation_name(
|
||||
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
|
||||
conversation_id=self._conversation_id, query=self._application_generate_entity.query or ""
|
||||
)
|
||||
|
||||
|
|
@ -136,9 +139,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
if isinstance(stream_response, ErrorStreamResponse):
|
||||
raise stream_response.err
|
||||
elif isinstance(stream_response, MessageEndStreamResponse):
|
||||
extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)}
|
||||
extras = {"usage": self._task_state.llm_result.usage.model_dump()}
|
||||
if self._task_state.metadata:
|
||||
extras["metadata"] = self._task_state.metadata
|
||||
extras["metadata"] = self._task_state.metadata.model_dump()
|
||||
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
|
||||
if self._conversation_mode == AppMode.COMPLETION.value:
|
||||
response = CompletionAppBlockingResponse(
|
||||
|
|
@ -277,7 +280,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
)
|
||||
if output_moderation_answer:
|
||||
self._task_state.llm_result.message.content = output_moderation_answer
|
||||
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(
|
||||
answer=output_moderation_answer
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Save message
|
||||
|
|
@ -286,9 +291,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
message_end_resp = self._message_end_to_stream_response()
|
||||
yield message_end_resp
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
self._handle_retriever_resources(event)
|
||||
self._message_cycle_manager.handle_retriever_resources(event)
|
||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||
annotation = self._handle_annotation_reply(event)
|
||||
annotation = self._message_cycle_manager.handle_annotation_reply(event)
|
||||
if annotation:
|
||||
self._task_state.llm_result.message.content = annotation.content
|
||||
elif isinstance(event, QueueAgentThoughtEvent):
|
||||
|
|
@ -296,7 +301,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
if agent_thought_response is not None:
|
||||
yield agent_thought_response
|
||||
elif isinstance(event, QueueMessageFileEvent):
|
||||
response = self._message_file_to_stream_response(event)
|
||||
response = self._message_cycle_manager.message_file_to_stream_response(event)
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent):
|
||||
|
|
@ -318,7 +323,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
self._task_state.llm_result.message.content = current_content
|
||||
|
||||
if isinstance(event, QueueLLMChunkEvent):
|
||||
yield self._message_to_stream_response(
|
||||
yield self._message_cycle_manager.message_to_stream_response(
|
||||
answer=cast(str, delta_text),
|
||||
message_id=self._message_id,
|
||||
)
|
||||
|
|
@ -328,7 +333,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
message_id=self._message_id,
|
||||
)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
yield self._message_replace_to_stream_response(answer=event.text)
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text)
|
||||
elif isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
else:
|
||||
|
|
@ -372,9 +377,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
message.provider_response_latency = time.perf_counter() - self._start_at
|
||||
message.total_price = usage.total_price
|
||||
message.currency = usage.currency
|
||||
message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
|
|
@ -423,16 +426,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
Message end to stream response.
|
||||
:return:
|
||||
"""
|
||||
self._task_state.metadata["usage"] = jsonable_encoder(self._task_state.llm_result.usage)
|
||||
|
||||
extras = {}
|
||||
if self._task_state.metadata:
|
||||
extras["metadata"] = self._task_state.metadata
|
||||
|
||||
self._task_state.metadata.usage = self._task_state.llm_result.usage
|
||||
metadata_dict = self._task_state.metadata.model_dump()
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=self._message_id,
|
||||
metadata=extras.get("metadata", {}),
|
||||
metadata=metadata_dict,
|
||||
)
|
||||
|
||||
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
|
||||
|
|
|
|||
|
|
@ -17,6 +17,8 @@ from core.app.entities.queue_entities import (
|
|||
QueueRetrieverResourcesEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AnnotationReply,
|
||||
AnnotationReplyAccount,
|
||||
EasyUITaskState,
|
||||
MessageFileStreamResponse,
|
||||
MessageReplaceStreamResponse,
|
||||
|
|
@ -30,7 +32,7 @@ from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
|
|||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
|
||||
class MessageCycleManage:
|
||||
class MessageCycleManager:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
|
|
@ -45,7 +47,7 @@ class MessageCycleManage:
|
|||
self._application_generate_entity = application_generate_entity
|
||||
self._task_state = task_state
|
||||
|
||||
def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
|
||||
def generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
|
||||
"""
|
||||
Generate conversation name.
|
||||
:param conversation_id: conversation id
|
||||
|
|
@ -102,7 +104,7 @@ class MessageCycleManage:
|
|||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
def _handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]:
|
||||
def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]:
|
||||
"""
|
||||
Handle annotation reply.
|
||||
:param event: event
|
||||
|
|
@ -111,25 +113,28 @@ class MessageCycleManage:
|
|||
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
|
||||
if annotation:
|
||||
account = annotation.account
|
||||
self._task_state.metadata["annotation_reply"] = {
|
||||
"id": annotation.id,
|
||||
"account": {"id": annotation.account_id, "name": account.name if account else "Dify user"},
|
||||
}
|
||||
self._task_state.metadata.annotation_reply = AnnotationReply(
|
||||
id=annotation.id,
|
||||
account=AnnotationReplyAccount(
|
||||
id=annotation.account_id,
|
||||
name=account.name if account else "Dify user",
|
||||
),
|
||||
)
|
||||
|
||||
return annotation
|
||||
|
||||
return None
|
||||
|
||||
def _handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None:
|
||||
def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None:
|
||||
"""
|
||||
Handle retriever resources.
|
||||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
|
||||
self._task_state.metadata["retriever_resources"] = event.retriever_resources
|
||||
self._task_state.metadata.retriever_resources = event.retriever_resources
|
||||
|
||||
def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
|
||||
def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
|
||||
"""
|
||||
Message file to stream response.
|
||||
:param event: event
|
||||
|
|
@ -166,7 +171,7 @@ class MessageCycleManage:
|
|||
|
||||
return None
|
||||
|
||||
def _message_to_stream_response(
|
||||
def message_to_stream_response(
|
||||
self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None
|
||||
) -> MessageStreamResponse:
|
||||
"""
|
||||
|
|
@ -182,7 +187,7 @@ class MessageCycleManage:
|
|||
from_variable_selector=from_variable_selector,
|
||||
)
|
||||
|
||||
def _message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
|
||||
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
|
||||
"""
|
||||
Message replace to stream response.
|
||||
:param answer: answer
|
||||
|
|
@ -1,8 +1,10 @@
|
|||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -85,7 +87,8 @@ class DatasetIndexToolCallbackHandler:
|
|||
|
||||
db.session.commit()
|
||||
|
||||
def return_retriever_resource_info(self, resource: list):
|
||||
# TODO(-LAN-): Improve type check
|
||||
def return_retriever_resource_info(self, resource: Sequence[RetrievalSourceMetadata]):
|
||||
"""Handle return_retriever_resource_info."""
|
||||
self._queue_manager.publish(
|
||||
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from core.helper.code_executor.python3.python3_transformer import Python3Templat
|
|||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT))
|
||||
|
||||
|
||||
class CodeExecutionError(Exception):
|
||||
|
|
@ -64,7 +65,7 @@ class CodeExecutor:
|
|||
:param code: code
|
||||
:return:
|
||||
"""
|
||||
url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) / "v1" / "sandbox" / "run"
|
||||
url = code_execution_endpoint_url / "v1" / "sandbox" / "run"
|
||||
|
||||
headers = {"X-Api-Key": dify_config.CODE_EXECUTION_API_KEY}
|
||||
|
||||
|
|
|
|||
|
|
@ -7,29 +7,28 @@ from configs import dify_config
|
|||
from core.helper.download import download_with_size_limit
|
||||
from core.plugin.entities.marketplace import MarketplacePluginDeclaration
|
||||
|
||||
marketplace_api_url = URL(str(dify_config.MARKETPLACE_API_URL))
|
||||
|
||||
def get_plugin_pkg_url(plugin_unique_identifier: str):
|
||||
return (URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/plugins/download").with_query(
|
||||
unique_identifier=plugin_unique_identifier
|
||||
)
|
||||
|
||||
def get_plugin_pkg_url(plugin_unique_identifier: str) -> str:
|
||||
return str((marketplace_api_url / "api/v1/plugins/download").with_query(unique_identifier=plugin_unique_identifier))
|
||||
|
||||
|
||||
def download_plugin_pkg(plugin_unique_identifier: str):
|
||||
url = str(get_plugin_pkg_url(plugin_unique_identifier))
|
||||
return download_with_size_limit(url, dify_config.PLUGIN_MAX_PACKAGE_SIZE)
|
||||
return download_with_size_limit(get_plugin_pkg_url(plugin_unique_identifier), dify_config.PLUGIN_MAX_PACKAGE_SIZE)
|
||||
|
||||
|
||||
def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplacePluginDeclaration]:
|
||||
if len(plugin_ids) == 0:
|
||||
return []
|
||||
|
||||
url = str(URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/plugins/batch")
|
||||
url = str(marketplace_api_url / "api/v1/plugins/batch")
|
||||
response = requests.post(url, json={"plugin_ids": plugin_ids})
|
||||
response.raise_for_status()
|
||||
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
|
||||
|
||||
|
||||
def record_install_plugin_event(plugin_unique_identifier: str):
|
||||
url = str(URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/stats/plugins/install_count")
|
||||
url = str(marketplace_api_url / "api/v1/stats/plugins/install_count")
|
||||
response = requests.post(url, json={"unique_identifier": plugin_unique_identifier})
|
||||
response.raise_for_status()
|
||||
|
|
|
|||
|
|
@ -1,61 +1,20 @@
|
|||
# Written by YORKI MINAKO🤡, Edited by Xiaoyi
|
||||
CONVERSATION_TITLE_PROMPT = """You need to decompose the user's input into "subject" and "intention" in order to accurately figure out what the user's input language actually is.
|
||||
Notice: the language type user uses could be diverse, which can be English, Chinese, Italian, Español, Arabic, Japanese, French, and etc.
|
||||
ENSURE your output is in the SAME language as the user's input!
|
||||
Your output is restricted only to: (Input language) Intention + Subject(short as possible)
|
||||
Your output MUST be a valid JSON.
|
||||
# Written by YORKI MINAKO🤡, Edited by Xiaoyi, Edited by yasu-oh
|
||||
CONVERSATION_TITLE_PROMPT = """You are asked to generate a concise chat title by decomposing the user’s input into two parts: “Intention” and “Subject”.
|
||||
|
||||
Tip: When the user's question is directed at you (the language model), you can add an emoji to make it more fun.
|
||||
1. Detect Input Language
|
||||
Automatically identify the language of the user’s input (e.g. English, Chinese, Italian, Español, Arabic, Japanese, French, and etc.).
|
||||
|
||||
2. Generate Title
|
||||
- Combine Intention + Subject into a single, as-short-as-possible phrase.
|
||||
- The title must be natural, friendly, and in the same language as the input.
|
||||
- If the input is a direct question to the model, you may add an emoji at the end.
|
||||
|
||||
example 1:
|
||||
User Input: hi, yesterday i had some burgers.
|
||||
3. Output Format
|
||||
Return **only** a valid JSON object with these exact keys and no additional text:
|
||||
{
|
||||
"Language Type": "The user's input is pure English",
|
||||
"Your Reasoning": "The language of my output must be pure English.",
|
||||
"Your Output": "sharing yesterday's food"
|
||||
}
|
||||
|
||||
example 2:
|
||||
User Input: hello
|
||||
{
|
||||
"Language Type": "The user's input is pure English",
|
||||
"Your Reasoning": "The language of my output must be pure English.",
|
||||
"Your Output": "Greeting myself☺️"
|
||||
}
|
||||
|
||||
|
||||
example 3:
|
||||
User Input: why mmap file: oom
|
||||
{
|
||||
"Language Type": "The user's input is written in pure English",
|
||||
"Your Reasoning": "The language of my output must be pure English.",
|
||||
"Your Output": "Asking about the reason for mmap file: oom"
|
||||
}
|
||||
|
||||
|
||||
example 4:
|
||||
User Input: www.convinceme.yesterday-you-ate-seafood.tv讲了什么?
|
||||
{
|
||||
"Language Type": "The user's input English-Chinese mixed",
|
||||
"Your Reasoning": "The English-part is an URL, the main intention is still written in Chinese, so the language of my output must be using Chinese.",
|
||||
"Your Output": "询问网站www.convinceme.yesterday-you-ate-seafood.tv"
|
||||
}
|
||||
|
||||
example 5:
|
||||
User Input: why小红的年龄is老than小明?
|
||||
{
|
||||
"Language Type": "The user's input is English-Chinese mixed",
|
||||
"Your Reasoning": "The English parts are filler words, the main intention is written in Chinese, besides, Chinese occupies a greater \"actual meaning\" than English, so the language of my output must be using Chinese.",
|
||||
"Your Output": "询问小红和小明的年龄"
|
||||
}
|
||||
|
||||
example 6:
|
||||
User Input: yo, 你今天咋样?
|
||||
{
|
||||
"Language Type": "The user's input is English-Chinese mixed",
|
||||
"Your Reasoning": "The English-part is a subjective particle, the main intention is written in Chinese, so the language of my output must be using Chinese.",
|
||||
"Your Output": "查询今日我的状态☺️"
|
||||
"Language Type": "<Detected language>",
|
||||
"Your Reasoning": "<Brief explanation in that language>",
|
||||
"Your Output": "<Intention + Subject>"
|
||||
}
|
||||
|
||||
User Input:
|
||||
|
|
|
|||
|
|
@ -17,19 +17,6 @@ class LLMMode(StrEnum):
|
|||
COMPLETION = "completion"
|
||||
CHAT = "chat"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "LLMMode":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
class LLMUsage(ModelUsage):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -129,17 +129,18 @@ def jsonable_encoder(
|
|||
sqlalchemy_safe=sqlalchemy_safe,
|
||||
)
|
||||
if dataclasses.is_dataclass(obj):
|
||||
# FIXME: mypy error, try to fix it instead of using type: ignore
|
||||
obj_dict = dataclasses.asdict(obj) # type: ignore
|
||||
return jsonable_encoder(
|
||||
obj_dict,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
custom_encoder=custom_encoder,
|
||||
sqlalchemy_safe=sqlalchemy_safe,
|
||||
)
|
||||
# Ensure obj is a dataclass instance, not a dataclass type
|
||||
if not isinstance(obj, type):
|
||||
obj_dict = dataclasses.asdict(obj)
|
||||
return jsonable_encoder(
|
||||
obj_dict,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
custom_encoder=custom_encoder,
|
||||
sqlalchemy_safe=sqlalchemy_safe,
|
||||
)
|
||||
if isinstance(obj, Enum):
|
||||
return obj.value
|
||||
if isinstance(obj, PurePath):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,11 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.ops.entities.config_entity import BaseTracingConfig
|
||||
from core.ops.entities.trace_entity import BaseTraceInfo
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App, TenantAccountJoin
|
||||
|
||||
|
||||
class BaseTraceInstance(ABC):
|
||||
|
|
@ -24,3 +28,38 @@ class BaseTraceInstance(ABC):
|
|||
Subclasses must implement specific tracing logic for activities.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_service_account_with_tenant(self, app_id: str) -> Account:
|
||||
"""
|
||||
Get service account for an app and set up its tenant.
|
||||
|
||||
Args:
|
||||
app_id: The ID of the app
|
||||
|
||||
Returns:
|
||||
Account: The service account with tenant set up
|
||||
|
||||
Raises:
|
||||
ValueError: If app, creator account or tenant cannot be found
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get the app to find its creator
|
||||
app = session.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||
|
||||
service_account = session.query(Account).filter(Account.id == app.created_by).first()
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||
|
||||
current_tenant = (
|
||||
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
|
||||
)
|
||||
if not current_tenant:
|
||||
raise ValueError(f"Current tenant not found for account {service_account.id}")
|
||||
service_account.set_tenant_id(current_tenant.tenant_id)
|
||||
|
||||
return service_account
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from datetime import datetime
|
|||
from enum import StrEnum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
|
||||
|
||||
|
||||
class BaseTraceInfo(BaseModel):
|
||||
|
|
@ -24,10 +24,13 @@ class BaseTraceInfo(BaseModel):
|
|||
return v
|
||||
return ""
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat(),
|
||||
}
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@field_serializer("start_time", "end_time")
|
||||
def serialize_datetime(self, dt: datetime | None) -> str | None:
|
||||
if dt is None:
|
||||
return None
|
||||
return dt.isoformat()
|
||||
|
||||
|
||||
class WorkflowTraceInfo(BaseTraceInfo):
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from datetime import datetime, timedelta
|
|||
from typing import Optional
|
||||
|
||||
from langfuse import Langfuse # type: ignore
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import LangfuseConfig
|
||||
|
|
@ -31,7 +31,7 @@ from core.ops.utils import filter_none_values
|
|||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App, EndUser, WorkflowNodeExecutionTriggeredFrom
|
||||
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -114,22 +114,11 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
# through workflow_run_id get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
# Find the app's creator account
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get the app to find its creator
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
app = session.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||
|
||||
service_account = session.query(Account).filter(Account.id == app.created_by).first()
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from typing import Optional, cast
|
|||
|
||||
from langsmith import Client
|
||||
from langsmith.schemas import RunBase
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import LangSmithConfig
|
||||
|
|
@ -28,10 +28,10 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
|
|||
)
|
||||
from core.ops.utils import filter_none_values, generate_dotted_order
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -139,22 +139,11 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
# through workflow_run_id get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
# Find the app's creator account
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get the app to find its creator
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
app = session.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||
|
||||
service_account = session.query(Account).filter(Account.id == app.created_by).first()
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
|
|
@ -185,7 +174,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
execution_metadata = node_execution.metadata if node_execution.metadata else {}
|
||||
node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0
|
||||
node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
|
||||
metadata = {str(key): value for key, value in execution_metadata.items()}
|
||||
metadata.update(
|
||||
{
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from typing import Optional, cast
|
|||
|
||||
from opik import Opik, Trace
|
||||
from opik.id_helpers import uuid4_to_uuid7
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import OpikConfig
|
||||
|
|
@ -22,10 +22,10 @@ from core.ops.entities.trace_entity import (
|
|||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -154,22 +154,11 @@ class OpikDataTrace(BaseTraceInstance):
|
|||
# through workflow_run_id get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
# Find the app's creator account
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get the app to find its creator
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
app = session.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||
|
||||
service_account = session.query(Account).filter(Account.id == app.created_by).first()
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
|
|
@ -246,7 +235,7 @@ class OpikDataTrace(BaseTraceInstance):
|
|||
parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id
|
||||
|
||||
if not total_tokens:
|
||||
total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0
|
||||
total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
|
||||
|
||||
span_data = {
|
||||
"trace_id": opik_trace_id,
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ from core.ops.entities.trace_entity import (
|
|||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.utils import get_message_data
|
||||
from core.workflow.entities.workflow_execution_entities import WorkflowExecution
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
|
||||
|
|
@ -386,7 +386,7 @@ class TraceTask:
|
|||
):
|
||||
self.trace_type = trace_type
|
||||
self.message_id = message_id
|
||||
self.workflow_run_id = workflow_execution.id if workflow_execution else None
|
||||
self.workflow_run_id = workflow_execution.id_ if workflow_execution else None
|
||||
self.conversation_id = conversation_id
|
||||
self.user_id = user_id
|
||||
self.timer = timer
|
||||
|
|
@ -487,6 +487,7 @@ class TraceTask:
|
|||
"file_list": file_list,
|
||||
"triggered_from": workflow_run.triggered_from,
|
||||
"user_id": user_id,
|
||||
"app_id": workflow_run.app_id,
|
||||
}
|
||||
|
||||
workflow_trace_info = WorkflowTraceInfo(
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from typing import Any, Optional, cast
|
|||
|
||||
import wandb
|
||||
import weave
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import WeaveConfig
|
||||
|
|
@ -23,10 +23,10 @@ from core.ops.entities.trace_entity import (
|
|||
)
|
||||
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -133,22 +133,11 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
# through workflow_run_id get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
# Find the app's creator account
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get the app to find its creator
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
app = session.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||
|
||||
service_account = session.query(Account).filter(Account.id == app.created_by).first()
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
|
|
@ -179,7 +168,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
execution_metadata = node_execution.metadata if node_execution.metadata else {}
|
||||
node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0
|
||||
node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
|
||||
attributes = {str(k): v for k, v in execution_metadata.items()}
|
||||
attributes.update(
|
||||
{
|
||||
|
|
|
|||
|
|
@ -31,8 +31,7 @@ from core.plugin.impl.exc import (
|
|||
PluginUniqueIdentifierError,
|
||||
)
|
||||
|
||||
plugin_daemon_inner_api_baseurl = dify_config.PLUGIN_DAEMON_URL
|
||||
plugin_daemon_inner_api_key = dify_config.PLUGIN_DAEMON_KEY
|
||||
plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL))
|
||||
|
||||
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
|
||||
|
||||
|
|
@ -53,9 +52,9 @@ class BasePluginClient:
|
|||
"""
|
||||
Make a request to the plugin daemon inner API.
|
||||
"""
|
||||
url = URL(str(plugin_daemon_inner_api_baseurl)) / path
|
||||
url = plugin_daemon_inner_api_baseurl / path
|
||||
headers = headers or {}
|
||||
headers["X-Api-Key"] = plugin_daemon_inner_api_key
|
||||
headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
|
||||
headers["Accept-Encoding"] = "gzip, deflate, br"
|
||||
|
||||
if headers.get("Content-Type") == "application/json" and isinstance(data, dict):
|
||||
|
|
|
|||
|
|
@ -85,7 +85,6 @@ class BaiduVector(BaseVector):
|
|||
end = min(start + batch_size, total_count)
|
||||
rows = []
|
||||
assert len(metadatas) == total_count, "metadatas length should be equal to total_count"
|
||||
# FIXME do you need this assert?
|
||||
for i in range(start, end, 1):
|
||||
row = Row(
|
||||
id=metadatas[i].get("doc_id", str(uuid.uuid4())),
|
||||
|
|
|
|||
|
|
@ -142,7 +142,7 @@ class ElasticSearchVector(BaseVector):
|
|||
if score > score_threshold:
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
|
|
|
|||
|
|
@ -97,6 +97,10 @@ class MilvusVector(BaseVector):
|
|||
|
||||
try:
|
||||
milvus_version = self._client.get_server_version()
|
||||
# Check if it's Zilliz Cloud - it supports full-text search with Milvus 2.5 compatibility
|
||||
if "Zilliz Cloud" in milvus_version:
|
||||
return True
|
||||
# For standard Milvus installations, check version number
|
||||
return version.parse(milvus_version).base_version >= version.parse("2.5.0").base_version
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to check Milvus version: {str(e)}. Disabling hybrid search.")
|
||||
|
|
|
|||
|
|
@ -245,4 +245,4 @@ class TidbService:
|
|||
return cluster_infos
|
||||
else:
|
||||
response.raise_for_status()
|
||||
return [] # FIXME for mypy, This line will not be reached as raise_for_status() will raise an exception
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -0,0 +1,23 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RetrievalSourceMetadata(BaseModel):
|
||||
position: Optional[int] = None
|
||||
dataset_id: Optional[str] = None
|
||||
dataset_name: Optional[str] = None
|
||||
document_id: Optional[str] = None
|
||||
document_name: Optional[str] = None
|
||||
data_source_type: Optional[str] = None
|
||||
segment_id: Optional[str] = None
|
||||
retriever_from: Optional[str] = None
|
||||
score: Optional[float] = None
|
||||
hit_count: Optional[int] = None
|
||||
word_count: Optional[int] = None
|
||||
segment_position: Optional[int] = None
|
||||
index_node_hash: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
page: Optional[int] = None
|
||||
doc_metadata: Optional[dict[str, Any]] = None
|
||||
title: Optional[str] = None
|
||||
|
|
@ -27,6 +27,8 @@ class WebsiteInfo(BaseModel):
|
|||
website import info.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
provider: str
|
||||
job_id: str
|
||||
url: str
|
||||
|
|
@ -34,12 +36,6 @@ class WebsiteInfo(BaseModel):
|
|||
tenant_id: str
|
||||
only_main_content: bool = False
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
|
||||
class ExtractSetting(BaseModel):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -70,13 +70,12 @@ class BaseDocumentTransformer(ABC):
|
|||
.. code-block:: python
|
||||
|
||||
class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
embeddings: Embeddings
|
||||
similarity_fn: Callable = cosine_similarity
|
||||
similarity_threshold: float = 0.95
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def transform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ from core.prompt.simple_prompt_transform import ModelMode
|
|||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
|
|
@ -198,21 +199,21 @@ class DatasetRetrieval:
|
|||
|
||||
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
||||
external_documents = [item for item in all_documents if item.provider == "external"]
|
||||
document_context_list = []
|
||||
retrieval_resource_list = []
|
||||
document_context_list: list[DocumentContext] = []
|
||||
retrieval_resource_list: list[RetrievalSourceMetadata] = []
|
||||
# deal with external documents
|
||||
for item in external_documents:
|
||||
document_context_list.append(DocumentContext(content=item.page_content, score=item.metadata.get("score")))
|
||||
source = {
|
||||
"dataset_id": item.metadata.get("dataset_id"),
|
||||
"dataset_name": item.metadata.get("dataset_name"),
|
||||
"document_id": item.metadata.get("document_id") or item.metadata.get("title"),
|
||||
"document_name": item.metadata.get("title"),
|
||||
"data_source_type": "external",
|
||||
"retriever_from": invoke_from.to_source(),
|
||||
"score": item.metadata.get("score"),
|
||||
"content": item.page_content,
|
||||
}
|
||||
source = RetrievalSourceMetadata(
|
||||
dataset_id=item.metadata.get("dataset_id"),
|
||||
dataset_name=item.metadata.get("dataset_name"),
|
||||
document_id=item.metadata.get("document_id") or item.metadata.get("title"),
|
||||
document_name=item.metadata.get("title"),
|
||||
data_source_type="external",
|
||||
retriever_from=invoke_from.to_source(),
|
||||
score=item.metadata.get("score"),
|
||||
content=item.page_content,
|
||||
)
|
||||
retrieval_resource_list.append(source)
|
||||
# deal with dify documents
|
||||
if dify_documents:
|
||||
|
|
@ -248,32 +249,32 @@ class DatasetRetrieval:
|
|||
.first()
|
||||
)
|
||||
if dataset and document:
|
||||
source = {
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"document_id": document.id,
|
||||
"document_name": document.name,
|
||||
"data_source_type": document.data_source_type,
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": invoke_from.to_source(),
|
||||
"score": record.score or 0.0,
|
||||
"doc_metadata": document.doc_metadata,
|
||||
}
|
||||
source = RetrievalSourceMetadata(
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
document_id=document.id,
|
||||
document_name=document.name,
|
||||
data_source_type=document.data_source_type,
|
||||
segment_id=segment.id,
|
||||
retriever_from=invoke_from.to_source(),
|
||||
score=record.score or 0.0,
|
||||
doc_metadata=document.doc_metadata,
|
||||
)
|
||||
|
||||
if invoke_from.to_source() == "dev":
|
||||
source["hit_count"] = segment.hit_count
|
||||
source["word_count"] = segment.word_count
|
||||
source["segment_position"] = segment.position
|
||||
source["index_node_hash"] = segment.index_node_hash
|
||||
source.hit_count = segment.hit_count
|
||||
source.word_count = segment.word_count
|
||||
source.segment_position = segment.position
|
||||
source.index_node_hash = segment.index_node_hash
|
||||
if segment.answer:
|
||||
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source["content"] = segment.content
|
||||
source.content = segment.content
|
||||
retrieval_resource_list.append(source)
|
||||
if hit_callback and retrieval_resource_list:
|
||||
retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.get("score") or 0.0, reverse=True)
|
||||
retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.score or 0.0, reverse=True)
|
||||
for position, item in enumerate(retrieval_resource_list, start=1):
|
||||
item["position"] = position
|
||||
item.position = position
|
||||
hit_callback.return_retriever_resource_info(retrieval_resource_list)
|
||||
if document_context_list:
|
||||
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
|
||||
|
|
@ -936,6 +937,9 @@ class DatasetRetrieval:
|
|||
return metadata_filter_document_ids, metadata_condition
|
||||
|
||||
def _replace_metadata_filter_value(self, text: str, inputs: dict) -> str:
|
||||
if not inputs:
|
||||
return text
|
||||
|
||||
def replacer(match):
|
||||
key = match.group(1)
|
||||
return str(inputs.get(key, f"{{{{{key}}}}}"))
|
||||
|
|
|
|||
|
|
@ -10,12 +10,12 @@ from sqlalchemy import select
|
|||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.entities.workflow_execution_entities import (
|
||||
from core.workflow.entities.workflow_execution import (
|
||||
WorkflowExecution,
|
||||
WorkflowExecutionStatus,
|
||||
WorkflowType,
|
||||
)
|
||||
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from models import (
|
||||
Account,
|
||||
CreatorUserRole,
|
||||
|
|
@ -104,10 +104,9 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
|
|||
status = WorkflowExecutionStatus(db_model.status)
|
||||
|
||||
return WorkflowExecution(
|
||||
id=db_model.id,
|
||||
id_=db_model.id,
|
||||
workflow_id=db_model.workflow_id,
|
||||
sequence_number=db_model.sequence_number,
|
||||
type=WorkflowType(db_model.type),
|
||||
workflow_type=WorkflowType(db_model.type),
|
||||
workflow_version=db_model.version,
|
||||
graph=graph,
|
||||
inputs=inputs,
|
||||
|
|
@ -140,14 +139,29 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
|
|||
raise ValueError("created_by_role is required in repository constructor")
|
||||
|
||||
db_model = WorkflowRun()
|
||||
db_model.id = domain_model.id
|
||||
db_model.id = domain_model.id_
|
||||
db_model.tenant_id = self._tenant_id
|
||||
if self._app_id is not None:
|
||||
db_model.app_id = self._app_id
|
||||
db_model.workflow_id = domain_model.workflow_id
|
||||
db_model.triggered_from = self._triggered_from
|
||||
db_model.sequence_number = domain_model.sequence_number
|
||||
db_model.type = domain_model.type
|
||||
|
||||
# Check if this is a new record
|
||||
with self._session_factory() as session:
|
||||
existing = session.scalar(select(WorkflowRun).where(WorkflowRun.id == domain_model.id_))
|
||||
if not existing:
|
||||
# For new records, get the next sequence number
|
||||
stmt = select(WorkflowRun.sequence_number).where(
|
||||
WorkflowRun.app_id == self._app_id,
|
||||
WorkflowRun.tenant_id == self._tenant_id,
|
||||
)
|
||||
max_sequence = session.scalar(stmt.order_by(WorkflowRun.sequence_number.desc()))
|
||||
db_model.sequence_number = (max_sequence or 0) + 1
|
||||
else:
|
||||
# For updates, keep the existing sequence number
|
||||
db_model.sequence_number = existing.sequence_number
|
||||
|
||||
db_model.type = domain_model.workflow_type
|
||||
db_model.version = domain_model.workflow_version
|
||||
db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None
|
||||
db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None
|
||||
|
|
|
|||
|
|
@ -12,19 +12,18 @@ from sqlalchemy.engine import Engine
|
|||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.entities.node_execution_entities import (
|
||||
NodeExecution,
|
||||
NodeExecutionStatus,
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
|
||||
from models import (
|
||||
Account,
|
||||
CreatorUserRole,
|
||||
EndUser,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowNodeExecutionModel,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
)
|
||||
|
||||
|
|
@ -87,9 +86,9 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
|
||||
# Initialize in-memory cache for node executions
|
||||
# Key: node_execution_id, Value: WorkflowNodeExecution (DB model)
|
||||
self._node_execution_cache: dict[str, WorkflowNodeExecution] = {}
|
||||
self._node_execution_cache: dict[str, WorkflowNodeExecutionModel] = {}
|
||||
|
||||
def _to_domain_model(self, db_model: WorkflowNodeExecution) -> NodeExecution:
|
||||
def _to_domain_model(self, db_model: WorkflowNodeExecutionModel) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Convert a database model to a domain model.
|
||||
|
||||
|
|
@ -103,16 +102,16 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
inputs = db_model.inputs_dict
|
||||
process_data = db_model.process_data_dict
|
||||
outputs = db_model.outputs_dict
|
||||
metadata = {NodeRunMetadataKey(k): v for k, v in db_model.execution_metadata_dict.items()}
|
||||
metadata = {WorkflowNodeExecutionMetadataKey(k): v for k, v in db_model.execution_metadata_dict.items()}
|
||||
|
||||
# Convert status to domain enum
|
||||
status = NodeExecutionStatus(db_model.status)
|
||||
status = WorkflowNodeExecutionStatus(db_model.status)
|
||||
|
||||
return NodeExecution(
|
||||
return WorkflowNodeExecution(
|
||||
id=db_model.id,
|
||||
node_execution_id=db_model.node_execution_id,
|
||||
workflow_id=db_model.workflow_id,
|
||||
workflow_run_id=db_model.workflow_run_id,
|
||||
workflow_execution_id=db_model.workflow_run_id,
|
||||
index=db_model.index,
|
||||
predecessor_node_id=db_model.predecessor_node_id,
|
||||
node_id=db_model.node_id,
|
||||
|
|
@ -129,7 +128,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
finished_at=db_model.finished_at,
|
||||
)
|
||||
|
||||
def to_db_model(self, domain_model: NodeExecution) -> WorkflowNodeExecution:
|
||||
def to_db_model(self, domain_model: WorkflowNodeExecution) -> WorkflowNodeExecutionModel:
|
||||
"""
|
||||
Convert a domain model to a database model.
|
||||
|
||||
|
|
@ -147,14 +146,14 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
if not self._creator_user_role:
|
||||
raise ValueError("created_by_role is required in repository constructor")
|
||||
|
||||
db_model = WorkflowNodeExecution()
|
||||
db_model = WorkflowNodeExecutionModel()
|
||||
db_model.id = domain_model.id
|
||||
db_model.tenant_id = self._tenant_id
|
||||
if self._app_id is not None:
|
||||
db_model.app_id = self._app_id
|
||||
db_model.workflow_id = domain_model.workflow_id
|
||||
db_model.triggered_from = self._triggered_from
|
||||
db_model.workflow_run_id = domain_model.workflow_run_id
|
||||
db_model.workflow_run_id = domain_model.workflow_execution_id
|
||||
db_model.index = domain_model.index
|
||||
db_model.predecessor_node_id = domain_model.predecessor_node_id
|
||||
db_model.node_execution_id = domain_model.node_execution_id
|
||||
|
|
@ -176,7 +175,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
db_model.finished_at = domain_model.finished_at
|
||||
return db_model
|
||||
|
||||
def save(self, execution: NodeExecution) -> None:
|
||||
def save(self, execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Save or update a NodeExecution domain entity to the database.
|
||||
|
||||
|
|
@ -208,7 +207,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
logger.debug(f"Updating cache for node_execution_id: {db_model.node_execution_id}")
|
||||
self._node_execution_cache[db_model.node_execution_id] = db_model
|
||||
|
||||
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[NodeExecution]:
|
||||
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve a NodeExecution by its node_execution_id.
|
||||
|
||||
|
|
@ -231,13 +230,13 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
# If not in cache, query the database
|
||||
logger.debug(f"Cache miss for node_execution_id: {node_execution_id}, querying database")
|
||||
with self._session_factory() as session:
|
||||
stmt = select(WorkflowNodeExecution).where(
|
||||
WorkflowNodeExecution.node_execution_id == node_execution_id,
|
||||
WorkflowNodeExecution.tenant_id == self._tenant_id,
|
||||
stmt = select(WorkflowNodeExecutionModel).where(
|
||||
WorkflowNodeExecutionModel.node_execution_id == node_execution_id,
|
||||
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
|
||||
)
|
||||
|
||||
if self._app_id:
|
||||
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
|
||||
|
||||
db_model = session.scalar(stmt)
|
||||
if db_model:
|
||||
|
|
@ -254,7 +253,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
workflow_run_id: str,
|
||||
order_config: Optional[OrderConfig] = None,
|
||||
triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
) -> Sequence[WorkflowNodeExecution]:
|
||||
) -> Sequence[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Retrieve all WorkflowNodeExecution database models for a specific workflow run.
|
||||
|
||||
|
|
@ -272,20 +271,20 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
A list of WorkflowNodeExecution database models
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
stmt = select(WorkflowNodeExecution).where(
|
||||
WorkflowNodeExecution.workflow_run_id == workflow_run_id,
|
||||
WorkflowNodeExecution.tenant_id == self._tenant_id,
|
||||
WorkflowNodeExecution.triggered_from == triggered_from,
|
||||
stmt = select(WorkflowNodeExecutionModel).where(
|
||||
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
||||
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
|
||||
WorkflowNodeExecutionModel.triggered_from == triggered_from,
|
||||
)
|
||||
|
||||
if self._app_id:
|
||||
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
|
||||
|
||||
# Apply ordering if provided
|
||||
if order_config and order_config.order_by:
|
||||
order_columns: list[UnaryExpression] = []
|
||||
for field in order_config.order_by:
|
||||
column = getattr(WorkflowNodeExecution, field, None)
|
||||
column = getattr(WorkflowNodeExecutionModel, field, None)
|
||||
if not column:
|
||||
continue
|
||||
if order_config.order_direction == "desc":
|
||||
|
|
@ -310,7 +309,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
workflow_run_id: str,
|
||||
order_config: Optional[OrderConfig] = None,
|
||||
triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
) -> Sequence[NodeExecution]:
|
||||
) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all NodeExecution instances for a specific workflow run.
|
||||
|
||||
|
|
@ -337,7 +336,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
|
||||
return domain_models
|
||||
|
||||
def get_running_executions(self, workflow_run_id: str) -> Sequence[NodeExecution]:
|
||||
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all running NodeExecution instances for a specific workflow run.
|
||||
|
||||
|
|
@ -351,15 +350,15 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
A list of running NodeExecution instances
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
stmt = select(WorkflowNodeExecution).where(
|
||||
WorkflowNodeExecution.workflow_run_id == workflow_run_id,
|
||||
WorkflowNodeExecution.tenant_id == self._tenant_id,
|
||||
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
stmt = select(WorkflowNodeExecutionModel).where(
|
||||
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
||||
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
|
||||
WorkflowNodeExecutionModel.status == WorkflowNodeExecutionStatus.RUNNING,
|
||||
WorkflowNodeExecutionModel.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
if self._app_id:
|
||||
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
|
||||
|
||||
db_models = session.scalars(stmt).all()
|
||||
domain_models = []
|
||||
|
|
@ -384,10 +383,10 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
It also clears the in-memory cache.
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
stmt = delete(WorkflowNodeExecution).where(WorkflowNodeExecution.tenant_id == self._tenant_id)
|
||||
stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.tenant_id == self._tenant_id)
|
||||
|
||||
if self._app_id:
|
||||
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
|
||||
|
||||
result = session.execute(stmt)
|
||||
session.commit()
|
||||
|
|
|
|||
|
|
@ -168,7 +168,7 @@ class ApiTool(Tool):
|
|||
cookies[parameter["name"]] = value
|
||||
|
||||
elif parameter["in"] == "header":
|
||||
headers[parameter["name"]] = value
|
||||
headers[parameter["name"]] = str(value)
|
||||
|
||||
# check if there is a request body and handle it
|
||||
if "requestBody" in self.api_bundle.openapi and self.api_bundle.openapi["requestBody"] is not None:
|
||||
|
|
|
|||
|
|
@ -279,7 +279,6 @@ class ToolParameter(PluginParameter):
|
|||
:param options: the options of the parameter
|
||||
"""
|
||||
# convert options to ToolParameterOption
|
||||
# FIXME fix the type error
|
||||
if options:
|
||||
option_objs = [
|
||||
PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa
|
|||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.models.document import Document as RagDocument
|
||||
from core.rag.rerank.rerank_model import RerankModelRunner
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
|
|
@ -107,7 +108,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
|||
else:
|
||||
document_context_list.append(segment.get_sign_content())
|
||||
if self.return_resource:
|
||||
context_list = []
|
||||
context_list: list[RetrievalSourceMetadata] = []
|
||||
resource_number = 1
|
||||
for segment in sorted_segments:
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||
|
|
@ -121,28 +122,28 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
|||
.first()
|
||||
)
|
||||
if dataset and document:
|
||||
source = {
|
||||
"position": resource_number,
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"document_id": document.id,
|
||||
"document_name": document.name,
|
||||
"data_source_type": document.data_source_type,
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": self.retriever_from,
|
||||
"score": document_score_list.get(segment.index_node_id, None),
|
||||
"doc_metadata": document.doc_metadata,
|
||||
}
|
||||
source = RetrievalSourceMetadata(
|
||||
position=resource_number,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
document_id=document.id,
|
||||
document_name=document.name,
|
||||
data_source_type=document.data_source_type,
|
||||
segment_id=segment.id,
|
||||
retriever_from=self.retriever_from,
|
||||
score=document_score_list.get(segment.index_node_id, None),
|
||||
doc_metadata=document.doc_metadata,
|
||||
)
|
||||
|
||||
if self.retriever_from == "dev":
|
||||
source["hit_count"] = segment.hit_count
|
||||
source["word_count"] = segment.word_count
|
||||
source["segment_position"] = segment.position
|
||||
source["index_node_hash"] = segment.index_node_hash
|
||||
source.hit_count = segment.hit_count
|
||||
source.word_count = segment.word_count
|
||||
source.segment_position = segment.position
|
||||
source.index_node_hash = segment.index_node_hash
|
||||
if segment.answer:
|
||||
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source["content"] = segment.content
|
||||
source.content = segment.content
|
||||
context_list.append(source)
|
||||
resource_number += 1
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from pydantic import BaseModel, Field
|
|||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.models.document import Document as RetrievalDocument
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
|
|
@ -14,7 +15,7 @@ from models.dataset import Dataset
|
|||
from models.dataset import Document as DatasetDocument
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
default_retrieval_model = {
|
||||
default_retrieval_model: dict[str, Any] = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
|
|
@ -79,7 +80,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||
else:
|
||||
document_ids_filter = None
|
||||
if dataset.provider == "external":
|
||||
results = []
|
||||
results: list[RetrievalDocument] = []
|
||||
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
|
|
@ -100,21 +101,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||
document.metadata["dataset_name"] = dataset.name
|
||||
results.append(document)
|
||||
# deal with external documents
|
||||
context_list = []
|
||||
context_list: list[RetrievalSourceMetadata] = []
|
||||
for position, item in enumerate(results, start=1):
|
||||
if item.metadata is not None:
|
||||
source = {
|
||||
"position": position,
|
||||
"dataset_id": item.metadata.get("dataset_id"),
|
||||
"dataset_name": item.metadata.get("dataset_name"),
|
||||
"document_id": item.metadata.get("document_id") or item.metadata.get("title"),
|
||||
"document_name": item.metadata.get("title"),
|
||||
"data_source_type": "external",
|
||||
"retriever_from": self.retriever_from,
|
||||
"score": item.metadata.get("score"),
|
||||
"title": item.metadata.get("title"),
|
||||
"content": item.page_content,
|
||||
}
|
||||
source = RetrievalSourceMetadata(
|
||||
position=position,
|
||||
dataset_id=item.metadata.get("dataset_id"),
|
||||
dataset_name=item.metadata.get("dataset_name"),
|
||||
document_id=item.metadata.get("document_id") or item.metadata.get("title"),
|
||||
document_name=item.metadata.get("title"),
|
||||
data_source_type="external",
|
||||
retriever_from=self.retriever_from,
|
||||
score=item.metadata.get("score"),
|
||||
title=item.metadata.get("title"),
|
||||
content=item.page_content,
|
||||
)
|
||||
context_list.append(source)
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
|
@ -125,7 +126,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||
return ""
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
|
||||
retrieval_resource_list = []
|
||||
retrieval_resource_list: list[RetrievalSourceMetadata] = []
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
|
|
@ -163,7 +164,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||
for item in documents:
|
||||
if item.metadata is not None and item.metadata.get("score"):
|
||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
||||
document_context_list = []
|
||||
document_context_list: list[DocumentContext] = []
|
||||
records = RetrievalService.format_retrieval_documents(documents)
|
||||
if records:
|
||||
for record in records:
|
||||
|
|
@ -197,37 +198,37 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||
.first()
|
||||
)
|
||||
if dataset and document:
|
||||
source = {
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"document_id": document.id, # type: ignore
|
||||
"document_name": document.name, # type: ignore
|
||||
"data_source_type": document.data_source_type, # type: ignore
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": self.retriever_from,
|
||||
"score": record.score or 0.0,
|
||||
"doc_metadata": document.doc_metadata, # type: ignore
|
||||
}
|
||||
source = RetrievalSourceMetadata(
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
document_id=document.id, # type: ignore
|
||||
document_name=document.name, # type: ignore
|
||||
data_source_type=document.data_source_type, # type: ignore
|
||||
segment_id=segment.id,
|
||||
retriever_from=self.retriever_from,
|
||||
score=record.score or 0.0,
|
||||
doc_metadata=document.doc_metadata, # type: ignore
|
||||
)
|
||||
|
||||
if self.retriever_from == "dev":
|
||||
source["hit_count"] = segment.hit_count
|
||||
source["word_count"] = segment.word_count
|
||||
source["segment_position"] = segment.position
|
||||
source["index_node_hash"] = segment.index_node_hash
|
||||
source.hit_count = segment.hit_count
|
||||
source.word_count = segment.word_count
|
||||
source.segment_position = segment.position
|
||||
source.index_node_hash = segment.index_node_hash
|
||||
if segment.answer:
|
||||
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source["content"] = segment.content
|
||||
source.content = segment.content
|
||||
retrieval_resource_list.append(source)
|
||||
|
||||
if self.return_resource and retrieval_resource_list:
|
||||
retrieval_resource_list = sorted(
|
||||
retrieval_resource_list,
|
||||
key=lambda x: x.get("score") or 0.0,
|
||||
key=lambda x: x.score or 0.0,
|
||||
reverse=True,
|
||||
)
|
||||
for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore
|
||||
item["position"] = position # type: ignore
|
||||
item.position = position # type: ignore
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(retrieval_resource_list)
|
||||
if document_context_list:
|
||||
|
|
|
|||
|
|
@ -66,7 +66,6 @@ class ToolFileMessageTransformer:
|
|||
if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
|
||||
raise ValueError("unexpected message type")
|
||||
|
||||
# FIXME: should do a type check here.
|
||||
assert isinstance(message.message.blob, bytes)
|
||||
tool_file_manager = ToolFileManager()
|
||||
file = tool_file_manager.create_file_by_raw(
|
||||
|
|
|
|||
|
|
@ -55,6 +55,13 @@ class ApiBasedToolSchemaParser:
|
|||
# convert parameters
|
||||
parameters = []
|
||||
if "parameters" in interface["operation"]:
|
||||
for i, parameter in enumerate(interface["operation"]["parameters"]):
|
||||
if "$ref" in parameter:
|
||||
root = openapi
|
||||
reference = parameter["$ref"].split("/")[1:]
|
||||
for ref in reference:
|
||||
root = root[ref]
|
||||
interface["operation"]["parameters"][i] = root
|
||||
for parameter in interface["operation"]["parameters"]:
|
||||
tool_parameter = ToolParameter(
|
||||
name=parameter["name"],
|
||||
|
|
|
|||
|
|
@ -1,37 +1,10 @@
|
|||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class NodeRunMetadataKey(StrEnum):
|
||||
"""
|
||||
Node Run Metadata Key.
|
||||
"""
|
||||
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
TOTAL_PRICE = "total_price"
|
||||
CURRENCY = "currency"
|
||||
TOOL_INFO = "tool_info"
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
AGENT_LOG = "agent_log"
|
||||
ITERATION_ID = "iteration_id"
|
||||
ITERATION_INDEX = "iteration_index"
|
||||
LOOP_ID = "loop_id"
|
||||
LOOP_INDEX = "loop_index"
|
||||
PARALLEL_ID = "parallel_id"
|
||||
PARALLEL_START_NODE_ID = "parallel_start_node_id"
|
||||
PARENT_PARALLEL_ID = "parent_parallel_id"
|
||||
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
|
||||
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
|
||||
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
|
||||
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
|
||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class NodeRunResult(BaseModel):
|
||||
|
|
@ -44,7 +17,7 @@ class NodeRunResult(BaseModel):
|
|||
inputs: Optional[Mapping[str, Any]] = None # node inputs
|
||||
process_data: Optional[Mapping[str, Any]] = None # process data
|
||||
outputs: Optional[Mapping[str, Any]] = None # node outputs
|
||||
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # node metadata
|
||||
metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None # node metadata
|
||||
llm_usage: Optional[LLMUsage] = None # llm usage
|
||||
|
||||
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
|
||||
|
|
|
|||
|
|
@ -37,12 +37,10 @@ class WorkflowExecution(BaseModel):
|
|||
user, tenant, and app attributes.
|
||||
"""
|
||||
|
||||
id: str = Field(...)
|
||||
id_: str = Field(...)
|
||||
workflow_id: str = Field(...)
|
||||
workflow_version: str = Field(...)
|
||||
sequence_number: int = Field(...)
|
||||
|
||||
type: WorkflowType = Field(...)
|
||||
workflow_type: WorkflowType = Field(...)
|
||||
graph: Mapping[str, Any] = Field(...)
|
||||
|
||||
inputs: Mapping[str, Any] = Field(...)
|
||||
|
|
@ -70,20 +68,18 @@ class WorkflowExecution(BaseModel):
|
|||
def new(
|
||||
cls,
|
||||
*,
|
||||
id: str,
|
||||
id_: str,
|
||||
workflow_id: str,
|
||||
sequence_number: int,
|
||||
type: WorkflowType,
|
||||
workflow_type: WorkflowType,
|
||||
workflow_version: str,
|
||||
graph: Mapping[str, Any],
|
||||
inputs: Mapping[str, Any],
|
||||
started_at: datetime,
|
||||
) -> "WorkflowExecution":
|
||||
return WorkflowExecution(
|
||||
id=id,
|
||||
id_=id_,
|
||||
workflow_id=workflow_id,
|
||||
sequence_number=sequence_number,
|
||||
type=type,
|
||||
workflow_type=workflow_type,
|
||||
workflow_version=workflow_version,
|
||||
graph=graph,
|
||||
inputs=inputs,
|
||||
|
|
@ -13,11 +13,35 @@ from typing import Any, Optional
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
|
||||
|
||||
class NodeExecutionStatus(StrEnum):
|
||||
class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
"""
|
||||
Node Run Metadata Key.
|
||||
"""
|
||||
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
TOTAL_PRICE = "total_price"
|
||||
CURRENCY = "currency"
|
||||
TOOL_INFO = "tool_info"
|
||||
AGENT_LOG = "agent_log"
|
||||
ITERATION_ID = "iteration_id"
|
||||
ITERATION_INDEX = "iteration_index"
|
||||
LOOP_ID = "loop_id"
|
||||
LOOP_INDEX = "loop_index"
|
||||
PARALLEL_ID = "parallel_id"
|
||||
PARALLEL_START_NODE_ID = "parallel_start_node_id"
|
||||
PARENT_PARALLEL_ID = "parent_parallel_id"
|
||||
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
|
||||
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
|
||||
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
|
||||
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
|
||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
|
||||
|
||||
class WorkflowNodeExecutionStatus(StrEnum):
|
||||
"""
|
||||
Node Execution Status Enum.
|
||||
"""
|
||||
|
|
@ -29,7 +53,7 @@ class NodeExecutionStatus(StrEnum):
|
|||
RETRY = "retry"
|
||||
|
||||
|
||||
class NodeExecution(BaseModel):
|
||||
class WorkflowNodeExecution(BaseModel):
|
||||
"""
|
||||
Domain model for workflow node execution.
|
||||
|
||||
|
|
@ -46,7 +70,7 @@ class NodeExecution(BaseModel):
|
|||
id: str # Unique identifier for this execution record
|
||||
node_execution_id: Optional[str] = None # Optional secondary ID for cross-referencing
|
||||
workflow_id: str # ID of the workflow this node belongs to
|
||||
workflow_run_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging)
|
||||
workflow_execution_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging)
|
||||
|
||||
# Execution positioning and flow
|
||||
index: int # Sequence number for ordering in trace visualization
|
||||
|
|
@ -61,12 +85,12 @@ class NodeExecution(BaseModel):
|
|||
outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node
|
||||
|
||||
# Execution state
|
||||
status: NodeExecutionStatus = NodeExecutionStatus.RUNNING # Current execution status
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING # Current execution status
|
||||
error: Optional[str] = None # Error message if execution failed
|
||||
elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds
|
||||
|
||||
# Additional metadata
|
||||
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.)
|
||||
metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.)
|
||||
|
||||
# Timing information
|
||||
created_at: datetime # When execution started
|
||||
|
|
@ -77,7 +101,7 @@ class NodeExecution(BaseModel):
|
|||
inputs: Optional[Mapping[str, Any]] = None,
|
||||
process_data: Optional[Mapping[str, Any]] = None,
|
||||
outputs: Optional[Mapping[str, Any]] = None,
|
||||
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None,
|
||||
metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Update the model from mappings.
|
||||
|
|
@ -13,7 +13,7 @@ class SystemVariableKey(StrEnum):
|
|||
DIALOGUE_COUNT = "dialogue_count"
|
||||
APP_ID = "app_id"
|
||||
WORKFLOW_ID = "workflow_id"
|
||||
WORKFLOW_RUN_ID = "workflow_run_id"
|
||||
WORKFLOW_EXECUTION_ID = "workflow_run_id"
|
||||
# RAG Pipeline
|
||||
DOCUMENT_ID = "document_id"
|
||||
BATCH = "batch"
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.nodes import NodeType
|
||||
|
|
@ -82,7 +83,7 @@ class NodeRunStreamChunkEvent(BaseNodeEvent):
|
|||
|
||||
|
||||
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
|
||||
retriever_resources: list[dict] = Field(..., description="retriever resources")
|
||||
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from typing import Optional
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class RouteNodeState(BaseModel):
|
||||
|
|
|
|||
|
|
@ -14,8 +14,9 @@ from flask import Flask, current_app, has_request_context
|
|||
from configs import dify_config
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseAgentEvent,
|
||||
|
|
@ -52,9 +53,8 @@ from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
|||
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -606,8 +606,6 @@ class GraphEngine:
|
|||
error=str(e),
|
||||
)
|
||||
)
|
||||
finally:
|
||||
db.session.remove()
|
||||
|
||||
def _run_node(
|
||||
self,
|
||||
|
|
@ -645,7 +643,6 @@ class GraphEngine:
|
|||
agent_strategy=agent_strategy,
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
max_retries = node_instance.node_data.retry_config.max_retries
|
||||
retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
|
||||
retries = 0
|
||||
|
|
@ -759,10 +756,12 @@ class GraphEngine:
|
|||
and node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH
|
||||
):
|
||||
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
|
||||
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
if run_result.metadata and run_result.metadata.get(
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS
|
||||
):
|
||||
# plus state total_tokens
|
||||
self.graph_runtime_state.total_tokens += int(
|
||||
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
|
||||
run_result.metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
if run_result.llm_usage:
|
||||
|
|
@ -785,13 +784,17 @@ class GraphEngine:
|
|||
|
||||
if parallel_id and parallel_start_node_id:
|
||||
metadata_dict = dict(run_result.metadata)
|
||||
metadata_dict[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
|
||||
metadata_dict[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
|
||||
metadata_dict[WorkflowNodeExecutionMetadataKey.PARALLEL_ID] = parallel_id
|
||||
metadata_dict[WorkflowNodeExecutionMetadataKey.PARALLEL_START_NODE_ID] = (
|
||||
parallel_start_node_id
|
||||
)
|
||||
if parent_parallel_id and parent_parallel_start_node_id:
|
||||
metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
|
||||
metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
|
||||
parent_parallel_start_node_id
|
||||
metadata_dict[WorkflowNodeExecutionMetadataKey.PARENT_PARALLEL_ID] = (
|
||||
parent_parallel_id
|
||||
)
|
||||
metadata_dict[
|
||||
WorkflowNodeExecutionMetadataKey.PARENT_PARALLEL_START_NODE_ID
|
||||
] = parent_parallel_start_node_id
|
||||
run_result.metadata = metadata_dict
|
||||
|
||||
yield NodeRunSucceededEvent(
|
||||
|
|
@ -856,8 +859,6 @@ class GraphEngine:
|
|||
except Exception as e:
|
||||
logger.exception(f"Node {node_instance.node_data.title} run failed")
|
||||
raise e
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
|
||||
"""
|
||||
|
|
@ -923,7 +924,7 @@ class GraphEngine:
|
|||
"error": error_result.error,
|
||||
"inputs": error_result.inputs,
|
||||
"metadata": {
|
||||
NodeRunMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy,
|
||||
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@ import json
|
|||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
|
|
@ -15,6 +18,7 @@ from core.tools.tool_manager import ToolManager
|
|||
from core.variables.segments import StringSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
|
|
@ -25,7 +29,6 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
|||
from extensions.ext_database import db
|
||||
from factories.agent_factory import get_plugin_agent_strategy
|
||||
from models.model import Conversation
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class AgentNode(ToolNode):
|
||||
|
|
@ -320,15 +323,12 @@ class AgentNode(ToolNode):
|
|||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
# get conversation
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from typing import Any, cast
|
|||
|
||||
from core.variables import ArrayFileSegment, FileSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
|
||||
from core.workflow.nodes.answer.entities import (
|
||||
AnswerNodeData,
|
||||
|
|
@ -13,7 +14,6 @@ from core.workflow.nodes.answer.entities import (
|
|||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class AnswerNode(BaseNode[AnswerNodeData]):
|
||||
|
|
|
|||
|
|
@ -4,9 +4,9 @@ from collections.abc import Generator, Mapping, Sequence
|
|||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import BaseNodeData
|
||||
|
||||
|
|
|
|||
|
|
@ -8,10 +8,10 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
|
|||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.variables.segments import ArrayFileSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.code.entities import CodeNodeData
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .exc import (
|
||||
CodeNodeError,
|
||||
|
|
|
|||
|
|
@ -26,9 +26,9 @@ from core.helper import ssrf_proxy
|
|||
from core.variables import ArrayFileSegment
|
||||
from core.variables.segments import FileSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import DocumentExtractorNodeData
|
||||
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class EndNode(BaseNode[EndNodeData]):
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class RunCompletedEvent(BaseModel):
|
||||
|
|
@ -17,7 +19,7 @@ class RunStreamChunkEvent(BaseModel):
|
|||
|
||||
|
||||
class RunRetrieverResourceEvent(BaseModel):
|
||||
retriever_resources: list[dict] = Field(..., description="retriever resources")
|
||||
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -8,12 +8,12 @@ from core.file import File, FileTransferMethod
|
|||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.http_request.executor import Executor
|
||||
from core.workflow.utils import variable_template_parser
|
||||
from factories import file_factory
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import (
|
||||
HttpRequestNodeData,
|
||||
|
|
|
|||
|
|
@ -4,12 +4,12 @@ from typing_extensions import deprecated
|
|||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.if_else.entities import IfElseNodeData
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class IfElseNode(BaseNode[IfElseNodeData]):
|
||||
|
|
|
|||
|
|
@ -12,10 +12,10 @@ from flask import Flask, current_app, has_request_context
|
|||
from configs import dify_config
|
||||
from core.variables import ArrayVariable, IntegerVariable, NoneVariable
|
||||
from core.workflow.entities.node_entities import (
|
||||
NodeRunMetadataKey,
|
||||
NodeRunResult,
|
||||
)
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseGraphEvent,
|
||||
BaseNodeEvent,
|
||||
|
|
@ -37,7 +37,6 @@ from core.workflow.nodes.base import BaseNode
|
|||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .exc import (
|
||||
InvalidIteratorValueError,
|
||||
|
|
@ -249,8 +248,8 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"output": outputs},
|
||||
metadata={
|
||||
NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
|
@ -361,16 +360,16 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
event.parallel_mode_run_id = parallel_mode_run_id
|
||||
|
||||
iter_metadata = {
|
||||
NodeRunMetadataKey.ITERATION_ID: self.node_id,
|
||||
NodeRunMetadataKey.ITERATION_INDEX: iter_run_index,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_ID: self.node_id,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_INDEX: iter_run_index,
|
||||
}
|
||||
if parallel_mode_run_id:
|
||||
# for parallel, the specific branch ID is more important than the sequential index
|
||||
iter_metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id
|
||||
iter_metadata[WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id
|
||||
|
||||
if event.route_node_state.node_run_result:
|
||||
current_metadata = event.route_node_state.node_run_result.metadata or {}
|
||||
if NodeRunMetadataKey.ITERATION_ID not in current_metadata:
|
||||
if WorkflowNodeExecutionMetadataKey.ITERATION_ID not in current_metadata:
|
||||
event.route_node_state.node_run_result.metadata = {**current_metadata, **iter_metadata}
|
||||
|
||||
return event
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.iteration.entities import IterationStartNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class IterationStartNode(BaseNode[IterationStartNodeData]):
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import Any, Optional, cast
|
|||
|
||||
from sqlalchemy import Float, and_, func, or_, text
|
||||
from sqlalchemy import cast as sqlalchemy_cast
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
|
|
@ -24,6 +25,7 @@ from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
|||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.variables import StringSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event.event import ModelInvokeCompletedEvent
|
||||
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
|
||||
|
|
@ -41,7 +43,6 @@ from extensions.ext_database import db
|
|||
from extensions.ext_redis import redis_client
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from .entities import KnowledgeRetrievalNodeData, ModelConfig
|
||||
|
|
@ -95,14 +96,15 @@ class KnowledgeRetrievalNode(LLMNode):
|
|||
redis_client.zremrangebyscore(key, 0, current_time - 60000)
|
||||
request_count = redis_client.zcard(key)
|
||||
if request_count > knowledge_rate_limit.limit:
|
||||
# add ratelimit record
|
||||
rate_limit_log = RateLimitLog(
|
||||
tenant_id=self.tenant_id,
|
||||
subscription_plan=knowledge_rate_limit.subscription_plan,
|
||||
operation="knowledge",
|
||||
)
|
||||
db.session.add(rate_limit_log)
|
||||
db.session.commit()
|
||||
with Session(db.engine) as session:
|
||||
# add ratelimit record
|
||||
rate_limit_log = RateLimitLog(
|
||||
tenant_id=self.tenant_id,
|
||||
subscription_plan=knowledge_rate_limit.subscription_plan,
|
||||
operation="knowledge",
|
||||
)
|
||||
session.add(rate_limit_log)
|
||||
session.commit()
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
|
|
@ -173,7 +175,9 @@ class KnowledgeRetrievalNode(LLMNode):
|
|||
dataset_retrieval = DatasetRetrieval()
|
||||
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(node_data.single_retrieval_config.model) # type: ignore
|
||||
if node_data.single_retrieval_config is None:
|
||||
raise ValueError("single_retrieval_config is required")
|
||||
model_instance, model_config = self.get_model_config(node_data.single_retrieval_config.model)
|
||||
# check model is support tool calling
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
|
@ -424,7 +428,7 @@ class KnowledgeRetrievalNode(LLMNode):
|
|||
raise ValueError("metadata_model_config is required")
|
||||
# get metadata model instance
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(node_data.metadata_model_config) # type: ignore
|
||||
model_instance, model_config = self.get_model_config(metadata_model_config)
|
||||
# fetch prompt messages
|
||||
prompt_template = self._get_prompt_template(
|
||||
node_data=node_data,
|
||||
|
|
@ -550,14 +554,7 @@ class KnowledgeRetrievalNode(LLMNode):
|
|||
variable_mapping[node_id + ".query"] = node_data.query_variable_selector
|
||||
return variable_mapping
|
||||
|
||||
def _fetch_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: # type: ignore
|
||||
"""
|
||||
Fetch model config
|
||||
:param model: model
|
||||
:return:
|
||||
"""
|
||||
if model is None:
|
||||
raise ValueError("model is required")
|
||||
def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
model_name = model.name
|
||||
provider_name = model.provider
|
||||
|
||||
|
|
|
|||
|
|
@ -4,9 +4,9 @@ from typing import Any, Literal, Union
|
|||
from core.file import File
|
||||
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import ListOperatorNodeData
|
||||
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ from datetime import UTC, datetime
|
|||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
import json_repair
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
|
|
@ -43,6 +45,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
|||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.variables import (
|
||||
ArrayAnySegment,
|
||||
ArrayFileSegment,
|
||||
|
|
@ -53,9 +56,10 @@ from core.variables import (
|
|||
StringSegment,
|
||||
)
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
|
|
@ -77,7 +81,6 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
|||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
from models.provider import Provider, ProviderType
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import (
|
||||
LLMNodeChatModelMessage,
|
||||
|
|
@ -267,9 +270,9 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
NodeRunMetadataKey.CURRENCY: usage.currency,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
},
|
||||
llm_usage=usage,
|
||||
)
|
||||
|
|
@ -302,8 +305,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
) -> Generator[NodeEvent, None, None]:
|
||||
db.session.close()
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=node_data_model.completion_params,
|
||||
|
|
@ -474,7 +475,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value)
|
||||
elif isinstance(context_value_variable, ArraySegment):
|
||||
context_str = ""
|
||||
original_retriever_resource = []
|
||||
original_retriever_resource: list[RetrievalSourceMetadata] = []
|
||||
for item in context_value_variable.value:
|
||||
if isinstance(item, str):
|
||||
context_str += item + "\n"
|
||||
|
|
@ -492,7 +493,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
retriever_resources=original_retriever_resource, context=context_str.strip()
|
||||
)
|
||||
|
||||
def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
|
||||
def _convert_to_original_retriever_resource(self, context_dict: dict):
|
||||
if (
|
||||
"metadata" in context_dict
|
||||
and "_source" in context_dict["metadata"]
|
||||
|
|
@ -500,24 +501,24 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
):
|
||||
metadata = context_dict.get("metadata", {})
|
||||
|
||||
source = {
|
||||
"position": metadata.get("position"),
|
||||
"dataset_id": metadata.get("dataset_id"),
|
||||
"dataset_name": metadata.get("dataset_name"),
|
||||
"document_id": metadata.get("document_id"),
|
||||
"document_name": metadata.get("document_name"),
|
||||
"data_source_type": metadata.get("data_source_type"),
|
||||
"segment_id": metadata.get("segment_id"),
|
||||
"retriever_from": metadata.get("retriever_from"),
|
||||
"score": metadata.get("score"),
|
||||
"hit_count": metadata.get("segment_hit_count"),
|
||||
"word_count": metadata.get("segment_word_count"),
|
||||
"segment_position": metadata.get("segment_position"),
|
||||
"index_node_hash": metadata.get("segment_index_node_hash"),
|
||||
"content": context_dict.get("content"),
|
||||
"page": metadata.get("page"),
|
||||
"doc_metadata": metadata.get("doc_metadata"),
|
||||
}
|
||||
source = RetrievalSourceMetadata(
|
||||
position=metadata.get("position"),
|
||||
dataset_id=metadata.get("dataset_id"),
|
||||
dataset_name=metadata.get("dataset_name"),
|
||||
document_id=metadata.get("document_id"),
|
||||
document_name=metadata.get("document_name"),
|
||||
data_source_type=metadata.get("data_source_type"),
|
||||
segment_id=metadata.get("segment_id"),
|
||||
retriever_from=metadata.get("retriever_from"),
|
||||
score=metadata.get("score"),
|
||||
hit_count=metadata.get("segment_hit_count"),
|
||||
word_count=metadata.get("segment_word_count"),
|
||||
segment_position=metadata.get("segment_position"),
|
||||
index_node_hash=metadata.get("segment_index_node_hash"),
|
||||
content=context_dict.get("content"),
|
||||
page=metadata.get("page"),
|
||||
doc_metadata=metadata.get("doc_metadata"),
|
||||
)
|
||||
|
||||
return source
|
||||
|
||||
|
|
@ -602,15 +603,11 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
# get conversation
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
|
|
@ -846,20 +843,24 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
used_quota = 1
|
||||
|
||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||
db.session.query(Provider).filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
).update(
|
||||
{
|
||||
"quota_used": Provider.quota_used + used_quota,
|
||||
"last_used": datetime.now(tz=UTC).replace(tzinfo=None),
|
||||
}
|
||||
)
|
||||
db.session.commit()
|
||||
with Session(db.engine) as session:
|
||||
stmt = (
|
||||
update(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
quota_used=Provider.quota_used + used_quota,
|
||||
last_used=datetime.now(tz=UTC).replace(tzinfo=None),
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.loop.entities import LoopEndNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class LoopEndNode(BaseNode[LoopEndNodeData]):
|
||||
|
|
|
|||
|
|
@ -15,7 +15,8 @@ from core.variables import (
|
|||
SegmentType,
|
||||
StringSegment,
|
||||
)
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseGraphEvent,
|
||||
BaseNodeEvent,
|
||||
|
|
@ -37,7 +38,6 @@ from core.workflow.nodes.enums import NodeType
|
|||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from core.workflow.nodes.loop.entities import LoopNodeData
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
|
@ -187,10 +187,10 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||
outputs=self.node_data.outputs,
|
||||
steps=loop_count,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
"completed_reason": "loop_break" if check_break_result else "loop_completed",
|
||||
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -198,9 +198,9 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
outputs=self.node_data.outputs,
|
||||
inputs=inputs,
|
||||
|
|
@ -221,8 +221,8 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||
metadata={
|
||||
"total_tokens": graph_engine.graph_runtime_state.total_tokens,
|
||||
"completed_reason": "error",
|
||||
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
error=str(e),
|
||||
)
|
||||
|
|
@ -232,9 +232,9 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
|
@ -322,7 +322,9 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||
inputs=inputs,
|
||||
steps=current_index,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: (
|
||||
graph_engine.graph_runtime_state.total_tokens
|
||||
),
|
||||
"completed_reason": "error",
|
||||
},
|
||||
error=event.error,
|
||||
|
|
@ -331,7 +333,11 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens},
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: (
|
||||
graph_engine.graph_runtime_state.total_tokens
|
||||
)
|
||||
},
|
||||
)
|
||||
)
|
||||
return {"check_break_result": True}
|
||||
|
|
@ -347,7 +353,7 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||
inputs=inputs,
|
||||
steps=current_index,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
"completed_reason": "error",
|
||||
},
|
||||
error=event.error,
|
||||
|
|
@ -356,7 +362,9 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens},
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens
|
||||
},
|
||||
)
|
||||
)
|
||||
return {"check_break_result": True}
|
||||
|
|
@ -411,11 +419,11 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||
metadata = event.route_node_state.node_run_result.metadata
|
||||
if not metadata:
|
||||
metadata = {}
|
||||
if NodeRunMetadataKey.LOOP_ID not in metadata:
|
||||
if WorkflowNodeExecutionMetadataKey.LOOP_ID not in metadata:
|
||||
metadata = {
|
||||
**metadata,
|
||||
NodeRunMetadataKey.LOOP_ID: self.node_id,
|
||||
NodeRunMetadataKey.LOOP_INDEX: iter_run_index,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_ID: self.node_id,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_INDEX: iter_run_index,
|
||||
}
|
||||
event.route_node_state.node_run_result.metadata = metadata
|
||||
return event
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.loop.entities import LoopStartNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class LoopStartNode(BaseNode[LoopStartNodeData]):
|
||||
|
|
|
|||
|
|
@ -25,13 +25,12 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
|||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.llm import LLMNode, ModelConfig
|
||||
from core.workflow.utils import variable_template_parser
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import ParameterExtractorNodeData
|
||||
from .exc import (
|
||||
|
|
@ -244,9 +243,9 @@ class ParameterExtractorNode(LLMNode):
|
|||
process_data=process_data,
|
||||
outputs={"__is_success": 1 if not error else 0, "__reason": error, **result},
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
NodeRunMetadataKey.CURRENCY: usage.currency,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
},
|
||||
llm_usage=usage,
|
||||
)
|
||||
|
|
@ -259,8 +258,6 @@ class ParameterExtractorNode(LLMNode):
|
|||
tools: list[PromptMessageTool],
|
||||
stop: list[str],
|
||||
) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]:
|
||||
db.session.close()
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=node_data_model.completion_params,
|
||||
|
|
@ -816,7 +813,6 @@ class ParameterExtractorNode(LLMNode):
|
|||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
# FIXME: fix the type error later
|
||||
variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}
|
||||
|
||||
if node_data.instruction:
|
||||
|
|
|
|||
|
|
@ -10,7 +10,8 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
|||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event import ModelInvokeCompletedEvent
|
||||
from core.workflow.nodes.llm import (
|
||||
|
|
@ -20,7 +21,6 @@ from core.workflow.nodes.llm import (
|
|||
)
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import QuestionClassifierNodeData
|
||||
from .exc import InvalidModelTypeError
|
||||
|
|
@ -79,9 +79,13 @@ class QuestionClassifierNode(LLMNode):
|
|||
memory=memory,
|
||||
max_token_limit=rest_token,
|
||||
)
|
||||
# Some models (e.g. Gemma, Mistral) force roles alternation (user/assistant/user/assistant...).
|
||||
# If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt,
|
||||
# two consecutive user prompts will be generated, causing model's error.
|
||||
# To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end.
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
sys_query=query,
|
||||
sys_query="",
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
sys_files=files,
|
||||
|
|
@ -142,9 +146,9 @@ class QuestionClassifierNode(LLMNode):
|
|||
outputs=outputs,
|
||||
edge_source_handle=category_id,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
NodeRunMetadataKey.CURRENCY: usage.currency,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
},
|
||||
llm_usage=usage,
|
||||
)
|
||||
|
|
@ -154,9 +158,9 @@ class QuestionClassifierNode(LLMNode):
|
|||
inputs=variables,
|
||||
error=str(e),
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
NodeRunMetadataKey.CURRENCY: usage.currency,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
},
|
||||
llm_usage=usage,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class StartNode(BaseNode[StartNodeData]):
|
||||
|
|
|
|||
|
|
@ -4,10 +4,10 @@ from typing import Any, Optional
|
|||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000"))
|
||||
|
||||
|
|
|
|||
|
|
@ -14,8 +14,9 @@ from core.tools.tool_engine import ToolEngine
|
|||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import AgentLogEvent
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
|
|
@ -25,7 +26,6 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
|||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import ToolFile
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .entities import ToolNodeData
|
||||
|
|
@ -70,7 +70,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
|
||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||
error=f"Failed to get tool runtime: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
|
|
@ -110,7 +110,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
|
||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||
error=f"Failed to invoke tool: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
|
|
@ -125,7 +125,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
|
||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||
error=f"Failed to transform tool message: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
|
|
@ -201,7 +201,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
json: list[dict] = []
|
||||
|
||||
agent_logs: list[AgentLogEvent] = []
|
||||
agent_execution_metadata: Mapping[NodeRunMetadataKey, Any] = {}
|
||||
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
|
||||
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
|
|
@ -274,7 +274,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
agent_execution_metadata = {
|
||||
key: value
|
||||
for key, value in msg_metadata.items()
|
||||
if key in NodeRunMetadataKey.__members__.values()
|
||||
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||
}
|
||||
json.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
|
|
@ -366,8 +366,8 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
outputs={"text": text, "files": files, "json": json, **variables},
|
||||
metadata={
|
||||
**agent_execution_metadata,
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info,
|
||||
NodeRunMetadataKey.AGENT_LOG: agent_logs,
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from typing import Literal, Optional
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
|
|
@ -17,7 +18,7 @@ class AdvancedSettings(BaseModel):
|
|||
Group.
|
||||
"""
|
||||
|
||||
output_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
|
||||
output_type: SegmentType
|
||||
variables: list[list[str]]
|
||||
group_name: str
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from core.variables import SegmentType, Variable
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from factories import variable_factory
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .node_data import VariableAssignerData, WriteMode
|
||||
|
||||
|
|
|
|||
|
|
@ -6,11 +6,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
|||
from core.variables import SegmentType, Variable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from . import helpers
|
||||
from .constants import EMPTY_VALUE_MAPPING
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ for accessing and manipulating data, regardless of the underlying
|
|||
storage mechanism.
|
||||
"""
|
||||
|
||||
from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
|
||||
|
||||
__all__ = [
|
||||
"OrderConfig",
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Optional, Protocol
|
||||
|
||||
from core.workflow.entities.workflow_execution_entities import WorkflowExecution
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
||||
|
||||
|
||||
class WorkflowExecutionRepository(Protocol):
|
||||
|
|
@ -2,7 +2,7 @@ from collections.abc import Sequence
|
|||
from dataclasses import dataclass
|
||||
from typing import Literal, Optional, Protocol
|
||||
|
||||
from core.workflow.entities.node_execution_entities import NodeExecution
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -26,7 +26,7 @@ class WorkflowNodeExecutionRepository(Protocol):
|
|||
application domains or deployment scenarios.
|
||||
"""
|
||||
|
||||
def save(self, execution: NodeExecution) -> None:
|
||||
def save(self, execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Save or update a NodeExecution instance.
|
||||
|
||||
|
|
@ -39,7 +39,7 @@ class WorkflowNodeExecutionRepository(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[NodeExecution]:
|
||||
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve a NodeExecution by its node_execution_id.
|
||||
|
||||
|
|
@ -55,7 +55,7 @@ class WorkflowNodeExecutionRepository(Protocol):
|
|||
self,
|
||||
workflow_run_id: str,
|
||||
order_config: Optional[OrderConfig] = None,
|
||||
) -> Sequence[NodeExecution]:
|
||||
) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all NodeExecution instances for a specific workflow run.
|
||||
|
||||
|
|
@ -70,7 +70,7 @@ class WorkflowNodeExecutionRepository(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
def get_running_executions(self, workflow_run_id: str) -> Sequence[NodeExecution]:
|
||||
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all running NodeExecution instances for a specific workflow run.
|
||||
|
||||
|
|
@ -1,11 +1,9 @@
|
|||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueNodeExceptionEvent,
|
||||
|
|
@ -19,21 +17,24 @@ from core.app.entities.queue_entities import (
|
|||
from core.app.task_pipeline.exc import WorkflowRunNotFoundError
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.entities.node_execution_entities import (
|
||||
NodeExecution,
|
||||
NodeExecutionStatus,
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.entities.workflow_execution_entities import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from models import (
|
||||
Workflow,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CycleManagerWorkflowInfo:
|
||||
workflow_id: str
|
||||
workflow_type: WorkflowType
|
||||
version: str
|
||||
graph_data: Mapping[str, Any]
|
||||
|
||||
|
||||
class WorkflowCycleManager:
|
||||
|
|
@ -42,32 +43,17 @@ class WorkflowCycleManager:
|
|||
*,
|
||||
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
|
||||
workflow_system_variables: dict[SystemVariableKey, Any],
|
||||
workflow_info: CycleManagerWorkflowInfo,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
) -> None:
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_system_variables = workflow_system_variables
|
||||
self._workflow_info = workflow_info
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
|
||||
def handle_workflow_run_start(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
workflow_id: str,
|
||||
) -> WorkflowExecution:
|
||||
workflow_stmt = select(Workflow).where(Workflow.id == workflow_id)
|
||||
workflow = session.scalar(workflow_stmt)
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found: {workflow_id}")
|
||||
|
||||
max_sequence_stmt = select(func.max(WorkflowRun.sequence_number)).where(
|
||||
WorkflowRun.tenant_id == workflow.tenant_id,
|
||||
WorkflowRun.app_id == workflow.app_id,
|
||||
)
|
||||
max_sequence = session.scalar(max_sequence_stmt) or 0
|
||||
new_sequence_number = max_sequence + 1
|
||||
|
||||
def handle_workflow_run_start(self) -> WorkflowExecution:
|
||||
inputs = {**self._application_generate_entity.inputs}
|
||||
for key, value in (self._workflow_system_variables or {}).items():
|
||||
if key.value == "conversation":
|
||||
|
|
@ -79,14 +65,13 @@ class WorkflowCycleManager:
|
|||
|
||||
# init workflow run
|
||||
# TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
|
||||
execution_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4())
|
||||
execution_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_EXECUTION_ID) or uuid4())
|
||||
execution = WorkflowExecution.new(
|
||||
id=execution_id,
|
||||
workflow_id=workflow.id,
|
||||
sequence_number=new_sequence_number,
|
||||
type=WorkflowType(workflow.type),
|
||||
workflow_version=workflow.version,
|
||||
graph=workflow.graph_dict,
|
||||
id_=execution_id,
|
||||
workflow_id=self._workflow_info.workflow_id,
|
||||
workflow_type=self._workflow_info.workflow_type,
|
||||
workflow_version=self._workflow_info.version,
|
||||
graph=self._workflow_info.graph_data,
|
||||
inputs=inputs,
|
||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
|
@ -168,7 +153,7 @@ class WorkflowCycleManager:
|
|||
workflow_run_id: str,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
status: WorkflowRunStatus,
|
||||
status: WorkflowExecutionStatus,
|
||||
error_message: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
|
|
@ -185,7 +170,7 @@ class WorkflowCycleManager:
|
|||
|
||||
# Use the instance repository to find running executions for a workflow run
|
||||
running_node_executions = self._workflow_node_execution_repository.get_running_executions(
|
||||
workflow_run_id=workflow_execution.id
|
||||
workflow_run_id=workflow_execution.id_
|
||||
)
|
||||
|
||||
# Update the domain models
|
||||
|
|
@ -193,7 +178,7 @@ class WorkflowCycleManager:
|
|||
for node_execution in running_node_executions:
|
||||
if node_execution.node_execution_id:
|
||||
# Update the domain model
|
||||
node_execution.status = NodeExecutionStatus.FAILED
|
||||
node_execution.status = WorkflowNodeExecutionStatus.FAILED
|
||||
node_execution.error = error_message
|
||||
node_execution.finished_at = now
|
||||
node_execution.elapsed_time = (now - node_execution.created_at).total_seconds()
|
||||
|
|
@ -219,28 +204,28 @@ class WorkflowCycleManager:
|
|||
*,
|
||||
workflow_execution_id: str,
|
||||
event: QueueNodeStartedEvent,
|
||||
) -> NodeExecution:
|
||||
) -> WorkflowNodeExecution:
|
||||
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
|
||||
|
||||
# Create a domain model
|
||||
created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
metadata = {
|
||||
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
||||
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
NodeRunMetadataKey.LOOP_ID: event.in_loop_id,
|
||||
WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
|
||||
}
|
||||
|
||||
domain_execution = NodeExecution(
|
||||
domain_execution = WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
workflow_id=workflow_execution.workflow_id,
|
||||
workflow_run_id=workflow_execution.id,
|
||||
workflow_execution_id=workflow_execution.id_,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
index=event.node_run_index,
|
||||
node_execution_id=event.node_execution_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
title=event.node_data.title,
|
||||
status=NodeExecutionStatus.RUNNING,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
metadata=metadata,
|
||||
created_at=created_at,
|
||||
)
|
||||
|
|
@ -250,7 +235,7 @@ class WorkflowCycleManager:
|
|||
|
||||
return domain_execution
|
||||
|
||||
def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> NodeExecution:
|
||||
def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
|
||||
# Get the domain model from repository
|
||||
domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id)
|
||||
if not domain_execution:
|
||||
|
|
@ -271,7 +256,7 @@ class WorkflowCycleManager:
|
|||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
|
||||
# Update domain model
|
||||
domain_execution.status = NodeExecutionStatus.SUCCEEDED
|
||||
domain_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
domain_execution.update_from_mapping(
|
||||
inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict
|
||||
)
|
||||
|
|
@ -290,7 +275,7 @@ class WorkflowCycleManager:
|
|||
| QueueNodeInIterationFailedEvent
|
||||
| QueueNodeInLoopFailedEvent
|
||||
| QueueNodeExceptionEvent,
|
||||
) -> NodeExecution:
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
:param event: queue node failed event
|
||||
|
|
@ -317,9 +302,9 @@ class WorkflowCycleManager:
|
|||
|
||||
# Update domain model
|
||||
domain_execution.status = (
|
||||
NodeExecutionStatus.FAILED
|
||||
WorkflowNodeExecutionStatus.FAILED
|
||||
if not isinstance(event, QueueNodeExceptionEvent)
|
||||
else NodeExecutionStatus.EXCEPTION
|
||||
else WorkflowNodeExecutionStatus.EXCEPTION
|
||||
)
|
||||
domain_execution.error = event.error
|
||||
domain_execution.update_from_mapping(
|
||||
|
|
@ -335,7 +320,7 @@ class WorkflowCycleManager:
|
|||
|
||||
def handle_workflow_node_execution_retried(
|
||||
self, *, workflow_execution_id: str, event: QueueNodeRetryEvent
|
||||
) -> NodeExecution:
|
||||
) -> WorkflowNodeExecution:
|
||||
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
|
||||
created_at = event.start_at
|
||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
|
@ -345,13 +330,13 @@ class WorkflowCycleManager:
|
|||
|
||||
# Convert metadata keys to strings
|
||||
origin_metadata = {
|
||||
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
||||
NodeRunMetadataKey.LOOP_ID: event.in_loop_id,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
|
||||
}
|
||||
|
||||
# Convert execution metadata keys to strings
|
||||
execution_metadata_dict: dict[NodeRunMetadataKey, str | None] = {}
|
||||
execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, str | None] = {}
|
||||
if event.execution_metadata:
|
||||
for key, value in event.execution_metadata.items():
|
||||
execution_metadata_dict[key] = value
|
||||
|
|
@ -359,16 +344,16 @@ class WorkflowCycleManager:
|
|||
merged_metadata = {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata
|
||||
|
||||
# Create a domain model
|
||||
domain_execution = NodeExecution(
|
||||
domain_execution = WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
workflow_id=workflow_execution.workflow_id,
|
||||
workflow_run_id=workflow_execution.id,
|
||||
workflow_execution_id=workflow_execution.id_,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
node_execution_id=event.node_execution_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
title=event.node_data.title,
|
||||
status=NodeExecutionStatus.RETRY,
|
||||
status=WorkflowNodeExecutionStatus.RETRY,
|
||||
created_at=created_at,
|
||||
finished_at=finished_at,
|
||||
elapsed_time=elapsed_time,
|
||||
|
|
|
|||
|
|
@ -93,8 +93,8 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
|
|||
raise VariableError("missing value type")
|
||||
if (value := mapping.get("value")) is None:
|
||||
raise VariableError("missing value")
|
||||
# FIXME: using Any here, fix it later
|
||||
result: Any
|
||||
|
||||
result: Variable
|
||||
match value_type:
|
||||
case SegmentType.STRING:
|
||||
result = StringVariable.model_validate(mapping)
|
||||
|
|
|
|||
|
|
@ -28,7 +28,8 @@ class SMTPClient:
|
|||
else:
|
||||
smtp = smtplib.SMTP(self.server, self.port, timeout=10)
|
||||
|
||||
if self.username and self.password:
|
||||
# Only authenticate if both username and password are non-empty
|
||||
if self.username and self.password and self.username.strip() and self.password.strip():
|
||||
smtp.login(self.username, self.password)
|
||||
|
||||
msg = MIMEMultipart()
|
||||
|
|
|
|||
|
|
@ -85,11 +85,9 @@ from .workflow import (
|
|||
Workflow,
|
||||
WorkflowAppLog,
|
||||
WorkflowAppLogCreatedFrom,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowNodeExecutionModel,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
WorkflowType,
|
||||
)
|
||||
|
||||
|
|
@ -101,14 +99,14 @@ __all__ = [
|
|||
"AccountStatus",
|
||||
"ApiRequest",
|
||||
"ApiToken",
|
||||
"ApiToolProvider", # Added
|
||||
"ApiToolProvider",
|
||||
"App",
|
||||
"AppAnnotationHitHistory",
|
||||
"AppAnnotationSetting",
|
||||
"AppDatasetJoin",
|
||||
"AppMode",
|
||||
"AppModelConfig",
|
||||
"BuiltinToolProvider", # Added
|
||||
"BuiltinToolProvider",
|
||||
"CeleryTask",
|
||||
"CeleryTaskSet",
|
||||
"Conversation",
|
||||
|
|
@ -174,11 +172,9 @@ __all__ = [
|
|||
"Workflow",
|
||||
"WorkflowAppLog",
|
||||
"WorkflowAppLogCreatedFrom",
|
||||
"WorkflowNodeExecution",
|
||||
"WorkflowNodeExecutionStatus",
|
||||
"WorkflowNodeExecutionModel",
|
||||
"WorkflowNodeExecutionTriggeredFrom",
|
||||
"WorkflowRun",
|
||||
"WorkflowRunStatus",
|
||||
"WorkflowRunTriggeredFrom",
|
||||
"WorkflowToolProvider",
|
||||
"WorkflowType",
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
|||
from core.plugin.entities.plugin import GenericProviderID
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -31,7 +32,6 @@ from .base import Base
|
|||
from .engine import db
|
||||
from .enums import CreatorUserRole
|
||||
from .types import StringUUID
|
||||
from .workflow import WorkflowRunStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .workflow import Workflow
|
||||
|
|
@ -795,22 +795,22 @@ class Conversation(Base):
|
|||
def status_count(self):
|
||||
messages = db.session.query(Message).filter(Message.conversation_id == self.id).all()
|
||||
status_counts = {
|
||||
WorkflowRunStatus.RUNNING: 0,
|
||||
WorkflowRunStatus.SUCCEEDED: 0,
|
||||
WorkflowRunStatus.FAILED: 0,
|
||||
WorkflowRunStatus.STOPPED: 0,
|
||||
WorkflowRunStatus.PARTIAL_SUCCEEDED: 0,
|
||||
WorkflowExecutionStatus.RUNNING: 0,
|
||||
WorkflowExecutionStatus.SUCCEEDED: 0,
|
||||
WorkflowExecutionStatus.FAILED: 0,
|
||||
WorkflowExecutionStatus.STOPPED: 0,
|
||||
WorkflowExecutionStatus.PARTIAL_SUCCEEDED: 0,
|
||||
}
|
||||
|
||||
for message in messages:
|
||||
if message.workflow_run:
|
||||
status_counts[WorkflowRunStatus(message.workflow_run.status)] += 1
|
||||
status_counts[WorkflowExecutionStatus(message.workflow_run.status)] += 1
|
||||
|
||||
return (
|
||||
{
|
||||
"success": status_counts[WorkflowRunStatus.SUCCEEDED],
|
||||
"failed": status_counts[WorkflowRunStatus.FAILED],
|
||||
"partial_success": status_counts[WorkflowRunStatus.PARTIAL_SUCCEEDED],
|
||||
"success": status_counts[WorkflowExecutionStatus.SUCCEEDED],
|
||||
"failed": status_counts[WorkflowExecutionStatus.FAILED],
|
||||
"partial_success": status_counts[WorkflowExecutionStatus.PARTIAL_SUCCEEDED],
|
||||
}
|
||||
if messages
|
||||
else None
|
||||
|
|
|
|||
|
|
@ -401,18 +401,6 @@ class Workflow(Base):
|
|||
)
|
||||
|
||||
|
||||
class WorkflowRunStatus(StrEnum):
|
||||
"""
|
||||
Workflow Run Status Enum
|
||||
"""
|
||||
|
||||
RUNNING = "running"
|
||||
SUCCEEDED = "succeeded"
|
||||
FAILED = "failed"
|
||||
STOPPED = "stopped"
|
||||
PARTIAL_SUCCEEDED = "partial-succeeded"
|
||||
|
||||
|
||||
class WorkflowRun(Base):
|
||||
"""
|
||||
Workflow Run
|
||||
|
|
@ -473,12 +461,12 @@ class WorkflowRun(Base):
|
|||
error: Mapped[Optional[str]] = mapped_column(db.Text)
|
||||
elapsed_time: Mapped[float] = mapped_column(db.Float, nullable=False, server_default=sa.text("0"))
|
||||
total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))
|
||||
total_steps: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"))
|
||||
total_steps: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True)
|
||||
created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
|
||||
exceptions_count: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"))
|
||||
exceptions_count: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True)
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
|
|
@ -578,19 +566,7 @@ class WorkflowNodeExecutionTriggeredFrom(StrEnum):
|
|||
RAG_PIPELINE_RUN = "rag-pipeline-run"
|
||||
|
||||
|
||||
class WorkflowNodeExecutionStatus(StrEnum):
|
||||
"""
|
||||
Workflow Node Execution Status Enum
|
||||
"""
|
||||
|
||||
RUNNING = "running"
|
||||
SUCCEEDED = "succeeded"
|
||||
FAILED = "failed"
|
||||
EXCEPTION = "exception"
|
||||
RETRY = "retry"
|
||||
|
||||
|
||||
class WorkflowNodeExecution(Base):
|
||||
class WorkflowNodeExecutionModel(Base):
|
||||
"""
|
||||
Workflow Node Execution
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ dependencies = [
|
|||
"chardet~=5.1.0",
|
||||
"flask~=3.1.0",
|
||||
"flask-compress~=1.17",
|
||||
"flask-cors~=5.0.0",
|
||||
"flask-cors~=6.0.0",
|
||||
"flask-login~=0.6.3",
|
||||
"flask-migrate~=4.0.7",
|
||||
"flask-restful~=0.3.10",
|
||||
|
|
@ -36,7 +36,6 @@ dependencies = [
|
|||
"mailchimp-transactional~=1.0.50",
|
||||
"markdown~=3.5.1",
|
||||
"numpy~=1.26.4",
|
||||
"oci~=2.135.1",
|
||||
"openai~=1.61.0",
|
||||
"openpyxl~=3.1.5",
|
||||
"opik~=1.7.25",
|
||||
|
|
@ -143,13 +142,16 @@ dev = [
|
|||
"types-requests~=2.32.0",
|
||||
"types-requests-oauthlib~=2.0.0",
|
||||
"types-shapely~=2.0.0",
|
||||
"types-simplejson~=3.20.0",
|
||||
"types-six~=1.17.0",
|
||||
"types-tensorflow~=2.18.0",
|
||||
"types-tqdm~=4.67.0",
|
||||
"types-ujson~=5.10.0",
|
||||
"types-simplejson>=3.20.0",
|
||||
"types-six>=1.17.0",
|
||||
"types-tensorflow>=2.18.0",
|
||||
"types-tqdm>=4.67.0",
|
||||
"types-ujson>=5.10.0",
|
||||
"boto3-stubs>=1.38.20",
|
||||
"types-jmespath>=1.0.2.20240106",
|
||||
"types_pyOpenSSL>=24.1.0",
|
||||
"types_cffi>=1.17.0",
|
||||
"types_setuptools>=80.9.0",
|
||||
]
|
||||
|
||||
############################################################
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
[pytest]
|
||||
continue-on-collection-errors = true
|
||||
addopts = --cov=./api --cov-report=json --cov-report=xml
|
||||
env =
|
||||
ANTHROPIC_API_KEY = sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz
|
||||
|
|
|
|||
|
|
@ -34,9 +34,8 @@ def clean_messages():
|
|||
while True:
|
||||
try:
|
||||
# Main query with join and filter
|
||||
# FIXME:for mypy no paginate method error
|
||||
messages = (
|
||||
db.session.query(Message) # type: ignore
|
||||
db.session.query(Message)
|
||||
.filter(Message.created_at < plan_sandbox_clean_message_day)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(100)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from extensions.ext_database import db
|
|||
from extensions.ext_storage import storage
|
||||
from models.account import Tenant
|
||||
from models.model import App, Conversation, Message
|
||||
from models.workflow import WorkflowNodeExecution, WorkflowRun
|
||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowRun
|
||||
from services.billing_service import BillingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -108,10 +108,11 @@ class ClearFreePlanTenantExpiredLogs:
|
|||
while True:
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
workflow_node_executions = (
|
||||
session.query(WorkflowNodeExecution)
|
||||
session.query(WorkflowNodeExecutionModel)
|
||||
.filter(
|
||||
WorkflowNodeExecution.tenant_id == tenant_id,
|
||||
WorkflowNodeExecution.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
|
||||
WorkflowNodeExecutionModel.tenant_id == tenant_id,
|
||||
WorkflowNodeExecutionModel.created_at
|
||||
< datetime.datetime.now() - datetime.timedelta(days=days),
|
||||
)
|
||||
.limit(batch)
|
||||
.all()
|
||||
|
|
@ -135,8 +136,8 @@ class ClearFreePlanTenantExpiredLogs:
|
|||
]
|
||||
|
||||
# delete workflow node executions
|
||||
session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id.in_(workflow_node_execution_ids),
|
||||
session.query(WorkflowNodeExecutionModel).filter(
|
||||
WorkflowNodeExecutionModel.id.in_(workflow_node_execution_ids),
|
||||
).delete(synchronize_session=False)
|
||||
session.commit()
|
||||
|
||||
|
|
|
|||
|
|
@ -2,8 +2,11 @@ import logging
|
|||
import time
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
from core.model_runtime.entities import LLMMode
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
|
|
@ -34,7 +37,29 @@ class HitTestingService:
|
|||
# get retrieval model , if the model is not setting , using default
|
||||
if not retrieval_model:
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
document_ids_filter = None
|
||||
metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {})
|
||||
if metadata_filtering_conditions:
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
from core.app.app_config.entities import MetadataFilteringCondition
|
||||
|
||||
metadata_filtering_conditions = MetadataFilteringCondition(**metadata_filtering_conditions)
|
||||
|
||||
metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
|
||||
dataset_ids=[dataset.id],
|
||||
query=query,
|
||||
metadata_filtering_mode="manual",
|
||||
metadata_filtering_conditions=metadata_filtering_conditions,
|
||||
inputs={},
|
||||
tenant_id="",
|
||||
user_id="",
|
||||
metadata_model_config=ModelConfig(provider="", name="", mode=LLMMode.CHAT, completion_params={}),
|
||||
)
|
||||
if metadata_filter_document_ids:
|
||||
document_ids_filter = metadata_filter_document_ids.get(dataset.id, [])
|
||||
if metadata_condition and not document_ids_filter:
|
||||
return cls.compact_retrieve_response(query, [])
|
||||
all_documents = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
|
||||
dataset_id=dataset.id,
|
||||
|
|
@ -48,6 +73,7 @@ class HitTestingService:
|
|||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=retrieval_model.get("weights", None),
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
|
||||
end = time.perf_counter()
|
||||
|
|
@ -99,7 +125,7 @@ class HitTestingService:
|
|||
return dict(cls.compact_external_retrieve_response(dataset, query, all_documents))
|
||||
|
||||
@classmethod
|
||||
def compact_retrieve_response(cls, query: str, documents: list[Document]):
|
||||
def compact_retrieve_response(cls, query: str, documents: list[Document]) -> dict[Any, Any]:
|
||||
records = RetrievalService.format_retrieval_documents(documents)
|
||||
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.ops.entities.config_entity import BaseTracingConfig
|
||||
from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, TraceAppConfig
|
||||
|
|
@ -92,13 +93,12 @@ class OpsService:
|
|||
except KeyError:
|
||||
return {"error": f"Invalid tracing provider: {tracing_provider}"}
|
||||
|
||||
config_class, other_keys = (
|
||||
provider_config_map[tracing_provider]["config_class"],
|
||||
provider_config_map[tracing_provider]["other_keys"],
|
||||
)
|
||||
# FIXME: ignore type error
|
||||
default_config_instance = config_class(**tracing_config) # type: ignore
|
||||
for key in other_keys: # type: ignore
|
||||
provider_config: dict[str, Any] = provider_config_map[tracing_provider]
|
||||
config_class: type[BaseTracingConfig] = provider_config["config_class"]
|
||||
other_keys: list[str] = provider_config["other_keys"]
|
||||
|
||||
default_config_instance: BaseTracingConfig = config_class(**tracing_config)
|
||||
for key in other_keys:
|
||||
if key in tracing_config and tracing_config[key] == "":
|
||||
tracing_config[key] = getattr(default_config_instance, key, None)
|
||||
|
||||
|
|
|
|||
|
|
@ -44,6 +44,17 @@ class TagService:
|
|||
results = [tag_binding.target_id for tag_binding in tag_bindings]
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str) -> list:
|
||||
tags = (
|
||||
db.session.query(Tag)
|
||||
.filter(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
|
||||
.all()
|
||||
)
|
||||
if not tags:
|
||||
return []
|
||||
return tags
|
||||
|
||||
@staticmethod
|
||||
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list:
|
||||
tags = (
|
||||
|
|
@ -62,6 +73,8 @@ class TagService:
|
|||
|
||||
@staticmethod
|
||||
def save_tags(args: dict) -> Tag:
|
||||
if TagService.get_tag_by_tag_name(args["type"], current_user.current_tenant_id, args["name"]):
|
||||
raise ValueError("Tag name already exists")
|
||||
tag = Tag(
|
||||
id=str(uuid.uuid4()),
|
||||
name=args["name"],
|
||||
|
|
@ -75,6 +88,8 @@ class TagService:
|
|||
|
||||
@staticmethod
|
||||
def update_tags(args: dict, tag_id: str) -> Tag:
|
||||
if TagService.get_tag_by_tag_name(args["type"], current_user.current_tenant_id, args["name"]):
|
||||
raise ValueError("Tag name already exists")
|
||||
tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
|
||||
if not tag:
|
||||
raise NotFound("Tag not found")
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue