From 1f986a3abbef7ae2cbcbdf0cd05acebeb48baeca Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 7 Mar 2024 19:45:02 +0800 Subject: [PATCH] fix bugs --- api/controllers/console/app/workflow.py | 28 ++++-- .../advanced_chat/generate_task_pipeline.py | 2 +- .../workflow_event_trigger_callback.py | 2 +- api/core/app/apps/chat/app_config_manager.py | 2 +- .../workflow_event_trigger_callback.py | 2 +- api/core/workflow/workflow_engine_manager.py | 99 +++++++++---------- .../versions/b289e2408ee2_add_workflow.py | 4 +- ...29b71023c_messages_columns_set_nullable.py | 41 ++++++++ api/models/model.py | 4 +- api/models/workflow.py | 6 +- 10 files changed, 118 insertions(+), 72 deletions(-) create mode 100644 api/migrations/versions/b5429b71023c_messages_columns_set_nullable.py diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 5d70076821..8a68cafad8 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -1,6 +1,7 @@ import json import logging from collections.abc import Generator +from typing import Union from flask import Response, stream_with_context from flask_restful import Resource, marshal_with, reqparse @@ -79,9 +80,9 @@ class AdvancedChatDraftWorkflowRunApi(Resource): Run draft workflow """ parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, location='json', default='') - parser.add_argument('files', type=list, required=False, location='json') + parser.add_argument('inputs', type=dict, location='json') + parser.add_argument('query', type=str, required=True, location='json', default='') + parser.add_argument('files', type=list, location='json') parser.add_argument('conversation_id', type=uuid_value, location='json') args = parser.parse_args() @@ -93,6 +94,8 @@ class AdvancedChatDraftWorkflowRunApi(Resource): args=args, invoke_from=InvokeFrom.DEBUGGER ) + + return compact_response(response) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.conversation.ConversationCompletedError: @@ -103,12 +106,6 @@ class AdvancedChatDraftWorkflowRunApi(Resource): logging.exception("internal server error.") raise InternalServerError() - def generate() -> Generator: - yield from response - - return Response(stream_with_context(generate()), status=200, - mimetype='text/event-stream') - class DraftWorkflowRunApi(Resource): @setup_required @@ -120,7 +117,7 @@ class DraftWorkflowRunApi(Resource): Run draft workflow """ parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') + parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') args = parser.parse_args() workflow_service = WorkflowService() @@ -280,6 +277,17 @@ class ConvertToWorkflowApi(Resource): return workflow +def compact_response(response: Union[dict, Generator]) -> Response: + if isinstance(response, dict): + return Response(response=json.dumps(response), status=200, mimetype='application/json') + else: + def generate() -> Generator: + yield from response + + return Response(stream_with_context(generate()), status=200, + mimetype='text/event-stream') + + api.add_resource(DraftWorkflowApi, '/apps//workflows/draft') api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps//advanced-chat/workflows/draft/run') api.add_resource(DraftWorkflowRunApi, '/apps//workflows/draft/run') diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 84352f16c7..624a0f430a 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -174,7 +174,7 @@ class AdvancedChatAppGenerateTaskPipeline: response = { 'event': 'workflow_started', 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': event.workflow_run_id, + 'workflow_run_id': workflow_run.id, 'data': { 'id': workflow_run.id, 'workflow_id': workflow_run.workflow_id, diff --git a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py index 44fb5905b0..5d99ce6297 100644 --- a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py +++ b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py @@ -15,7 +15,7 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): self._queue_manager = queue_manager - self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph) + self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph_dict) def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: """ diff --git a/api/core/app/apps/chat/app_config_manager.py b/api/core/app/apps/chat/app_config_manager.py index ac69a92823..553cf34ee9 100644 --- a/api/core/app/apps/chat/app_config_manager.py +++ b/api/core/app/apps/chat/app_config_manager.py @@ -46,7 +46,7 @@ class ChatAppConfigManager(BaseAppConfigManager): else: config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG - if override_config_dict != EasyUIBasedAppModelConfigFrom.ARGS: + if config_from != EasyUIBasedAppModelConfigFrom.ARGS: app_model_config_dict = app_model_config.to_dict() config_dict = app_model_config_dict.copy() else: diff --git a/api/core/app/apps/workflow/workflow_event_trigger_callback.py b/api/core/app/apps/workflow/workflow_event_trigger_callback.py index 57775f2cce..3d7a4035e7 100644 --- a/api/core/app/apps/workflow/workflow_event_trigger_callback.py +++ b/api/core/app/apps/workflow/workflow_event_trigger_callback.py @@ -15,7 +15,7 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): self._queue_manager = queue_manager - self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph) + self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph_dict) def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: """ diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 5423546957..05a784c221 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -5,7 +5,7 @@ from typing import Optional, Union from core.model_runtime.utils.encoders import jsonable_encoder from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback -from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.workflow_entities import WorkflowRunState from core.workflow.nodes.base_node import BaseNode @@ -122,10 +122,10 @@ class WorkflowEngineManager: if 'nodes' not in graph or 'edges' not in graph: raise ValueError('nodes or edges not found in workflow graph') - if isinstance(graph.get('nodes'), list): + if not isinstance(graph.get('nodes'), list): raise ValueError('nodes in workflow graph must be a list') - if isinstance(graph.get('edges'), list): + if not isinstance(graph.get('edges'), list): raise ValueError('edges in workflow graph must be a list') # init workflow run @@ -150,6 +150,7 @@ class WorkflowEngineManager: try: predecessor_node = None + has_entry_node = False while True: # get next node, multiple target nodes in the future next_node = self._get_next_node( @@ -161,6 +162,8 @@ class WorkflowEngineManager: if not next_node: break + has_entry_node = True + # max steps 30 reached if len(workflow_run_state.workflow_node_executions) > 30: raise ValueError('Max steps 30 reached.') @@ -182,7 +185,7 @@ class WorkflowEngineManager: predecessor_node = next_node - if not predecessor_node and not next_node: + if not has_entry_node: self._workflow_run_failed( workflow_run_state=workflow_run_state, error='Start node not found in workflow graph.', @@ -219,38 +222,31 @@ class WorkflowEngineManager: :param callbacks: workflow callbacks :return: """ - try: - db.session.begin() + max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \ + .filter(WorkflowRun.tenant_id == workflow.tenant_id) \ + .filter(WorkflowRun.app_id == workflow.app_id) \ + .scalar() or 0 + new_sequence_number = max_sequence + 1 - max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \ - .filter(WorkflowRun.tenant_id == workflow.tenant_id) \ - .filter(WorkflowRun.app_id == workflow.app_id) \ - .for_update() \ - .scalar() or 0 - new_sequence_number = max_sequence + 1 + # init workflow run + workflow_run = WorkflowRun( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + sequence_number=new_sequence_number, + workflow_id=workflow.id, + type=workflow.type, + triggered_from=triggered_from.value, + version=workflow.version, + graph=workflow.graph, + inputs=json.dumps({**user_inputs, **jsonable_encoder(system_inputs)}), + status=WorkflowRunStatus.RUNNING.value, + created_by_role=(CreatedByRole.ACCOUNT.value + if isinstance(user, Account) else CreatedByRole.END_USER.value), + created_by=user.id + ) - # init workflow run - workflow_run = WorkflowRun( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - sequence_number=new_sequence_number, - workflow_id=workflow.id, - type=workflow.type, - triggered_from=triggered_from.value, - version=workflow.version, - graph=workflow.graph, - inputs=json.dumps({**user_inputs, **system_inputs}), - status=WorkflowRunStatus.RUNNING.value, - created_by_role=(CreatedByRole.ACCOUNT.value - if isinstance(user, Account) else CreatedByRole.END_USER.value), - created_by=user.id - ) - - db.session.add(workflow_run) - db.session.commit() - except: - db.session.rollback() - raise + db.session.add(workflow_run) + db.session.commit() if callbacks: for callback in callbacks: @@ -330,7 +326,7 @@ class WorkflowEngineManager: if not predecessor_node: for node_config in nodes: - if node_config.get('type') == NodeType.START.value: + if node_config.get('data', {}).get('type', '') == NodeType.START.value: return StartNode(config=node_config) else: edges = graph.get('edges') @@ -368,7 +364,7 @@ class WorkflowEngineManager: return None # get next node - target_node = node_classes.get(NodeType.value_of(target_node_config.get('type'))) + target_node = node_classes.get(NodeType.value_of(target_node_config.get('data', {}).get('type'))) return target_node( config=target_node_config, @@ -424,17 +420,18 @@ class WorkflowEngineManager: callbacks=callbacks ) - for variable_key, variable_value in node_run_result.outputs.items(): - # append variables to variable pool recursively - self._append_variables_recursively( - variable_pool=workflow_run_state.variable_pool, - node_id=node.node_id, - variable_key_list=[variable_key], - variable_value=variable_value - ) + if node_run_result.outputs: + for variable_key, variable_value in node_run_result.outputs.items(): + # append variables to variable pool recursively + self._append_variables_recursively( + variable_pool=workflow_run_state.variable_pool, + node_id=node.node_id, + variable_key_list=[variable_key], + variable_value=variable_value + ) - if node_run_result.metadata.get('total_tokens'): - workflow_run_state.total_tokens += int(node_run_result.metadata.get('total_tokens')) + if node_run_result.metadata and node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): + workflow_run_state.total_tokens += int(node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)) return workflow_node_execution @@ -464,7 +461,6 @@ class WorkflowEngineManager: node_id=node.node_id, node_type=node.node_type.value, title=node.node_data.title, - type=node.node_type.value, status=WorkflowNodeExecutionStatus.RUNNING.value, created_by_role=workflow_run.created_by_role, created_by=workflow_run.created_by @@ -493,10 +489,11 @@ class WorkflowEngineManager: """ workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value workflow_node_execution.elapsed_time = time.perf_counter() - start_at - workflow_node_execution.inputs = json.dumps(result.inputs) - workflow_node_execution.process_data = json.dumps(result.process_data) - workflow_node_execution.outputs = json.dumps(result.outputs) - workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(result.metadata)) + workflow_node_execution.inputs = json.dumps(result.inputs) if result.inputs else None + workflow_node_execution.process_data = json.dumps(result.process_data) if result.process_data else None + workflow_node_execution.outputs = json.dumps(result.outputs) if result.outputs else None + workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(result.metadata)) \ + if result.metadata else None workflow_node_execution.finished_at = datetime.utcnow() db.session.commit() diff --git a/api/migrations/versions/b289e2408ee2_add_workflow.py b/api/migrations/versions/b289e2408ee2_add_workflow.py index cf8530dc67..8fadf2dc6c 100644 --- a/api/migrations/versions/b289e2408ee2_add_workflow.py +++ b/api/migrations/versions/b289e2408ee2_add_workflow.py @@ -45,8 +45,8 @@ def upgrade(): sa.Column('node_id', sa.String(length=255), nullable=False), sa.Column('node_type', sa.String(length=255), nullable=False), sa.Column('title', sa.String(length=255), nullable=False), - sa.Column('inputs', sa.Text(), nullable=False), - sa.Column('process_data', sa.Text(), nullable=False), + sa.Column('inputs', sa.Text(), nullable=True), + sa.Column('process_data', sa.Text(), nullable=True), sa.Column('outputs', sa.Text(), nullable=True), sa.Column('status', sa.String(length=255), nullable=False), sa.Column('error', sa.Text(), nullable=True), diff --git a/api/migrations/versions/b5429b71023c_messages_columns_set_nullable.py b/api/migrations/versions/b5429b71023c_messages_columns_set_nullable.py new file mode 100644 index 0000000000..ee81fdab28 --- /dev/null +++ b/api/migrations/versions/b5429b71023c_messages_columns_set_nullable.py @@ -0,0 +1,41 @@ +"""messages columns set nullable + +Revision ID: b5429b71023c +Revises: 42e85ed5564d +Create Date: 2024-03-07 09:52:00.846136 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'b5429b71023c' +down_revision = '42e85ed5564d' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.alter_column('model_provider', + existing_type=sa.VARCHAR(length=255), + nullable=True) + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(length=255), + nullable=True) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(length=255), + nullable=False) + batch_op.alter_column('model_provider', + existing_type=sa.VARCHAR(length=255), + nullable=False) + + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index c579c3dee8..6856c4e1b0 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -585,8 +585,8 @@ class Message(db.Model): id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) app_id = db.Column(UUID, nullable=False) - model_provider = db.Column(db.String(255), nullable=False) - model_id = db.Column(db.String(255), nullable=False) + model_provider = db.Column(db.String(255), nullable=True) + model_id = db.Column(db.String(255), nullable=True) override_model_configs = db.Column(db.Text) conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=False) inputs = db.Column(db.JSON) diff --git a/api/models/workflow.py b/api/models/workflow.py index 032134a0d1..0883d0ef13 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -138,7 +138,7 @@ class Workflow(db.Model): if 'nodes' not in graph_dict: return [] - start_node = next((node for node in graph_dict['nodes'] if node['type'] == 'start'), None) + start_node = next((node for node in graph_dict['nodes'] if node['data']['type'] == 'start'), None) if not start_node: return [] @@ -392,8 +392,8 @@ class WorkflowNodeExecution(db.Model): node_id = db.Column(db.String(255), nullable=False) node_type = db.Column(db.String(255), nullable=False) title = db.Column(db.String(255), nullable=False) - inputs = db.Column(db.Text, nullable=False) - process_data = db.Column(db.Text, nullable=False) + inputs = db.Column(db.Text) + process_data = db.Column(db.Text) outputs = db.Column(db.Text) status = db.Column(db.String(255), nullable=False) error = db.Column(db.Text)