diff --git a/api/constants/model_template.py b/api/constants/model_template.py index 61aab64d8a..c8aaba23cb 100644 --- a/api/constants/model_template.py +++ b/api/constants/model_template.py @@ -7,8 +7,7 @@ default_app_templates = { 'mode': AppMode.WORKFLOW.value, 'enable_site': True, 'enable_api': True - }, - 'model_config': {} + } }, # chat default mode @@ -34,14 +33,6 @@ default_app_templates = { 'mode': AppMode.ADVANCED_CHAT.value, 'enable_site': True, 'enable_api': True - }, - 'model_config': { - 'model': { - "provider": "openai", - "name": "gpt-4", - "mode": "chat", - "completion_params": {} - } } }, diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 4fcf8daf6e..54585d8519 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -41,10 +41,16 @@ class DraftWorkflowApi(Resource): """ parser = reqparse.RequestParser() parser.add_argument('graph', type=dict, required=True, nullable=False, location='json') + parser.add_argument('features', type=dict, required=True, nullable=False, location='json') args = parser.parse_args() workflow_service = WorkflowService() - workflow_service.sync_draft_workflow(app_model=app_model, graph=args.get('graph'), account=current_user) + workflow_service.sync_draft_workflow( + app_model=app_model, + graph=args.get('graph'), + features=args.get('features'), + account=current_user + ) return { "result": "success" diff --git a/api/core/app/chat/app_runner.py b/api/core/app/chat/app_runner.py index a1eccab13a..4c8018572e 100644 --- a/api/core/app/chat/app_runner.py +++ b/api/core/app/chat/app_runner.py @@ -1,21 +1,17 @@ import logging -from typing import Optional from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.base_app_runner import AppRunner from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.application_entities import ( ApplicationGenerateEntity, - DatasetEntity, - InvokeFrom, - ModelConfigEntity, ) from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationException from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db -from models.model import App, AppMode, Conversation, Message +from models.model import App, Conversation, Message logger = logging.getLogger(__name__) @@ -145,18 +141,23 @@ class ChatAppRunner(AppRunner): # get context from datasets context = None if app_orchestration_config.dataset and app_orchestration_config.dataset.dataset_ids: - context = self.retrieve_dataset_context( + hit_callback = DatasetIndexToolCallbackHandler( + queue_manager, + app_record.id, + message.id, + application_generate_entity.user_id, + application_generate_entity.invoke_from + ) + + dataset_retrieval = DatasetRetrieval() + context = dataset_retrieval.retrieve( tenant_id=app_record.tenant_id, - app_record=app_record, - queue_manager=queue_manager, model_config=app_orchestration_config.model_config, - show_retrieve_source=app_orchestration_config.show_retrieve_source, - dataset_config=app_orchestration_config.dataset, - message=message, - inputs=inputs, + config=app_orchestration_config.dataset, query=query, - user_id=application_generate_entity.user_id, invoke_from=application_generate_entity.invoke_from, + show_retrieve_source=app_orchestration_config.show_retrieve_source, + hit_callback=hit_callback, memory=memory ) @@ -212,57 +213,3 @@ class ChatAppRunner(AppRunner): queue_manager=queue_manager, stream=application_generate_entity.stream ) - - def retrieve_dataset_context(self, tenant_id: str, - app_record: App, - queue_manager: AppQueueManager, - model_config: ModelConfigEntity, - dataset_config: DatasetEntity, - show_retrieve_source: bool, - message: Message, - inputs: dict, - query: str, - user_id: str, - invoke_from: InvokeFrom, - memory: Optional[TokenBufferMemory] = None) -> Optional[str]: - """ - Retrieve dataset context - :param tenant_id: tenant id - :param app_record: app record - :param queue_manager: queue manager - :param model_config: model config - :param dataset_config: dataset config - :param show_retrieve_source: show retrieve source - :param message: message - :param inputs: inputs - :param query: query - :param user_id: user id - :param invoke_from: invoke from - :param memory: memory - :return: - """ - hit_callback = DatasetIndexToolCallbackHandler( - queue_manager, - app_record.id, - message.id, - user_id, - invoke_from - ) - - # TODO - if (app_record.mode == AppMode.COMPLETION.value and dataset_config - and dataset_config.retrieve_config.query_variable): - query = inputs.get(dataset_config.retrieve_config.query_variable, "") - - dataset_retrieval = DatasetRetrieval() - return dataset_retrieval.retrieve( - tenant_id=tenant_id, - model_config=model_config, - config=dataset_config, - query=query, - invoke_from=invoke_from, - show_retrieve_source=show_retrieve_source, - hit_callback=hit_callback, - memory=memory - ) - \ No newline at end of file diff --git a/api/core/app/completion/app_runner.py b/api/core/app/completion/app_runner.py index 3ac182b34e..ab2f40ad9a 100644 --- a/api/core/app/completion/app_runner.py +++ b/api/core/app/completion/app_runner.py @@ -1,21 +1,16 @@ import logging -from typing import Optional -from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.app.app_queue_manager import AppQueueManager from core.app.base_app_runner import AppRunner from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.application_entities import ( ApplicationGenerateEntity, - DatasetEntity, - InvokeFrom, - ModelConfigEntity, ) -from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationException from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db -from models.model import App, AppMode, Conversation, Message +from models.model import App, Message logger = logging.getLogger(__name__) @@ -27,13 +22,11 @@ class CompletionAppRunner(AppRunner): def run(self, application_generate_entity: ApplicationGenerateEntity, queue_manager: AppQueueManager, - conversation: Conversation, message: Message) -> None: """ Run application :param application_generate_entity: application generate entity :param queue_manager: application queue manager - :param conversation: conversation :param message: message :return: """ @@ -61,30 +54,15 @@ class CompletionAppRunner(AppRunner): query=query ) - memory = None - if application_generate_entity.conversation_id: - # get memory of conversation (read-only) - model_instance = ModelInstance( - provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, - model=app_orchestration_config.model_config.model - ) - - memory = TokenBufferMemory( - conversation=conversation, - model_instance=model_instance - ) - # organize all inputs and template to prompt messages # Include: prompt template, inputs, query(optional), files(optional) - # memory(optional) prompt_messages, stop = self.organize_prompt_messages( app_record=app_record, model_config=app_orchestration_config.model_config, prompt_template_entity=app_orchestration_config.prompt_template, inputs=inputs, files=files, - query=query, - memory=memory + query=query ) # moderation @@ -107,30 +85,6 @@ class CompletionAppRunner(AppRunner): ) return - if query: - # annotation reply - annotation_reply = self.query_app_annotations_to_reply( - app_record=app_record, - message=message, - query=query, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from - ) - - if annotation_reply: - queue_manager.publish_annotation_reply( - message_annotation_id=annotation_reply.id, - pub_from=PublishFrom.APPLICATION_MANAGER - ) - self.direct_output( - queue_manager=queue_manager, - app_orchestration_config=app_orchestration_config, - prompt_messages=prompt_messages, - text=annotation_reply.content, - stream=application_generate_entity.stream - ) - return - # fill in variable inputs from external data tools if exists external_data_tools = app_orchestration_config.external_data_variables if external_data_tools: @@ -145,19 +99,27 @@ class CompletionAppRunner(AppRunner): # get context from datasets context = None if app_orchestration_config.dataset and app_orchestration_config.dataset.dataset_ids: - context = self.retrieve_dataset_context( + hit_callback = DatasetIndexToolCallbackHandler( + queue_manager, + app_record.id, + message.id, + application_generate_entity.user_id, + application_generate_entity.invoke_from + ) + + dataset_config = app_orchestration_config.dataset + if dataset_config and dataset_config.retrieve_config.query_variable: + query = inputs.get(dataset_config.retrieve_config.query_variable, "") + + dataset_retrieval = DatasetRetrieval() + context = dataset_retrieval.retrieve( tenant_id=app_record.tenant_id, - app_record=app_record, - queue_manager=queue_manager, model_config=app_orchestration_config.model_config, - show_retrieve_source=app_orchestration_config.show_retrieve_source, - dataset_config=app_orchestration_config.dataset, - message=message, - inputs=inputs, + config=dataset_config, query=query, - user_id=application_generate_entity.user_id, invoke_from=application_generate_entity.invoke_from, - memory=memory + show_retrieve_source=app_orchestration_config.show_retrieve_source, + hit_callback=hit_callback ) # reorganize all inputs and template to prompt messages @@ -170,8 +132,7 @@ class CompletionAppRunner(AppRunner): inputs=inputs, files=files, query=query, - context=context, - memory=memory + context=context ) # check hosting moderation @@ -210,57 +171,4 @@ class CompletionAppRunner(AppRunner): queue_manager=queue_manager, stream=application_generate_entity.stream ) - - def retrieve_dataset_context(self, tenant_id: str, - app_record: App, - queue_manager: AppQueueManager, - model_config: ModelConfigEntity, - dataset_config: DatasetEntity, - show_retrieve_source: bool, - message: Message, - inputs: dict, - query: str, - user_id: str, - invoke_from: InvokeFrom, - memory: Optional[TokenBufferMemory] = None) -> Optional[str]: - """ - Retrieve dataset context - :param tenant_id: tenant id - :param app_record: app record - :param queue_manager: queue manager - :param model_config: model config - :param dataset_config: dataset config - :param show_retrieve_source: show retrieve source - :param message: message - :param inputs: inputs - :param query: query - :param user_id: user id - :param invoke_from: invoke from - :param memory: memory - :return: - """ - hit_callback = DatasetIndexToolCallbackHandler( - queue_manager, - app_record.id, - message.id, - user_id, - invoke_from - ) - - # TODO - if (app_record.mode == AppMode.COMPLETION.value and dataset_config - and dataset_config.retrieve_config.query_variable): - query = inputs.get(dataset_config.retrieve_config.query_variable, "") - - dataset_retrieval = DatasetRetrieval() - return dataset_retrieval.retrieve( - tenant_id=tenant_id, - model_config=model_config, - config=dataset_config, - query=query, - invoke_from=invoke_from, - show_retrieve_source=show_retrieve_source, - hit_callback=hit_callback, - memory=memory - ) \ No newline at end of file diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index decdc0567f..bcb2c318c6 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,5 +1,3 @@ -import json - from flask_restful import fields from fields.member_fields import simple_account_fields @@ -7,7 +5,8 @@ from libs.helper import TimestampField workflow_fields = { 'id': fields.String, - 'graph': fields.Raw(attribute=lambda x: json.loads(x.graph) if hasattr(x, 'graph') else None), + 'graph': fields.Nested(simple_account_fields, attribute='graph_dict'), + 'features': fields.Nested(simple_account_fields, attribute='features_dict'), 'created_by': fields.Nested(simple_account_fields, attribute='created_by_account'), 'created_at': TimestampField, 'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True), diff --git a/api/migrations/versions/b289e2408ee2_add_workflow.py b/api/migrations/versions/b289e2408ee2_add_workflow.py index 5f7ddc7d68..5ae1e65611 100644 --- a/api/migrations/versions/b289e2408ee2_add_workflow.py +++ b/api/migrations/versions/b289e2408ee2_add_workflow.py @@ -97,6 +97,7 @@ def upgrade(): sa.Column('type', sa.String(length=255), nullable=False), sa.Column('version', sa.String(length=255), nullable=False), sa.Column('graph', sa.Text(), nullable=True), + sa.Column('features', sa.Text(), nullable=True), sa.Column('created_by', postgresql.UUID(), nullable=False), sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), sa.Column('updated_by', postgresql.UUID(), nullable=True), @@ -106,7 +107,7 @@ def upgrade(): with op.batch_alter_table('workflows', schema=None) as batch_op: batch_op.create_index('workflow_version_idx', ['tenant_id', 'app_id', 'version'], unique=False) - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + with op.batch_alter_table('apps', schema=None) as batch_op: batch_op.add_column(sa.Column('workflow_id', postgresql.UUID(), nullable=True)) with op.batch_alter_table('messages', schema=None) as batch_op: @@ -120,7 +121,7 @@ def downgrade(): with op.batch_alter_table('messages', schema=None) as batch_op: batch_op.drop_column('workflow_run_id') - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + with op.batch_alter_table('apps', schema=None) as batch_op: batch_op.drop_column('workflow_id') with op.batch_alter_table('workflows', schema=None) as batch_op: diff --git a/api/models/model.py b/api/models/model.py index 6708898b51..b8723dd443 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -63,6 +63,7 @@ class App(db.Model): icon = db.Column(db.String(255)) icon_background = db.Column(db.String(255)) app_model_config_id = db.Column(UUID, nullable=True) + workflow_id = db.Column(UUID, nullable=True) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) enable_site = db.Column(db.Boolean, nullable=False) enable_api = db.Column(db.Boolean, nullable=False) @@ -85,6 +86,14 @@ class App(db.Model): AppModelConfig.id == self.app_model_config_id).first() return app_model_config + @property + def workflow(self): + if self.workflow_id: + from api.models.workflow import Workflow + return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first() + + return None + @property def api_base_url(self): return (current_app.config['SERVICE_API_URL'] if current_app.config['SERVICE_API_URL'] @@ -176,7 +185,6 @@ class AppModelConfig(db.Model): dataset_configs = db.Column(db.Text) external_data_tools = db.Column(db.Text) file_upload = db.Column(db.Text) - workflow_id = db.Column(UUID) @property def app(self): @@ -276,14 +284,6 @@ class AppModelConfig(db.Model): "image": {"enabled": False, "number_limits": 3, "detail": "high", "transfer_methods": ["remote_url", "local_file"]}} - @property - def workflow(self): - if self.workflow_id: - from api.models.workflow import Workflow - return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first() - - return None - def to_dict(self) -> dict: return { "opening_statement": self.opening_statement, @@ -343,7 +343,6 @@ class AppModelConfig(db.Model): if model_config.get('dataset_configs') else None self.file_upload = json.dumps(model_config.get('file_upload')) \ if model_config.get('file_upload') else None - self.workflow_id = model_config.get('workflow_id') return self def copy(self): @@ -368,8 +367,7 @@ class AppModelConfig(db.Model): chat_prompt_config=self.chat_prompt_config, completion_prompt_config=self.completion_prompt_config, dataset_configs=self.dataset_configs, - file_upload=self.file_upload, - workflow_id=self.workflow_id + file_upload=self.file_upload ) return new_app_model_config diff --git a/api/models/workflow.py b/api/models/workflow.py index 316d3e623e..c38c1dd610 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,3 +1,4 @@ +import json from enum import Enum from typing import Union @@ -106,6 +107,7 @@ class Workflow(db.Model): type = db.Column(db.String(255), nullable=False) version = db.Column(db.String(255), nullable=False) graph = db.Column(db.Text) + features = db.Column(db.Text) created_by = db.Column(UUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_by = db.Column(UUID) @@ -119,6 +121,14 @@ class Workflow(db.Model): def updated_by_account(self): return Account.query.get(self.updated_by) + @property + def graph_dict(self): + return self.graph if not self.graph else json.loads(self.graph) + + @property + def features_dict(self): + return self.features if not self.features else json.loads(self.features) + class WorkflowRunTriggeredFrom(Enum): """ diff --git a/api/services/app_service.py b/api/services/app_service.py index 374727d2d4..7dd5d770ea 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -64,8 +64,8 @@ class AppService: app_template = default_app_templates[app_mode] # get model config - default_model_config = app_template['model_config'] - if 'model' in default_model_config: + default_model_config = app_template.get('model_config') + if default_model_config and 'model' in default_model_config: # get model provider model_manager = ModelManager() @@ -110,12 +110,15 @@ class AppService: db.session.add(app) db.session.flush() - app_model_config = AppModelConfig(**default_model_config) - app_model_config.app_id = app.id - db.session.add(app_model_config) - db.session.flush() + if default_model_config: + app_model_config = AppModelConfig(**default_model_config) + app_model_config.app_id = app.id + db.session.add(app_model_config) + db.session.flush() - app.app_model_config_id = app_model_config.id + app.app_model_config_id = app_model_config.id + + db.session.commit() app_was_created.send(app, account=account) @@ -135,16 +138,22 @@ class AppService: app_data = import_data.get('app') model_config_data = import_data.get('model_config') - workflow_graph = import_data.get('workflow_graph') + workflow = import_data.get('workflow') - if not app_data or not model_config_data: - raise ValueError("Missing app or model_config in data argument") + if not app_data: + raise ValueError("Missing app in data argument") app_mode = AppMode.value_of(app_data.get('mode')) if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: - if not workflow_graph: - raise ValueError("Missing workflow_graph in data argument " - "when mode is advanced-chat or workflow") + if not workflow: + raise ValueError("Missing workflow in data argument " + "when app mode is advanced-chat or workflow") + elif app_mode in [AppMode.CHAT, AppMode.AGENT_CHAT]: + if not model_config_data: + raise ValueError("Missing model_config in data argument " + "when app mode is chat or agent-chat") + else: + raise ValueError("Invalid app mode") app = App( tenant_id=tenant_id, @@ -161,26 +170,32 @@ class AppService: db.session.add(app) db.session.commit() - if workflow_graph: - # init draft workflow - workflow_service = WorkflowService() - workflow_service.sync_draft_workflow(app, workflow_graph, account) - - app_model_config = AppModelConfig() - app_model_config = app_model_config.from_model_config_dict(model_config_data) - app_model_config.app_id = app.id - - db.session.add(app_model_config) - db.session.commit() - - app.app_model_config_id = app_model_config.id - app_was_created.send(app, account=account) - app_model_config_was_updated.send( - app, - app_model_config=app_model_config - ) + if workflow: + # init draft workflow + workflow_service = WorkflowService() + workflow_service.sync_draft_workflow( + app_model=app, + graph=workflow.get('graph'), + features=workflow.get('features'), + account=account + ) + + if model_config_data: + app_model_config = AppModelConfig() + app_model_config = app_model_config.from_model_config_dict(model_config_data) + app_model_config.app_id = app.id + + db.session.add(app_model_config) + db.session.commit() + + app.app_model_config_id = app_model_config.id + + app_model_config_was_updated.send( + app, + app_model_config=app_model_config + ) return app @@ -190,7 +205,7 @@ class AppService: :param app: App instance :return: """ - app_model_config = app.app_model_config + app_mode = AppMode.value_of(app.mode) export_data = { "app": { @@ -198,16 +213,27 @@ class AppService: "mode": app.mode, "icon": app.icon, "icon_background": app.icon_background - }, - "model_config": app_model_config.to_dict(), + } } - if app_model_config.workflow_id: - export_data['workflow_graph'] = json.loads(app_model_config.workflow.graph) + if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: + if app.workflow_id: + workflow = app.workflow + export_data['workflow'] = { + "graph": workflow.graph_dict, + "features": workflow.features_dict + } + else: + workflow_service = WorkflowService() + workflow = workflow_service.get_draft_workflow(app) + export_data['workflow'] = { + "graph": workflow.graph_dict, + "features": workflow.features_dict + } else: - workflow_service = WorkflowService() - workflow = workflow_service.get_draft_workflow(app) - export_data['workflow_graph'] = json.loads(workflow.graph) + app_model_config = app.app_model_config + + export_data['model_config'] = app_model_config.to_dict() return yaml.dump(export_data) diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index f384855e7a..6c0182dd9e 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -44,13 +44,10 @@ class WorkflowConverter: :param account: Account :return: new App instance """ - # get original app config - app_model_config = app_model.app_model_config - # convert app model config workflow = self.convert_app_model_config_to_workflow( app_model=app_model, - app_model_config=app_model_config, + app_model_config=app_model.app_model_config, account_id=account.id ) @@ -58,8 +55,9 @@ class WorkflowConverter: new_app = App() new_app.tenant_id = app_model.tenant_id new_app.name = app_model.name + '(workflow)' - new_app.mode = AppMode.CHAT.value \ + new_app.mode = AppMode.ADVANCED_CHAT.value \ if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value + new_app.workflow_id = workflow.id new_app.icon = app_model.icon new_app.icon_background = app_model.icon_background new_app.enable_site = app_model.enable_site @@ -69,28 +67,6 @@ class WorkflowConverter: new_app.is_demo = False new_app.is_public = app_model.is_public db.session.add(new_app) - db.session.flush() - - # create new app model config record - new_app_model_config = app_model_config.copy() - new_app_model_config.id = None - new_app_model_config.app_id = new_app.id - new_app_model_config.external_data_tools = '' - new_app_model_config.model = '' - new_app_model_config.user_input_form = '' - new_app_model_config.dataset_query_variable = None - new_app_model_config.pre_prompt = None - new_app_model_config.agent_mode = '' - new_app_model_config.prompt_type = 'simple' - new_app_model_config.chat_prompt_config = '' - new_app_model_config.completion_prompt_config = '' - new_app_model_config.dataset_configs = '' - new_app_model_config.workflow_id = workflow.id - - db.session.add(new_app_model_config) - db.session.flush() - - new_app.app_model_config_id = new_app_model_config.id db.session.commit() app_was_created.send(new_app, account=account) @@ -110,11 +86,13 @@ class WorkflowConverter: # get new app mode new_app_mode = self._get_new_app_mode(app_model) + app_model_config_dict = app_model_config.to_dict() + # convert app model config application_manager = AppManager() app_orchestration_config_entity = application_manager.convert_from_app_model_config_dict( tenant_id=app_model.tenant_id, - app_model_config_dict=app_model_config.to_dict(), + app_model_config_dict=app_model_config_dict, skip_check=True ) @@ -177,6 +155,25 @@ class WorkflowConverter: graph = self._append_node(graph, end_node) + # features + if new_app_mode == AppMode.ADVANCED_CHAT: + features = { + "opening_statement": app_model_config_dict.get("opening_statement"), + "suggested_questions": app_model_config_dict.get("suggested_questions"), + "suggested_questions_after_answer": app_model_config_dict.get("suggested_questions_after_answer"), + "speech_to_text": app_model_config_dict.get("speech_to_text"), + "text_to_speech": app_model_config_dict.get("text_to_speech"), + "file_upload": app_model_config_dict.get("file_upload"), + "sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"), + "retriever_resource": app_model_config_dict.get("retriever_resource"), + } + else: + features = { + "text_to_speech": app_model_config_dict.get("text_to_speech"), + "file_upload": app_model_config_dict.get("file_upload"), + "sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"), + } + # create workflow record workflow = Workflow( tenant_id=app_model.tenant_id, @@ -184,6 +181,7 @@ class WorkflowConverter: type=WorkflowType.from_app_mode(new_app_mode).value, version='draft', graph=json.dumps(graph), + features=json.dumps(features), created_by=account_id, created_at=app_model_config.created_at ) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 5a9234c70a..006bc44e41 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -33,29 +33,31 @@ class WorkflowService: """ Get published workflow """ - app_model_config = app_model.app_model_config - - if not app_model_config.workflow_id: + if not app_model.workflow_id: return None # fetch published workflow by workflow_id workflow = db.session.query(Workflow).filter( Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, - Workflow.id == app_model_config.workflow_id + Workflow.id == app_model.workflow_id ).first() # return published workflow return workflow - - def sync_draft_workflow(self, app_model: App, graph: dict, account: Account) -> Workflow: + def sync_draft_workflow(self, app_model: App, + graph: dict, + features: dict, + account: Account) -> Workflow: """ Sync draft workflow """ # fetch draft workflow by app_model workflow = self.get_draft_workflow(app_model=app_model) + # TODO validate features + # create draft workflow if not found if not workflow: workflow = Workflow( @@ -64,12 +66,14 @@ class WorkflowService: type=WorkflowType.from_app_mode(app_model.mode).value, version='draft', graph=json.dumps(graph), + features=json.dumps(features), created_by=account.id ) db.session.add(workflow) # update draft workflow if found else: workflow.graph = json.dumps(graph) + workflow.features = json.dumps(features) workflow.updated_by = account.id workflow.updated_at = datetime.utcnow() @@ -112,28 +116,7 @@ class WorkflowService: db.session.add(workflow) db.session.commit() - app_model_config = app_model.app_model_config - - # create new app model config record - new_app_model_config = app_model_config.copy() - new_app_model_config.id = None - new_app_model_config.app_id = app_model.id - new_app_model_config.external_data_tools = '' - new_app_model_config.model = '' - new_app_model_config.user_input_form = '' - new_app_model_config.dataset_query_variable = None - new_app_model_config.pre_prompt = None - new_app_model_config.agent_mode = '' - new_app_model_config.prompt_type = 'simple' - new_app_model_config.chat_prompt_config = '' - new_app_model_config.completion_prompt_config = '' - new_app_model_config.dataset_configs = '' - new_app_model_config.workflow_id = workflow.id - - db.session.add(new_app_model_config) - db.session.flush() - - app_model.app_model_config_id = new_app_model_config.id + app_model.workflow_id = workflow.id db.session.commit() # TODO update app related datasets