refactor: fix basedpyright/ruff errors

This commit is contained in:
Stream 2025-09-22 15:17:19 +08:00
parent e9313b9c1b
commit 394b7d09b8
No known key found for this signature in database
GPG Key ID: 033728094B100D70
11 changed files with 37 additions and 24 deletions

View File

@ -1,13 +1,9 @@
from flask_restx import Resource, reqparse
from sqlalchemy.orm import Session
from controllers.service_api import api
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.memory.entities import MemoryBlock
from core.workflow.entities.variable_pool import VariablePool
from libs.helper import uuid_value
from models import db
from models.chatflow_memory import ChatflowMemoryVariable
from services.chatflow_memory_service import ChatflowMemoryService
from services.workflow_service import WorkflowService
@ -31,6 +27,7 @@ class MemoryListApi(Resource):
result = [it for it in result if it.memory_id == memory_id]
return [it for it in result if it.end_user_visible]
class MemoryEditApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def put(self, app_model):
@ -44,6 +41,8 @@ class MemoryEditApi(Resource):
update = args.get("update")
conversation_id = args.get("conversation_id")
node_id = args.get("node_id")
if not isinstance(update, str):
return {'error': 'Invalid update'}, 400
if not workflow:
return {'error': 'Workflow not found'}, 404
memory_spec = next((it for it in workflow.memory_blocks if it.id == args['id']), None)
@ -63,5 +62,6 @@ class MemoryEditApi(Resource):
)
return '', 204
api.add_resource(MemoryListApi, '/memories')
api.add_resource(MemoryEditApi, '/memory-edit')

View File

@ -1,14 +1,9 @@
from flask_restx import reqparse
from sqlalchemy.orm.session import Session
from sympy import false
from controllers.web import api
from controllers.web.wraps import WebApiResource
from core.memory.entities import MemoryBlock
from core.workflow.entities.variable_pool import VariablePool
from libs.helper import uuid_value
from models import db
from models.chatflow_memory import ChatflowMemoryVariable
from services.chatflow_memory_service import ChatflowMemoryService
from services.workflow_service import WorkflowService
@ -31,6 +26,7 @@ class MemoryListApi(WebApiResource):
result = [it for it in result if it.memory_id == memory_id]
return [it for it in result if it.end_user_visible]
class MemoryEditApi(WebApiResource):
def put(self, app_model):
parser = reqparse.RequestParser()
@ -43,6 +39,8 @@ class MemoryEditApi(WebApiResource):
update = args.get("update")
conversation_id = args.get("conversation_id")
node_id = args.get("node_id")
if not isinstance(update, str):
return {'error': 'Update must be a string'}, 400
if not workflow:
return {'error': 'Workflow not found'}, 404
memory_spec = next((it for it in workflow.memory_blocks if it.id == args['id']), None)
@ -64,5 +62,6 @@ class MemoryEditApi(WebApiResource):
)
return '', 204
api.add_resource(MemoryListApi, '/memories')
api.add_resource(MemoryEditApi, '/memory-edit')

View File

@ -28,7 +28,7 @@ from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion
from core.workflow.entities import GraphRuntimeState, VariablePool
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.entities.event import GraphRunSucceededEvent
from core.workflow.graph_events import GraphRunSucceededEvent
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
@ -223,6 +223,9 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
if not assistant_message:
logger.warning("Chatflow output does not contain 'answer'.")
return
if not isinstance(assistant_message, str):
logger.warning("Chatflow output 'answer' is not a string.")
return
try:
self._sync_conversation_to_chatflow_tables(assistant_message)
except Exception as e:

View File

@ -1,4 +1,4 @@
from enum import Enum
from enum import StrEnum
from typing import Optional
from uuid import uuid4
@ -7,23 +7,23 @@ from pydantic import BaseModel, Field
from core.app.app_config.entities import ModelConfig
class MemoryScope(str, Enum):
class MemoryScope(StrEnum):
"""Memory scope determined by node_id field"""
APP = "app" # node_id is None
NODE = "node" # node_id is not None
class MemoryTerm(str, Enum):
class MemoryTerm(StrEnum):
"""Memory term determined by conversation_id field"""
SESSION = "session" # conversation_id is not None
PERSISTENT = "persistent" # conversation_id is None
class MemoryStrategy(str, Enum):
class MemoryStrategy(StrEnum):
ON_TURNS = "on_turns"
class MemoryScheduleMode(str, Enum):
class MemoryScheduleMode(StrEnum):
SYNC = "sync"
ASYNC = "async"
@ -69,6 +69,7 @@ class MemoryBlock(BaseModel):
conversation_id: Optional[str] = None
node_id: Optional[str] = None
class MemoryBlockWithVisibility(BaseModel):
id: str
name: str

View File

@ -202,9 +202,10 @@ class ArrayFileSegment(ArraySegment):
def text(self) -> str:
return ""
class VersionedMemoryValue(BaseModel):
current_value: str = None # type: ignore
versions: Mapping[str, str] = dict()
versions: Mapping[str, str] = {}
model_config = ConfigDict(frozen=True)
@ -215,7 +216,7 @@ class VersionedMemoryValue(BaseModel):
) -> "VersionedMemoryValue":
if version_name is None:
version_name = str(len(self.versions) + 1)
if version_name in self.versions.keys():
if version_name in self.versions:
raise ValueError(f"Version '{version_name}' already exists.")
self.current_value = new_value
return VersionedMemoryValue(
@ -226,6 +227,7 @@ class VersionedMemoryValue(BaseModel):
}
)
class VersionedMemorySegment(Segment):
value_type: SegmentType = SegmentType.VERSIONED_MEMORY
value: VersionedMemoryValue = None # type: ignore

View File

@ -22,7 +22,8 @@ from .segments import (
ObjectSegment,
Segment,
StringSegment,
get_segment_discriminator, VersionedMemorySegment,
VersionedMemorySegment,
get_segment_discriminator,
)
from .types import SegmentType
@ -105,9 +106,11 @@ class BooleanVariable(BooleanSegment, Variable):
class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
pass
class VersionedMemoryVariable(VersionedMemorySegment, Variable):
pass
class ArrayBooleanVariable(ArrayBooleanSegment, ArrayVariable):
pass

View File

@ -83,7 +83,6 @@ class VariablePool(BaseModel):
for memory_id, memory_value in self.memory_blocks.items():
self.add([CONVERSATION_VARIABLE_NODE_ID, memory_id], memory_value)
def add(self, selector: Sequence[str], value: Any, /):
"""
Add a variable to the variable pool.

View File

@ -1235,7 +1235,7 @@ class LLMNode(Node):
memory_blocks = workflow.memory_blocks
for block_id in block_ids:
memory_block_spec = next((block for block in memory_blocks if block.id == block_id),None)
memory_block_spec = next((block for block in memory_blocks if block.id == block_id), None)
if memory_block_spec and memory_block_spec.scope == MemoryScope.NODE:
is_draft = (self.invoke_from == InvokeFrom.DEBUGGER)

View File

@ -20,7 +20,9 @@ from core.variables.segments import (
NoneSegment,
ObjectSegment,
Segment,
StringSegment, VersionedMemorySegment, VersionedMemoryValue,
StringSegment,
VersionedMemorySegment,
VersionedMemoryValue,
)
from core.variables.types import SegmentType
from core.variables.variables import (
@ -38,7 +40,8 @@ from core.variables.variables import (
ObjectVariable,
SecretVariable,
StringVariable,
Variable, VersionedMemoryVariable,
Variable,
VersionedMemoryVariable,
)
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,

View File

@ -32,6 +32,7 @@ from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)
class ChatflowMemoryService:
@staticmethod
def get_persistent_memories(
@ -186,9 +187,9 @@ class ChatflowMemoryService:
ChatflowMemoryVariable.memory_id == spec.id,
ChatflowMemoryVariable.tenant_id == tenant_id,
ChatflowMemoryVariable.app_id == app_id,
ChatflowMemoryVariable.node_id == \
ChatflowMemoryVariable.node_id ==
(node_id if spec.scope == MemoryScope.NODE else None),
ChatflowMemoryVariable.conversation_id == \
ChatflowMemoryVariable.conversation_id ==
(conversation_id if spec.term == MemoryTerm.SESSION else None),
)
).order_by(ChatflowMemoryVariable.version.desc()).limit(1)
@ -517,6 +518,7 @@ class ChatflowMemoryService:
result.append((str(message.role.value), message.get_text_content()))
return result
def _get_memory_sync_lock_key(app_id: str, conversation_id: str) -> str:
"""Generate Redis lock key for memory sync updates

View File

@ -1055,6 +1055,7 @@ def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: Varia
else:
raise Exception("unreachable")
def _fetch_memory_blocks(workflow: Workflow, conversation_id: str, is_draft: bool) -> Mapping[str, str]:
memory_blocks = {}
memory_block_specs = workflow.memory_blocks