Merge branch 'feat/memory-orchestration-be-dev-env' into deploy/dev

This commit is contained in:
Stream 2025-10-11 16:16:15 +08:00
commit 05751ac492
No known key found for this signature in database
GPG Key ID: 033728094B100D70
156 changed files with 3100 additions and 806 deletions

View File

@ -81,7 +81,6 @@ ignore = [
"SIM113", # enumerate-for-loop
"SIM117", # multiple-with-statements
"SIM210", # if-expr-with-true-false
"UP038", # deprecated and not recommended by Ruff, https://docs.astral.sh/ruff/rules/non-pep604-isinstance/
]
[lint.per-file-ignores]

View File

@ -90,7 +90,7 @@ class ModelConfigResource(Resource):
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
continue
agent_tool_entity = AgentToolEntity(**tool)
agent_tool_entity = AgentToolEntity.model_validate(tool)
# get tool
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
@ -124,7 +124,7 @@ class ModelConfigResource(Resource):
# encrypt agent tool parameters if it's secret-input
agent_mode = new_app_model_config.agent_mode_dict
for tool in agent_mode.get("tools") or []:
agent_tool_entity = AgentToolEntity(**tool)
agent_tool_entity = AgentToolEntity.model_validate(tool)
# get tool
key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}"

View File

@ -15,7 +15,7 @@ from core.datasource.entities.datasource_entities import DatasourceProviderType,
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.indexing_runner import IndexingRunner
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
@ -257,13 +257,15 @@ class DataSourceNotionApi(Resource):
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
notion_info={
"credential_id": credential_id,
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id,
},
notion_info=NotionInfo.model_validate(
{
"credential_id": credential_id,
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id,
}
),
document_model=args["doc_form"],
)
extract_settings.append(extract_setting)

View File

@ -24,7 +24,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from fields.app_fields import related_app_list
@ -513,13 +513,15 @@ class DatasetIndexingEstimateApi(Resource):
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
notion_info={
"credential_id": credential_id,
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id,
},
notion_info=NotionInfo.model_validate(
{
"credential_id": credential_id,
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id,
}
),
document_model=args["doc_form"],
)
extract_settings.append(extract_setting)
@ -528,14 +530,16 @@ class DatasetIndexingEstimateApi(Resource):
for url in website_info_list["urls"]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value,
website_info={
"provider": website_info_list["provider"],
"job_id": website_info_list["job_id"],
"url": url,
"tenant_id": current_user.current_tenant_id,
"mode": "crawl",
"only_main_content": website_info_list["only_main_content"],
},
website_info=WebsiteInfo.model_validate(
{
"provider": website_info_list["provider"],
"job_id": website_info_list["job_id"],
"url": url,
"tenant_id": current_user.current_tenant_id,
"mode": "crawl",
"only_main_content": website_info_list["only_main_content"],
}
),
document_model=args["doc_form"],
)
extract_settings.append(extract_setting)

View File

@ -44,7 +44,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from extensions.ext_database import db
from fields.document_fields import (
dataset_and_document_fields,
@ -305,7 +305,7 @@ class DatasetDocumentListApi(Resource):
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
args = parser.parse_args()
knowledge_config = KnowledgeConfig(**args)
knowledge_config = KnowledgeConfig.model_validate(args)
if not dataset.indexing_technique and not knowledge_config.indexing_technique:
raise ValueError("indexing_technique is required.")
@ -395,7 +395,7 @@ class DatasetInitApi(Resource):
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
knowledge_config = KnowledgeConfig(**args)
knowledge_config = KnowledgeConfig.model_validate(args)
if knowledge_config.indexing_technique == "high_quality":
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
@ -547,13 +547,15 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
notion_info={
"credential_id": data_source_info["credential_id"],
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"tenant_id": current_user.current_tenant_id,
},
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info["credential_id"],
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"tenant_id": current_user.current_tenant_id,
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)
@ -562,14 +564,16 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value,
website_info={
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"url": data_source_info["url"],
"tenant_id": current_user.current_tenant_id,
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
},
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"url": data_source_info["url"],
"tenant_id": current_user.current_tenant_id,
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)

View File

@ -309,7 +309,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
)
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset)
segment = SegmentService.update_segment(SegmentUpdateArgs.model_validate(args), segment, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@setup_required
@ -564,7 +564,7 @@ class ChildChunkAddApi(Resource):
args = parser.parse_args()
try:
chunks_data = args["chunks"]
chunks = [ChildChunkUpdateArgs(**chunk) for chunk in chunks_data]
chunks = [ChildChunkUpdateArgs.model_validate(chunk) for chunk in chunks_data]
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))

View File

@ -28,7 +28,7 @@ class DatasetMetadataCreateApi(Resource):
parser.add_argument("type", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
metadata_args = MetadataArgs(**args)
metadata_args = MetadataArgs.model_validate(args)
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -137,7 +137,7 @@ class DocumentMetadataEditApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json")
args = parser.parse_args()
metadata_args = MetadataOperationData(**args)
metadata_args = MetadataOperationData.model_validate(args)
MetadataService.update_documents_metadata(dataset, metadata_args)

View File

@ -88,7 +88,7 @@ class CustomizedPipelineTemplateApi(Resource):
nullable=True,
)
args = parser.parse_args()
pipeline_template_info = PipelineTemplateInfoEntity(**args)
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args)
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
return 200

View File

@ -6,7 +6,7 @@ from flask_restx import Resource, inputs, marshal_with, reqparse
from sqlalchemy import and_, select
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
from controllers.console import api
from controllers.console import console_ns
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from extensions.ext_database import db
@ -22,6 +22,7 @@ from services.feature_service import FeatureService
logger = logging.getLogger(__name__)
@console_ns.route("/installed-apps")
class InstalledAppsListApi(Resource):
@login_required
@account_initialization_required
@ -154,6 +155,7 @@ class InstalledAppsListApi(Resource):
return {"message": "App installed successfully"}
@console_ns.route("/installed-apps/<uuid:installed_app_id>")
class InstalledAppApi(InstalledAppResource):
"""
update and delete an installed app
@ -185,7 +187,3 @@ class InstalledAppApi(InstalledAppResource):
db.session.commit()
return {"result": "success", "message": "App info updated successfully"}
api.add_resource(InstalledAppsListApi, "/installed-apps")
api.add_resource(InstalledAppApi, "/installed-apps/<uuid:installed_app_id>")

View File

@ -1,7 +1,7 @@
from flask_restx import marshal_with
from controllers.common import fields
from controllers.console import api
from controllers.console import console_ns
from controllers.console.app.error import AppUnavailableError
from controllers.console.explore.wraps import InstalledAppResource
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
@ -9,6 +9,7 @@ from models.model import AppMode, InstalledApp
from services.app_service import AppService
@console_ns.route("/installed-apps/<uuid:installed_app_id>/parameters", endpoint="installed_app_parameters")
class AppParameterApi(InstalledAppResource):
"""Resource for app variables."""
@ -39,6 +40,7 @@ class AppParameterApi(InstalledAppResource):
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
@console_ns.route("/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")
class ExploreAppMetaApi(InstalledAppResource):
def get(self, installed_app: InstalledApp):
"""Get app meta"""
@ -46,9 +48,3 @@ class ExploreAppMetaApi(InstalledAppResource):
if not app_model:
raise ValueError("App not found")
return AppService().get_app_meta(app_model)
api.add_resource(
AppParameterApi, "/installed-apps/<uuid:installed_app_id>/parameters", endpoint="installed_app_parameters"
)
api.add_resource(ExploreAppMetaApi, "/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")

View File

@ -1,7 +1,7 @@
from flask_restx import Resource, fields, marshal_with, reqparse
from constants.languages import languages
from controllers.console import api
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required
from libs.helper import AppIconUrlField
from libs.login import current_user, login_required
@ -36,6 +36,7 @@ recommended_app_list_fields = {
}
@console_ns.route("/explore/apps")
class RecommendedAppListApi(Resource):
@login_required
@account_initialization_required
@ -57,13 +58,10 @@ class RecommendedAppListApi(Resource):
return RecommendedAppService.get_recommended_apps_and_categories(language_prefix)
@console_ns.route("/explore/apps/<uuid:app_id>")
class RecommendedAppApi(Resource):
@login_required
@account_initialization_required
def get(self, app_id):
app_id = str(app_id)
return RecommendedAppService.get_recommend_app_detail(app_id)
api.add_resource(RecommendedAppListApi, "/explore/apps")
api.add_resource(RecommendedAppApi, "/explore/apps/<uuid:app_id>")

View File

@ -2,7 +2,7 @@ from flask_restx import fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console import console_ns
from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import message_file_fields
@ -25,6 +25,7 @@ message_fields = {
}
@console_ns.route("/installed-apps/<uuid:installed_app_id>/saved-messages", endpoint="installed_app_saved_messages")
class SavedMessageListApi(InstalledAppResource):
saved_message_infinite_scroll_pagination_fields = {
"limit": fields.Integer,
@ -66,6 +67,9 @@ class SavedMessageListApi(InstalledAppResource):
return {"result": "success"}
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>", endpoint="installed_app_saved_message"
)
class SavedMessageApi(InstalledAppResource):
def delete(self, installed_app, message_id):
app_model = installed_app.app
@ -80,15 +84,3 @@ class SavedMessageApi(InstalledAppResource):
SavedMessageService.delete(app_model, current_user, message_id)
return {"result": "success"}, 204
api.add_resource(
SavedMessageListApi,
"/installed-apps/<uuid:installed_app_id>/saved-messages",
endpoint="installed_app_saved_messages",
)
api.add_resource(
SavedMessageApi,
"/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>",
endpoint="installed_app_saved_message",
)

View File

@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
from configs import dify_config
from constants.languages import supported_language
from controllers.console import api
from controllers.console import console_ns
from controllers.console.auth.error import (
EmailAlreadyInUseError,
EmailChangeLimitError,
@ -46,6 +46,7 @@ from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
@console_ns.route("/account/init")
class AccountInitApi(Resource):
@setup_required
@login_required
@ -98,6 +99,7 @@ class AccountInitApi(Resource):
return {"result": "success"}
@console_ns.route("/account/profile")
class AccountProfileApi(Resource):
@setup_required
@login_required
@ -110,6 +112,7 @@ class AccountProfileApi(Resource):
return current_user
@console_ns.route("/account/name")
class AccountNameApi(Resource):
@setup_required
@login_required
@ -131,6 +134,7 @@ class AccountNameApi(Resource):
return updated_account
@console_ns.route("/account/avatar")
class AccountAvatarApi(Resource):
@setup_required
@login_required
@ -159,6 +163,7 @@ class AccountAvatarApi(Resource):
return updated_account
@console_ns.route("/account/interface-language")
class AccountInterfaceLanguageApi(Resource):
@setup_required
@login_required
@ -176,6 +181,7 @@ class AccountInterfaceLanguageApi(Resource):
return updated_account
@console_ns.route("/account/interface-theme")
class AccountInterfaceThemeApi(Resource):
@setup_required
@login_required
@ -193,6 +199,7 @@ class AccountInterfaceThemeApi(Resource):
return updated_account
@console_ns.route("/account/timezone")
class AccountTimezoneApi(Resource):
@setup_required
@login_required
@ -214,6 +221,7 @@ class AccountTimezoneApi(Resource):
return updated_account
@console_ns.route("/account/password")
class AccountPasswordApi(Resource):
@setup_required
@login_required
@ -239,6 +247,7 @@ class AccountPasswordApi(Resource):
return {"result": "success"}
@console_ns.route("/account/integrates")
class AccountIntegrateApi(Resource):
integrate_fields = {
"provider": fields.String,
@ -295,6 +304,7 @@ class AccountIntegrateApi(Resource):
return {"data": integrate_data}
@console_ns.route("/account/delete/verify")
class AccountDeleteVerifyApi(Resource):
@setup_required
@login_required
@ -310,6 +320,7 @@ class AccountDeleteVerifyApi(Resource):
return {"result": "success", "data": token}
@console_ns.route("/account/delete")
class AccountDeleteApi(Resource):
@setup_required
@login_required
@ -332,6 +343,7 @@ class AccountDeleteApi(Resource):
return {"result": "success"}
@console_ns.route("/account/delete/feedback")
class AccountDeleteUpdateFeedbackApi(Resource):
@setup_required
def post(self):
@ -345,6 +357,7 @@ class AccountDeleteUpdateFeedbackApi(Resource):
return {"result": "success"}
@console_ns.route("/account/education/verify")
class EducationVerifyApi(Resource):
verify_fields = {
"token": fields.String,
@ -364,6 +377,7 @@ class EducationVerifyApi(Resource):
return BillingService.EducationIdentity.verify(account.id, account.email)
@console_ns.route("/account/education")
class EducationApi(Resource):
status_fields = {
"result": fields.Boolean,
@ -408,6 +422,7 @@ class EducationApi(Resource):
return res
@console_ns.route("/account/education/autocomplete")
class EducationAutoCompleteApi(Resource):
data_fields = {
"data": fields.List(fields.String),
@ -431,6 +446,7 @@ class EducationAutoCompleteApi(Resource):
return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"])
@console_ns.route("/account/change-email")
class ChangeEmailSendEmailApi(Resource):
@enable_change_email
@setup_required
@ -479,6 +495,7 @@ class ChangeEmailSendEmailApi(Resource):
return {"result": "success", "data": token}
@console_ns.route("/account/change-email/validity")
class ChangeEmailCheckApi(Resource):
@enable_change_email
@setup_required
@ -520,6 +537,7 @@ class ChangeEmailCheckApi(Resource):
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@console_ns.route("/account/change-email/reset")
class ChangeEmailResetApi(Resource):
@enable_change_email
@setup_required
@ -559,6 +577,7 @@ class ChangeEmailResetApi(Resource):
return updated_account
@console_ns.route("/account/change-email/check-email-unique")
class CheckEmailUnique(Resource):
@setup_required
def post(self):
@ -570,28 +589,3 @@ class CheckEmailUnique(Resource):
if not AccountService.check_email_unique(args["email"]):
raise EmailAlreadyInUseError()
return {"result": "success"}
# Register API resources
api.add_resource(AccountInitApi, "/account/init")
api.add_resource(AccountProfileApi, "/account/profile")
api.add_resource(AccountNameApi, "/account/name")
api.add_resource(AccountAvatarApi, "/account/avatar")
api.add_resource(AccountInterfaceLanguageApi, "/account/interface-language")
api.add_resource(AccountInterfaceThemeApi, "/account/interface-theme")
api.add_resource(AccountTimezoneApi, "/account/timezone")
api.add_resource(AccountPasswordApi, "/account/password")
api.add_resource(AccountIntegrateApi, "/account/integrates")
api.add_resource(AccountDeleteVerifyApi, "/account/delete/verify")
api.add_resource(AccountDeleteApi, "/account/delete")
api.add_resource(AccountDeleteUpdateFeedbackApi, "/account/delete/feedback")
api.add_resource(EducationVerifyApi, "/account/education/verify")
api.add_resource(EducationApi, "/account/education")
api.add_resource(EducationAutoCompleteApi, "/account/education/autocomplete")
# Change email
api.add_resource(ChangeEmailSendEmailApi, "/account/change-email")
api.add_resource(ChangeEmailCheckApi, "/account/change-email/validity")
api.add_resource(ChangeEmailResetApi, "/account/change-email/reset")
api.add_resource(CheckEmailUnique, "/account/change-email/check-email-unique")
# api.add_resource(AccountEmailApi, '/account/email')
# api.add_resource(AccountEmailVerifyApi, '/account/email-verify')

View File

@ -1,7 +1,7 @@
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
@ -10,6 +10,9 @@ from models.account import Account, TenantAccountRole
from services.model_load_balancing_service import ModelLoadBalancingService
@console_ns.route(
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate"
)
class LoadBalancingCredentialsValidateApi(Resource):
@setup_required
@login_required
@ -61,6 +64,9 @@ class LoadBalancingCredentialsValidateApi(Resource):
return response
@console_ns.route(
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate"
)
class LoadBalancingConfigCredentialsValidateApi(Resource):
@setup_required
@login_required
@ -111,15 +117,3 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
response["error"] = error
return response
# Load Balancing Config
api.add_resource(
LoadBalancingCredentialsValidateApi,
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate",
)
api.add_resource(
LoadBalancingConfigCredentialsValidateApi,
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate",
)

View File

@ -6,7 +6,7 @@ from flask_restx import Resource, marshal_with, reqparse
import services
from configs import dify_config
from controllers.console import api
from controllers.console import console_ns
from controllers.console.auth.error import (
CannotTransferOwnerToSelfError,
EmailCodeError,
@ -33,6 +33,7 @@ from services.errors.account import AccountAlreadyInTenantError
from services.feature_service import FeatureService
@console_ns.route("/workspaces/current/members")
class MemberListApi(Resource):
"""List all members of current tenant."""
@ -49,6 +50,7 @@ class MemberListApi(Resource):
return {"result": "success", "accounts": members}, 200
@console_ns.route("/workspaces/current/members/invite-email")
class MemberInviteEmailApi(Resource):
"""Invite a new member by email."""
@ -111,6 +113,7 @@ class MemberInviteEmailApi(Resource):
}, 201
@console_ns.route("/workspaces/current/members/<uuid:member_id>")
class MemberCancelInviteApi(Resource):
"""Cancel an invitation by member id."""
@ -143,6 +146,7 @@ class MemberCancelInviteApi(Resource):
}, 200
@console_ns.route("/workspaces/current/members/<uuid:member_id>/update-role")
class MemberUpdateRoleApi(Resource):
"""Update member role."""
@ -177,6 +181,7 @@ class MemberUpdateRoleApi(Resource):
return {"result": "success"}
@console_ns.route("/workspaces/current/dataset-operators")
class DatasetOperatorMemberListApi(Resource):
"""List all members of current tenant."""
@ -193,6 +198,7 @@ class DatasetOperatorMemberListApi(Resource):
return {"result": "success", "accounts": members}, 200
@console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email")
class SendOwnerTransferEmailApi(Resource):
"""Send owner transfer email."""
@ -233,6 +239,7 @@ class SendOwnerTransferEmailApi(Resource):
return {"result": "success", "data": token}
@console_ns.route("/workspaces/current/members/owner-transfer-check")
class OwnerTransferCheckApi(Resource):
@setup_required
@login_required
@ -278,6 +285,7 @@ class OwnerTransferCheckApi(Resource):
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@console_ns.route("/workspaces/current/members/<uuid:member_id>/owner-transfer")
class OwnerTransfer(Resource):
@setup_required
@login_required
@ -339,14 +347,3 @@ class OwnerTransfer(Resource):
raise ValueError(str(e))
return {"result": "success"}
api.add_resource(MemberListApi, "/workspaces/current/members")
api.add_resource(MemberInviteEmailApi, "/workspaces/current/members/invite-email")
api.add_resource(MemberCancelInviteApi, "/workspaces/current/members/<uuid:member_id>")
api.add_resource(MemberUpdateRoleApi, "/workspaces/current/members/<uuid:member_id>/update-role")
api.add_resource(DatasetOperatorMemberListApi, "/workspaces/current/dataset-operators")
# owner transfer
api.add_resource(SendOwnerTransferEmailApi, "/workspaces/current/members/send-owner-transfer-confirm-email")
api.add_resource(OwnerTransferCheckApi, "/workspaces/current/members/owner-transfer-check")
api.add_resource(OwnerTransfer, "/workspaces/current/members/<uuid:member_id>/owner-transfer")

View File

@ -5,7 +5,7 @@ from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
@ -17,6 +17,7 @@ from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService
@console_ns.route("/workspaces/current/model-providers")
class ModelProviderListApi(Resource):
@setup_required
@login_required
@ -45,6 +46,7 @@ class ModelProviderListApi(Resource):
return jsonable_encoder({"data": provider_list})
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials")
class ModelProviderCredentialApi(Resource):
@setup_required
@login_required
@ -151,6 +153,7 @@ class ModelProviderCredentialApi(Resource):
return {"result": "success"}, 204
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/switch")
class ModelProviderCredentialSwitchApi(Resource):
@setup_required
@login_required
@ -175,6 +178,7 @@ class ModelProviderCredentialSwitchApi(Resource):
return {"result": "success"}
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/validate")
class ModelProviderValidateApi(Resource):
@setup_required
@login_required
@ -211,6 +215,7 @@ class ModelProviderValidateApi(Resource):
return response
@console_ns.route("/workspaces/<string:tenant_id>/model-providers/<path:provider>/<string:icon_type>/<string:lang>")
class ModelProviderIconApi(Resource):
"""
Get model provider icon
@ -229,6 +234,7 @@ class ModelProviderIconApi(Resource):
return send_file(io.BytesIO(icon), mimetype=mimetype)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/preferred-provider-type")
class PreferredProviderTypeUpdateApi(Resource):
@setup_required
@login_required
@ -262,6 +268,7 @@ class PreferredProviderTypeUpdateApi(Resource):
return {"result": "success"}
@console_ns.route("/workspaces/current/model-providers/<path:provider>/checkout-url")
class ModelProviderPaymentCheckoutUrlApi(Resource):
@setup_required
@login_required
@ -281,21 +288,3 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
prefilled_email=current_user.email,
)
return data
api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers")
api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<path:provider>/credentials")
api.add_resource(
ModelProviderCredentialSwitchApi, "/workspaces/current/model-providers/<path:provider>/credentials/switch"
)
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate")
api.add_resource(
PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<path:provider>/preferred-provider-type"
)
api.add_resource(ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<path:provider>/checkout-url")
api.add_resource(
ModelProviderIconApi,
"/workspaces/<string:tenant_id>/model-providers/<path:provider>/<string:icon_type>/<string:lang>",
)

View File

@ -4,7 +4,7 @@ from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
@ -17,6 +17,7 @@ from services.model_provider_service import ModelProviderService
logger = logging.getLogger(__name__)
@console_ns.route("/workspaces/current/default-model")
class DefaultModelApi(Resource):
@setup_required
@login_required
@ -85,6 +86,7 @@ class DefaultModelApi(Resource):
return {"result": "success"}
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models")
class ModelProviderModelApi(Resource):
@setup_required
@login_required
@ -187,6 +189,7 @@ class ModelProviderModelApi(Resource):
return {"result": "success"}, 204
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials")
class ModelProviderModelCredentialApi(Resource):
@setup_required
@login_required
@ -364,6 +367,7 @@ class ModelProviderModelCredentialApi(Resource):
return {"result": "success"}, 204
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/switch")
class ModelProviderModelCredentialSwitchApi(Resource):
@setup_required
@login_required
@ -395,6 +399,9 @@ class ModelProviderModelCredentialSwitchApi(Resource):
return {"result": "success"}
@console_ns.route(
"/workspaces/current/model-providers/<path:provider>/models/enable", endpoint="model-provider-model-enable"
)
class ModelProviderModelEnableApi(Resource):
@setup_required
@login_required
@ -422,6 +429,9 @@ class ModelProviderModelEnableApi(Resource):
return {"result": "success"}
@console_ns.route(
"/workspaces/current/model-providers/<path:provider>/models/disable", endpoint="model-provider-model-disable"
)
class ModelProviderModelDisableApi(Resource):
@setup_required
@login_required
@ -449,6 +459,7 @@ class ModelProviderModelDisableApi(Resource):
return {"result": "success"}
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/validate")
class ModelProviderModelValidateApi(Resource):
@setup_required
@login_required
@ -494,6 +505,7 @@ class ModelProviderModelValidateApi(Resource):
return response
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/parameter-rules")
class ModelProviderModelParameterRuleApi(Resource):
@setup_required
@login_required
@ -513,6 +525,7 @@ class ModelProviderModelParameterRuleApi(Resource):
return jsonable_encoder({"data": parameter_rules})
@console_ns.route("/workspaces/current/models/model-types/<string:model_type>")
class ModelProviderAvailableModelApi(Resource):
@setup_required
@login_required
@ -524,32 +537,3 @@ class ModelProviderAvailableModelApi(Resource):
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
return jsonable_encoder({"data": models})
api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<path:provider>/models")
api.add_resource(
ModelProviderModelEnableApi,
"/workspaces/current/model-providers/<path:provider>/models/enable",
endpoint="model-provider-model-enable",
)
api.add_resource(
ModelProviderModelDisableApi,
"/workspaces/current/model-providers/<path:provider>/models/disable",
endpoint="model-provider-model-disable",
)
api.add_resource(
ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<path:provider>/models/credentials"
)
api.add_resource(
ModelProviderModelCredentialSwitchApi,
"/workspaces/current/model-providers/<path:provider>/models/credentials/switch",
)
api.add_resource(
ModelProviderModelValidateApi, "/workspaces/current/model-providers/<path:provider>/models/credentials/validate"
)
api.add_resource(
ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<path:provider>/models/parameter-rules"
)
api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>")
api.add_resource(DefaultModelApi, "/workspaces/current/default-model")

View File

@ -6,7 +6,7 @@ from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api
from controllers.console import console_ns
from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
@ -19,6 +19,7 @@ from services.plugin.plugin_permission_service import PluginPermissionService
from services.plugin.plugin_service import PluginService
@console_ns.route("/workspaces/current/plugin/debugging-key")
class PluginDebuggingKeyApi(Resource):
@setup_required
@login_required
@ -37,6 +38,7 @@ class PluginDebuggingKeyApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/list")
class PluginListApi(Resource):
@setup_required
@login_required
@ -55,6 +57,7 @@ class PluginListApi(Resource):
return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total})
@console_ns.route("/workspaces/current/plugin/list/latest-versions")
class PluginListLatestVersionsApi(Resource):
@setup_required
@login_required
@ -72,6 +75,7 @@ class PluginListLatestVersionsApi(Resource):
return jsonable_encoder({"versions": versions})
@console_ns.route("/workspaces/current/plugin/list/installations/ids")
class PluginListInstallationsFromIdsApi(Resource):
@setup_required
@login_required
@ -91,6 +95,7 @@ class PluginListInstallationsFromIdsApi(Resource):
return jsonable_encoder({"plugins": plugins})
@console_ns.route("/workspaces/current/plugin/icon")
class PluginIconApi(Resource):
@setup_required
def get(self):
@ -108,6 +113,7 @@ class PluginIconApi(Resource):
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
@console_ns.route("/workspaces/current/plugin/upload/pkg")
class PluginUploadFromPkgApi(Resource):
@setup_required
@login_required
@ -131,6 +137,7 @@ class PluginUploadFromPkgApi(Resource):
return jsonable_encoder(response)
@console_ns.route("/workspaces/current/plugin/upload/github")
class PluginUploadFromGithubApi(Resource):
@setup_required
@login_required
@ -153,6 +160,7 @@ class PluginUploadFromGithubApi(Resource):
return jsonable_encoder(response)
@console_ns.route("/workspaces/current/plugin/upload/bundle")
class PluginUploadFromBundleApi(Resource):
@setup_required
@login_required
@ -176,6 +184,7 @@ class PluginUploadFromBundleApi(Resource):
return jsonable_encoder(response)
@console_ns.route("/workspaces/current/plugin/install/pkg")
class PluginInstallFromPkgApi(Resource):
@setup_required
@login_required
@ -201,6 +210,7 @@ class PluginInstallFromPkgApi(Resource):
return jsonable_encoder(response)
@console_ns.route("/workspaces/current/plugin/install/github")
class PluginInstallFromGithubApi(Resource):
@setup_required
@login_required
@ -230,6 +240,7 @@ class PluginInstallFromGithubApi(Resource):
return jsonable_encoder(response)
@console_ns.route("/workspaces/current/plugin/install/marketplace")
class PluginInstallFromMarketplaceApi(Resource):
@setup_required
@login_required
@ -255,6 +266,7 @@ class PluginInstallFromMarketplaceApi(Resource):
return jsonable_encoder(response)
@console_ns.route("/workspaces/current/plugin/marketplace/pkg")
class PluginFetchMarketplacePkgApi(Resource):
@setup_required
@login_required
@ -280,6 +292,7 @@ class PluginFetchMarketplacePkgApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/fetch-manifest")
class PluginFetchManifestApi(Resource):
@setup_required
@login_required
@ -304,6 +317,7 @@ class PluginFetchManifestApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/tasks")
class PluginFetchInstallTasksApi(Resource):
@setup_required
@login_required
@ -325,6 +339,7 @@ class PluginFetchInstallTasksApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/tasks/<task_id>")
class PluginFetchInstallTaskApi(Resource):
@setup_required
@login_required
@ -339,6 +354,7 @@ class PluginFetchInstallTaskApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/tasks/<task_id>/delete")
class PluginDeleteInstallTaskApi(Resource):
@setup_required
@login_required
@ -353,6 +369,7 @@ class PluginDeleteInstallTaskApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/tasks/delete_all")
class PluginDeleteAllInstallTaskItemsApi(Resource):
@setup_required
@login_required
@ -367,6 +384,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
class PluginDeleteInstallTaskItemApi(Resource):
@setup_required
@login_required
@ -381,6 +399,7 @@ class PluginDeleteInstallTaskItemApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/upgrade/marketplace")
class PluginUpgradeFromMarketplaceApi(Resource):
@setup_required
@login_required
@ -404,6 +423,7 @@ class PluginUpgradeFromMarketplaceApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/upgrade/github")
class PluginUpgradeFromGithubApi(Resource):
@setup_required
@login_required
@ -435,6 +455,7 @@ class PluginUpgradeFromGithubApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/uninstall")
class PluginUninstallApi(Resource):
@setup_required
@login_required
@ -453,6 +474,7 @@ class PluginUninstallApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/permission/change")
class PluginChangePermissionApi(Resource):
@setup_required
@login_required
@ -475,6 +497,7 @@ class PluginChangePermissionApi(Resource):
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
@console_ns.route("/workspaces/current/plugin/permission/fetch")
class PluginFetchPermissionApi(Resource):
@setup_required
@login_required
@ -499,6 +522,7 @@ class PluginFetchPermissionApi(Resource):
)
@console_ns.route("/workspaces/current/plugin/parameters/dynamic-options")
class PluginFetchDynamicSelectOptionsApi(Resource):
@setup_required
@login_required
@ -535,6 +559,7 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
return jsonable_encoder({"options": options})
@console_ns.route("/workspaces/current/plugin/preferences/change")
class PluginChangePreferencesApi(Resource):
@setup_required
@login_required
@ -590,6 +615,7 @@ class PluginChangePreferencesApi(Resource):
return jsonable_encoder({"success": True})
@console_ns.route("/workspaces/current/plugin/preferences/fetch")
class PluginFetchPreferencesApi(Resource):
@setup_required
@login_required
@ -628,6 +654,7 @@ class PluginFetchPreferencesApi(Resource):
return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict})
@console_ns.route("/workspaces/current/plugin/preferences/autoupgrade/exclude")
class PluginAutoUpgradeExcludePluginApi(Resource):
@setup_required
@login_required
@ -641,35 +668,3 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
args = req.parse_args()
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])})
api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key")
api.add_resource(PluginListApi, "/workspaces/current/plugin/list")
api.add_resource(PluginListLatestVersionsApi, "/workspaces/current/plugin/list/latest-versions")
api.add_resource(PluginListInstallationsFromIdsApi, "/workspaces/current/plugin/list/installations/ids")
api.add_resource(PluginIconApi, "/workspaces/current/plugin/icon")
api.add_resource(PluginUploadFromPkgApi, "/workspaces/current/plugin/upload/pkg")
api.add_resource(PluginUploadFromGithubApi, "/workspaces/current/plugin/upload/github")
api.add_resource(PluginUploadFromBundleApi, "/workspaces/current/plugin/upload/bundle")
api.add_resource(PluginInstallFromPkgApi, "/workspaces/current/plugin/install/pkg")
api.add_resource(PluginInstallFromGithubApi, "/workspaces/current/plugin/install/github")
api.add_resource(PluginUpgradeFromMarketplaceApi, "/workspaces/current/plugin/upgrade/marketplace")
api.add_resource(PluginUpgradeFromGithubApi, "/workspaces/current/plugin/upgrade/github")
api.add_resource(PluginInstallFromMarketplaceApi, "/workspaces/current/plugin/install/marketplace")
api.add_resource(PluginFetchManifestApi, "/workspaces/current/plugin/fetch-manifest")
api.add_resource(PluginFetchInstallTasksApi, "/workspaces/current/plugin/tasks")
api.add_resource(PluginFetchInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>")
api.add_resource(PluginDeleteInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>/delete")
api.add_resource(PluginDeleteAllInstallTaskItemsApi, "/workspaces/current/plugin/tasks/delete_all")
api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall")
api.add_resource(PluginFetchMarketplacePkgApi, "/workspaces/current/plugin/marketplace/pkg")
api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change")
api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch")
api.add_resource(PluginFetchDynamicSelectOptionsApi, "/workspaces/current/plugin/parameters/dynamic-options")
api.add_resource(PluginFetchPreferencesApi, "/workspaces/current/plugin/preferences/fetch")
api.add_resource(PluginChangePreferencesApi, "/workspaces/current/plugin/preferences/change")
api.add_resource(PluginAutoUpgradeExcludePluginApi, "/workspaces/current/plugin/preferences/autoupgrade/exclude")

View File

@ -10,7 +10,7 @@ from flask_restx import (
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
enterprise_license_required,
@ -47,6 +47,7 @@ def is_valid_url(url: str) -> bool:
return False
@console_ns.route("/workspaces/current/tool-providers")
class ToolProviderListApi(Resource):
@setup_required
@login_required
@ -71,6 +72,7 @@ class ToolProviderListApi(Resource):
return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None))
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/tools")
class ToolBuiltinProviderListToolsApi(Resource):
@setup_required
@login_required
@ -88,6 +90,7 @@ class ToolBuiltinProviderListToolsApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/info")
class ToolBuiltinProviderInfoApi(Resource):
@setup_required
@login_required
@ -100,6 +103,7 @@ class ToolBuiltinProviderInfoApi(Resource):
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/delete")
class ToolBuiltinProviderDeleteApi(Resource):
@setup_required
@login_required
@ -121,6 +125,7 @@ class ToolBuiltinProviderDeleteApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/add")
class ToolBuiltinProviderAddApi(Resource):
@setup_required
@login_required
@ -150,6 +155,7 @@ class ToolBuiltinProviderAddApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/update")
class ToolBuiltinProviderUpdateApi(Resource):
@setup_required
@login_required
@ -181,6 +187,7 @@ class ToolBuiltinProviderUpdateApi(Resource):
return result
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/credentials")
class ToolBuiltinProviderGetCredentialsApi(Resource):
@setup_required
@login_required
@ -196,6 +203,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/icon")
class ToolBuiltinProviderIconApi(Resource):
@setup_required
def get(self, provider):
@ -204,6 +212,7 @@ class ToolBuiltinProviderIconApi(Resource):
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
@console_ns.route("/workspaces/current/tool-provider/api/add")
class ToolApiProviderAddApi(Resource):
@setup_required
@login_required
@ -243,6 +252,7 @@ class ToolApiProviderAddApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/api/remote")
class ToolApiProviderGetRemoteSchemaApi(Resource):
@setup_required
@login_required
@ -266,6 +276,7 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/api/tools")
class ToolApiProviderListToolsApi(Resource):
@setup_required
@login_required
@ -291,6 +302,7 @@ class ToolApiProviderListToolsApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/api/update")
class ToolApiProviderUpdateApi(Resource):
@setup_required
@login_required
@ -332,6 +344,7 @@ class ToolApiProviderUpdateApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/api/delete")
class ToolApiProviderDeleteApi(Resource):
@setup_required
@login_required
@ -358,6 +371,7 @@ class ToolApiProviderDeleteApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/api/get")
class ToolApiProviderGetApi(Resource):
@setup_required
@login_required
@ -381,6 +395,7 @@ class ToolApiProviderGetApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/credential/schema/<path:credential_type>")
class ToolBuiltinProviderCredentialsSchemaApi(Resource):
@setup_required
@login_required
@ -396,6 +411,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/api/schema")
class ToolApiProviderSchemaApi(Resource):
@setup_required
@login_required
@ -412,6 +428,7 @@ class ToolApiProviderSchemaApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/api/test/pre")
class ToolApiProviderPreviousTestApi(Resource):
@setup_required
@login_required
@ -439,6 +456,7 @@ class ToolApiProviderPreviousTestApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/workflow/create")
class ToolWorkflowProviderCreateApi(Resource):
@setup_required
@login_required
@ -478,6 +496,7 @@ class ToolWorkflowProviderCreateApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/workflow/update")
class ToolWorkflowProviderUpdateApi(Resource):
@setup_required
@login_required
@ -520,6 +539,7 @@ class ToolWorkflowProviderUpdateApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/workflow/delete")
class ToolWorkflowProviderDeleteApi(Resource):
@setup_required
@login_required
@ -545,6 +565,7 @@ class ToolWorkflowProviderDeleteApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/workflow/get")
class ToolWorkflowProviderGetApi(Resource):
@setup_required
@login_required
@ -579,6 +600,7 @@ class ToolWorkflowProviderGetApi(Resource):
return jsonable_encoder(tool)
@console_ns.route("/workspaces/current/tool-provider/workflow/tools")
class ToolWorkflowProviderListToolApi(Resource):
@setup_required
@login_required
@ -603,6 +625,7 @@ class ToolWorkflowProviderListToolApi(Resource):
)
@console_ns.route("/workspaces/current/tools/builtin")
class ToolBuiltinListApi(Resource):
@setup_required
@login_required
@ -624,6 +647,7 @@ class ToolBuiltinListApi(Resource):
)
@console_ns.route("/workspaces/current/tools/api")
class ToolApiListApi(Resource):
@setup_required
@login_required
@ -642,6 +666,7 @@ class ToolApiListApi(Resource):
)
@console_ns.route("/workspaces/current/tools/workflow")
class ToolWorkflowListApi(Resource):
@setup_required
@login_required
@ -663,6 +688,7 @@ class ToolWorkflowListApi(Resource):
)
@console_ns.route("/workspaces/current/tool-labels")
class ToolLabelsApi(Resource):
@setup_required
@login_required
@ -672,6 +698,7 @@ class ToolLabelsApi(Resource):
return jsonable_encoder(ToolLabelsService.list_tool_labels())
@console_ns.route("/oauth/plugin/<path:provider>/tool/authorization-url")
class ToolPluginOAuthApi(Resource):
@setup_required
@login_required
@ -716,6 +743,7 @@ class ToolPluginOAuthApi(Resource):
return response
@console_ns.route("/oauth/plugin/<path:provider>/tool/callback")
class ToolOAuthCallback(Resource):
@setup_required
def get(self, provider):
@ -766,6 +794,7 @@ class ToolOAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/default-credential")
class ToolBuiltinProviderSetDefaultApi(Resource):
@setup_required
@login_required
@ -779,6 +808,7 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client")
class ToolOAuthCustomClient(Resource):
@setup_required
@login_required
@ -822,6 +852,7 @@ class ToolOAuthCustomClient(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/oauth/client-schema")
class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
@setup_required
@login_required
@ -834,6 +865,7 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/credential/info")
class ToolBuiltinProviderGetCredentialInfoApi(Resource):
@setup_required
@login_required
@ -849,6 +881,7 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/mcp")
class ToolProviderMCPApi(Resource):
@setup_required
@login_required
@ -933,6 +966,7 @@ class ToolProviderMCPApi(Resource):
return {"result": "success"}
@console_ns.route("/workspaces/current/tool-provider/mcp/auth")
class ToolMCPAuthApi(Resource):
@setup_required
@login_required
@ -978,6 +1012,7 @@ class ToolMCPAuthApi(Resource):
raise ValueError(f"Failed to connect to MCP server: {e}") from e
@console_ns.route("/workspaces/current/tool-provider/mcp/tools/<path:provider_id>")
class ToolMCPDetailApi(Resource):
@setup_required
@login_required
@ -988,6 +1023,7 @@ class ToolMCPDetailApi(Resource):
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
@console_ns.route("/workspaces/current/tools/mcp")
class ToolMCPListAllApi(Resource):
@setup_required
@login_required
@ -1001,6 +1037,7 @@ class ToolMCPListAllApi(Resource):
return [tool.to_dict() for tool in tools]
@console_ns.route("/workspaces/current/tool-provider/mcp/update/<path:provider_id>")
class ToolMCPUpdateApi(Resource):
@setup_required
@login_required
@ -1014,6 +1051,7 @@ class ToolMCPUpdateApi(Resource):
return jsonable_encoder(tools)
@console_ns.route("/mcp/oauth/callback")
class ToolMCPCallbackApi(Resource):
def get(self):
parser = reqparse.RequestParser()
@ -1024,67 +1062,3 @@ class ToolMCPCallbackApi(Resource):
authorization_code = args["code"]
handle_callback(state_key, authorization_code)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
# tool provider
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
# tool oauth
api.add_resource(ToolPluginOAuthApi, "/oauth/plugin/<path:provider>/tool/authorization-url")
api.add_resource(ToolOAuthCallback, "/oauth/plugin/<path:provider>/tool/callback")
api.add_resource(ToolOAuthCustomClient, "/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client")
# builtin tool provider
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/tools")
api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/info")
api.add_resource(ToolBuiltinProviderAddApi, "/workspaces/current/tool-provider/builtin/<path:provider>/add")
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
api.add_resource(
ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin/<path:provider>/default-credential"
)
api.add_resource(
ToolBuiltinProviderGetCredentialInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credential/info"
)
api.add_resource(
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials"
)
api.add_resource(
ToolBuiltinProviderCredentialsSchemaApi,
"/workspaces/current/tool-provider/builtin/<path:provider>/credential/schema/<path:credential_type>",
)
api.add_resource(
ToolBuiltinProviderGetOauthClientSchemaApi,
"/workspaces/current/tool-provider/builtin/<path:provider>/oauth/client-schema",
)
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon")
# api tool provider
api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add")
api.add_resource(ToolApiProviderGetRemoteSchemaApi, "/workspaces/current/tool-provider/api/remote")
api.add_resource(ToolApiProviderListToolsApi, "/workspaces/current/tool-provider/api/tools")
api.add_resource(ToolApiProviderUpdateApi, "/workspaces/current/tool-provider/api/update")
api.add_resource(ToolApiProviderDeleteApi, "/workspaces/current/tool-provider/api/delete")
api.add_resource(ToolApiProviderGetApi, "/workspaces/current/tool-provider/api/get")
api.add_resource(ToolApiProviderSchemaApi, "/workspaces/current/tool-provider/api/schema")
api.add_resource(ToolApiProviderPreviousTestApi, "/workspaces/current/tool-provider/api/test/pre")
# workflow tool provider
api.add_resource(ToolWorkflowProviderCreateApi, "/workspaces/current/tool-provider/workflow/create")
api.add_resource(ToolWorkflowProviderUpdateApi, "/workspaces/current/tool-provider/workflow/update")
api.add_resource(ToolWorkflowProviderDeleteApi, "/workspaces/current/tool-provider/workflow/delete")
api.add_resource(ToolWorkflowProviderGetApi, "/workspaces/current/tool-provider/workflow/get")
api.add_resource(ToolWorkflowProviderListToolApi, "/workspaces/current/tool-provider/workflow/tools")
# mcp tool provider
api.add_resource(ToolMCPDetailApi, "/workspaces/current/tool-provider/mcp/tools/<path:provider_id>")
api.add_resource(ToolProviderMCPApi, "/workspaces/current/tool-provider/mcp")
api.add_resource(ToolMCPUpdateApi, "/workspaces/current/tool-provider/mcp/update/<path:provider_id>")
api.add_resource(ToolMCPAuthApi, "/workspaces/current/tool-provider/mcp/auth")
api.add_resource(ToolMCPCallbackApi, "/mcp/oauth/callback")
api.add_resource(ToolBuiltinListApi, "/workspaces/current/tools/builtin")
api.add_resource(ToolApiListApi, "/workspaces/current/tools/api")
api.add_resource(ToolMCPListAllApi, "/workspaces/current/tools/mcp")
api.add_resource(ToolWorkflowListApi, "/workspaces/current/tools/workflow")
api.add_resource(ToolLabelsApi, "/workspaces/current/tool-labels")

View File

@ -14,7 +14,7 @@ from controllers.common.errors import (
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.console import api
from controllers.console import console_ns
from controllers.console.admin import admin_required
from controllers.console.error import AccountNotLinkTenantError
from controllers.console.wraps import (
@ -68,6 +68,7 @@ tenants_fields = {
workspace_fields = {"id": fields.String, "name": fields.String, "status": fields.String, "created_at": TimestampField}
@console_ns.route("/workspaces")
class TenantListApi(Resource):
@setup_required
@login_required
@ -96,6 +97,7 @@ class TenantListApi(Resource):
return {"workspaces": marshal(tenant_dicts, tenants_fields)}, 200
@console_ns.route("/all-workspaces")
class WorkspaceListApi(Resource):
@setup_required
@admin_required
@ -121,6 +123,8 @@ class WorkspaceListApi(Resource):
}, 200
@console_ns.route("/workspaces/current", endpoint="workspaces_current")
@console_ns.route("/info", endpoint="info") # Deprecated
class TenantApi(Resource):
@setup_required
@login_required
@ -146,11 +150,10 @@ class TenantApi(Resource):
else:
raise Unauthorized("workspace is archived")
if not tenant:
raise ValueError("No tenant available")
return WorkspaceService.get_tenant_info(tenant), 200
@console_ns.route("/workspaces/switch")
class SwitchWorkspaceApi(Resource):
@setup_required
@login_required
@ -175,6 +178,7 @@ class SwitchWorkspaceApi(Resource):
return {"result": "success", "new_tenant": marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)}
@console_ns.route("/workspaces/custom-config")
class CustomConfigWorkspaceApi(Resource):
@setup_required
@login_required
@ -205,6 +209,7 @@ class CustomConfigWorkspaceApi(Resource):
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
@console_ns.route("/workspaces/custom-config/webapp-logo/upload")
class WebappLogoWorkspaceApi(Resource):
@setup_required
@login_required
@ -245,6 +250,7 @@ class WebappLogoWorkspaceApi(Resource):
return {"id": upload_file.id}, 201
@console_ns.route("/workspaces/info")
class WorkspaceInfoApi(Resource):
@setup_required
@login_required
@ -264,13 +270,3 @@ class WorkspaceInfoApi(Resource):
db.session.commit()
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
api.add_resource(TenantListApi, "/workspaces") # GET for getting all tenants
api.add_resource(WorkspaceListApi, "/all-workspaces") # GET for getting all tenants
api.add_resource(TenantApi, "/workspaces/current", endpoint="workspaces_current") # GET for getting current tenant info
api.add_resource(TenantApi, "/info", endpoint="info") # Deprecated
api.add_resource(SwitchWorkspaceApi, "/workspaces/switch") # POST for switching tenant
api.add_resource(CustomConfigWorkspaceApi, "/workspaces/custom-config")
api.add_resource(WebappLogoWorkspaceApi, "/workspaces/custom-config/webapp-logo/upload")
api.add_resource(WorkspaceInfoApi, "/workspaces/info") # POST for changing workspace info

View File

@ -128,7 +128,7 @@ def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseMo
raise ValueError("invalid json")
try:
payload = payload_type(**data)
payload = payload_type.model_validate(data)
except Exception as e:
raise ValueError(f"invalid payload: {str(e)}")

View File

@ -19,6 +19,7 @@ from .app import (
annotation,
app,
audio,
chatflow_memory,
completion,
conversation,
file,
@ -40,6 +41,7 @@ __all__ = [
"annotation",
"app",
"audio",
"chatflow_memory",
"completion",
"conversation",
"dataset",

View File

@ -0,0 +1,124 @@
from flask_restx import Resource, reqparse
from controllers.service_api import api
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.memory.entities import MemoryBlock, MemoryCreatedBy
from core.workflow.entities.variable_pool import VariablePool
from models import App, EndUser
from services.chatflow_memory_service import ChatflowMemoryService
from services.workflow_service import WorkflowService
class MemoryListApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def get(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument("conversation_id", required=False, type=str | None, default=None)
parser.add_argument("memory_id", required=False, type=str | None, default=None)
parser.add_argument("version", required=False, type=int | None, default=None)
args = parser.parse_args()
conversation_id: str | None = args.get("conversation_id")
memory_id = args.get("memory_id")
version = args.get("version")
if conversation_id:
result = ChatflowMemoryService.get_persistent_memories_with_conversation(
app_model,
MemoryCreatedBy(end_user_id=end_user.id),
conversation_id,
version
)
session_memories = ChatflowMemoryService.get_session_memories_with_conversation(
app_model,
MemoryCreatedBy(end_user_id=end_user.id),
conversation_id,
version
)
result = [*result, *session_memories]
else:
result = ChatflowMemoryService.get_persistent_memories(
app_model,
MemoryCreatedBy(end_user_id=end_user.id),
version
)
if memory_id:
result = [it for it in result if it.spec.id == memory_id]
return [it for it in result if it.spec.end_user_visible]
class MemoryEditApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def put(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument('id', type=str, required=True)
parser.add_argument("conversation_id", type=str | None, required=False, default=None)
parser.add_argument('node_id', type=str | None, required=False, default=None)
parser.add_argument('update', type=str, required=True)
args = parser.parse_args()
workflow = WorkflowService().get_published_workflow(app_model)
update = args.get("update")
conversation_id = args.get("conversation_id")
node_id = args.get("node_id")
if not isinstance(update, str):
return {'error': 'Invalid update'}, 400
if not workflow:
return {'error': 'Workflow not found'}, 404
memory_spec = next((it for it in workflow.memory_blocks if it.id == args['id']), None)
if not memory_spec:
return {'error': 'Memory not found'}, 404
# First get existing memory
existing_memory = ChatflowMemoryService.get_memory_by_spec(
spec=memory_spec,
tenant_id=app_model.tenant_id,
app_id=app_model.id,
created_by=MemoryCreatedBy(end_user_id=end_user.id),
conversation_id=conversation_id,
node_id=node_id,
is_draft=False
)
# Create updated memory instance with incremented version
updated_memory = MemoryBlock(
spec=existing_memory.spec,
tenant_id=existing_memory.tenant_id,
app_id=existing_memory.app_id,
conversation_id=existing_memory.conversation_id,
node_id=existing_memory.node_id,
value=update, # New value
version=existing_memory.version + 1, # Increment version for update
edited_by_user=True,
created_by=existing_memory.created_by,
)
ChatflowMemoryService.save_memory(updated_memory, VariablePool(), False)
return '', 204
class MemoryDeleteApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def delete(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument('id', type=str, required=False, default=None)
args = parser.parse_args()
memory_id = args.get('id')
if memory_id:
ChatflowMemoryService.delete_memory(
app_model,
memory_id,
MemoryCreatedBy(end_user_id=end_user.id)
)
return '', 204
else:
ChatflowMemoryService.delete_all_user_memories(
app_model,
MemoryCreatedBy(end_user_id=end_user.id)
)
return '', 200
api.add_resource(MemoryListApi, '/memories')
api.add_resource(MemoryEditApi, '/memory-edit')
api.add_resource(MemoryDeleteApi, '/memories')

View File

@ -280,7 +280,7 @@ class DatasetListApi(DatasetApiResource):
external_knowledge_id=args["external_knowledge_id"],
embedding_model_provider=args["embedding_model_provider"],
embedding_model_name=args["embedding_model"],
retrieval_model=RetrievalModel(**args["retrieval_model"])
retrieval_model=RetrievalModel.model_validate(args["retrieval_model"])
if args["retrieval_model"] is not None
else None,
)

View File

@ -136,7 +136,7 @@ class DocumentAddByTextApi(DatasetApiResource):
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source
knowledge_config = KnowledgeConfig(**args)
knowledge_config = KnowledgeConfig.model_validate(args)
# validate args
DocumentService.document_create_args_validate(knowledge_config)
@ -221,7 +221,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
args["data_source"] = data_source
# validate args
args["original_document_id"] = str(document_id)
knowledge_config = KnowledgeConfig(**args)
knowledge_config = KnowledgeConfig.model_validate(args)
DocumentService.document_create_args_validate(knowledge_config)
try:
@ -328,7 +328,7 @@ class DocumentAddByFileApi(DatasetApiResource):
}
args["data_source"] = data_source
# validate args
knowledge_config = KnowledgeConfig(**args)
knowledge_config = KnowledgeConfig.model_validate(args)
DocumentService.document_create_args_validate(knowledge_config)
dataset_process_rule = dataset.latest_process_rule if "process_rule" not in args else None
@ -426,7 +426,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
# validate args
args["original_document_id"] = str(document_id)
knowledge_config = KnowledgeConfig(**args)
knowledge_config = KnowledgeConfig.model_validate(args)
DocumentService.document_create_args_validate(knowledge_config)
try:

View File

@ -51,7 +51,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
def post(self, tenant_id, dataset_id):
"""Create metadata for a dataset."""
args = metadata_create_parser.parse_args()
metadata_args = MetadataArgs(**args)
metadata_args = MetadataArgs.model_validate(args)
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -200,7 +200,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
DatasetService.check_dataset_permission(dataset, current_user)
args = document_metadata_parser.parse_args()
metadata_args = MetadataOperationData(**args)
metadata_args = MetadataOperationData.model_validate(args)
MetadataService.update_documents_metadata(dataset, metadata_args)

View File

@ -98,7 +98,7 @@ class DatasourceNodeRunApi(DatasetApiResource):
parser.add_argument("is_published", type=bool, required=True, location="json")
args: ParseResult = parser.parse_args()
datasource_node_run_api_entity: DatasourceNodeRunApiEntity = DatasourceNodeRunApiEntity(**args)
datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(args)
assert isinstance(current_user, Account)
rag_pipeline_service: RagPipelineService = RagPipelineService()
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)

View File

@ -252,7 +252,7 @@ class DatasetSegmentApi(DatasetApiResource):
args = segment_update_parser.parse_args()
updated_segment = SegmentService.update_segment(
SegmentUpdateArgs(**args["segment"]), segment, document, dataset
SegmentUpdateArgs.model_validate(args["segment"]), segment, document, dataset
)
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200

View File

@ -18,6 +18,7 @@ web_ns = Namespace("web", description="Web application API operations", path="/"
from . import (
app,
audio,
chatflow_memory,
completion,
conversation,
feature,
@ -39,6 +40,7 @@ __all__ = [
"app",
"audio",
"bp",
"chatflow_memory",
"completion",
"conversation",
"feature",

View File

@ -0,0 +1,123 @@
from flask_restx import reqparse
from controllers.web import api
from controllers.web.wraps import WebApiResource
from core.memory.entities import MemoryBlock, MemoryCreatedBy
from core.workflow.entities.variable_pool import VariablePool
from models import App, EndUser
from services.chatflow_memory_service import ChatflowMemoryService
from services.workflow_service import WorkflowService
class MemoryListApi(WebApiResource):
def get(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument("conversation_id", required=False, type=str | None, default=None)
parser.add_argument("memory_id", required=False, type=str | None, default=None)
parser.add_argument("version", required=False, type=int | None, default=None)
args = parser.parse_args()
conversation_id: str | None = args.get("conversation_id")
memory_id = args.get("memory_id")
version = args.get("version")
if conversation_id:
result = ChatflowMemoryService.get_persistent_memories_with_conversation(
app_model,
MemoryCreatedBy(end_user_id=end_user.id),
conversation_id,
version
)
session_memories = ChatflowMemoryService.get_session_memories_with_conversation(
app_model,
MemoryCreatedBy(end_user_id=end_user.id),
conversation_id,
version
)
result = [*result, *session_memories]
else:
result = ChatflowMemoryService.get_persistent_memories(
app_model,
MemoryCreatedBy(end_user_id=end_user.id),
version
)
if memory_id:
result = [it for it in result if it.spec.id == memory_id]
return [it for it in result if it.spec.end_user_visible]
class MemoryEditApi(WebApiResource):
def put(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument('id', type=str, required=True)
parser.add_argument("conversation_id", type=str | None, required=False, default=None)
parser.add_argument('node_id', type=str | None, required=False, default=None)
parser.add_argument('update', type=str, required=True)
args = parser.parse_args()
workflow = WorkflowService().get_published_workflow(app_model)
update = args.get("update")
conversation_id = args.get("conversation_id")
node_id = args.get("node_id")
if not isinstance(update, str):
return {'error': 'Update must be a string'}, 400
if not workflow:
return {'error': 'Workflow not found'}, 404
memory_spec = next((it for it in workflow.memory_blocks if it.id == args['id']), None)
if not memory_spec:
return {'error': 'Memory not found'}, 404
if not memory_spec.end_user_editable:
return {'error': 'Memory not editable'}, 403
# First get existing memory
existing_memory = ChatflowMemoryService.get_memory_by_spec(
spec=memory_spec,
tenant_id=app_model.tenant_id,
app_id=app_model.id,
created_by=MemoryCreatedBy(end_user_id=end_user.id),
conversation_id=conversation_id,
node_id=node_id,
is_draft=False
)
# Create updated memory instance with incremented version
updated_memory = MemoryBlock(
spec=existing_memory.spec,
tenant_id=existing_memory.tenant_id,
app_id=existing_memory.app_id,
conversation_id=existing_memory.conversation_id,
node_id=existing_memory.node_id,
value=update, # New value
version=existing_memory.version + 1, # Increment version for update
edited_by_user=True,
created_by=existing_memory.created_by,
)
ChatflowMemoryService.save_memory(updated_memory, VariablePool(), False)
return '', 204
class MemoryDeleteApi(WebApiResource):
def delete(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument('id', type=str, required=False, default=None)
args = parser.parse_args()
memory_id = args.get('id')
if memory_id:
ChatflowMemoryService.delete_memory(
app_model,
memory_id,
MemoryCreatedBy(end_user_id=end_user.id)
)
return '', 204
else:
ChatflowMemoryService.delete_all_user_memories(
app_model,
MemoryCreatedBy(end_user_id=end_user.id)
)
return '', 200
api.add_resource(MemoryListApi, '/memories')
api.add_resource(MemoryEditApi, '/memory-edit')
api.add_resource(MemoryDeleteApi, '/memories')

View File

@ -126,6 +126,8 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
end_user_id = enterprise_user_decoded.get("end_user_id")
session_id = enterprise_user_decoded.get("session_id")
user_auth_type = enterprise_user_decoded.get("auth_type")
exchanged_token_expires_unix = enterprise_user_decoded.get("exp")
if not user_auth_type:
raise Unauthorized("Missing auth_type in the token.")
@ -169,8 +171,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
)
db.session.add(end_user)
db.session.commit()
exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)
exp = int(exp_dt.timestamp())
exp = int((datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)).timestamp())
if exchanged_token_expires_unix:
exp = int(exchanged_token_expires_unix)
payload = {
"iss": site.id,
"sub": "Web API Passport",

View File

@ -40,7 +40,7 @@ class AgentConfigManager:
"credential_id": tool.get("credential_id", None),
}
agent_tools.append(AgentToolEntity(**agent_tool_properties))
agent_tools.append(AgentToolEntity.model_validate(agent_tool_properties))
if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in {
"react_router",

View File

@ -1,10 +1,11 @@
import logging
import time
from collections.abc import Mapping
from collections.abc import Mapping, MutableMapping
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from typing_extensions import override
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
@ -20,11 +21,14 @@ from core.app.entities.queue_entities import (
QueueTextChunkEvent,
)
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.memory.entities import MemoryCreatedBy, MemoryScope
from core.model_runtime.entities import AssistantPromptMessage, UserPromptMessage
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion
from core.workflow.entities import GraphRuntimeState, VariablePool
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_events import GraphRunSucceededEvent
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
@ -34,6 +38,8 @@ from models import Workflow
from models.enums import UserFrom
from models.model import App, Conversation, Message, MessageAnnotation
from models.workflow import ConversationVariable
from services.chatflow_history_service import ChatflowHistoryService
from services.chatflow_memory_service import ChatflowMemoryService
logger = logging.getLogger(__name__)
@ -70,6 +76,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self._app = app
def run(self):
ChatflowMemoryService.wait_for_sync_memory_completion(
workflow=self._workflow,
conversation_id=self.conversation.id
)
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)
@ -133,6 +144,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
conversation_variables=conversation_variables,
memory_blocks=self._fetch_memory_blocks(),
)
# init graph
@ -177,6 +189,31 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
for event in generator:
self._handle_event(workflow_entry, event)
try:
self._check_app_memory_updates(variable_pool)
except Exception as e:
logger.exception("Failed to check app memory updates", exc_info=e)
@override
def _handle_event(self, workflow_entry: WorkflowEntry, event: Any) -> None:
super()._handle_event(workflow_entry, event)
if isinstance(event, GraphRunSucceededEvent):
workflow_outputs = event.outputs
if not workflow_outputs:
logger.warning("Chatflow output is empty.")
return
assistant_message = workflow_outputs.get('answer')
if not assistant_message:
logger.warning("Chatflow output does not contain 'answer'.")
return
if not isinstance(assistant_message, str):
logger.warning("Chatflow output 'answer' is not a string.")
return
try:
self._sync_conversation_to_chatflow_tables(assistant_message)
except Exception as e:
logger.exception("Failed to sync conversation to memory tables", exc_info=e)
def handle_input_moderation(
self,
app_record: App,
@ -374,3 +411,67 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
# Return combined list
return existing_variables + new_variables
def _fetch_memory_blocks(self) -> Mapping[str, str]:
"""fetch all memory blocks for current app"""
memory_blocks_dict: MutableMapping[str, str] = {}
is_draft = (self.application_generate_entity.invoke_from == InvokeFrom.DEBUGGER)
conversation_id = self.conversation.id
memory_block_specs = self._workflow.memory_blocks
# Get runtime memory values
memories = ChatflowMemoryService.get_memories_by_specs(
memory_block_specs=memory_block_specs,
tenant_id=self._workflow.tenant_id,
app_id=self._workflow.app_id,
node_id=None,
conversation_id=conversation_id,
is_draft=is_draft,
created_by=self._get_created_by(),
)
# Build memory_id -> value mapping
for memory in memories:
if memory.spec.scope == MemoryScope.APP:
# App level: use memory_id directly
memory_blocks_dict[memory.spec.id] = memory.value
else: # NODE scope
node_id = memory.node_id
if not node_id:
logger.warning("Memory block %s has no node_id, skip.", memory.spec.id)
continue
key = f"{node_id}.{memory.spec.id}"
memory_blocks_dict[key] = memory.value
return memory_blocks_dict
def _sync_conversation_to_chatflow_tables(self, assistant_message: str):
ChatflowHistoryService.save_app_message(
prompt_message=UserPromptMessage(content=(self.application_generate_entity.query)),
conversation_id=self.conversation.id,
app_id=self._workflow.app_id,
tenant_id=self._workflow.tenant_id
)
ChatflowHistoryService.save_app_message(
prompt_message=AssistantPromptMessage(content=assistant_message),
conversation_id=self.conversation.id,
app_id=self._workflow.app_id,
tenant_id=self._workflow.tenant_id
)
def _check_app_memory_updates(self, variable_pool: VariablePool):
is_draft = (self.application_generate_entity.invoke_from == InvokeFrom.DEBUGGER)
ChatflowMemoryService.update_app_memory_if_needed(
workflow=self._workflow,
conversation_id=self.conversation.id,
variable_pool=variable_pool,
is_draft=is_draft,
created_by=self._get_created_by()
)
def _get_created_by(self) -> MemoryCreatedBy:
if self.application_generate_entity.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}:
return MemoryCreatedBy(account_id=self.application_generate_entity.user_id)
else:
return MemoryCreatedBy(end_user_id=self.application_generate_entity.user_id)

View File

@ -61,9 +61,6 @@ class AppRunner:
if model_context_tokens is None:
return -1
if max_tokens is None:
max_tokens = 0
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
if prompt_tokens + max_tokens > model_context_tokens:

View File

@ -116,7 +116,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
rag_pipeline_variables = []
if workflow.rag_pipeline_variables:
for v in workflow.rag_pipeline_variables:
rag_pipeline_variable = RAGPipelineVariable(**v)
rag_pipeline_variable = RAGPipelineVariable.model_validate(v)
if (
rag_pipeline_variable.belong_to_node_id
in (self.application_generate_entity.start_node_id, "shared")

View File

@ -1,4 +1,4 @@
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
class I18nObject(BaseModel):
@ -11,11 +11,12 @@ class I18nObject(BaseModel):
pt_BR: str | None = Field(default=None)
ja_JP: str | None = Field(default=None)
def __init__(self, **data):
super().__init__(**data)
@model_validator(mode="after")
def _(self):
self.zh_Hans = self.zh_Hans or self.en_US
self.pt_BR = self.pt_BR or self.en_US
self.ja_JP = self.ja_JP or self.en_US
return self
def to_dict(self) -> dict:
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}

View File

@ -5,7 +5,7 @@ from collections import defaultdict
from collections.abc import Iterator, Sequence
from json import JSONDecodeError
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, model_validator
from sqlalchemy import func, select
from sqlalchemy.orm import Session
@ -73,9 +73,8 @@ class ProviderConfiguration(BaseModel):
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def __init__(self, **data):
super().__init__(**data)
@model_validator(mode="after")
def _(self):
if self.provider.provider not in original_provider_configurate_methods:
original_provider_configurate_methods[self.provider.provider] = []
for configurate_method in self.provider.configurate_methods:
@ -90,6 +89,7 @@ class ProviderConfiguration(BaseModel):
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods
):
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
return self
def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None:
"""

View File

@ -131,7 +131,7 @@ class CodeExecutor:
if (code := response_data.get("code")) != 0:
raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response_data.get('message')}")
response_code = CodeExecutionResponse(**response_data)
response_code = CodeExecutionResponse.model_validate(response_data)
if response_code.data.error:
raise CodeExecutionError(response_code.data.error)

View File

@ -26,7 +26,7 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
response = httpx.post(url, json={"plugin_ids": plugin_ids}, headers={"X-Dify-Version": dify_config.project.version})
response.raise_for_status()
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
return [MarketplacePluginDeclaration.model_validate(plugin) for plugin in response.json()["data"]["plugins"]]
def batch_fetch_plugin_manifests_ignore_deserialization_error(
@ -41,7 +41,7 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
result: list[MarketplacePluginDeclaration] = []
for plugin in response.json()["data"]["plugins"]:
try:
result.append(MarketplacePluginDeclaration(**plugin))
result.append(MarketplacePluginDeclaration.model_validate(plugin))
except Exception:
pass

View File

@ -20,7 +20,7 @@ from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@ -357,14 +357,16 @@ class IndexingRunner:
raise ValueError("no notion import info found")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
notion_info={
"credential_id": data_source_info["credential_id"],
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"document": dataset_document,
"tenant_id": dataset_document.tenant_id,
},
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info["credential_id"],
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"document": dataset_document,
"tenant_id": dataset_document.tenant_id,
}
),
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
@ -378,14 +380,16 @@ class IndexingRunner:
raise ValueError("no website import info found")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value,
website_info={
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"tenant_id": dataset_document.tenant_id,
"url": data_source_info["url"],
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
},
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"tenant_id": dataset_document.tenant_id,
"url": data_source_info["url"],
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
}
),
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])

View File

@ -14,10 +14,12 @@ from core.llm_generator.prompts import (
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
LLM_MODIFY_CODE_SYSTEM,
LLM_MODIFY_PROMPT_SYSTEM,
MEMORY_UPDATE_PROMPT,
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
SYSTEM_STRUCTURED_OUTPUT_GENERATE,
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
)
from core.memory.entities import MemoryBlock, MemoryBlockSpec
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
@ -27,6 +29,7 @@ from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.utils import measure_time
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
from extensions.ext_storage import storage
@ -560,3 +563,35 @@ class LLMGenerator:
"Failed to invoke LLM model, model: %s", json.dumps(model_config.get("name")), exc_info=True
)
return {"error": f"An unexpected error occurred: {str(e)}"}
@staticmethod
def update_memory_block(
tenant_id: str,
visible_history: Sequence[tuple[str, str]],
variable_pool: VariablePool,
memory_block: MemoryBlock,
memory_spec: MemoryBlockSpec
) -> str:
model_instance = ModelManager().get_model_instance(
tenant_id=tenant_id,
provider=memory_spec.model.provider,
model=memory_spec.model.name,
model_type=ModelType.LLM,
)
formatted_history = ""
for sender, message in visible_history:
formatted_history += f"{sender}: {message}\n"
filled_instruction = variable_pool.convert_template(memory_spec.instruction).text
formatted_prompt = PromptTemplateParser(MEMORY_UPDATE_PROMPT).format(
inputs={
"formatted_history": formatted_history,
"current_value": memory_block.value,
"instruction": filled_instruction,
}
)
llm_result = model_instance.invoke_llm(
prompt_messages=[UserPromptMessage(content=formatted_prompt)],
model_parameters=memory_spec.model.completion_params,
stream=False,
)
return llm_result.message.get_text_content()

View File

@ -422,3 +422,18 @@ INSTRUCTION_GENERATE_TEMPLATE_PROMPT = """The output of this prompt is not as ex
You should edit the prompt according to the IDEAL OUTPUT."""
INSTRUCTION_GENERATE_TEMPLATE_CODE = """Please fix the errors in the {{#error_message#}}."""
MEMORY_UPDATE_PROMPT = """
Based on the following conversation history, update the memory content:
Conversation history:
{{formatted_history}}
Current memory:
{{current_value}}
Update instruction:
{{instruction}}
Please output only the updated memory content, no other text like greeting:
"""

View File

@ -294,7 +294,7 @@ class ClientSession(
method="completion/complete",
params=types.CompleteRequestParams(
ref=ref,
argument=types.CompletionArgument(**argument),
argument=types.CompletionArgument.model_validate(argument),
),
)
),

119
api/core/memory/entities.py Normal file
View File

@ -0,0 +1,119 @@
from __future__ import annotations
from enum import StrEnum
from typing import TYPE_CHECKING, Optional
from uuid import uuid4
from pydantic import BaseModel, Field
if TYPE_CHECKING:
from core.app.app_config.entities import ModelConfig
class MemoryScope(StrEnum):
"""Memory scope determined by node_id field"""
APP = "app" # node_id is None
NODE = "node" # node_id is not None
class MemoryTerm(StrEnum):
"""Memory term determined by conversation_id field"""
SESSION = "session" # conversation_id is not None
PERSISTENT = "persistent" # conversation_id is None
class MemoryStrategy(StrEnum):
ON_TURNS = "on_turns"
class MemoryScheduleMode(StrEnum):
SYNC = "sync"
ASYNC = "async"
class MemoryBlockSpec(BaseModel):
"""Memory block specification for workflow configuration"""
id: str = Field(
default_factory=lambda: str(uuid4()),
description="Unique identifier for the memory block",
)
name: str = Field(description="Display name of the memory block")
description: str = Field(default="", description="Description of the memory block")
template: str = Field(description="Initial template content for the memory")
instruction: str = Field(description="Instructions for updating the memory")
scope: MemoryScope = Field(description="Scope of the memory (app or node level)")
term: MemoryTerm = Field(description="Term of the memory (session or persistent)")
strategy: MemoryStrategy = Field(description="Update strategy for the memory")
update_turns: int = Field(gt=0, description="Number of turns between updates")
preserved_turns: int = Field(gt=0, description="Number of conversation turns to preserve")
schedule_mode: MemoryScheduleMode = Field(description="Synchronous or asynchronous update mode")
model: ModelConfig = Field(description="Model configuration for memory updates")
end_user_visible: bool = Field(default=False, description="Whether memory is visible to end users")
end_user_editable: bool = Field(default=False, description="Whether memory is editable by end users")
class MemoryCreatedBy(BaseModel):
end_user_id: str | None = None
account_id: str | None = None
class MemoryBlock(BaseModel):
"""Runtime memory block instance
Design Rules:
- app_id = None: Global memory (future feature, not implemented yet)
- app_id = str: App-specific memory
- conversation_id = None: Persistent memory (cross-conversation)
- conversation_id = str: Session memory (conversation-specific)
- node_id = None: App-level scope
- node_id = str: Node-level scope
These rules implicitly determine scope and term without redundant storage.
"""
spec: MemoryBlockSpec
tenant_id: str
value: str
app_id: str
conversation_id: Optional[str] = None
node_id: Optional[str] = None
edited_by_user: bool = False
created_by: MemoryCreatedBy
version: int = Field(description="Memory block version number")
class MemoryValueData(BaseModel):
value: str
edited_by_user: bool = False
class ChatflowConversationMetadata(BaseModel):
"""Metadata for chatflow conversation with visible message count"""
type: str = "mutable_visible_window"
visible_count: int = Field(gt=0, description="Number of visible messages to keep")
class MemoryBlockWithConversation(MemoryBlock):
"""MemoryBlock with optional conversation metadata for session memories"""
conversation_metadata: ChatflowConversationMetadata = Field(
description="Conversation metadata, only present for session memories"
)
@classmethod
def from_memory_block(
cls,
memory_block: MemoryBlock,
conversation_metadata: ChatflowConversationMetadata
) -> MemoryBlockWithConversation:
"""Create MemoryBlockWithConversation from MemoryBlock"""
return cls(
spec=memory_block.spec,
tenant_id=memory_block.tenant_id,
value=memory_block.value,
app_id=memory_block.app_id,
conversation_id=memory_block.conversation_id,
node_id=memory_block.node_id,
edited_by_user=memory_block.edited_by_user,
created_by=memory_block.created_by,
version=memory_block.version,
conversation_metadata=conversation_metadata
)

View File

@ -0,0 +1,6 @@
class MemorySyncTimeoutError(Exception):
def __init__(self, app_id: str, conversation_id: str):
self.app_id = app_id
self.conversation_id = conversation_id
self.message = "Memory synchronization timeout after 50 seconds"
super().__init__(self.message)

View File

@ -1,4 +1,4 @@
from pydantic import BaseModel
from pydantic import BaseModel, model_validator
class I18nObject(BaseModel):
@ -9,7 +9,8 @@ class I18nObject(BaseModel):
zh_Hans: str | None = None
en_US: str
def __init__(self, **data):
super().__init__(**data)
@model_validator(mode="after")
def _(self):
if not self.zh_Hans:
self.zh_Hans = self.en_US
return self

View File

@ -1,7 +1,7 @@
from collections.abc import Sequence
from enum import Enum, StrEnum, auto
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
@ -46,10 +46,11 @@ class FormOption(BaseModel):
value: str
show_on: list[FormShowOnObject] = []
def __init__(self, **data):
super().__init__(**data)
@model_validator(mode="after")
def _(self):
if not self.label:
self.label = I18nObject(en_US=self.value)
return self
class CredentialFormSchema(BaseModel):

View File

@ -269,17 +269,17 @@ class ModelProviderFactory:
}
if model_type == ModelType.LLM:
return LargeLanguageModel(**init_params) # type: ignore
return LargeLanguageModel.model_validate(init_params)
elif model_type == ModelType.TEXT_EMBEDDING:
return TextEmbeddingModel(**init_params) # type: ignore
return TextEmbeddingModel.model_validate(init_params)
elif model_type == ModelType.RERANK:
return RerankModel(**init_params) # type: ignore
return RerankModel.model_validate(init_params)
elif model_type == ModelType.SPEECH2TEXT:
return Speech2TextModel(**init_params) # type: ignore
return Speech2TextModel.model_validate(init_params)
elif model_type == ModelType.MODERATION:
return ModerationModel(**init_params) # type: ignore
return ModerationModel.model_validate(init_params)
elif model_type == ModelType.TTS:
return TTSModel(**init_params) # type: ignore
return TTSModel.model_validate(init_params)
def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
"""

View File

@ -51,7 +51,7 @@ class ApiModeration(Moderation):
params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query)
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.model_dump())
return ModerationInputsResult(**result)
return ModerationInputsResult.model_validate(result)
return ModerationInputsResult(
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
@ -67,7 +67,7 @@ class ApiModeration(Moderation):
params = ModerationOutputParams(app_id=self.app_id, text=text)
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.model_dump())
return ModerationOutputsResult(**result)
return ModerationOutputsResult.model_validate(result)
return ModerationOutputsResult(
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response

View File

@ -84,15 +84,15 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
for i in range(len(v)):
if v[i]["role"] == PromptMessageRole.USER.value:
v[i] = UserPromptMessage(**v[i])
v[i] = UserPromptMessage.model_validate(v[i])
elif v[i]["role"] == PromptMessageRole.ASSISTANT.value:
v[i] = AssistantPromptMessage(**v[i])
v[i] = AssistantPromptMessage.model_validate(v[i])
elif v[i]["role"] == PromptMessageRole.SYSTEM.value:
v[i] = SystemPromptMessage(**v[i])
v[i] = SystemPromptMessage.model_validate(v[i])
elif v[i]["role"] == PromptMessageRole.TOOL.value:
v[i] = ToolPromptMessage(**v[i])
v[i] = ToolPromptMessage.model_validate(v[i])
else:
v[i] = PromptMessage(**v[i])
v[i] = PromptMessage.model_validate(v[i])
return v

View File

@ -94,7 +94,7 @@ class BasePluginClient:
self,
method: str,
path: str,
type: type[T],
type_: type[T],
headers: dict | None = None,
data: bytes | dict | None = None,
params: dict | None = None,
@ -104,13 +104,13 @@ class BasePluginClient:
Make a stream request to the plugin daemon inner API and yield the response as a model.
"""
for line in self._stream_request(method, path, params, headers, data, files):
yield type(**json.loads(line)) # type: ignore
yield type_(**json.loads(line)) # type: ignore
def _request_with_model(
self,
method: str,
path: str,
type: type[T],
type_: type[T],
headers: dict | None = None,
data: bytes | None = None,
params: dict | None = None,
@ -120,13 +120,13 @@ class BasePluginClient:
Make a request to the plugin daemon inner API and return the response as a model.
"""
response = self._request(method, path, headers, data, params, files)
return type(**response.json()) # type: ignore
return type_(**response.json()) # type: ignore
def _request_with_plugin_daemon_response(
self,
method: str,
path: str,
type: type[T],
type_: type[T],
headers: dict | None = None,
data: bytes | dict | None = None,
params: dict | None = None,
@ -140,22 +140,22 @@ class BasePluginClient:
response = self._request(method, path, headers, data, params, files)
response.raise_for_status()
except HTTPError as e:
msg = f"Failed to request plugin daemon, status: {e.response.status_code}, url: {path}"
logger.exception(msg)
logger.exception("Failed to request plugin daemon, status: %s, url: %s", e.response.status_code, path)
raise e
except Exception as e:
msg = f"Failed to request plugin daemon, url: {path}"
logger.exception(msg)
logger.exception("Failed to request plugin daemon, url: %s", path)
raise ValueError(msg) from e
try:
json_response = response.json()
if transformer:
json_response = transformer(json_response)
rep = PluginDaemonBasicResponse[type](**json_response) # type: ignore
# https://stackoverflow.com/questions/59634937/variable-foo-class-is-not-valid-as-type-but-why
rep = PluginDaemonBasicResponse[type_].model_validate(json_response) # type: ignore
except Exception:
msg = (
f"Failed to parse response from plugin daemon to PluginDaemonBasicResponse [{str(type.__name__)}],"
f"Failed to parse response from plugin daemon to PluginDaemonBasicResponse [{str(type_.__name__)}],"
f" url: {path}"
)
logger.exception(msg)
@ -163,7 +163,7 @@ class BasePluginClient:
if rep.code != 0:
try:
error = PluginDaemonError(**json.loads(rep.message))
error = PluginDaemonError.model_validate(json.loads(rep.message))
except Exception:
raise ValueError(f"{rep.message}, code: {rep.code}")
@ -178,7 +178,7 @@ class BasePluginClient:
self,
method: str,
path: str,
type: type[T],
type_: type[T],
headers: dict | None = None,
data: bytes | dict | None = None,
params: dict | None = None,
@ -189,7 +189,7 @@ class BasePluginClient:
"""
for line in self._stream_request(method, path, params, headers, data, files):
try:
rep = PluginDaemonBasicResponse[type].model_validate_json(line) # type: ignore
rep = PluginDaemonBasicResponse[type_].model_validate_json(line) # type: ignore
except (ValueError, TypeError):
# TODO modify this when line_data has code and message
try:
@ -204,7 +204,7 @@ class BasePluginClient:
if rep.code != 0:
if rep.code == -500:
try:
error = PluginDaemonError(**json.loads(rep.message))
error = PluginDaemonError.model_validate(json.loads(rep.message))
except Exception:
raise PluginDaemonInnerError(code=rep.code, message=rep.message)

View File

@ -46,7 +46,9 @@ class PluginDatasourceManager(BasePluginClient):
params={"page": 1, "page_size": 256},
transformer=transformer,
)
local_file_datasource_provider = PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
local_file_datasource_provider = PluginDatasourceProviderEntity.model_validate(
self._get_local_file_datasource_provider()
)
for provider in response:
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider)
@ -104,7 +106,7 @@ class PluginDatasourceManager(BasePluginClient):
Fetch datasource provider for the given tenant and plugin.
"""
if provider_id == "langgenius/file/file":
return PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
return PluginDatasourceProviderEntity.model_validate(self._get_local_file_datasource_provider())
tool_provider_id = DatasourceProviderID(provider_id)

View File

@ -162,7 +162,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/llm/invoke",
type=LLMResultChunk,
type_=LLMResultChunk,
data=jsonable_encoder(
{
"user_id": user_id,
@ -208,7 +208,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/llm/num_tokens",
type=PluginLLMNumTokensResponse,
type_=PluginLLMNumTokensResponse,
data=jsonable_encoder(
{
"user_id": user_id,
@ -250,7 +250,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke",
type=TextEmbeddingResult,
type_=TextEmbeddingResult,
data=jsonable_encoder(
{
"user_id": user_id,
@ -291,7 +291,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens",
type=PluginTextEmbeddingNumTokensResponse,
type_=PluginTextEmbeddingNumTokensResponse,
data=jsonable_encoder(
{
"user_id": user_id,
@ -334,7 +334,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/rerank/invoke",
type=RerankResult,
type_=RerankResult,
data=jsonable_encoder(
{
"user_id": user_id,
@ -378,7 +378,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/tts/invoke",
type=PluginStringResultResponse,
type_=PluginStringResultResponse,
data=jsonable_encoder(
{
"user_id": user_id,
@ -422,7 +422,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/tts/model/voices",
type=PluginVoicesResponse,
type_=PluginVoicesResponse,
data=jsonable_encoder(
{
"user_id": user_id,
@ -466,7 +466,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/speech2text/invoke",
type=PluginStringResultResponse,
type_=PluginStringResultResponse,
data=jsonable_encoder(
{
"user_id": user_id,
@ -506,7 +506,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/moderation/invoke",
type=PluginBasicBooleanResponse,
type_=PluginBasicBooleanResponse,
data=jsonable_encoder(
{
"user_id": user_id,

View File

@ -45,6 +45,12 @@ class MemoryConfig(BaseModel):
enabled: bool
size: int | None = None
mode: Literal["linear", "block"] | None = "linear"
block_id: list[str] | None = None
role_prefix: RolePrefix | None = None
window: WindowConfig
query_prompt_template: str | None = None
@property
def is_block_mode(self) -> bool:
return self.mode == "block" and bool(self.block_id)

View File

@ -134,7 +134,7 @@ class RetrievalService:
if not dataset:
return []
metadata_condition = (
MetadataCondition(**metadata_filtering_conditions) if metadata_filtering_conditions else None
MetadataCondition.model_validate(metadata_filtering_conditions) if metadata_filtering_conditions else None
)
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
dataset.tenant_id,

View File

@ -17,9 +17,6 @@ class NotionInfo(BaseModel):
tenant_id: str
model_config = ConfigDict(arbitrary_types_allowed=True)
def __init__(self, **data):
super().__init__(**data)
class WebsiteInfo(BaseModel):
"""
@ -47,6 +44,3 @@ class ExtractSetting(BaseModel):
website_info: WebsiteInfo | None = None
document_model: str | None = None
model_config = ConfigDict(arbitrary_types_allowed=True)
def __init__(self, **data):
super().__init__(**data)

View File

@ -38,11 +38,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
raise ValueError("No process rule found.")
if process_rule.get("mode") == "automatic":
automatic_rule = DatasetProcessRule.AUTOMATIC_RULES
rules = Rule(**automatic_rule)
rules = Rule.model_validate(automatic_rule)
else:
if not process_rule.get("rules"):
raise ValueError("No rules found in process rule.")
rules = Rule(**process_rule.get("rules"))
rules = Rule.model_validate(process_rule.get("rules"))
# Split the text documents into nodes.
if not rules.segmentation:
raise ValueError("No segmentation found in rules.")

View File

@ -40,7 +40,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
raise ValueError("No process rule found.")
if not process_rule.get("rules"):
raise ValueError("No rules found in process rule.")
rules = Rule(**process_rule.get("rules"))
rules = Rule.model_validate(process_rule.get("rules"))
all_documents: list[Document] = []
if rules.parent_mode == ParentMode.PARAGRAPH:
# Split the text documents into nodes.
@ -110,7 +110,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
child_documents = document.children
if child_documents:
formatted_child_documents = [
Document(**child_document.model_dump()) for child_document in child_documents
Document.model_validate(child_document.model_dump()) for child_document in child_documents
]
vector.create(formatted_child_documents)
@ -224,7 +224,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
return child_nodes
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
parent_childs = ParentChildStructureChunk(**chunks)
parent_childs = ParentChildStructureChunk.model_validate(chunks)
documents = []
for parent_child in parent_childs.parent_child_chunks:
metadata = {
@ -274,7 +274,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
vector.create(all_child_documents)
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
parent_childs = ParentChildStructureChunk(**chunks)
parent_childs = ParentChildStructureChunk.model_validate(chunks)
preview = []
for parent_child in parent_childs.parent_child_chunks:
preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents})

View File

@ -47,7 +47,7 @@ class QAIndexProcessor(BaseIndexProcessor):
raise ValueError("No process rule found.")
if not process_rule.get("rules"):
raise ValueError("No rules found in process rule.")
rules = Rule(**process_rule.get("rules"))
rules = Rule.model_validate(process_rule.get("rules"))
splitter = self._get_splitter(
processing_rule_mode=process_rule.get("mode"),
max_tokens=rules.segmentation.max_tokens if rules.segmentation else 0,
@ -168,7 +168,7 @@ class QAIndexProcessor(BaseIndexProcessor):
return docs
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
qa_chunks = QAStructureChunk(**chunks)
qa_chunks = QAStructureChunk.model_validate(chunks)
documents = []
for qa_chunk in qa_chunks.qa_chunks:
metadata = {
@ -191,7 +191,7 @@ class QAIndexProcessor(BaseIndexProcessor):
raise ValueError("Indexing technique must be high quality.")
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
qa_chunks = QAStructureChunk(**chunks)
qa_chunks = QAStructureChunk.model_validate(chunks)
preview = []
for qa_chunk in qa_chunks.qa_chunks:
preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer})

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import re
from typing import Any
from core.model_manager import ModelInstance
@ -52,7 +53,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
"""Create a new TextSplitter."""
super().__init__(**kwargs)
self._fixed_separator = fixed_separator
self._separators = separators or ["\n\n", "\n", " ", ""]
self._separators = separators or ["\n\n", "\n", "", ". ", " ", ""]
def split_text(self, text: str) -> list[str]:
"""Split incoming text and return chunks."""
@ -90,16 +91,19 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
# Now that we have the separator, split the text
if separator:
if separator == " ":
splits = text.split()
splits = re.split(r" +", text)
else:
splits = text.split(separator)
splits = [item + separator if i < len(splits) else item for i, item in enumerate(splits)]
else:
splits = list(text)
splits = [s for s in splits if (s not in {"", "\n"})]
if separator == "\n":
splits = [s for s in splits if s != ""]
else:
splits = [s for s in splits if (s not in {"", "\n"})]
_good_splits = []
_good_splits_lengths = [] # cache the lengths of the splits
_separator = "" if self._keep_separator else separator
_separator = separator if self._keep_separator else ""
s_lens = self._length_function(splits)
if separator != "":
for s, s_len in zip(splits, s_lens):

View File

@ -90,7 +90,7 @@ class BuiltinToolProviderController(ToolProviderController):
tools.append(
assistant_tool_class(
provider=provider,
entity=ToolEntity(**tool),
entity=ToolEntity.model_validate(tool),
runtime=ToolRuntime(tenant_id=""),
)
)

View File

@ -1,4 +1,4 @@
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
class I18nObject(BaseModel):
@ -11,11 +11,12 @@ class I18nObject(BaseModel):
pt_BR: str | None = Field(default=None)
ja_JP: str | None = Field(default=None)
def __init__(self, **data):
super().__init__(**data)
@model_validator(mode="after")
def _populate_missing_locales(self):
self.zh_Hans = self.zh_Hans or self.en_US
self.pt_BR = self.pt_BR or self.en_US
self.ja_JP = self.ja_JP or self.en_US
return self
def to_dict(self):
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}

View File

@ -54,7 +54,7 @@ class MCPToolProviderController(ToolProviderController):
"""
tools = []
tools_data = json.loads(db_provider.tools)
remote_mcp_tools = [RemoteMCPTool(**tool) for tool in tools_data]
remote_mcp_tools = [RemoteMCPTool.model_validate(tool) for tool in tools_data]
user = db_provider.load_user()
tools = [
ToolEntity(

View File

@ -1008,7 +1008,7 @@ class ToolManager:
config = tool_configurations.get(parameter.name, {})
if not (config and isinstance(config, dict) and config.get("value") is not None):
continue
tool_input = ToolNodeData.ToolInput(**tool_configurations.get(parameter.name, {}))
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
if tool_input.type == "variable":
variable = variable_pool.get(tool_input.value)
if variable is None:

View File

@ -126,7 +126,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
data_source_type=document.data_source_type,
segment_id=segment.id,
retriever_from=self.retriever_from,
score=document_score_list.get(segment.index_node_id, None),
score=document_score_list.get(segment.index_node_id),
doc_metadata=document.doc_metadata,
)

View File

@ -203,6 +203,48 @@ class ArrayFileSegment(ArraySegment):
return ""
class VersionedMemoryValue(BaseModel):
current_value: str = None # type: ignore
versions: Mapping[str, str] = {}
model_config = ConfigDict(frozen=True)
def add_version(
self,
new_value: str,
version_name: str | None = None
) -> "VersionedMemoryValue":
if version_name is None:
version_name = str(len(self.versions) + 1)
if version_name in self.versions:
raise ValueError(f"Version '{version_name}' already exists.")
self.current_value = new_value
return VersionedMemoryValue(
current_value=new_value,
versions={
version_name: new_value,
**self.versions,
}
)
class VersionedMemorySegment(Segment):
value_type: SegmentType = SegmentType.VERSIONED_MEMORY
value: VersionedMemoryValue = None # type: ignore
@property
def text(self) -> str:
return self.value.current_value
@property
def log(self) -> str:
return self.value.current_value
@property
def markdown(self) -> str:
return self.value.current_value
class ArrayBooleanSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_BOOLEAN
value: Sequence[bool] = None # type: ignore
@ -248,6 +290,7 @@ SegmentUnion: TypeAlias = Annotated[
| Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
| Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
| Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)]
| Annotated[VersionedMemorySegment, Tag(SegmentType.VERSIONED_MEMORY)]
),
Discriminator(get_segment_discriminator),
]

View File

@ -41,6 +41,8 @@ class SegmentType(StrEnum):
ARRAY_FILE = "array[file]"
ARRAY_BOOLEAN = "array[boolean]"
VERSIONED_MEMORY = "versioned_memory"
NONE = "none"
GROUP = "group"

View File

@ -22,6 +22,7 @@ from .segments import (
ObjectSegment,
Segment,
StringSegment,
VersionedMemorySegment,
get_segment_discriminator,
)
from .types import SegmentType
@ -106,6 +107,10 @@ class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
pass
class VersionedMemoryVariable(VersionedMemorySegment, Variable):
pass
class ArrayBooleanVariable(ArrayBooleanSegment, ArrayVariable):
pass
@ -161,6 +166,7 @@ VariableUnion: TypeAlias = Annotated[
| Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)]
| Annotated[ArrayBooleanVariable, Tag(SegmentType.ARRAY_BOOLEAN)]
| Annotated[SecretVariable, Tag(SegmentType.SECRET)]
| Annotated[VersionedMemoryVariable, Tag(SegmentType.VERSIONED_MEMORY)]
),
Discriminator(get_segment_discriminator),
]

View File

@ -1,4 +1,5 @@
SYSTEM_VARIABLE_NODE_ID = "sys"
ENVIRONMENT_VARIABLE_NODE_ID = "env"
CONVERSATION_VARIABLE_NODE_ID = "conversation"
MEMORY_BLOCK_VARIABLE_NODE_ID = "memory_block"
RAG_PIPELINE_VARIABLE_NODE_ID = "rag"

View File

@ -55,6 +55,10 @@ class VariablePool(BaseModel):
description="RAG pipeline variables.",
default_factory=list,
)
memory_blocks: Mapping[str, str] = Field(
description="Memory blocks.",
default_factory=dict,
)
def model_post_init(self, context: Any, /):
# Create a mapping from field names to SystemVariableKey enum values
@ -75,6 +79,9 @@ class VariablePool(BaseModel):
rag_pipeline_variables_map[node_id][key] = value
for key, value in rag_pipeline_variables_map.items():
self.add((RAG_PIPELINE_VARIABLE_NODE_ID, key), value)
# Add memory blocks to the variable pool
for memory_id, memory_value in self.memory_blocks.items():
self.add([CONVERSATION_VARIABLE_NODE_ID, memory_id], memory_value)
def add(self, selector: Sequence[str], value: Any, /):
"""

View File

@ -105,10 +105,10 @@ class RedisChannel:
command_type = CommandType(command_type_value)
if command_type == CommandType.ABORT:
return AbortCommand(**data)
return AbortCommand.model_validate(data)
else:
# For other command types, use base class
return GraphEngineCommand(**data)
return GraphEngineCommand.model_validate(data)
except (ValueError, TypeError):
return None

View File

@ -16,7 +16,7 @@ class EndNode(Node):
_node_data: EndNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = EndNodeData(**data)
self._node_data = EndNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy

View File

@ -342,10 +342,13 @@ class IterationNode(Node):
iterator_list_value: Sequence[object],
iter_run_map: dict[str, float],
) -> Generator[NodeEventBase, None, None]:
# Flatten the list of lists if all outputs are lists
flattened_outputs = self._flatten_outputs_if_needed(outputs)
yield IterationSucceededEvent(
start_at=started_at,
inputs=inputs,
outputs={"output": outputs},
outputs={"output": flattened_outputs},
steps=len(iterator_list_value),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
@ -357,13 +360,39 @@ class IterationNode(Node):
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": outputs},
outputs={"output": flattened_outputs},
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
},
)
)
def _flatten_outputs_if_needed(self, outputs: list[object]) -> list[object]:
"""
Flatten the outputs list if all elements are lists.
This maintains backward compatibility with version 1.8.1 behavior.
"""
if not outputs:
return outputs
# Check if all non-None outputs are lists
non_none_outputs = [output for output in outputs if output is not None]
if not non_none_outputs:
return outputs
if all(isinstance(output, list) for output in non_none_outputs):
# Flatten the list of lists
flattened: list[Any] = []
for output in outputs:
if isinstance(output, list):
flattened.extend(output)
elif output is not None:
# This shouldn't happen based on our check, but handle it gracefully
flattened.append(output)
return flattened
return outputs
def _handle_iteration_failure(
self,
started_at: datetime,
@ -373,10 +402,13 @@ class IterationNode(Node):
iter_run_map: dict[str, float],
error: IterationNodeError,
) -> Generator[NodeEventBase, None, None]:
# Flatten the list of lists if all outputs are lists (even in failure case)
flattened_outputs = self._flatten_outputs_if_needed(outputs)
yield IterationFailedEvent(
start_at=started_at,
inputs=inputs,
outputs={"output": outputs},
outputs={"output": flattened_outputs},
steps=len(iterator_list_value),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,

View File

@ -18,7 +18,7 @@ class IterationStartNode(Node):
_node_data: IterationStartNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = IterationStartNodeData(**data)
self._node_data = IterationStartNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy

View File

@ -41,7 +41,7 @@ class ListOperatorNode(Node):
_node_data: ListOperatorNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = ListOperatorNodeData(**data)
self._node_data = ListOperatorNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy

View File

@ -6,11 +6,15 @@ import re
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.file import FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.memory.entities import MemoryCreatedBy, MemoryScope
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities import (
@ -71,6 +75,8 @@ from core.workflow.node_events import (
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from models import UserFrom, Workflow
from models.engine import db
from . import llm_utils
from .entities import (
@ -315,6 +321,11 @@ class LLMNode(Node):
if self._file_outputs:
outputs["files"] = ArrayFileSegment(value=self._file_outputs)
try:
self._handle_chatflow_memory(result_text, variable_pool)
except Exception as e:
logger.warning("Memory orchestration failed for node %s: %s", self.node_id, str(e))
# Send final chunk event to indicate streaming is complete
yield StreamChunkEvent(
selector=[self._node_id, "text"],
@ -1184,6 +1195,79 @@ class LLMNode(Node):
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled
def _handle_chatflow_memory(self, llm_output: str, variable_pool: VariablePool):
if not self._node_data.memory or self._node_data.memory.mode != "block":
return
conversation_id_segment = variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.CONVERSATION_ID))
if not conversation_id_segment:
raise ValueError("Conversation ID not found in variable pool.")
conversation_id = conversation_id_segment.text
user_query_segment = variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
if not user_query_segment:
raise ValueError("User query not found in variable pool.")
user_query = user_query_segment.text
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
from services.chatflow_history_service import ChatflowHistoryService
ChatflowHistoryService.save_node_message(
prompt_message=(UserPromptMessage(content=user_query)),
node_id=self.node_id,
conversation_id=conversation_id,
app_id=self.app_id,
tenant_id=self.tenant_id
)
ChatflowHistoryService.save_node_message(
prompt_message=(AssistantPromptMessage(content=llm_output)),
node_id=self.node_id,
conversation_id=conversation_id,
app_id=self.app_id,
tenant_id=self.tenant_id
)
memory_config = self._node_data.memory
if not memory_config:
return
block_ids = memory_config.block_id
if not block_ids:
return
# FIXME: This is dirty workaround and may cause incorrect resolution for workflow version
with Session(db.engine) as session:
stmt = select(Workflow).where(
Workflow.tenant_id == self.tenant_id,
Workflow.app_id == self.app_id
)
workflow = session.scalars(stmt).first()
if not workflow:
raise ValueError("Workflow not found.")
memory_blocks = workflow.memory_blocks
for block_id in block_ids:
memory_block_spec = next((block for block in memory_blocks if block.id == block_id), None)
if memory_block_spec and memory_block_spec.scope == MemoryScope.NODE:
is_draft = (self.invoke_from == InvokeFrom.DEBUGGER)
from services.chatflow_memory_service import ChatflowMemoryService
ChatflowMemoryService.update_node_memory_if_needed(
tenant_id=self.tenant_id,
app_id=self.app_id,
node_id=self.id,
conversation_id=conversation_id,
memory_block_spec=memory_block_spec,
variable_pool=variable_pool,
is_draft=is_draft,
created_by=self._get_user_from_context()
)
def _get_user_from_context(self) -> MemoryCreatedBy:
if self.user_from == UserFrom.ACCOUNT:
return MemoryCreatedBy(account_id=self.user_id)
else:
return MemoryCreatedBy(end_user_id=self.user_id)
def _combine_message_content_with_role(
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole

View File

@ -18,7 +18,7 @@ class LoopEndNode(Node):
_node_data: LoopEndNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopEndNodeData(**data)
self._node_data = LoopEndNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy

View File

@ -18,7 +18,7 @@ class LoopStartNode(Node):
_node_data: LoopStartNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopStartNodeData(**data)
self._node_data = LoopStartNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy

View File

@ -16,7 +16,7 @@ class StartNode(Node):
_node_data: StartNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = StartNodeData(**data)
self._node_data = StartNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy

View File

@ -15,7 +15,7 @@ class VariableAggregatorNode(Node):
_node_data: VariableAssignerNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = VariableAssignerNodeData(**data)
self._node_data = VariableAssignerNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy

View File

@ -14,7 +14,7 @@ def handle(sender, **kwargs):
for node_data in synced_draft_workflow.graph_dict.get("nodes", []):
if node_data.get("data", {}).get("type") == NodeType.TOOL.value:
try:
tool_entity = ToolEntity(**node_data["data"])
tool_entity = ToolEntity.model_validate(node_data["data"])
tool_runtime = ToolManager.get_tool_runtime(
provider_type=tool_entity.provider_type,
provider_id=tool_entity.provider_id,

View File

@ -61,7 +61,7 @@ def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set[str]:
for node in knowledge_retrieval_nodes:
try:
node_data = KnowledgeRetrievalNodeData(**node.get("data", {}))
node_data = KnowledgeRetrievalNodeData.model_validate(node.get("data", {}))
dataset_ids.update(dataset_id for dataset_id in node_data.dataset_ids)
except Exception:
continue

View File

@ -21,6 +21,8 @@ from core.variables.segments import (
ObjectSegment,
Segment,
StringSegment,
VersionedMemorySegment,
VersionedMemoryValue,
)
from core.variables.types import SegmentType
from core.variables.variables import (
@ -39,6 +41,7 @@ from core.variables.variables import (
SecretVariable,
StringVariable,
Variable,
VersionedMemoryVariable,
)
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
@ -69,6 +72,7 @@ SEGMENT_TO_VARIABLE_MAP = {
NoneSegment: NoneVariable,
ObjectSegment: ObjectVariable,
StringSegment: StringVariable,
VersionedMemorySegment: VersionedMemoryVariable
}
@ -193,6 +197,7 @@ _segment_factory: Mapping[SegmentType, type[Segment]] = {
SegmentType.FILE: FileSegment,
SegmentType.BOOLEAN: BooleanSegment,
SegmentType.OBJECT: ObjectSegment,
SegmentType.VERSIONED_MEMORY: VersionedMemorySegment,
# Array types
SegmentType.ARRAY_ANY: ArrayAnySegment,
SegmentType.ARRAY_STRING: ArrayStringSegment,
@ -259,6 +264,12 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
else:
raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list")
if segment_type == SegmentType.VERSIONED_MEMORY:
return VersionedMemorySegment(
value_type=segment_type,
value=VersionedMemoryValue.model_validate(value)
)
inferred_type = SegmentType.infer_segment_type(value)
# Type compatibility checking
if inferred_type is None:

View File

@ -0,0 +1,104 @@
"""add_chatflow_memory_tables
Revision ID: d00b2b40ea3e
Revises: 68519ad5cd18
Create Date: 2025-10-11 15:29:20.244675
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'd00b2b40ea3e'
down_revision = '68519ad5cd18'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('chatflow_conversations',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('node_id', sa.Text(), nullable=True),
sa.Column('original_conversation_id', models.types.StringUUID(), nullable=True),
sa.Column('conversation_metadata', sa.Text(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='chatflow_conversations_pkey')
)
with op.batch_alter_table('chatflow_conversations', schema=None) as batch_op:
batch_op.create_index('chatflow_conversations_original_conversation_id_idx', ['tenant_id', 'app_id', 'node_id', 'original_conversation_id'], unique=False)
op.create_table('chatflow_memory_variables',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=True),
sa.Column('conversation_id', models.types.StringUUID(), nullable=True),
sa.Column('node_id', sa.Text(), nullable=True),
sa.Column('memory_id', sa.Text(), nullable=False),
sa.Column('value', sa.Text(), nullable=False),
sa.Column('name', sa.Text(), nullable=False),
sa.Column('scope', sa.String(length=10), nullable=False),
sa.Column('term', sa.String(length=20), nullable=False),
sa.Column('version', sa.Integer(), nullable=False),
sa.Column('created_by_role', sa.String(length=20), nullable=False),
sa.Column('created_by', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='chatflow_memory_variables_pkey')
)
with op.batch_alter_table('chatflow_memory_variables', schema=None) as batch_op:
batch_op.create_index('chatflow_memory_variables_memory_id_idx', ['tenant_id', 'app_id', 'node_id', 'memory_id'], unique=False)
op.create_table('chatflow_messages',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('conversation_id', models.types.StringUUID(), nullable=False),
sa.Column('index', sa.Integer(), nullable=False),
sa.Column('version', sa.Integer(), nullable=False),
sa.Column('data', sa.Text(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='chatflow_messages_pkey')
)
with op.batch_alter_table('chatflow_messages', schema=None) as batch_op:
batch_op.create_index('chatflow_messages_version_idx', ['conversation_id', 'index', 'version'], unique=False)
with op.batch_alter_table('datasource_providers', schema=None) as batch_op:
batch_op.alter_column('avatar_url',
existing_type=sa.TEXT(),
type_=sa.String(length=255),
existing_nullable=True)
with op.batch_alter_table('providers', schema=None) as batch_op:
batch_op.drop_column('credential_status')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('credential_status', sa.VARCHAR(length=20), server_default=sa.text("'active'::character varying"), autoincrement=False, nullable=True))
with op.batch_alter_table('datasource_providers', schema=None) as batch_op:
batch_op.alter_column('avatar_url',
existing_type=sa.String(length=255),
type_=sa.TEXT(),
existing_nullable=True)
with op.batch_alter_table('chatflow_messages', schema=None) as batch_op:
batch_op.drop_index('chatflow_messages_version_idx')
op.drop_table('chatflow_messages')
with op.batch_alter_table('chatflow_memory_variables', schema=None) as batch_op:
batch_op.drop_index('chatflow_memory_variables_memory_id_idx')
op.drop_table('chatflow_memory_variables')
with op.batch_alter_table('chatflow_conversations', schema=None) as batch_op:
batch_op.drop_index('chatflow_conversations_original_conversation_id_idx')
op.drop_table('chatflow_conversations')
# ### end Alembic commands ###

View File

@ -9,6 +9,7 @@ from .account import (
TenantStatus,
)
from .api_based_extension import APIBasedExtension, APIBasedExtensionPoint
from .chatflow_memory import ChatflowConversation, ChatflowMemoryVariable, ChatflowMessage
from .comment import (
WorkflowComment,
WorkflowCommentMention,
@ -121,6 +122,9 @@ __all__ = [
"BuiltinToolProvider",
"CeleryTask",
"CeleryTaskSet",
"ChatflowConversation",
"ChatflowMemoryVariable",
"ChatflowMessage",
"Conversation",
"ConversationVariable",
"CreatorUserRole",

View File

@ -1,15 +1,16 @@
import enum
import json
from dataclasses import field
from datetime import datetime
from typing import Any, Optional
import sqlalchemy as sa
from flask_login import UserMixin # type: ignore[import-untyped]
from sqlalchemy import DateTime, String, func, select
from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor
from sqlalchemy.orm import Mapped, Session, mapped_column
from typing_extensions import deprecated
from models.base import Base
from models.base import TypeBase
from .engine import db
from .types import StringUUID
@ -83,31 +84,37 @@ class AccountStatus(enum.StrEnum):
CLOSED = "closed"
class Account(UserMixin, Base):
class Account(UserMixin, TypeBase):
__tablename__ = "accounts"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="account_pkey"), sa.Index("account_email_idx", "email"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
name: Mapped[str] = mapped_column(String(255))
email: Mapped[str] = mapped_column(String(255))
password: Mapped[str | None] = mapped_column(String(255))
password_salt: Mapped[str | None] = mapped_column(String(255))
avatar: Mapped[str | None] = mapped_column(String(255), nullable=True)
interface_language: Mapped[str | None] = mapped_column(String(255))
interface_theme: Mapped[str | None] = mapped_column(String(255), nullable=True)
timezone: Mapped[str | None] = mapped_column(String(255))
last_login_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
last_login_ip: Mapped[str | None] = mapped_column(String(255), nullable=True)
last_active_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'active'::character varying"))
initialized_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
password: Mapped[str | None] = mapped_column(String(255), default=None)
password_salt: Mapped[str | None] = mapped_column(String(255), default=None)
avatar: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
interface_language: Mapped[str | None] = mapped_column(String(255), default=None)
interface_theme: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
timezone: Mapped[str | None] = mapped_column(String(255), default=None)
last_login_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
last_login_ip: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
last_active_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
status: Mapped[str] = mapped_column(
String(16), server_default=sa.text("'active'::character varying"), default="active"
)
initialized_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
@reconstructor
def init_on_load(self):
self.role: TenantAccountRole | None = None
self._current_tenant: Tenant | None = None
role: TenantAccountRole | None = field(default=None, init=False)
_current_tenant: "Tenant | None" = field(default=None, init=False)
@property
def is_password_set(self):
@ -226,18 +233,24 @@ class TenantStatus(enum.StrEnum):
ARCHIVE = "archive"
class Tenant(Base):
class Tenant(TypeBase):
__tablename__ = "tenants"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="tenant_pkey"),)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
name: Mapped[str] = mapped_column(String(255))
encrypt_public_key: Mapped[str | None] = mapped_column(sa.Text)
plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'::character varying"))
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying"))
custom_config: Mapped[str | None] = mapped_column(sa.Text)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
encrypt_public_key: Mapped[str | None] = mapped_column(sa.Text, default=None)
plan: Mapped[str] = mapped_column(
String(255), server_default=sa.text("'basic'::character varying"), default="basic"
)
status: Mapped[str] = mapped_column(
String(255), server_default=sa.text("'normal'::character varying"), default="normal"
)
custom_config: Mapped[str | None] = mapped_column(sa.Text, default=None)
created_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), init=False)
def get_accounts(self) -> list[Account]:
return list(
@ -257,7 +270,7 @@ class Tenant(Base):
self.custom_config = json.dumps(value)
class TenantAccountJoin(Base):
class TenantAccountJoin(TypeBase):
__tablename__ = "tenant_account_joins"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"),
@ -266,17 +279,21 @@ class TenantAccountJoin(Base):
sa.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID)
account_id: Mapped[str] = mapped_column(StringUUID)
current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"))
role: Mapped[str] = mapped_column(String(16), server_default="normal")
invited_by: Mapped[str | None] = mapped_column(StringUUID)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"), default=False)
role: Mapped[str] = mapped_column(String(16), server_default="normal", default="normal")
invited_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
class AccountIntegrate(Base):
class AccountIntegrate(TypeBase):
__tablename__ = "account_integrates"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="account_integrate_pkey"),
@ -284,16 +301,20 @@ class AccountIntegrate(Base):
sa.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
account_id: Mapped[str] = mapped_column(StringUUID)
provider: Mapped[str] = mapped_column(String(16))
open_id: Mapped[str] = mapped_column(String(255))
encrypted_token: Mapped[str] = mapped_column(String(255))
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
class InvitationCode(Base):
class InvitationCode(TypeBase):
__tablename__ = "invitation_codes"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="invitation_code_pkey"),
@ -301,18 +322,22 @@ class InvitationCode(Base):
sa.Index("invitation_codes_code_idx", "code", "status"),
)
id: Mapped[int] = mapped_column(sa.Integer)
id: Mapped[int] = mapped_column(sa.Integer, init=False)
batch: Mapped[str] = mapped_column(String(255))
code: Mapped[str] = mapped_column(String(32))
status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'unused'::character varying"))
used_at: Mapped[datetime | None] = mapped_column(DateTime)
used_by_tenant_id: Mapped[str | None] = mapped_column(StringUUID)
used_by_account_id: Mapped[str | None] = mapped_column(StringUUID)
deprecated_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
status: Mapped[str] = mapped_column(
String(16), server_default=sa.text("'unused'::character varying"), default="unused"
)
used_at: Mapped[datetime | None] = mapped_column(DateTime, default=None)
used_by_tenant_id: Mapped[str | None] = mapped_column(StringUUID, default=None)
used_by_account_id: Mapped[str | None] = mapped_column(StringUUID, default=None)
deprecated_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
DateTime, server_default=sa.text("CURRENT_TIMESTAMP(0)"), nullable=False, init=False
)
class TenantPluginPermission(Base):
class TenantPluginPermission(TypeBase):
class InstallPermission(enum.StrEnum):
EVERYONE = "everyone"
ADMINS = "admins"
@ -329,13 +354,17 @@ class TenantPluginPermission(Base):
sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
install_permission: Mapped[InstallPermission] = mapped_column(String(16), nullable=False, server_default="everyone")
debug_permission: Mapped[DebugPermission] = mapped_column(String(16), nullable=False, server_default="noone")
install_permission: Mapped[InstallPermission] = mapped_column(
String(16), nullable=False, server_default="everyone", default=InstallPermission.EVERYONE
)
debug_permission: Mapped[DebugPermission] = mapped_column(
String(16), nullable=False, server_default="noone", default=DebugPermission.NOBODY
)
class TenantPluginAutoUpgradeStrategy(Base):
class TenantPluginAutoUpgradeStrategy(TypeBase):
class StrategySetting(enum.StrEnum):
DISABLED = "disabled"
FIX_ONLY = "fix_only"
@ -352,12 +381,20 @@ class TenantPluginAutoUpgradeStrategy(Base):
sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
strategy_setting: Mapped[StrategySetting] = mapped_column(String(16), nullable=False, server_default="fix_only")
upgrade_time_of_day: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) # seconds of the day
upgrade_mode: Mapped[UpgradeMode] = mapped_column(String(16), nullable=False, server_default="exclude")
exclude_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name)
include_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
strategy_setting: Mapped[StrategySetting] = mapped_column(
String(16), nullable=False, server_default="fix_only", default=StrategySetting.FIX_ONLY
)
upgrade_mode: Mapped[UpgradeMode] = mapped_column(
String(16), nullable=False, server_default="exclude", default=UpgradeMode.EXCLUDE
)
exclude_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False, default_factory=list)
include_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False, default_factory=list)
upgrade_time_of_day: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)

View File

@ -0,0 +1,76 @@
from datetime import datetime
import sqlalchemy as sa
from sqlalchemy import DateTime, func
from sqlalchemy.orm import Mapped, mapped_column
from .base import Base
from .types import StringUUID
class ChatflowMemoryVariable(Base):
__tablename__ = "chatflow_memory_variables"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="chatflow_memory_variables_pkey"),
sa.Index("chatflow_memory_variables_memory_id_idx", "tenant_id", "app_id", "node_id", "memory_id"),
)
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
conversation_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
node_id: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
memory_id: Mapped[str] = mapped_column(sa.Text, nullable=False)
value: Mapped[str] = mapped_column(sa.Text, nullable=False)
name: Mapped[str] = mapped_column(sa.Text, nullable=False)
scope: Mapped[str] = mapped_column(sa.String(10), nullable=False) # 'app' or 'node'
term: Mapped[str] = mapped_column(sa.String(20), nullable=False) # 'session' or 'persistent'
version: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=1)
created_by_role: Mapped[str] = mapped_column(sa.String(20)) # 'end_user' or 'account`
created_by: Mapped[str] = mapped_column(StringUUID)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
class ChatflowConversation(Base):
__tablename__ = "chatflow_conversations"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="chatflow_conversations_pkey"),
sa.Index(
"chatflow_conversations_original_conversation_id_idx",
"tenant_id", "app_id", "node_id", "original_conversation_id"
),
)
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
node_id: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
original_conversation_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
conversation_metadata: Mapped[str] = mapped_column(sa.Text, nullable=False) # JSON
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
class ChatflowMessage(Base):
__tablename__ = "chatflow_messages"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="chatflow_messages_pkey"),
sa.Index("chatflow_messages_version_idx", "conversation_id", "index", "version"),
)
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
index: Mapped[int] = mapped_column(sa.Integer, nullable=False) # This index starts from 0
version: Mapped[int] = mapped_column(sa.Integer, nullable=False)
data: Mapped[str] = mapped_column(sa.Text, nullable=False) # Serialized PromptMessage JSON
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)

View File

@ -754,7 +754,7 @@ class DocumentSegment(Base):
if process_rule and process_rule.mode == "hierarchical":
rules_dict = process_rule.rules_dict
if rules_dict:
rules = Rule(**rules_dict)
rules = Rule.model_validate(rules_dict)
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
child_chunks = (
db.session.query(ChildChunk)
@ -772,7 +772,7 @@ class DocumentSegment(Base):
if process_rule and process_rule.mode == "hierarchical":
rules_dict = process_rule.rules_dict
if rules_dict:
rules = Rule(**rules_dict)
rules = Rule.model_validate(rules_dict)
if rules.parent_mode:
child_chunks = (
db.session.query(ChildChunk)

View File

@ -23,6 +23,7 @@ class DraftVariableType(StrEnum):
NODE = "node"
SYS = "sys"
CONVERSATION = "conversation"
MEMORY_BLOCK = "memory_block"
class MessageStatus(StrEnum):

View File

@ -152,7 +152,7 @@ class ApiToolProvider(Base):
def tools(self) -> list["ApiToolBundle"]:
from core.tools.entities.tool_bundle import ApiToolBundle
return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)]
return [ApiToolBundle.model_validate(tool) for tool in json.loads(self.tools_str)]
@property
def credentials(self) -> dict[str, Any]:
@ -242,7 +242,10 @@ class WorkflowToolProvider(Base):
def parameter_configurations(self) -> list["WorkflowToolParameterConfiguration"]:
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
return [WorkflowToolParameterConfiguration(**config) for config in json.loads(self.parameter_configuration)]
return [
WorkflowToolParameterConfiguration.model_validate(config)
for config in json.loads(self.parameter_configuration)
]
@property
def app(self) -> App | None:
@ -312,7 +315,7 @@ class MCPToolProvider(Base):
def mcp_tools(self) -> list["MCPTool"]:
from core.mcp.types import Tool as MCPTool
return [MCPTool(**tool) for tool in json.loads(self.tools)]
return [MCPTool.model_validate(tool) for tool in json.loads(self.tools)]
@property
def provider_icon(self) -> Mapping[str, str] | str:
@ -552,4 +555,4 @@ class DeprecatedPublishedAppTool(Base):
def description_i18n(self) -> "I18nObject":
from core.tools.entities.common_entities import I18nObject
return I18nObject(**json.loads(self.description))
return I18nObject.model_validate(json.loads(self.description))

View File

@ -1,5 +1,6 @@
import json
import logging
import uuid
from collections.abc import Mapping, Sequence
from datetime import datetime
from enum import StrEnum
@ -11,9 +12,15 @@ from sqlalchemy import DateTime, Select, exists, orm, select
from core.file.constants import maybe_file_object
from core.file.models import File
from core.memory.entities import MemoryBlockSpec
from core.variables import utils as variable_utils
from core.variables.segments import VersionedMemoryValue
from core.variables.variables import FloatVariable, IntegerVariable, StringVariable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
MEMORY_BLOCK_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
)
from core.workflow.enums import NodeType
from extensions.ext_storage import Storage
from factories.variable_factory import TypeMismatchError, build_segment_with_type
@ -360,7 +367,9 @@ class Workflow(Base):
@property
def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
# _environment_variables is guaranteed to be non-None due to server_default="{}"
# TODO: find some way to init `self._environment_variables` when instance created.
if self._environment_variables is None:
self._environment_variables = "{}"
# Use workflow.tenant_id to avoid relying on request user in background threads
tenant_id = self.tenant_id
@ -438,13 +447,14 @@ class Workflow(Base):
"features": self.features_dict,
"environment_variables": [var.model_dump(mode="json") for var in environment_variables],
"conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables],
"rag_pipeline_variables": self.rag_pipeline_variables,
}
return result
@property
def conversation_variables(self) -> Sequence[Variable]:
# _conversation_variables is guaranteed to be non-None due to server_default="{}"
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._conversation_variables is None:
self._conversation_variables = "{}"
variables_dict: dict[str, Any] = json.loads(self._conversation_variables)
results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()]
@ -474,6 +484,17 @@ class Workflow(Base):
ensure_ascii=False,
)
@property
def memory_blocks(self) -> Sequence[MemoryBlockSpec]:
"""Memory blocks configuration from graph"""
if not self.graph_dict:
return []
memory_blocks_config = self.graph_dict.get('memory_blocks', [])
results = [MemoryBlockSpec.model_validate(config) for config in memory_blocks_config]
return results
@staticmethod
def version_from_datetime(d: datetime) -> str:
return str(d)
@ -1485,6 +1506,31 @@ class WorkflowDraftVariable(Base):
variable.editable = editable
return variable
@staticmethod
def new_memory_block_variable(
*,
app_id: str,
node_id: str | None = None,
memory_id: str,
name: str,
value: VersionedMemoryValue,
description: str = "",
) -> "WorkflowDraftVariable":
"""Create a new memory block draft variable."""
return WorkflowDraftVariable(
id=str(uuid.uuid4()),
app_id=app_id,
node_id=MEMORY_BLOCK_VARIABLE_NODE_ID,
name=name,
value=value.model_dump_json(),
description=description,
selector=[MEMORY_BLOCK_VARIABLE_NODE_ID, memory_id] if node_id is None else
[MEMORY_BLOCK_VARIABLE_NODE_ID, memory_id, node_id],
value_type=SegmentType.VERSIONED_MEMORY,
visible=True,
editable=True,
)
@property
def edited(self):
return self.last_edited_at is not None

View File

@ -113,7 +113,7 @@ dev = [
"lxml-stubs~=0.5.1",
"ty~=0.0.1a19",
"basedpyright~=1.31.0",
"ruff~=0.12.3",
"ruff~=0.14.0",
"pytest~=8.3.2",
"pytest-benchmark~=4.0.0",
"pytest-cov~=4.1.0",

View File

@ -246,10 +246,8 @@ class AccountService:
)
)
account = Account()
account.email = email
account.name = name
password_to_set = None
salt_to_set = None
if password:
valid_password(password)
@ -261,14 +259,18 @@ class AccountService:
password_hashed = hash_password(password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
password_to_set = base64_password_hashed
salt_to_set = base64_salt
account.interface_language = interface_language
account.interface_theme = interface_theme
# Set timezone based on language
account.timezone = language_timezone_mapping.get(interface_language, "UTC")
account = Account(
name=name,
email=email,
password=password_to_set,
password_salt=salt_to_set,
interface_language=interface_language,
interface_theme=interface_theme,
timezone=language_timezone_mapping.get(interface_language, "UTC"),
)
db.session.add(account)
db.session.commit()

View File

@ -659,31 +659,31 @@ class AppDslService:
typ = node.get("data", {}).get("type")
match typ:
case NodeType.TOOL.value:
tool_entity = ToolNodeData(**node["data"])
tool_entity = ToolNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id),
)
case NodeType.LLM.value:
llm_entity = LLMNodeData(**node["data"])
llm_entity = LLMNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider),
)
case NodeType.QUESTION_CLASSIFIER.value:
question_classifier_entity = QuestionClassifierNodeData(**node["data"])
question_classifier_entity = QuestionClassifierNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
question_classifier_entity.model.provider
),
)
case NodeType.PARAMETER_EXTRACTOR.value:
parameter_extractor_entity = ParameterExtractorNodeData(**node["data"])
parameter_extractor_entity = ParameterExtractorNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
parameter_extractor_entity.model.provider
),
)
case NodeType.KNOWLEDGE_RETRIEVAL.value:
knowledge_retrieval_entity = KnowledgeRetrievalNodeData(**node["data"])
knowledge_retrieval_entity = KnowledgeRetrievalNodeData.model_validate(node["data"])
if knowledge_retrieval_entity.retrieval_mode == "multiple":
if knowledge_retrieval_entity.multiple_retrieval_config:
if (
@ -773,7 +773,7 @@ class AppDslService:
"""
Returns the leaked dependencies in current workspace
"""
dependencies = [PluginDependency(**dep) for dep in dsl_dependencies]
dependencies = [PluginDependency.model_validate(dep) for dep in dsl_dependencies]
if not dependencies:
return []

View File

@ -0,0 +1,237 @@
import json
from collections.abc import MutableMapping, Sequence
from typing import Literal, Optional, overload
from sqlalchemy import Row, Select, and_, func, select
from sqlalchemy.orm import Session
from core.memory.entities import ChatflowConversationMetadata
from core.model_runtime.entities.message_entities import (
PromptMessage,
)
from extensions.ext_database import db
from models.chatflow_memory import ChatflowConversation, ChatflowMessage
class ChatflowHistoryService:
@staticmethod
def get_visible_chat_history(
conversation_id: str,
app_id: str,
tenant_id: str,
node_id: Optional[str] = None,
max_visible_count: Optional[int] = None
) -> Sequence[PromptMessage]:
with Session(db.engine) as session:
chatflow_conv = ChatflowHistoryService._get_or_create_chatflow_conversation(
session, conversation_id, app_id, tenant_id, node_id, create_if_missing=False
)
if not chatflow_conv:
return []
metadata = ChatflowConversationMetadata.model_validate_json(chatflow_conv.conversation_metadata)
visible_count: int = max_visible_count or metadata.visible_count
stmt = select(ChatflowMessage).where(
ChatflowMessage.conversation_id == chatflow_conv.id
).order_by(ChatflowMessage.index.asc(), ChatflowMessage.version.desc())
raw_messages: Sequence[Row[tuple[ChatflowMessage]]] = session.execute(stmt).all()
sorted_messages = ChatflowHistoryService._filter_latest_messages(
[it[0] for it in raw_messages]
)
visible_count = min(visible_count, len(sorted_messages))
visible_messages = sorted_messages[-visible_count:]
return [PromptMessage.model_validate_json(it.data) for it in visible_messages]
@staticmethod
def save_message(
prompt_message: PromptMessage,
conversation_id: str,
app_id: str,
tenant_id: str,
node_id: Optional[str] = None
) -> None:
with Session(db.engine) as session:
chatflow_conv = ChatflowHistoryService._get_or_create_chatflow_conversation(
session, conversation_id, app_id, tenant_id, node_id, create_if_missing=True
)
# Get next index
max_index = session.execute(
select(func.max(ChatflowMessage.index)).where(
ChatflowMessage.conversation_id == chatflow_conv.id
)
).scalar() or -1
next_index = max_index + 1
# Save new message to append-only table
new_message = ChatflowMessage(
conversation_id=chatflow_conv.id,
index=next_index,
version=1,
data=json.dumps(prompt_message)
)
session.add(new_message)
session.commit()
# 添加每次保存消息后简单增长visible_count
current_metadata = ChatflowConversationMetadata.model_validate_json(chatflow_conv.conversation_metadata)
new_visible_count = current_metadata.visible_count + 1
new_metadata = ChatflowConversationMetadata(visible_count=new_visible_count)
chatflow_conv.conversation_metadata = new_metadata.model_dump_json()
@staticmethod
def save_app_message(
prompt_message: PromptMessage,
conversation_id: str,
app_id: str,
tenant_id: str
) -> None:
"""Save PromptMessage to app-level chatflow conversation."""
ChatflowHistoryService.save_message(
prompt_message=prompt_message,
conversation_id=conversation_id,
app_id=app_id,
tenant_id=tenant_id,
node_id=None
)
@staticmethod
def save_node_message(
prompt_message: PromptMessage,
node_id: str,
conversation_id: str,
app_id: str,
tenant_id: str
) -> None:
ChatflowHistoryService.save_message(
prompt_message=prompt_message,
conversation_id=conversation_id,
app_id=app_id,
tenant_id=tenant_id,
node_id=node_id
)
@staticmethod
def update_visible_count(
conversation_id: str,
node_id: Optional[str],
new_visible_count: int,
app_id: str,
tenant_id: str
) -> None:
with Session(db.engine) as session:
chatflow_conv = ChatflowHistoryService._get_or_create_chatflow_conversation(
session, conversation_id, app_id, tenant_id, node_id, create_if_missing=True
)
# Only update visible_count in metadata, do not delete any data
new_metadata = ChatflowConversationMetadata(visible_count=new_visible_count)
chatflow_conv.conversation_metadata = new_metadata.model_dump_json()
session.commit()
@staticmethod
def get_conversation_metadata(
tenant_id: str,
app_id: str,
conversation_id: str,
node_id: Optional[str]
) -> ChatflowConversationMetadata:
with Session(db.engine) as session:
chatflow_conv = ChatflowHistoryService._get_or_create_chatflow_conversation(
session, conversation_id, app_id, tenant_id, node_id, create_if_missing=False
)
if not chatflow_conv:
raise ValueError(f"Conversation not found: {conversation_id}")
return ChatflowConversationMetadata.model_validate_json(chatflow_conv.conversation_metadata)
@staticmethod
def _filter_latest_messages(raw_messages: Sequence[ChatflowMessage]) -> Sequence[ChatflowMessage]:
index_to_message: MutableMapping[int, ChatflowMessage] = {}
for msg in raw_messages:
index = msg.index
if index not in index_to_message or msg.version > index_to_message[index].version:
index_to_message[index] = msg
sorted_messages = sorted(index_to_message.values(), key=lambda m: m.index)
return sorted_messages
@overload
@staticmethod
def _get_or_create_chatflow_conversation(
session: Session,
conversation_id: str,
app_id: str,
tenant_id: str,
node_id: Optional[str] = None,
create_if_missing: Literal[True] = True
) -> ChatflowConversation: ...
@overload
@staticmethod
def _get_or_create_chatflow_conversation(
session: Session,
conversation_id: str,
app_id: str,
tenant_id: str,
node_id: Optional[str] = None,
create_if_missing: Literal[False] = False
) -> Optional[ChatflowConversation]: ...
@overload
@staticmethod
def _get_or_create_chatflow_conversation(
session: Session,
conversation_id: str,
app_id: str,
tenant_id: str,
node_id: Optional[str] = None,
create_if_missing: bool = False
) -> Optional[ChatflowConversation]: ...
@staticmethod
def _get_or_create_chatflow_conversation(
session: Session,
conversation_id: str,
app_id: str,
tenant_id: str,
node_id: Optional[str] = None,
create_if_missing: bool = False
) -> Optional[ChatflowConversation]:
"""Get existing chatflow conversation or optionally create new one"""
stmt: Select[tuple[ChatflowConversation]] = select(ChatflowConversation).where(
and_(
ChatflowConversation.original_conversation_id == conversation_id,
ChatflowConversation.tenant_id == tenant_id,
ChatflowConversation.app_id == app_id
)
)
if node_id:
stmt = stmt.where(ChatflowConversation.node_id == node_id)
else:
stmt = stmt.where(ChatflowConversation.node_id.is_(None))
chatflow_conv: Row[tuple[ChatflowConversation]] | None = session.execute(stmt).first()
if chatflow_conv:
result: ChatflowConversation = chatflow_conv[0] # Extract the ChatflowConversation object
return result
else:
if create_if_missing:
# Create a new chatflow conversation
default_metadata = ChatflowConversationMetadata(visible_count=0)
new_chatflow_conv = ChatflowConversation(
tenant_id=tenant_id,
app_id=app_id,
node_id=node_id,
original_conversation_id=conversation_id,
conversation_metadata=default_metadata.model_dump_json(),
)
session.add(new_chatflow_conv)
session.flush() # Obtain ID
return new_chatflow_conv
return None

View File

@ -0,0 +1,679 @@
import logging
import threading
import time
from collections.abc import Sequence
from typing import Optional
from sqlalchemy import and_, delete, select
from sqlalchemy.orm import Session
from core.llm_generator.llm_generator import LLMGenerator
from core.memory.entities import (
MemoryBlock,
MemoryBlockSpec,
MemoryBlockWithConversation,
MemoryCreatedBy,
MemoryScheduleMode,
MemoryScope,
MemoryTerm,
MemoryValueData,
)
from core.memory.errors import MemorySyncTimeoutError
from core.model_runtime.entities.message_entities import PromptMessage
from core.variables.segments import VersionedMemoryValue
from core.workflow.constants import MEMORY_BLOCK_VARIABLE_NODE_ID
from core.workflow.entities.variable_pool import VariablePool
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models import App, CreatorUserRole
from models.chatflow_memory import ChatflowMemoryVariable
from models.workflow import Workflow, WorkflowDraftVariable
from services.chatflow_history_service import ChatflowHistoryService
from services.workflow_draft_variable_service import WorkflowDraftVariableService
from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)
class ChatflowMemoryService:
@staticmethod
def get_persistent_memories(
app: App,
created_by: MemoryCreatedBy,
version: int | None = None
) -> Sequence[MemoryBlock]:
if created_by.account_id:
created_by_role = CreatorUserRole.ACCOUNT
created_by_id = created_by.account_id
else:
created_by_role = CreatorUserRole.END_USER
created_by_id = created_by.id
if version is None:
# If version not specified, get the latest version
stmt = select(ChatflowMemoryVariable).distinct(ChatflowMemoryVariable.memory_id).where(
and_(
ChatflowMemoryVariable.tenant_id == app.tenant_id,
ChatflowMemoryVariable.app_id == app.id,
ChatflowMemoryVariable.conversation_id == None,
ChatflowMemoryVariable.created_by_role == created_by_role,
ChatflowMemoryVariable.created_by == created_by_id,
)
).order_by(ChatflowMemoryVariable.version.desc())
else:
stmt = select(ChatflowMemoryVariable).where(
and_(
ChatflowMemoryVariable.tenant_id == app.tenant_id,
ChatflowMemoryVariable.app_id == app.id,
ChatflowMemoryVariable.conversation_id == None,
ChatflowMemoryVariable.created_by_role == created_by_role,
ChatflowMemoryVariable.created_by == created_by_id,
ChatflowMemoryVariable.version == version
)
)
with Session(db.engine) as session:
db_results = session.execute(stmt).all()
return ChatflowMemoryService._convert_to_memory_blocks(app, created_by, [result[0] for result in db_results])
@staticmethod
def get_session_memories(
app: App,
created_by: MemoryCreatedBy,
conversation_id: str,
version: int | None = None
) -> Sequence[MemoryBlock]:
if version is None:
# If version not specified, get the latest version
stmt = select(ChatflowMemoryVariable).distinct(ChatflowMemoryVariable.memory_id).where(
and_(
ChatflowMemoryVariable.tenant_id == app.tenant_id,
ChatflowMemoryVariable.app_id == app.id,
ChatflowMemoryVariable.conversation_id == conversation_id
)
).order_by(ChatflowMemoryVariable.version.desc())
else:
stmt = select(ChatflowMemoryVariable).where(
and_(
ChatflowMemoryVariable.tenant_id == app.tenant_id,
ChatflowMemoryVariable.app_id == app.id,
ChatflowMemoryVariable.conversation_id == conversation_id,
ChatflowMemoryVariable.version == version
)
)
with Session(db.engine) as session:
db_results = session.execute(stmt).all()
return ChatflowMemoryService._convert_to_memory_blocks(app, created_by, [result[0] for result in db_results])
@staticmethod
def save_memory(memory: MemoryBlock, variable_pool: VariablePool, is_draft: bool) -> None:
key = f"{memory.node_id}.{memory.spec.id}" if memory.node_id else memory.spec.id
variable_pool.add([MEMORY_BLOCK_VARIABLE_NODE_ID, key], memory.value)
if memory.created_by.account_id:
created_by_role = CreatorUserRole.ACCOUNT
created_by = memory.created_by.account_id
else:
created_by_role = CreatorUserRole.END_USER
created_by = memory.created_by.id
with Session(db.engine) as session:
session.add(
ChatflowMemoryVariable(
memory_id=memory.spec.id,
tenant_id=memory.tenant_id,
app_id=memory.app_id,
node_id=memory.node_id,
conversation_id=memory.conversation_id,
name=memory.spec.name,
value=MemoryValueData(
value=memory.value,
edited_by_user=memory.edited_by_user
).model_dump_json(),
term=memory.spec.term,
scope=memory.spec.scope,
version=memory.version, # Use version from MemoryBlock directly
created_by_role=created_by_role,
created_by=created_by,
)
)
session.commit()
if is_draft:
with Session(bind=db.engine) as session:
draft_var_service = WorkflowDraftVariableService(session)
existing_vars = draft_var_service.get_draft_variables_by_selectors(
app_id=memory.app_id,
selectors=[['memory_block', memory.spec.id]]
)
if existing_vars:
draft_var = existing_vars[0]
draft_var.value = VersionedMemoryValue.model_validate_json(draft_var.value)\
.add_version(memory.value)\
.model_dump_json()
else:
draft_var = WorkflowDraftVariable.new_memory_block_variable(
app_id=memory.app_id,
memory_id=memory.spec.id,
name=memory.spec.name,
value=VersionedMemoryValue().add_version(memory.value),
description=memory.spec.description
)
session.add(draft_var)
session.commit()
@staticmethod
def get_memories_by_specs(
memory_block_specs: Sequence[MemoryBlockSpec],
tenant_id: str, app_id: str,
created_by: MemoryCreatedBy,
conversation_id: Optional[str],
node_id: Optional[str],
is_draft: bool
) -> Sequence[MemoryBlock]:
return [ChatflowMemoryService.get_memory_by_spec(
spec, tenant_id, app_id, created_by, conversation_id, node_id, is_draft
) for spec in memory_block_specs]
@staticmethod
def get_memory_by_spec(
spec: MemoryBlockSpec,
tenant_id: str,
app_id: str,
created_by: MemoryCreatedBy,
conversation_id: Optional[str],
node_id: Optional[str],
is_draft: bool
) -> MemoryBlock:
with Session(db.engine) as session:
if is_draft:
draft_var_service = WorkflowDraftVariableService(session)
selector = [MEMORY_BLOCK_VARIABLE_NODE_ID, f"{spec.id}.{node_id}"] \
if node_id else [MEMORY_BLOCK_VARIABLE_NODE_ID, spec.id]
draft_vars = draft_var_service.get_draft_variables_by_selectors(
app_id=app_id,
selectors=[selector]
)
if draft_vars:
draft_var = draft_vars[0]
return MemoryBlock(
value=draft_var.get_value().text,
tenant_id=tenant_id,
app_id=app_id,
conversation_id=conversation_id,
node_id=node_id,
spec=spec,
created_by=created_by,
version=1,
)
stmt = select(ChatflowMemoryVariable).where(
and_(
ChatflowMemoryVariable.memory_id == spec.id,
ChatflowMemoryVariable.tenant_id == tenant_id,
ChatflowMemoryVariable.app_id == app_id,
ChatflowMemoryVariable.node_id ==
(node_id if spec.scope == MemoryScope.NODE else None),
ChatflowMemoryVariable.conversation_id ==
(conversation_id if spec.term == MemoryTerm.SESSION else None),
)
).order_by(ChatflowMemoryVariable.version.desc()).limit(1)
result = session.execute(stmt).scalar()
if result:
memory_value_data = MemoryValueData.model_validate_json(result.value)
return MemoryBlock(
value=memory_value_data.value,
tenant_id=tenant_id,
app_id=app_id,
conversation_id=conversation_id,
node_id=node_id,
spec=spec,
edited_by_user=memory_value_data.edited_by_user,
created_by=created_by,
version=result.version,
)
return MemoryBlock(
tenant_id=tenant_id,
value=spec.template,
app_id=app_id,
conversation_id=conversation_id,
node_id=node_id,
spec=spec,
created_by=created_by,
version=1,
)
@staticmethod
def update_app_memory_if_needed(
workflow: Workflow,
conversation_id: str,
variable_pool: VariablePool,
created_by: MemoryCreatedBy,
is_draft: bool
):
visible_messages = ChatflowHistoryService.get_visible_chat_history(
conversation_id=conversation_id,
app_id=workflow.app_id,
tenant_id=workflow.tenant_id,
node_id=None,
)
sync_blocks: list[MemoryBlock] = []
async_blocks: list[MemoryBlock] = []
for memory_spec in workflow.memory_blocks:
if memory_spec.scope == MemoryScope.APP:
memory = ChatflowMemoryService.get_memory_by_spec(
spec=memory_spec,
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
conversation_id=conversation_id,
node_id=None,
is_draft=is_draft,
created_by=created_by,
)
if ChatflowMemoryService._should_update_memory(memory, visible_messages):
if memory.spec.schedule_mode == MemoryScheduleMode.SYNC:
sync_blocks.append(memory)
else:
async_blocks.append(memory)
if not sync_blocks and not async_blocks:
return
# async mode: submit individual async tasks directly
for memory_block in async_blocks:
ChatflowMemoryService._app_submit_async_memory_update(
block=memory_block,
is_draft=is_draft,
variable_pool=variable_pool,
visible_messages=visible_messages,
conversation_id=conversation_id,
)
# sync mode: submit a batch update task
if sync_blocks:
ChatflowMemoryService._app_submit_sync_memory_batch_update(
sync_blocks=sync_blocks,
is_draft=is_draft,
conversation_id=conversation_id,
app_id=workflow.app_id,
visible_messages=visible_messages,
variable_pool=variable_pool
)
@staticmethod
def update_node_memory_if_needed(
tenant_id: str,
app_id: str,
node_id: str,
created_by: MemoryCreatedBy,
conversation_id: str,
memory_block_spec: MemoryBlockSpec,
variable_pool: VariablePool,
is_draft: bool
) -> bool:
visible_messages = ChatflowHistoryService.get_visible_chat_history(
conversation_id=conversation_id,
app_id=app_id,
tenant_id=tenant_id,
node_id=node_id,
)
memory_block = ChatflowMemoryService.get_memory_by_spec(
spec=memory_block_spec,
tenant_id=tenant_id,
app_id=app_id,
conversation_id=conversation_id,
node_id=node_id,
is_draft=is_draft,
created_by=created_by,
)
if not ChatflowMemoryService._should_update_memory(
memory_block=memory_block,
visible_history=visible_messages
):
return False
if memory_block_spec.schedule_mode == MemoryScheduleMode.SYNC:
# Node-level sync: blocking execution
ChatflowMemoryService._update_node_memory_sync(
visible_messages=visible_messages,
memory_block=memory_block,
variable_pool=variable_pool,
is_draft=is_draft,
conversation_id=conversation_id
)
else:
# Node-level async: execute asynchronously
ChatflowMemoryService._update_node_memory_async(
memory_block=memory_block,
visible_messages=visible_messages,
variable_pool=variable_pool,
is_draft=is_draft,
conversation_id=conversation_id
)
return True
@staticmethod
def wait_for_sync_memory_completion(workflow: Workflow, conversation_id: str):
"""Wait for sync memory update to complete, maximum 50 seconds"""
memory_blocks = workflow.memory_blocks
sync_memory_blocks = [
block for block in memory_blocks
if block.scope == MemoryScope.APP and block.schedule_mode == MemoryScheduleMode.SYNC
]
if not sync_memory_blocks:
return
lock_key = _get_memory_sync_lock_key(workflow.app_id, conversation_id)
# Retry up to 10 times, wait 5 seconds each time, total 50 seconds
max_retries = 10
retry_interval = 5
for i in range(max_retries):
if not redis_client.exists(lock_key):
# Lock doesn't exist, can continue
return
if i < max_retries - 1:
# Still have retry attempts, wait
time.sleep(retry_interval)
else:
# Maximum retry attempts reached, raise exception
raise MemorySyncTimeoutError(
app_id=workflow.app_id,
conversation_id=conversation_id
)
@staticmethod
def _convert_to_memory_blocks(
app: App,
created_by: MemoryCreatedBy,
raw_results: Sequence[ChatflowMemoryVariable]
) -> Sequence[MemoryBlock]:
workflow = WorkflowService().get_published_workflow(app)
if not workflow:
return []
results = []
for chatflow_memory_variable in raw_results:
spec = next(
(spec for spec in workflow.memory_blocks if spec.id == chatflow_memory_variable.memory_id),
None
)
if spec and chatflow_memory_variable.app_id:
memory_value_data = MemoryValueData.model_validate_json(chatflow_memory_variable.value)
results.append(
MemoryBlock(
spec=spec,
tenant_id=chatflow_memory_variable.tenant_id,
value=memory_value_data.value,
app_id=chatflow_memory_variable.app_id,
conversation_id=chatflow_memory_variable.conversation_id,
node_id=chatflow_memory_variable.node_id,
edited_by_user=memory_value_data.edited_by_user,
created_by=created_by,
version=chatflow_memory_variable.version,
)
)
return results
@staticmethod
def _should_update_memory(
memory_block: MemoryBlock,
visible_history: Sequence[PromptMessage]
) -> bool:
return len(visible_history) >= memory_block.spec.update_turns
@staticmethod
def _app_submit_async_memory_update(
block: MemoryBlock,
visible_messages: Sequence[PromptMessage],
variable_pool: VariablePool,
conversation_id: str,
is_draft: bool
):
thread = threading.Thread(
target=ChatflowMemoryService._perform_memory_update,
kwargs={
'memory_block': block,
'visible_messages': visible_messages,
'variable_pool': variable_pool,
'is_draft': is_draft,
'conversation_id': conversation_id
},
)
thread.start()
@staticmethod
def _app_submit_sync_memory_batch_update(
sync_blocks: Sequence[MemoryBlock],
app_id: str,
conversation_id: str,
visible_messages: Sequence[PromptMessage],
variable_pool: VariablePool,
is_draft: bool
):
"""Submit sync memory batch update task"""
thread = threading.Thread(
target=ChatflowMemoryService._batch_update_sync_memory,
kwargs={
'sync_blocks': sync_blocks,
'app_id': app_id,
'conversation_id': conversation_id,
'visible_messages': visible_messages,
'variable_pool': variable_pool,
'is_draft': is_draft
},
)
thread.start()
@staticmethod
def _batch_update_sync_memory(
sync_blocks: Sequence[MemoryBlock],
app_id: str,
conversation_id: str,
visible_messages: Sequence[PromptMessage],
variable_pool: VariablePool,
is_draft: bool
):
try:
lock_key = _get_memory_sync_lock_key(app_id, conversation_id)
with redis_client.lock(lock_key, timeout=120):
threads = []
for block in sync_blocks:
thread = threading.Thread(
target=ChatflowMemoryService._perform_memory_update,
kwargs={
'memory_block': block,
'visible_messages': visible_messages,
'variable_pool': variable_pool,
'is_draft': is_draft,
'conversation_id': conversation_id,
},
)
threads.append(thread)
for thread in threads:
thread.start()
for thread in threads:
thread.join()
except Exception as e:
logger.exception("Error batch updating memory", exc_info=e)
@staticmethod
def _update_node_memory_sync(
memory_block: MemoryBlock,
visible_messages: Sequence[PromptMessage],
variable_pool: VariablePool,
conversation_id: str,
is_draft: bool
):
ChatflowMemoryService._perform_memory_update(
memory_block=memory_block,
visible_messages=visible_messages,
variable_pool=variable_pool,
is_draft=is_draft,
conversation_id=conversation_id
)
@staticmethod
def _update_node_memory_async(
memory_block: MemoryBlock,
visible_messages: Sequence[PromptMessage],
variable_pool: VariablePool,
conversation_id: str,
is_draft: bool = False
):
thread = threading.Thread(
target=ChatflowMemoryService._perform_memory_update,
kwargs={
'memory_block': memory_block,
'visible_messages': visible_messages,
'variable_pool': variable_pool,
'is_draft': is_draft,
'conversation_id': conversation_id,
},
daemon=True
)
thread.start()
@staticmethod
def _perform_memory_update(
memory_block: MemoryBlock,
variable_pool: VariablePool,
conversation_id: str,
visible_messages: Sequence[PromptMessage],
is_draft: bool
):
updated_value = LLMGenerator.update_memory_block(
tenant_id=memory_block.tenant_id,
visible_history=ChatflowMemoryService._format_chat_history(visible_messages),
variable_pool=variable_pool,
memory_block=memory_block,
memory_spec=memory_block.spec,
)
updated_memory = MemoryBlock(
tenant_id=memory_block.tenant_id,
value=updated_value,
spec=memory_block.spec,
app_id=memory_block.app_id,
conversation_id=conversation_id,
node_id=memory_block.node_id,
edited_by_user=False,
created_by=memory_block.created_by,
version=memory_block.version + 1, # Increment version for business logic update
)
ChatflowMemoryService.save_memory(updated_memory, variable_pool, is_draft)
ChatflowHistoryService.update_visible_count(
conversation_id=conversation_id,
node_id=memory_block.node_id,
new_visible_count=memory_block.spec.preserved_turns,
app_id=memory_block.app_id,
tenant_id=memory_block.tenant_id
)
@staticmethod
def delete_memory(app: App, memory_id: str, created_by: MemoryCreatedBy):
workflow = WorkflowService().get_published_workflow(app)
if not workflow:
raise ValueError("Workflow not found")
memory_spec = next((it for it in workflow.memory_blocks if it.id == memory_id), None)
if not memory_spec or not memory_spec.end_user_editable:
raise ValueError("Memory not found or not deletable")
if created_by.account_id:
created_by_role = CreatorUserRole.ACCOUNT
created_by_id = created_by.account_id
else:
created_by_role = CreatorUserRole.END_USER
created_by_id = created_by.id
with Session(db.engine) as session:
stmt = delete(ChatflowMemoryVariable).where(
and_(
ChatflowMemoryVariable.tenant_id == app.tenant_id,
ChatflowMemoryVariable.app_id == app.id,
ChatflowMemoryVariable.memory_id == memory_id,
ChatflowMemoryVariable.created_by_role == created_by_role,
ChatflowMemoryVariable.created_by == created_by_id
)
)
session.execute(stmt)
session.commit()
@staticmethod
def delete_all_user_memories(app: App, created_by: MemoryCreatedBy):
if created_by.account_id:
created_by_role = CreatorUserRole.ACCOUNT
created_by_id = created_by.account_id
else:
created_by_role = CreatorUserRole.END_USER
created_by_id = created_by.id
with Session(db.engine) as session:
stmt = delete(ChatflowMemoryVariable).where(
and_(
ChatflowMemoryVariable.tenant_id == app.tenant_id,
ChatflowMemoryVariable.app_id == app.id,
ChatflowMemoryVariable.created_by_role == created_by_role,
ChatflowMemoryVariable.created_by == created_by_id
)
)
session.execute(stmt)
session.commit()
@staticmethod
def get_persistent_memories_with_conversation(
app: App,
created_by: MemoryCreatedBy,
conversation_id: str,
version: int | None = None
) -> Sequence[MemoryBlockWithConversation]:
"""Get persistent memories with conversation metadata (always None for persistent)"""
memory_blocks = ChatflowMemoryService.get_persistent_memories(app, created_by, version)
return [
MemoryBlockWithConversation.from_memory_block(
block,
ChatflowHistoryService.get_conversation_metadata(
app.tenant_id, app.id, conversation_id, block.node_id
)
)
for block in memory_blocks
]
@staticmethod
def get_session_memories_with_conversation(
app: App,
created_by: MemoryCreatedBy,
conversation_id: str,
version: int | None = None
) -> Sequence[MemoryBlockWithConversation]:
"""Get session memories with conversation metadata"""
memory_blocks = ChatflowMemoryService.get_session_memories(app, created_by, conversation_id, version)
return [
MemoryBlockWithConversation.from_memory_block(
block,
ChatflowHistoryService.get_conversation_metadata(
app.tenant_id, app.id, conversation_id, block.node_id
)
)
for block in memory_blocks
]
@staticmethod
def _format_chat_history(messages: Sequence[PromptMessage]) -> Sequence[tuple[str, str]]:
result = []
for message in messages:
result.append((str(message.role.value), message.get_text_content()))
return result
def _get_memory_sync_lock_key(app_id: str, conversation_id: str) -> str:
"""Generate Redis lock key for memory sync updates
Args:
app_id: Application ID
conversation_id: Conversation ID
Returns:
Formatted lock key
"""
return f"memory_sync_update:{app_id}:{conversation_id}"

View File

@ -70,7 +70,7 @@ class EnterpriseService:
data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/id", params=params)
if not data:
raise ValueError("No data found.")
return WebAppSettings(**data)
return WebAppSettings.model_validate(data)
@classmethod
def batch_get_app_access_mode_by_id(cls, app_ids: list[str]) -> dict[str, WebAppSettings]:
@ -100,7 +100,7 @@ class EnterpriseService:
data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/code", params=params)
if not data:
raise ValueError("No data found.")
return WebAppSettings(**data)
return WebAppSettings.model_validate(data)
@classmethod
def update_app_access_mode(cls, app_id: str, access_mode: str):

View File

@ -1,6 +1,7 @@
from collections.abc import Sequence
from enum import Enum
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, model_validator
from configs import dify_config
from core.entities.model_entities import (
@ -71,7 +72,7 @@ class ProviderResponse(BaseModel):
icon_large: I18nObject | None = None
background: str | None = None
help: ProviderHelpEntity | None = None
supported_model_types: list[ModelType]
supported_model_types: Sequence[ModelType]
configurate_methods: list[ConfigurateMethod]
provider_credential_schema: ProviderCredentialSchema | None = None
model_credential_schema: ModelCredentialSchema | None = None
@ -82,9 +83,8 @@ class ProviderResponse(BaseModel):
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def __init__(self, **data):
super().__init__(**data)
@model_validator(mode="after")
def _(self):
url_prefix = (
dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
)
@ -97,6 +97,7 @@ class ProviderResponse(BaseModel):
self.icon_large = I18nObject(
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
)
return self
class ProviderWithModelsResponse(BaseModel):
@ -112,9 +113,8 @@ class ProviderWithModelsResponse(BaseModel):
status: CustomConfigurationStatus
models: list[ProviderModelWithStatusEntity]
def __init__(self, **data):
super().__init__(**data)
@model_validator(mode="after")
def _(self):
url_prefix = (
dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
)
@ -127,6 +127,7 @@ class ProviderWithModelsResponse(BaseModel):
self.icon_large = I18nObject(
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
)
return self
class SimpleProviderEntityResponse(SimpleProviderEntity):
@ -136,9 +137,8 @@ class SimpleProviderEntityResponse(SimpleProviderEntity):
tenant_id: str
def __init__(self, **data):
super().__init__(**data)
@model_validator(mode="after")
def _(self):
url_prefix = (
dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
)
@ -151,6 +151,7 @@ class SimpleProviderEntityResponse(SimpleProviderEntity):
self.icon_large = I18nObject(
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
)
return self
class DefaultModelResponse(BaseModel):

View File

@ -46,7 +46,7 @@ class HitTestingService:
from core.app.app_config.entities import MetadataFilteringCondition
metadata_filtering_conditions = MetadataFilteringCondition(**metadata_filtering_conditions)
metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions)
metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
dataset_ids=[dataset.id],

View File

@ -123,7 +123,7 @@ class OpsService:
config_class: type[BaseTracingConfig] = provider_config["config_class"]
other_keys: list[str] = provider_config["other_keys"]
default_config_instance: BaseTracingConfig = config_class(**tracing_config)
default_config_instance = config_class.model_validate(tracing_config)
for key in other_keys:
if key in tracing_config and tracing_config[key] == "":
tracing_config[key] = getattr(default_config_instance, key, None)

Some files were not shown because too many files have changed in this diff Show More