fix: merge issues

This commit is contained in:
takatost 2024-08-21 17:25:26 +08:00
parent 35be41b337
commit e34497ded1
13 changed files with 56 additions and 105 deletions

View File

@ -228,62 +228,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
message_id=message.id
)
# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
# Create conversation variables if they don't exist.
conversation_variables = [
ConversationVariable.from_variable(
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
# Convert database entities to variables.
conversation_variables = [item.to_variable() for item in conversation_variables]
session.commit()
# Increment dialogue count.
conversation.dialogue_count += 1
conversation_id = conversation.id
conversation_dialogue_count = conversation.dialogue_count
db.session.commit()
db.session.refresh(conversation)
inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files
user_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = application_generate_entity.user_id
# Create a variable pool.
system_inputs = {
SystemVariableKey.QUERY: query,
SystemVariableKey.FILES: files,
SystemVariableKey.CONVERSATION_ID: conversation_id,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
)
contexts.workflow_variable_pool.set(variable_pool)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(), # type: ignore

View File

@ -23,11 +23,11 @@ from core.moderation.base import ModerationException
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.model import App, Conversation, EndUser, Message
from models.workflow import ConversationVariable
from models.workflow import ConversationVariable, WorkflowType
logger = logging.getLogger(__name__)
@ -124,6 +124,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
# Create conversation variables if they don't exist.
conversation_variables = [
ConversationVariable.from_variable(
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
@ -131,16 +132,24 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
session.commit()
# Convert database entities to variables
# Convert database entities to variables.
conversation_variables = [item.to_variable() for item in conversation_variables]
session.commit()
# Increment dialogue count.
self.conversation.dialogue_count += 1
conversation_dialogue_count = self.conversation.dialogue_count
db.session.commit()
# Create a variable pool.
system_inputs = {
SystemVariable.QUERY: query,
SystemVariable.FILES: files,
SystemVariable.CONVERSATION_ID: self.conversation.id,
SystemVariable.USER_ID: user_id,
SystemVariableKey.QUERY: query,
SystemVariableKey.FILES: files,
SystemVariableKey.CONVERSATION_ID: self.conversation.id,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
}
# init variable pool
@ -159,7 +168,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_type=workflow.type,
workflow_type=WorkflowType.value_of(workflow.type),
graph=graph,
graph_config=workflow.graph_dict,
user_id=self.application_generate_entity.user_id,

View File

@ -41,9 +41,7 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.end.end_node import EndNode
from extensions.ext_database import db
from models.account import Account
from models.model import EndUser

View File

@ -33,7 +33,7 @@ from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
@ -56,7 +56,7 @@ class WorkflowCycleManage:
_workflow: Workflow
_user: Union[Account, EndUser]
_task_state: WorkflowTaskState
_workflow_system_variables: dict[SystemVariable, Any]
_workflow_system_variables: dict[SystemVariableKey, Any]
def _handle_workflow_run_start(self) -> WorkflowRun:
max_sequence = (

View File

@ -3,7 +3,7 @@ from collections.abc import Mapping, Sequence
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID, VariablePool
from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.start.entities import StartNodeData
from models.workflow import WorkflowNodeExecutionStatus

View File

@ -19,23 +19,23 @@ class VariableAssignerNode(BaseNode):
_node_data_cls: type[BaseNodeData] = VariableAssignerData
_node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
data = cast(VariableAssignerData, self.node_data)
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = variable_pool.get(data.assigned_variable_selector)
original_variable = self.graph_runtime_state.variable_pool.get(data.assigned_variable_selector)
if not isinstance(original_variable, Variable):
raise VariableAssignerNodeError('assigned variable not found')
match data.write_mode:
case WriteMode.OVER_WRITE:
income_value = variable_pool.get(data.input_variable_selector)
income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError('input value not found')
updated_variable = original_variable.model_copy(update={'value': income_value.value})
case WriteMode.APPEND:
income_value = variable_pool.get(data.input_variable_selector)
income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError('input value not found')
updated_value = original_variable.value + [income_value.value]
@ -49,11 +49,11 @@ class VariableAssignerNode(BaseNode):
raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}')
# Over write the variable.
variable_pool.add(data.assigned_variable_selector, updated_variable)
self.graph_runtime_state.variable_pool.add(data.assigned_variable_selector, updated_variable)
# TODO: Move database operation to the pipeline.
# Update conversation variable.
conversation_id = variable_pool.get(['sys', 'conversation_id'])
conversation_id = self.graph_runtime_state.variable_pool.get(['sys', 'conversation_id'])
if not conversation_id:
raise VariableAssignerNodeError('conversation_id not found')
update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)

View File

@ -1,7 +1,7 @@
"""add node_execution_id into node_executions
Revision ID: 675b5321501b
Revises: 8782057ff0dc
Revises: 2dbe42621d96
Create Date: 2024-08-12 10:54:02.259331
"""
@ -12,7 +12,7 @@ import models as models
# revision identifiers, used by Alembic.
revision = '675b5321501b'
down_revision = '8782057ff0dc'
down_revision = '2dbe42621d96'
branch_labels = None
depends_on = None

View File

@ -3,7 +3,7 @@ from unittest.mock import patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
BaseNodeEvent,
GraphRunFailedEvent,
@ -201,8 +201,8 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
)
variable_pool = VariablePool(system_variables={
SystemVariable.FILES: [],
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.FILES: [],
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={
"query": "hi"
})
@ -363,10 +363,10 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
)
variable_pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather in SF',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.QUERY: 'what\'s the weather in SF',
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={})
graph_engine = GraphEngine(
@ -521,10 +521,10 @@ def test_run_branch(mock_close, mock_remove):
)
variable_pool = VariablePool(system_variables={
SystemVariable.QUERY: 'hi',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.QUERY: 'hi',
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={
"uid": "takato"
})

View File

@ -5,7 +5,7 @@ from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
@ -57,8 +57,8 @@ def test_execute_answer():
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.FILES: [],
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.FILES: [],
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
pool.add(['start', 'weather'], 'sunny')
pool.add(['llm', 'text'], 'You are a helpful AI.')

View File

@ -4,7 +4,7 @@ from datetime import datetime, timezone
from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
NodeRunStartedEvent,
@ -198,10 +198,10 @@ def test_process():
)
variable_pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather in SF',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.QUERY: 'what\'s the weather in SF',
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={})
answer_stream_processor = AnswerStreamProcessor(

View File

@ -5,7 +5,7 @@ from unittest.mock import patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunResult, UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
@ -155,10 +155,10 @@ def test_run():
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.QUERY: 'dify',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: '1'
SystemVariableKey.QUERY: 'dify',
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariableKey.USER_ID: '1'
}, user_inputs={}, environment_variables=[])
pool.add(['pe', 'list_output'], ["dify-1", "dify-2"])

View File

@ -90,8 +90,8 @@ def test_execute_answer():
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.FILES: [],
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.FILES: [],
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
pool.add(['start', 'weather'], 'sunny')
pool.add(['llm', 'text'], 'You are a helpful AI.')

View File

@ -274,7 +274,7 @@ def test_clear_array():
},
)
with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run:
with mock.patch('core.workflow.nodes.variable_assigner.node.update_conversation_variable') as mock_run:
list(node.run())
mock_run.assert_called_once()