diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index c42620b92f..5f5fd7010c 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -8,10 +8,12 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, + InvokeFrom, ) from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent from core.moderation.base import ModerationException from core.workflow.entities.node_entities import SystemVariable +from core.workflow.nodes.base_node import UserFrom from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db from models.model import App, Conversation, Message @@ -78,6 +80,10 @@ class AdvancedChatAppRunner(AppRunner): workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager.run_workflow( workflow=workflow, + user_id=application_generate_entity.user_id, + user_from=UserFrom.ACCOUNT + if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] + else UserFrom.END_USER, user_inputs=inputs, system_inputs={ SystemVariable.QUERY: query, diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 2d032fcdcb..922c3003bf 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -7,12 +7,14 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfig from core.app.apps.workflow.workflow_event_trigger_callback import WorkflowEventTriggerCallback from core.app.entities.app_invoke_entities import ( AppGenerateEntity, + InvokeFrom, WorkflowAppGenerateEntity, ) from core.app.entities.queue_entities import QueueStopEvent, QueueTextChunkEvent from core.moderation.base import ModerationException from core.moderation.input_moderation import InputModeration from core.workflow.entities.node_entities import SystemVariable +from core.workflow.nodes.base_node import UserFrom from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db from models.model import App @@ -63,6 +65,10 @@ class WorkflowAppRunner: workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager.run_workflow( workflow=workflow, + user_id=application_generate_entity.user_id, + user_from=UserFrom.ACCOUNT + if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] + else UserFrom.END_USER, user_inputs=inputs, system_inputs={ SystemVariable.FILES: files diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py index 91f9ef95fe..a78bf09a53 100644 --- a/api/core/workflow/entities/workflow_entities.py +++ b/api/core/workflow/entities/workflow_entities.py @@ -2,7 +2,7 @@ from typing import Optional from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.base_node import BaseNode, UserFrom from models.workflow import Workflow, WorkflowType @@ -20,6 +20,8 @@ class WorkflowRunState: app_id: str workflow_id: str workflow_type: WorkflowType + user_id: str + user_from: UserFrom start_at: float variable_pool: VariablePool @@ -28,11 +30,17 @@ class WorkflowRunState: workflow_nodes_and_results: list[WorkflowNodeAndResult] = [] - def __init__(self, workflow: Workflow, start_at: float, variable_pool: VariablePool): + def __init__(self, workflow: Workflow, + start_at: float, + variable_pool: VariablePool, + user_id: str, + user_from: UserFrom): self.workflow_id = workflow.id self.tenant_id = workflow.tenant_id self.app_id = workflow.app_id self.workflow_type = WorkflowType.value_of(workflow.type) + self.user_id = user_id + self.user_from = user_from self.start_at = start_at self.variable_pool = variable_pool diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index 6db25bea7e..a603f484ef 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from enum import Enum from typing import Optional from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback @@ -8,6 +9,26 @@ from core.workflow.entities.variable_pool import VariablePool from models.workflow import WorkflowNodeExecutionStatus +class UserFrom(Enum): + """ + User from + """ + ACCOUNT = "account" + END_USER = "end-user" + + @classmethod + def value_of(cls, value: str) -> "UserFrom": + """ + Value of + :param value: value + :return: + """ + for item in cls: + if item.value == value: + return item + raise ValueError(f"Invalid value: {value}") + + class BaseNode(ABC): _node_data_cls: type[BaseNodeData] _node_type: NodeType @@ -15,6 +36,8 @@ class BaseNode(ABC): tenant_id: str app_id: str workflow_id: str + user_id: str + user_from: UserFrom node_id: str node_data: BaseNodeData @@ -25,11 +48,15 @@ class BaseNode(ABC): def __init__(self, tenant_id: str, app_id: str, workflow_id: str, + user_id: str, + user_from: UserFrom, config: dict, callbacks: list[BaseWorkflowCallback] = None) -> None: self.tenant_id = tenant_id self.app_id = app_id self.workflow_id = workflow_id + self.user_id = user_id + self.user_from = user_from self.node_id = config.get("id") if not self.node_id: diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index d01746ceb8..0bc13cbb5a 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -6,7 +6,7 @@ from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState -from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.base_node import BaseNode, UserFrom from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.direct_answer.direct_answer_node import DirectAnswerNode from core.workflow.nodes.end.end_node import EndNode @@ -76,12 +76,16 @@ class WorkflowEngineManager: return default_config def run_workflow(self, workflow: Workflow, + user_id: str, + user_from: UserFrom, user_inputs: dict, system_inputs: Optional[dict] = None, callbacks: list[BaseWorkflowCallback] = None) -> None: """ Run workflow :param workflow: Workflow instance + :param user_id: user id + :param user_from: user from :param user_inputs: user variables inputs :param system_inputs: system inputs, like: query, files :param callbacks: workflow callbacks @@ -113,7 +117,9 @@ class WorkflowEngineManager: variable_pool=VariablePool( system_variables=system_inputs, user_inputs=user_inputs - ) + ), + user_id=user_id, + user_from=user_from ) try: @@ -222,6 +228,8 @@ class WorkflowEngineManager: tenant_id=workflow_run_state.tenant_id, app_id=workflow_run_state.app_id, workflow_id=workflow_run_state.workflow_id, + user_id=workflow_run_state.user_id, + user_from=workflow_run_state.user_from, config=node_config, callbacks=callbacks ) @@ -267,6 +275,8 @@ class WorkflowEngineManager: tenant_id=workflow_run_state.tenant_id, app_id=workflow_run_state.app_id, workflow_id=workflow_run_state.workflow_id, + user_id=workflow_run_state.user_id, + user_from=workflow_run_state.user_from, config=target_node_config, callbacks=callbacks ) diff --git a/api/tests/unit_tests/core/workflow/__init__.py b/api/tests/unit_tests/core/workflow/__init__.py new file mode 100644 index 0000000000..e69de29bb2