diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index ef17dbd288..f7c35fa64e 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -4,7 +4,7 @@ from datetime import datetime from graphon.model_runtime.utils.encoders import jsonable_encoder from sqlalchemy import delete, or_, select -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from core.tools.__base.tool_provider import ToolProviderController from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity @@ -42,32 +42,43 @@ class WorkflowToolManageService: labels: list[str] | None = None, ): # check if the name is unique - existing_workflow_tool_provider = db.session.scalar( - select(WorkflowToolProvider) - .where( - WorkflowToolProvider.tenant_id == tenant_id, - # name or app_id - or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id), + existing_workflow_tool_provider: WorkflowToolProvider | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + # query if the name or app_id exists + existing_workflow_tool_provider = _session.scalar( + select(WorkflowToolProvider) + .where( + WorkflowToolProvider.tenant_id == tenant_id, + # name or app_id + or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id), + ) + .limit(1) ) - .limit(1) - ) + # if the name or app_id exists raise error if existing_workflow_tool_provider is not None: raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists") - app: App | None = db.session.scalar( - select(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).limit(1) - ) + # query the app + app: App | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + app = _session.scalar(select(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).limit(1)) + # if not found raise error if app is None: raise ValueError(f"App {workflow_app_id} not found") + # query the workflow workflow: Workflow | None = app.workflow + + # if not found raise error if workflow is None: raise ValueError(f"Workflow not found for app {workflow_app_id}") + # check if workflow configuration is synced WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(workflow.graph_dict) + # create workflow tool provider workflow_tool_provider = WorkflowToolProvider( tenant_id=tenant_id, user_id=user_id, @@ -87,13 +98,15 @@ class WorkflowToolManageService: logger.warning(e, exc_info=True) raise ValueError(str(e)) - with Session(db.engine, expire_on_commit=False) as session, session.begin(): - session.add(workflow_tool_provider) + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + _session.add(workflow_tool_provider) + # keep the session open to make orm instances in the same session if labels is not None: ToolLabelManager.update_tool_labels( ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels ) + return {"result": "success"} @classmethod @@ -112,6 +125,7 @@ class WorkflowToolManageService: ): """ Update a workflow tool. + :param user_id: the user id :param tenant_id: the tenant id :param workflow_tool_id: workflow tool id @@ -187,28 +201,32 @@ class WorkflowToolManageService: def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]: """ List workflow tools. + :param user_id: the user id :param tenant_id: the tenant id :return: the list of tools """ - db_tools = db.session.scalars( - select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id) - ).all() + + providers: list[WorkflowToolProvider] = [] + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + providers = list( + _session.scalars(select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)).all() + ) # Create a mapping from provider_id to app_id - provider_id_to_app_id = {provider.id: provider.app_id for provider in db_tools} + provider_id_to_app_id = {provider.id: provider.app_id for provider in providers} tools: list[WorkflowToolProviderController] = [] - for provider in db_tools: + for provider in providers: try: tools.append(ToolTransformService.workflow_provider_to_controller(provider)) except Exception: # skip deleted tools logger.exception("Failed to load workflow tool provider %s", provider.id) - labels = ToolLabelManager.get_tools_labels([t for t in tools if isinstance(t, ToolProviderController)]) + labels = ToolLabelManager.get_tools_labels([tool for tool in tools if isinstance(tool, ToolProviderController)]) - result = [] + result: list[ToolProviderApiEntity] = [] for tool in tools: workflow_app_id = provider_id_to_app_id.get(tool.provider_id) @@ -233,17 +251,18 @@ class WorkflowToolManageService: def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str): """ Delete a workflow tool. + :param user_id: the user id :param tenant_id: the tenant id :param workflow_tool_id: the workflow tool id """ - db.session.execute( - delete(WorkflowToolProvider).where( - WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id - ) - ) - db.session.commit() + with sessionmaker(db.engine).begin() as _session: + _ = _session.execute( + delete(WorkflowToolProvider).where( + WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id + ) + ) return {"result": "success"} @@ -251,47 +270,59 @@ class WorkflowToolManageService: def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str): """ Get a workflow tool. + :param user_id: the user id :param tenant_id: the tenant id :param workflow_tool_id: the workflow tool id :return: the tool """ - db_tool: WorkflowToolProvider | None = db.session.scalar( - select(WorkflowToolProvider) - .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) - .limit(1) - ) - return cls._get_workflow_tool(tenant_id, db_tool) + + tool_provider: WorkflowToolProvider | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + tool_provider = _session.scalar( + select(WorkflowToolProvider) + .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .limit(1) + ) + + return cls._get_workflow_tool(tenant_id, tool_provider) @classmethod def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str): """ Get a workflow tool. + :param user_id: the user id :param tenant_id: the tenant id :param workflow_app_id: the workflow app id :return: the tool """ - db_tool: WorkflowToolProvider | None = db.session.scalar( - select(WorkflowToolProvider) - .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id) - .limit(1) - ) - return cls._get_workflow_tool(tenant_id, db_tool) + + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + tool_provider: WorkflowToolProvider | None = _session.scalar( + select(WorkflowToolProvider) + .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id) + .limit(1) + ) + + return cls._get_workflow_tool(tenant_id, tool_provider) @classmethod def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None): """ Get a workflow tool. + :db_tool: the database tool :return: the tool """ if db_tool is None: raise ValueError("Tool not found") - workflow_app: App | None = db.session.scalar( - select(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).limit(1) - ) + workflow_app: App | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + workflow_app = _session.scalar( + select(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).limit(1) + ) if workflow_app is None: raise ValueError(f"App {db_tool.app_id} not found") @@ -331,28 +362,32 @@ class WorkflowToolManageService: def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[ToolApiEntity]: """ List workflow tool provider tools. + :param user_id: the user id :param tenant_id: the tenant id :param workflow_tool_id: the workflow tool id :return: the list of tools """ - db_tool: WorkflowToolProvider | None = db.session.scalar( - select(WorkflowToolProvider) - .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) - .limit(1) - ) - if db_tool is None: + provider: WorkflowToolProvider | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + provider = _session.scalar( + select(WorkflowToolProvider) + .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .limit(1) + ) + + if provider is None: raise ValueError(f"Tool {workflow_tool_id} not found") - tool = ToolTransformService.workflow_provider_to_controller(db_tool) + tool = ToolTransformService.workflow_provider_to_controller(provider) workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id) if len(workflow_tools) == 0: raise ValueError(f"Tool {workflow_tool_id} not found") return [ ToolTransformService.convert_tool_entity_to_api_entity( - tool=tool.get_tools(db_tool.tenant_id)[0], + tool=tool.get_tools(provider.tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool), tenant_id=tenant_id, )