mirror of https://github.com/langgenius/dify.git
Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine
This commit is contained in:
commit
2a97a69825
|
|
@ -511,7 +511,7 @@ def add_qdrant_index(field: str):
|
|||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
from qdrant_client.http.models import PayloadSchemaType
|
||||
|
||||
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig
|
||||
from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
|
||||
|
||||
for binding in bindings:
|
||||
if dify_config.QDRANT_URL is None:
|
||||
|
|
@ -525,7 +525,21 @@ def add_qdrant_index(field: str):
|
|||
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,
|
||||
)
|
||||
try:
|
||||
client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params())
|
||||
params = qdrant_config.to_qdrant_params()
|
||||
# Check the type before using
|
||||
if isinstance(params, PathQdrantParams):
|
||||
# PathQdrantParams case
|
||||
client = qdrant_client.QdrantClient(path=params.path)
|
||||
else:
|
||||
# UrlQdrantParams case - params is UrlQdrantParams
|
||||
client = qdrant_client.QdrantClient(
|
||||
url=params.url,
|
||||
api_key=params.api_key,
|
||||
timeout=int(params.timeout),
|
||||
verify=params.verify,
|
||||
grpc_port=params.grpc_port,
|
||||
prefer_grpc=params.prefer_grpc,
|
||||
)
|
||||
# create payload index
|
||||
client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD)
|
||||
create_count += 1
|
||||
|
|
|
|||
|
|
@ -16,14 +16,14 @@ AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"]
|
|||
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
|
||||
|
||||
|
||||
_doc_extensions: list[str]
|
||||
if dify_config.ETL_TYPE == "Unstructured":
|
||||
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"]
|
||||
DOCUMENT_EXTENSIONS.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
|
||||
_doc_extensions = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"]
|
||||
_doc_extensions.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
|
||||
if dify_config.UNSTRUCTURED_API_URL:
|
||||
DOCUMENT_EXTENSIONS.append("ppt")
|
||||
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
|
||||
_doc_extensions.append("ppt")
|
||||
else:
|
||||
DOCUMENT_EXTENSIONS = [
|
||||
_doc_extensions = [
|
||||
"txt",
|
||||
"markdown",
|
||||
"md",
|
||||
|
|
@ -38,4 +38,4 @@ else:
|
|||
"vtt",
|
||||
"properties",
|
||||
]
|
||||
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
|
||||
DOCUMENT_EXTENSIONS = _doc_extensions + [ext.upper() for ext in _doc_extensions]
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ if TYPE_CHECKING:
|
|||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -43,56 +43,64 @@ api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm"
|
|||
api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")
|
||||
|
||||
# Import other controllers
|
||||
from . import admin, apikey, extension, feature, ping, setup, version
|
||||
from . import admin, apikey, extension, feature, ping, setup, version # pyright: ignore[reportUnusedImport]
|
||||
|
||||
# Import app controllers
|
||||
from .app import (
|
||||
advanced_prompt_template,
|
||||
agent,
|
||||
annotation,
|
||||
app,
|
||||
audio,
|
||||
completion,
|
||||
conversation,
|
||||
conversation_variables,
|
||||
generator,
|
||||
mcp_server,
|
||||
message,
|
||||
model_config,
|
||||
ops_trace,
|
||||
site,
|
||||
statistic,
|
||||
workflow,
|
||||
workflow_app_log,
|
||||
workflow_draft_variable,
|
||||
workflow_run,
|
||||
workflow_statistic,
|
||||
advanced_prompt_template, # pyright: ignore[reportUnusedImport]
|
||||
agent, # pyright: ignore[reportUnusedImport]
|
||||
annotation, # pyright: ignore[reportUnusedImport]
|
||||
app, # pyright: ignore[reportUnusedImport]
|
||||
audio, # pyright: ignore[reportUnusedImport]
|
||||
completion, # pyright: ignore[reportUnusedImport]
|
||||
conversation, # pyright: ignore[reportUnusedImport]
|
||||
conversation_variables, # pyright: ignore[reportUnusedImport]
|
||||
generator, # pyright: ignore[reportUnusedImport]
|
||||
mcp_server, # pyright: ignore[reportUnusedImport]
|
||||
message, # pyright: ignore[reportUnusedImport]
|
||||
model_config, # pyright: ignore[reportUnusedImport]
|
||||
ops_trace, # pyright: ignore[reportUnusedImport]
|
||||
site, # pyright: ignore[reportUnusedImport]
|
||||
statistic, # pyright: ignore[reportUnusedImport]
|
||||
workflow, # pyright: ignore[reportUnusedImport]
|
||||
workflow_app_log, # pyright: ignore[reportUnusedImport]
|
||||
workflow_draft_variable, # pyright: ignore[reportUnusedImport]
|
||||
workflow_run, # pyright: ignore[reportUnusedImport]
|
||||
workflow_statistic, # pyright: ignore[reportUnusedImport]
|
||||
)
|
||||
|
||||
# Import auth controllers
|
||||
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth, oauth_server
|
||||
from .auth import (
|
||||
activate, # pyright: ignore[reportUnusedImport]
|
||||
data_source_bearer_auth, # pyright: ignore[reportUnusedImport]
|
||||
data_source_oauth, # pyright: ignore[reportUnusedImport]
|
||||
forgot_password, # pyright: ignore[reportUnusedImport]
|
||||
login, # pyright: ignore[reportUnusedImport]
|
||||
oauth, # pyright: ignore[reportUnusedImport]
|
||||
oauth_server, # pyright: ignore[reportUnusedImport]
|
||||
)
|
||||
|
||||
# Import billing controllers
|
||||
from .billing import billing, compliance
|
||||
from .billing import billing, compliance # pyright: ignore[reportUnusedImport]
|
||||
|
||||
# Import datasets controllers
|
||||
from .datasets import (
|
||||
data_source,
|
||||
datasets,
|
||||
datasets_document,
|
||||
datasets_segments,
|
||||
external,
|
||||
hit_testing,
|
||||
metadata,
|
||||
website,
|
||||
data_source, # pyright: ignore[reportUnusedImport]
|
||||
datasets, # pyright: ignore[reportUnusedImport]
|
||||
datasets_document, # pyright: ignore[reportUnusedImport]
|
||||
datasets_segments, # pyright: ignore[reportUnusedImport]
|
||||
external, # pyright: ignore[reportUnusedImport]
|
||||
hit_testing, # pyright: ignore[reportUnusedImport]
|
||||
metadata, # pyright: ignore[reportUnusedImport]
|
||||
website, # pyright: ignore[reportUnusedImport]
|
||||
)
|
||||
|
||||
# Import explore controllers
|
||||
from .explore import (
|
||||
installed_app,
|
||||
parameter,
|
||||
recommended_app,
|
||||
saved_message,
|
||||
installed_app, # pyright: ignore[reportUnusedImport]
|
||||
parameter, # pyright: ignore[reportUnusedImport]
|
||||
recommended_app, # pyright: ignore[reportUnusedImport]
|
||||
saved_message, # pyright: ignore[reportUnusedImport]
|
||||
)
|
||||
|
||||
# Explore Audio
|
||||
|
|
@ -167,18 +175,18 @@ api.add_resource(
|
|||
)
|
||||
|
||||
# Import tag controllers
|
||||
from .tag import tags
|
||||
from .tag import tags # pyright: ignore[reportUnusedImport]
|
||||
|
||||
# Import workspace controllers
|
||||
from .workspace import (
|
||||
account,
|
||||
agent_providers,
|
||||
endpoint,
|
||||
load_balancing_config,
|
||||
members,
|
||||
model_providers,
|
||||
models,
|
||||
plugin,
|
||||
tool_providers,
|
||||
workspace,
|
||||
account, # pyright: ignore[reportUnusedImport]
|
||||
agent_providers, # pyright: ignore[reportUnusedImport]
|
||||
endpoint, # pyright: ignore[reportUnusedImport]
|
||||
load_balancing_config, # pyright: ignore[reportUnusedImport]
|
||||
members, # pyright: ignore[reportUnusedImport]
|
||||
model_providers, # pyright: ignore[reportUnusedImport]
|
||||
models, # pyright: ignore[reportUnusedImport]
|
||||
plugin, # pyright: ignore[reportUnusedImport]
|
||||
tool_providers, # pyright: ignore[reportUnusedImport]
|
||||
workspace, # pyright: ignore[reportUnusedImport]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
import flask_restx
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx._http import HTTPStatus
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
|
@ -40,7 +41,7 @@ def _get_resource(resource_id, tenant_id, resource_model):
|
|||
).scalar_one_or_none()
|
||||
|
||||
if resource is None:
|
||||
flask_restx.abort(404, message=f"{resource_model.__name__} not found.")
|
||||
flask_restx.abort(HTTPStatus.NOT_FOUND, message=f"{resource_model.__name__} not found.")
|
||||
|
||||
return resource
|
||||
|
||||
|
|
@ -49,7 +50,7 @@ class BaseApiKeyListResource(Resource):
|
|||
method_decorators = [account_initialization_required, login_required, setup_required]
|
||||
|
||||
resource_type: str | None = None
|
||||
resource_model: Optional[Any] = None
|
||||
resource_model: Optional[type] = None
|
||||
resource_id_field: str | None = None
|
||||
token_prefix: str | None = None
|
||||
max_keys = 10
|
||||
|
|
@ -82,7 +83,7 @@ class BaseApiKeyListResource(Resource):
|
|||
|
||||
if current_key_count >= self.max_keys:
|
||||
flask_restx.abort(
|
||||
400,
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
||||
custom="max_keys_exceeded",
|
||||
)
|
||||
|
|
@ -102,7 +103,7 @@ class BaseApiKeyResource(Resource):
|
|||
method_decorators = [account_initialization_required, login_required, setup_required]
|
||||
|
||||
resource_type: str | None = None
|
||||
resource_model: Optional[Any] = None
|
||||
resource_model: Optional[type] = None
|
||||
resource_id_field: str | None = None
|
||||
|
||||
def delete(self, resource_id, api_key_id):
|
||||
|
|
@ -126,7 +127,7 @@ class BaseApiKeyResource(Resource):
|
|||
)
|
||||
|
||||
if key is None:
|
||||
flask_restx.abort(404, message="API key not found")
|
||||
flask_restx.abort(HTTPStatus.NOT_FOUND, message="API key not found")
|
||||
|
||||
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -115,6 +115,10 @@ class AppListApi(Resource):
|
|||
raise BadRequest("mode is required")
|
||||
|
||||
app_service = AppService()
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
if current_user.current_tenant_id is None:
|
||||
raise ValueError("current_user.current_tenant_id cannot be None")
|
||||
app = app_service.create_app(current_user.current_tenant_id, args, current_user)
|
||||
|
||||
return app, 201
|
||||
|
|
@ -161,14 +165,26 @@ class AppApi(Resource):
|
|||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app(app_model, args)
|
||||
# Construct ArgsDict from parsed arguments
|
||||
from services.app_service import AppService as AppServiceType
|
||||
|
||||
args_dict: AppServiceType.ArgsDict = {
|
||||
"name": args["name"],
|
||||
"description": args.get("description", ""),
|
||||
"icon_type": args.get("icon_type", ""),
|
||||
"icon": args.get("icon", ""),
|
||||
"icon_background": args.get("icon_background", ""),
|
||||
"use_icon_as_answer_icon": args.get("use_icon_as_answer_icon", False),
|
||||
"max_active_requests": args.get("max_active_requests", 0),
|
||||
}
|
||||
app_model = app_service.update_app(app_model, args_dict)
|
||||
|
||||
return app_model
|
||||
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def delete(self, app_model):
|
||||
"""Delete app"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
|
|
@ -224,10 +240,10 @@ class AppCopyApi(Resource):
|
|||
|
||||
|
||||
class AppExportApi(Resource):
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def get(self, app_model):
|
||||
"""Export app"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
|
|
@ -263,7 +279,7 @@ class AppNameApi(Resource):
|
|||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_name(app_model, args.get("name"))
|
||||
app_model = app_service.update_app_name(app_model, args["name"])
|
||||
|
||||
return app_model
|
||||
|
||||
|
|
@ -285,7 +301,7 @@ class AppIconApi(Resource):
|
|||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_icon(app_model, args.get("icon"), args.get("icon_background"))
|
||||
app_model = app_service.update_app_icon(app_model, args.get("icon") or "", args.get("icon_background") or "")
|
||||
|
||||
return app_model
|
||||
|
||||
|
|
@ -306,7 +322,7 @@ class AppSiteStatus(Resource):
|
|||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_site_status(app_model, args.get("enable_site"))
|
||||
app_model = app_service.update_app_site_status(app_model, args["enable_site"])
|
||||
|
||||
return app_model
|
||||
|
||||
|
|
@ -327,7 +343,7 @@ class AppApiStatus(Resource):
|
|||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_api_status(app_model, args.get("enable_api"))
|
||||
app_model = app_service.update_app_api_status(app_model, args["enable_api"])
|
||||
|
||||
return app_model
|
||||
|
||||
|
|
|
|||
|
|
@ -77,10 +77,10 @@ class ChatMessageAudioApi(Resource):
|
|||
|
||||
|
||||
class ChatMessageTextApi(Resource):
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def post(self, app_model: App):
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -125,10 +125,10 @@ class ChatMessageTextApi(Resource):
|
|||
|
||||
|
||||
class TextModesApi(Resource):
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def get(self, app_model):
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import logging
|
||||
|
||||
import flask_login
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
|
@ -29,7 +28,8 @@ from core.helper.trace_id_helper import get_external_trace_id
|
|||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
|
@ -56,11 +56,11 @@ class CompletionMessageApi(Resource):
|
|||
streaming = args["response_mode"] != "blocking"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
account = flask_login.current_user
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account or EndUser instance")
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
|
|
@ -92,9 +92,9 @@ class CompletionMessageStopApi(Resource):
|
|||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def post(self, app_model, task_id):
|
||||
account = flask_login.current_user
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
|
@ -123,11 +123,11 @@ class ChatMessageApi(Resource):
|
|||
if external_trace_id:
|
||||
args["external_trace_id"] = external_trace_id
|
||||
|
||||
account = flask_login.current_user
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account or EndUser instance")
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
|
|
@ -161,9 +161,9 @@ class ChatMessageStopApi(Resource):
|
|||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model, task_id):
|
||||
account = flask_login.current_user
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from fields.conversation_fields import (
|
|||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||
from models import Account, Conversation, EndUser, Message, MessageAnnotation
|
||||
from models.model import AppMode
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
|
|
@ -124,6 +124,8 @@ class CompletionConversationDetailApi(Resource):
|
|||
conversation_id = str(conversation_id)
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
ConversationService.delete(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
|
@ -282,6 +284,8 @@ class ChatConversationDetailApi(Resource):
|
|||
conversation_id = str(conversation_id)
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
ConversationService.delete(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import logging
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from sqlalchemy import exists, select
|
||||
|
|
@ -27,7 +26,8 @@ from extensions.ext_database import db
|
|||
from fields.conversation_fields import annotation_fields, message_detail_fields
|
||||
from libs.helper import uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
|
|
@ -118,11 +118,14 @@ class ChatMessageListApi(Resource):
|
|||
|
||||
|
||||
class MessageFeedbackApi(Resource):
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def post(self, app_model):
|
||||
if current_user is None:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", required=True, type=uuid_value, location="json")
|
||||
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||
|
|
@ -167,6 +170,8 @@ class MessageAnnotationApi(Resource):
|
|||
@get_app_model
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_model):
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
@ -182,10 +187,10 @@ class MessageAnnotationApi(Resource):
|
|||
|
||||
|
||||
class MessageAnnotationCountApi(Resource):
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def get(self, app_model):
|
||||
count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count()
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from extensions.ext_database import db
|
|||
from fields.app_fields import app_site_fields
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import login_required
|
||||
from models import Site
|
||||
from models import Account, Site
|
||||
|
||||
|
||||
def parse_app_site_args():
|
||||
|
|
@ -75,6 +75,8 @@ class AppSite(Resource):
|
|||
if value is not None:
|
||||
setattr(site, attr_name, value)
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
site.updated_by = current_user.id
|
||||
site.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
|
@ -99,6 +101,8 @@ class AppSiteAccessTokenReset(Resource):
|
|||
raise NotFound
|
||||
|
||||
site.code = Site.generate_code(16)
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
site.updated_by = current_user.id
|
||||
site.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -18,10 +18,10 @@ from models import AppMode, Message
|
|||
|
||||
|
||||
class DailyMessageStatistic(Resource):
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
|
||||
|
|
@ -75,10 +75,10 @@ WHERE
|
|||
|
||||
|
||||
class DailyConversationStatistic(Resource):
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
|
||||
|
|
@ -127,10 +127,10 @@ class DailyConversationStatistic(Resource):
|
|||
|
||||
|
||||
class DailyTerminalsStatistic(Resource):
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
|
||||
|
|
@ -184,10 +184,10 @@ WHERE
|
|||
|
||||
|
||||
class DailyTokenCostStatistic(Resource):
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
|
||||
|
|
@ -320,10 +320,10 @@ ORDER BY
|
|||
|
||||
|
||||
class UserSatisfactionRateStatistic(Resource):
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
|
||||
|
|
@ -443,10 +443,10 @@ WHERE
|
|||
|
||||
|
||||
class TokensPerSecondStatistic(Resource):
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
|
||||
|
|
|
|||
|
|
@ -18,10 +18,10 @@ from models.model import AppMode
|
|||
|
||||
|
||||
class WorkflowDailyRunsStatistic(Resource):
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
|
||||
|
|
@ -80,10 +80,10 @@ WHERE
|
|||
|
||||
|
||||
class WorkflowDailyTerminalsStatistic(Resource):
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
|
||||
|
|
@ -142,10 +142,10 @@ WHERE
|
|||
|
||||
|
||||
class WorkflowDailyTokenCostStatistic(Resource):
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
|
||||
|
|
|
|||
|
|
@ -77,6 +77,9 @@ class OAuthCallback(Resource):
|
|||
if state:
|
||||
invite_token = state
|
||||
|
||||
if not code:
|
||||
return {"error": "Authorization code is required"}, 400
|
||||
|
||||
try:
|
||||
token = oauth_provider.get_access_token(code)
|
||||
user_info = oauth_provider.get_user_info(token)
|
||||
|
|
@ -86,7 +89,7 @@ class OAuthCallback(Resource):
|
|||
return {"error": "OAuth process failed"}, 400
|
||||
|
||||
if invite_token and RegisterService.is_valid_invite_token(invite_token):
|
||||
invitation = RegisterService._get_invitation_by_token(token=invite_token)
|
||||
invitation = RegisterService.get_invitation_by_token(token=invite_token)
|
||||
if invitation:
|
||||
invitation_email = invitation.get("email", None)
|
||||
if invitation_email != user_info.email:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import logging
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import reqparse
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
|
|
@ -28,6 +27,8 @@ from extensions.ext_database import db
|
|||
from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
|
@ -57,6 +58,8 @@ class CompletionApi(InstalledAppResource):
|
|||
db.session.commit()
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
|
||||
)
|
||||
|
|
@ -90,6 +93,8 @@ class CompletionStopApi(InstalledAppResource):
|
|||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
|
@ -117,6 +122,8 @@ class ChatApi(InstalledAppResource):
|
|||
db.session.commit()
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||
)
|
||||
|
|
@ -153,6 +160,8 @@ class ChatStopApi(InstalledAppResource):
|
|||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from sqlalchemy.orm import Session
|
||||
|
|
@ -10,6 +9,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
|||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.model import AppMode
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
|
||||
|
|
@ -35,6 +36,8 @@ class ConversationListApi(InstalledAppResource):
|
|||
pinned = args["pinned"] == "true"
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
with Session(db.engine) as session:
|
||||
return WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
|
|
@ -58,6 +61,8 @@ class ConversationApi(InstalledAppResource):
|
|||
|
||||
conversation_id = str(c_id)
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
ConversationService.delete(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
|
@ -81,6 +86,8 @@ class ConversationRenameApi(InstalledAppResource):
|
|||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
return ConversationService.rename(
|
||||
app_model, conversation_id, current_user, args["name"], args["auto_generate"]
|
||||
)
|
||||
|
|
@ -98,6 +105,8 @@ class ConversationPinApi(InstalledAppResource):
|
|||
conversation_id = str(c_id)
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
WebConversationService.pin(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
|
@ -113,6 +122,8 @@ class ConversationUnPinApi(InstalledAppResource):
|
|||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
WebConversationService.unpin(app_model, conversation_id, current_user)
|
||||
|
||||
return {"result": "success"}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import logging
|
|||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, inputs, marshal_with, reqparse
|
||||
from sqlalchemy import and_
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
|
@ -13,8 +12,8 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
|
|||
from extensions.ext_database import db
|
||||
from fields.installed_app_fields import installed_app_list_fields
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import login_required
|
||||
from models import App, InstalledApp, RecommendedApp
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account, App, InstalledApp, RecommendedApp
|
||||
from services.account_service import TenantService
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
|
|
@ -29,6 +28,8 @@ class InstalledAppsListApi(Resource):
|
|||
@marshal_with(installed_app_list_fields)
|
||||
def get(self):
|
||||
app_id = request.args.get("app_id", default=None, type=str)
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
current_tenant_id = current_user.current_tenant_id
|
||||
|
||||
if app_id:
|
||||
|
|
@ -40,6 +41,8 @@ class InstalledAppsListApi(Resource):
|
|||
else:
|
||||
installed_apps = db.session.query(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id).all()
|
||||
|
||||
if current_user.current_tenant is None:
|
||||
raise ValueError("current_user.current_tenant must not be None")
|
||||
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
|
||||
installed_app_list: list[dict[str, Any]] = [
|
||||
{
|
||||
|
|
@ -115,6 +118,8 @@ class InstalledAppsListApi(Resource):
|
|||
if recommended_app is None:
|
||||
raise NotFound("App not found")
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
current_tenant_id = current_user.current_tenant_id
|
||||
app = db.session.query(App).where(App.id == args["app_id"]).first()
|
||||
|
||||
|
|
@ -154,6 +159,8 @@ class InstalledAppApi(InstalledAppResource):
|
|||
"""
|
||||
|
||||
def delete(self, installed_app):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
if installed_app.app_owner_tenant_id == current_user.current_tenant_id:
|
||||
raise BadRequest("You can't uninstall an app owned by the current tenant")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import logging
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
|
@ -24,6 +23,8 @@ from core.model_runtime.errors.invoke import InvokeError
|
|||
from fields.message_fields import message_infinite_scroll_pagination_fields
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
|
|
@ -54,6 +55,8 @@ class MessageListApi(InstalledAppResource):
|
|||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
return MessageService.pagination_by_first_id(
|
||||
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
|
||||
)
|
||||
|
|
@ -75,6 +78,8 @@ class MessageFeedbackApi(InstalledAppResource):
|
|||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
MessageService.create_feedback(
|
||||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
|
|
@ -105,6 +110,8 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
|||
streaming = args["response_mode"] == "streaming"
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
response = AppGenerateService.generate_more_like_this(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
|
|
@ -142,6 +149,8 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
|
|||
message_id = str(message_id)
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
|
||||
from constants.languages import languages
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from libs.helper import AppIconUrlField
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from services.recommended_app_service import RecommendedAppService
|
||||
|
||||
app_fields = {
|
||||
|
|
@ -46,8 +45,9 @@ class RecommendedAppListApi(Resource):
|
|||
parser.add_argument("language", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.get("language") and args.get("language") in languages:
|
||||
language_prefix = args.get("language")
|
||||
language = args.get("language")
|
||||
if language and language in languages:
|
||||
language_prefix = language
|
||||
elif current_user and current_user.interface_language:
|
||||
language_prefix = current_user.interface_language
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import fields, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
|
@ -8,6 +7,8 @@ from controllers.console.explore.error import NotCompletionAppError
|
|||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
|
|
@ -42,6 +43,8 @@ class SavedMessageListApi(InstalledAppResource):
|
|||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"])
|
||||
|
||||
def post(self, installed_app):
|
||||
|
|
@ -54,6 +57,8 @@ class SavedMessageListApi(InstalledAppResource):
|
|||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
SavedMessageService.save(app_model, current_user, args["message_id"])
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
|
@ -70,6 +75,8 @@ class SavedMessageApi(InstalledAppResource):
|
|||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
SavedMessageService.delete(app_model, current_user, message_id)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from controllers.console.wraps import (
|
|||
)
|
||||
from fields.file_fields import file_fields, upload_config_fields
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.file_service import FileService
|
||||
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
|
|
@ -68,6 +69,8 @@ class FileApi(Resource):
|
|||
source = None
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
upload_file = FileService.upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
|
|
|
|||
|
|
@ -34,14 +34,14 @@ class VersionApi(Resource):
|
|||
return result
|
||||
|
||||
try:
|
||||
response = requests.get(check_update_url, {"current_version": args.get("current_version")}, timeout=(3, 10))
|
||||
response = requests.get(check_update_url, {"current_version": args["current_version"]}, timeout=(3, 10))
|
||||
except Exception as error:
|
||||
logger.warning("Check update version error: %s.", str(error))
|
||||
result["version"] = args.get("current_version")
|
||||
result["version"] = args["current_version"]
|
||||
return result
|
||||
|
||||
content = json.loads(response.content)
|
||||
if _has_new_version(latest_version=content["version"], current_version=f"{args.get('current_version')}"):
|
||||
if _has_new_version(latest_version=content["version"], current_version=f"{args['current_version']}"):
|
||||
result["version"] = content["version"]
|
||||
result["release_date"] = content["releaseDate"]
|
||||
result["release_notes"] = content["releaseNotes"]
|
||||
|
|
|
|||
|
|
@ -49,6 +49,8 @@ class AccountInitApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
def post(self):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
account = current_user
|
||||
|
||||
if account.status == "active":
|
||||
|
|
@ -102,6 +104,8 @@ class AccountProfileApi(Resource):
|
|||
@marshal_with(account_fields)
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
return current_user
|
||||
|
||||
|
||||
|
|
@ -111,6 +115,8 @@ class AccountNameApi(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
|
@ -130,6 +136,8 @@ class AccountAvatarApi(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("avatar", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
|
@ -145,6 +153,8 @@ class AccountInterfaceLanguageApi(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("interface_language", type=supported_language, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
|
@ -160,6 +170,8 @@ class AccountInterfaceThemeApi(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
|
@ -175,6 +187,8 @@ class AccountTimezoneApi(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("timezone", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
|
@ -194,6 +208,8 @@ class AccountPasswordApi(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("password", type=str, required=False, location="json")
|
||||
parser.add_argument("new_password", type=str, required=True, location="json")
|
||||
|
|
@ -228,6 +244,8 @@ class AccountIntegrateApi(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(integrate_list_fields)
|
||||
def get(self):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
account = current_user
|
||||
|
||||
account_integrates = db.session.query(AccountIntegrate).where(AccountIntegrate.account_id == account.id).all()
|
||||
|
|
@ -268,6 +286,8 @@ class AccountDeleteVerifyApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
account = current_user
|
||||
|
||||
token, code = AccountService.generate_account_deletion_verification_code(account)
|
||||
|
|
@ -281,6 +301,8 @@ class AccountDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -321,6 +343,8 @@ class EducationVerifyApi(Resource):
|
|||
@cloud_edition_billing_enabled
|
||||
@marshal_with(verify_fields)
|
||||
def get(self):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
account = current_user
|
||||
|
||||
return BillingService.EducationIdentity.verify(account.id, account.email)
|
||||
|
|
@ -340,6 +364,8 @@ class EducationApi(Resource):
|
|||
@only_edition_cloud
|
||||
@cloud_edition_billing_enabled
|
||||
def post(self):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -357,6 +383,8 @@ class EducationApi(Resource):
|
|||
@cloud_edition_billing_enabled
|
||||
@marshal_with(status_fields)
|
||||
def get(self):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
account = current_user
|
||||
|
||||
res = BillingService.EducationIdentity.status(account.id)
|
||||
|
|
@ -421,6 +449,8 @@ class ChangeEmailSendEmailApi(Resource):
|
|||
raise InvalidTokenError()
|
||||
user_email = reset_data.get("email", "")
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
if user_email != current_user.email:
|
||||
raise InvalidEmailError()
|
||||
else:
|
||||
|
|
@ -501,6 +531,8 @@ class ChangeEmailResetApi(Resource):
|
|||
AccountService.revoke_change_email_token(args["token"])
|
||||
|
||||
old_email = reset_data.get("old_email", "")
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
if current_user.email != old_email:
|
||||
raise AccountNotFound()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from urllib import parse
|
||||
|
||||
from flask import request
|
||||
from flask import abort, request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, abort, marshal_with, reqparse
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
|
|
@ -41,6 +41,10 @@ class MemberListApi(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(account_with_role_list_fields)
|
||||
def get(self):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
return {"result": "success", "accounts": members}, 200
|
||||
|
||||
|
|
@ -65,7 +69,11 @@ class MemberInviteEmailApi(Resource):
|
|||
if not TenantAccountRole.is_non_owner_role(invitee_role):
|
||||
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
inviter = current_user
|
||||
if not inviter.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
invitation_results = []
|
||||
console_web_url = dify_config.CONSOLE_WEB_URL
|
||||
|
||||
|
|
@ -76,6 +84,8 @@ class MemberInviteEmailApi(Resource):
|
|||
|
||||
for invitee_email in invitee_emails:
|
||||
try:
|
||||
if not inviter.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
token = RegisterService.invite_new_member(
|
||||
inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter
|
||||
)
|
||||
|
|
@ -97,7 +107,7 @@ class MemberInviteEmailApi(Resource):
|
|||
return {
|
||||
"result": "success",
|
||||
"invitation_results": invitation_results,
|
||||
"tenant_id": str(current_user.current_tenant.id),
|
||||
"tenant_id": str(inviter.current_tenant.id) if inviter.current_tenant else "",
|
||||
}, 201
|
||||
|
||||
|
||||
|
|
@ -108,6 +118,10 @@ class MemberCancelInviteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, member_id):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
member = db.session.query(Account).where(Account.id == str(member_id)).first()
|
||||
if member is None:
|
||||
abort(404)
|
||||
|
|
@ -123,7 +137,10 @@ class MemberCancelInviteApi(Resource):
|
|||
except Exception as e:
|
||||
raise ValueError(str(e))
|
||||
|
||||
return {"result": "success", "tenant_id": str(current_user.current_tenant.id)}, 200
|
||||
return {
|
||||
"result": "success",
|
||||
"tenant_id": str(current_user.current_tenant.id) if current_user.current_tenant else "",
|
||||
}, 200
|
||||
|
||||
|
||||
class MemberUpdateRoleApi(Resource):
|
||||
|
|
@ -141,6 +158,10 @@ class MemberUpdateRoleApi(Resource):
|
|||
if not TenantAccountRole.is_valid_role(new_role):
|
||||
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
member = db.session.get(Account, str(member_id))
|
||||
if not member:
|
||||
abort(404)
|
||||
|
|
@ -164,6 +185,10 @@ class DatasetOperatorMemberListApi(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(account_with_role_list_fields)
|
||||
def get(self):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
|
||||
return {"result": "success", "accounts": members}, 200
|
||||
|
||||
|
|
@ -184,6 +209,10 @@ class SendOwnerTransferEmailApi(Resource):
|
|||
raise EmailSendIpLimitError()
|
||||
|
||||
# check if the current user is the owner of the workspace
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||
raise NotOwnerError()
|
||||
|
||||
|
|
@ -198,7 +227,7 @@ class SendOwnerTransferEmailApi(Resource):
|
|||
account=current_user,
|
||||
email=email,
|
||||
language=language,
|
||||
workspace_name=current_user.current_tenant.name,
|
||||
workspace_name=current_user.current_tenant.name if current_user.current_tenant else "",
|
||||
)
|
||||
|
||||
return {"result": "success", "data": token}
|
||||
|
|
@ -215,6 +244,10 @@ class OwnerTransferCheckApi(Resource):
|
|||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
# check if the current user is the owner of the workspace
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||
raise NotOwnerError()
|
||||
|
||||
|
|
@ -256,6 +289,10 @@ class OwnerTransfer(Resource):
|
|||
args = parser.parse_args()
|
||||
|
||||
# check if the current user is the owner of the workspace
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||
raise NotOwnerError()
|
||||
|
||||
|
|
@ -274,9 +311,11 @@ class OwnerTransfer(Resource):
|
|||
member = db.session.get(Account, str(member_id))
|
||||
if not member:
|
||||
abort(404)
|
||||
else:
|
||||
member_account = member
|
||||
if not TenantService.is_member(member_account, current_user.current_tenant):
|
||||
return # Never reached, but helps type checker
|
||||
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
if not TenantService.is_member(member, current_user.current_tenant):
|
||||
raise MemberNotInTenantError()
|
||||
|
||||
try:
|
||||
|
|
@ -286,13 +325,13 @@ class OwnerTransfer(Resource):
|
|||
AccountService.send_new_owner_transfer_notify_email(
|
||||
account=member,
|
||||
email=member.email,
|
||||
workspace_name=current_user.current_tenant.name,
|
||||
workspace_name=current_user.current_tenant.name if current_user.current_tenant else "",
|
||||
)
|
||||
|
||||
AccountService.send_old_owner_transfer_notify_email(
|
||||
account=current_user,
|
||||
email=current_user.email,
|
||||
workspace_name=current_user.current_tenant.name,
|
||||
workspace_name=current_user.current_tenant.name if current_user.current_tenant else "",
|
||||
new_owner_email=member.email,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import StrLen, uuid_value
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from services.billing_service import BillingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
|
@ -21,6 +22,10 @@ class ModelProviderListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
if not current_user.current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -45,6 +50,10 @@ class ModelProviderCredentialApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
if not current_user.current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
tenant_id = current_user.current_tenant_id
|
||||
# if credential_id is not provided, return current used credential
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -62,6 +71,8 @@ class ModelProviderCredentialApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
@ -72,6 +83,8 @@ class ModelProviderCredentialApi(Resource):
|
|||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
if not current_user.current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
try:
|
||||
model_provider_service.create_provider_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
|
|
@ -88,6 +101,8 @@ class ModelProviderCredentialApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, provider: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
@ -99,6 +114,8 @@ class ModelProviderCredentialApi(Resource):
|
|||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
if not current_user.current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
try:
|
||||
model_provider_service.update_provider_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
|
|
@ -116,12 +133,16 @@ class ModelProviderCredentialApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not current_user.current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_provider_credential(
|
||||
tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"]
|
||||
|
|
@ -135,12 +156,16 @@ class ModelProviderCredentialSwitchApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not current_user.current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
service = ModelProviderService()
|
||||
service.switch_active_provider_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
|
|
@ -155,10 +180,14 @@ class ModelProviderValidateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not current_user.current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
|
@ -205,9 +234,13 @@ class PreferredProviderTypeUpdateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
if not current_user.current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -236,7 +269,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
|
|||
def get(self, provider: str):
|
||||
if provider != "anthropic":
|
||||
raise ValueError(f"provider name {provider} is invalid")
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
if not current_user.current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
data = BillingService.get_model_provider_payment_link(
|
||||
provider_name=provider,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from controllers.console.wraps import (
|
|||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import login_required
|
||||
from models.account import Tenant, TenantStatus
|
||||
from models.account import Account, Tenant, TenantStatus
|
||||
from services.account_service import TenantService
|
||||
from services.feature_service import FeatureService
|
||||
from services.file_service import FileService
|
||||
|
|
@ -70,6 +70,8 @@ class TenantListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
tenants = TenantService.get_join_tenants(current_user)
|
||||
tenant_dicts = []
|
||||
|
||||
|
|
@ -83,7 +85,7 @@ class TenantListApi(Resource):
|
|||
"status": tenant.status,
|
||||
"created_at": tenant.created_at,
|
||||
"plan": features.billing.subscription.plan if features.billing.enabled else "sandbox",
|
||||
"current": tenant.id == current_user.current_tenant_id,
|
||||
"current": tenant.id == current_user.current_tenant_id if current_user.current_tenant_id else False,
|
||||
}
|
||||
|
||||
tenant_dicts.append(tenant_dict)
|
||||
|
|
@ -125,7 +127,11 @@ class TenantApi(Resource):
|
|||
if request.path == "/info":
|
||||
logger.warning("Deprecated URL /info was used.")
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
tenant = current_user.current_tenant
|
||||
if not tenant:
|
||||
raise ValueError("No current tenant")
|
||||
|
||||
if tenant.status == TenantStatus.ARCHIVE:
|
||||
tenants = TenantService.get_join_tenants(current_user)
|
||||
|
|
@ -137,6 +143,8 @@ class TenantApi(Resource):
|
|||
else:
|
||||
raise Unauthorized("workspace is archived")
|
||||
|
||||
if not tenant:
|
||||
raise ValueError("No tenant available")
|
||||
return WorkspaceService.get_tenant_info(tenant), 200
|
||||
|
||||
|
||||
|
|
@ -145,6 +153,8 @@ class SwitchWorkspaceApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tenant_id", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
|
@ -168,11 +178,15 @@ class CustomConfigWorkspaceApi(Resource):
|
|||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("workspace_custom")
|
||||
def post(self):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("remove_webapp_brand", type=bool, location="json")
|
||||
parser.add_argument("replace_webapp_logo", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not current_user.current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
|
||||
|
||||
custom_config_dict = {
|
||||
|
|
@ -194,6 +208,8 @@ class WebappLogoWorkspaceApi(Resource):
|
|||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("workspace_custom")
|
||||
def post(self):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
|
@ -232,10 +248,14 @@ class WorkspaceInfoApi(Resource):
|
|||
@account_initialization_required
|
||||
# Change workspace name
|
||||
def post(self):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not current_user.current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
|
||||
tenant.name = args["name"]
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -15,6 +15,6 @@ api = ExternalApi(
|
|||
|
||||
files_ns = Namespace("files", description="File operations", path="/")
|
||||
|
||||
from . import image_preview, tool_files, upload
|
||||
from . import image_preview, tool_files, upload # pyright: ignore[reportUnusedImport]
|
||||
|
||||
api.add_namespace(files_ns)
|
||||
|
|
|
|||
|
|
@ -16,8 +16,8 @@ api = ExternalApi(
|
|||
# Create namespace
|
||||
inner_api_ns = Namespace("inner_api", description="Internal API operations", path="/")
|
||||
|
||||
from . import mail
|
||||
from .plugin import plugin
|
||||
from .workspace import workspace
|
||||
from . import mail as _mail # pyright: ignore[reportUnusedImport]
|
||||
from .plugin import plugin as _plugin # pyright: ignore[reportUnusedImport]
|
||||
from .workspace import workspace as _workspace # pyright: ignore[reportUnusedImport]
|
||||
|
||||
api.add_namespace(inner_api_ns)
|
||||
|
|
|
|||
|
|
@ -37,9 +37,9 @@ from models.model import EndUser
|
|||
|
||||
@inner_api_ns.route("/invoke/llm")
|
||||
class PluginInvokeLLMApi(Resource):
|
||||
@get_user_tenant
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeLLM)
|
||||
@inner_api_ns.doc("plugin_invoke_llm")
|
||||
@inner_api_ns.doc(description="Invoke LLM models through plugin interface")
|
||||
|
|
@ -60,9 +60,9 @@ class PluginInvokeLLMApi(Resource):
|
|||
|
||||
@inner_api_ns.route("/invoke/llm/structured-output")
|
||||
class PluginInvokeLLMWithStructuredOutputApi(Resource):
|
||||
@get_user_tenant
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeLLMWithStructuredOutput)
|
||||
@inner_api_ns.doc("plugin_invoke_llm_structured")
|
||||
@inner_api_ns.doc(description="Invoke LLM models with structured output through plugin interface")
|
||||
|
|
@ -85,9 +85,9 @@ class PluginInvokeLLMWithStructuredOutputApi(Resource):
|
|||
|
||||
@inner_api_ns.route("/invoke/text-embedding")
|
||||
class PluginInvokeTextEmbeddingApi(Resource):
|
||||
@get_user_tenant
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeTextEmbedding)
|
||||
@inner_api_ns.doc("plugin_invoke_text_embedding")
|
||||
@inner_api_ns.doc(description="Invoke text embedding models through plugin interface")
|
||||
|
|
@ -115,9 +115,9 @@ class PluginInvokeTextEmbeddingApi(Resource):
|
|||
|
||||
@inner_api_ns.route("/invoke/rerank")
|
||||
class PluginInvokeRerankApi(Resource):
|
||||
@get_user_tenant
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeRerank)
|
||||
@inner_api_ns.doc("plugin_invoke_rerank")
|
||||
@inner_api_ns.doc(description="Invoke rerank models through plugin interface")
|
||||
|
|
@ -141,9 +141,9 @@ class PluginInvokeRerankApi(Resource):
|
|||
|
||||
@inner_api_ns.route("/invoke/tts")
|
||||
class PluginInvokeTTSApi(Resource):
|
||||
@get_user_tenant
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeTTS)
|
||||
@inner_api_ns.doc("plugin_invoke_tts")
|
||||
@inner_api_ns.doc(description="Invoke text-to-speech models through plugin interface")
|
||||
|
|
@ -168,9 +168,9 @@ class PluginInvokeTTSApi(Resource):
|
|||
|
||||
@inner_api_ns.route("/invoke/speech2text")
|
||||
class PluginInvokeSpeech2TextApi(Resource):
|
||||
@get_user_tenant
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeSpeech2Text)
|
||||
@inner_api_ns.doc("plugin_invoke_speech2text")
|
||||
@inner_api_ns.doc(description="Invoke speech-to-text models through plugin interface")
|
||||
|
|
@ -194,9 +194,9 @@ class PluginInvokeSpeech2TextApi(Resource):
|
|||
|
||||
@inner_api_ns.route("/invoke/moderation")
|
||||
class PluginInvokeModerationApi(Resource):
|
||||
@get_user_tenant
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeModeration)
|
||||
@inner_api_ns.doc("plugin_invoke_moderation")
|
||||
@inner_api_ns.doc(description="Invoke moderation models through plugin interface")
|
||||
|
|
@ -220,9 +220,9 @@ class PluginInvokeModerationApi(Resource):
|
|||
|
||||
@inner_api_ns.route("/invoke/tool")
|
||||
class PluginInvokeToolApi(Resource):
|
||||
@get_user_tenant
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeTool)
|
||||
@inner_api_ns.doc("plugin_invoke_tool")
|
||||
@inner_api_ns.doc(description="Invoke tools through plugin interface")
|
||||
|
|
@ -252,9 +252,9 @@ class PluginInvokeToolApi(Resource):
|
|||
|
||||
@inner_api_ns.route("/invoke/parameter-extractor")
|
||||
class PluginInvokeParameterExtractorNodeApi(Resource):
|
||||
@get_user_tenant
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeParameterExtractorNode)
|
||||
@inner_api_ns.doc("plugin_invoke_parameter_extractor")
|
||||
@inner_api_ns.doc(description="Invoke parameter extractor node through plugin interface")
|
||||
|
|
@ -285,9 +285,9 @@ class PluginInvokeParameterExtractorNodeApi(Resource):
|
|||
|
||||
@inner_api_ns.route("/invoke/question-classifier")
|
||||
class PluginInvokeQuestionClassifierNodeApi(Resource):
|
||||
@get_user_tenant
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeQuestionClassifierNode)
|
||||
@inner_api_ns.doc("plugin_invoke_question_classifier")
|
||||
@inner_api_ns.doc(description="Invoke question classifier node through plugin interface")
|
||||
|
|
@ -318,9 +318,9 @@ class PluginInvokeQuestionClassifierNodeApi(Resource):
|
|||
|
||||
@inner_api_ns.route("/invoke/app")
|
||||
class PluginInvokeAppApi(Resource):
|
||||
@get_user_tenant
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeApp)
|
||||
@inner_api_ns.doc("plugin_invoke_app")
|
||||
@inner_api_ns.doc(description="Invoke application through plugin interface")
|
||||
|
|
@ -348,9 +348,9 @@ class PluginInvokeAppApi(Resource):
|
|||
|
||||
@inner_api_ns.route("/invoke/encrypt")
|
||||
class PluginInvokeEncryptApi(Resource):
|
||||
@get_user_tenant
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeEncrypt)
|
||||
@inner_api_ns.doc("plugin_invoke_encrypt")
|
||||
@inner_api_ns.doc(description="Encrypt or decrypt data through plugin interface")
|
||||
|
|
@ -375,9 +375,9 @@ class PluginInvokeEncryptApi(Resource):
|
|||
|
||||
@inner_api_ns.route("/invoke/summary")
|
||||
class PluginInvokeSummaryApi(Resource):
|
||||
@get_user_tenant
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeSummary)
|
||||
@inner_api_ns.doc("plugin_invoke_summary")
|
||||
@inner_api_ns.doc(description="Invoke summary functionality through plugin interface")
|
||||
|
|
@ -405,9 +405,9 @@ class PluginInvokeSummaryApi(Resource):
|
|||
|
||||
@inner_api_ns.route("/upload/file/request")
|
||||
class PluginUploadFileRequestApi(Resource):
|
||||
@get_user_tenant
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestRequestUploadFile)
|
||||
@inner_api_ns.doc("plugin_upload_file_request")
|
||||
@inner_api_ns.doc(description="Request signed URL for file upload through plugin interface")
|
||||
|
|
@ -426,9 +426,9 @@ class PluginUploadFileRequestApi(Resource):
|
|||
|
||||
@inner_api_ns.route("/fetch/app/info")
|
||||
class PluginFetchAppInfoApi(Resource):
|
||||
@get_user_tenant
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestFetchAppInfo)
|
||||
@inner_api_ns.doc("plugin_fetch_app_info")
|
||||
@inner_api_ns.doc(description="Fetch application information through plugin interface")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Optional, ParamSpec, TypeVar
|
||||
from typing import Optional, ParamSpec, TypeVar, cast
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
|
|
@ -10,7 +10,7 @@ from sqlalchemy.orm import Session
|
|||
|
||||
from core.file.constants import DEFAULT_SERVICE_API_USER_ID
|
||||
from extensions.ext_database import db
|
||||
from libs.login import _get_user
|
||||
from libs.login import current_user
|
||||
from models.account import Tenant
|
||||
from models.model import EndUser
|
||||
|
||||
|
|
@ -66,8 +66,8 @@ def get_user_tenant(view: Optional[Callable[P, R]] = None):
|
|||
|
||||
p = parser.parse_args()
|
||||
|
||||
user_id: Optional[str] = p.get("user_id")
|
||||
tenant_id: str = p.get("tenant_id")
|
||||
user_id = cast(str, p.get("user_id"))
|
||||
tenant_id = cast(str, p.get("tenant_id"))
|
||||
|
||||
if not tenant_id:
|
||||
raise ValueError("tenant_id is required")
|
||||
|
|
@ -98,7 +98,7 @@ def get_user_tenant(view: Optional[Callable[P, R]] = None):
|
|||
kwargs["user_model"] = user
|
||||
|
||||
current_app.login_manager._update_request_context_with_user(user) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,6 @@ api = ExternalApi(
|
|||
|
||||
mcp_ns = Namespace("mcp", description="MCP operations", path="/")
|
||||
|
||||
from . import mcp
|
||||
from . import mcp # pyright: ignore[reportUnusedImport]
|
||||
|
||||
api.add_namespace(mcp_ns)
|
||||
|
|
|
|||
|
|
@ -15,9 +15,27 @@ api = ExternalApi(
|
|||
|
||||
service_api_ns = Namespace("service_api", description="Service operations", path="/")
|
||||
|
||||
from . import index
|
||||
from .app import annotation, app, audio, completion, conversation, file, file_preview, message, site, workflow
|
||||
from .dataset import dataset, document, hit_testing, metadata, segment, upload_file
|
||||
from .workspace import models
|
||||
from . import index # pyright: ignore[reportUnusedImport]
|
||||
from .app import (
|
||||
annotation, # pyright: ignore[reportUnusedImport]
|
||||
app, # pyright: ignore[reportUnusedImport]
|
||||
audio, # pyright: ignore[reportUnusedImport]
|
||||
completion, # pyright: ignore[reportUnusedImport]
|
||||
conversation, # pyright: ignore[reportUnusedImport]
|
||||
file, # pyright: ignore[reportUnusedImport]
|
||||
file_preview, # pyright: ignore[reportUnusedImport]
|
||||
message, # pyright: ignore[reportUnusedImport]
|
||||
site, # pyright: ignore[reportUnusedImport]
|
||||
workflow, # pyright: ignore[reportUnusedImport]
|
||||
)
|
||||
from .dataset import (
|
||||
dataset, # pyright: ignore[reportUnusedImport]
|
||||
document, # pyright: ignore[reportUnusedImport]
|
||||
hit_testing, # pyright: ignore[reportUnusedImport]
|
||||
metadata, # pyright: ignore[reportUnusedImport]
|
||||
segment, # pyright: ignore[reportUnusedImport]
|
||||
upload_file, # pyright: ignore[reportUnusedImport]
|
||||
)
|
||||
from .workspace import models # pyright: ignore[reportUnusedImport]
|
||||
|
||||
api.add_namespace(service_api_ns)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from flask_restx import Resource, reqparse
|
||||
from flask_restx._http import HTTPStatus
|
||||
from flask_restx.inputs import int_range
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
|
@ -121,7 +122,7 @@ class ConversationDetailApi(Resource):
|
|||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=204)
|
||||
@service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=HTTPStatus.NO_CONTENT)
|
||||
def delete(self, app_model: App, end_user: EndUser, c_id):
|
||||
"""Delete a specific conversation."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from extensions.ext_database import db
|
|||
from fields.document_fields import document_fields, document_status_fields
|
||||
from libs.login import current_user
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.model import EndUser
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
||||
from services.file_service import FileService
|
||||
|
|
@ -298,6 +299,9 @@ class DocumentAddByFileApi(DatasetApiResource):
|
|||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
if not isinstance(current_user, EndUser):
|
||||
raise ValueError("Invalid user account")
|
||||
|
||||
upload_file = FileService.upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
|
|
@ -387,6 +391,8 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
|||
raise FilenameNotExistsError
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, EndUser):
|
||||
raise ValueError("Invalid user account")
|
||||
upload_file = FileService.upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from core.file.constants import DEFAULT_SERVICE_API_USER_ID
|
|||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import _get_user
|
||||
from libs.login import current_user
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
|
||||
from models.dataset import Dataset, RateLimitLog
|
||||
from models.model import ApiToken, App, EndUser
|
||||
|
|
@ -210,7 +210,7 @@ def validate_dataset_token(view: Optional[Callable[Concatenate[T, P], R]] = None
|
|||
if account:
|
||||
account.current_tenant = tenant
|
||||
current_app.login_manager._update_request_context_with_user(account) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
|
||||
else:
|
||||
raise Unauthorized("Tenant owner account does not exist.")
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -17,20 +17,20 @@ api = ExternalApi(
|
|||
web_ns = Namespace("web", description="Web application API operations", path="/")
|
||||
|
||||
from . import (
|
||||
app,
|
||||
audio,
|
||||
completion,
|
||||
conversation,
|
||||
feature,
|
||||
files,
|
||||
forgot_password,
|
||||
login,
|
||||
message,
|
||||
passport,
|
||||
remote_files,
|
||||
saved_message,
|
||||
site,
|
||||
workflow,
|
||||
app, # pyright: ignore[reportUnusedImport]
|
||||
audio, # pyright: ignore[reportUnusedImport]
|
||||
completion, # pyright: ignore[reportUnusedImport]
|
||||
conversation, # pyright: ignore[reportUnusedImport]
|
||||
feature, # pyright: ignore[reportUnusedImport]
|
||||
files, # pyright: ignore[reportUnusedImport]
|
||||
forgot_password, # pyright: ignore[reportUnusedImport]
|
||||
login, # pyright: ignore[reportUnusedImport]
|
||||
message, # pyright: ignore[reportUnusedImport]
|
||||
passport, # pyright: ignore[reportUnusedImport]
|
||||
remote_files, # pyright: ignore[reportUnusedImport]
|
||||
saved_message, # pyright: ignore[reportUnusedImport]
|
||||
site, # pyright: ignore[reportUnusedImport]
|
||||
workflow, # pyright: ignore[reportUnusedImport]
|
||||
)
|
||||
|
||||
api.add_namespace(web_ns)
|
||||
|
|
|
|||
|
|
@ -1 +0,0 @@
|
|||
import core.moderation.base
|
||||
|
|
@ -72,6 +72,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
function_call_state = True
|
||||
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
|
||||
final_answer = ""
|
||||
prompt_messages: list = [] # Initialize prompt_messages
|
||||
agent_thought_id = "" # Initialize agent_thought_id
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
|
||||
if not final_llm_usage_dict["usage"]:
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
function_call_state = True
|
||||
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
|
||||
final_answer = ""
|
||||
prompt_messages: list = [] # Initialize prompt_messages
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = app_generate_entity.trace_manager
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ class SensitiveWordAvoidanceConfigManager:
|
|||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(
|
||||
cls, tenant_id, config: dict, only_structure_validate: bool = False
|
||||
cls, tenant_id: str, config: dict, only_structure_validate: bool = False
|
||||
) -> tuple[dict, list[str]]:
|
||||
if not config.get("sensitive_word_avoidance"):
|
||||
config["sensitive_word_avoidance"] = {"enabled": False}
|
||||
|
|
@ -38,7 +38,14 @@ class SensitiveWordAvoidanceConfigManager:
|
|||
|
||||
if not only_structure_validate:
|
||||
typ = config["sensitive_word_avoidance"]["type"]
|
||||
sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"]
|
||||
if not isinstance(typ, str):
|
||||
raise ValueError("sensitive_word_avoidance.type must be a string")
|
||||
|
||||
sensitive_word_avoidance_config = config["sensitive_word_avoidance"].get("config")
|
||||
if sensitive_word_avoidance_config is None:
|
||||
sensitive_word_avoidance_config = {}
|
||||
if not isinstance(sensitive_word_avoidance_config, dict):
|
||||
raise ValueError("sensitive_word_avoidance.config must be a dict")
|
||||
|
||||
ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config)
|
||||
|
||||
|
|
|
|||
|
|
@ -25,10 +25,14 @@ class PromptTemplateConfigManager:
|
|||
if chat_prompt_config:
|
||||
chat_prompt_messages = []
|
||||
for message in chat_prompt_config.get("prompt", []):
|
||||
text = message.get("text")
|
||||
if not isinstance(text, str):
|
||||
raise ValueError("message text must be a string")
|
||||
role = message.get("role")
|
||||
if not isinstance(role, str):
|
||||
raise ValueError("message role must be a string")
|
||||
chat_prompt_messages.append(
|
||||
AdvancedChatMessageEntity(
|
||||
**{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])}
|
||||
)
|
||||
AdvancedChatMessageEntity(text=text, role=PromptMessageRole.value_of(role))
|
||||
)
|
||||
|
||||
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages)
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, Any] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
|
|
@ -82,7 +82,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
|
|
@ -102,7 +102,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, Any] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
|
|
@ -110,7 +110,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.to_dict()
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
|
|
@ -118,8 +118,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
|
||||
yield response_chunk
|
||||
|
|
|
|||
|
|
@ -169,7 +169,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
|
||||
if self._base_task_pipeline._stream:
|
||||
if self._base_task_pipeline.stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
|
|
@ -297,13 +297,13 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
|
||||
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
|
||||
"""Handle ping events."""
|
||||
yield self._base_task_pipeline._ping_stream_response()
|
||||
yield self._base_task_pipeline.ping_stream_response()
|
||||
|
||||
def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]:
|
||||
"""Handle error events."""
|
||||
with self._database_session() as session:
|
||||
err = self._base_task_pipeline._handle_error(event=event, session=session, message_id=self._message_id)
|
||||
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||
err = self._base_task_pipeline.handle_error(event=event, session=session, message_id=self._message_id)
|
||||
yield self._base_task_pipeline.error_to_stream_response(err)
|
||||
|
||||
def _handle_workflow_started_event(self, *args, **kwargs) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow started events."""
|
||||
|
|
@ -594,10 +594,10 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
workflow_execution=workflow_execution,
|
||||
)
|
||||
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}"))
|
||||
err = self._base_task_pipeline._handle_error(event=err_event, session=session, message_id=self._message_id)
|
||||
err = self._base_task_pipeline.handle_error(event=err_event, session=session, message_id=self._message_id)
|
||||
|
||||
yield workflow_finish_resp
|
||||
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||
yield self._base_task_pipeline.error_to_stream_response(err)
|
||||
|
||||
def _handle_stop_event(
|
||||
self,
|
||||
|
|
@ -650,7 +650,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
"""Handle advanced chat message end events."""
|
||||
self._ensure_graph_runtime_initialized(graph_runtime_state)
|
||||
|
||||
output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished(
|
||||
output_moderation_answer = self._base_task_pipeline.handle_output_moderation_when_task_finished(
|
||||
self._task_state.answer
|
||||
)
|
||||
if output_moderation_answer:
|
||||
|
|
@ -846,7 +846,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
|
||||
message.answer = answer_text
|
||||
message.updated_at = naive_utc_now()
|
||||
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
|
||||
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline.start_at
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
message_files = [
|
||||
MessageFile(
|
||||
|
|
@ -902,9 +902,9 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
:param text: text
|
||||
:return: True if output moderation should direct output, otherwise False
|
||||
"""
|
||||
if self._base_task_pipeline._output_moderation_handler:
|
||||
if self._base_task_pipeline._output_moderation_handler.should_direct_output():
|
||||
self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output()
|
||||
if self._base_task_pipeline.output_moderation_handler:
|
||||
if self._base_task_pipeline.output_moderation_handler.should_direct_output():
|
||||
self._task_state.answer = self._base_task_pipeline.output_moderation_handler.get_final_output()
|
||||
self._base_task_pipeline.queue_manager.publish(
|
||||
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
|
|
@ -914,7 +914,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
)
|
||||
return True
|
||||
else:
|
||||
self._base_task_pipeline._output_moderation_handler.append_new_token(text)
|
||||
self._base_task_pipeline.output_moderation_handler.append_new_token(text)
|
||||
|
||||
return False
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.agent.entities import AgentEntity
|
||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
|
|
@ -160,7 +160,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
|||
return filtered_config
|
||||
|
||||
@classmethod
|
||||
def validate_agent_mode_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_agent_mode_and_set_defaults(
|
||||
cls, tenant_id: str, config: dict[str, Any]
|
||||
) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate agent_mode and set defaults for agent feature
|
||||
|
||||
|
|
@ -170,30 +172,32 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
|||
if not config.get("agent_mode"):
|
||||
config["agent_mode"] = {"enabled": False, "tools": []}
|
||||
|
||||
if not isinstance(config["agent_mode"], dict):
|
||||
agent_mode = config["agent_mode"]
|
||||
if not isinstance(agent_mode, dict):
|
||||
raise ValueError("agent_mode must be of object type")
|
||||
|
||||
if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]:
|
||||
config["agent_mode"]["enabled"] = False
|
||||
# FIXME(-LAN-): Cast needed due to basedpyright limitation with dict type narrowing
|
||||
agent_mode = cast(dict[str, Any], agent_mode)
|
||||
|
||||
if not isinstance(config["agent_mode"]["enabled"], bool):
|
||||
if "enabled" not in agent_mode or not agent_mode["enabled"]:
|
||||
agent_mode["enabled"] = False
|
||||
|
||||
if not isinstance(agent_mode["enabled"], bool):
|
||||
raise ValueError("enabled in agent_mode must be of boolean type")
|
||||
|
||||
if not config["agent_mode"].get("strategy"):
|
||||
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
|
||||
if not agent_mode.get("strategy"):
|
||||
agent_mode["strategy"] = PlanningStrategy.ROUTER.value
|
||||
|
||||
if config["agent_mode"]["strategy"] not in [
|
||||
member.value for member in list(PlanningStrategy.__members__.values())
|
||||
]:
|
||||
if agent_mode["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]:
|
||||
raise ValueError("strategy in agent_mode must be in the specified strategy list")
|
||||
|
||||
if not config["agent_mode"].get("tools"):
|
||||
config["agent_mode"]["tools"] = []
|
||||
if not agent_mode.get("tools"):
|
||||
agent_mode["tools"] = []
|
||||
|
||||
if not isinstance(config["agent_mode"]["tools"], list):
|
||||
if not isinstance(agent_mode["tools"], list):
|
||||
raise ValueError("tools in agent_mode must be a list of objects")
|
||||
|
||||
for tool in config["agent_mode"]["tools"]:
|
||||
for tool in agent_mode["tools"]:
|
||||
key = list(tool.keys())[0]
|
||||
if key in OLD_TOOLS:
|
||||
# old style, use tool name as key
|
||||
|
|
|
|||
|
|
@ -46,7 +46,10 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
response = cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
metadata = response.get("metadata", {})
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
if isinstance(metadata, dict):
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
else:
|
||||
response["metadata"] = {}
|
||||
|
||||
return response
|
||||
|
||||
|
|
@ -78,7 +81,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
|
|
@ -106,7 +109,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.to_dict()
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
|
|
@ -114,6 +117,6 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
|
||||
yield response_chunk
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ class AppQueueManager:
|
|||
self._task_id = task_id
|
||||
self._user_id = user_id
|
||||
self._invoke_from = invoke_from
|
||||
self.invoke_from = invoke_from # Public accessor for invoke_from
|
||||
|
||||
user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
|
||||
redis_client.setex(
|
||||
|
|
|
|||
|
|
@ -46,7 +46,10 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
response = cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
metadata = response.get("metadata", {})
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
if isinstance(metadata, dict):
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
else:
|
||||
response["metadata"] = {}
|
||||
|
||||
return response
|
||||
|
||||
|
|
@ -78,7 +81,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
|
|
@ -106,7 +109,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.to_dict()
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
|
|
@ -114,6 +117,6 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
|
||||
yield response_chunk
|
||||
|
|
|
|||
|
|
@ -271,6 +271,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
raise MoreLikeThisDisabledError()
|
||||
|
||||
app_model_config = message.app_model_config
|
||||
if not app_model_config:
|
||||
raise ValueError("Message app_model_config is None")
|
||||
override_model_config_dict = app_model_config.to_dict()
|
||||
model_dict = override_model_config_dict["model"]
|
||||
completion_params = model_dict.get("completion_params")
|
||||
|
|
|
|||
|
|
@ -45,7 +45,10 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
response = cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
metadata = response.get("metadata", {})
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
if isinstance(metadata, dict):
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
else:
|
||||
response["metadata"] = {}
|
||||
|
||||
return response
|
||||
|
||||
|
|
@ -76,7 +79,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
|
|
@ -103,14 +106,16 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.to_dict()
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
|
||||
yield response_chunk
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
return dict(blocking_response.to_dict())
|
||||
return blocking_response.model_dump()
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
|
||||
|
|
@ -51,7 +51,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, object] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
|
|
@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
|
|
@ -80,7 +80,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, object] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
|
|
@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
yield response_chunk
|
||||
|
|
|
|||
|
|
@ -133,7 +133,7 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_features_dict = workflow.features_dict
|
||||
self._workflow_run_id = ""
|
||||
self._invoke_from = queue_manager._invoke_from
|
||||
self._invoke_from = queue_manager.invoke_from
|
||||
self._draft_var_saver_factory = draft_var_saver_factory
|
||||
|
||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
|
|
@ -142,7 +142,7 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
:return:
|
||||
"""
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
if self._base_task_pipeline._stream:
|
||||
if self._base_task_pipeline.stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
|
|
@ -272,12 +272,12 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
|
||||
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
|
||||
"""Handle ping events."""
|
||||
yield self._base_task_pipeline._ping_stream_response()
|
||||
yield self._base_task_pipeline.ping_stream_response()
|
||||
|
||||
def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]:
|
||||
"""Handle error events."""
|
||||
err = self._base_task_pipeline._handle_error(event=event)
|
||||
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||
err = self._base_task_pipeline.handle_error(event=event)
|
||||
yield self._base_task_pipeline.error_to_stream_response(err)
|
||||
|
||||
def _handle_workflow_started_event(
|
||||
self, event: QueueWorkflowStartedEvent, **kwargs
|
||||
|
|
|
|||
|
|
@ -123,7 +123,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
|||
"""
|
||||
|
||||
# app config
|
||||
app_config: EasyUIBasedAppConfig
|
||||
app_config: EasyUIBasedAppConfig = None # type: ignore
|
||||
model_conf: ModelConfigWithCredentialsEntity
|
||||
|
||||
query: Optional[str] = None
|
||||
|
|
@ -186,7 +186,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
|
|||
"""
|
||||
|
||||
# app config
|
||||
app_config: WorkflowUIBasedAppConfig
|
||||
app_config: WorkflowUIBasedAppConfig = None # type: ignore
|
||||
|
||||
workflow_run_id: Optional[str] = None
|
||||
query: str
|
||||
|
|
@ -218,7 +218,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
|||
"""
|
||||
|
||||
# app config
|
||||
app_config: WorkflowUIBasedAppConfig
|
||||
app_config: WorkflowUIBasedAppConfig = None # type: ignore
|
||||
workflow_execution_id: str
|
||||
|
||||
class SingleIterationRunEntity(BaseModel):
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from typing import Any, Optional
|
|||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
|
|
@ -90,9 +89,6 @@ class StreamResponse(BaseModel):
|
|||
event: StreamEvent
|
||||
task_id: str
|
||||
|
||||
def to_dict(self):
|
||||
return jsonable_encoder(self)
|
||||
|
||||
|
||||
class ErrorStreamResponse(StreamResponse):
|
||||
"""
|
||||
|
|
@ -685,9 +681,6 @@ class AppBlockingResponse(BaseModel):
|
|||
|
||||
task_id: str
|
||||
|
||||
def to_dict(self):
|
||||
return jsonable_encoder(self)
|
||||
|
||||
|
||||
class ChatbotAppBlockingResponse(AppBlockingResponse):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -35,6 +35,9 @@ class AnnotationReplyFeature:
|
|||
|
||||
collection_binding_detail = annotation_setting.collection_binding_detail
|
||||
|
||||
if not collection_binding_detail:
|
||||
return None
|
||||
|
||||
try:
|
||||
score_threshold = annotation_setting.score_threshold or 1
|
||||
embedding_provider_name = collection_binding_detail.provider_name
|
||||
|
|
|
|||
|
|
@ -1 +1,3 @@
|
|||
from .rate_limit import RateLimit
|
||||
|
||||
__all__ = ["RateLimit"]
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ class RateLimit:
|
|||
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
|
||||
_instance_dict: dict[str, "RateLimit"] = {}
|
||||
|
||||
def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int):
|
||||
def __new__(cls, client_id: str, max_active_requests: int):
|
||||
if client_id not in cls._instance_dict:
|
||||
instance = super().__new__(cls)
|
||||
cls._instance_dict[client_id] = instance
|
||||
|
|
|
|||
|
|
@ -38,11 +38,11 @@ class BasedGenerateTaskPipeline:
|
|||
):
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self.queue_manager = queue_manager
|
||||
self._start_at = time.perf_counter()
|
||||
self._output_moderation_handler = self._init_output_moderation()
|
||||
self._stream = stream
|
||||
self.start_at = time.perf_counter()
|
||||
self.output_moderation_handler = self._init_output_moderation()
|
||||
self.stream = stream
|
||||
|
||||
def _handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""):
|
||||
def handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""):
|
||||
logger.debug("error: %s", event.error)
|
||||
e = event.error
|
||||
err: Exception
|
||||
|
|
@ -86,7 +86,7 @@ class BasedGenerateTaskPipeline:
|
|||
|
||||
return message
|
||||
|
||||
def _error_to_stream_response(self, e: Exception):
|
||||
def error_to_stream_response(self, e: Exception):
|
||||
"""
|
||||
Error to stream response.
|
||||
:param e: exception
|
||||
|
|
@ -94,7 +94,7 @@ class BasedGenerateTaskPipeline:
|
|||
"""
|
||||
return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e)
|
||||
|
||||
def _ping_stream_response(self) -> PingStreamResponse:
|
||||
def ping_stream_response(self) -> PingStreamResponse:
|
||||
"""
|
||||
Ping stream response.
|
||||
:return:
|
||||
|
|
@ -118,21 +118,21 @@ class BasedGenerateTaskPipeline:
|
|||
)
|
||||
return None
|
||||
|
||||
def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]:
|
||||
def handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]:
|
||||
"""
|
||||
Handle output moderation when task finished.
|
||||
:param completion: completion
|
||||
:return:
|
||||
"""
|
||||
# response moderation
|
||||
if self._output_moderation_handler:
|
||||
self._output_moderation_handler.stop_thread()
|
||||
if self.output_moderation_handler:
|
||||
self.output_moderation_handler.stop_thread()
|
||||
|
||||
completion, flagged = self._output_moderation_handler.moderation_completion(
|
||||
completion, flagged = self.output_moderation_handler.moderation_completion(
|
||||
completion=completion, public_event=False
|
||||
)
|
||||
|
||||
self._output_moderation_handler = None
|
||||
self.output_moderation_handler = None
|
||||
if flagged:
|
||||
return completion
|
||||
|
||||
|
|
|
|||
|
|
@ -125,7 +125,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
)
|
||||
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
if self._stream:
|
||||
if self.stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
|
|
@ -265,9 +265,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
with Session(db.engine) as session:
|
||||
err = self._handle_error(event=event, session=session, message_id=self._message_id)
|
||||
err = self.handle_error(event=event, session=session, message_id=self._message_id)
|
||||
session.commit()
|
||||
yield self._error_to_stream_response(err)
|
||||
yield self.error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
|
||||
if isinstance(event, QueueMessageEndEvent):
|
||||
|
|
@ -277,7 +277,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
self._handle_stop(event)
|
||||
|
||||
# handle output moderation
|
||||
output_moderation_answer = self._handle_output_moderation_when_task_finished(
|
||||
output_moderation_answer = self.handle_output_moderation_when_task_finished(
|
||||
cast(str, self._task_state.llm_result.message.content)
|
||||
)
|
||||
if output_moderation_answer:
|
||||
|
|
@ -354,7 +354,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text)
|
||||
elif isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
yield self.ping_stream_response()
|
||||
else:
|
||||
continue
|
||||
if publisher:
|
||||
|
|
@ -394,7 +394,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
message.answer_tokens = usage.completion_tokens
|
||||
message.answer_unit_price = usage.completion_unit_price
|
||||
message.answer_price_unit = usage.completion_price_unit
|
||||
message.provider_response_latency = time.perf_counter() - self._start_at
|
||||
message.provider_response_latency = time.perf_counter() - self.start_at
|
||||
message.total_price = usage.total_price
|
||||
message.currency = usage.currency
|
||||
self._task_state.llm_result.usage.latency = message.provider_response_latency
|
||||
|
|
@ -438,7 +438,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
# transform usage
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
|
||||
self._task_state.llm_result.usage = model_type_instance.calc_response_usage(
|
||||
model, credentials, prompt_tokens, completion_tokens
|
||||
)
|
||||
|
||||
|
|
@ -498,10 +498,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
:param text: text
|
||||
:return: True if output moderation should direct output, otherwise False
|
||||
"""
|
||||
if self._output_moderation_handler:
|
||||
if self._output_moderation_handler.should_direct_output():
|
||||
if self.output_moderation_handler:
|
||||
if self.output_moderation_handler.should_direct_output():
|
||||
# stop subscribe new token when output moderation should direct output
|
||||
self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
|
||||
self._task_state.llm_result.message.content = self.output_moderation_handler.get_final_output()
|
||||
self.queue_manager.publish(
|
||||
QueueLLMChunkEvent(
|
||||
chunk=LLMResultChunk(
|
||||
|
|
@ -521,6 +521,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
)
|
||||
return True
|
||||
else:
|
||||
self._output_moderation_handler.append_new_token(text)
|
||||
self.output_moderation_handler.append_new_token(text)
|
||||
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ class AppGeneratorTTSPublisher:
|
|||
self.voice = voice
|
||||
if not voice or voice not in values:
|
||||
self.voice = self.voices[0].get("value")
|
||||
self.MAX_SENTENCE = 2
|
||||
self.max_sentence = 2
|
||||
self._last_audio_event: Optional[AudioTrunk] = None
|
||||
# FIXME better way to handle this threading.start
|
||||
threading.Thread(target=self._runtime).start()
|
||||
|
|
@ -113,8 +113,8 @@ class AppGeneratorTTSPublisher:
|
|||
self.msg_text += message.event.outputs.get("output", "")
|
||||
self.last_message = message
|
||||
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
|
||||
if len(sentence_arr) >= min(self.MAX_SENTENCE, 7):
|
||||
self.MAX_SENTENCE += 1
|
||||
if len(sentence_arr) >= min(self.max_sentence, 7):
|
||||
self.max_sentence += 1
|
||||
text_content = "".join(sentence_arr)
|
||||
futures_result = self.executor.submit(
|
||||
_invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice
|
||||
|
|
|
|||
|
|
@ -1843,8 +1843,14 @@ class ProviderConfigurations(BaseModel):
|
|||
def __setitem__(self, key, value):
|
||||
self.configurations[key] = value
|
||||
|
||||
def __contains__(self, key):
|
||||
if "/" not in key:
|
||||
key = str(ModelProviderID(key))
|
||||
return key in self.configurations
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.configurations)
|
||||
# Return an iterator of (key, value) tuples to match BaseModel's __iter__
|
||||
yield from self.configurations.items()
|
||||
|
||||
def values(self) -> Iterator[ProviderConfiguration]:
|
||||
return iter(self.configurations.values())
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ def to_prompt_message_content(
|
|||
|
||||
def download(f: File, /):
|
||||
if f.transfer_method in (FileTransferMethod.TOOL_FILE, FileTransferMethod.LOCAL_FILE):
|
||||
return _download_file_content(f._storage_key)
|
||||
return _download_file_content(f.storage_key)
|
||||
elif f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
|
|
@ -134,9 +134,9 @@ def _get_encoded_string(f: File, /):
|
|||
response.raise_for_status()
|
||||
data = response.content
|
||||
case FileTransferMethod.LOCAL_FILE:
|
||||
data = _download_file_content(f._storage_key)
|
||||
data = _download_file_content(f.storage_key)
|
||||
case FileTransferMethod.TOOL_FILE:
|
||||
data = _download_file_content(f._storage_key)
|
||||
data = _download_file_content(f.storage_key)
|
||||
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return encoded_string
|
||||
|
|
|
|||
|
|
@ -146,3 +146,11 @@ class File(BaseModel):
|
|||
if not self.related_id:
|
||||
raise ValueError("Missing file related_id")
|
||||
return self
|
||||
|
||||
@property
|
||||
def storage_key(self) -> str:
|
||||
return self._storage_key
|
||||
|
||||
@storage_key.setter
|
||||
def storage_key(self, value: str):
|
||||
self._storage_key = value
|
||||
|
|
|
|||
|
|
@ -13,18 +13,18 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
|
||||
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY = True # Default value for HTTP_REQUEST_NODE_SSL_VERIFY is True
|
||||
http_request_node_ssl_verify = True # Default value for http_request_node_ssl_verify is True
|
||||
try:
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
|
||||
http_request_node_ssl_verify_lower = str(HTTP_REQUEST_NODE_SSL_VERIFY).lower()
|
||||
config_value = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
|
||||
http_request_node_ssl_verify_lower = str(config_value).lower()
|
||||
if http_request_node_ssl_verify_lower == "true":
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY = True
|
||||
http_request_node_ssl_verify = True
|
||||
elif http_request_node_ssl_verify_lower == "false":
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY = False
|
||||
http_request_node_ssl_verify = False
|
||||
else:
|
||||
raise ValueError("Invalid value. HTTP_REQUEST_NODE_SSL_VERIFY should be 'True' or 'False'")
|
||||
except NameError:
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY = True
|
||||
http_request_node_ssl_verify = True
|
||||
|
||||
BACKOFF_FACTOR = 0.5
|
||||
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
||||
|
|
@ -51,7 +51,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
|||
)
|
||||
|
||||
if "ssl_verify" not in kwargs:
|
||||
kwargs["ssl_verify"] = HTTP_REQUEST_NODE_SSL_VERIFY
|
||||
kwargs["ssl_verify"] = http_request_node_ssl_verify
|
||||
|
||||
ssl_verify = kwargs.pop("ssl_verify")
|
||||
|
||||
|
|
|
|||
|
|
@ -529,6 +529,7 @@ class IndexingRunner:
|
|||
# chunk nodes by chunk size
|
||||
indexing_start_at = time.perf_counter()
|
||||
tokens = 0
|
||||
create_keyword_thread = None
|
||||
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy":
|
||||
# create keyword index
|
||||
create_keyword_thread = threading.Thread(
|
||||
|
|
@ -567,7 +568,11 @@ class IndexingRunner:
|
|||
|
||||
for future in futures:
|
||||
tokens += future.result()
|
||||
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy":
|
||||
if (
|
||||
dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX
|
||||
and dataset.indexing_technique == "economy"
|
||||
and create_keyword_thread is not None
|
||||
):
|
||||
create_keyword_thread.join()
|
||||
indexing_end_at = time.perf_counter()
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from core.llm_generator.prompts import (
|
|||
)
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
|
|
@ -314,14 +314,20 @@ class LLMGenerator:
|
|||
model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)]
|
||||
prompt_messages: list[PromptMessage] = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)]
|
||||
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
# Explicitly use the non-streaming overload
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters={"temperature": 0.01, "max_tokens": 2000},
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Runtime type check since pyright has issues with the overload
|
||||
if not isinstance(result, LLMResult):
|
||||
raise TypeError("Expected LLMResult when stream=False")
|
||||
response = result
|
||||
|
||||
answer = cast(str, response.message.content)
|
||||
return answer.strip()
|
||||
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ class SpecialModelType(StrEnum):
|
|||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
*,
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
|
|
@ -53,14 +54,13 @@ def invoke_llm_with_structured_output(
|
|||
model_parameters: Optional[Mapping] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: Literal[True] = True,
|
||||
stream: Literal[True],
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
*,
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
|
|
@ -69,14 +69,13 @@ def invoke_llm_with_structured_output(
|
|||
model_parameters: Optional[Mapping] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: Literal[False] = False,
|
||||
stream: Literal[False],
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> LLMResultWithStructuredOutput: ...
|
||||
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
*,
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
|
|
@ -89,9 +88,8 @@ def invoke_llm_with_structured_output(
|
|||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
|
||||
|
||||
def invoke_llm_with_structured_output(
|
||||
*,
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
|
|
|
|||
|
|
@ -23,13 +23,13 @@ DEFAULT_QUEUE_READ_TIMEOUT = 3
|
|||
@final
|
||||
class _StatusReady:
|
||||
def __init__(self, endpoint_url: str):
|
||||
self._endpoint_url = endpoint_url
|
||||
self.endpoint_url = endpoint_url
|
||||
|
||||
|
||||
@final
|
||||
class _StatusError:
|
||||
def __init__(self, exc: Exception):
|
||||
self._exc = exc
|
||||
self.exc = exc
|
||||
|
||||
|
||||
# Type aliases for better readability
|
||||
|
|
@ -211,9 +211,9 @@ class SSETransport:
|
|||
raise ValueError("failed to get endpoint URL")
|
||||
|
||||
if isinstance(status, _StatusReady):
|
||||
return status._endpoint_url
|
||||
return status.endpoint_url
|
||||
elif isinstance(status, _StatusError):
|
||||
raise status._exc
|
||||
raise status.exc
|
||||
else:
|
||||
raise ValueError("failed to get endpoint URL")
|
||||
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ def handle_mcp_request(
|
|||
"""
|
||||
|
||||
request_type = type(request.root)
|
||||
request_root = request.root
|
||||
|
||||
def create_success_response(result_data: mcp_types.Result) -> mcp_types.JSONRPCResponse:
|
||||
"""Create success response with business result data"""
|
||||
|
|
@ -58,21 +59,20 @@ def handle_mcp_request(
|
|||
error=error_data,
|
||||
)
|
||||
|
||||
# Request handler mapping using functional approach
|
||||
request_handlers = {
|
||||
mcp_types.InitializeRequest: lambda: handle_initialize(mcp_server.description),
|
||||
mcp_types.ListToolsRequest: lambda: handle_list_tools(
|
||||
app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict
|
||||
),
|
||||
mcp_types.CallToolRequest: lambda: handle_call_tool(app, request, user_input_form, end_user),
|
||||
mcp_types.PingRequest: lambda: handle_ping(),
|
||||
}
|
||||
|
||||
try:
|
||||
# Dispatch request to appropriate handler
|
||||
handler = request_handlers.get(request_type)
|
||||
if handler:
|
||||
return create_success_response(handler())
|
||||
# Dispatch request to appropriate handler based on instance type
|
||||
if isinstance(request_root, mcp_types.InitializeRequest):
|
||||
return create_success_response(handle_initialize(mcp_server.description))
|
||||
elif isinstance(request_root, mcp_types.ListToolsRequest):
|
||||
return create_success_response(
|
||||
handle_list_tools(
|
||||
app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict
|
||||
)
|
||||
)
|
||||
elif isinstance(request_root, mcp_types.CallToolRequest):
|
||||
return create_success_response(handle_call_tool(app, request, user_input_form, end_user))
|
||||
elif isinstance(request_root, mcp_types.PingRequest):
|
||||
return create_success_response(handle_ping())
|
||||
else:
|
||||
return create_error_response(mcp_types.METHOD_NOT_FOUND, f"Method not found: {request_type.__name__}")
|
||||
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
|||
self.request_meta = request_meta
|
||||
self.request = request
|
||||
self._session = session
|
||||
self._completed = False
|
||||
self.completed = False
|
||||
self._on_complete = on_complete
|
||||
self._entered = False # Track if we're in a context manager
|
||||
|
||||
|
|
@ -98,7 +98,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
|||
):
|
||||
"""Exit the context manager, performing cleanup and notifying completion."""
|
||||
try:
|
||||
if self._completed:
|
||||
if self.completed:
|
||||
self._on_complete(self)
|
||||
finally:
|
||||
self._entered = False
|
||||
|
|
@ -113,9 +113,9 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
|||
"""
|
||||
if not self._entered:
|
||||
raise RuntimeError("RequestResponder must be used as a context manager")
|
||||
assert not self._completed, "Request already responded to"
|
||||
assert not self.completed, "Request already responded to"
|
||||
|
||||
self._completed = True
|
||||
self.completed = True
|
||||
|
||||
self._session._send_response(request_id=self.request_id, response=response)
|
||||
|
||||
|
|
@ -124,7 +124,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
|||
if not self._entered:
|
||||
raise RuntimeError("RequestResponder must be used as a context manager")
|
||||
|
||||
self._completed = True # Mark as completed so it's removed from in_flight
|
||||
self.completed = True # Mark as completed so it's removed from in_flight
|
||||
# Send an error response to indicate cancellation
|
||||
self._session._send_response(
|
||||
request_id=self.request_id,
|
||||
|
|
@ -351,7 +351,7 @@ class BaseSession(
|
|||
self._in_flight[responder.request_id] = responder
|
||||
self._received_request(responder)
|
||||
|
||||
if not responder._completed:
|
||||
if not responder.completed:
|
||||
self._handle_incoming(responder)
|
||||
|
||||
elif isinstance(message.message.root, JSONRPCNotification):
|
||||
|
|
|
|||
|
|
@ -357,7 +357,7 @@ class LargeLanguageModel(AIModel):
|
|||
)
|
||||
return 0
|
||||
|
||||
def _calc_response_usage(
|
||||
def calc_response_usage(
|
||||
self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int
|
||||
) -> LLMUsage:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -82,7 +82,9 @@ def merge_blob_chunks(
|
|||
message_class = type(resp)
|
||||
merged_message = message_class(
|
||||
type=ToolInvokeMessage.MessageType.BLOB,
|
||||
message=ToolInvokeMessage.BlobMessage(blob=files[chunk_id].data[: files[chunk_id].bytes_written]),
|
||||
message=ToolInvokeMessage.BlobMessage(
|
||||
blob=bytes(files[chunk_id].data[: files[chunk_id].bytes_written])
|
||||
),
|
||||
meta=resp.meta,
|
||||
)
|
||||
yield cast(MessageType, merged_message)
|
||||
|
|
|
|||
|
|
@ -101,9 +101,22 @@ class SimplePromptTransform(PromptTransform):
|
|||
with_memory_prompt=histories is not None,
|
||||
)
|
||||
|
||||
variables = {k: inputs[k] for k in prompt_template_config["custom_variable_keys"] if k in inputs}
|
||||
custom_variable_keys_obj = prompt_template_config["custom_variable_keys"]
|
||||
special_variable_keys_obj = prompt_template_config["special_variable_keys"]
|
||||
|
||||
for v in prompt_template_config["special_variable_keys"]:
|
||||
# Type check for custom_variable_keys
|
||||
if not isinstance(custom_variable_keys_obj, list):
|
||||
raise TypeError(f"Expected list for custom_variable_keys, got {type(custom_variable_keys_obj)}")
|
||||
custom_variable_keys = cast(list[str], custom_variable_keys_obj)
|
||||
|
||||
# Type check for special_variable_keys
|
||||
if not isinstance(special_variable_keys_obj, list):
|
||||
raise TypeError(f"Expected list for special_variable_keys, got {type(special_variable_keys_obj)}")
|
||||
special_variable_keys = cast(list[str], special_variable_keys_obj)
|
||||
|
||||
variables = {k: inputs[k] for k in custom_variable_keys if k in inputs}
|
||||
|
||||
for v in special_variable_keys:
|
||||
# support #context#, #query# and #histories#
|
||||
if v == "#context#":
|
||||
variables["#context#"] = context or ""
|
||||
|
|
@ -113,9 +126,16 @@ class SimplePromptTransform(PromptTransform):
|
|||
variables["#histories#"] = histories or ""
|
||||
|
||||
prompt_template = prompt_template_config["prompt_template"]
|
||||
if not isinstance(prompt_template, PromptTemplateParser):
|
||||
raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template)}")
|
||||
|
||||
prompt = prompt_template.format(variables)
|
||||
|
||||
return prompt, prompt_template_config["prompt_rules"]
|
||||
prompt_rules = prompt_template_config["prompt_rules"]
|
||||
if not isinstance(prompt_rules, dict):
|
||||
raise TypeError(f"Expected dict for prompt_rules, got {type(prompt_rules)}")
|
||||
|
||||
return prompt, prompt_rules
|
||||
|
||||
def get_prompt_template(
|
||||
self,
|
||||
|
|
@ -126,11 +146,11 @@ class SimplePromptTransform(PromptTransform):
|
|||
has_context: bool,
|
||||
query_in_prompt: bool,
|
||||
with_memory_prompt: bool = False,
|
||||
):
|
||||
) -> dict[str, object]:
|
||||
prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model)
|
||||
|
||||
custom_variable_keys = []
|
||||
special_variable_keys = []
|
||||
custom_variable_keys: list[str] = []
|
||||
special_variable_keys: list[str] = []
|
||||
|
||||
prompt = ""
|
||||
for order in prompt_rules["system_prompt_orders"]:
|
||||
|
|
|
|||
|
|
@ -40,6 +40,19 @@ if TYPE_CHECKING:
|
|||
MetadataFilter = Union[DictFilter, common_types.Filter]
|
||||
|
||||
|
||||
class PathQdrantParams(BaseModel):
|
||||
path: str
|
||||
|
||||
|
||||
class UrlQdrantParams(BaseModel):
|
||||
url: str
|
||||
api_key: Optional[str]
|
||||
timeout: float
|
||||
verify: bool
|
||||
grpc_port: int
|
||||
prefer_grpc: bool
|
||||
|
||||
|
||||
class QdrantConfig(BaseModel):
|
||||
endpoint: str
|
||||
api_key: Optional[str] = None
|
||||
|
|
@ -50,7 +63,7 @@ class QdrantConfig(BaseModel):
|
|||
replication_factor: int = 1
|
||||
write_consistency_factor: int = 1
|
||||
|
||||
def to_qdrant_params(self):
|
||||
def to_qdrant_params(self) -> PathQdrantParams | UrlQdrantParams:
|
||||
if self.endpoint and self.endpoint.startswith("path:"):
|
||||
path = self.endpoint.replace("path:", "")
|
||||
if not os.path.isabs(path):
|
||||
|
|
@ -58,23 +71,23 @@ class QdrantConfig(BaseModel):
|
|||
raise ValueError("Root path is not set")
|
||||
path = os.path.join(self.root_path, path)
|
||||
|
||||
return {"path": path}
|
||||
return PathQdrantParams(path=path)
|
||||
else:
|
||||
return {
|
||||
"url": self.endpoint,
|
||||
"api_key": self.api_key,
|
||||
"timeout": self.timeout,
|
||||
"verify": self.endpoint.startswith("https"),
|
||||
"grpc_port": self.grpc_port,
|
||||
"prefer_grpc": self.prefer_grpc,
|
||||
}
|
||||
return UrlQdrantParams(
|
||||
url=self.endpoint,
|
||||
api_key=self.api_key,
|
||||
timeout=self.timeout,
|
||||
verify=self.endpoint.startswith("https"),
|
||||
grpc_port=self.grpc_port,
|
||||
prefer_grpc=self.prefer_grpc,
|
||||
)
|
||||
|
||||
|
||||
class QdrantVector(BaseVector):
|
||||
def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = "Cosine"):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params())
|
||||
self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params().model_dump())
|
||||
self._distance_func = distance_func.upper()
|
||||
self._group_id = group_id
|
||||
|
||||
|
|
|
|||
|
|
@ -94,10 +94,10 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
|
|||
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
|
||||
|
||||
# In-memory cache for workflow node executions
|
||||
self._execution_cache: dict[str, WorkflowNodeExecution] = {}
|
||||
self._execution_cache = {}
|
||||
|
||||
# Cache for mapping workflow_execution_ids to execution IDs for efficient retrieval
|
||||
self._workflow_execution_mapping: dict[str, list[str]] = {}
|
||||
self._workflow_execution_mapping = {}
|
||||
|
||||
logger.info(
|
||||
"Initialized CeleryWorkflowNodeExecutionRepository for tenant %s, app %s, triggered_from %s",
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from .types import SegmentType
|
|||
|
||||
class SegmentGroup(Segment):
|
||||
value_type: SegmentType = SegmentType.GROUP
|
||||
value: list[Segment]
|
||||
value: list[Segment] = None # type: ignore
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
|
|
|
|||
|
|
@ -74,12 +74,12 @@ class NoneSegment(Segment):
|
|||
|
||||
class StringSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.STRING
|
||||
value: str
|
||||
value: str = None # type: ignore
|
||||
|
||||
|
||||
class FloatSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.FLOAT
|
||||
value: float
|
||||
value: float = None # type: ignore
|
||||
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
|
||||
# The following tests cannot pass.
|
||||
#
|
||||
|
|
@ -98,12 +98,12 @@ class FloatSegment(Segment):
|
|||
|
||||
class IntegerSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.INTEGER
|
||||
value: int
|
||||
value: int = None # type: ignore
|
||||
|
||||
|
||||
class ObjectSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.OBJECT
|
||||
value: Mapping[str, Any]
|
||||
value: Mapping[str, Any] = None # type: ignore
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
|
|
@ -136,7 +136,7 @@ class ArraySegment(Segment):
|
|||
|
||||
class FileSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.FILE
|
||||
value: File
|
||||
value: File = None # type: ignore
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
|
|
@ -153,17 +153,17 @@ class FileSegment(Segment):
|
|||
|
||||
class BooleanSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.BOOLEAN
|
||||
value: bool
|
||||
value: bool = None # type: ignore
|
||||
|
||||
|
||||
class ArrayAnySegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_ANY
|
||||
value: Sequence[Any]
|
||||
value: Sequence[Any] = None # type: ignore
|
||||
|
||||
|
||||
class ArrayStringSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_STRING
|
||||
value: Sequence[str]
|
||||
value: Sequence[str] = None # type: ignore
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
|
|
@ -175,17 +175,17 @@ class ArrayStringSegment(ArraySegment):
|
|||
|
||||
class ArrayNumberSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_NUMBER
|
||||
value: Sequence[float | int]
|
||||
value: Sequence[float | int] = None # type: ignore
|
||||
|
||||
|
||||
class ArrayObjectSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_OBJECT
|
||||
value: Sequence[Mapping[str, Any]]
|
||||
value: Sequence[Mapping[str, Any]] = None # type: ignore
|
||||
|
||||
|
||||
class ArrayFileSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_FILE
|
||||
value: Sequence[File]
|
||||
value: Sequence[File] = None # type: ignore
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
|
|
@ -205,7 +205,7 @@ class ArrayFileSegment(ArraySegment):
|
|||
|
||||
class ArrayBooleanSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_BOOLEAN
|
||||
value: Sequence[bool]
|
||||
value: Sequence[bool] = None # type: ignore
|
||||
|
||||
|
||||
def get_segment_discriminator(v: Any) -> SegmentType | None:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,6 @@ from core.workflow.nodes.base.node import Node
|
|||
|
||||
class WorkflowNodeRunFailedError(Exception):
|
||||
def __init__(self, node: Node, err_msg: str):
|
||||
self._node = node
|
||||
self._error = err_msg
|
||||
self.node = node
|
||||
self.error = err_msg
|
||||
super().__init__(f"Node {node.title} run failed: {err_msg}")
|
||||
|
|
|
|||
|
|
@ -66,8 +66,8 @@ class ListOperatorNode(Node):
|
|||
return "1"
|
||||
|
||||
def _run(self):
|
||||
inputs: dict[str, list] = {}
|
||||
process_data: dict[str, list] = {}
|
||||
inputs: dict[str, Sequence[object]] = {}
|
||||
process_data: dict[str, Sequence[object]] = {}
|
||||
outputs: dict[str, Any] = {}
|
||||
|
||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable)
|
||||
|
|
|
|||
|
|
@ -1184,7 +1184,8 @@ def _combine_message_content_with_role(
|
|||
return AssistantPromptMessage(content=contents)
|
||||
case PromptMessageRole.SYSTEM:
|
||||
return SystemPromptMessage(content=contents)
|
||||
raise NotImplementedError(f"Role {role} is not supported")
|
||||
case _:
|
||||
raise NotImplementedError(f"Role {role} is not supported")
|
||||
|
||||
|
||||
def _render_jinja2_message(
|
||||
|
|
|
|||
|
|
@ -462,9 +462,9 @@ class StorageKeyLoader:
|
|||
upload_file_row = upload_files.get(model_id)
|
||||
if upload_file_row is None:
|
||||
raise ValueError(f"Upload file not found for id: {model_id}")
|
||||
file._storage_key = upload_file_row.key
|
||||
file.storage_key = upload_file_row.key
|
||||
elif file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
tool_file_row = tool_files.get(model_id)
|
||||
if tool_file_row is None:
|
||||
raise ValueError(f"Tool file not found for id: {model_id}")
|
||||
file._storage_key = tool_file_row.file_key
|
||||
file.storage_key = tool_file_row.file_key
|
||||
|
|
|
|||
|
|
@ -12,4 +12,7 @@ def serialize_value_type(v: _VarTypedDict | Segment) -> str:
|
|||
if isinstance(v, Segment):
|
||||
return v.value_type.exposed_type().value
|
||||
else:
|
||||
return v["value_type"].exposed_type().value
|
||||
value_type = v.get("value_type")
|
||||
if value_type is None:
|
||||
raise ValueError("value_type is required but not provided")
|
||||
return value_type.exposed_type().value
|
||||
|
|
|
|||
|
|
@ -69,6 +69,8 @@ def register_external_error_handlers(api: Api):
|
|||
headers["WWW-Authenticate"] = 'Bearer realm="api"'
|
||||
return data, status_code, headers
|
||||
|
||||
_ = handle_http_exception
|
||||
|
||||
@api.errorhandler(ValueError)
|
||||
def handle_value_error(e: ValueError):
|
||||
got_request_exception.send(current_app, exception=e)
|
||||
|
|
@ -76,6 +78,8 @@ def register_external_error_handlers(api: Api):
|
|||
data = {"code": "invalid_param", "message": str(e), "status": status_code}
|
||||
return data, status_code
|
||||
|
||||
_ = handle_value_error
|
||||
|
||||
@api.errorhandler(AppInvokeQuotaExceededError)
|
||||
def handle_quota_exceeded(e: AppInvokeQuotaExceededError):
|
||||
got_request_exception.send(current_app, exception=e)
|
||||
|
|
@ -83,15 +87,17 @@ def register_external_error_handlers(api: Api):
|
|||
data = {"code": "too_many_requests", "message": str(e), "status": status_code}
|
||||
return data, status_code
|
||||
|
||||
_ = handle_quota_exceeded
|
||||
|
||||
@api.errorhandler(Exception)
|
||||
def handle_general_exception(e: Exception):
|
||||
got_request_exception.send(current_app, exception=e)
|
||||
|
||||
status_code = 500
|
||||
data: dict[str, Any] = getattr(e, "data", {"message": http_status_message(status_code)})
|
||||
data = getattr(e, "data", {"message": http_status_message(status_code)})
|
||||
|
||||
# 🔒 Normalize non-mapping data (e.g., if someone set e.data = Response)
|
||||
if not isinstance(data, Mapping):
|
||||
if not isinstance(data, dict):
|
||||
data = {"message": str(e)}
|
||||
|
||||
data.setdefault("code", "unknown")
|
||||
|
|
@ -101,10 +107,12 @@ def register_external_error_handlers(api: Api):
|
|||
exc_info: Any = sys.exc_info()
|
||||
if exc_info[1] is None:
|
||||
exc_info = None
|
||||
current_app.log_exception(exc_info) # ty: ignore [invalid-argument-type]
|
||||
current_app.log_exception(exc_info)
|
||||
|
||||
return data, status_code
|
||||
|
||||
_ = handle_general_exception
|
||||
|
||||
|
||||
class ExternalApi(Api):
|
||||
_authorizations = {
|
||||
|
|
|
|||
|
|
@ -167,13 +167,6 @@ class DatetimeString:
|
|||
return value
|
||||
|
||||
|
||||
def _get_float(value):
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError(f"{value} is not a valid float")
|
||||
|
||||
|
||||
def timezone(timezone_string):
|
||||
if timezone_string and timezone_string in available_timezones():
|
||||
return timezone_string
|
||||
|
|
|
|||
|
|
@ -1,24 +1,44 @@
|
|||
{
|
||||
"include": ["."],
|
||||
"exclude": [".venv", "tests/", "migrations/"],
|
||||
"ignore": [
|
||||
"core/",
|
||||
"controllers/",
|
||||
"tasks/",
|
||||
"services/",
|
||||
"schedule/",
|
||||
"extensions/",
|
||||
"utils/",
|
||||
"repositories/",
|
||||
"libs/",
|
||||
"fields/",
|
||||
"factories/",
|
||||
"events/",
|
||||
"contexts/",
|
||||
"constants/",
|
||||
"commands.py"
|
||||
"exclude": [
|
||||
".venv",
|
||||
"tests/",
|
||||
"migrations/",
|
||||
"core/rag",
|
||||
"extensions",
|
||||
"libs",
|
||||
"controllers/console/datasets",
|
||||
"controllers/service_api/dataset",
|
||||
"core/ops",
|
||||
"core/tools",
|
||||
"core/model_runtime",
|
||||
"core/workflow",
|
||||
"core/app/app_config/easy_ui_based_app/dataset"
|
||||
],
|
||||
"typeCheckingMode": "strict",
|
||||
"allowedUntypedLibraries": [
|
||||
"flask_restx",
|
||||
"flask_login",
|
||||
"opentelemetry.instrumentation.celery",
|
||||
"opentelemetry.instrumentation.flask",
|
||||
"opentelemetry.instrumentation.requests",
|
||||
"opentelemetry.instrumentation.sqlalchemy",
|
||||
"opentelemetry.instrumentation.redis"
|
||||
],
|
||||
"reportUnknownMemberType": "hint",
|
||||
"reportUnknownParameterType": "hint",
|
||||
"reportUnknownArgumentType": "hint",
|
||||
"reportUnknownVariableType": "hint",
|
||||
"reportUnknownLambdaType": "hint",
|
||||
"reportMissingParameterType": "hint",
|
||||
"reportMissingTypeArgument": "hint",
|
||||
"reportUnnecessaryContains": "hint",
|
||||
"reportUnnecessaryComparison": "hint",
|
||||
"reportUnnecessaryCast": "hint",
|
||||
"reportUnnecessaryIsInstance": "hint",
|
||||
"reportUntypedFunctionDecorator": "hint",
|
||||
|
||||
"reportAttributeAccessIssue": "hint",
|
||||
"pythonVersion": "3.11",
|
||||
"pythonPlatform": "All"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1318,7 +1318,7 @@ class RegisterService:
|
|||
def get_invitation_if_token_valid(
|
||||
cls, workspace_id: Optional[str], email: str, token: str
|
||||
) -> Optional[dict[str, Any]]:
|
||||
invitation_data = cls._get_invitation_by_token(token, workspace_id, email)
|
||||
invitation_data = cls.get_invitation_by_token(token, workspace_id, email)
|
||||
if not invitation_data:
|
||||
return None
|
||||
|
||||
|
|
@ -1355,7 +1355,7 @@ class RegisterService:
|
|||
}
|
||||
|
||||
@classmethod
|
||||
def _get_invitation_by_token(
|
||||
def get_invitation_by_token(
|
||||
cls, token: str, workspace_id: Optional[str] = None, email: Optional[str] = None
|
||||
) -> Optional[dict[str, str]]:
|
||||
if workspace_id is not None and email is not None:
|
||||
|
|
|
|||
|
|
@ -349,7 +349,7 @@ class AppAnnotationService:
|
|||
|
||||
try:
|
||||
# Skip the first row
|
||||
df = pd.read_csv(file, dtype=str)
|
||||
df = pd.read_csv(file.stream, dtype=str)
|
||||
result = []
|
||||
for _, row in df.iterrows():
|
||||
content = {"question": row.iloc[0], "answer": row.iloc[1]}
|
||||
|
|
@ -463,15 +463,23 @@ class AppAnnotationService:
|
|||
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
|
||||
if annotation_setting:
|
||||
collection_binding_detail = annotation_setting.collection_binding_detail
|
||||
return {
|
||||
"id": annotation_setting.id,
|
||||
"enabled": True,
|
||||
"score_threshold": annotation_setting.score_threshold,
|
||||
"embedding_model": {
|
||||
"embedding_provider_name": collection_binding_detail.provider_name,
|
||||
"embedding_model_name": collection_binding_detail.model_name,
|
||||
},
|
||||
}
|
||||
if collection_binding_detail:
|
||||
return {
|
||||
"id": annotation_setting.id,
|
||||
"enabled": True,
|
||||
"score_threshold": annotation_setting.score_threshold,
|
||||
"embedding_model": {
|
||||
"embedding_provider_name": collection_binding_detail.provider_name,
|
||||
"embedding_model_name": collection_binding_detail.model_name,
|
||||
},
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"id": annotation_setting.id,
|
||||
"enabled": True,
|
||||
"score_threshold": annotation_setting.score_threshold,
|
||||
"embedding_model": {},
|
||||
}
|
||||
return {"enabled": False}
|
||||
|
||||
@classmethod
|
||||
|
|
@ -506,15 +514,23 @@ class AppAnnotationService:
|
|||
|
||||
collection_binding_detail = annotation_setting.collection_binding_detail
|
||||
|
||||
return {
|
||||
"id": annotation_setting.id,
|
||||
"enabled": True,
|
||||
"score_threshold": annotation_setting.score_threshold,
|
||||
"embedding_model": {
|
||||
"embedding_provider_name": collection_binding_detail.provider_name,
|
||||
"embedding_model_name": collection_binding_detail.model_name,
|
||||
},
|
||||
}
|
||||
if collection_binding_detail:
|
||||
return {
|
||||
"id": annotation_setting.id,
|
||||
"enabled": True,
|
||||
"score_threshold": annotation_setting.score_threshold,
|
||||
"embedding_model": {
|
||||
"embedding_provider_name": collection_binding_detail.provider_name,
|
||||
"embedding_model_name": collection_binding_detail.model_name,
|
||||
},
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"id": annotation_setting.id,
|
||||
"enabled": True,
|
||||
"score_threshold": annotation_setting.score_threshold,
|
||||
"embedding_model": {},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def clear_all_annotations(cls, app_id: str):
|
||||
|
|
|
|||
|
|
@ -407,6 +407,7 @@ class ClearFreePlanTenantExpiredLogs:
|
|||
datetime.timedelta(hours=1),
|
||||
]
|
||||
|
||||
tenant_count = 0
|
||||
for test_interval in test_intervals:
|
||||
tenant_count = (
|
||||
session.query(Tenant.id)
|
||||
|
|
|
|||
|
|
@ -134,11 +134,14 @@ class DatasetService:
|
|||
|
||||
# Check if tag_ids is not empty to avoid WHERE false condition
|
||||
if tag_ids and len(tag_ids) > 0:
|
||||
target_ids = TagService.get_target_ids_by_tag_ids(
|
||||
"knowledge",
|
||||
tenant_id, # ty: ignore [invalid-argument-type]
|
||||
tag_ids,
|
||||
)
|
||||
if tenant_id is not None:
|
||||
target_ids = TagService.get_target_ids_by_tag_ids(
|
||||
"knowledge",
|
||||
tenant_id,
|
||||
tag_ids,
|
||||
)
|
||||
else:
|
||||
target_ids = []
|
||||
if target_ids and len(target_ids) > 0:
|
||||
query = query.where(Dataset.id.in_(target_ids))
|
||||
else:
|
||||
|
|
@ -987,7 +990,8 @@ class DocumentService:
|
|||
for document in documents
|
||||
if document.data_source_type == "upload_file" and document.data_source_info_dict
|
||||
]
|
||||
batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids)
|
||||
if dataset.doc_form is not None:
|
||||
batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids)
|
||||
|
||||
for document in documents:
|
||||
db.session.delete(document)
|
||||
|
|
@ -2688,56 +2692,6 @@ class SegmentService:
|
|||
|
||||
return paginated_segments.items, paginated_segments.total
|
||||
|
||||
@classmethod
|
||||
def update_segment_by_id(
|
||||
cls, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, segment_data: dict, user_id: str
|
||||
) -> tuple[DocumentSegment, Document]:
|
||||
"""Update a segment by its ID with validation and checks."""
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
|
||||
# check document
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
# check embedding model setting if high quality
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=user_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(ex.description)
|
||||
|
||||
# check segment
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
# validate and update segment
|
||||
cls.segment_create_args_validate(segment_data, document)
|
||||
updated_segment = cls.update_segment(SegmentUpdateArgs(**segment_data), segment, document, dataset)
|
||||
|
||||
return updated_segment, document
|
||||
|
||||
@classmethod
|
||||
def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]:
|
||||
"""Get a segment by its ID."""
|
||||
|
|
|
|||
|
|
@ -181,7 +181,7 @@ class ExternalDatasetService:
|
|||
do http request depending on api bundle
|
||||
"""
|
||||
|
||||
kwargs = {
|
||||
kwargs: dict[str, Any] = {
|
||||
"url": settings.url,
|
||||
"headers": settings.headers,
|
||||
"follow_redirects": True,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import hashlib
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Literal, Union
|
||||
from typing import Literal, Union
|
||||
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
|
@ -35,7 +35,7 @@ class FileService:
|
|||
filename: str,
|
||||
content: bytes,
|
||||
mimetype: str,
|
||||
user: Union[Account, EndUser, Any],
|
||||
user: Union[Account, EndUser],
|
||||
source: Literal["datasets"] | None = None,
|
||||
source_url: str = "",
|
||||
) -> UploadFile:
|
||||
|
|
|
|||
|
|
@ -165,7 +165,7 @@ class ModelLoadBalancingService:
|
|||
|
||||
try:
|
||||
if load_balancing_config.encrypted_config:
|
||||
credentials = json.loads(load_balancing_config.encrypted_config)
|
||||
credentials: dict[str, object] = json.loads(load_balancing_config.encrypted_config)
|
||||
else:
|
||||
credentials = {}
|
||||
except JSONDecodeError:
|
||||
|
|
@ -180,11 +180,13 @@ class ModelLoadBalancingService:
|
|||
for variable in credential_secret_variables:
|
||||
if variable in credentials:
|
||||
try:
|
||||
credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
credentials.get(variable), # ty: ignore [invalid-argument-type]
|
||||
decoding_rsa_key,
|
||||
decoding_cipher_rsa,
|
||||
)
|
||||
token_value = credentials.get(variable)
|
||||
if isinstance(token_value, str):
|
||||
credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
token_value,
|
||||
decoding_rsa_key,
|
||||
decoding_cipher_rsa,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
|
@ -345,8 +347,9 @@ class ModelLoadBalancingService:
|
|||
credential_id = config.get("credential_id")
|
||||
enabled = config.get("enabled")
|
||||
|
||||
credential_record: ProviderCredential | ProviderModelCredential | None = None
|
||||
|
||||
if credential_id:
|
||||
credential_record: ProviderCredential | ProviderModelCredential | None = None
|
||||
if config_from == "predefined-model":
|
||||
credential_record = (
|
||||
db.session.query(ProviderCredential)
|
||||
|
|
|
|||
|
|
@ -100,6 +100,7 @@ class PluginMigration:
|
|||
datetime.timedelta(hours=1),
|
||||
]
|
||||
|
||||
tenant_count = 0
|
||||
for test_interval in test_intervals:
|
||||
tenant_count = (
|
||||
session.query(Tenant.id)
|
||||
|
|
|
|||
|
|
@ -223,8 +223,8 @@ class BuiltinToolManageService:
|
|||
"""
|
||||
add builtin tool provider
|
||||
"""
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
|
||||
with redis_client.lock(lock, timeout=20):
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
|
|
@ -285,9 +285,9 @@ class BuiltinToolManageService:
|
|||
|
||||
session.add(db_provider)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise ValueError(str(e))
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise ValueError(str(e))
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from core.helper import encrypter
|
|||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.simple_prompt_transform import SimplePromptTransform
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.nodes import NodeType
|
||||
from events.app_event import app_was_created
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -420,7 +421,11 @@ class WorkflowConverter:
|
|||
query_in_prompt=False,
|
||||
)
|
||||
|
||||
template = prompt_template_config["prompt_template"].template
|
||||
prompt_template_obj = prompt_template_config["prompt_template"]
|
||||
if not isinstance(prompt_template_obj, PromptTemplateParser):
|
||||
raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template_obj)}")
|
||||
|
||||
template = prompt_template_obj.template
|
||||
if not template:
|
||||
prompts = []
|
||||
else:
|
||||
|
|
@ -457,7 +462,11 @@ class WorkflowConverter:
|
|||
query_in_prompt=False,
|
||||
)
|
||||
|
||||
template = prompt_template_config["prompt_template"].template
|
||||
prompt_template_obj = prompt_template_config["prompt_template"]
|
||||
if not isinstance(prompt_template_obj, PromptTemplateParser):
|
||||
raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template_obj)}")
|
||||
|
||||
template = prompt_template_obj.template
|
||||
template = self._replace_template_variables(
|
||||
template=template,
|
||||
variables=start_node["data"]["variables"],
|
||||
|
|
@ -467,6 +476,9 @@ class WorkflowConverter:
|
|||
prompts = {"text": template}
|
||||
|
||||
prompt_rules = prompt_template_config["prompt_rules"]
|
||||
if not isinstance(prompt_rules, dict):
|
||||
raise TypeError(f"Expected dict for prompt_rules, got {type(prompt_rules)}")
|
||||
|
||||
role_prefix = {
|
||||
"user": prompt_rules.get("human_prefix", "Human"),
|
||||
"assistant": prompt_rules.get("assistant_prefix", "Assistant"),
|
||||
|
|
|
|||
|
|
@ -783,11 +783,13 @@ class WorkflowService:
|
|||
WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
)
|
||||
error = node_run_result.error if not run_succeeded else None
|
||||
|
||||
return node, node_run_result, run_succeeded, error
|
||||
|
||||
except WorkflowNodeRunFailedError as e:
|
||||
return e._node, None, False, e._error
|
||||
node = e.node
|
||||
run_succeeded = False
|
||||
node_run_result = None
|
||||
error = e.error
|
||||
return node, node_run_result, run_succeeded, error
|
||||
|
||||
def _apply_error_strategy(self, node: Node, node_run_result: NodeRunResult) -> NodeRunResult:
|
||||
"""Apply error strategy when node execution fails."""
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ class WorkspaceService:
|
|||
def get_tenant_info(cls, tenant: Tenant):
|
||||
if not tenant:
|
||||
return None
|
||||
tenant_info = {
|
||||
tenant_info: dict[str, object] = {
|
||||
"id": tenant.id,
|
||||
"name": tenant.name,
|
||||
"plan": tenant.plan,
|
||||
|
|
|
|||
|
|
@ -3278,7 +3278,7 @@ class TestRegisterService:
|
|||
redis_client.setex(cache_key, 24 * 60 * 60, account_id)
|
||||
|
||||
# Execute invitation retrieval
|
||||
result = RegisterService._get_invitation_by_token(
|
||||
result = RegisterService.get_invitation_by_token(
|
||||
token=token,
|
||||
workspace_id=workspace_id,
|
||||
email=email,
|
||||
|
|
@ -3316,7 +3316,7 @@ class TestRegisterService:
|
|||
redis_client.setex(token_key, 24 * 60 * 60, json.dumps(invitation_data))
|
||||
|
||||
# Execute invitation retrieval
|
||||
result = RegisterService._get_invitation_by_token(token=token)
|
||||
result = RegisterService.get_invitation_by_token(token=token)
|
||||
|
||||
# Verify result contains expected data
|
||||
assert result is not None
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from core.app.app_config.entities import (
|
|||
VariableEntityType,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from models.account import Account, Tenant
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
|
|
@ -37,7 +38,7 @@ class TestWorkflowConverter:
|
|||
# Setup default mock returns
|
||||
mock_encrypter.decrypt_token.return_value = "decrypted_api_key"
|
||||
mock_prompt_transform.return_value.get_prompt_template.return_value = {
|
||||
"prompt_template": type("obj", (object,), {"template": "You are a helpful assistant {{text_input}}"})(),
|
||||
"prompt_template": PromptTemplateParser(template="You are a helpful assistant {{text_input}}"),
|
||||
"prompt_rules": {"human_prefix": "Human", "assistant_prefix": "Assistant"},
|
||||
}
|
||||
mock_agent_chat_config_manager.get_app_config.return_value = self._create_mock_app_config()
|
||||
|
|
|
|||
|
|
@ -1370,8 +1370,8 @@ class TestRegisterService:
|
|||
account_id="user-123", email="test@example.com"
|
||||
)
|
||||
|
||||
with patch("services.account_service.RegisterService._get_invitation_by_token") as mock_get_invitation_by_token:
|
||||
# Mock the invitation data returned by _get_invitation_by_token
|
||||
with patch("services.account_service.RegisterService.get_invitation_by_token") as mock_get_invitation_by_token:
|
||||
# Mock the invitation data returned by get_invitation_by_token
|
||||
invitation_data = {
|
||||
"account_id": "user-123",
|
||||
"email": "test@example.com",
|
||||
|
|
@ -1503,12 +1503,12 @@ class TestRegisterService:
|
|||
assert result == "member_invite:token:test-token"
|
||||
|
||||
def test_get_invitation_by_token_with_workspace_and_email(self, mock_redis_dependencies):
|
||||
"""Test _get_invitation_by_token with workspace ID and email."""
|
||||
"""Test get_invitation_by_token with workspace ID and email."""
|
||||
# Setup mock
|
||||
mock_redis_dependencies.get.return_value = b"user-123"
|
||||
|
||||
# Execute test
|
||||
result = RegisterService._get_invitation_by_token("token-123", "workspace-456", "test@example.com")
|
||||
result = RegisterService.get_invitation_by_token("token-123", "workspace-456", "test@example.com")
|
||||
|
||||
# Verify results
|
||||
assert result is not None
|
||||
|
|
@ -1517,7 +1517,7 @@ class TestRegisterService:
|
|||
assert result["workspace_id"] == "workspace-456"
|
||||
|
||||
def test_get_invitation_by_token_without_workspace_and_email(self, mock_redis_dependencies):
|
||||
"""Test _get_invitation_by_token without workspace ID and email."""
|
||||
"""Test get_invitation_by_token without workspace ID and email."""
|
||||
# Setup mock
|
||||
invitation_data = {
|
||||
"account_id": "user-123",
|
||||
|
|
@ -1527,19 +1527,19 @@ class TestRegisterService:
|
|||
mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode()
|
||||
|
||||
# Execute test
|
||||
result = RegisterService._get_invitation_by_token("token-123")
|
||||
result = RegisterService.get_invitation_by_token("token-123")
|
||||
|
||||
# Verify results
|
||||
assert result is not None
|
||||
assert result == invitation_data
|
||||
|
||||
def test_get_invitation_by_token_no_data(self, mock_redis_dependencies):
|
||||
"""Test _get_invitation_by_token with no data."""
|
||||
"""Test get_invitation_by_token with no data."""
|
||||
# Setup mock
|
||||
mock_redis_dependencies.get.return_value = None
|
||||
|
||||
# Execute test
|
||||
result = RegisterService._get_invitation_by_token("token-123")
|
||||
result = RegisterService.get_invitation_by_token("token-123")
|
||||
|
||||
# Verify results
|
||||
assert result is None
|
||||
|
|
|
|||
Loading…
Reference in New Issue