Tighten phase 3 runtime typing

This commit is contained in:
Yanli 盐粒 2026-03-17 18:49:14 +08:00
parent a717519822
commit 9f0d79b8b0
15 changed files with 88 additions and 57 deletions

View File

@ -5,7 +5,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload
from typing import TYPE_CHECKING, Any, Literal, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@ -47,7 +47,6 @@ from extensions.ext_database import db
from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.base import Base
from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService
from services.workflow_draft_variable_service import (
@ -524,6 +523,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
with Session(bind=db.engine, expire_on_commit=False) as session:
workflow = _refresh_model(session, workflow)
message = _refresh_model(session, message)
assert message is not None
# workflow_ = session.get(Workflow, workflow.id)
# assert workflow_ is not None
# workflow = workflow_
@ -690,11 +690,22 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
raise e
_T = TypeVar("_T", bound=Base)
@overload
def _refresh_model(session: object, model: Workflow) -> Workflow: ...
def _refresh_model(session, model: _T) -> _T:
with Session(bind=db.engine, expire_on_commit=False) as session:
detach_model = session.get(type(model), model.id)
assert detach_model is not None
return detach_model
@overload
def _refresh_model(session: object, model: Message) -> Message: ...
def _refresh_model(session: object, model: Workflow | Message) -> Workflow | Message:
_ = session
with Session(bind=db.engine, expire_on_commit=False) as db_session:
if isinstance(model, Workflow):
detached_workflow = db_session.get(Workflow, model.id)
assert detached_workflow is not None
return detached_workflow
detached_message = db_session.get(Message, model.id)
assert detached_message is not None
return detached_message

View File

@ -1,4 +1,4 @@
from collections.abc import Generator
from collections.abc import Generator, Iterator
from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@ -56,7 +56,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, Any, None]:
"""
Convert stream full response.
@ -87,7 +87,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, Any, None]:
"""
Convert stream simple response.

View File

@ -1,4 +1,4 @@
from collections.abc import Generator
from collections.abc import Generator, Iterator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@ -55,7 +55,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@ -86,7 +86,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

View File

@ -1,7 +1,7 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping
from typing import Any, Union
from collections.abc import Generator, Iterator, Mapping
from typing import Any
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
@ -16,24 +16,26 @@ class AppGenerateResponseConverter(ABC):
@classmethod
def convert(
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
cls, response: AppBlockingResponse | Iterator[AppStreamResponse], invoke_from: InvokeFrom
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response)
else:
stream_response = response
def _generate_full_response() -> Generator[dict | str, Any, None]:
yield from cls.convert_stream_full_response(response)
def _generate_full_response() -> Generator[dict[str, Any] | str, None, None]:
yield from cls.convert_stream_full_response(stream_response)
return _generate_full_response()
else:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_simple_response(response)
else:
stream_response = response
def _generate_simple_response() -> Generator[dict | str, Any, None]:
yield from cls.convert_stream_simple_response(response)
def _generate_simple_response() -> Generator[dict[str, Any] | str, None, None]:
yield from cls.convert_stream_simple_response(stream_response)
return _generate_simple_response()
@ -50,14 +52,14 @@ class AppGenerateResponseConverter(ABC):
@classmethod
@abstractmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
raise NotImplementedError
@classmethod
@abstractmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
raise NotImplementedError

View File

@ -224,6 +224,7 @@ class BaseAppGenerator:
def _get_draft_var_saver_factory(invoke_from: InvokeFrom, account: Account | EndUser) -> DraftVariableSaverFactory:
if invoke_from == InvokeFrom.DEBUGGER:
assert isinstance(account, Account)
debug_account = account
def draft_var_saver_factory(
session: Session,
@ -240,7 +241,7 @@ class BaseAppGenerator:
node_type=node_type,
node_execution_id=node_execution_id,
enclosing_node_id=enclosing_node_id,
user=account,
user=debug_account,
)
else:

View File

@ -166,15 +166,19 @@ class ChatAppGenerator(MessageBasedAppGenerator):
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
assert conversation is not None
assert message is not None
generated_conversation_id = str(conversation.id)
generated_message_id = str(message.id)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
conversation_id=generated_conversation_id,
app_mode=conversation.mode,
message_id=message.id,
message_id=generated_message_id,
)
# new thread with request context
@ -184,8 +188,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
flask_app=current_app._get_current_object(), # type: ignore
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation_id=conversation.id,
message_id=message.id,
conversation_id=generated_conversation_id,
message_id=generated_message_id,
)
worker_thread = threading.Thread(target=worker_with_context)

View File

@ -1,4 +1,4 @@
from collections.abc import Generator
from collections.abc import Generator, Iterator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@ -55,7 +55,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@ -86,7 +86,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

View File

@ -149,6 +149,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity)
assert conversation is not None
assert message is not None
# init queue manager
queue_manager = MessageBasedAppQueueManager(
@ -312,15 +314,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity)
assert conversation is not None
assert message is not None
conversation_id = str(conversation.id)
message_id = str(message.id)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
conversation_id=conversation_id,
app_mode=conversation.mode,
message_id=message.id,
message_id=message_id,
)
# new thread with request context
@ -330,7 +336,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
flask_app=current_app._get_current_object(), # type: ignore
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
message_id=message.id,
message_id=message_id,
)
worker_thread = threading.Thread(target=worker_with_context)

View File

@ -1,4 +1,4 @@
from collections.abc import Generator
from collections.abc import Generator, Iterator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@ -54,7 +54,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@ -84,7 +84,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

View File

@ -1,4 +1,4 @@
from collections.abc import Generator
from collections.abc import Generator, Iterator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@ -36,7 +36,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@ -65,7 +65,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

View File

@ -1,4 +1,4 @@
from collections.abc import Generator
from collections.abc import Generator, Iterator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@ -36,7 +36,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@ -65,7 +65,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

View File

@ -133,6 +133,8 @@ class ExecutionLimitsLayer(GraphEngineLayer):
elif limit_type == LimitType.TIME_LIMIT:
elapsed_time = time.time() - self.start_time if self.start_time else 0
reason = f"Maximum execution time exceeded: {elapsed_time:.2f}s > {self.max_time}s"
else:
return
self.logger.warning("Execution limit exceeded: %s", reason)

View File

@ -336,16 +336,13 @@ class Node(Generic[NodeDataT]):
def _restore_execution_id_from_runtime_state(self) -> str | None:
graph_execution = self.graph_runtime_state.graph_execution
try:
node_executions = graph_execution.node_executions
except AttributeError:
return None
node_executions = getattr(graph_execution, "node_executions", None)
if not isinstance(node_executions, dict):
return None
node_execution = node_executions.get(self._node_id)
if node_execution is None:
return None
execution_id = node_execution.execution_id
execution_id = getattr(node_execution, "execution_id", None)
if not execution_id:
return None
return str(execution_id)
@ -395,8 +392,7 @@ class Node(Generic[NodeDataT]):
if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance]
yield self._dispatch(event)
elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance]
event.id = self.execution_id
yield event
yield event.model_copy(update={"id": self.execution_id})
else:
yield event
except Exception as e:

View File

@ -443,7 +443,10 @@ def _extract_text_from_docx(file_content: bytes) -> str:
# Keep track of paragraph and table positions
content_items: list[tuple[int, str, Table | Paragraph]] = []
it = iter(doc.element.body)
doc_body = getattr(doc.element, "body", None)
if doc_body is None:
raise TextExtractionError("DOCX body not found")
it = iter(doc_body)
part = next(it, None)
i = 0
while part is not None:

View File

@ -9,7 +9,7 @@ from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.variable_assigner.common import helpers as common_helpers
from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from dify_graph.variables import SegmentType, VariableBase
from dify_graph.variables import Segment, SegmentType, VariableBase
from .node_data import VariableAssignerData, WriteMode
@ -74,23 +74,29 @@ class VariableAssignerNode(Node[VariableAssignerData]):
if not isinstance(original_variable, VariableBase):
raise VariableOperatorNodeError("assigned variable not found")
income_value: Segment
updated_variable: VariableBase
match self.node_data.write_mode:
case WriteMode.OVER_WRITE:
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value:
input_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if input_value is None:
raise VariableOperatorNodeError("input value not found")
income_value = input_value
updated_variable = original_variable.model_copy(update={"value": income_value.value})
case WriteMode.APPEND:
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value:
input_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if input_value is None:
raise VariableOperatorNodeError("input value not found")
income_value = input_value
updated_value = original_variable.value + [income_value.value]
updated_variable = original_variable.model_copy(update={"value": updated_value})
case WriteMode.CLEAR:
income_value = SegmentType.get_zero_value(original_variable.value_type)
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
case _:
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
# Over write the variable.
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)