diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index bcc88c848d..230ccdca15 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -261,6 +261,7 @@ class InstructionGenerateApi(Resource): instruction=args["instruction"], model_config=args["model_config"], ideal_output=args["ideal_output"], + workflow_service=WorkflowService(), ) return {"error": "incompatible parameters"}, 400 except ProviderTokenNotInitError as ex: diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 6eb91d515c..e07d0ec14e 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -2,7 +2,7 @@ import json import logging import re from collections.abc import Sequence -from typing import cast +from typing import Protocol, cast import json_repair @@ -32,10 +32,19 @@ from core.workflow.node_events import AgentLogEvent from extensions.ext_database import db from extensions.ext_storage import storage from models import App, Message, WorkflowNodeExecutionModel +from models.workflow import Workflow logger = logging.getLogger(__name__) +class WorkflowServiceInterface(Protocol): + def get_draft_workflow(self, app_model: App, workflow_id: str | None = None) -> Workflow | None: + pass + + def get_node_last_run(self, app_model: App, workflow: Workflow, node_id: str) -> WorkflowNodeExecutionModel | None: + pass + + class LLMGenerator: @classmethod def generate_conversation_name( @@ -419,18 +428,17 @@ class LLMGenerator: instruction: str, model_config: dict, ideal_output: str | None, + workflow_service: WorkflowServiceInterface, ): - from services.workflow_service import WorkflowService - session = db.session() app: App | None = session.query(App).where(App.id == flow_id).first() if not app: raise ValueError("App not found.") - workflow = WorkflowService().get_draft_workflow(app_model=app) + workflow = workflow_service.get_draft_workflow(app_model=app) if not workflow: raise ValueError("Workflow not found for the given app model.") - last_run = WorkflowService().get_node_last_run(app_model=app, workflow=workflow, node_id=node_id) + last_run = workflow_service.get_node_last_run(app_model=app, workflow=workflow, node_id=node_id) try: node_type = cast(WorkflowNodeExecutionModel, last_run).node_type except Exception: