mirror of https://github.com/langgenius/dify.git
refactor: fix basedpyright/ruff errors
This commit is contained in:
parent
e9313b9c1b
commit
394b7d09b8
|
|
@ -1,13 +1,9 @@
|
|||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.memory.entities import MemoryBlock
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from libs.helper import uuid_value
|
||||
from models import db
|
||||
from models.chatflow_memory import ChatflowMemoryVariable
|
||||
from services.chatflow_memory_service import ChatflowMemoryService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
|
@ -31,6 +27,7 @@ class MemoryListApi(Resource):
|
|||
result = [it for it in result if it.memory_id == memory_id]
|
||||
return [it for it in result if it.end_user_visible]
|
||||
|
||||
|
||||
class MemoryEditApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def put(self, app_model):
|
||||
|
|
@ -44,6 +41,8 @@ class MemoryEditApi(Resource):
|
|||
update = args.get("update")
|
||||
conversation_id = args.get("conversation_id")
|
||||
node_id = args.get("node_id")
|
||||
if not isinstance(update, str):
|
||||
return {'error': 'Invalid update'}, 400
|
||||
if not workflow:
|
||||
return {'error': 'Workflow not found'}, 404
|
||||
memory_spec = next((it for it in workflow.memory_blocks if it.id == args['id']), None)
|
||||
|
|
@ -63,5 +62,6 @@ class MemoryEditApi(Resource):
|
|||
)
|
||||
return '', 204
|
||||
|
||||
|
||||
api.add_resource(MemoryListApi, '/memories')
|
||||
api.add_resource(MemoryEditApi, '/memory-edit')
|
||||
|
|
|
|||
|
|
@ -1,14 +1,9 @@
|
|||
from flask_restx import reqparse
|
||||
from sqlalchemy.orm.session import Session
|
||||
from sympy import false
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.memory.entities import MemoryBlock
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from libs.helper import uuid_value
|
||||
from models import db
|
||||
from models.chatflow_memory import ChatflowMemoryVariable
|
||||
from services.chatflow_memory_service import ChatflowMemoryService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
|
@ -31,6 +26,7 @@ class MemoryListApi(WebApiResource):
|
|||
result = [it for it in result if it.memory_id == memory_id]
|
||||
return [it for it in result if it.end_user_visible]
|
||||
|
||||
|
||||
class MemoryEditApi(WebApiResource):
|
||||
def put(self, app_model):
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -43,6 +39,8 @@ class MemoryEditApi(WebApiResource):
|
|||
update = args.get("update")
|
||||
conversation_id = args.get("conversation_id")
|
||||
node_id = args.get("node_id")
|
||||
if not isinstance(update, str):
|
||||
return {'error': 'Update must be a string'}, 400
|
||||
if not workflow:
|
||||
return {'error': 'Workflow not found'}, 404
|
||||
memory_spec = next((it for it in workflow.memory_blocks if it.id == args['id']), None)
|
||||
|
|
@ -64,5 +62,6 @@ class MemoryEditApi(WebApiResource):
|
|||
)
|
||||
return '', 204
|
||||
|
||||
|
||||
api.add_resource(MemoryListApi, '/memories')
|
||||
api.add_resource(MemoryEditApi, '/memory-edit')
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ from core.moderation.input_moderation import InputModeration
|
|||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.entities.event import GraphRunSucceededEvent
|
||||
from core.workflow.graph_events import GraphRunSucceededEvent
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
|
|
@ -223,6 +223,9 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
if not assistant_message:
|
||||
logger.warning("Chatflow output does not contain 'answer'.")
|
||||
return
|
||||
if not isinstance(assistant_message, str):
|
||||
logger.warning("Chatflow output 'answer' is not a string.")
|
||||
return
|
||||
try:
|
||||
self._sync_conversation_to_chatflow_tables(assistant_message)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
|
|
@ -7,23 +7,23 @@ from pydantic import BaseModel, Field
|
|||
from core.app.app_config.entities import ModelConfig
|
||||
|
||||
|
||||
class MemoryScope(str, Enum):
|
||||
class MemoryScope(StrEnum):
|
||||
"""Memory scope determined by node_id field"""
|
||||
APP = "app" # node_id is None
|
||||
NODE = "node" # node_id is not None
|
||||
|
||||
|
||||
class MemoryTerm(str, Enum):
|
||||
class MemoryTerm(StrEnum):
|
||||
"""Memory term determined by conversation_id field"""
|
||||
SESSION = "session" # conversation_id is not None
|
||||
PERSISTENT = "persistent" # conversation_id is None
|
||||
|
||||
|
||||
class MemoryStrategy(str, Enum):
|
||||
class MemoryStrategy(StrEnum):
|
||||
ON_TURNS = "on_turns"
|
||||
|
||||
|
||||
class MemoryScheduleMode(str, Enum):
|
||||
class MemoryScheduleMode(StrEnum):
|
||||
SYNC = "sync"
|
||||
ASYNC = "async"
|
||||
|
||||
|
|
@ -69,6 +69,7 @@ class MemoryBlock(BaseModel):
|
|||
conversation_id: Optional[str] = None
|
||||
node_id: Optional[str] = None
|
||||
|
||||
|
||||
class MemoryBlockWithVisibility(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
|
|
|
|||
|
|
@ -202,9 +202,10 @@ class ArrayFileSegment(ArraySegment):
|
|||
def text(self) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
class VersionedMemoryValue(BaseModel):
|
||||
current_value: str = None # type: ignore
|
||||
versions: Mapping[str, str] = dict()
|
||||
versions: Mapping[str, str] = {}
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
|
|
@ -215,7 +216,7 @@ class VersionedMemoryValue(BaseModel):
|
|||
) -> "VersionedMemoryValue":
|
||||
if version_name is None:
|
||||
version_name = str(len(self.versions) + 1)
|
||||
if version_name in self.versions.keys():
|
||||
if version_name in self.versions:
|
||||
raise ValueError(f"Version '{version_name}' already exists.")
|
||||
self.current_value = new_value
|
||||
return VersionedMemoryValue(
|
||||
|
|
@ -226,6 +227,7 @@ class VersionedMemoryValue(BaseModel):
|
|||
}
|
||||
)
|
||||
|
||||
|
||||
class VersionedMemorySegment(Segment):
|
||||
value_type: SegmentType = SegmentType.VERSIONED_MEMORY
|
||||
value: VersionedMemoryValue = None # type: ignore
|
||||
|
|
|
|||
|
|
@ -22,7 +22,8 @@ from .segments import (
|
|||
ObjectSegment,
|
||||
Segment,
|
||||
StringSegment,
|
||||
get_segment_discriminator, VersionedMemorySegment,
|
||||
VersionedMemorySegment,
|
||||
get_segment_discriminator,
|
||||
)
|
||||
from .types import SegmentType
|
||||
|
||||
|
|
@ -105,9 +106,11 @@ class BooleanVariable(BooleanSegment, Variable):
|
|||
class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
|
||||
pass
|
||||
|
||||
|
||||
class VersionedMemoryVariable(VersionedMemorySegment, Variable):
|
||||
pass
|
||||
|
||||
|
||||
class ArrayBooleanVariable(ArrayBooleanSegment, ArrayVariable):
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -83,7 +83,6 @@ class VariablePool(BaseModel):
|
|||
for memory_id, memory_value in self.memory_blocks.items():
|
||||
self.add([CONVERSATION_VARIABLE_NODE_ID, memory_id], memory_value)
|
||||
|
||||
|
||||
def add(self, selector: Sequence[str], value: Any, /):
|
||||
"""
|
||||
Add a variable to the variable pool.
|
||||
|
|
|
|||
|
|
@ -1235,7 +1235,7 @@ class LLMNode(Node):
|
|||
memory_blocks = workflow.memory_blocks
|
||||
|
||||
for block_id in block_ids:
|
||||
memory_block_spec = next((block for block in memory_blocks if block.id == block_id),None)
|
||||
memory_block_spec = next((block for block in memory_blocks if block.id == block_id), None)
|
||||
|
||||
if memory_block_spec and memory_block_spec.scope == MemoryScope.NODE:
|
||||
is_draft = (self.invoke_from == InvokeFrom.DEBUGGER)
|
||||
|
|
|
|||
|
|
@ -20,7 +20,9 @@ from core.variables.segments import (
|
|||
NoneSegment,
|
||||
ObjectSegment,
|
||||
Segment,
|
||||
StringSegment, VersionedMemorySegment, VersionedMemoryValue,
|
||||
StringSegment,
|
||||
VersionedMemorySegment,
|
||||
VersionedMemoryValue,
|
||||
)
|
||||
from core.variables.types import SegmentType
|
||||
from core.variables.variables import (
|
||||
|
|
@ -38,7 +40,8 @@ from core.variables.variables import (
|
|||
ObjectVariable,
|
||||
SecretVariable,
|
||||
StringVariable,
|
||||
Variable, VersionedMemoryVariable,
|
||||
Variable,
|
||||
VersionedMemoryVariable,
|
||||
)
|
||||
from core.workflow.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ from services.workflow_service import WorkflowService
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatflowMemoryService:
|
||||
@staticmethod
|
||||
def get_persistent_memories(
|
||||
|
|
@ -186,9 +187,9 @@ class ChatflowMemoryService:
|
|||
ChatflowMemoryVariable.memory_id == spec.id,
|
||||
ChatflowMemoryVariable.tenant_id == tenant_id,
|
||||
ChatflowMemoryVariable.app_id == app_id,
|
||||
ChatflowMemoryVariable.node_id == \
|
||||
ChatflowMemoryVariable.node_id ==
|
||||
(node_id if spec.scope == MemoryScope.NODE else None),
|
||||
ChatflowMemoryVariable.conversation_id == \
|
||||
ChatflowMemoryVariable.conversation_id ==
|
||||
(conversation_id if spec.term == MemoryTerm.SESSION else None),
|
||||
)
|
||||
).order_by(ChatflowMemoryVariable.version.desc()).limit(1)
|
||||
|
|
@ -517,6 +518,7 @@ class ChatflowMemoryService:
|
|||
result.append((str(message.role.value), message.get_text_content()))
|
||||
return result
|
||||
|
||||
|
||||
def _get_memory_sync_lock_key(app_id: str, conversation_id: str) -> str:
|
||||
"""Generate Redis lock key for memory sync updates
|
||||
|
||||
|
|
|
|||
|
|
@ -1055,6 +1055,7 @@ def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: Varia
|
|||
else:
|
||||
raise Exception("unreachable")
|
||||
|
||||
|
||||
def _fetch_memory_blocks(workflow: Workflow, conversation_id: str, is_draft: bool) -> Mapping[str, str]:
|
||||
memory_blocks = {}
|
||||
memory_block_specs = workflow.memory_blocks
|
||||
|
|
|
|||
Loading…
Reference in New Issue