Fix basedpyright type errors (#25435)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
-LAN- 2025-09-10 01:54:26 +08:00 committed by GitHub
parent 2ac7a9c8fc
commit 08dd3f7b50
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
100 changed files with 847 additions and 497 deletions

View File

@ -511,7 +511,7 @@ def add_qdrant_index(field: str):
from qdrant_client.http.exceptions import UnexpectedResponse from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import PayloadSchemaType from qdrant_client.http.models import PayloadSchemaType
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
for binding in bindings: for binding in bindings:
if dify_config.QDRANT_URL is None: if dify_config.QDRANT_URL is None:
@ -525,7 +525,21 @@ def add_qdrant_index(field: str):
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED, prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,
) )
try: try:
client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params()) params = qdrant_config.to_qdrant_params()
# Check the type before using
if isinstance(params, PathQdrantParams):
# PathQdrantParams case
client = qdrant_client.QdrantClient(path=params.path)
else:
# UrlQdrantParams case - params is UrlQdrantParams
client = qdrant_client.QdrantClient(
url=params.url,
api_key=params.api_key,
timeout=int(params.timeout),
verify=params.verify,
grpc_port=params.grpc_port,
prefer_grpc=params.prefer_grpc,
)
# create payload index # create payload index
client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD) client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD)
create_count += 1 create_count += 1

View File

@ -16,14 +16,14 @@ AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"]
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS]) AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
_doc_extensions: list[str]
if dify_config.ETL_TYPE == "Unstructured": if dify_config.ETL_TYPE == "Unstructured":
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"] _doc_extensions = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"]
DOCUMENT_EXTENSIONS.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub")) _doc_extensions.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
if dify_config.UNSTRUCTURED_API_URL: if dify_config.UNSTRUCTURED_API_URL:
DOCUMENT_EXTENSIONS.append("ppt") _doc_extensions.append("ppt")
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
else: else:
DOCUMENT_EXTENSIONS = [ _doc_extensions = [
"txt", "txt",
"markdown", "markdown",
"md", "md",
@ -38,4 +38,4 @@ else:
"vtt", "vtt",
"properties", "properties",
] ]
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS]) DOCUMENT_EXTENSIONS = _doc_extensions + [ext.upper() for ext in _doc_extensions]

View File

@ -8,7 +8,6 @@ if TYPE_CHECKING:
from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.entities.model_entities import AIModelEntity
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.provider import PluginToolProviderController
from core.workflow.entities.variable_pool import VariablePool
""" """

View File

@ -43,56 +43,64 @@ api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm"
api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies") api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")
# Import other controllers # Import other controllers
from . import admin, apikey, extension, feature, ping, setup, version from . import admin, apikey, extension, feature, ping, setup, version # pyright: ignore[reportUnusedImport]
# Import app controllers # Import app controllers
from .app import ( from .app import (
advanced_prompt_template, advanced_prompt_template, # pyright: ignore[reportUnusedImport]
agent, agent, # pyright: ignore[reportUnusedImport]
annotation, annotation, # pyright: ignore[reportUnusedImport]
app, app, # pyright: ignore[reportUnusedImport]
audio, audio, # pyright: ignore[reportUnusedImport]
completion, completion, # pyright: ignore[reportUnusedImport]
conversation, conversation, # pyright: ignore[reportUnusedImport]
conversation_variables, conversation_variables, # pyright: ignore[reportUnusedImport]
generator, generator, # pyright: ignore[reportUnusedImport]
mcp_server, mcp_server, # pyright: ignore[reportUnusedImport]
message, message, # pyright: ignore[reportUnusedImport]
model_config, model_config, # pyright: ignore[reportUnusedImport]
ops_trace, ops_trace, # pyright: ignore[reportUnusedImport]
site, site, # pyright: ignore[reportUnusedImport]
statistic, statistic, # pyright: ignore[reportUnusedImport]
workflow, workflow, # pyright: ignore[reportUnusedImport]
workflow_app_log, workflow_app_log, # pyright: ignore[reportUnusedImport]
workflow_draft_variable, workflow_draft_variable, # pyright: ignore[reportUnusedImport]
workflow_run, workflow_run, # pyright: ignore[reportUnusedImport]
workflow_statistic, workflow_statistic, # pyright: ignore[reportUnusedImport]
) )
# Import auth controllers # Import auth controllers
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth, oauth_server from .auth import (
activate, # pyright: ignore[reportUnusedImport]
data_source_bearer_auth, # pyright: ignore[reportUnusedImport]
data_source_oauth, # pyright: ignore[reportUnusedImport]
forgot_password, # pyright: ignore[reportUnusedImport]
login, # pyright: ignore[reportUnusedImport]
oauth, # pyright: ignore[reportUnusedImport]
oauth_server, # pyright: ignore[reportUnusedImport]
)
# Import billing controllers # Import billing controllers
from .billing import billing, compliance from .billing import billing, compliance # pyright: ignore[reportUnusedImport]
# Import datasets controllers # Import datasets controllers
from .datasets import ( from .datasets import (
data_source, data_source, # pyright: ignore[reportUnusedImport]
datasets, datasets, # pyright: ignore[reportUnusedImport]
datasets_document, datasets_document, # pyright: ignore[reportUnusedImport]
datasets_segments, datasets_segments, # pyright: ignore[reportUnusedImport]
external, external, # pyright: ignore[reportUnusedImport]
hit_testing, hit_testing, # pyright: ignore[reportUnusedImport]
metadata, metadata, # pyright: ignore[reportUnusedImport]
website, website, # pyright: ignore[reportUnusedImport]
) )
# Import explore controllers # Import explore controllers
from .explore import ( from .explore import (
installed_app, installed_app, # pyright: ignore[reportUnusedImport]
parameter, parameter, # pyright: ignore[reportUnusedImport]
recommended_app, recommended_app, # pyright: ignore[reportUnusedImport]
saved_message, saved_message, # pyright: ignore[reportUnusedImport]
) )
# Explore Audio # Explore Audio
@ -167,18 +175,18 @@ api.add_resource(
) )
# Import tag controllers # Import tag controllers
from .tag import tags from .tag import tags # pyright: ignore[reportUnusedImport]
# Import workspace controllers # Import workspace controllers
from .workspace import ( from .workspace import (
account, account, # pyright: ignore[reportUnusedImport]
agent_providers, agent_providers, # pyright: ignore[reportUnusedImport]
endpoint, endpoint, # pyright: ignore[reportUnusedImport]
load_balancing_config, load_balancing_config, # pyright: ignore[reportUnusedImport]
members, members, # pyright: ignore[reportUnusedImport]
model_providers, model_providers, # pyright: ignore[reportUnusedImport]
models, models, # pyright: ignore[reportUnusedImport]
plugin, plugin, # pyright: ignore[reportUnusedImport]
tool_providers, tool_providers, # pyright: ignore[reportUnusedImport]
workspace, workspace, # pyright: ignore[reportUnusedImport]
) )

View File

@ -1,8 +1,9 @@
from typing import Any, Optional from typing import Optional
import flask_restx import flask_restx
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource, fields, marshal_with from flask_restx import Resource, fields, marshal_with
from flask_restx._http import HTTPStatus
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@ -40,7 +41,7 @@ def _get_resource(resource_id, tenant_id, resource_model):
).scalar_one_or_none() ).scalar_one_or_none()
if resource is None: if resource is None:
flask_restx.abort(404, message=f"{resource_model.__name__} not found.") flask_restx.abort(HTTPStatus.NOT_FOUND, message=f"{resource_model.__name__} not found.")
return resource return resource
@ -49,7 +50,7 @@ class BaseApiKeyListResource(Resource):
method_decorators = [account_initialization_required, login_required, setup_required] method_decorators = [account_initialization_required, login_required, setup_required]
resource_type: str | None = None resource_type: str | None = None
resource_model: Optional[Any] = None resource_model: Optional[type] = None
resource_id_field: str | None = None resource_id_field: str | None = None
token_prefix: str | None = None token_prefix: str | None = None
max_keys = 10 max_keys = 10
@ -82,7 +83,7 @@ class BaseApiKeyListResource(Resource):
if current_key_count >= self.max_keys: if current_key_count >= self.max_keys:
flask_restx.abort( flask_restx.abort(
400, HTTPStatus.BAD_REQUEST,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.", message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
custom="max_keys_exceeded", custom="max_keys_exceeded",
) )
@ -102,7 +103,7 @@ class BaseApiKeyResource(Resource):
method_decorators = [account_initialization_required, login_required, setup_required] method_decorators = [account_initialization_required, login_required, setup_required]
resource_type: str | None = None resource_type: str | None = None
resource_model: Optional[Any] = None resource_model: Optional[type] = None
resource_id_field: str | None = None resource_id_field: str | None = None
def delete(self, resource_id, api_key_id): def delete(self, resource_id, api_key_id):
@ -126,7 +127,7 @@ class BaseApiKeyResource(Resource):
) )
if key is None: if key is None:
flask_restx.abort(404, message="API key not found") flask_restx.abort(HTTPStatus.NOT_FOUND, message="API key not found")
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit() db.session.commit()

View File

@ -115,6 +115,10 @@ class AppListApi(Resource):
raise BadRequest("mode is required") raise BadRequest("mode is required")
app_service = AppService() app_service = AppService()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
if current_user.current_tenant_id is None:
raise ValueError("current_user.current_tenant_id cannot be None")
app = app_service.create_app(current_user.current_tenant_id, args, current_user) app = app_service.create_app(current_user.current_tenant_id, args, current_user)
return app, 201 return app, 201
@ -161,14 +165,26 @@ class AppApi(Resource):
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
app_model = app_service.update_app(app_model, args) # Construct ArgsDict from parsed arguments
from services.app_service import AppService as AppServiceType
args_dict: AppServiceType.ArgsDict = {
"name": args["name"],
"description": args.get("description", ""),
"icon_type": args.get("icon_type", ""),
"icon": args.get("icon", ""),
"icon_background": args.get("icon_background", ""),
"use_icon_as_answer_icon": args.get("use_icon_as_answer_icon", False),
"max_active_requests": args.get("max_active_requests", 0),
}
app_model = app_service.update_app(app_model, args_dict)
return app_model return app_model
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
def delete(self, app_model): def delete(self, app_model):
"""Delete app""" """Delete app"""
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
@ -224,10 +240,10 @@ class AppCopyApi(Resource):
class AppExportApi(Resource): class AppExportApi(Resource):
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
def get(self, app_model): def get(self, app_model):
"""Export app""" """Export app"""
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
@ -263,7 +279,7 @@ class AppNameApi(Resource):
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
app_model = app_service.update_app_name(app_model, args.get("name")) app_model = app_service.update_app_name(app_model, args["name"])
return app_model return app_model
@ -285,7 +301,7 @@ class AppIconApi(Resource):
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
app_model = app_service.update_app_icon(app_model, args.get("icon"), args.get("icon_background")) app_model = app_service.update_app_icon(app_model, args.get("icon") or "", args.get("icon_background") or "")
return app_model return app_model
@ -306,7 +322,7 @@ class AppSiteStatus(Resource):
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
app_model = app_service.update_app_site_status(app_model, args.get("enable_site")) app_model = app_service.update_app_site_status(app_model, args["enable_site"])
return app_model return app_model
@ -327,7 +343,7 @@ class AppApiStatus(Resource):
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
app_model = app_service.update_app_api_status(app_model, args.get("enable_api")) app_model = app_service.update_app_api_status(app_model, args["enable_api"])
return app_model return app_model

View File

@ -77,10 +77,10 @@ class ChatMessageAudioApi(Resource):
class ChatMessageTextApi(Resource): class ChatMessageTextApi(Resource):
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
def post(self, app_model: App): def post(self, app_model: App):
try: try:
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -125,10 +125,10 @@ class ChatMessageTextApi(Resource):
class TextModesApi(Resource): class TextModesApi(Resource):
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
def get(self, app_model): def get(self, app_model):
try: try:
parser = reqparse.RequestParser() parser = reqparse.RequestParser()

View File

@ -1,6 +1,5 @@
import logging import logging
import flask_login
from flask import request from flask import request
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
@ -29,7 +28,8 @@ from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs import helper from libs import helper
from libs.helper import uuid_value from libs.helper import uuid_value
from libs.login import login_required from libs.login import current_user, login_required
from models import Account
from models.model import AppMode from models.model import AppMode
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError from services.errors.llm import InvokeRateLimitError
@ -56,11 +56,11 @@ class CompletionMessageApi(Resource):
streaming = args["response_mode"] != "blocking" streaming = args["response_mode"] != "blocking"
args["auto_generate_name"] = False args["auto_generate_name"] = False
account = flask_login.current_user
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account or EndUser instance")
response = AppGenerateService.generate( response = AppGenerateService.generate(
app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
) )
return helper.compact_generate_response(response) return helper.compact_generate_response(response)
@ -92,9 +92,9 @@ class CompletionMessageStopApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.COMPLETION) @get_app_model(mode=AppMode.COMPLETION)
def post(self, app_model, task_id): def post(self, app_model, task_id):
account = flask_login.current_user if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
return {"result": "success"}, 200 return {"result": "success"}, 200
@ -123,11 +123,11 @@ class ChatMessageApi(Resource):
if external_trace_id: if external_trace_id:
args["external_trace_id"] = external_trace_id args["external_trace_id"] = external_trace_id
account = flask_login.current_user
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account or EndUser instance")
response = AppGenerateService.generate( response = AppGenerateService.generate(
app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
) )
return helper.compact_generate_response(response) return helper.compact_generate_response(response)
@ -161,9 +161,9 @@ class ChatMessageStopApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def post(self, app_model, task_id): def post(self, app_model, task_id):
account = flask_login.current_user if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
return {"result": "success"}, 200 return {"result": "success"}, 200

View File

@ -22,7 +22,7 @@ from fields.conversation_fields import (
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.helper import DatetimeString from libs.helper import DatetimeString
from libs.login import login_required from libs.login import login_required
from models import Conversation, EndUser, Message, MessageAnnotation from models import Account, Conversation, EndUser, Message, MessageAnnotation
from models.model import AppMode from models.model import AppMode
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError from services.errors.conversation import ConversationNotExistsError
@ -124,6 +124,8 @@ class CompletionConversationDetailApi(Resource):
conversation_id = str(conversation_id) conversation_id = str(conversation_id)
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
ConversationService.delete(app_model, conversation_id, current_user) ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError: except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
@ -282,6 +284,8 @@ class ChatConversationDetailApi(Resource):
conversation_id = str(conversation_id) conversation_id = str(conversation_id)
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
ConversationService.delete(app_model, conversation_id, current_user) ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError: except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")

View File

@ -1,6 +1,5 @@
import logging import logging
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx.inputs import int_range from flask_restx.inputs import int_range
from sqlalchemy import exists, select from sqlalchemy import exists, select
@ -27,7 +26,8 @@ from extensions.ext_database import db
from fields.conversation_fields import annotation_fields, message_detail_fields from fields.conversation_fields import annotation_fields, message_detail_fields
from libs.helper import uuid_value from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import login_required from libs.login import current_user, login_required
from models.account import Account
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.annotation_service import AppAnnotationService from services.annotation_service import AppAnnotationService
from services.errors.conversation import ConversationNotExistsError from services.errors.conversation import ConversationNotExistsError
@ -118,11 +118,14 @@ class ChatMessageListApi(Resource):
class MessageFeedbackApi(Resource): class MessageFeedbackApi(Resource):
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
def post(self, app_model): def post(self, app_model):
if current_user is None:
raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("message_id", required=True, type=uuid_value, location="json") parser.add_argument("message_id", required=True, type=uuid_value, location="json")
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
@ -167,6 +170,8 @@ class MessageAnnotationApi(Resource):
@get_app_model @get_app_model
@marshal_with(annotation_fields) @marshal_with(annotation_fields)
def post(self, app_model): def post(self, app_model):
if not isinstance(current_user, Account):
raise Forbidden()
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
@ -182,10 +187,10 @@ class MessageAnnotationApi(Resource):
class MessageAnnotationCountApi(Resource): class MessageAnnotationCountApi(Resource):
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
def get(self, app_model): def get(self, app_model):
count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count() count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count()

View File

@ -10,7 +10,7 @@ from extensions.ext_database import db
from fields.app_fields import app_site_fields from fields.app_fields import app_site_fields
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.login import login_required from libs.login import login_required
from models import Site from models import Account, Site
def parse_app_site_args(): def parse_app_site_args():
@ -75,6 +75,8 @@ class AppSite(Resource):
if value is not None: if value is not None:
setattr(site, attr_name, value) setattr(site, attr_name, value)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
site.updated_by = current_user.id site.updated_by = current_user.id
site.updated_at = naive_utc_now() site.updated_at = naive_utc_now()
db.session.commit() db.session.commit()
@ -99,6 +101,8 @@ class AppSiteAccessTokenReset(Resource):
raise NotFound raise NotFound
site.code = Site.generate_code(16) site.code = Site.generate_code(16)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
site.updated_by = current_user.id site.updated_by = current_user.id
site.updated_at = naive_utc_now() site.updated_at = naive_utc_now()
db.session.commit() db.session.commit()

View File

@ -18,10 +18,10 @@ from models import AppMode, Message
class DailyMessageStatistic(Resource): class DailyMessageStatistic(Resource):
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
def get(self, app_model): def get(self, app_model):
account = current_user account = current_user
@ -75,10 +75,10 @@ WHERE
class DailyConversationStatistic(Resource): class DailyConversationStatistic(Resource):
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
def get(self, app_model): def get(self, app_model):
account = current_user account = current_user
@ -127,10 +127,10 @@ class DailyConversationStatistic(Resource):
class DailyTerminalsStatistic(Resource): class DailyTerminalsStatistic(Resource):
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
def get(self, app_model): def get(self, app_model):
account = current_user account = current_user
@ -184,10 +184,10 @@ WHERE
class DailyTokenCostStatistic(Resource): class DailyTokenCostStatistic(Resource):
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
def get(self, app_model): def get(self, app_model):
account = current_user account = current_user
@ -320,10 +320,10 @@ ORDER BY
class UserSatisfactionRateStatistic(Resource): class UserSatisfactionRateStatistic(Resource):
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
def get(self, app_model): def get(self, app_model):
account = current_user account = current_user
@ -443,10 +443,10 @@ WHERE
class TokensPerSecondStatistic(Resource): class TokensPerSecondStatistic(Resource):
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
def get(self, app_model): def get(self, app_model):
account = current_user account = current_user

View File

@ -18,10 +18,10 @@ from models.model import AppMode
class WorkflowDailyRunsStatistic(Resource): class WorkflowDailyRunsStatistic(Resource):
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
def get(self, app_model): def get(self, app_model):
account = current_user account = current_user
@ -80,10 +80,10 @@ WHERE
class WorkflowDailyTerminalsStatistic(Resource): class WorkflowDailyTerminalsStatistic(Resource):
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
def get(self, app_model): def get(self, app_model):
account = current_user account = current_user
@ -142,10 +142,10 @@ WHERE
class WorkflowDailyTokenCostStatistic(Resource): class WorkflowDailyTokenCostStatistic(Resource):
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
def get(self, app_model): def get(self, app_model):
account = current_user account = current_user

View File

@ -77,6 +77,9 @@ class OAuthCallback(Resource):
if state: if state:
invite_token = state invite_token = state
if not code:
return {"error": "Authorization code is required"}, 400
try: try:
token = oauth_provider.get_access_token(code) token = oauth_provider.get_access_token(code)
user_info = oauth_provider.get_user_info(token) user_info = oauth_provider.get_user_info(token)
@ -86,7 +89,7 @@ class OAuthCallback(Resource):
return {"error": "OAuth process failed"}, 400 return {"error": "OAuth process failed"}, 400
if invite_token and RegisterService.is_valid_invite_token(invite_token): if invite_token and RegisterService.is_valid_invite_token(invite_token):
invitation = RegisterService._get_invitation_by_token(token=invite_token) invitation = RegisterService.get_invitation_by_token(token=invite_token)
if invitation: if invitation:
invitation_email = invitation.get("email", None) invitation_email = invitation.get("email", None)
if invitation_email != user_info.email: if invitation_email != user_info.email:

View File

@ -1,6 +1,5 @@
import logging import logging
from flask_login import current_user
from flask_restx import reqparse from flask_restx import reqparse
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
@ -28,6 +27,8 @@ from extensions.ext_database import db
from libs import helper from libs import helper
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.helper import uuid_value from libs.helper import uuid_value
from libs.login import current_user
from models import Account
from models.model import AppMode from models.model import AppMode
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError from services.errors.llm import InvokeRateLimitError
@ -57,6 +58,8 @@ class CompletionApi(InstalledAppResource):
db.session.commit() db.session.commit()
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
response = AppGenerateService.generate( response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
) )
@ -90,6 +93,8 @@ class CompletionStopApi(InstalledAppResource):
if app_model.mode != "completion": if app_model.mode != "completion":
raise NotCompletionAppError() raise NotCompletionAppError()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {"result": "success"}, 200 return {"result": "success"}, 200
@ -117,6 +122,8 @@ class ChatApi(InstalledAppResource):
db.session.commit() db.session.commit()
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
response = AppGenerateService.generate( response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
) )
@ -153,6 +160,8 @@ class ChatStopApi(InstalledAppResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {"result": "success"}, 200 return {"result": "success"}, 200

View File

@ -1,4 +1,3 @@
from flask_login import current_user
from flask_restx import marshal_with, reqparse from flask_restx import marshal_with, reqparse
from flask_restx.inputs import int_range from flask_restx.inputs import int_range
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -10,6 +9,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db from extensions.ext_database import db
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value from libs.helper import uuid_value
from libs.login import current_user
from models import Account
from models.model import AppMode from models.model import AppMode
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
@ -35,6 +36,8 @@ class ConversationListApi(InstalledAppResource):
pinned = args["pinned"] == "true" pinned = args["pinned"] == "true"
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
with Session(db.engine) as session: with Session(db.engine) as session:
return WebConversationService.pagination_by_last_id( return WebConversationService.pagination_by_last_id(
session=session, session=session,
@ -58,6 +61,8 @@ class ConversationApi(InstalledAppResource):
conversation_id = str(c_id) conversation_id = str(c_id)
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
ConversationService.delete(app_model, conversation_id, current_user) ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError: except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
@ -81,6 +86,8 @@ class ConversationRenameApi(InstalledAppResource):
args = parser.parse_args() args = parser.parse_args()
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
return ConversationService.rename( return ConversationService.rename(
app_model, conversation_id, current_user, args["name"], args["auto_generate"] app_model, conversation_id, current_user, args["name"], args["auto_generate"]
) )
@ -98,6 +105,8 @@ class ConversationPinApi(InstalledAppResource):
conversation_id = str(c_id) conversation_id = str(c_id)
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
WebConversationService.pin(app_model, conversation_id, current_user) WebConversationService.pin(app_model, conversation_id, current_user)
except ConversationNotExistsError: except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
@ -113,6 +122,8 @@ class ConversationUnPinApi(InstalledAppResource):
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
WebConversationService.unpin(app_model, conversation_id, current_user) WebConversationService.unpin(app_model, conversation_id, current_user)
return {"result": "success"} return {"result": "success"}

View File

@ -2,7 +2,6 @@ import logging
from typing import Any from typing import Any
from flask import request from flask import request
from flask_login import current_user
from flask_restx import Resource, inputs, marshal_with, reqparse from flask_restx import Resource, inputs, marshal_with, reqparse
from sqlalchemy import and_ from sqlalchemy import and_
from werkzeug.exceptions import BadRequest, Forbidden, NotFound from werkzeug.exceptions import BadRequest, Forbidden, NotFound
@ -13,8 +12,8 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
from extensions.ext_database import db from extensions.ext_database import db
from fields.installed_app_fields import installed_app_list_fields from fields.installed_app_fields import installed_app_list_fields
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.login import login_required from libs.login import current_user, login_required
from models import App, InstalledApp, RecommendedApp from models import Account, App, InstalledApp, RecommendedApp
from services.account_service import TenantService from services.account_service import TenantService
from services.app_service import AppService from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
@ -29,6 +28,8 @@ class InstalledAppsListApi(Resource):
@marshal_with(installed_app_list_fields) @marshal_with(installed_app_list_fields)
def get(self): def get(self):
app_id = request.args.get("app_id", default=None, type=str) app_id = request.args.get("app_id", default=None, type=str)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
current_tenant_id = current_user.current_tenant_id current_tenant_id = current_user.current_tenant_id
if app_id: if app_id:
@ -40,6 +41,8 @@ class InstalledAppsListApi(Resource):
else: else:
installed_apps = db.session.query(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id).all() installed_apps = db.session.query(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id).all()
if current_user.current_tenant is None:
raise ValueError("current_user.current_tenant must not be None")
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
installed_app_list: list[dict[str, Any]] = [ installed_app_list: list[dict[str, Any]] = [
{ {
@ -115,6 +118,8 @@ class InstalledAppsListApi(Resource):
if recommended_app is None: if recommended_app is None:
raise NotFound("App not found") raise NotFound("App not found")
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
current_tenant_id = current_user.current_tenant_id current_tenant_id = current_user.current_tenant_id
app = db.session.query(App).where(App.id == args["app_id"]).first() app = db.session.query(App).where(App.id == args["app_id"]).first()
@ -154,6 +159,8 @@ class InstalledAppApi(InstalledAppResource):
""" """
def delete(self, installed_app): def delete(self, installed_app):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
if installed_app.app_owner_tenant_id == current_user.current_tenant_id: if installed_app.app_owner_tenant_id == current_user.current_tenant_id:
raise BadRequest("You can't uninstall an app owned by the current tenant") raise BadRequest("You can't uninstall an app owned by the current tenant")

View File

@ -1,6 +1,5 @@
import logging import logging
from flask_login import current_user
from flask_restx import marshal_with, reqparse from flask_restx import marshal_with, reqparse
from flask_restx.inputs import int_range from flask_restx.inputs import int_range
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
@ -24,6 +23,8 @@ from core.model_runtime.errors.invoke import InvokeError
from fields.message_fields import message_infinite_scroll_pagination_fields from fields.message_fields import message_infinite_scroll_pagination_fields
from libs import helper from libs import helper
from libs.helper import uuid_value from libs.helper import uuid_value
from libs.login import current_user
from models import Account
from models.model import AppMode from models.model import AppMode
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.app import MoreLikeThisDisabledError from services.errors.app import MoreLikeThisDisabledError
@ -54,6 +55,8 @@ class MessageListApi(InstalledAppResource):
args = parser.parse_args() args = parser.parse_args()
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
return MessageService.pagination_by_first_id( return MessageService.pagination_by_first_id(
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"] app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
) )
@ -75,6 +78,8 @@ class MessageFeedbackApi(InstalledAppResource):
args = parser.parse_args() args = parser.parse_args()
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
MessageService.create_feedback( MessageService.create_feedback(
app_model=app_model, app_model=app_model,
message_id=message_id, message_id=message_id,
@ -105,6 +110,8 @@ class MessageMoreLikeThisApi(InstalledAppResource):
streaming = args["response_mode"] == "streaming" streaming = args["response_mode"] == "streaming"
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
response = AppGenerateService.generate_more_like_this( response = AppGenerateService.generate_more_like_this(
app_model=app_model, app_model=app_model,
user=current_user, user=current_user,
@ -142,6 +149,8 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
message_id = str(message_id) message_id = str(message_id)
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
questions = MessageService.get_suggested_questions_after_answer( questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
) )

View File

@ -1,11 +1,10 @@
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with, reqparse
from constants.languages import languages from constants.languages import languages
from controllers.console import api from controllers.console import api
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from libs.helper import AppIconUrlField from libs.helper import AppIconUrlField
from libs.login import login_required from libs.login import current_user, login_required
from services.recommended_app_service import RecommendedAppService from services.recommended_app_service import RecommendedAppService
app_fields = { app_fields = {
@ -46,8 +45,9 @@ class RecommendedAppListApi(Resource):
parser.add_argument("language", type=str, location="args") parser.add_argument("language", type=str, location="args")
args = parser.parse_args() args = parser.parse_args()
if args.get("language") and args.get("language") in languages: language = args.get("language")
language_prefix = args.get("language") if language and language in languages:
language_prefix = language
elif current_user and current_user.interface_language: elif current_user and current_user.interface_language:
language_prefix = current_user.interface_language language_prefix = current_user.interface_language
else: else:

View File

@ -1,4 +1,3 @@
from flask_login import current_user
from flask_restx import fields, marshal_with, reqparse from flask_restx import fields, marshal_with, reqparse
from flask_restx.inputs import int_range from flask_restx.inputs import int_range
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
@ -8,6 +7,8 @@ from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import message_file_fields from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from libs.login import current_user
from models import Account
from services.errors.message import MessageNotExistsError from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService from services.saved_message_service import SavedMessageService
@ -42,6 +43,8 @@ class SavedMessageListApi(InstalledAppResource):
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args() args = parser.parse_args()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"]) return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"])
def post(self, installed_app): def post(self, installed_app):
@ -54,6 +57,8 @@ class SavedMessageListApi(InstalledAppResource):
args = parser.parse_args() args = parser.parse_args()
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
SavedMessageService.save(app_model, current_user, args["message_id"]) SavedMessageService.save(app_model, current_user, args["message_id"])
except MessageNotExistsError: except MessageNotExistsError:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")
@ -70,6 +75,8 @@ class SavedMessageApi(InstalledAppResource):
if app_model.mode != "completion": if app_model.mode != "completion":
raise NotCompletionAppError() raise NotCompletionAppError()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
SavedMessageService.delete(app_model, current_user, message_id) SavedMessageService.delete(app_model, current_user, message_id)
return {"result": "success"}, 204 return {"result": "success"}, 204

View File

@ -22,6 +22,7 @@ from controllers.console.wraps import (
) )
from fields.file_fields import file_fields, upload_config_fields from fields.file_fields import file_fields, upload_config_fields
from libs.login import login_required from libs.login import login_required
from models import Account
from services.file_service import FileService from services.file_service import FileService
PREVIEW_WORDS_LIMIT = 3000 PREVIEW_WORDS_LIMIT = 3000
@ -68,6 +69,8 @@ class FileApi(Resource):
source = None source = None
try: try:
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
upload_file = FileService.upload_file( upload_file = FileService.upload_file(
filename=file.filename, filename=file.filename,
content=file.read(), content=file.read(),

View File

@ -34,14 +34,14 @@ class VersionApi(Resource):
return result return result
try: try:
response = requests.get(check_update_url, {"current_version": args.get("current_version")}, timeout=(3, 10)) response = requests.get(check_update_url, {"current_version": args["current_version"]}, timeout=(3, 10))
except Exception as error: except Exception as error:
logger.warning("Check update version error: %s.", str(error)) logger.warning("Check update version error: %s.", str(error))
result["version"] = args.get("current_version") result["version"] = args["current_version"]
return result return result
content = json.loads(response.content) content = json.loads(response.content)
if _has_new_version(latest_version=content["version"], current_version=f"{args.get('current_version')}"): if _has_new_version(latest_version=content["version"], current_version=f"{args['current_version']}"):
result["version"] = content["version"] result["version"] = content["version"]
result["release_date"] = content["releaseDate"] result["release_date"] = content["releaseDate"]
result["release_notes"] = content["releaseNotes"] result["release_notes"] = content["releaseNotes"]

View File

@ -49,6 +49,8 @@ class AccountInitApi(Resource):
@setup_required @setup_required
@login_required @login_required
def post(self): def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user account = current_user
if account.status == "active": if account.status == "active":
@ -102,6 +104,8 @@ class AccountProfileApi(Resource):
@marshal_with(account_fields) @marshal_with(account_fields)
@enterprise_license_required @enterprise_license_required
def get(self): def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
return current_user return current_user
@ -111,6 +115,8 @@ class AccountNameApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json") parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
@ -130,6 +136,8 @@ class AccountAvatarApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("avatar", type=str, required=True, location="json") parser.add_argument("avatar", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
@ -145,6 +153,8 @@ class AccountInterfaceLanguageApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("interface_language", type=supported_language, required=True, location="json") parser.add_argument("interface_language", type=supported_language, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
@ -160,6 +170,8 @@ class AccountInterfaceThemeApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json") parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
@ -175,6 +187,8 @@ class AccountTimezoneApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("timezone", type=str, required=True, location="json") parser.add_argument("timezone", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
@ -194,6 +208,8 @@ class AccountPasswordApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("password", type=str, required=False, location="json") parser.add_argument("password", type=str, required=False, location="json")
parser.add_argument("new_password", type=str, required=True, location="json") parser.add_argument("new_password", type=str, required=True, location="json")
@ -228,6 +244,8 @@ class AccountIntegrateApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(integrate_list_fields) @marshal_with(integrate_list_fields)
def get(self): def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user account = current_user
account_integrates = db.session.query(AccountIntegrate).where(AccountIntegrate.account_id == account.id).all() account_integrates = db.session.query(AccountIntegrate).where(AccountIntegrate.account_id == account.id).all()
@ -268,6 +286,8 @@ class AccountDeleteVerifyApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user account = current_user
token, code = AccountService.generate_account_deletion_verification_code(account) token, code = AccountService.generate_account_deletion_verification_code(account)
@ -281,6 +301,8 @@ class AccountDeleteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -321,6 +343,8 @@ class EducationVerifyApi(Resource):
@cloud_edition_billing_enabled @cloud_edition_billing_enabled
@marshal_with(verify_fields) @marshal_with(verify_fields)
def get(self): def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user account = current_user
return BillingService.EducationIdentity.verify(account.id, account.email) return BillingService.EducationIdentity.verify(account.id, account.email)
@ -340,6 +364,8 @@ class EducationApi(Resource):
@only_edition_cloud @only_edition_cloud
@cloud_edition_billing_enabled @cloud_edition_billing_enabled
def post(self): def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -357,6 +383,8 @@ class EducationApi(Resource):
@cloud_edition_billing_enabled @cloud_edition_billing_enabled
@marshal_with(status_fields) @marshal_with(status_fields)
def get(self): def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user account = current_user
res = BillingService.EducationIdentity.status(account.id) res = BillingService.EducationIdentity.status(account.id)
@ -421,6 +449,8 @@ class ChangeEmailSendEmailApi(Resource):
raise InvalidTokenError() raise InvalidTokenError()
user_email = reset_data.get("email", "") user_email = reset_data.get("email", "")
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if user_email != current_user.email: if user_email != current_user.email:
raise InvalidEmailError() raise InvalidEmailError()
else: else:
@ -501,6 +531,8 @@ class ChangeEmailResetApi(Resource):
AccountService.revoke_change_email_token(args["token"]) AccountService.revoke_change_email_token(args["token"])
old_email = reset_data.get("old_email", "") old_email = reset_data.get("old_email", "")
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if current_user.email != old_email: if current_user.email != old_email:
raise AccountNotFound() raise AccountNotFound()

View File

@ -1,8 +1,8 @@
from urllib import parse from urllib import parse
from flask import request from flask import abort, request
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource, abort, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
import services import services
from configs import dify_config from configs import dify_config
@ -41,6 +41,10 @@ class MemberListApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_with_role_list_fields) @marshal_with(account_with_role_list_fields)
def get(self): def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_tenant_members(current_user.current_tenant) members = TenantService.get_tenant_members(current_user.current_tenant)
return {"result": "success", "accounts": members}, 200 return {"result": "success", "accounts": members}, 200
@ -65,7 +69,11 @@ class MemberInviteEmailApi(Resource):
if not TenantAccountRole.is_non_owner_role(invitee_role): if not TenantAccountRole.is_non_owner_role(invitee_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400 return {"code": "invalid-role", "message": "Invalid role"}, 400
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
inviter = current_user inviter = current_user
if not inviter.current_tenant:
raise ValueError("No current tenant")
invitation_results = [] invitation_results = []
console_web_url = dify_config.CONSOLE_WEB_URL console_web_url = dify_config.CONSOLE_WEB_URL
@ -76,6 +84,8 @@ class MemberInviteEmailApi(Resource):
for invitee_email in invitee_emails: for invitee_email in invitee_emails:
try: try:
if not inviter.current_tenant:
raise ValueError("No current tenant")
token = RegisterService.invite_new_member( token = RegisterService.invite_new_member(
inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter
) )
@ -97,7 +107,7 @@ class MemberInviteEmailApi(Resource):
return { return {
"result": "success", "result": "success",
"invitation_results": invitation_results, "invitation_results": invitation_results,
"tenant_id": str(current_user.current_tenant.id), "tenant_id": str(inviter.current_tenant.id) if inviter.current_tenant else "",
}, 201 }, 201
@ -108,6 +118,10 @@ class MemberCancelInviteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, member_id): def delete(self, member_id):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant:
raise ValueError("No current tenant")
member = db.session.query(Account).where(Account.id == str(member_id)).first() member = db.session.query(Account).where(Account.id == str(member_id)).first()
if member is None: if member is None:
abort(404) abort(404)
@ -123,7 +137,10 @@ class MemberCancelInviteApi(Resource):
except Exception as e: except Exception as e:
raise ValueError(str(e)) raise ValueError(str(e))
return {"result": "success", "tenant_id": str(current_user.current_tenant.id)}, 200 return {
"result": "success",
"tenant_id": str(current_user.current_tenant.id) if current_user.current_tenant else "",
}, 200
class MemberUpdateRoleApi(Resource): class MemberUpdateRoleApi(Resource):
@ -141,6 +158,10 @@ class MemberUpdateRoleApi(Resource):
if not TenantAccountRole.is_valid_role(new_role): if not TenantAccountRole.is_valid_role(new_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400 return {"code": "invalid-role", "message": "Invalid role"}, 400
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant:
raise ValueError("No current tenant")
member = db.session.get(Account, str(member_id)) member = db.session.get(Account, str(member_id))
if not member: if not member:
abort(404) abort(404)
@ -164,6 +185,10 @@ class DatasetOperatorMemberListApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_with_role_list_fields) @marshal_with(account_with_role_list_fields)
def get(self): def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_dataset_operator_members(current_user.current_tenant) members = TenantService.get_dataset_operator_members(current_user.current_tenant)
return {"result": "success", "accounts": members}, 200 return {"result": "success", "accounts": members}, 200
@ -184,6 +209,10 @@ class SendOwnerTransferEmailApi(Resource):
raise EmailSendIpLimitError() raise EmailSendIpLimitError()
# check if the current user is the owner of the workspace # check if the current user is the owner of the workspace
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant): if not TenantService.is_owner(current_user, current_user.current_tenant):
raise NotOwnerError() raise NotOwnerError()
@ -198,7 +227,7 @@ class SendOwnerTransferEmailApi(Resource):
account=current_user, account=current_user,
email=email, email=email,
language=language, language=language,
workspace_name=current_user.current_tenant.name, workspace_name=current_user.current_tenant.name if current_user.current_tenant else "",
) )
return {"result": "success", "data": token} return {"result": "success", "data": token}
@ -215,6 +244,10 @@ class OwnerTransferCheckApi(Resource):
parser.add_argument("token", type=str, required=True, nullable=False, location="json") parser.add_argument("token", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
# check if the current user is the owner of the workspace # check if the current user is the owner of the workspace
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant): if not TenantService.is_owner(current_user, current_user.current_tenant):
raise NotOwnerError() raise NotOwnerError()
@ -256,6 +289,10 @@ class OwnerTransfer(Resource):
args = parser.parse_args() args = parser.parse_args()
# check if the current user is the owner of the workspace # check if the current user is the owner of the workspace
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant): if not TenantService.is_owner(current_user, current_user.current_tenant):
raise NotOwnerError() raise NotOwnerError()
@ -274,9 +311,11 @@ class OwnerTransfer(Resource):
member = db.session.get(Account, str(member_id)) member = db.session.get(Account, str(member_id))
if not member: if not member:
abort(404) abort(404)
else: return # Never reached, but helps type checker
member_account = member
if not TenantService.is_member(member_account, current_user.current_tenant): if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_member(member, current_user.current_tenant):
raise MemberNotInTenantError() raise MemberNotInTenantError()
try: try:
@ -286,13 +325,13 @@ class OwnerTransfer(Resource):
AccountService.send_new_owner_transfer_notify_email( AccountService.send_new_owner_transfer_notify_email(
account=member, account=member,
email=member.email, email=member.email,
workspace_name=current_user.current_tenant.name, workspace_name=current_user.current_tenant.name if current_user.current_tenant else "",
) )
AccountService.send_old_owner_transfer_notify_email( AccountService.send_old_owner_transfer_notify_email(
account=current_user, account=current_user,
email=current_user.email, email=current_user.email,
workspace_name=current_user.current_tenant.name, workspace_name=current_user.current_tenant.name if current_user.current_tenant else "",
new_owner_email=member.email, new_owner_email=member.email,
) )

View File

@ -12,6 +12,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import StrLen, uuid_value from libs.helper import StrLen, uuid_value
from libs.login import login_required from libs.login import login_required
from models.account import Account
from services.billing_service import BillingService from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService from services.model_provider_service import ModelProviderService
@ -21,6 +22,10 @@ class ModelProviderListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -45,6 +50,10 @@ class ModelProviderCredentialApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider: str): def get(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
# if credential_id is not provided, return current used credential # if credential_id is not provided, return current used credential
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -62,6 +71,8 @@ class ModelProviderCredentialApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
@ -72,6 +83,8 @@ class ModelProviderCredentialApi(Resource):
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
try: try:
model_provider_service.create_provider_credential( model_provider_service.create_provider_credential(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
@ -88,6 +101,8 @@ class ModelProviderCredentialApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def put(self, provider: str): def put(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
@ -99,6 +114,8 @@ class ModelProviderCredentialApi(Resource):
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
try: try:
model_provider_service.update_provider_credential( model_provider_service.update_provider_credential(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
@ -116,12 +133,16 @@ class ModelProviderCredentialApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, provider: str): def delete(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
model_provider_service.remove_provider_credential( model_provider_service.remove_provider_credential(
tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"] tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"]
@ -135,12 +156,16 @@ class ModelProviderCredentialSwitchApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
service = ModelProviderService() service = ModelProviderService()
service.switch_active_provider_credential( service.switch_active_provider_credential(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
@ -155,10 +180,14 @@ class ModelProviderValidateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
@ -205,9 +234,13 @@ class PreferredProviderTypeUpdateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -236,7 +269,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
def get(self, provider: str): def get(self, provider: str):
if provider != "anthropic": if provider != "anthropic":
raise ValueError(f"provider name {provider} is invalid") raise ValueError(f"provider name {provider} is invalid")
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
BillingService.is_tenant_owner_or_admin(current_user) BillingService.is_tenant_owner_or_admin(current_user)
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
data = BillingService.get_model_provider_payment_link( data = BillingService.get_model_provider_payment_link(
provider_name=provider, provider_name=provider,
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,

View File

@ -25,7 +25,7 @@ from controllers.console.wraps import (
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import TimestampField from libs.helper import TimestampField
from libs.login import login_required from libs.login import login_required
from models.account import Tenant, TenantStatus from models.account import Account, Tenant, TenantStatus
from services.account_service import TenantService from services.account_service import TenantService
from services.feature_service import FeatureService from services.feature_service import FeatureService
from services.file_service import FileService from services.file_service import FileService
@ -70,6 +70,8 @@ class TenantListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
tenants = TenantService.get_join_tenants(current_user) tenants = TenantService.get_join_tenants(current_user)
tenant_dicts = [] tenant_dicts = []
@ -83,7 +85,7 @@ class TenantListApi(Resource):
"status": tenant.status, "status": tenant.status,
"created_at": tenant.created_at, "created_at": tenant.created_at,
"plan": features.billing.subscription.plan if features.billing.enabled else "sandbox", "plan": features.billing.subscription.plan if features.billing.enabled else "sandbox",
"current": tenant.id == current_user.current_tenant_id, "current": tenant.id == current_user.current_tenant_id if current_user.current_tenant_id else False,
} }
tenant_dicts.append(tenant_dict) tenant_dicts.append(tenant_dict)
@ -125,7 +127,11 @@ class TenantApi(Resource):
if request.path == "/info": if request.path == "/info":
logger.warning("Deprecated URL /info was used.") logger.warning("Deprecated URL /info was used.")
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
tenant = current_user.current_tenant tenant = current_user.current_tenant
if not tenant:
raise ValueError("No current tenant")
if tenant.status == TenantStatus.ARCHIVE: if tenant.status == TenantStatus.ARCHIVE:
tenants = TenantService.get_join_tenants(current_user) tenants = TenantService.get_join_tenants(current_user)
@ -137,6 +143,8 @@ class TenantApi(Resource):
else: else:
raise Unauthorized("workspace is archived") raise Unauthorized("workspace is archived")
if not tenant:
raise ValueError("No tenant available")
return WorkspaceService.get_tenant_info(tenant), 200 return WorkspaceService.get_tenant_info(tenant), 200
@ -145,6 +153,8 @@ class SwitchWorkspaceApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("tenant_id", type=str, required=True, location="json") parser.add_argument("tenant_id", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
@ -168,11 +178,15 @@ class CustomConfigWorkspaceApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("workspace_custom") @cloud_edition_billing_resource_check("workspace_custom")
def post(self): def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("remove_webapp_brand", type=bool, location="json") parser.add_argument("remove_webapp_brand", type=bool, location="json")
parser.add_argument("replace_webapp_logo", type=str, location="json") parser.add_argument("replace_webapp_logo", type=str, location="json")
args = parser.parse_args() args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant = db.get_or_404(Tenant, current_user.current_tenant_id) tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
custom_config_dict = { custom_config_dict = {
@ -194,6 +208,8 @@ class WebappLogoWorkspaceApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("workspace_custom") @cloud_edition_billing_resource_check("workspace_custom")
def post(self): def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
# check file # check file
if "file" not in request.files: if "file" not in request.files:
raise NoFileUploadedError() raise NoFileUploadedError()
@ -232,10 +248,14 @@ class WorkspaceInfoApi(Resource):
@account_initialization_required @account_initialization_required
# Change workspace name # Change workspace name
def post(self): def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json") parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant = db.get_or_404(Tenant, current_user.current_tenant_id) tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
tenant.name = args["name"] tenant.name = args["name"]
db.session.commit() db.session.commit()

View File

@ -15,6 +15,6 @@ api = ExternalApi(
files_ns = Namespace("files", description="File operations", path="/") files_ns = Namespace("files", description="File operations", path="/")
from . import image_preview, tool_files, upload from . import image_preview, tool_files, upload # pyright: ignore[reportUnusedImport]
api.add_namespace(files_ns) api.add_namespace(files_ns)

View File

@ -16,8 +16,8 @@ api = ExternalApi(
# Create namespace # Create namespace
inner_api_ns = Namespace("inner_api", description="Internal API operations", path="/") inner_api_ns = Namespace("inner_api", description="Internal API operations", path="/")
from . import mail from . import mail as _mail # pyright: ignore[reportUnusedImport]
from .plugin import plugin from .plugin import plugin as _plugin # pyright: ignore[reportUnusedImport]
from .workspace import workspace from .workspace import workspace as _workspace # pyright: ignore[reportUnusedImport]
api.add_namespace(inner_api_ns) api.add_namespace(inner_api_ns)

View File

@ -37,9 +37,9 @@ from models.model import EndUser
@inner_api_ns.route("/invoke/llm") @inner_api_ns.route("/invoke/llm")
class PluginInvokeLLMApi(Resource): class PluginInvokeLLMApi(Resource):
@get_user_tenant
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeLLM) @plugin_data(payload_type=RequestInvokeLLM)
@inner_api_ns.doc("plugin_invoke_llm") @inner_api_ns.doc("plugin_invoke_llm")
@inner_api_ns.doc(description="Invoke LLM models through plugin interface") @inner_api_ns.doc(description="Invoke LLM models through plugin interface")
@ -60,9 +60,9 @@ class PluginInvokeLLMApi(Resource):
@inner_api_ns.route("/invoke/llm/structured-output") @inner_api_ns.route("/invoke/llm/structured-output")
class PluginInvokeLLMWithStructuredOutputApi(Resource): class PluginInvokeLLMWithStructuredOutputApi(Resource):
@get_user_tenant
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeLLMWithStructuredOutput) @plugin_data(payload_type=RequestInvokeLLMWithStructuredOutput)
@inner_api_ns.doc("plugin_invoke_llm_structured") @inner_api_ns.doc("plugin_invoke_llm_structured")
@inner_api_ns.doc(description="Invoke LLM models with structured output through plugin interface") @inner_api_ns.doc(description="Invoke LLM models with structured output through plugin interface")
@ -85,9 +85,9 @@ class PluginInvokeLLMWithStructuredOutputApi(Resource):
@inner_api_ns.route("/invoke/text-embedding") @inner_api_ns.route("/invoke/text-embedding")
class PluginInvokeTextEmbeddingApi(Resource): class PluginInvokeTextEmbeddingApi(Resource):
@get_user_tenant
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeTextEmbedding) @plugin_data(payload_type=RequestInvokeTextEmbedding)
@inner_api_ns.doc("plugin_invoke_text_embedding") @inner_api_ns.doc("plugin_invoke_text_embedding")
@inner_api_ns.doc(description="Invoke text embedding models through plugin interface") @inner_api_ns.doc(description="Invoke text embedding models through plugin interface")
@ -115,9 +115,9 @@ class PluginInvokeTextEmbeddingApi(Resource):
@inner_api_ns.route("/invoke/rerank") @inner_api_ns.route("/invoke/rerank")
class PluginInvokeRerankApi(Resource): class PluginInvokeRerankApi(Resource):
@get_user_tenant
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeRerank) @plugin_data(payload_type=RequestInvokeRerank)
@inner_api_ns.doc("plugin_invoke_rerank") @inner_api_ns.doc("plugin_invoke_rerank")
@inner_api_ns.doc(description="Invoke rerank models through plugin interface") @inner_api_ns.doc(description="Invoke rerank models through plugin interface")
@ -141,9 +141,9 @@ class PluginInvokeRerankApi(Resource):
@inner_api_ns.route("/invoke/tts") @inner_api_ns.route("/invoke/tts")
class PluginInvokeTTSApi(Resource): class PluginInvokeTTSApi(Resource):
@get_user_tenant
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeTTS) @plugin_data(payload_type=RequestInvokeTTS)
@inner_api_ns.doc("plugin_invoke_tts") @inner_api_ns.doc("plugin_invoke_tts")
@inner_api_ns.doc(description="Invoke text-to-speech models through plugin interface") @inner_api_ns.doc(description="Invoke text-to-speech models through plugin interface")
@ -168,9 +168,9 @@ class PluginInvokeTTSApi(Resource):
@inner_api_ns.route("/invoke/speech2text") @inner_api_ns.route("/invoke/speech2text")
class PluginInvokeSpeech2TextApi(Resource): class PluginInvokeSpeech2TextApi(Resource):
@get_user_tenant
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeSpeech2Text) @plugin_data(payload_type=RequestInvokeSpeech2Text)
@inner_api_ns.doc("plugin_invoke_speech2text") @inner_api_ns.doc("plugin_invoke_speech2text")
@inner_api_ns.doc(description="Invoke speech-to-text models through plugin interface") @inner_api_ns.doc(description="Invoke speech-to-text models through plugin interface")
@ -194,9 +194,9 @@ class PluginInvokeSpeech2TextApi(Resource):
@inner_api_ns.route("/invoke/moderation") @inner_api_ns.route("/invoke/moderation")
class PluginInvokeModerationApi(Resource): class PluginInvokeModerationApi(Resource):
@get_user_tenant
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeModeration) @plugin_data(payload_type=RequestInvokeModeration)
@inner_api_ns.doc("plugin_invoke_moderation") @inner_api_ns.doc("plugin_invoke_moderation")
@inner_api_ns.doc(description="Invoke moderation models through plugin interface") @inner_api_ns.doc(description="Invoke moderation models through plugin interface")
@ -220,9 +220,9 @@ class PluginInvokeModerationApi(Resource):
@inner_api_ns.route("/invoke/tool") @inner_api_ns.route("/invoke/tool")
class PluginInvokeToolApi(Resource): class PluginInvokeToolApi(Resource):
@get_user_tenant
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeTool) @plugin_data(payload_type=RequestInvokeTool)
@inner_api_ns.doc("plugin_invoke_tool") @inner_api_ns.doc("plugin_invoke_tool")
@inner_api_ns.doc(description="Invoke tools through plugin interface") @inner_api_ns.doc(description="Invoke tools through plugin interface")
@ -252,9 +252,9 @@ class PluginInvokeToolApi(Resource):
@inner_api_ns.route("/invoke/parameter-extractor") @inner_api_ns.route("/invoke/parameter-extractor")
class PluginInvokeParameterExtractorNodeApi(Resource): class PluginInvokeParameterExtractorNodeApi(Resource):
@get_user_tenant
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeParameterExtractorNode) @plugin_data(payload_type=RequestInvokeParameterExtractorNode)
@inner_api_ns.doc("plugin_invoke_parameter_extractor") @inner_api_ns.doc("plugin_invoke_parameter_extractor")
@inner_api_ns.doc(description="Invoke parameter extractor node through plugin interface") @inner_api_ns.doc(description="Invoke parameter extractor node through plugin interface")
@ -285,9 +285,9 @@ class PluginInvokeParameterExtractorNodeApi(Resource):
@inner_api_ns.route("/invoke/question-classifier") @inner_api_ns.route("/invoke/question-classifier")
class PluginInvokeQuestionClassifierNodeApi(Resource): class PluginInvokeQuestionClassifierNodeApi(Resource):
@get_user_tenant
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeQuestionClassifierNode) @plugin_data(payload_type=RequestInvokeQuestionClassifierNode)
@inner_api_ns.doc("plugin_invoke_question_classifier") @inner_api_ns.doc("plugin_invoke_question_classifier")
@inner_api_ns.doc(description="Invoke question classifier node through plugin interface") @inner_api_ns.doc(description="Invoke question classifier node through plugin interface")
@ -318,9 +318,9 @@ class PluginInvokeQuestionClassifierNodeApi(Resource):
@inner_api_ns.route("/invoke/app") @inner_api_ns.route("/invoke/app")
class PluginInvokeAppApi(Resource): class PluginInvokeAppApi(Resource):
@get_user_tenant
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeApp) @plugin_data(payload_type=RequestInvokeApp)
@inner_api_ns.doc("plugin_invoke_app") @inner_api_ns.doc("plugin_invoke_app")
@inner_api_ns.doc(description="Invoke application through plugin interface") @inner_api_ns.doc(description="Invoke application through plugin interface")
@ -348,9 +348,9 @@ class PluginInvokeAppApi(Resource):
@inner_api_ns.route("/invoke/encrypt") @inner_api_ns.route("/invoke/encrypt")
class PluginInvokeEncryptApi(Resource): class PluginInvokeEncryptApi(Resource):
@get_user_tenant
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeEncrypt) @plugin_data(payload_type=RequestInvokeEncrypt)
@inner_api_ns.doc("plugin_invoke_encrypt") @inner_api_ns.doc("plugin_invoke_encrypt")
@inner_api_ns.doc(description="Encrypt or decrypt data through plugin interface") @inner_api_ns.doc(description="Encrypt or decrypt data through plugin interface")
@ -375,9 +375,9 @@ class PluginInvokeEncryptApi(Resource):
@inner_api_ns.route("/invoke/summary") @inner_api_ns.route("/invoke/summary")
class PluginInvokeSummaryApi(Resource): class PluginInvokeSummaryApi(Resource):
@get_user_tenant
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeSummary) @plugin_data(payload_type=RequestInvokeSummary)
@inner_api_ns.doc("plugin_invoke_summary") @inner_api_ns.doc("plugin_invoke_summary")
@inner_api_ns.doc(description="Invoke summary functionality through plugin interface") @inner_api_ns.doc(description="Invoke summary functionality through plugin interface")
@ -405,9 +405,9 @@ class PluginInvokeSummaryApi(Resource):
@inner_api_ns.route("/upload/file/request") @inner_api_ns.route("/upload/file/request")
class PluginUploadFileRequestApi(Resource): class PluginUploadFileRequestApi(Resource):
@get_user_tenant
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestRequestUploadFile) @plugin_data(payload_type=RequestRequestUploadFile)
@inner_api_ns.doc("plugin_upload_file_request") @inner_api_ns.doc("plugin_upload_file_request")
@inner_api_ns.doc(description="Request signed URL for file upload through plugin interface") @inner_api_ns.doc(description="Request signed URL for file upload through plugin interface")
@ -426,9 +426,9 @@ class PluginUploadFileRequestApi(Resource):
@inner_api_ns.route("/fetch/app/info") @inner_api_ns.route("/fetch/app/info")
class PluginFetchAppInfoApi(Resource): class PluginFetchAppInfoApi(Resource):
@get_user_tenant
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestFetchAppInfo) @plugin_data(payload_type=RequestFetchAppInfo)
@inner_api_ns.doc("plugin_fetch_app_info") @inner_api_ns.doc("plugin_fetch_app_info")
@inner_api_ns.doc(description="Fetch application information through plugin interface") @inner_api_ns.doc(description="Fetch application information through plugin interface")

View File

@ -1,6 +1,6 @@
from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import Optional, ParamSpec, TypeVar from typing import Optional, ParamSpec, TypeVar, cast
from flask import current_app, request from flask import current_app, request
from flask_login import user_logged_in from flask_login import user_logged_in
@ -10,7 +10,7 @@ from sqlalchemy.orm import Session
from core.file.constants import DEFAULT_SERVICE_API_USER_ID from core.file.constants import DEFAULT_SERVICE_API_USER_ID
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import _get_user from libs.login import current_user
from models.account import Tenant from models.account import Tenant
from models.model import EndUser from models.model import EndUser
@ -66,8 +66,8 @@ def get_user_tenant(view: Optional[Callable[P, R]] = None):
p = parser.parse_args() p = parser.parse_args()
user_id: Optional[str] = p.get("user_id") user_id = cast(str, p.get("user_id"))
tenant_id: str = p.get("tenant_id") tenant_id = cast(str, p.get("tenant_id"))
if not tenant_id: if not tenant_id:
raise ValueError("tenant_id is required") raise ValueError("tenant_id is required")
@ -98,7 +98,7 @@ def get_user_tenant(view: Optional[Callable[P, R]] = None):
kwargs["user_model"] = user kwargs["user_model"] = user
current_app.login_manager._update_request_context_with_user(user) # type: ignore current_app.login_manager._update_request_context_with_user(user) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
return view_func(*args, **kwargs) return view_func(*args, **kwargs)

View File

@ -15,6 +15,6 @@ api = ExternalApi(
mcp_ns = Namespace("mcp", description="MCP operations", path="/") mcp_ns = Namespace("mcp", description="MCP operations", path="/")
from . import mcp from . import mcp # pyright: ignore[reportUnusedImport]
api.add_namespace(mcp_ns) api.add_namespace(mcp_ns)

View File

@ -15,9 +15,27 @@ api = ExternalApi(
service_api_ns = Namespace("service_api", description="Service operations", path="/") service_api_ns = Namespace("service_api", description="Service operations", path="/")
from . import index from . import index # pyright: ignore[reportUnusedImport]
from .app import annotation, app, audio, completion, conversation, file, file_preview, message, site, workflow from .app import (
from .dataset import dataset, document, hit_testing, metadata, segment, upload_file annotation, # pyright: ignore[reportUnusedImport]
from .workspace import models app, # pyright: ignore[reportUnusedImport]
audio, # pyright: ignore[reportUnusedImport]
completion, # pyright: ignore[reportUnusedImport]
conversation, # pyright: ignore[reportUnusedImport]
file, # pyright: ignore[reportUnusedImport]
file_preview, # pyright: ignore[reportUnusedImport]
message, # pyright: ignore[reportUnusedImport]
site, # pyright: ignore[reportUnusedImport]
workflow, # pyright: ignore[reportUnusedImport]
)
from .dataset import (
dataset, # pyright: ignore[reportUnusedImport]
document, # pyright: ignore[reportUnusedImport]
hit_testing, # pyright: ignore[reportUnusedImport]
metadata, # pyright: ignore[reportUnusedImport]
segment, # pyright: ignore[reportUnusedImport]
upload_file, # pyright: ignore[reportUnusedImport]
)
from .workspace import models # pyright: ignore[reportUnusedImport]
api.add_namespace(service_api_ns) api.add_namespace(service_api_ns)

View File

@ -1,4 +1,5 @@
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from flask_restx._http import HTTPStatus
from flask_restx.inputs import int_range from flask_restx.inputs import int_range
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound from werkzeug.exceptions import BadRequest, NotFound
@ -121,7 +122,7 @@ class ConversationDetailApi(Resource):
} }
) )
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=204) @service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=HTTPStatus.NO_CONTENT)
def delete(self, app_model: App, end_user: EndUser, c_id): def delete(self, app_model: App, end_user: EndUser, c_id):
"""Delete a specific conversation.""" """Delete a specific conversation."""
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)

View File

@ -30,6 +30,7 @@ from extensions.ext_database import db
from fields.document_fields import document_fields, document_status_fields from fields.document_fields import document_fields, document_status_fields
from libs.login import current_user from libs.login import current_user
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
from models.model import EndUser
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
from services.file_service import FileService from services.file_service import FileService
@ -298,6 +299,9 @@ class DocumentAddByFileApi(DatasetApiResource):
if not file.filename: if not file.filename:
raise FilenameNotExistsError raise FilenameNotExistsError
if not isinstance(current_user, EndUser):
raise ValueError("Invalid user account")
upload_file = FileService.upload_file( upload_file = FileService.upload_file(
filename=file.filename, filename=file.filename,
content=file.read(), content=file.read(),
@ -387,6 +391,8 @@ class DocumentUpdateByFileApi(DatasetApiResource):
raise FilenameNotExistsError raise FilenameNotExistsError
try: try:
if not isinstance(current_user, EndUser):
raise ValueError("Invalid user account")
upload_file = FileService.upload_file( upload_file = FileService.upload_file(
filename=file.filename, filename=file.filename,
content=file.read(), content=file.read(),

View File

@ -17,7 +17,7 @@ from core.file.constants import DEFAULT_SERVICE_API_USER_ID
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.login import _get_user from libs.login import current_user
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
from models.dataset import Dataset, RateLimitLog from models.dataset import Dataset, RateLimitLog
from models.model import ApiToken, App, EndUser from models.model import ApiToken, App, EndUser
@ -210,7 +210,7 @@ def validate_dataset_token(view: Optional[Callable[Concatenate[T, P], R]] = None
if account: if account:
account.current_tenant = tenant account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account) # type: ignore current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
else: else:
raise Unauthorized("Tenant owner account does not exist.") raise Unauthorized("Tenant owner account does not exist.")
else: else:

View File

@ -17,20 +17,20 @@ api = ExternalApi(
web_ns = Namespace("web", description="Web application API operations", path="/") web_ns = Namespace("web", description="Web application API operations", path="/")
from . import ( from . import (
app, app, # pyright: ignore[reportUnusedImport]
audio, audio, # pyright: ignore[reportUnusedImport]
completion, completion, # pyright: ignore[reportUnusedImport]
conversation, conversation, # pyright: ignore[reportUnusedImport]
feature, feature, # pyright: ignore[reportUnusedImport]
files, files, # pyright: ignore[reportUnusedImport]
forgot_password, forgot_password, # pyright: ignore[reportUnusedImport]
login, login, # pyright: ignore[reportUnusedImport]
message, message, # pyright: ignore[reportUnusedImport]
passport, passport, # pyright: ignore[reportUnusedImport]
remote_files, remote_files, # pyright: ignore[reportUnusedImport]
saved_message, saved_message, # pyright: ignore[reportUnusedImport]
site, site, # pyright: ignore[reportUnusedImport]
workflow, workflow, # pyright: ignore[reportUnusedImport]
) )
api.add_namespace(web_ns) api.add_namespace(web_ns)

View File

@ -1 +0,0 @@
import core.moderation.base

View File

@ -72,6 +72,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
function_call_state = True function_call_state = True
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
final_answer = "" final_answer = ""
prompt_messages: list = [] # Initialize prompt_messages
agent_thought_id = "" # Initialize agent_thought_id
def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage): def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
if not final_llm_usage_dict["usage"]: if not final_llm_usage_dict["usage"]:

View File

@ -54,6 +54,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
function_call_state = True function_call_state = True
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
final_answer = "" final_answer = ""
prompt_messages: list = [] # Initialize prompt_messages
# get tracing instance # get tracing instance
trace_manager = app_generate_entity.trace_manager trace_manager = app_generate_entity.trace_manager

View File

@ -21,7 +21,7 @@ class SensitiveWordAvoidanceConfigManager:
@classmethod @classmethod
def validate_and_set_defaults( def validate_and_set_defaults(
cls, tenant_id, config: dict, only_structure_validate: bool = False cls, tenant_id: str, config: dict, only_structure_validate: bool = False
) -> tuple[dict, list[str]]: ) -> tuple[dict, list[str]]:
if not config.get("sensitive_word_avoidance"): if not config.get("sensitive_word_avoidance"):
config["sensitive_word_avoidance"] = {"enabled": False} config["sensitive_word_avoidance"] = {"enabled": False}
@ -38,7 +38,14 @@ class SensitiveWordAvoidanceConfigManager:
if not only_structure_validate: if not only_structure_validate:
typ = config["sensitive_word_avoidance"]["type"] typ = config["sensitive_word_avoidance"]["type"]
sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"] if not isinstance(typ, str):
raise ValueError("sensitive_word_avoidance.type must be a string")
sensitive_word_avoidance_config = config["sensitive_word_avoidance"].get("config")
if sensitive_word_avoidance_config is None:
sensitive_word_avoidance_config = {}
if not isinstance(sensitive_word_avoidance_config, dict):
raise ValueError("sensitive_word_avoidance.config must be a dict")
ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config) ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config)

View File

@ -25,10 +25,14 @@ class PromptTemplateConfigManager:
if chat_prompt_config: if chat_prompt_config:
chat_prompt_messages = [] chat_prompt_messages = []
for message in chat_prompt_config.get("prompt", []): for message in chat_prompt_config.get("prompt", []):
text = message.get("text")
if not isinstance(text, str):
raise ValueError("message text must be a string")
role = message.get("role")
if not isinstance(role, str):
raise ValueError("message role must be a string")
chat_prompt_messages.append( chat_prompt_messages.append(
AdvancedChatMessageEntity( AdvancedChatMessageEntity(text=text, role=PromptMessageRole.value_of(role))
**{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])}
)
) )
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages) advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages)

View File

@ -71,7 +71,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping" yield "ping"
continue continue
response_chunk = { response_chunk: dict[str, Any] = {
"event": sub_stream_response.event.value, "event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id, "conversation_id": chunk.conversation_id,
"message_id": chunk.message_id, "message_id": chunk.message_id,
@ -82,7 +82,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data) response_chunk.update(data)
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk yield response_chunk
@classmethod @classmethod
@ -102,7 +102,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping" yield "ping"
continue continue
response_chunk = { response_chunk: dict[str, Any] = {
"event": sub_stream_response.event.value, "event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id, "conversation_id": chunk.conversation_id,
"message_id": chunk.message_id, "message_id": chunk.message_id,
@ -110,7 +110,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
} }
if isinstance(sub_stream_response, MessageEndStreamResponse): if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict() sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {}) metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict) response_chunk.update(sub_stream_response_dict)
@ -118,8 +118,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data) response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute] response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk yield response_chunk

View File

@ -174,7 +174,7 @@ class AdvancedChatAppGenerateTaskPipeline:
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._base_task_pipeline._stream: if self._base_task_pipeline.stream:
return self._to_stream_response(generator) return self._to_stream_response(generator)
else: else:
return self._to_blocking_response(generator) return self._to_blocking_response(generator)
@ -302,13 +302,13 @@ class AdvancedChatAppGenerateTaskPipeline:
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]: def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
"""Handle ping events.""" """Handle ping events."""
yield self._base_task_pipeline._ping_stream_response() yield self._base_task_pipeline.ping_stream_response()
def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]: def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]:
"""Handle error events.""" """Handle error events."""
with self._database_session() as session: with self._database_session() as session:
err = self._base_task_pipeline._handle_error(event=event, session=session, message_id=self._message_id) err = self._base_task_pipeline.handle_error(event=event, session=session, message_id=self._message_id)
yield self._base_task_pipeline._error_to_stream_response(err) yield self._base_task_pipeline.error_to_stream_response(err)
def _handle_workflow_started_event(self, *args, **kwargs) -> Generator[StreamResponse, None, None]: def _handle_workflow_started_event(self, *args, **kwargs) -> Generator[StreamResponse, None, None]:
"""Handle workflow started events.""" """Handle workflow started events."""
@ -627,10 +627,10 @@ class AdvancedChatAppGenerateTaskPipeline:
workflow_execution=workflow_execution, workflow_execution=workflow_execution,
) )
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}")) err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}"))
err = self._base_task_pipeline._handle_error(event=err_event, session=session, message_id=self._message_id) err = self._base_task_pipeline.handle_error(event=err_event, session=session, message_id=self._message_id)
yield workflow_finish_resp yield workflow_finish_resp
yield self._base_task_pipeline._error_to_stream_response(err) yield self._base_task_pipeline.error_to_stream_response(err)
def _handle_stop_event( def _handle_stop_event(
self, self,
@ -683,7 +683,7 @@ class AdvancedChatAppGenerateTaskPipeline:
"""Handle advanced chat message end events.""" """Handle advanced chat message end events."""
self._ensure_graph_runtime_initialized(graph_runtime_state) self._ensure_graph_runtime_initialized(graph_runtime_state)
output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished( output_moderation_answer = self._base_task_pipeline.handle_output_moderation_when_task_finished(
self._task_state.answer self._task_state.answer
) )
if output_moderation_answer: if output_moderation_answer:
@ -899,7 +899,7 @@ class AdvancedChatAppGenerateTaskPipeline:
message.answer = answer_text message.answer = answer_text
message.updated_at = naive_utc_now() message.updated_at = naive_utc_now()
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at message.provider_response_latency = time.perf_counter() - self._base_task_pipeline.start_at
message.message_metadata = self._task_state.metadata.model_dump_json() message.message_metadata = self._task_state.metadata.model_dump_json()
message_files = [ message_files = [
MessageFile( MessageFile(
@ -955,9 +955,9 @@ class AdvancedChatAppGenerateTaskPipeline:
:param text: text :param text: text
:return: True if output moderation should direct output, otherwise False :return: True if output moderation should direct output, otherwise False
""" """
if self._base_task_pipeline._output_moderation_handler: if self._base_task_pipeline.output_moderation_handler:
if self._base_task_pipeline._output_moderation_handler.should_direct_output(): if self._base_task_pipeline.output_moderation_handler.should_direct_output():
self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output() self._task_state.answer = self._base_task_pipeline.output_moderation_handler.get_final_output()
self._base_task_pipeline.queue_manager.publish( self._base_task_pipeline.queue_manager.publish(
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
) )
@ -967,7 +967,7 @@ class AdvancedChatAppGenerateTaskPipeline:
) )
return True return True
else: else:
self._base_task_pipeline._output_moderation_handler.append_new_token(text) self._base_task_pipeline.output_moderation_handler.append_new_token(text)
return False return False

View File

@ -1,6 +1,6 @@
import uuid import uuid
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, Optional from typing import Any, Optional, cast
from core.agent.entities import AgentEntity from core.agent.entities import AgentEntity
from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.base_app_config_manager import BaseAppConfigManager
@ -160,7 +160,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
return filtered_config return filtered_config
@classmethod @classmethod
def validate_agent_mode_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: def validate_agent_mode_and_set_defaults(
cls, tenant_id: str, config: dict[str, Any]
) -> tuple[dict[str, Any], list[str]]:
""" """
Validate agent_mode and set defaults for agent feature Validate agent_mode and set defaults for agent feature
@ -170,30 +172,32 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
if not config.get("agent_mode"): if not config.get("agent_mode"):
config["agent_mode"] = {"enabled": False, "tools": []} config["agent_mode"] = {"enabled": False, "tools": []}
if not isinstance(config["agent_mode"], dict): agent_mode = config["agent_mode"]
if not isinstance(agent_mode, dict):
raise ValueError("agent_mode must be of object type") raise ValueError("agent_mode must be of object type")
if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]: # FIXME(-LAN-): Cast needed due to basedpyright limitation with dict type narrowing
config["agent_mode"]["enabled"] = False agent_mode = cast(dict[str, Any], agent_mode)
if not isinstance(config["agent_mode"]["enabled"], bool): if "enabled" not in agent_mode or not agent_mode["enabled"]:
agent_mode["enabled"] = False
if not isinstance(agent_mode["enabled"], bool):
raise ValueError("enabled in agent_mode must be of boolean type") raise ValueError("enabled in agent_mode must be of boolean type")
if not config["agent_mode"].get("strategy"): if not agent_mode.get("strategy"):
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value agent_mode["strategy"] = PlanningStrategy.ROUTER.value
if config["agent_mode"]["strategy"] not in [ if agent_mode["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]:
member.value for member in list(PlanningStrategy.__members__.values())
]:
raise ValueError("strategy in agent_mode must be in the specified strategy list") raise ValueError("strategy in agent_mode must be in the specified strategy list")
if not config["agent_mode"].get("tools"): if not agent_mode.get("tools"):
config["agent_mode"]["tools"] = [] agent_mode["tools"] = []
if not isinstance(config["agent_mode"]["tools"], list): if not isinstance(agent_mode["tools"], list):
raise ValueError("tools in agent_mode must be a list of objects") raise ValueError("tools in agent_mode must be a list of objects")
for tool in config["agent_mode"]["tools"]: for tool in agent_mode["tools"]:
key = list(tool.keys())[0] key = list(tool.keys())[0]
if key in OLD_TOOLS: if key in OLD_TOOLS:
# old style, use tool name as key # old style, use tool name as key

View File

@ -46,7 +46,10 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
response = cls.convert_blocking_full_response(blocking_response) response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get("metadata", {}) metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata) if isinstance(metadata, dict):
response["metadata"] = cls._get_simple_metadata(metadata)
else:
response["metadata"] = {}
return response return response
@ -78,7 +81,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data) response_chunk.update(data)
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk yield response_chunk
@classmethod @classmethod
@ -106,7 +109,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
} }
if isinstance(sub_stream_response, MessageEndStreamResponse): if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict() sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {}) metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict) response_chunk.update(sub_stream_response_dict)
@ -114,6 +117,6 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data) response_chunk.update(data)
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk yield response_chunk

View File

@ -32,6 +32,7 @@ class AppQueueManager:
self._task_id = task_id self._task_id = task_id
self._user_id = user_id self._user_id = user_id
self._invoke_from = invoke_from self._invoke_from = invoke_from
self.invoke_from = invoke_from # Public accessor for invoke_from
user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user" user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
redis_client.setex( redis_client.setex(

View File

@ -46,7 +46,10 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
response = cls.convert_blocking_full_response(blocking_response) response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get("metadata", {}) metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata) if isinstance(metadata, dict):
response["metadata"] = cls._get_simple_metadata(metadata)
else:
response["metadata"] = {}
return response return response
@ -78,7 +81,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data) response_chunk.update(data)
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk yield response_chunk
@classmethod @classmethod
@ -106,7 +109,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
} }
if isinstance(sub_stream_response, MessageEndStreamResponse): if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict() sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {}) metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict) response_chunk.update(sub_stream_response_dict)
@ -114,6 +117,6 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data) response_chunk.update(data)
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk yield response_chunk

View File

@ -271,6 +271,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
raise MoreLikeThisDisabledError() raise MoreLikeThisDisabledError()
app_model_config = message.app_model_config app_model_config = message.app_model_config
if not app_model_config:
raise ValueError("Message app_model_config is None")
override_model_config_dict = app_model_config.to_dict() override_model_config_dict = app_model_config.to_dict()
model_dict = override_model_config_dict["model"] model_dict = override_model_config_dict["model"]
completion_params = model_dict.get("completion_params") completion_params = model_dict.get("completion_params")

View File

@ -45,7 +45,10 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
response = cls.convert_blocking_full_response(blocking_response) response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get("metadata", {}) metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata) if isinstance(metadata, dict):
response["metadata"] = cls._get_simple_metadata(metadata)
else:
response["metadata"] = {}
return response return response
@ -76,7 +79,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data) response_chunk.update(data)
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk yield response_chunk
@classmethod @classmethod
@ -103,14 +106,16 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
} }
if isinstance(sub_stream_response, MessageEndStreamResponse): if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict() sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {}) metadata = sub_stream_response_dict.get("metadata", {})
if not isinstance(metadata, dict):
metadata = {}
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict) response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse): if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data) response_chunk.update(data)
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk yield response_chunk

View File

@ -23,7 +23,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
:param blocking_response: blocking response :param blocking_response: blocking response
:return: :return:
""" """
return dict(blocking_response.to_dict()) return blocking_response.model_dump()
@classmethod @classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override] def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
@ -51,7 +51,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping" yield "ping"
continue continue
response_chunk = { response_chunk: dict[str, object] = {
"event": sub_stream_response.event.value, "event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id, "workflow_run_id": chunk.workflow_run_id,
} }
@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data) response_chunk.update(data)
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk yield response_chunk
@classmethod @classmethod
@ -80,7 +80,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping" yield "ping"
continue continue
response_chunk = { response_chunk: dict[str, object] = {
"event": sub_stream_response.event.value, "event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id, "workflow_run_id": chunk.workflow_run_id,
} }
@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute] response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk yield response_chunk

View File

@ -137,7 +137,7 @@ class WorkflowAppGenerateTaskPipeline:
self._application_generate_entity = application_generate_entity self._application_generate_entity = application_generate_entity
self._workflow_features_dict = workflow.features_dict self._workflow_features_dict = workflow.features_dict
self._workflow_run_id = "" self._workflow_run_id = ""
self._invoke_from = queue_manager._invoke_from self._invoke_from = queue_manager.invoke_from
self._draft_var_saver_factory = draft_var_saver_factory self._draft_var_saver_factory = draft_var_saver_factory
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
@ -146,7 +146,7 @@ class WorkflowAppGenerateTaskPipeline:
:return: :return:
""" """
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._base_task_pipeline._stream: if self._base_task_pipeline.stream:
return self._to_stream_response(generator) return self._to_stream_response(generator)
else: else:
return self._to_blocking_response(generator) return self._to_blocking_response(generator)
@ -276,12 +276,12 @@ class WorkflowAppGenerateTaskPipeline:
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]: def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
"""Handle ping events.""" """Handle ping events."""
yield self._base_task_pipeline._ping_stream_response() yield self._base_task_pipeline.ping_stream_response()
def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]: def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]:
"""Handle error events.""" """Handle error events."""
err = self._base_task_pipeline._handle_error(event=event) err = self._base_task_pipeline.handle_error(event=event)
yield self._base_task_pipeline._error_to_stream_response(err) yield self._base_task_pipeline.error_to_stream_response(err)
def _handle_workflow_started_event( def _handle_workflow_started_event(
self, event: QueueWorkflowStartedEvent, **kwargs self, event: QueueWorkflowStartedEvent, **kwargs

View File

@ -123,7 +123,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
""" """
# app config # app config
app_config: EasyUIBasedAppConfig app_config: EasyUIBasedAppConfig = None # type: ignore
model_conf: ModelConfigWithCredentialsEntity model_conf: ModelConfigWithCredentialsEntity
query: Optional[str] = None query: Optional[str] = None
@ -186,7 +186,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
""" """
# app config # app config
app_config: WorkflowUIBasedAppConfig app_config: WorkflowUIBasedAppConfig = None # type: ignore
workflow_run_id: Optional[str] = None workflow_run_id: Optional[str] = None
query: str query: str
@ -218,7 +218,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
""" """
# app config # app config
app_config: WorkflowUIBasedAppConfig app_config: WorkflowUIBasedAppConfig = None # type: ignore
workflow_execution_id: str workflow_execution_id: str
class SingleIterationRunEntity(BaseModel): class SingleIterationRunEntity(BaseModel):

View File

@ -5,7 +5,6 @@ from typing import Any, Optional
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@ -92,9 +91,6 @@ class StreamResponse(BaseModel):
event: StreamEvent event: StreamEvent
task_id: str task_id: str
def to_dict(self):
return jsonable_encoder(self)
class ErrorStreamResponse(StreamResponse): class ErrorStreamResponse(StreamResponse):
""" """
@ -745,9 +741,6 @@ class AppBlockingResponse(BaseModel):
task_id: str task_id: str
def to_dict(self):
return jsonable_encoder(self)
class ChatbotAppBlockingResponse(AppBlockingResponse): class ChatbotAppBlockingResponse(AppBlockingResponse):
""" """

View File

@ -35,6 +35,9 @@ class AnnotationReplyFeature:
collection_binding_detail = annotation_setting.collection_binding_detail collection_binding_detail = annotation_setting.collection_binding_detail
if not collection_binding_detail:
return None
try: try:
score_threshold = annotation_setting.score_threshold or 1 score_threshold = annotation_setting.score_threshold or 1
embedding_provider_name = collection_binding_detail.provider_name embedding_provider_name = collection_binding_detail.provider_name

View File

@ -1 +1,3 @@
from .rate_limit import RateLimit from .rate_limit import RateLimit
__all__ = ["RateLimit"]

View File

@ -19,7 +19,7 @@ class RateLimit:
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes _ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
_instance_dict: dict[str, "RateLimit"] = {} _instance_dict: dict[str, "RateLimit"] = {}
def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int): def __new__(cls, client_id: str, max_active_requests: int):
if client_id not in cls._instance_dict: if client_id not in cls._instance_dict:
instance = super().__new__(cls) instance = super().__new__(cls)
cls._instance_dict[client_id] = instance cls._instance_dict[client_id] = instance

View File

@ -38,11 +38,11 @@ class BasedGenerateTaskPipeline:
): ):
self._application_generate_entity = application_generate_entity self._application_generate_entity = application_generate_entity
self.queue_manager = queue_manager self.queue_manager = queue_manager
self._start_at = time.perf_counter() self.start_at = time.perf_counter()
self._output_moderation_handler = self._init_output_moderation() self.output_moderation_handler = self._init_output_moderation()
self._stream = stream self.stream = stream
def _handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""): def handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""):
logger.debug("error: %s", event.error) logger.debug("error: %s", event.error)
e = event.error e = event.error
err: Exception err: Exception
@ -86,7 +86,7 @@ class BasedGenerateTaskPipeline:
return message return message
def _error_to_stream_response(self, e: Exception): def error_to_stream_response(self, e: Exception):
""" """
Error to stream response. Error to stream response.
:param e: exception :param e: exception
@ -94,7 +94,7 @@ class BasedGenerateTaskPipeline:
""" """
return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e) return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e)
def _ping_stream_response(self) -> PingStreamResponse: def ping_stream_response(self) -> PingStreamResponse:
""" """
Ping stream response. Ping stream response.
:return: :return:
@ -118,21 +118,21 @@ class BasedGenerateTaskPipeline:
) )
return None return None
def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]: def handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]:
""" """
Handle output moderation when task finished. Handle output moderation when task finished.
:param completion: completion :param completion: completion
:return: :return:
""" """
# response moderation # response moderation
if self._output_moderation_handler: if self.output_moderation_handler:
self._output_moderation_handler.stop_thread() self.output_moderation_handler.stop_thread()
completion, flagged = self._output_moderation_handler.moderation_completion( completion, flagged = self.output_moderation_handler.moderation_completion(
completion=completion, public_event=False completion=completion, public_event=False
) )
self._output_moderation_handler = None self.output_moderation_handler = None
if flagged: if flagged:
return completion return completion

View File

@ -125,7 +125,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
) )
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream: if self.stream:
return self._to_stream_response(generator) return self._to_stream_response(generator)
else: else:
return self._to_blocking_response(generator) return self._to_blocking_response(generator)
@ -265,9 +265,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if isinstance(event, QueueErrorEvent): if isinstance(event, QueueErrorEvent):
with Session(db.engine) as session: with Session(db.engine) as session:
err = self._handle_error(event=event, session=session, message_id=self._message_id) err = self.handle_error(event=event, session=session, message_id=self._message_id)
session.commit() session.commit()
yield self._error_to_stream_response(err) yield self.error_to_stream_response(err)
break break
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
if isinstance(event, QueueMessageEndEvent): if isinstance(event, QueueMessageEndEvent):
@ -277,7 +277,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
self._handle_stop(event) self._handle_stop(event)
# handle output moderation # handle output moderation
output_moderation_answer = self._handle_output_moderation_when_task_finished( output_moderation_answer = self.handle_output_moderation_when_task_finished(
cast(str, self._task_state.llm_result.message.content) cast(str, self._task_state.llm_result.message.content)
) )
if output_moderation_answer: if output_moderation_answer:
@ -354,7 +354,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
elif isinstance(event, QueueMessageReplaceEvent): elif isinstance(event, QueueMessageReplaceEvent):
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text) yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueuePingEvent): elif isinstance(event, QueuePingEvent):
yield self._ping_stream_response() yield self.ping_stream_response()
else: else:
continue continue
if publisher: if publisher:
@ -394,7 +394,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
message.answer_tokens = usage.completion_tokens message.answer_tokens = usage.completion_tokens
message.answer_unit_price = usage.completion_unit_price message.answer_unit_price = usage.completion_unit_price
message.answer_price_unit = usage.completion_price_unit message.answer_price_unit = usage.completion_price_unit
message.provider_response_latency = time.perf_counter() - self._start_at message.provider_response_latency = time.perf_counter() - self.start_at
message.total_price = usage.total_price message.total_price = usage.total_price
message.currency = usage.currency message.currency = usage.currency
self._task_state.llm_result.usage.latency = message.provider_response_latency self._task_state.llm_result.usage.latency = message.provider_response_latency
@ -438,7 +438,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
# transform usage # transform usage
model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance) model_type_instance = cast(LargeLanguageModel, model_type_instance)
self._task_state.llm_result.usage = model_type_instance._calc_response_usage( self._task_state.llm_result.usage = model_type_instance.calc_response_usage(
model, credentials, prompt_tokens, completion_tokens model, credentials, prompt_tokens, completion_tokens
) )
@ -498,10 +498,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
:param text: text :param text: text
:return: True if output moderation should direct output, otherwise False :return: True if output moderation should direct output, otherwise False
""" """
if self._output_moderation_handler: if self.output_moderation_handler:
if self._output_moderation_handler.should_direct_output(): if self.output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output # stop subscribe new token when output moderation should direct output
self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output() self._task_state.llm_result.message.content = self.output_moderation_handler.get_final_output()
self.queue_manager.publish( self.queue_manager.publish(
QueueLLMChunkEvent( QueueLLMChunkEvent(
chunk=LLMResultChunk( chunk=LLMResultChunk(
@ -521,6 +521,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
) )
return True return True
else: else:
self._output_moderation_handler.append_new_token(text) self.output_moderation_handler.append_new_token(text)
return False return False

View File

@ -72,7 +72,7 @@ class AppGeneratorTTSPublisher:
self.voice = voice self.voice = voice
if not voice or voice not in values: if not voice or voice not in values:
self.voice = self.voices[0].get("value") self.voice = self.voices[0].get("value")
self.MAX_SENTENCE = 2 self.max_sentence = 2
self._last_audio_event: Optional[AudioTrunk] = None self._last_audio_event: Optional[AudioTrunk] = None
# FIXME better way to handle this threading.start # FIXME better way to handle this threading.start
threading.Thread(target=self._runtime).start() threading.Thread(target=self._runtime).start()
@ -113,8 +113,8 @@ class AppGeneratorTTSPublisher:
self.msg_text += message.event.outputs.get("output", "") self.msg_text += message.event.outputs.get("output", "")
self.last_message = message self.last_message = message
sentence_arr, text_tmp = self._extract_sentence(self.msg_text) sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
if len(sentence_arr) >= min(self.MAX_SENTENCE, 7): if len(sentence_arr) >= min(self.max_sentence, 7):
self.MAX_SENTENCE += 1 self.max_sentence += 1
text_content = "".join(sentence_arr) text_content = "".join(sentence_arr)
futures_result = self.executor.submit( futures_result = self.executor.submit(
_invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice _invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice

View File

@ -1840,8 +1840,14 @@ class ProviderConfigurations(BaseModel):
def __setitem__(self, key, value): def __setitem__(self, key, value):
self.configurations[key] = value self.configurations[key] = value
def __contains__(self, key):
if "/" not in key:
key = str(ModelProviderID(key))
return key in self.configurations
def __iter__(self): def __iter__(self):
return iter(self.configurations) # Return an iterator of (key, value) tuples to match BaseModel's __iter__
yield from self.configurations.items()
def values(self) -> Iterator[ProviderConfiguration]: def values(self) -> Iterator[ProviderConfiguration]:
return iter(self.configurations.values()) return iter(self.configurations.values())

View File

@ -98,7 +98,7 @@ def to_prompt_message_content(
def download(f: File, /): def download(f: File, /):
if f.transfer_method in (FileTransferMethod.TOOL_FILE, FileTransferMethod.LOCAL_FILE): if f.transfer_method in (FileTransferMethod.TOOL_FILE, FileTransferMethod.LOCAL_FILE):
return _download_file_content(f._storage_key) return _download_file_content(f.storage_key)
elif f.transfer_method == FileTransferMethod.REMOTE_URL: elif f.transfer_method == FileTransferMethod.REMOTE_URL:
response = ssrf_proxy.get(f.remote_url, follow_redirects=True) response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
response.raise_for_status() response.raise_for_status()
@ -134,9 +134,9 @@ def _get_encoded_string(f: File, /):
response.raise_for_status() response.raise_for_status()
data = response.content data = response.content
case FileTransferMethod.LOCAL_FILE: case FileTransferMethod.LOCAL_FILE:
data = _download_file_content(f._storage_key) data = _download_file_content(f.storage_key)
case FileTransferMethod.TOOL_FILE: case FileTransferMethod.TOOL_FILE:
data = _download_file_content(f._storage_key) data = _download_file_content(f.storage_key)
encoded_string = base64.b64encode(data).decode("utf-8") encoded_string = base64.b64encode(data).decode("utf-8")
return encoded_string return encoded_string

View File

@ -146,3 +146,11 @@ class File(BaseModel):
if not self.related_id: if not self.related_id:
raise ValueError("Missing file related_id") raise ValueError("Missing file related_id")
return self return self
@property
def storage_key(self) -> str:
return self._storage_key
@storage_key.setter
def storage_key(self, value: str):
self._storage_key = value

View File

@ -13,18 +13,18 @@ logger = logging.getLogger(__name__)
SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
HTTP_REQUEST_NODE_SSL_VERIFY = True # Default value for HTTP_REQUEST_NODE_SSL_VERIFY is True http_request_node_ssl_verify = True # Default value for http_request_node_ssl_verify is True
try: try:
HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY config_value = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
http_request_node_ssl_verify_lower = str(HTTP_REQUEST_NODE_SSL_VERIFY).lower() http_request_node_ssl_verify_lower = str(config_value).lower()
if http_request_node_ssl_verify_lower == "true": if http_request_node_ssl_verify_lower == "true":
HTTP_REQUEST_NODE_SSL_VERIFY = True http_request_node_ssl_verify = True
elif http_request_node_ssl_verify_lower == "false": elif http_request_node_ssl_verify_lower == "false":
HTTP_REQUEST_NODE_SSL_VERIFY = False http_request_node_ssl_verify = False
else: else:
raise ValueError("Invalid value. HTTP_REQUEST_NODE_SSL_VERIFY should be 'True' or 'False'") raise ValueError("Invalid value. HTTP_REQUEST_NODE_SSL_VERIFY should be 'True' or 'False'")
except NameError: except NameError:
HTTP_REQUEST_NODE_SSL_VERIFY = True http_request_node_ssl_verify = True
BACKOFF_FACTOR = 0.5 BACKOFF_FACTOR = 0.5
STATUS_FORCELIST = [429, 500, 502, 503, 504] STATUS_FORCELIST = [429, 500, 502, 503, 504]
@ -51,7 +51,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
) )
if "ssl_verify" not in kwargs: if "ssl_verify" not in kwargs:
kwargs["ssl_verify"] = HTTP_REQUEST_NODE_SSL_VERIFY kwargs["ssl_verify"] = http_request_node_ssl_verify
ssl_verify = kwargs.pop("ssl_verify") ssl_verify = kwargs.pop("ssl_verify")

View File

@ -529,6 +529,7 @@ class IndexingRunner:
# chunk nodes by chunk size # chunk nodes by chunk size
indexing_start_at = time.perf_counter() indexing_start_at = time.perf_counter()
tokens = 0 tokens = 0
create_keyword_thread = None
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy": if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy":
# create keyword index # create keyword index
create_keyword_thread = threading.Thread( create_keyword_thread = threading.Thread(
@ -567,7 +568,11 @@ class IndexingRunner:
for future in futures: for future in futures:
tokens += future.result() tokens += future.result()
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy": if (
dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX
and dataset.indexing_technique == "economy"
and create_keyword_thread is not None
):
create_keyword_thread.join() create_keyword_thread.join()
indexing_end_at = time.perf_counter() indexing_end_at = time.perf_counter()

View File

@ -20,7 +20,7 @@ from core.llm_generator.prompts import (
) )
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.entities.trace_entity import TraceTaskName from core.ops.entities.trace_entity import TraceTaskName
@ -313,14 +313,20 @@ class LLMGenerator:
model_type=ModelType.LLM, model_type=ModelType.LLM,
) )
prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)] prompt_messages: list[PromptMessage] = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)]
response: LLMResult = model_instance.invoke_llm( # Explicitly use the non-streaming overload
result = model_instance.invoke_llm(
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
model_parameters={"temperature": 0.01, "max_tokens": 2000}, model_parameters={"temperature": 0.01, "max_tokens": 2000},
stream=False, stream=False,
) )
# Runtime type check since pyright has issues with the overload
if not isinstance(result, LLMResult):
raise TypeError("Expected LLMResult when stream=False")
response = result
answer = cast(str, response.message.content) answer = cast(str, response.message.content)
return answer.strip() return answer.strip()

View File

@ -45,6 +45,7 @@ class SpecialModelType(StrEnum):
@overload @overload
def invoke_llm_with_structured_output( def invoke_llm_with_structured_output(
*,
provider: str, provider: str,
model_schema: AIModelEntity, model_schema: AIModelEntity,
model_instance: ModelInstance, model_instance: ModelInstance,
@ -53,14 +54,13 @@ def invoke_llm_with_structured_output(
model_parameters: Optional[Mapping] = None, model_parameters: Optional[Mapping] = None,
tools: Sequence[PromptMessageTool] | None = None, tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None, stop: Optional[list[str]] = None,
stream: Literal[True] = True, stream: Literal[True],
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
@overload @overload
def invoke_llm_with_structured_output( def invoke_llm_with_structured_output(
*,
provider: str, provider: str,
model_schema: AIModelEntity, model_schema: AIModelEntity,
model_instance: ModelInstance, model_instance: ModelInstance,
@ -69,14 +69,13 @@ def invoke_llm_with_structured_output(
model_parameters: Optional[Mapping] = None, model_parameters: Optional[Mapping] = None,
tools: Sequence[PromptMessageTool] | None = None, tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None, stop: Optional[list[str]] = None,
stream: Literal[False] = False, stream: Literal[False],
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
) -> LLMResultWithStructuredOutput: ... ) -> LLMResultWithStructuredOutput: ...
@overload @overload
def invoke_llm_with_structured_output( def invoke_llm_with_structured_output(
*,
provider: str, provider: str,
model_schema: AIModelEntity, model_schema: AIModelEntity,
model_instance: ModelInstance, model_instance: ModelInstance,
@ -89,9 +88,8 @@ def invoke_llm_with_structured_output(
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ... ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
def invoke_llm_with_structured_output( def invoke_llm_with_structured_output(
*,
provider: str, provider: str,
model_schema: AIModelEntity, model_schema: AIModelEntity,
model_instance: ModelInstance, model_instance: ModelInstance,

View File

@ -23,13 +23,13 @@ DEFAULT_QUEUE_READ_TIMEOUT = 3
@final @final
class _StatusReady: class _StatusReady:
def __init__(self, endpoint_url: str): def __init__(self, endpoint_url: str):
self._endpoint_url = endpoint_url self.endpoint_url = endpoint_url
@final @final
class _StatusError: class _StatusError:
def __init__(self, exc: Exception): def __init__(self, exc: Exception):
self._exc = exc self.exc = exc
# Type aliases for better readability # Type aliases for better readability
@ -211,9 +211,9 @@ class SSETransport:
raise ValueError("failed to get endpoint URL") raise ValueError("failed to get endpoint URL")
if isinstance(status, _StatusReady): if isinstance(status, _StatusReady):
return status._endpoint_url return status.endpoint_url
elif isinstance(status, _StatusError): elif isinstance(status, _StatusError):
raise status._exc raise status.exc
else: else:
raise ValueError("failed to get endpoint URL") raise ValueError("failed to get endpoint URL")

View File

@ -38,6 +38,7 @@ def handle_mcp_request(
""" """
request_type = type(request.root) request_type = type(request.root)
request_root = request.root
def create_success_response(result_data: mcp_types.Result) -> mcp_types.JSONRPCResponse: def create_success_response(result_data: mcp_types.Result) -> mcp_types.JSONRPCResponse:
"""Create success response with business result data""" """Create success response with business result data"""
@ -58,21 +59,20 @@ def handle_mcp_request(
error=error_data, error=error_data,
) )
# Request handler mapping using functional approach
request_handlers = {
mcp_types.InitializeRequest: lambda: handle_initialize(mcp_server.description),
mcp_types.ListToolsRequest: lambda: handle_list_tools(
app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict
),
mcp_types.CallToolRequest: lambda: handle_call_tool(app, request, user_input_form, end_user),
mcp_types.PingRequest: lambda: handle_ping(),
}
try: try:
# Dispatch request to appropriate handler # Dispatch request to appropriate handler based on instance type
handler = request_handlers.get(request_type) if isinstance(request_root, mcp_types.InitializeRequest):
if handler: return create_success_response(handle_initialize(mcp_server.description))
return create_success_response(handler()) elif isinstance(request_root, mcp_types.ListToolsRequest):
return create_success_response(
handle_list_tools(
app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict
)
)
elif isinstance(request_root, mcp_types.CallToolRequest):
return create_success_response(handle_call_tool(app, request, user_input_form, end_user))
elif isinstance(request_root, mcp_types.PingRequest):
return create_success_response(handle_ping())
else: else:
return create_error_response(mcp_types.METHOD_NOT_FOUND, f"Method not found: {request_type.__name__}") return create_error_response(mcp_types.METHOD_NOT_FOUND, f"Method not found: {request_type.__name__}")

View File

@ -81,7 +81,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
self.request_meta = request_meta self.request_meta = request_meta
self.request = request self.request = request
self._session = session self._session = session
self._completed = False self.completed = False
self._on_complete = on_complete self._on_complete = on_complete
self._entered = False # Track if we're in a context manager self._entered = False # Track if we're in a context manager
@ -98,7 +98,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
): ):
"""Exit the context manager, performing cleanup and notifying completion.""" """Exit the context manager, performing cleanup and notifying completion."""
try: try:
if self._completed: if self.completed:
self._on_complete(self) self._on_complete(self)
finally: finally:
self._entered = False self._entered = False
@ -113,9 +113,9 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
""" """
if not self._entered: if not self._entered:
raise RuntimeError("RequestResponder must be used as a context manager") raise RuntimeError("RequestResponder must be used as a context manager")
assert not self._completed, "Request already responded to" assert not self.completed, "Request already responded to"
self._completed = True self.completed = True
self._session._send_response(request_id=self.request_id, response=response) self._session._send_response(request_id=self.request_id, response=response)
@ -124,7 +124,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
if not self._entered: if not self._entered:
raise RuntimeError("RequestResponder must be used as a context manager") raise RuntimeError("RequestResponder must be used as a context manager")
self._completed = True # Mark as completed so it's removed from in_flight self.completed = True # Mark as completed so it's removed from in_flight
# Send an error response to indicate cancellation # Send an error response to indicate cancellation
self._session._send_response( self._session._send_response(
request_id=self.request_id, request_id=self.request_id,
@ -351,7 +351,7 @@ class BaseSession(
self._in_flight[responder.request_id] = responder self._in_flight[responder.request_id] = responder
self._received_request(responder) self._received_request(responder)
if not responder._completed: if not responder.completed:
self._handle_incoming(responder) self._handle_incoming(responder)
elif isinstance(message.message.root, JSONRPCNotification): elif isinstance(message.message.root, JSONRPCNotification):

View File

@ -354,7 +354,7 @@ class LargeLanguageModel(AIModel):
) )
return 0 return 0
def _calc_response_usage( def calc_response_usage(
self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int
) -> LLMUsage: ) -> LLMUsage:
""" """

View File

@ -1,4 +1,5 @@
import enum import enum
import json
from typing import Any, Optional, Union from typing import Any, Optional, Union
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
@ -162,8 +163,6 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
# Try to parse JSON string for arrays # Try to parse JSON string for arrays
if isinstance(value, str): if isinstance(value, str):
try: try:
import json
parsed_value = json.loads(value) parsed_value = json.loads(value)
if isinstance(parsed_value, list): if isinstance(parsed_value, list):
return parsed_value return parsed_value
@ -176,8 +175,6 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
# Try to parse JSON string for objects # Try to parse JSON string for objects
if isinstance(value, str): if isinstance(value, str):
try: try:
import json
parsed_value = json.loads(value) parsed_value = json.loads(value)
if isinstance(parsed_value, dict): if isinstance(parsed_value, dict):
return parsed_value return parsed_value

View File

@ -82,7 +82,9 @@ def merge_blob_chunks(
message_class = type(resp) message_class = type(resp)
merged_message = message_class( merged_message = message_class(
type=ToolInvokeMessage.MessageType.BLOB, type=ToolInvokeMessage.MessageType.BLOB,
message=ToolInvokeMessage.BlobMessage(blob=files[chunk_id].data[: files[chunk_id].bytes_written]), message=ToolInvokeMessage.BlobMessage(
blob=bytes(files[chunk_id].data[: files[chunk_id].bytes_written])
),
meta=resp.meta, meta=resp.meta,
) )
yield cast(MessageType, merged_message) yield cast(MessageType, merged_message)

View File

@ -101,9 +101,22 @@ class SimplePromptTransform(PromptTransform):
with_memory_prompt=histories is not None, with_memory_prompt=histories is not None,
) )
variables = {k: inputs[k] for k in prompt_template_config["custom_variable_keys"] if k in inputs} custom_variable_keys_obj = prompt_template_config["custom_variable_keys"]
special_variable_keys_obj = prompt_template_config["special_variable_keys"]
for v in prompt_template_config["special_variable_keys"]: # Type check for custom_variable_keys
if not isinstance(custom_variable_keys_obj, list):
raise TypeError(f"Expected list for custom_variable_keys, got {type(custom_variable_keys_obj)}")
custom_variable_keys = cast(list[str], custom_variable_keys_obj)
# Type check for special_variable_keys
if not isinstance(special_variable_keys_obj, list):
raise TypeError(f"Expected list for special_variable_keys, got {type(special_variable_keys_obj)}")
special_variable_keys = cast(list[str], special_variable_keys_obj)
variables = {k: inputs[k] for k in custom_variable_keys if k in inputs}
for v in special_variable_keys:
# support #context#, #query# and #histories# # support #context#, #query# and #histories#
if v == "#context#": if v == "#context#":
variables["#context#"] = context or "" variables["#context#"] = context or ""
@ -113,9 +126,16 @@ class SimplePromptTransform(PromptTransform):
variables["#histories#"] = histories or "" variables["#histories#"] = histories or ""
prompt_template = prompt_template_config["prompt_template"] prompt_template = prompt_template_config["prompt_template"]
if not isinstance(prompt_template, PromptTemplateParser):
raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template)}")
prompt = prompt_template.format(variables) prompt = prompt_template.format(variables)
return prompt, prompt_template_config["prompt_rules"] prompt_rules = prompt_template_config["prompt_rules"]
if not isinstance(prompt_rules, dict):
raise TypeError(f"Expected dict for prompt_rules, got {type(prompt_rules)}")
return prompt, prompt_rules
def get_prompt_template( def get_prompt_template(
self, self,
@ -126,11 +146,11 @@ class SimplePromptTransform(PromptTransform):
has_context: bool, has_context: bool,
query_in_prompt: bool, query_in_prompt: bool,
with_memory_prompt: bool = False, with_memory_prompt: bool = False,
): ) -> dict[str, object]:
prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model) prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model)
custom_variable_keys = [] custom_variable_keys: list[str] = []
special_variable_keys = [] special_variable_keys: list[str] = []
prompt = "" prompt = ""
for order in prompt_rules["system_prompt_orders"]: for order in prompt_rules["system_prompt_orders"]:

View File

@ -40,6 +40,19 @@ if TYPE_CHECKING:
MetadataFilter = Union[DictFilter, common_types.Filter] MetadataFilter = Union[DictFilter, common_types.Filter]
class PathQdrantParams(BaseModel):
path: str
class UrlQdrantParams(BaseModel):
url: str
api_key: Optional[str]
timeout: float
verify: bool
grpc_port: int
prefer_grpc: bool
class QdrantConfig(BaseModel): class QdrantConfig(BaseModel):
endpoint: str endpoint: str
api_key: Optional[str] = None api_key: Optional[str] = None
@ -50,7 +63,7 @@ class QdrantConfig(BaseModel):
replication_factor: int = 1 replication_factor: int = 1
write_consistency_factor: int = 1 write_consistency_factor: int = 1
def to_qdrant_params(self): def to_qdrant_params(self) -> PathQdrantParams | UrlQdrantParams:
if self.endpoint and self.endpoint.startswith("path:"): if self.endpoint and self.endpoint.startswith("path:"):
path = self.endpoint.replace("path:", "") path = self.endpoint.replace("path:", "")
if not os.path.isabs(path): if not os.path.isabs(path):
@ -58,23 +71,23 @@ class QdrantConfig(BaseModel):
raise ValueError("Root path is not set") raise ValueError("Root path is not set")
path = os.path.join(self.root_path, path) path = os.path.join(self.root_path, path)
return {"path": path} return PathQdrantParams(path=path)
else: else:
return { return UrlQdrantParams(
"url": self.endpoint, url=self.endpoint,
"api_key": self.api_key, api_key=self.api_key,
"timeout": self.timeout, timeout=self.timeout,
"verify": self.endpoint.startswith("https"), verify=self.endpoint.startswith("https"),
"grpc_port": self.grpc_port, grpc_port=self.grpc_port,
"prefer_grpc": self.prefer_grpc, prefer_grpc=self.prefer_grpc,
} )
class QdrantVector(BaseVector): class QdrantVector(BaseVector):
def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = "Cosine"): def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = "Cosine"):
super().__init__(collection_name) super().__init__(collection_name)
self._client_config = config self._client_config = config
self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params()) self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params().model_dump())
self._distance_func = distance_func.upper() self._distance_func = distance_func.upper()
self._group_id = group_id self._group_id = group_id

View File

@ -94,10 +94,10 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
# In-memory cache for workflow node executions # In-memory cache for workflow node executions
self._execution_cache: dict[str, WorkflowNodeExecution] = {} self._execution_cache = {}
# Cache for mapping workflow_execution_ids to execution IDs for efficient retrieval # Cache for mapping workflow_execution_ids to execution IDs for efficient retrieval
self._workflow_execution_mapping: dict[str, list[str]] = {} self._workflow_execution_mapping = {}
logger.info( logger.info(
"Initialized CeleryWorkflowNodeExecutionRepository for tenant %s, app %s, triggered_from %s", "Initialized CeleryWorkflowNodeExecutionRepository for tenant %s, app %s, triggered_from %s",

View File

@ -4,7 +4,7 @@ from .types import SegmentType
class SegmentGroup(Segment): class SegmentGroup(Segment):
value_type: SegmentType = SegmentType.GROUP value_type: SegmentType = SegmentType.GROUP
value: list[Segment] value: list[Segment] = None # type: ignore
@property @property
def text(self): def text(self):

View File

@ -74,12 +74,12 @@ class NoneSegment(Segment):
class StringSegment(Segment): class StringSegment(Segment):
value_type: SegmentType = SegmentType.STRING value_type: SegmentType = SegmentType.STRING
value: str value: str = None # type: ignore
class FloatSegment(Segment): class FloatSegment(Segment):
value_type: SegmentType = SegmentType.FLOAT value_type: SegmentType = SegmentType.FLOAT
value: float value: float = None # type: ignore
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems. # NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
# The following tests cannot pass. # The following tests cannot pass.
# #
@ -98,12 +98,12 @@ class FloatSegment(Segment):
class IntegerSegment(Segment): class IntegerSegment(Segment):
value_type: SegmentType = SegmentType.INTEGER value_type: SegmentType = SegmentType.INTEGER
value: int value: int = None # type: ignore
class ObjectSegment(Segment): class ObjectSegment(Segment):
value_type: SegmentType = SegmentType.OBJECT value_type: SegmentType = SegmentType.OBJECT
value: Mapping[str, Any] value: Mapping[str, Any] = None # type: ignore
@property @property
def text(self) -> str: def text(self) -> str:
@ -136,7 +136,7 @@ class ArraySegment(Segment):
class FileSegment(Segment): class FileSegment(Segment):
value_type: SegmentType = SegmentType.FILE value_type: SegmentType = SegmentType.FILE
value: File value: File = None # type: ignore
@property @property
def markdown(self) -> str: def markdown(self) -> str:
@ -153,17 +153,17 @@ class FileSegment(Segment):
class BooleanSegment(Segment): class BooleanSegment(Segment):
value_type: SegmentType = SegmentType.BOOLEAN value_type: SegmentType = SegmentType.BOOLEAN
value: bool value: bool = None # type: ignore
class ArrayAnySegment(ArraySegment): class ArrayAnySegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_ANY value_type: SegmentType = SegmentType.ARRAY_ANY
value: Sequence[Any] value: Sequence[Any] = None # type: ignore
class ArrayStringSegment(ArraySegment): class ArrayStringSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_STRING value_type: SegmentType = SegmentType.ARRAY_STRING
value: Sequence[str] value: Sequence[str] = None # type: ignore
@property @property
def text(self) -> str: def text(self) -> str:
@ -175,17 +175,17 @@ class ArrayStringSegment(ArraySegment):
class ArrayNumberSegment(ArraySegment): class ArrayNumberSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_NUMBER value_type: SegmentType = SegmentType.ARRAY_NUMBER
value: Sequence[float | int] value: Sequence[float | int] = None # type: ignore
class ArrayObjectSegment(ArraySegment): class ArrayObjectSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_OBJECT value_type: SegmentType = SegmentType.ARRAY_OBJECT
value: Sequence[Mapping[str, Any]] value: Sequence[Mapping[str, Any]] = None # type: ignore
class ArrayFileSegment(ArraySegment): class ArrayFileSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_FILE value_type: SegmentType = SegmentType.ARRAY_FILE
value: Sequence[File] value: Sequence[File] = None # type: ignore
@property @property
def markdown(self) -> str: def markdown(self) -> str:
@ -205,7 +205,7 @@ class ArrayFileSegment(ArraySegment):
class ArrayBooleanSegment(ArraySegment): class ArrayBooleanSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_BOOLEAN value_type: SegmentType = SegmentType.ARRAY_BOOLEAN
value: Sequence[bool] value: Sequence[bool] = None # type: ignore
def get_segment_discriminator(v: Any) -> SegmentType | None: def get_segment_discriminator(v: Any) -> SegmentType | None:

View File

@ -3,6 +3,6 @@ from core.workflow.nodes.base import BaseNode
class WorkflowNodeRunFailedError(Exception): class WorkflowNodeRunFailedError(Exception):
def __init__(self, node: BaseNode, err_msg: str): def __init__(self, node: BaseNode, err_msg: str):
self._node = node self.node = node
self._error = err_msg self.error = err_msg
super().__init__(f"Node {node.title} run failed: {err_msg}") super().__init__(f"Node {node.title} run failed: {err_msg}")

View File

@ -67,8 +67,8 @@ class ListOperatorNode(BaseNode):
return "1" return "1"
def _run(self): def _run(self):
inputs: dict[str, list] = {} inputs: dict[str, Sequence[object]] = {}
process_data: dict[str, list] = {} process_data: dict[str, Sequence[object]] = {}
outputs: dict[str, Any] = {} outputs: dict[str, Any] = {}
variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable) variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable)

View File

@ -1183,7 +1183,8 @@ def _combine_message_content_with_role(
return AssistantPromptMessage(content=contents) return AssistantPromptMessage(content=contents)
case PromptMessageRole.SYSTEM: case PromptMessageRole.SYSTEM:
return SystemPromptMessage(content=contents) return SystemPromptMessage(content=contents)
raise NotImplementedError(f"Role {role} is not supported") case _:
raise NotImplementedError(f"Role {role} is not supported")
def _render_jinja2_message( def _render_jinja2_message(

View File

@ -462,9 +462,9 @@ class StorageKeyLoader:
upload_file_row = upload_files.get(model_id) upload_file_row = upload_files.get(model_id)
if upload_file_row is None: if upload_file_row is None:
raise ValueError(f"Upload file not found for id: {model_id}") raise ValueError(f"Upload file not found for id: {model_id}")
file._storage_key = upload_file_row.key file.storage_key = upload_file_row.key
elif file.transfer_method == FileTransferMethod.TOOL_FILE: elif file.transfer_method == FileTransferMethod.TOOL_FILE:
tool_file_row = tool_files.get(model_id) tool_file_row = tool_files.get(model_id)
if tool_file_row is None: if tool_file_row is None:
raise ValueError(f"Tool file not found for id: {model_id}") raise ValueError(f"Tool file not found for id: {model_id}")
file._storage_key = tool_file_row.file_key file.storage_key = tool_file_row.file_key

View File

@ -12,4 +12,7 @@ def serialize_value_type(v: _VarTypedDict | Segment) -> str:
if isinstance(v, Segment): if isinstance(v, Segment):
return v.value_type.exposed_type().value return v.value_type.exposed_type().value
else: else:
return v["value_type"].exposed_type().value value_type = v.get("value_type")
if value_type is None:
raise ValueError("value_type is required but not provided")
return value_type.exposed_type().value

View File

@ -69,6 +69,8 @@ def register_external_error_handlers(api: Api):
headers["WWW-Authenticate"] = 'Bearer realm="api"' headers["WWW-Authenticate"] = 'Bearer realm="api"'
return data, status_code, headers return data, status_code, headers
_ = handle_http_exception
@api.errorhandler(ValueError) @api.errorhandler(ValueError)
def handle_value_error(e: ValueError): def handle_value_error(e: ValueError):
got_request_exception.send(current_app, exception=e) got_request_exception.send(current_app, exception=e)
@ -76,6 +78,8 @@ def register_external_error_handlers(api: Api):
data = {"code": "invalid_param", "message": str(e), "status": status_code} data = {"code": "invalid_param", "message": str(e), "status": status_code}
return data, status_code return data, status_code
_ = handle_value_error
@api.errorhandler(AppInvokeQuotaExceededError) @api.errorhandler(AppInvokeQuotaExceededError)
def handle_quota_exceeded(e: AppInvokeQuotaExceededError): def handle_quota_exceeded(e: AppInvokeQuotaExceededError):
got_request_exception.send(current_app, exception=e) got_request_exception.send(current_app, exception=e)
@ -83,15 +87,17 @@ def register_external_error_handlers(api: Api):
data = {"code": "too_many_requests", "message": str(e), "status": status_code} data = {"code": "too_many_requests", "message": str(e), "status": status_code}
return data, status_code return data, status_code
_ = handle_quota_exceeded
@api.errorhandler(Exception) @api.errorhandler(Exception)
def handle_general_exception(e: Exception): def handle_general_exception(e: Exception):
got_request_exception.send(current_app, exception=e) got_request_exception.send(current_app, exception=e)
status_code = 500 status_code = 500
data: dict[str, Any] = getattr(e, "data", {"message": http_status_message(status_code)}) data = getattr(e, "data", {"message": http_status_message(status_code)})
# 🔒 Normalize non-mapping data (e.g., if someone set e.data = Response) # 🔒 Normalize non-mapping data (e.g., if someone set e.data = Response)
if not isinstance(data, Mapping): if not isinstance(data, dict):
data = {"message": str(e)} data = {"message": str(e)}
data.setdefault("code", "unknown") data.setdefault("code", "unknown")
@ -101,10 +107,12 @@ def register_external_error_handlers(api: Api):
exc_info: Any = sys.exc_info() exc_info: Any = sys.exc_info()
if exc_info[1] is None: if exc_info[1] is None:
exc_info = None exc_info = None
current_app.log_exception(exc_info) # ty: ignore [invalid-argument-type] current_app.log_exception(exc_info)
return data, status_code return data, status_code
_ = handle_general_exception
class ExternalApi(Api): class ExternalApi(Api):
_authorizations = { _authorizations = {

View File

@ -167,13 +167,6 @@ class DatetimeString:
return value return value
def _get_float(value):
try:
return float(value)
except (TypeError, ValueError):
raise ValueError(f"{value} is not a valid float")
def timezone(timezone_string): def timezone(timezone_string):
if timezone_string and timezone_string in available_timezones(): if timezone_string and timezone_string in available_timezones():
return timezone_string return timezone_string

View File

@ -1,24 +1,44 @@
{ {
"include": ["."], "include": ["."],
"exclude": [".venv", "tests/", "migrations/"], "exclude": [
"ignore": [ ".venv",
"core/", "tests/",
"controllers/", "migrations/",
"tasks/", "core/rag",
"services/", "extensions",
"schedule/", "libs",
"extensions/", "controllers/console/datasets",
"utils/", "controllers/service_api/dataset",
"repositories/", "core/ops",
"libs/", "core/tools",
"fields/", "core/model_runtime",
"factories/", "core/workflow",
"events/", "core/app/app_config/easy_ui_based_app/dataset"
"contexts/",
"constants/",
"commands.py"
], ],
"typeCheckingMode": "strict", "typeCheckingMode": "strict",
"allowedUntypedLibraries": [
"flask_restx",
"flask_login",
"opentelemetry.instrumentation.celery",
"opentelemetry.instrumentation.flask",
"opentelemetry.instrumentation.requests",
"opentelemetry.instrumentation.sqlalchemy",
"opentelemetry.instrumentation.redis"
],
"reportUnknownMemberType": "hint",
"reportUnknownParameterType": "hint",
"reportUnknownArgumentType": "hint",
"reportUnknownVariableType": "hint",
"reportUnknownLambdaType": "hint",
"reportMissingParameterType": "hint",
"reportMissingTypeArgument": "hint",
"reportUnnecessaryContains": "hint",
"reportUnnecessaryComparison": "hint",
"reportUnnecessaryCast": "hint",
"reportUnnecessaryIsInstance": "hint",
"reportUntypedFunctionDecorator": "hint",
"reportAttributeAccessIssue": "hint",
"pythonVersion": "3.11", "pythonVersion": "3.11",
"pythonPlatform": "All" "pythonPlatform": "All"
} }

View File

@ -1318,7 +1318,7 @@ class RegisterService:
def get_invitation_if_token_valid( def get_invitation_if_token_valid(
cls, workspace_id: Optional[str], email: str, token: str cls, workspace_id: Optional[str], email: str, token: str
) -> Optional[dict[str, Any]]: ) -> Optional[dict[str, Any]]:
invitation_data = cls._get_invitation_by_token(token, workspace_id, email) invitation_data = cls.get_invitation_by_token(token, workspace_id, email)
if not invitation_data: if not invitation_data:
return None return None
@ -1355,7 +1355,7 @@ class RegisterService:
} }
@classmethod @classmethod
def _get_invitation_by_token( def get_invitation_by_token(
cls, token: str, workspace_id: Optional[str] = None, email: Optional[str] = None cls, token: str, workspace_id: Optional[str] = None, email: Optional[str] = None
) -> Optional[dict[str, str]]: ) -> Optional[dict[str, str]]:
if workspace_id is not None and email is not None: if workspace_id is not None and email is not None:

View File

@ -349,7 +349,7 @@ class AppAnnotationService:
try: try:
# Skip the first row # Skip the first row
df = pd.read_csv(file, dtype=str) df = pd.read_csv(file.stream, dtype=str)
result = [] result = []
for _, row in df.iterrows(): for _, row in df.iterrows():
content = {"question": row.iloc[0], "answer": row.iloc[1]} content = {"question": row.iloc[0], "answer": row.iloc[1]}
@ -463,15 +463,23 @@ class AppAnnotationService:
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
if annotation_setting: if annotation_setting:
collection_binding_detail = annotation_setting.collection_binding_detail collection_binding_detail = annotation_setting.collection_binding_detail
return { if collection_binding_detail:
"id": annotation_setting.id, return {
"enabled": True, "id": annotation_setting.id,
"score_threshold": annotation_setting.score_threshold, "enabled": True,
"embedding_model": { "score_threshold": annotation_setting.score_threshold,
"embedding_provider_name": collection_binding_detail.provider_name, "embedding_model": {
"embedding_model_name": collection_binding_detail.model_name, "embedding_provider_name": collection_binding_detail.provider_name,
}, "embedding_model_name": collection_binding_detail.model_name,
} },
}
else:
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {},
}
return {"enabled": False} return {"enabled": False}
@classmethod @classmethod
@ -506,15 +514,23 @@ class AppAnnotationService:
collection_binding_detail = annotation_setting.collection_binding_detail collection_binding_detail = annotation_setting.collection_binding_detail
return { if collection_binding_detail:
"id": annotation_setting.id, return {
"enabled": True, "id": annotation_setting.id,
"score_threshold": annotation_setting.score_threshold, "enabled": True,
"embedding_model": { "score_threshold": annotation_setting.score_threshold,
"embedding_provider_name": collection_binding_detail.provider_name, "embedding_model": {
"embedding_model_name": collection_binding_detail.model_name, "embedding_provider_name": collection_binding_detail.provider_name,
}, "embedding_model_name": collection_binding_detail.model_name,
} },
}
else:
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {},
}
@classmethod @classmethod
def clear_all_annotations(cls, app_id: str): def clear_all_annotations(cls, app_id: str):

View File

@ -407,6 +407,7 @@ class ClearFreePlanTenantExpiredLogs:
datetime.timedelta(hours=1), datetime.timedelta(hours=1),
] ]
tenant_count = 0
for test_interval in test_intervals: for test_interval in test_intervals:
tenant_count = ( tenant_count = (
session.query(Tenant.id) session.query(Tenant.id)

View File

@ -134,11 +134,14 @@ class DatasetService:
# Check if tag_ids is not empty to avoid WHERE false condition # Check if tag_ids is not empty to avoid WHERE false condition
if tag_ids and len(tag_ids) > 0: if tag_ids and len(tag_ids) > 0:
target_ids = TagService.get_target_ids_by_tag_ids( if tenant_id is not None:
"knowledge", target_ids = TagService.get_target_ids_by_tag_ids(
tenant_id, # ty: ignore [invalid-argument-type] "knowledge",
tag_ids, tenant_id,
) tag_ids,
)
else:
target_ids = []
if target_ids and len(target_ids) > 0: if target_ids and len(target_ids) > 0:
query = query.where(Dataset.id.in_(target_ids)) query = query.where(Dataset.id.in_(target_ids))
else: else:
@ -987,7 +990,8 @@ class DocumentService:
for document in documents for document in documents
if document.data_source_type == "upload_file" and document.data_source_info_dict if document.data_source_type == "upload_file" and document.data_source_info_dict
] ]
batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids) if dataset.doc_form is not None:
batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids)
for document in documents: for document in documents:
db.session.delete(document) db.session.delete(document)
@ -2688,56 +2692,6 @@ class SegmentService:
return paginated_segments.items, paginated_segments.total return paginated_segments.items, paginated_segments.total
@classmethod
def update_segment_by_id(
cls, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, segment_data: dict, user_id: str
) -> tuple[DocumentSegment, Document]:
"""Update a segment by its ID with validation and checks."""
# check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check embedding model setting if high quality
if dataset.indexing_technique == "high_quality":
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=user_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
# check segment
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
.first()
)
if not segment:
raise NotFound("Segment not found.")
# validate and update segment
cls.segment_create_args_validate(segment_data, document)
updated_segment = cls.update_segment(SegmentUpdateArgs(**segment_data), segment, document, dataset)
return updated_segment, document
@classmethod @classmethod
def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]: def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]:
"""Get a segment by its ID.""" """Get a segment by its ID."""

View File

@ -181,7 +181,7 @@ class ExternalDatasetService:
do http request depending on api bundle do http request depending on api bundle
""" """
kwargs = { kwargs: dict[str, Any] = {
"url": settings.url, "url": settings.url,
"headers": settings.headers, "headers": settings.headers,
"follow_redirects": True, "follow_redirects": True,

View File

@ -1,7 +1,7 @@
import hashlib import hashlib
import os import os
import uuid import uuid
from typing import Any, Literal, Union from typing import Literal, Union
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
@ -35,7 +35,7 @@ class FileService:
filename: str, filename: str,
content: bytes, content: bytes,
mimetype: str, mimetype: str,
user: Union[Account, EndUser, Any], user: Union[Account, EndUser],
source: Literal["datasets"] | None = None, source: Literal["datasets"] | None = None,
source_url: str = "", source_url: str = "",
) -> UploadFile: ) -> UploadFile:

View File

@ -165,7 +165,7 @@ class ModelLoadBalancingService:
try: try:
if load_balancing_config.encrypted_config: if load_balancing_config.encrypted_config:
credentials = json.loads(load_balancing_config.encrypted_config) credentials: dict[str, object] = json.loads(load_balancing_config.encrypted_config)
else: else:
credentials = {} credentials = {}
except JSONDecodeError: except JSONDecodeError:
@ -180,11 +180,13 @@ class ModelLoadBalancingService:
for variable in credential_secret_variables: for variable in credential_secret_variables:
if variable in credentials: if variable in credentials:
try: try:
credentials[variable] = encrypter.decrypt_token_with_decoding( token_value = credentials.get(variable)
credentials.get(variable), # ty: ignore [invalid-argument-type] if isinstance(token_value, str):
decoding_rsa_key, credentials[variable] = encrypter.decrypt_token_with_decoding(
decoding_cipher_rsa, token_value,
) decoding_rsa_key,
decoding_cipher_rsa,
)
except ValueError: except ValueError:
pass pass
@ -345,8 +347,9 @@ class ModelLoadBalancingService:
credential_id = config.get("credential_id") credential_id = config.get("credential_id")
enabled = config.get("enabled") enabled = config.get("enabled")
credential_record: ProviderCredential | ProviderModelCredential | None = None
if credential_id: if credential_id:
credential_record: ProviderCredential | ProviderModelCredential | None = None
if config_from == "predefined-model": if config_from == "predefined-model":
credential_record = ( credential_record = (
db.session.query(ProviderCredential) db.session.query(ProviderCredential)

View File

@ -99,6 +99,7 @@ class PluginMigration:
datetime.timedelta(hours=1), datetime.timedelta(hours=1),
] ]
tenant_count = 0
for test_interval in test_intervals: for test_interval in test_intervals:
tenant_count = ( tenant_count = (
session.query(Tenant.id) session.query(Tenant.id)

View File

@ -223,8 +223,8 @@ class BuiltinToolManageService:
""" """
add builtin tool provider add builtin tool provider
""" """
try: with Session(db.engine) as session:
with Session(db.engine) as session: try:
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}" lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
with redis_client.lock(lock, timeout=20): with redis_client.lock(lock, timeout=20):
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
@ -285,9 +285,9 @@ class BuiltinToolManageService:
session.add(db_provider) session.add(db_provider)
session.commit() session.commit()
except Exception as e: except Exception as e:
session.rollback() session.rollback()
raise ValueError(str(e)) raise ValueError(str(e))
return {"result": "success"} return {"result": "success"}
@staticmethod @staticmethod

View File

@ -18,6 +18,7 @@ from core.helper import encrypter
from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.simple_prompt_transform import SimplePromptTransform from core.prompt.simple_prompt_transform import SimplePromptTransform
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
from events.app_event import app_was_created from events.app_event import app_was_created
from extensions.ext_database import db from extensions.ext_database import db
@ -420,7 +421,11 @@ class WorkflowConverter:
query_in_prompt=False, query_in_prompt=False,
) )
template = prompt_template_config["prompt_template"].template prompt_template_obj = prompt_template_config["prompt_template"]
if not isinstance(prompt_template_obj, PromptTemplateParser):
raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template_obj)}")
template = prompt_template_obj.template
if not template: if not template:
prompts = [] prompts = []
else: else:
@ -457,7 +462,11 @@ class WorkflowConverter:
query_in_prompt=False, query_in_prompt=False,
) )
template = prompt_template_config["prompt_template"].template prompt_template_obj = prompt_template_config["prompt_template"]
if not isinstance(prompt_template_obj, PromptTemplateParser):
raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template_obj)}")
template = prompt_template_obj.template
template = self._replace_template_variables( template = self._replace_template_variables(
template=template, template=template,
variables=start_node["data"]["variables"], variables=start_node["data"]["variables"],
@ -467,6 +476,9 @@ class WorkflowConverter:
prompts = {"text": template} prompts = {"text": template}
prompt_rules = prompt_template_config["prompt_rules"] prompt_rules = prompt_template_config["prompt_rules"]
if not isinstance(prompt_rules, dict):
raise TypeError(f"Expected dict for prompt_rules, got {type(prompt_rules)}")
role_prefix = { role_prefix = {
"user": prompt_rules.get("human_prefix", "Human"), "user": prompt_rules.get("human_prefix", "Human"),
"assistant": prompt_rules.get("assistant_prefix", "Assistant"), "assistant": prompt_rules.get("assistant_prefix", "Assistant"),

View File

@ -769,10 +769,10 @@ class WorkflowService:
) )
error = node_run_result.error if not run_succeeded else None error = node_run_result.error if not run_succeeded else None
except WorkflowNodeRunFailedError as e: except WorkflowNodeRunFailedError as e:
node = e._node node = e.node
run_succeeded = False run_succeeded = False
node_run_result = None node_run_result = None
error = e._error error = e.error
# Create a NodeExecution domain model # Create a NodeExecution domain model
node_execution = WorkflowNodeExecution( node_execution = WorkflowNodeExecution(

View File

@ -12,7 +12,7 @@ class WorkspaceService:
def get_tenant_info(cls, tenant: Tenant): def get_tenant_info(cls, tenant: Tenant):
if not tenant: if not tenant:
return None return None
tenant_info = { tenant_info: dict[str, object] = {
"id": tenant.id, "id": tenant.id,
"name": tenant.name, "name": tenant.name,
"plan": tenant.plan, "plan": tenant.plan,

View File

@ -3278,7 +3278,7 @@ class TestRegisterService:
redis_client.setex(cache_key, 24 * 60 * 60, account_id) redis_client.setex(cache_key, 24 * 60 * 60, account_id)
# Execute invitation retrieval # Execute invitation retrieval
result = RegisterService._get_invitation_by_token( result = RegisterService.get_invitation_by_token(
token=token, token=token,
workspace_id=workspace_id, workspace_id=workspace_id,
email=email, email=email,
@ -3316,7 +3316,7 @@ class TestRegisterService:
redis_client.setex(token_key, 24 * 60 * 60, json.dumps(invitation_data)) redis_client.setex(token_key, 24 * 60 * 60, json.dumps(invitation_data))
# Execute invitation retrieval # Execute invitation retrieval
result = RegisterService._get_invitation_by_token(token=token) result = RegisterService.get_invitation_by_token(token=token)
# Verify result contains expected data # Verify result contains expected data
assert result is not None assert result is not None

View File

@ -14,6 +14,7 @@ from core.app.app_config.entities import (
VariableEntityType, VariableEntityType,
) )
from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.entities.llm_entities import LLMMode
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from models.account import Account, Tenant from models.account import Account, Tenant
from models.api_based_extension import APIBasedExtension from models.api_based_extension import APIBasedExtension
from models.model import App, AppMode, AppModelConfig from models.model import App, AppMode, AppModelConfig
@ -37,7 +38,7 @@ class TestWorkflowConverter:
# Setup default mock returns # Setup default mock returns
mock_encrypter.decrypt_token.return_value = "decrypted_api_key" mock_encrypter.decrypt_token.return_value = "decrypted_api_key"
mock_prompt_transform.return_value.get_prompt_template.return_value = { mock_prompt_transform.return_value.get_prompt_template.return_value = {
"prompt_template": type("obj", (object,), {"template": "You are a helpful assistant {{text_input}}"})(), "prompt_template": PromptTemplateParser(template="You are a helpful assistant {{text_input}}"),
"prompt_rules": {"human_prefix": "Human", "assistant_prefix": "Assistant"}, "prompt_rules": {"human_prefix": "Human", "assistant_prefix": "Assistant"},
} }
mock_agent_chat_config_manager.get_app_config.return_value = self._create_mock_app_config() mock_agent_chat_config_manager.get_app_config.return_value = self._create_mock_app_config()

View File

@ -1370,8 +1370,8 @@ class TestRegisterService:
account_id="user-123", email="test@example.com" account_id="user-123", email="test@example.com"
) )
with patch("services.account_service.RegisterService._get_invitation_by_token") as mock_get_invitation_by_token: with patch("services.account_service.RegisterService.get_invitation_by_token") as mock_get_invitation_by_token:
# Mock the invitation data returned by _get_invitation_by_token # Mock the invitation data returned by get_invitation_by_token
invitation_data = { invitation_data = {
"account_id": "user-123", "account_id": "user-123",
"email": "test@example.com", "email": "test@example.com",
@ -1503,12 +1503,12 @@ class TestRegisterService:
assert result == "member_invite:token:test-token" assert result == "member_invite:token:test-token"
def test_get_invitation_by_token_with_workspace_and_email(self, mock_redis_dependencies): def test_get_invitation_by_token_with_workspace_and_email(self, mock_redis_dependencies):
"""Test _get_invitation_by_token with workspace ID and email.""" """Test get_invitation_by_token with workspace ID and email."""
# Setup mock # Setup mock
mock_redis_dependencies.get.return_value = b"user-123" mock_redis_dependencies.get.return_value = b"user-123"
# Execute test # Execute test
result = RegisterService._get_invitation_by_token("token-123", "workspace-456", "test@example.com") result = RegisterService.get_invitation_by_token("token-123", "workspace-456", "test@example.com")
# Verify results # Verify results
assert result is not None assert result is not None
@ -1517,7 +1517,7 @@ class TestRegisterService:
assert result["workspace_id"] == "workspace-456" assert result["workspace_id"] == "workspace-456"
def test_get_invitation_by_token_without_workspace_and_email(self, mock_redis_dependencies): def test_get_invitation_by_token_without_workspace_and_email(self, mock_redis_dependencies):
"""Test _get_invitation_by_token without workspace ID and email.""" """Test get_invitation_by_token without workspace ID and email."""
# Setup mock # Setup mock
invitation_data = { invitation_data = {
"account_id": "user-123", "account_id": "user-123",
@ -1527,19 +1527,19 @@ class TestRegisterService:
mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode() mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode()
# Execute test # Execute test
result = RegisterService._get_invitation_by_token("token-123") result = RegisterService.get_invitation_by_token("token-123")
# Verify results # Verify results
assert result is not None assert result is not None
assert result == invitation_data assert result == invitation_data
def test_get_invitation_by_token_no_data(self, mock_redis_dependencies): def test_get_invitation_by_token_no_data(self, mock_redis_dependencies):
"""Test _get_invitation_by_token with no data.""" """Test get_invitation_by_token with no data."""
# Setup mock # Setup mock
mock_redis_dependencies.get.return_value = None mock_redis_dependencies.get.return_value = None
# Execute test # Execute test
result = RegisterService._get_invitation_by_token("token-123") result = RegisterService.get_invitation_by_token("token-123")
# Verify results # Verify results
assert result is None assert result is None