mirror of
https://github.com/langgenius/dify.git
synced 2026-06-26 06:41:10 +08:00
725 lines
26 KiB
Python
725 lines
26 KiB
Python
import contextlib
|
|
import logging
|
|
import uuid
|
|
from collections.abc import Generator, Mapping
|
|
from enum import StrEnum
|
|
from typing import Annotated, Any
|
|
|
|
from celery import shared_task
|
|
from flask import current_app, json
|
|
from pydantic import BaseModel, Discriminator, Field, Tag
|
|
from sqlalchemy import Engine, select
|
|
from sqlalchemy.orm import Session, sessionmaker
|
|
|
|
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
|
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
|
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
|
from core.app.entities.app_invoke_entities import (
|
|
AdvancedChatAppGenerateEntity,
|
|
InvokeFrom,
|
|
WorkflowAppGenerateEntity,
|
|
)
|
|
from core.app.entities.task_entities import WorkflowFinishStreamResponse, WorkflowStartStreamResponse
|
|
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext
|
|
from core.repositories import DifyCoreRepositoryFactory
|
|
from extensions.ext_database import db
|
|
from graphon.entities import WorkflowStartReason
|
|
from graphon.enums import WorkflowExecutionStatus
|
|
from graphon.runtime import GraphRuntimeState
|
|
from libs.datetime_utils import naive_utc_now
|
|
from libs.flask_utils import set_login_user
|
|
from libs.helper import to_timestamp
|
|
from models.account import Account
|
|
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
|
|
from models.model import App, AppMode, Conversation, EndUser, Message
|
|
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom, WorkflowRun
|
|
from repositories.factory import DifyAPIRepositoryFactory
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
WORKFLOW_BASED_APP_EXECUTION_QUEUE = "workflow_based_app_execution"
|
|
|
|
|
|
class _UserType(StrEnum):
|
|
ACCOUNT = "account"
|
|
END_USER = "end_user"
|
|
|
|
|
|
class _Account(BaseModel):
|
|
TYPE: _UserType = _UserType.ACCOUNT
|
|
|
|
user_id: str
|
|
|
|
|
|
class _EndUser(BaseModel):
|
|
TYPE: _UserType = _UserType.END_USER
|
|
end_user_id: str
|
|
|
|
|
|
def _get_user_type_descriminator(value: Any):
|
|
match value:
|
|
case _Account() | _EndUser():
|
|
return value.TYPE
|
|
case dict():
|
|
user_type_str = value.get("TYPE")
|
|
if user_type_str is None:
|
|
return None
|
|
try:
|
|
user_type = _UserType(user_type_str)
|
|
except ValueError:
|
|
return None
|
|
return user_type
|
|
case _:
|
|
# return None if the discriminator value isn't found
|
|
return None
|
|
|
|
|
|
type User = Annotated[
|
|
(Annotated[_Account, Tag(_UserType.ACCOUNT)] | Annotated[_EndUser, Tag(_UserType.END_USER)]),
|
|
Discriminator(_get_user_type_descriminator),
|
|
]
|
|
|
|
|
|
class AppExecutionParams(BaseModel):
|
|
app_id: str
|
|
workflow_id: str
|
|
tenant_id: str
|
|
app_mode: AppMode = AppMode.ADVANCED_CHAT
|
|
user: User
|
|
args: Mapping[str, Any]
|
|
|
|
invoke_from: InvokeFrom
|
|
streaming: bool = True
|
|
call_depth: int = 0
|
|
root_node_id: str | None = None
|
|
workflow_run_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
|
|
|
@classmethod
|
|
def new(
|
|
cls,
|
|
app_model: App,
|
|
workflow: Workflow,
|
|
user: Account | EndUser,
|
|
args: Mapping[str, Any],
|
|
invoke_from: InvokeFrom,
|
|
streaming: bool = True,
|
|
call_depth: int = 0,
|
|
root_node_id: str | None = None,
|
|
workflow_run_id: str | None = None,
|
|
):
|
|
user_params: _Account | _EndUser
|
|
match user:
|
|
case Account():
|
|
user_params = _Account(user_id=user.id)
|
|
case EndUser():
|
|
user_params = _EndUser(end_user_id=user.id)
|
|
case _:
|
|
raise AssertionError("this statement should be unreachable.")
|
|
return cls(
|
|
app_id=app_model.id,
|
|
workflow_id=workflow.id,
|
|
tenant_id=app_model.tenant_id,
|
|
app_mode=AppMode.value_of(app_model.mode),
|
|
user=user_params,
|
|
args=args,
|
|
invoke_from=invoke_from,
|
|
streaming=streaming,
|
|
call_depth=call_depth,
|
|
root_node_id=root_node_id,
|
|
workflow_run_id=workflow_run_id or str(uuid.uuid4()),
|
|
)
|
|
|
|
|
|
class _AppRunner:
|
|
def __init__(self, session_factory: sessionmaker | Engine, exec_params: AppExecutionParams):
|
|
if isinstance(session_factory, Engine):
|
|
session_factory = sessionmaker(bind=session_factory)
|
|
self._session_factory = session_factory
|
|
self._exec_params = exec_params
|
|
|
|
@contextlib.contextmanager
|
|
def _session(self):
|
|
with self._session_factory(expire_on_commit=False) as session, session.begin():
|
|
yield session
|
|
|
|
@contextlib.contextmanager
|
|
def _setup_flask_context(self, user: Account | EndUser):
|
|
flask_app = current_app._get_current_object() # type: ignore
|
|
with flask_app.app_context():
|
|
set_login_user(user)
|
|
yield
|
|
|
|
def run(self):
|
|
exec_params = self._exec_params
|
|
with self._session() as session:
|
|
workflow = session.get(Workflow, exec_params.workflow_id)
|
|
if workflow is None:
|
|
logger.warning("Workflow %s not found for execution", exec_params.workflow_id)
|
|
return None
|
|
app = session.get(App, workflow.app_id)
|
|
if app is None:
|
|
logger.warning("App %s not found for workflow %s", workflow.app_id, exec_params.workflow_id)
|
|
return None
|
|
|
|
pause_config = PauseStateLayerConfig(
|
|
session_factory=self._session_factory,
|
|
state_owner_user_id=workflow.created_by,
|
|
)
|
|
|
|
user = self._resolve_user()
|
|
|
|
with self._setup_flask_context(user):
|
|
try:
|
|
response = self._run_app(
|
|
app=app,
|
|
workflow=workflow,
|
|
user=user,
|
|
pause_state_config=pause_config,
|
|
)
|
|
except Exception as exc:
|
|
if exec_params.streaming:
|
|
_publish_failed_workflow_terminal_events(
|
|
exc=exc,
|
|
exec_params=exec_params,
|
|
)
|
|
raise
|
|
|
|
if not exec_params.streaming:
|
|
return response
|
|
|
|
assert isinstance(response, Generator)
|
|
_publish_streaming_response(
|
|
response,
|
|
exec_params.workflow_run_id,
|
|
exec_params.app_mode,
|
|
exec_params.workflow_id,
|
|
exec_params.args.get("inputs", {}),
|
|
WorkflowStartReason.INITIAL,
|
|
)
|
|
|
|
def _run_app(
|
|
self,
|
|
*,
|
|
app: App,
|
|
workflow: Workflow,
|
|
user: Account | EndUser,
|
|
pause_state_config: PauseStateLayerConfig,
|
|
):
|
|
exec_params = self._exec_params
|
|
if exec_params.app_mode == AppMode.ADVANCED_CHAT:
|
|
return AdvancedChatAppGenerator().generate(
|
|
app_model=app,
|
|
workflow=workflow,
|
|
user=user,
|
|
args=exec_params.args,
|
|
invoke_from=exec_params.invoke_from,
|
|
streaming=exec_params.streaming,
|
|
workflow_run_id=exec_params.workflow_run_id,
|
|
pause_state_config=pause_state_config,
|
|
)
|
|
if exec_params.app_mode == AppMode.WORKFLOW:
|
|
return WorkflowAppGenerator().generate(
|
|
app_model=app,
|
|
workflow=workflow,
|
|
user=user,
|
|
args=exec_params.args,
|
|
invoke_from=exec_params.invoke_from,
|
|
streaming=exec_params.streaming,
|
|
call_depth=exec_params.call_depth,
|
|
root_node_id=exec_params.root_node_id,
|
|
workflow_run_id=exec_params.workflow_run_id,
|
|
pause_state_config=pause_state_config,
|
|
)
|
|
|
|
logger.error("Unsupported app mode for execution: %s", exec_params.app_mode)
|
|
return None
|
|
|
|
def _resolve_user(self) -> Account | EndUser:
|
|
user_params = self._exec_params.user
|
|
|
|
match user_params:
|
|
case _EndUser():
|
|
with self._session() as session:
|
|
return session.get(EndUser, user_params.end_user_id)
|
|
case _Account():
|
|
with self._session() as session:
|
|
user: Account = session.get(Account, user_params.user_id)
|
|
user.set_tenant_id(self._exec_params.tenant_id)
|
|
return user
|
|
case _:
|
|
raise AssertionError(f"user should only be _Account or _EndUser, got {type(user_params)}")
|
|
|
|
|
|
def _resolve_user_for_run(session: Session, workflow_run: WorkflowRun) -> Account | EndUser | None:
|
|
role = CreatorUserRole(workflow_run.created_by_role)
|
|
if role == CreatorUserRole.ACCOUNT:
|
|
user = session.get(Account, workflow_run.created_by)
|
|
if user:
|
|
user.set_tenant_id(workflow_run.tenant_id)
|
|
return user
|
|
|
|
return session.get(EndUser, workflow_run.created_by)
|
|
|
|
|
|
def _publish_failed_workflow_terminal_events(exc: Exception, exec_params: AppExecutionParams) -> None:
|
|
"""Publish synthetic workflow lifecycle events for pre-runtime failures.
|
|
|
|
Early failures can happen before the app generator creates a task entity or
|
|
emits any workflow queue events. In that window SSE consumers still need a
|
|
normal terminal event to close their state machines, so we synthesize a
|
|
minimal `workflow_started -> workflow_finished(failed)` sequence here.
|
|
|
|
`workflow_run_id` is reused as a synthetic `task_id` because no application
|
|
task id exists yet on this failure path.
|
|
"""
|
|
timestamp = to_timestamp(naive_utc_now())
|
|
assert timestamp is not None
|
|
|
|
topic = MessageBasedAppGenerator.get_response_topic(exec_params.app_mode, exec_params.workflow_run_id)
|
|
started_payload = WorkflowStartStreamResponse(
|
|
task_id=exec_params.workflow_run_id,
|
|
workflow_run_id=exec_params.workflow_run_id,
|
|
data=WorkflowStartStreamResponse.Data(
|
|
id=exec_params.workflow_run_id,
|
|
workflow_id=exec_params.workflow_id,
|
|
inputs=exec_params.args.get("inputs", {}),
|
|
created_at=timestamp,
|
|
reason=WorkflowStartReason.INITIAL,
|
|
),
|
|
)
|
|
topic.publish(json.dumps(started_payload.model_dump(mode="json"), ensure_ascii=False).encode())
|
|
|
|
finished_payload = WorkflowFinishStreamResponse(
|
|
task_id=exec_params.workflow_run_id,
|
|
workflow_run_id=exec_params.workflow_run_id,
|
|
data=WorkflowFinishStreamResponse.Data(
|
|
id=exec_params.workflow_run_id,
|
|
workflow_id=exec_params.workflow_id,
|
|
status=WorkflowExecutionStatus.FAILED,
|
|
outputs=None,
|
|
error=str(exc),
|
|
elapsed_time=0.0,
|
|
total_tokens=0,
|
|
total_steps=0,
|
|
created_by={},
|
|
created_at=timestamp,
|
|
finished_at=timestamp,
|
|
exceptions_count=1,
|
|
files=[],
|
|
),
|
|
)
|
|
topic.publish(json.dumps(finished_payload.model_dump(mode="json"), ensure_ascii=False).encode())
|
|
|
|
|
|
def _get_event_name(event: str | Mapping[str, Any] | BaseModel) -> str | None:
|
|
if isinstance(event, BaseModel):
|
|
# Temporary compatibility for legacy BaseModel stream events; remove after confirming generators always emit
|
|
# str / Mapping responses.
|
|
event_name = getattr(event, "event", None)
|
|
elif isinstance(event, Mapping):
|
|
event_name = event.get("event")
|
|
else:
|
|
return None
|
|
|
|
if event_name is None:
|
|
return None
|
|
return str(event_name)
|
|
|
|
|
|
def _get_task_id(event: str | Mapping[str, Any] | BaseModel) -> str | None:
|
|
if isinstance(event, BaseModel):
|
|
# Temporary compatibility for legacy BaseModel stream events; remove after confirming generators always emit
|
|
# str / Mapping responses.
|
|
task_id = getattr(event, "task_id", None)
|
|
elif isinstance(event, Mapping):
|
|
task_id = event.get("task_id")
|
|
else:
|
|
return None
|
|
|
|
return task_id if isinstance(task_id, str) and task_id else None
|
|
|
|
|
|
def _publish_streaming_response(
|
|
response_stream: Generator[str | Mapping[str, Any] | BaseModel, None, None],
|
|
workflow_run_id: str | uuid.UUID,
|
|
app_mode: AppMode,
|
|
workflow_id: str,
|
|
inputs: Mapping[str, Any],
|
|
started_reason: WorkflowStartReason,
|
|
) -> None:
|
|
"""Publish workflow stream events and close broken streams with a failed terminal event.
|
|
|
|
`_AppRunner.run()` only handles failures before the generator is returned.
|
|
Once we start iterating the runtime stream, this helper becomes the last
|
|
place that can guarantee SSE consumers eventually see a terminal workflow
|
|
lifecycle event.
|
|
"""
|
|
normalized_workflow_run_id = str(workflow_run_id)
|
|
|
|
def _publish_failed_terminal_event(error_message: str, task_id: str, publish_started: bool) -> None:
|
|
timestamp = to_timestamp(naive_utc_now())
|
|
assert timestamp is not None
|
|
|
|
if publish_started:
|
|
started_payload = WorkflowStartStreamResponse(
|
|
task_id=task_id,
|
|
workflow_run_id=normalized_workflow_run_id,
|
|
data=WorkflowStartStreamResponse.Data(
|
|
id=normalized_workflow_run_id,
|
|
workflow_id=workflow_id,
|
|
inputs=inputs,
|
|
created_at=timestamp,
|
|
reason=started_reason,
|
|
),
|
|
)
|
|
topic.publish(
|
|
json.dumps(
|
|
started_payload.model_dump(mode="json", fallback=str),
|
|
ensure_ascii=False,
|
|
).encode()
|
|
)
|
|
|
|
finished_payload = WorkflowFinishStreamResponse(
|
|
task_id=task_id,
|
|
workflow_run_id=normalized_workflow_run_id,
|
|
data=WorkflowFinishStreamResponse.Data(
|
|
id=normalized_workflow_run_id,
|
|
workflow_id=workflow_id,
|
|
status=WorkflowExecutionStatus.FAILED,
|
|
outputs=None,
|
|
error=error_message,
|
|
elapsed_time=0.0,
|
|
total_tokens=0,
|
|
total_steps=0,
|
|
created_by={},
|
|
created_at=timestamp,
|
|
finished_at=timestamp,
|
|
exceptions_count=1,
|
|
files=[],
|
|
),
|
|
)
|
|
topic.publish(json.dumps(finished_payload.model_dump(mode="json"), ensure_ascii=False).encode())
|
|
|
|
terminal_events = {"workflow_finished", "workflow_paused"}
|
|
unexpected_stream_end_message = "Workflow stream ended without a terminal event"
|
|
topic = MessageBasedAppGenerator.get_response_topic(app_mode, normalized_workflow_run_id)
|
|
started_published = False
|
|
terminal_published = False
|
|
last_task_id = normalized_workflow_run_id
|
|
|
|
try:
|
|
for event in response_stream:
|
|
event_name = _get_event_name(event)
|
|
task_id = _get_task_id(event)
|
|
if task_id is not None:
|
|
last_task_id = task_id
|
|
|
|
try:
|
|
if isinstance(event, BaseModel):
|
|
payload = json.dumps(event.model_dump(mode="json"), ensure_ascii=False)
|
|
else:
|
|
payload = json.dumps(event, ensure_ascii=False, default=str)
|
|
except (TypeError, ValueError):
|
|
logger.exception("error while encoding event")
|
|
continue
|
|
|
|
topic.publish(payload.encode())
|
|
|
|
if event_name == "workflow_started":
|
|
started_published = True
|
|
elif event_name in terminal_events:
|
|
terminal_published = True
|
|
except Exception as exc:
|
|
if not terminal_published:
|
|
logger.exception(
|
|
"Workflow stream for run %s failed before terminal event; publishing fallback terminal event",
|
|
normalized_workflow_run_id,
|
|
)
|
|
_publish_failed_terminal_event(
|
|
error_message=str(exc) or exc.__class__.__name__,
|
|
task_id=last_task_id,
|
|
publish_started=not started_published,
|
|
)
|
|
raise
|
|
|
|
if not terminal_published:
|
|
logger.warning(
|
|
"Workflow stream for run %s ended without a terminal event; publishing fallback terminal event",
|
|
normalized_workflow_run_id,
|
|
)
|
|
_publish_failed_terminal_event(
|
|
error_message=unexpected_stream_end_message,
|
|
task_id=last_task_id,
|
|
publish_started=not started_published,
|
|
)
|
|
|
|
|
|
@shared_task(queue=WORKFLOW_BASED_APP_EXECUTION_QUEUE)
|
|
def workflow_based_app_execution_task(
|
|
payload: str,
|
|
) -> Generator[Mapping[str, Any] | str, None, None] | Mapping[str, Any] | None:
|
|
exec_params = AppExecutionParams.model_validate_json(payload)
|
|
|
|
logger.info("workflow_based_app_execution_task run with params: %s", exec_params)
|
|
|
|
runner = _AppRunner(db.engine, exec_params=exec_params)
|
|
return runner.run()
|
|
|
|
|
|
def _resume_app_execution(payload: dict[str, Any]) -> None:
|
|
workflow_run_id = payload["workflow_run_id"]
|
|
|
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
|
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker=session_factory)
|
|
|
|
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id)
|
|
if pause_entity is None:
|
|
logger.warning("No pause entity found for workflow run %s", workflow_run_id)
|
|
return
|
|
|
|
try:
|
|
resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode())
|
|
except Exception:
|
|
logger.exception("Failed to load resumption context for workflow run %s", workflow_run_id)
|
|
return
|
|
|
|
generate_entity = resumption_context.get_generate_entity()
|
|
|
|
graph_runtime_state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state)
|
|
|
|
conversation = None
|
|
message = None
|
|
with Session(db.engine, expire_on_commit=False) as session:
|
|
workflow_run = session.get(WorkflowRun, workflow_run_id)
|
|
if workflow_run is None:
|
|
logger.warning("Workflow run %s not found during resume", workflow_run_id)
|
|
return
|
|
|
|
workflow = session.get(Workflow, workflow_run.workflow_id)
|
|
if workflow is None:
|
|
logger.warning("Workflow %s not found during resume", workflow_run.workflow_id)
|
|
return
|
|
|
|
app_model = session.get(App, workflow_run.app_id)
|
|
if app_model is None:
|
|
logger.warning("App %s not found during resume", workflow_run.app_id)
|
|
return
|
|
|
|
user = _resolve_user_for_run(session, workflow_run)
|
|
if user is None:
|
|
logger.warning("User %s not found for workflow run %s", workflow_run.created_by, workflow_run_id)
|
|
return
|
|
|
|
if isinstance(generate_entity, AdvancedChatAppGenerateEntity):
|
|
if generate_entity.conversation_id is None:
|
|
logger.warning("Conversation id missing in resumption context for workflow run %s", workflow_run_id)
|
|
return
|
|
|
|
conversation = session.get(Conversation, generate_entity.conversation_id)
|
|
if conversation is None:
|
|
logger.warning(
|
|
"Conversation %s not found for workflow run %s", generate_entity.conversation_id, workflow_run_id
|
|
)
|
|
return
|
|
|
|
message = session.scalar(
|
|
select(Message)
|
|
.where(
|
|
Message.conversation_id == conversation.id,
|
|
Message.workflow_run_id == workflow_run_id,
|
|
)
|
|
.order_by(Message.created_at.desc())
|
|
.limit(1)
|
|
)
|
|
if message is None:
|
|
logger.warning("Message not found for workflow run %s", workflow_run_id)
|
|
return
|
|
|
|
if not isinstance(generate_entity, (AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity)):
|
|
logger.error(
|
|
"Unsupported resumption entity for workflow run %s (found %s)",
|
|
workflow_run_id,
|
|
type(generate_entity),
|
|
)
|
|
return
|
|
|
|
workflow_run_repo.resume_workflow_pause(workflow_run_id, pause_entity)
|
|
|
|
pause_config = PauseStateLayerConfig(
|
|
session_factory=session_factory,
|
|
state_owner_user_id=workflow.created_by,
|
|
)
|
|
|
|
match generate_entity:
|
|
case AdvancedChatAppGenerateEntity():
|
|
assert conversation is not None
|
|
assert message is not None
|
|
_resume_advanced_chat(
|
|
app_model=app_model,
|
|
workflow=workflow,
|
|
user=user,
|
|
conversation=conversation,
|
|
message=message,
|
|
generate_entity=generate_entity,
|
|
graph_runtime_state=graph_runtime_state,
|
|
session_factory=session_factory,
|
|
pause_state_config=pause_config,
|
|
workflow_run_id=workflow_run_id,
|
|
workflow_run=workflow_run,
|
|
)
|
|
case WorkflowAppGenerateEntity():
|
|
_resume_workflow(
|
|
app_model=app_model,
|
|
workflow=workflow,
|
|
user=user,
|
|
generate_entity=generate_entity,
|
|
graph_runtime_state=graph_runtime_state,
|
|
session_factory=session_factory,
|
|
pause_state_config=pause_config,
|
|
workflow_run_id=workflow_run_id,
|
|
workflow_run=workflow_run,
|
|
workflow_run_repo=workflow_run_repo,
|
|
pause_entity=pause_entity,
|
|
)
|
|
|
|
|
|
def _resume_advanced_chat(
|
|
*,
|
|
app_model: App,
|
|
workflow: Workflow,
|
|
user: Account | EndUser,
|
|
conversation: Conversation,
|
|
message: Message,
|
|
generate_entity: AdvancedChatAppGenerateEntity,
|
|
graph_runtime_state: GraphRuntimeState,
|
|
session_factory: sessionmaker,
|
|
pause_state_config: PauseStateLayerConfig,
|
|
workflow_run_id: str,
|
|
workflow_run: WorkflowRun,
|
|
) -> None:
|
|
resumed_generate_entity = generate_entity.model_copy(update={"stream": True})
|
|
|
|
try:
|
|
triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from)
|
|
except ValueError:
|
|
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
|
|
|
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
|
session_factory=session_factory,
|
|
user=user,
|
|
app_id=app_model.id,
|
|
triggered_from=triggered_from,
|
|
)
|
|
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
|
session_factory=session_factory,
|
|
user=user,
|
|
app_id=app_model.id,
|
|
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
|
)
|
|
|
|
generator = AdvancedChatAppGenerator()
|
|
|
|
try:
|
|
response = generator.resume(
|
|
app_model=app_model,
|
|
workflow=workflow,
|
|
user=user,
|
|
conversation=conversation,
|
|
message=message,
|
|
application_generate_entity=resumed_generate_entity,
|
|
workflow_execution_repository=workflow_execution_repository,
|
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
|
graph_runtime_state=graph_runtime_state,
|
|
pause_state_config=pause_state_config,
|
|
)
|
|
except Exception:
|
|
logger.exception("Failed to resume chatflow execution for workflow run %s", workflow_run_id)
|
|
raise
|
|
|
|
assert isinstance(response, Generator)
|
|
_publish_streaming_response(
|
|
response,
|
|
workflow_run_id,
|
|
AppMode.ADVANCED_CHAT,
|
|
workflow.id,
|
|
generate_entity.inputs,
|
|
WorkflowStartReason.RESUMPTION,
|
|
)
|
|
|
|
|
|
def _resume_workflow(
|
|
*,
|
|
app_model: App,
|
|
workflow: Workflow,
|
|
user: Account | EndUser,
|
|
generate_entity: WorkflowAppGenerateEntity,
|
|
graph_runtime_state: GraphRuntimeState,
|
|
session_factory: sessionmaker,
|
|
pause_state_config: PauseStateLayerConfig,
|
|
workflow_run_id: str,
|
|
workflow_run: WorkflowRun,
|
|
workflow_run_repo,
|
|
pause_entity,
|
|
) -> None:
|
|
resumed_generate_entity = generate_entity.model_copy(update={"stream": True})
|
|
|
|
try:
|
|
triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from)
|
|
except ValueError:
|
|
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
|
|
|
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
|
session_factory=session_factory,
|
|
user=user,
|
|
app_id=app_model.id,
|
|
triggered_from=triggered_from,
|
|
)
|
|
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
|
session_factory=session_factory,
|
|
user=user,
|
|
app_id=app_model.id,
|
|
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
|
)
|
|
|
|
generator = WorkflowAppGenerator()
|
|
|
|
try:
|
|
response = generator.resume(
|
|
app_model=app_model,
|
|
workflow=workflow,
|
|
user=user,
|
|
application_generate_entity=resumed_generate_entity,
|
|
graph_runtime_state=graph_runtime_state,
|
|
workflow_execution_repository=workflow_execution_repository,
|
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
|
pause_state_config=pause_state_config,
|
|
)
|
|
except Exception:
|
|
logger.exception("Failed to resume workflow execution for workflow run %s", workflow_run_id)
|
|
raise
|
|
|
|
assert isinstance(response, Generator)
|
|
_publish_streaming_response(
|
|
response,
|
|
workflow_run_id,
|
|
AppMode.WORKFLOW,
|
|
workflow.id,
|
|
generate_entity.inputs,
|
|
WorkflowStartReason.RESUMPTION,
|
|
)
|
|
|
|
try:
|
|
workflow_run_repo.delete_workflow_pause(pause_entity)
|
|
except Exception as exc:
|
|
if exc.__class__.__name__ != "_WorkflowRunError" or "WorkflowPause not found" not in str(exc):
|
|
raise
|
|
logger.info(
|
|
"Skipped deleting workflow pause %s after resume because it was already replaced or removed",
|
|
pause_entity.id,
|
|
)
|
|
|
|
|
|
@shared_task(queue=WORKFLOW_BASED_APP_EXECUTION_QUEUE, name="resume_app_execution")
|
|
def resume_app_execution(payload: dict[str, Any]) -> None:
|
|
_resume_app_execution(payload)
|