mirror of https://github.com/langgenius/dify.git
Fix basedpyright type errors (#25435)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
2ac7a9c8fc
commit
08dd3f7b50
|
|
@ -511,7 +511,7 @@ def add_qdrant_index(field: str):
|
||||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||||
from qdrant_client.http.models import PayloadSchemaType
|
from qdrant_client.http.models import PayloadSchemaType
|
||||||
|
|
||||||
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig
|
from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
|
||||||
|
|
||||||
for binding in bindings:
|
for binding in bindings:
|
||||||
if dify_config.QDRANT_URL is None:
|
if dify_config.QDRANT_URL is None:
|
||||||
|
|
@ -525,7 +525,21 @@ def add_qdrant_index(field: str):
|
||||||
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,
|
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params())
|
params = qdrant_config.to_qdrant_params()
|
||||||
|
# Check the type before using
|
||||||
|
if isinstance(params, PathQdrantParams):
|
||||||
|
# PathQdrantParams case
|
||||||
|
client = qdrant_client.QdrantClient(path=params.path)
|
||||||
|
else:
|
||||||
|
# UrlQdrantParams case - params is UrlQdrantParams
|
||||||
|
client = qdrant_client.QdrantClient(
|
||||||
|
url=params.url,
|
||||||
|
api_key=params.api_key,
|
||||||
|
timeout=int(params.timeout),
|
||||||
|
verify=params.verify,
|
||||||
|
grpc_port=params.grpc_port,
|
||||||
|
prefer_grpc=params.prefer_grpc,
|
||||||
|
)
|
||||||
# create payload index
|
# create payload index
|
||||||
client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD)
|
client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD)
|
||||||
create_count += 1
|
create_count += 1
|
||||||
|
|
|
||||||
|
|
@ -16,14 +16,14 @@ AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"]
|
||||||
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
|
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
|
||||||
|
|
||||||
|
|
||||||
|
_doc_extensions: list[str]
|
||||||
if dify_config.ETL_TYPE == "Unstructured":
|
if dify_config.ETL_TYPE == "Unstructured":
|
||||||
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"]
|
_doc_extensions = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"]
|
||||||
DOCUMENT_EXTENSIONS.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
|
_doc_extensions.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
|
||||||
if dify_config.UNSTRUCTURED_API_URL:
|
if dify_config.UNSTRUCTURED_API_URL:
|
||||||
DOCUMENT_EXTENSIONS.append("ppt")
|
_doc_extensions.append("ppt")
|
||||||
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
|
|
||||||
else:
|
else:
|
||||||
DOCUMENT_EXTENSIONS = [
|
_doc_extensions = [
|
||||||
"txt",
|
"txt",
|
||||||
"markdown",
|
"markdown",
|
||||||
"md",
|
"md",
|
||||||
|
|
@ -38,4 +38,4 @@ else:
|
||||||
"vtt",
|
"vtt",
|
||||||
"properties",
|
"properties",
|
||||||
]
|
]
|
||||||
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
|
DOCUMENT_EXTENSIONS = _doc_extensions + [ext.upper() for ext in _doc_extensions]
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ if TYPE_CHECKING:
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -43,56 +43,64 @@ api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm"
|
||||||
api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")
|
api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")
|
||||||
|
|
||||||
# Import other controllers
|
# Import other controllers
|
||||||
from . import admin, apikey, extension, feature, ping, setup, version
|
from . import admin, apikey, extension, feature, ping, setup, version # pyright: ignore[reportUnusedImport]
|
||||||
|
|
||||||
# Import app controllers
|
# Import app controllers
|
||||||
from .app import (
|
from .app import (
|
||||||
advanced_prompt_template,
|
advanced_prompt_template, # pyright: ignore[reportUnusedImport]
|
||||||
agent,
|
agent, # pyright: ignore[reportUnusedImport]
|
||||||
annotation,
|
annotation, # pyright: ignore[reportUnusedImport]
|
||||||
app,
|
app, # pyright: ignore[reportUnusedImport]
|
||||||
audio,
|
audio, # pyright: ignore[reportUnusedImport]
|
||||||
completion,
|
completion, # pyright: ignore[reportUnusedImport]
|
||||||
conversation,
|
conversation, # pyright: ignore[reportUnusedImport]
|
||||||
conversation_variables,
|
conversation_variables, # pyright: ignore[reportUnusedImport]
|
||||||
generator,
|
generator, # pyright: ignore[reportUnusedImport]
|
||||||
mcp_server,
|
mcp_server, # pyright: ignore[reportUnusedImport]
|
||||||
message,
|
message, # pyright: ignore[reportUnusedImport]
|
||||||
model_config,
|
model_config, # pyright: ignore[reportUnusedImport]
|
||||||
ops_trace,
|
ops_trace, # pyright: ignore[reportUnusedImport]
|
||||||
site,
|
site, # pyright: ignore[reportUnusedImport]
|
||||||
statistic,
|
statistic, # pyright: ignore[reportUnusedImport]
|
||||||
workflow,
|
workflow, # pyright: ignore[reportUnusedImport]
|
||||||
workflow_app_log,
|
workflow_app_log, # pyright: ignore[reportUnusedImport]
|
||||||
workflow_draft_variable,
|
workflow_draft_variable, # pyright: ignore[reportUnusedImport]
|
||||||
workflow_run,
|
workflow_run, # pyright: ignore[reportUnusedImport]
|
||||||
workflow_statistic,
|
workflow_statistic, # pyright: ignore[reportUnusedImport]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Import auth controllers
|
# Import auth controllers
|
||||||
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth, oauth_server
|
from .auth import (
|
||||||
|
activate, # pyright: ignore[reportUnusedImport]
|
||||||
|
data_source_bearer_auth, # pyright: ignore[reportUnusedImport]
|
||||||
|
data_source_oauth, # pyright: ignore[reportUnusedImport]
|
||||||
|
forgot_password, # pyright: ignore[reportUnusedImport]
|
||||||
|
login, # pyright: ignore[reportUnusedImport]
|
||||||
|
oauth, # pyright: ignore[reportUnusedImport]
|
||||||
|
oauth_server, # pyright: ignore[reportUnusedImport]
|
||||||
|
)
|
||||||
|
|
||||||
# Import billing controllers
|
# Import billing controllers
|
||||||
from .billing import billing, compliance
|
from .billing import billing, compliance # pyright: ignore[reportUnusedImport]
|
||||||
|
|
||||||
# Import datasets controllers
|
# Import datasets controllers
|
||||||
from .datasets import (
|
from .datasets import (
|
||||||
data_source,
|
data_source, # pyright: ignore[reportUnusedImport]
|
||||||
datasets,
|
datasets, # pyright: ignore[reportUnusedImport]
|
||||||
datasets_document,
|
datasets_document, # pyright: ignore[reportUnusedImport]
|
||||||
datasets_segments,
|
datasets_segments, # pyright: ignore[reportUnusedImport]
|
||||||
external,
|
external, # pyright: ignore[reportUnusedImport]
|
||||||
hit_testing,
|
hit_testing, # pyright: ignore[reportUnusedImport]
|
||||||
metadata,
|
metadata, # pyright: ignore[reportUnusedImport]
|
||||||
website,
|
website, # pyright: ignore[reportUnusedImport]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Import explore controllers
|
# Import explore controllers
|
||||||
from .explore import (
|
from .explore import (
|
||||||
installed_app,
|
installed_app, # pyright: ignore[reportUnusedImport]
|
||||||
parameter,
|
parameter, # pyright: ignore[reportUnusedImport]
|
||||||
recommended_app,
|
recommended_app, # pyright: ignore[reportUnusedImport]
|
||||||
saved_message,
|
saved_message, # pyright: ignore[reportUnusedImport]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Explore Audio
|
# Explore Audio
|
||||||
|
|
@ -167,18 +175,18 @@ api.add_resource(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Import tag controllers
|
# Import tag controllers
|
||||||
from .tag import tags
|
from .tag import tags # pyright: ignore[reportUnusedImport]
|
||||||
|
|
||||||
# Import workspace controllers
|
# Import workspace controllers
|
||||||
from .workspace import (
|
from .workspace import (
|
||||||
account,
|
account, # pyright: ignore[reportUnusedImport]
|
||||||
agent_providers,
|
agent_providers, # pyright: ignore[reportUnusedImport]
|
||||||
endpoint,
|
endpoint, # pyright: ignore[reportUnusedImport]
|
||||||
load_balancing_config,
|
load_balancing_config, # pyright: ignore[reportUnusedImport]
|
||||||
members,
|
members, # pyright: ignore[reportUnusedImport]
|
||||||
model_providers,
|
model_providers, # pyright: ignore[reportUnusedImport]
|
||||||
models,
|
models, # pyright: ignore[reportUnusedImport]
|
||||||
plugin,
|
plugin, # pyright: ignore[reportUnusedImport]
|
||||||
tool_providers,
|
tool_providers, # pyright: ignore[reportUnusedImport]
|
||||||
workspace,
|
workspace, # pyright: ignore[reportUnusedImport]
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
from typing import Any, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import flask_restx
|
import flask_restx
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restx import Resource, fields, marshal_with
|
from flask_restx import Resource, fields, marshal_with
|
||||||
|
from flask_restx._http import HTTPStatus
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
@ -40,7 +41,7 @@ def _get_resource(resource_id, tenant_id, resource_model):
|
||||||
).scalar_one_or_none()
|
).scalar_one_or_none()
|
||||||
|
|
||||||
if resource is None:
|
if resource is None:
|
||||||
flask_restx.abort(404, message=f"{resource_model.__name__} not found.")
|
flask_restx.abort(HTTPStatus.NOT_FOUND, message=f"{resource_model.__name__} not found.")
|
||||||
|
|
||||||
return resource
|
return resource
|
||||||
|
|
||||||
|
|
@ -49,7 +50,7 @@ class BaseApiKeyListResource(Resource):
|
||||||
method_decorators = [account_initialization_required, login_required, setup_required]
|
method_decorators = [account_initialization_required, login_required, setup_required]
|
||||||
|
|
||||||
resource_type: str | None = None
|
resource_type: str | None = None
|
||||||
resource_model: Optional[Any] = None
|
resource_model: Optional[type] = None
|
||||||
resource_id_field: str | None = None
|
resource_id_field: str | None = None
|
||||||
token_prefix: str | None = None
|
token_prefix: str | None = None
|
||||||
max_keys = 10
|
max_keys = 10
|
||||||
|
|
@ -82,7 +83,7 @@ class BaseApiKeyListResource(Resource):
|
||||||
|
|
||||||
if current_key_count >= self.max_keys:
|
if current_key_count >= self.max_keys:
|
||||||
flask_restx.abort(
|
flask_restx.abort(
|
||||||
400,
|
HTTPStatus.BAD_REQUEST,
|
||||||
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
||||||
custom="max_keys_exceeded",
|
custom="max_keys_exceeded",
|
||||||
)
|
)
|
||||||
|
|
@ -102,7 +103,7 @@ class BaseApiKeyResource(Resource):
|
||||||
method_decorators = [account_initialization_required, login_required, setup_required]
|
method_decorators = [account_initialization_required, login_required, setup_required]
|
||||||
|
|
||||||
resource_type: str | None = None
|
resource_type: str | None = None
|
||||||
resource_model: Optional[Any] = None
|
resource_model: Optional[type] = None
|
||||||
resource_id_field: str | None = None
|
resource_id_field: str | None = None
|
||||||
|
|
||||||
def delete(self, resource_id, api_key_id):
|
def delete(self, resource_id, api_key_id):
|
||||||
|
|
@ -126,7 +127,7 @@ class BaseApiKeyResource(Resource):
|
||||||
)
|
)
|
||||||
|
|
||||||
if key is None:
|
if key is None:
|
||||||
flask_restx.abort(404, message="API key not found")
|
flask_restx.abort(HTTPStatus.NOT_FOUND, message="API key not found")
|
||||||
|
|
||||||
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
|
||||||
|
|
@ -115,6 +115,10 @@ class AppListApi(Resource):
|
||||||
raise BadRequest("mode is required")
|
raise BadRequest("mode is required")
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
|
if current_user.current_tenant_id is None:
|
||||||
|
raise ValueError("current_user.current_tenant_id cannot be None")
|
||||||
app = app_service.create_app(current_user.current_tenant_id, args, current_user)
|
app = app_service.create_app(current_user.current_tenant_id, args, current_user)
|
||||||
|
|
||||||
return app, 201
|
return app, 201
|
||||||
|
|
@ -161,14 +165,26 @@ class AppApi(Resource):
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
app_model = app_service.update_app(app_model, args)
|
# Construct ArgsDict from parsed arguments
|
||||||
|
from services.app_service import AppService as AppServiceType
|
||||||
|
|
||||||
|
args_dict: AppServiceType.ArgsDict = {
|
||||||
|
"name": args["name"],
|
||||||
|
"description": args.get("description", ""),
|
||||||
|
"icon_type": args.get("icon_type", ""),
|
||||||
|
"icon": args.get("icon", ""),
|
||||||
|
"icon_background": args.get("icon_background", ""),
|
||||||
|
"use_icon_as_answer_icon": args.get("use_icon_as_answer_icon", False),
|
||||||
|
"max_active_requests": args.get("max_active_requests", 0),
|
||||||
|
}
|
||||||
|
app_model = app_service.update_app(app_model, args_dict)
|
||||||
|
|
||||||
return app_model
|
return app_model
|
||||||
|
|
||||||
|
@get_app_model
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model
|
|
||||||
def delete(self, app_model):
|
def delete(self, app_model):
|
||||||
"""Delete app"""
|
"""Delete app"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
|
|
@ -224,10 +240,10 @@ class AppCopyApi(Resource):
|
||||||
|
|
||||||
|
|
||||||
class AppExportApi(Resource):
|
class AppExportApi(Resource):
|
||||||
|
@get_app_model
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model
|
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
"""Export app"""
|
"""Export app"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
|
|
@ -263,7 +279,7 @@ class AppNameApi(Resource):
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
app_model = app_service.update_app_name(app_model, args.get("name"))
|
app_model = app_service.update_app_name(app_model, args["name"])
|
||||||
|
|
||||||
return app_model
|
return app_model
|
||||||
|
|
||||||
|
|
@ -285,7 +301,7 @@ class AppIconApi(Resource):
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
app_model = app_service.update_app_icon(app_model, args.get("icon"), args.get("icon_background"))
|
app_model = app_service.update_app_icon(app_model, args.get("icon") or "", args.get("icon_background") or "")
|
||||||
|
|
||||||
return app_model
|
return app_model
|
||||||
|
|
||||||
|
|
@ -306,7 +322,7 @@ class AppSiteStatus(Resource):
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
app_model = app_service.update_app_site_status(app_model, args.get("enable_site"))
|
app_model = app_service.update_app_site_status(app_model, args["enable_site"])
|
||||||
|
|
||||||
return app_model
|
return app_model
|
||||||
|
|
||||||
|
|
@ -327,7 +343,7 @@ class AppApiStatus(Resource):
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
app_model = app_service.update_app_api_status(app_model, args.get("enable_api"))
|
app_model = app_service.update_app_api_status(app_model, args["enable_api"])
|
||||||
|
|
||||||
return app_model
|
return app_model
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -77,10 +77,10 @@ class ChatMessageAudioApi(Resource):
|
||||||
|
|
||||||
|
|
||||||
class ChatMessageTextApi(Resource):
|
class ChatMessageTextApi(Resource):
|
||||||
|
@get_app_model
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model
|
|
||||||
def post(self, app_model: App):
|
def post(self, app_model: App):
|
||||||
try:
|
try:
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
|
@ -125,10 +125,10 @@ class ChatMessageTextApi(Resource):
|
||||||
|
|
||||||
|
|
||||||
class TextModesApi(Resource):
|
class TextModesApi(Resource):
|
||||||
|
@get_app_model
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model
|
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
try:
|
try:
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import flask_login
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
from werkzeug.exceptions import InternalServerError, NotFound
|
from werkzeug.exceptions import InternalServerError, NotFound
|
||||||
|
|
@ -29,7 +28,8 @@ from core.helper.trace_id_helper import get_external_trace_id
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.helper import uuid_value
|
from libs.helper import uuid_value
|
||||||
from libs.login import login_required
|
from libs.login import current_user, login_required
|
||||||
|
from models import Account
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
from services.app_generate_service import AppGenerateService
|
from services.app_generate_service import AppGenerateService
|
||||||
from services.errors.llm import InvokeRateLimitError
|
from services.errors.llm import InvokeRateLimitError
|
||||||
|
|
@ -56,11 +56,11 @@ class CompletionMessageApi(Resource):
|
||||||
streaming = args["response_mode"] != "blocking"
|
streaming = args["response_mode"] != "blocking"
|
||||||
args["auto_generate_name"] = False
|
args["auto_generate_name"] = False
|
||||||
|
|
||||||
account = flask_login.current_user
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account or EndUser instance")
|
||||||
response = AppGenerateService.generate(
|
response = AppGenerateService.generate(
|
||||||
app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
|
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
|
||||||
)
|
)
|
||||||
|
|
||||||
return helper.compact_generate_response(response)
|
return helper.compact_generate_response(response)
|
||||||
|
|
@ -92,9 +92,9 @@ class CompletionMessageStopApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=AppMode.COMPLETION)
|
@get_app_model(mode=AppMode.COMPLETION)
|
||||||
def post(self, app_model, task_id):
|
def post(self, app_model, task_id):
|
||||||
account = flask_login.current_user
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||||
|
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
@ -123,11 +123,11 @@ class ChatMessageApi(Resource):
|
||||||
if external_trace_id:
|
if external_trace_id:
|
||||||
args["external_trace_id"] = external_trace_id
|
args["external_trace_id"] = external_trace_id
|
||||||
|
|
||||||
account = flask_login.current_user
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account or EndUser instance")
|
||||||
response = AppGenerateService.generate(
|
response = AppGenerateService.generate(
|
||||||
app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
|
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
|
||||||
)
|
)
|
||||||
|
|
||||||
return helper.compact_generate_response(response)
|
return helper.compact_generate_response(response)
|
||||||
|
|
@ -161,9 +161,9 @@ class ChatMessageStopApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||||
def post(self, app_model, task_id):
|
def post(self, app_model, task_id):
|
||||||
account = flask_login.current_user
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||||
|
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ from fields.conversation_fields import (
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.helper import DatetimeString
|
from libs.helper import DatetimeString
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models import Conversation, EndUser, Message, MessageAnnotation
|
from models import Account, Conversation, EndUser, Message, MessageAnnotation
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
from services.conversation_service import ConversationService
|
from services.conversation_service import ConversationService
|
||||||
from services.errors.conversation import ConversationNotExistsError
|
from services.errors.conversation import ConversationNotExistsError
|
||||||
|
|
@ -124,6 +124,8 @@ class CompletionConversationDetailApi(Resource):
|
||||||
conversation_id = str(conversation_id)
|
conversation_id = str(conversation_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
ConversationService.delete(app_model, conversation_id, current_user)
|
ConversationService.delete(app_model, conversation_id, current_user)
|
||||||
except ConversationNotExistsError:
|
except ConversationNotExistsError:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
|
@ -282,6 +284,8 @@ class ChatConversationDetailApi(Resource):
|
||||||
conversation_id = str(conversation_id)
|
conversation_id = str(conversation_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
ConversationService.delete(app_model, conversation_id, current_user)
|
ConversationService.delete(app_model, conversation_id, current_user)
|
||||||
except ConversationNotExistsError:
|
except ConversationNotExistsError:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||||
from flask_restx.inputs import int_range
|
from flask_restx.inputs import int_range
|
||||||
from sqlalchemy import exists, select
|
from sqlalchemy import exists, select
|
||||||
|
|
@ -27,7 +26,8 @@ from extensions.ext_database import db
|
||||||
from fields.conversation_fields import annotation_fields, message_detail_fields
|
from fields.conversation_fields import annotation_fields, message_detail_fields
|
||||||
from libs.helper import uuid_value
|
from libs.helper import uuid_value
|
||||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||||
from libs.login import login_required
|
from libs.login import current_user, login_required
|
||||||
|
from models.account import Account
|
||||||
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||||
from services.annotation_service import AppAnnotationService
|
from services.annotation_service import AppAnnotationService
|
||||||
from services.errors.conversation import ConversationNotExistsError
|
from services.errors.conversation import ConversationNotExistsError
|
||||||
|
|
@ -118,11 +118,14 @@ class ChatMessageListApi(Resource):
|
||||||
|
|
||||||
|
|
||||||
class MessageFeedbackApi(Resource):
|
class MessageFeedbackApi(Resource):
|
||||||
|
@get_app_model
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model
|
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
|
if current_user is None:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("message_id", required=True, type=uuid_value, location="json")
|
parser.add_argument("message_id", required=True, type=uuid_value, location="json")
|
||||||
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||||
|
|
@ -167,6 +170,8 @@ class MessageAnnotationApi(Resource):
|
||||||
@get_app_model
|
@get_app_model
|
||||||
@marshal_with(annotation_fields)
|
@marshal_with(annotation_fields)
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise Forbidden()
|
||||||
if not current_user.is_editor:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
|
|
@ -182,10 +187,10 @@ class MessageAnnotationApi(Resource):
|
||||||
|
|
||||||
|
|
||||||
class MessageAnnotationCountApi(Resource):
|
class MessageAnnotationCountApi(Resource):
|
||||||
|
@get_app_model
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model
|
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count()
|
count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from extensions.ext_database import db
|
||||||
from fields.app_fields import app_site_fields
|
from fields.app_fields import app_site_fields
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models import Site
|
from models import Account, Site
|
||||||
|
|
||||||
|
|
||||||
def parse_app_site_args():
|
def parse_app_site_args():
|
||||||
|
|
@ -75,6 +75,8 @@ class AppSite(Resource):
|
||||||
if value is not None:
|
if value is not None:
|
||||||
setattr(site, attr_name, value)
|
setattr(site, attr_name, value)
|
||||||
|
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
site.updated_by = current_user.id
|
site.updated_by = current_user.id
|
||||||
site.updated_at = naive_utc_now()
|
site.updated_at = naive_utc_now()
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
@ -99,6 +101,8 @@ class AppSiteAccessTokenReset(Resource):
|
||||||
raise NotFound
|
raise NotFound
|
||||||
|
|
||||||
site.code = Site.generate_code(16)
|
site.code = Site.generate_code(16)
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
site.updated_by = current_user.id
|
site.updated_by = current_user.id
|
||||||
site.updated_at = naive_utc_now()
|
site.updated_at = naive_utc_now()
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
|
||||||
|
|
@ -18,10 +18,10 @@ from models import AppMode, Message
|
||||||
|
|
||||||
|
|
||||||
class DailyMessageStatistic(Resource):
|
class DailyMessageStatistic(Resource):
|
||||||
|
@get_app_model
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model
|
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
|
|
@ -75,10 +75,10 @@ WHERE
|
||||||
|
|
||||||
|
|
||||||
class DailyConversationStatistic(Resource):
|
class DailyConversationStatistic(Resource):
|
||||||
|
@get_app_model
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model
|
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
|
|
@ -127,10 +127,10 @@ class DailyConversationStatistic(Resource):
|
||||||
|
|
||||||
|
|
||||||
class DailyTerminalsStatistic(Resource):
|
class DailyTerminalsStatistic(Resource):
|
||||||
|
@get_app_model
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model
|
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
|
|
@ -184,10 +184,10 @@ WHERE
|
||||||
|
|
||||||
|
|
||||||
class DailyTokenCostStatistic(Resource):
|
class DailyTokenCostStatistic(Resource):
|
||||||
|
@get_app_model
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model
|
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
|
|
@ -320,10 +320,10 @@ ORDER BY
|
||||||
|
|
||||||
|
|
||||||
class UserSatisfactionRateStatistic(Resource):
|
class UserSatisfactionRateStatistic(Resource):
|
||||||
|
@get_app_model
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model
|
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
|
|
@ -443,10 +443,10 @@ WHERE
|
||||||
|
|
||||||
|
|
||||||
class TokensPerSecondStatistic(Resource):
|
class TokensPerSecondStatistic(Resource):
|
||||||
|
@get_app_model
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model
|
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,10 +18,10 @@ from models.model import AppMode
|
||||||
|
|
||||||
|
|
||||||
class WorkflowDailyRunsStatistic(Resource):
|
class WorkflowDailyRunsStatistic(Resource):
|
||||||
|
@get_app_model
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model
|
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
|
|
@ -80,10 +80,10 @@ WHERE
|
||||||
|
|
||||||
|
|
||||||
class WorkflowDailyTerminalsStatistic(Resource):
|
class WorkflowDailyTerminalsStatistic(Resource):
|
||||||
|
@get_app_model
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model
|
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
|
|
@ -142,10 +142,10 @@ WHERE
|
||||||
|
|
||||||
|
|
||||||
class WorkflowDailyTokenCostStatistic(Resource):
|
class WorkflowDailyTokenCostStatistic(Resource):
|
||||||
|
@get_app_model
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model
|
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -77,6 +77,9 @@ class OAuthCallback(Resource):
|
||||||
if state:
|
if state:
|
||||||
invite_token = state
|
invite_token = state
|
||||||
|
|
||||||
|
if not code:
|
||||||
|
return {"error": "Authorization code is required"}, 400
|
||||||
|
|
||||||
try:
|
try:
|
||||||
token = oauth_provider.get_access_token(code)
|
token = oauth_provider.get_access_token(code)
|
||||||
user_info = oauth_provider.get_user_info(token)
|
user_info = oauth_provider.get_user_info(token)
|
||||||
|
|
@ -86,7 +89,7 @@ class OAuthCallback(Resource):
|
||||||
return {"error": "OAuth process failed"}, 400
|
return {"error": "OAuth process failed"}, 400
|
||||||
|
|
||||||
if invite_token and RegisterService.is_valid_invite_token(invite_token):
|
if invite_token and RegisterService.is_valid_invite_token(invite_token):
|
||||||
invitation = RegisterService._get_invitation_by_token(token=invite_token)
|
invitation = RegisterService.get_invitation_by_token(token=invite_token)
|
||||||
if invitation:
|
if invitation:
|
||||||
invitation_email = invitation.get("email", None)
|
invitation_email = invitation.get("email", None)
|
||||||
if invitation_email != user_info.email:
|
if invitation_email != user_info.email:
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import reqparse
|
from flask_restx import reqparse
|
||||||
from werkzeug.exceptions import InternalServerError, NotFound
|
from werkzeug.exceptions import InternalServerError, NotFound
|
||||||
|
|
||||||
|
|
@ -28,6 +27,8 @@ from extensions.ext_database import db
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.helper import uuid_value
|
from libs.helper import uuid_value
|
||||||
|
from libs.login import current_user
|
||||||
|
from models import Account
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
from services.app_generate_service import AppGenerateService
|
from services.app_generate_service import AppGenerateService
|
||||||
from services.errors.llm import InvokeRateLimitError
|
from services.errors.llm import InvokeRateLimitError
|
||||||
|
|
@ -57,6 +58,8 @@ class CompletionApi(InstalledAppResource):
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
response = AppGenerateService.generate(
|
response = AppGenerateService.generate(
|
||||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
|
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
|
||||||
)
|
)
|
||||||
|
|
@ -90,6 +93,8 @@ class CompletionStopApi(InstalledAppResource):
|
||||||
if app_model.mode != "completion":
|
if app_model.mode != "completion":
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||||
|
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
@ -117,6 +122,8 @@ class ChatApi(InstalledAppResource):
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
response = AppGenerateService.generate(
|
response = AppGenerateService.generate(
|
||||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||||
)
|
)
|
||||||
|
|
@ -153,6 +160,8 @@ class ChatStopApi(InstalledAppResource):
|
||||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||||
|
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import marshal_with, reqparse
|
from flask_restx import marshal_with, reqparse
|
||||||
from flask_restx.inputs import int_range
|
from flask_restx.inputs import int_range
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
@ -10,6 +9,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||||
from libs.helper import uuid_value
|
from libs.helper import uuid_value
|
||||||
|
from libs.login import current_user
|
||||||
|
from models import Account
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
from services.conversation_service import ConversationService
|
from services.conversation_service import ConversationService
|
||||||
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
|
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
|
||||||
|
|
@ -35,6 +36,8 @@ class ConversationListApi(InstalledAppResource):
|
||||||
pinned = args["pinned"] == "true"
|
pinned = args["pinned"] == "true"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
return WebConversationService.pagination_by_last_id(
|
return WebConversationService.pagination_by_last_id(
|
||||||
session=session,
|
session=session,
|
||||||
|
|
@ -58,6 +61,8 @@ class ConversationApi(InstalledAppResource):
|
||||||
|
|
||||||
conversation_id = str(c_id)
|
conversation_id = str(c_id)
|
||||||
try:
|
try:
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
ConversationService.delete(app_model, conversation_id, current_user)
|
ConversationService.delete(app_model, conversation_id, current_user)
|
||||||
except ConversationNotExistsError:
|
except ConversationNotExistsError:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
|
@ -81,6 +86,8 @@ class ConversationRenameApi(InstalledAppResource):
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
return ConversationService.rename(
|
return ConversationService.rename(
|
||||||
app_model, conversation_id, current_user, args["name"], args["auto_generate"]
|
app_model, conversation_id, current_user, args["name"], args["auto_generate"]
|
||||||
)
|
)
|
||||||
|
|
@ -98,6 +105,8 @@ class ConversationPinApi(InstalledAppResource):
|
||||||
conversation_id = str(c_id)
|
conversation_id = str(c_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
WebConversationService.pin(app_model, conversation_id, current_user)
|
WebConversationService.pin(app_model, conversation_id, current_user)
|
||||||
except ConversationNotExistsError:
|
except ConversationNotExistsError:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
|
@ -113,6 +122,8 @@ class ConversationUnPinApi(InstalledAppResource):
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
conversation_id = str(c_id)
|
conversation_id = str(c_id)
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
WebConversationService.unpin(app_model, conversation_id, current_user)
|
WebConversationService.unpin(app_model, conversation_id, current_user)
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@ import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, inputs, marshal_with, reqparse
|
from flask_restx import Resource, inputs, marshal_with, reqparse
|
||||||
from sqlalchemy import and_
|
from sqlalchemy import and_
|
||||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||||
|
|
@ -13,8 +12,8 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.installed_app_fields import installed_app_list_fields
|
from fields.installed_app_fields import installed_app_list_fields
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.login import login_required
|
from libs.login import current_user, login_required
|
||||||
from models import App, InstalledApp, RecommendedApp
|
from models import Account, App, InstalledApp, RecommendedApp
|
||||||
from services.account_service import TenantService
|
from services.account_service import TenantService
|
||||||
from services.app_service import AppService
|
from services.app_service import AppService
|
||||||
from services.enterprise.enterprise_service import EnterpriseService
|
from services.enterprise.enterprise_service import EnterpriseService
|
||||||
|
|
@ -29,6 +28,8 @@ class InstalledAppsListApi(Resource):
|
||||||
@marshal_with(installed_app_list_fields)
|
@marshal_with(installed_app_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
app_id = request.args.get("app_id", default=None, type=str)
|
app_id = request.args.get("app_id", default=None, type=str)
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
current_tenant_id = current_user.current_tenant_id
|
current_tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
if app_id:
|
if app_id:
|
||||||
|
|
@ -40,6 +41,8 @@ class InstalledAppsListApi(Resource):
|
||||||
else:
|
else:
|
||||||
installed_apps = db.session.query(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id).all()
|
installed_apps = db.session.query(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id).all()
|
||||||
|
|
||||||
|
if current_user.current_tenant is None:
|
||||||
|
raise ValueError("current_user.current_tenant must not be None")
|
||||||
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
|
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
|
||||||
installed_app_list: list[dict[str, Any]] = [
|
installed_app_list: list[dict[str, Any]] = [
|
||||||
{
|
{
|
||||||
|
|
@ -115,6 +118,8 @@ class InstalledAppsListApi(Resource):
|
||||||
if recommended_app is None:
|
if recommended_app is None:
|
||||||
raise NotFound("App not found")
|
raise NotFound("App not found")
|
||||||
|
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
current_tenant_id = current_user.current_tenant_id
|
current_tenant_id = current_user.current_tenant_id
|
||||||
app = db.session.query(App).where(App.id == args["app_id"]).first()
|
app = db.session.query(App).where(App.id == args["app_id"]).first()
|
||||||
|
|
||||||
|
|
@ -154,6 +159,8 @@ class InstalledAppApi(InstalledAppResource):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def delete(self, installed_app):
|
def delete(self, installed_app):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
if installed_app.app_owner_tenant_id == current_user.current_tenant_id:
|
if installed_app.app_owner_tenant_id == current_user.current_tenant_id:
|
||||||
raise BadRequest("You can't uninstall an app owned by the current tenant")
|
raise BadRequest("You can't uninstall an app owned by the current tenant")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import marshal_with, reqparse
|
from flask_restx import marshal_with, reqparse
|
||||||
from flask_restx.inputs import int_range
|
from flask_restx.inputs import int_range
|
||||||
from werkzeug.exceptions import InternalServerError, NotFound
|
from werkzeug.exceptions import InternalServerError, NotFound
|
||||||
|
|
@ -24,6 +23,8 @@ from core.model_runtime.errors.invoke import InvokeError
|
||||||
from fields.message_fields import message_infinite_scroll_pagination_fields
|
from fields.message_fields import message_infinite_scroll_pagination_fields
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.helper import uuid_value
|
from libs.helper import uuid_value
|
||||||
|
from libs.login import current_user
|
||||||
|
from models import Account
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
from services.app_generate_service import AppGenerateService
|
from services.app_generate_service import AppGenerateService
|
||||||
from services.errors.app import MoreLikeThisDisabledError
|
from services.errors.app import MoreLikeThisDisabledError
|
||||||
|
|
@ -54,6 +55,8 @@ class MessageListApi(InstalledAppResource):
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
return MessageService.pagination_by_first_id(
|
return MessageService.pagination_by_first_id(
|
||||||
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
|
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
|
||||||
)
|
)
|
||||||
|
|
@ -75,6 +78,8 @@ class MessageFeedbackApi(InstalledAppResource):
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
MessageService.create_feedback(
|
MessageService.create_feedback(
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
|
|
@ -105,6 +110,8 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
||||||
streaming = args["response_mode"] == "streaming"
|
streaming = args["response_mode"] == "streaming"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
response = AppGenerateService.generate_more_like_this(
|
response = AppGenerateService.generate_more_like_this(
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
user=current_user,
|
user=current_user,
|
||||||
|
|
@ -142,6 +149,8 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
|
||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
questions = MessageService.get_suggested_questions_after_answer(
|
questions = MessageService.get_suggested_questions_after_answer(
|
||||||
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
|
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,10 @@
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||||
|
|
||||||
from constants.languages import languages
|
from constants.languages import languages
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from libs.helper import AppIconUrlField
|
from libs.helper import AppIconUrlField
|
||||||
from libs.login import login_required
|
from libs.login import current_user, login_required
|
||||||
from services.recommended_app_service import RecommendedAppService
|
from services.recommended_app_service import RecommendedAppService
|
||||||
|
|
||||||
app_fields = {
|
app_fields = {
|
||||||
|
|
@ -46,8 +45,9 @@ class RecommendedAppListApi(Resource):
|
||||||
parser.add_argument("language", type=str, location="args")
|
parser.add_argument("language", type=str, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.get("language") and args.get("language") in languages:
|
language = args.get("language")
|
||||||
language_prefix = args.get("language")
|
if language and language in languages:
|
||||||
|
language_prefix = language
|
||||||
elif current_user and current_user.interface_language:
|
elif current_user and current_user.interface_language:
|
||||||
language_prefix = current_user.interface_language
|
language_prefix = current_user.interface_language
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import fields, marshal_with, reqparse
|
from flask_restx import fields, marshal_with, reqparse
|
||||||
from flask_restx.inputs import int_range
|
from flask_restx.inputs import int_range
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
@ -8,6 +7,8 @@ from controllers.console.explore.error import NotCompletionAppError
|
||||||
from controllers.console.explore.wraps import InstalledAppResource
|
from controllers.console.explore.wraps import InstalledAppResource
|
||||||
from fields.conversation_fields import message_file_fields
|
from fields.conversation_fields import message_file_fields
|
||||||
from libs.helper import TimestampField, uuid_value
|
from libs.helper import TimestampField, uuid_value
|
||||||
|
from libs.login import current_user
|
||||||
|
from models import Account
|
||||||
from services.errors.message import MessageNotExistsError
|
from services.errors.message import MessageNotExistsError
|
||||||
from services.saved_message_service import SavedMessageService
|
from services.saved_message_service import SavedMessageService
|
||||||
|
|
||||||
|
|
@ -42,6 +43,8 @@ class SavedMessageListApi(InstalledAppResource):
|
||||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"])
|
return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"])
|
||||||
|
|
||||||
def post(self, installed_app):
|
def post(self, installed_app):
|
||||||
|
|
@ -54,6 +57,8 @@ class SavedMessageListApi(InstalledAppResource):
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
SavedMessageService.save(app_model, current_user, args["message_id"])
|
SavedMessageService.save(app_model, current_user, args["message_id"])
|
||||||
except MessageNotExistsError:
|
except MessageNotExistsError:
|
||||||
raise NotFound("Message Not Exists.")
|
raise NotFound("Message Not Exists.")
|
||||||
|
|
@ -70,6 +75,8 @@ class SavedMessageApi(InstalledAppResource):
|
||||||
if app_model.mode != "completion":
|
if app_model.mode != "completion":
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
SavedMessageService.delete(app_model, current_user, message_id)
|
SavedMessageService.delete(app_model, current_user, message_id)
|
||||||
|
|
||||||
return {"result": "success"}, 204
|
return {"result": "success"}, 204
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ from controllers.console.wraps import (
|
||||||
)
|
)
|
||||||
from fields.file_fields import file_fields, upload_config_fields
|
from fields.file_fields import file_fields, upload_config_fields
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
|
from models import Account
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
|
|
||||||
PREVIEW_WORDS_LIMIT = 3000
|
PREVIEW_WORDS_LIMIT = 3000
|
||||||
|
|
@ -68,6 +69,8 @@ class FileApi(Resource):
|
||||||
source = None
|
source = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
upload_file = FileService.upload_file(
|
upload_file = FileService.upload_file(
|
||||||
filename=file.filename,
|
filename=file.filename,
|
||||||
content=file.read(),
|
content=file.read(),
|
||||||
|
|
|
||||||
|
|
@ -34,14 +34,14 @@ class VersionApi(Resource):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.get(check_update_url, {"current_version": args.get("current_version")}, timeout=(3, 10))
|
response = requests.get(check_update_url, {"current_version": args["current_version"]}, timeout=(3, 10))
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
logger.warning("Check update version error: %s.", str(error))
|
logger.warning("Check update version error: %s.", str(error))
|
||||||
result["version"] = args.get("current_version")
|
result["version"] = args["current_version"]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
content = json.loads(response.content)
|
content = json.loads(response.content)
|
||||||
if _has_new_version(latest_version=content["version"], current_version=f"{args.get('current_version')}"):
|
if _has_new_version(latest_version=content["version"], current_version=f"{args['current_version']}"):
|
||||||
result["version"] = content["version"]
|
result["version"] = content["version"]
|
||||||
result["release_date"] = content["releaseDate"]
|
result["release_date"] = content["releaseDate"]
|
||||||
result["release_notes"] = content["releaseNotes"]
|
result["release_notes"] = content["releaseNotes"]
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,8 @@ class AccountInitApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
def post(self):
|
def post(self):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
if account.status == "active":
|
if account.status == "active":
|
||||||
|
|
@ -102,6 +104,8 @@ class AccountProfileApi(Resource):
|
||||||
@marshal_with(account_fields)
|
@marshal_with(account_fields)
|
||||||
@enterprise_license_required
|
@enterprise_license_required
|
||||||
def get(self):
|
def get(self):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -111,6 +115,8 @@ class AccountNameApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(account_fields)
|
@marshal_with(account_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("name", type=str, required=True, location="json")
|
parser.add_argument("name", type=str, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
@ -130,6 +136,8 @@ class AccountAvatarApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(account_fields)
|
@marshal_with(account_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("avatar", type=str, required=True, location="json")
|
parser.add_argument("avatar", type=str, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
@ -145,6 +153,8 @@ class AccountInterfaceLanguageApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(account_fields)
|
@marshal_with(account_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("interface_language", type=supported_language, required=True, location="json")
|
parser.add_argument("interface_language", type=supported_language, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
@ -160,6 +170,8 @@ class AccountInterfaceThemeApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(account_fields)
|
@marshal_with(account_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json")
|
parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
@ -175,6 +187,8 @@ class AccountTimezoneApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(account_fields)
|
@marshal_with(account_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("timezone", type=str, required=True, location="json")
|
parser.add_argument("timezone", type=str, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
@ -194,6 +208,8 @@ class AccountPasswordApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(account_fields)
|
@marshal_with(account_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("password", type=str, required=False, location="json")
|
parser.add_argument("password", type=str, required=False, location="json")
|
||||||
parser.add_argument("new_password", type=str, required=True, location="json")
|
parser.add_argument("new_password", type=str, required=True, location="json")
|
||||||
|
|
@ -228,6 +244,8 @@ class AccountIntegrateApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(integrate_list_fields)
|
@marshal_with(integrate_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
account_integrates = db.session.query(AccountIntegrate).where(AccountIntegrate.account_id == account.id).all()
|
account_integrates = db.session.query(AccountIntegrate).where(AccountIntegrate.account_id == account.id).all()
|
||||||
|
|
@ -268,6 +286,8 @@ class AccountDeleteVerifyApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
token, code = AccountService.generate_account_deletion_verification_code(account)
|
token, code = AccountService.generate_account_deletion_verification_code(account)
|
||||||
|
|
@ -281,6 +301,8 @@ class AccountDeleteApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
|
@ -321,6 +343,8 @@ class EducationVerifyApi(Resource):
|
||||||
@cloud_edition_billing_enabled
|
@cloud_edition_billing_enabled
|
||||||
@marshal_with(verify_fields)
|
@marshal_with(verify_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
return BillingService.EducationIdentity.verify(account.id, account.email)
|
return BillingService.EducationIdentity.verify(account.id, account.email)
|
||||||
|
|
@ -340,6 +364,8 @@ class EducationApi(Resource):
|
||||||
@only_edition_cloud
|
@only_edition_cloud
|
||||||
@cloud_edition_billing_enabled
|
@cloud_edition_billing_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
|
@ -357,6 +383,8 @@ class EducationApi(Resource):
|
||||||
@cloud_edition_billing_enabled
|
@cloud_edition_billing_enabled
|
||||||
@marshal_with(status_fields)
|
@marshal_with(status_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
res = BillingService.EducationIdentity.status(account.id)
|
res = BillingService.EducationIdentity.status(account.id)
|
||||||
|
|
@ -421,6 +449,8 @@ class ChangeEmailSendEmailApi(Resource):
|
||||||
raise InvalidTokenError()
|
raise InvalidTokenError()
|
||||||
user_email = reset_data.get("email", "")
|
user_email = reset_data.get("email", "")
|
||||||
|
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
if user_email != current_user.email:
|
if user_email != current_user.email:
|
||||||
raise InvalidEmailError()
|
raise InvalidEmailError()
|
||||||
else:
|
else:
|
||||||
|
|
@ -501,6 +531,8 @@ class ChangeEmailResetApi(Resource):
|
||||||
AccountService.revoke_change_email_token(args["token"])
|
AccountService.revoke_change_email_token(args["token"])
|
||||||
|
|
||||||
old_email = reset_data.get("old_email", "")
|
old_email = reset_data.get("old_email", "")
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
if current_user.email != old_email:
|
if current_user.email != old_email:
|
||||||
raise AccountNotFound()
|
raise AccountNotFound()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
from urllib import parse
|
from urllib import parse
|
||||||
|
|
||||||
from flask import request
|
from flask import abort, request
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restx import Resource, abort, marshal_with, reqparse
|
from flask_restx import Resource, marshal_with, reqparse
|
||||||
|
|
||||||
import services
|
import services
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
|
@ -41,6 +41,10 @@ class MemberListApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(account_with_role_list_fields)
|
@marshal_with(account_with_role_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
|
if not current_user.current_tenant:
|
||||||
|
raise ValueError("No current tenant")
|
||||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||||
return {"result": "success", "accounts": members}, 200
|
return {"result": "success", "accounts": members}, 200
|
||||||
|
|
||||||
|
|
@ -65,7 +69,11 @@ class MemberInviteEmailApi(Resource):
|
||||||
if not TenantAccountRole.is_non_owner_role(invitee_role):
|
if not TenantAccountRole.is_non_owner_role(invitee_role):
|
||||||
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
||||||
|
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
inviter = current_user
|
inviter = current_user
|
||||||
|
if not inviter.current_tenant:
|
||||||
|
raise ValueError("No current tenant")
|
||||||
invitation_results = []
|
invitation_results = []
|
||||||
console_web_url = dify_config.CONSOLE_WEB_URL
|
console_web_url = dify_config.CONSOLE_WEB_URL
|
||||||
|
|
||||||
|
|
@ -76,6 +84,8 @@ class MemberInviteEmailApi(Resource):
|
||||||
|
|
||||||
for invitee_email in invitee_emails:
|
for invitee_email in invitee_emails:
|
||||||
try:
|
try:
|
||||||
|
if not inviter.current_tenant:
|
||||||
|
raise ValueError("No current tenant")
|
||||||
token = RegisterService.invite_new_member(
|
token = RegisterService.invite_new_member(
|
||||||
inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter
|
inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter
|
||||||
)
|
)
|
||||||
|
|
@ -97,7 +107,7 @@ class MemberInviteEmailApi(Resource):
|
||||||
return {
|
return {
|
||||||
"result": "success",
|
"result": "success",
|
||||||
"invitation_results": invitation_results,
|
"invitation_results": invitation_results,
|
||||||
"tenant_id": str(current_user.current_tenant.id),
|
"tenant_id": str(inviter.current_tenant.id) if inviter.current_tenant else "",
|
||||||
}, 201
|
}, 201
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -108,6 +118,10 @@ class MemberCancelInviteApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, member_id):
|
def delete(self, member_id):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
|
if not current_user.current_tenant:
|
||||||
|
raise ValueError("No current tenant")
|
||||||
member = db.session.query(Account).where(Account.id == str(member_id)).first()
|
member = db.session.query(Account).where(Account.id == str(member_id)).first()
|
||||||
if member is None:
|
if member is None:
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
@ -123,7 +137,10 @@ class MemberCancelInviteApi(Resource):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(str(e))
|
raise ValueError(str(e))
|
||||||
|
|
||||||
return {"result": "success", "tenant_id": str(current_user.current_tenant.id)}, 200
|
return {
|
||||||
|
"result": "success",
|
||||||
|
"tenant_id": str(current_user.current_tenant.id) if current_user.current_tenant else "",
|
||||||
|
}, 200
|
||||||
|
|
||||||
|
|
||||||
class MemberUpdateRoleApi(Resource):
|
class MemberUpdateRoleApi(Resource):
|
||||||
|
|
@ -141,6 +158,10 @@ class MemberUpdateRoleApi(Resource):
|
||||||
if not TenantAccountRole.is_valid_role(new_role):
|
if not TenantAccountRole.is_valid_role(new_role):
|
||||||
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
||||||
|
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
|
if not current_user.current_tenant:
|
||||||
|
raise ValueError("No current tenant")
|
||||||
member = db.session.get(Account, str(member_id))
|
member = db.session.get(Account, str(member_id))
|
||||||
if not member:
|
if not member:
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
@ -164,6 +185,10 @@ class DatasetOperatorMemberListApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(account_with_role_list_fields)
|
@marshal_with(account_with_role_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
|
if not current_user.current_tenant:
|
||||||
|
raise ValueError("No current tenant")
|
||||||
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
|
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
|
||||||
return {"result": "success", "accounts": members}, 200
|
return {"result": "success", "accounts": members}, 200
|
||||||
|
|
||||||
|
|
@ -184,6 +209,10 @@ class SendOwnerTransferEmailApi(Resource):
|
||||||
raise EmailSendIpLimitError()
|
raise EmailSendIpLimitError()
|
||||||
|
|
||||||
# check if the current user is the owner of the workspace
|
# check if the current user is the owner of the workspace
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
|
if not current_user.current_tenant:
|
||||||
|
raise ValueError("No current tenant")
|
||||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||||
raise NotOwnerError()
|
raise NotOwnerError()
|
||||||
|
|
||||||
|
|
@ -198,7 +227,7 @@ class SendOwnerTransferEmailApi(Resource):
|
||||||
account=current_user,
|
account=current_user,
|
||||||
email=email,
|
email=email,
|
||||||
language=language,
|
language=language,
|
||||||
workspace_name=current_user.current_tenant.name,
|
workspace_name=current_user.current_tenant.name if current_user.current_tenant else "",
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"result": "success", "data": token}
|
return {"result": "success", "data": token}
|
||||||
|
|
@ -215,6 +244,10 @@ class OwnerTransferCheckApi(Resource):
|
||||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
# check if the current user is the owner of the workspace
|
# check if the current user is the owner of the workspace
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
|
if not current_user.current_tenant:
|
||||||
|
raise ValueError("No current tenant")
|
||||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||||
raise NotOwnerError()
|
raise NotOwnerError()
|
||||||
|
|
||||||
|
|
@ -256,6 +289,10 @@ class OwnerTransfer(Resource):
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# check if the current user is the owner of the workspace
|
# check if the current user is the owner of the workspace
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
|
if not current_user.current_tenant:
|
||||||
|
raise ValueError("No current tenant")
|
||||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||||
raise NotOwnerError()
|
raise NotOwnerError()
|
||||||
|
|
||||||
|
|
@ -274,9 +311,11 @@ class OwnerTransfer(Resource):
|
||||||
member = db.session.get(Account, str(member_id))
|
member = db.session.get(Account, str(member_id))
|
||||||
if not member:
|
if not member:
|
||||||
abort(404)
|
abort(404)
|
||||||
else:
|
return # Never reached, but helps type checker
|
||||||
member_account = member
|
|
||||||
if not TenantService.is_member(member_account, current_user.current_tenant):
|
if not current_user.current_tenant:
|
||||||
|
raise ValueError("No current tenant")
|
||||||
|
if not TenantService.is_member(member, current_user.current_tenant):
|
||||||
raise MemberNotInTenantError()
|
raise MemberNotInTenantError()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -286,13 +325,13 @@ class OwnerTransfer(Resource):
|
||||||
AccountService.send_new_owner_transfer_notify_email(
|
AccountService.send_new_owner_transfer_notify_email(
|
||||||
account=member,
|
account=member,
|
||||||
email=member.email,
|
email=member.email,
|
||||||
workspace_name=current_user.current_tenant.name,
|
workspace_name=current_user.current_tenant.name if current_user.current_tenant else "",
|
||||||
)
|
)
|
||||||
|
|
||||||
AccountService.send_old_owner_transfer_notify_email(
|
AccountService.send_old_owner_transfer_notify_email(
|
||||||
account=current_user,
|
account=current_user,
|
||||||
email=current_user.email,
|
email=current_user.email,
|
||||||
workspace_name=current_user.current_tenant.name,
|
workspace_name=current_user.current_tenant.name if current_user.current_tenant else "",
|
||||||
new_owner_email=member.email,
|
new_owner_email=member.email,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from libs.helper import StrLen, uuid_value
|
from libs.helper import StrLen, uuid_value
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
|
from models.account import Account
|
||||||
from services.billing_service import BillingService
|
from services.billing_service import BillingService
|
||||||
from services.model_provider_service import ModelProviderService
|
from services.model_provider_service import ModelProviderService
|
||||||
|
|
||||||
|
|
@ -21,6 +22,10 @@ class ModelProviderListApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
|
if not current_user.current_tenant_id:
|
||||||
|
raise ValueError("No current tenant")
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
|
@ -45,6 +50,10 @@ class ModelProviderCredentialApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider: str):
|
def get(self, provider: str):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
|
if not current_user.current_tenant_id:
|
||||||
|
raise ValueError("No current tenant")
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
# if credential_id is not provided, return current used credential
|
# if credential_id is not provided, return current used credential
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
|
@ -62,6 +71,8 @@ class ModelProviderCredentialApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider: str):
|
def post(self, provider: str):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
|
|
@ -72,6 +83,8 @@ class ModelProviderCredentialApi(Resource):
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
|
|
||||||
|
if not current_user.current_tenant_id:
|
||||||
|
raise ValueError("No current tenant")
|
||||||
try:
|
try:
|
||||||
model_provider_service.create_provider_credential(
|
model_provider_service.create_provider_credential(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
|
@ -88,6 +101,8 @@ class ModelProviderCredentialApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def put(self, provider: str):
|
def put(self, provider: str):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
|
|
@ -99,6 +114,8 @@ class ModelProviderCredentialApi(Resource):
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
|
|
||||||
|
if not current_user.current_tenant_id:
|
||||||
|
raise ValueError("No current tenant")
|
||||||
try:
|
try:
|
||||||
model_provider_service.update_provider_credential(
|
model_provider_service.update_provider_credential(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
|
@ -116,12 +133,16 @@ class ModelProviderCredentialApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, provider: str):
|
def delete(self, provider: str):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not current_user.current_tenant_id:
|
||||||
|
raise ValueError("No current tenant")
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
model_provider_service.remove_provider_credential(
|
model_provider_service.remove_provider_credential(
|
||||||
tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"]
|
tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"]
|
||||||
|
|
@ -135,12 +156,16 @@ class ModelProviderCredentialSwitchApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider: str):
|
def post(self, provider: str):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not current_user.current_tenant_id:
|
||||||
|
raise ValueError("No current tenant")
|
||||||
service = ModelProviderService()
|
service = ModelProviderService()
|
||||||
service.switch_active_provider_credential(
|
service.switch_active_provider_credential(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
|
@ -155,10 +180,14 @@ class ModelProviderValidateApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider: str):
|
def post(self, provider: str):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not current_user.current_tenant_id:
|
||||||
|
raise ValueError("No current tenant")
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
|
|
@ -205,9 +234,13 @@ class PreferredProviderTypeUpdateApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider: str):
|
def post(self, provider: str):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
|
if not current_user.current_tenant_id:
|
||||||
|
raise ValueError("No current tenant")
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
|
@ -236,7 +269,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
|
||||||
def get(self, provider: str):
|
def get(self, provider: str):
|
||||||
if provider != "anthropic":
|
if provider != "anthropic":
|
||||||
raise ValueError(f"provider name {provider} is invalid")
|
raise ValueError(f"provider name {provider} is invalid")
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
BillingService.is_tenant_owner_or_admin(current_user)
|
BillingService.is_tenant_owner_or_admin(current_user)
|
||||||
|
if not current_user.current_tenant_id:
|
||||||
|
raise ValueError("No current tenant")
|
||||||
data = BillingService.get_model_provider_payment_link(
|
data = BillingService.get_model_provider_payment_link(
|
||||||
provider_name=provider,
|
provider_name=provider,
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ from controllers.console.wraps import (
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.account import Tenant, TenantStatus
|
from models.account import Account, Tenant, TenantStatus
|
||||||
from services.account_service import TenantService
|
from services.account_service import TenantService
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
|
|
@ -70,6 +70,8 @@ class TenantListApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
tenants = TenantService.get_join_tenants(current_user)
|
tenants = TenantService.get_join_tenants(current_user)
|
||||||
tenant_dicts = []
|
tenant_dicts = []
|
||||||
|
|
||||||
|
|
@ -83,7 +85,7 @@ class TenantListApi(Resource):
|
||||||
"status": tenant.status,
|
"status": tenant.status,
|
||||||
"created_at": tenant.created_at,
|
"created_at": tenant.created_at,
|
||||||
"plan": features.billing.subscription.plan if features.billing.enabled else "sandbox",
|
"plan": features.billing.subscription.plan if features.billing.enabled else "sandbox",
|
||||||
"current": tenant.id == current_user.current_tenant_id,
|
"current": tenant.id == current_user.current_tenant_id if current_user.current_tenant_id else False,
|
||||||
}
|
}
|
||||||
|
|
||||||
tenant_dicts.append(tenant_dict)
|
tenant_dicts.append(tenant_dict)
|
||||||
|
|
@ -125,7 +127,11 @@ class TenantApi(Resource):
|
||||||
if request.path == "/info":
|
if request.path == "/info":
|
||||||
logger.warning("Deprecated URL /info was used.")
|
logger.warning("Deprecated URL /info was used.")
|
||||||
|
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
tenant = current_user.current_tenant
|
tenant = current_user.current_tenant
|
||||||
|
if not tenant:
|
||||||
|
raise ValueError("No current tenant")
|
||||||
|
|
||||||
if tenant.status == TenantStatus.ARCHIVE:
|
if tenant.status == TenantStatus.ARCHIVE:
|
||||||
tenants = TenantService.get_join_tenants(current_user)
|
tenants = TenantService.get_join_tenants(current_user)
|
||||||
|
|
@ -137,6 +143,8 @@ class TenantApi(Resource):
|
||||||
else:
|
else:
|
||||||
raise Unauthorized("workspace is archived")
|
raise Unauthorized("workspace is archived")
|
||||||
|
|
||||||
|
if not tenant:
|
||||||
|
raise ValueError("No tenant available")
|
||||||
return WorkspaceService.get_tenant_info(tenant), 200
|
return WorkspaceService.get_tenant_info(tenant), 200
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -145,6 +153,8 @@ class SwitchWorkspaceApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("tenant_id", type=str, required=True, location="json")
|
parser.add_argument("tenant_id", type=str, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
@ -168,11 +178,15 @@ class CustomConfigWorkspaceApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check("workspace_custom")
|
@cloud_edition_billing_resource_check("workspace_custom")
|
||||||
def post(self):
|
def post(self):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("remove_webapp_brand", type=bool, location="json")
|
parser.add_argument("remove_webapp_brand", type=bool, location="json")
|
||||||
parser.add_argument("replace_webapp_logo", type=str, location="json")
|
parser.add_argument("replace_webapp_logo", type=str, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not current_user.current_tenant_id:
|
||||||
|
raise ValueError("No current tenant")
|
||||||
tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
|
tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
|
||||||
|
|
||||||
custom_config_dict = {
|
custom_config_dict = {
|
||||||
|
|
@ -194,6 +208,8 @@ class WebappLogoWorkspaceApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check("workspace_custom")
|
@cloud_edition_billing_resource_check("workspace_custom")
|
||||||
def post(self):
|
def post(self):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
# check file
|
# check file
|
||||||
if "file" not in request.files:
|
if "file" not in request.files:
|
||||||
raise NoFileUploadedError()
|
raise NoFileUploadedError()
|
||||||
|
|
@ -232,10 +248,14 @@ class WorkspaceInfoApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
# Change workspace name
|
# Change workspace name
|
||||||
def post(self):
|
def post(self):
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("name", type=str, required=True, location="json")
|
parser.add_argument("name", type=str, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not current_user.current_tenant_id:
|
||||||
|
raise ValueError("No current tenant")
|
||||||
tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
|
tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
|
||||||
tenant.name = args["name"]
|
tenant.name = args["name"]
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,6 @@ api = ExternalApi(
|
||||||
|
|
||||||
files_ns = Namespace("files", description="File operations", path="/")
|
files_ns = Namespace("files", description="File operations", path="/")
|
||||||
|
|
||||||
from . import image_preview, tool_files, upload
|
from . import image_preview, tool_files, upload # pyright: ignore[reportUnusedImport]
|
||||||
|
|
||||||
api.add_namespace(files_ns)
|
api.add_namespace(files_ns)
|
||||||
|
|
|
||||||
|
|
@ -16,8 +16,8 @@ api = ExternalApi(
|
||||||
# Create namespace
|
# Create namespace
|
||||||
inner_api_ns = Namespace("inner_api", description="Internal API operations", path="/")
|
inner_api_ns = Namespace("inner_api", description="Internal API operations", path="/")
|
||||||
|
|
||||||
from . import mail
|
from . import mail as _mail # pyright: ignore[reportUnusedImport]
|
||||||
from .plugin import plugin
|
from .plugin import plugin as _plugin # pyright: ignore[reportUnusedImport]
|
||||||
from .workspace import workspace
|
from .workspace import workspace as _workspace # pyright: ignore[reportUnusedImport]
|
||||||
|
|
||||||
api.add_namespace(inner_api_ns)
|
api.add_namespace(inner_api_ns)
|
||||||
|
|
|
||||||
|
|
@ -37,9 +37,9 @@ from models.model import EndUser
|
||||||
|
|
||||||
@inner_api_ns.route("/invoke/llm")
|
@inner_api_ns.route("/invoke/llm")
|
||||||
class PluginInvokeLLMApi(Resource):
|
class PluginInvokeLLMApi(Resource):
|
||||||
|
@get_user_tenant
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
@get_user_tenant
|
|
||||||
@plugin_data(payload_type=RequestInvokeLLM)
|
@plugin_data(payload_type=RequestInvokeLLM)
|
||||||
@inner_api_ns.doc("plugin_invoke_llm")
|
@inner_api_ns.doc("plugin_invoke_llm")
|
||||||
@inner_api_ns.doc(description="Invoke LLM models through plugin interface")
|
@inner_api_ns.doc(description="Invoke LLM models through plugin interface")
|
||||||
|
|
@ -60,9 +60,9 @@ class PluginInvokeLLMApi(Resource):
|
||||||
|
|
||||||
@inner_api_ns.route("/invoke/llm/structured-output")
|
@inner_api_ns.route("/invoke/llm/structured-output")
|
||||||
class PluginInvokeLLMWithStructuredOutputApi(Resource):
|
class PluginInvokeLLMWithStructuredOutputApi(Resource):
|
||||||
|
@get_user_tenant
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
@get_user_tenant
|
|
||||||
@plugin_data(payload_type=RequestInvokeLLMWithStructuredOutput)
|
@plugin_data(payload_type=RequestInvokeLLMWithStructuredOutput)
|
||||||
@inner_api_ns.doc("plugin_invoke_llm_structured")
|
@inner_api_ns.doc("plugin_invoke_llm_structured")
|
||||||
@inner_api_ns.doc(description="Invoke LLM models with structured output through plugin interface")
|
@inner_api_ns.doc(description="Invoke LLM models with structured output through plugin interface")
|
||||||
|
|
@ -85,9 +85,9 @@ class PluginInvokeLLMWithStructuredOutputApi(Resource):
|
||||||
|
|
||||||
@inner_api_ns.route("/invoke/text-embedding")
|
@inner_api_ns.route("/invoke/text-embedding")
|
||||||
class PluginInvokeTextEmbeddingApi(Resource):
|
class PluginInvokeTextEmbeddingApi(Resource):
|
||||||
|
@get_user_tenant
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
@get_user_tenant
|
|
||||||
@plugin_data(payload_type=RequestInvokeTextEmbedding)
|
@plugin_data(payload_type=RequestInvokeTextEmbedding)
|
||||||
@inner_api_ns.doc("plugin_invoke_text_embedding")
|
@inner_api_ns.doc("plugin_invoke_text_embedding")
|
||||||
@inner_api_ns.doc(description="Invoke text embedding models through plugin interface")
|
@inner_api_ns.doc(description="Invoke text embedding models through plugin interface")
|
||||||
|
|
@ -115,9 +115,9 @@ class PluginInvokeTextEmbeddingApi(Resource):
|
||||||
|
|
||||||
@inner_api_ns.route("/invoke/rerank")
|
@inner_api_ns.route("/invoke/rerank")
|
||||||
class PluginInvokeRerankApi(Resource):
|
class PluginInvokeRerankApi(Resource):
|
||||||
|
@get_user_tenant
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
@get_user_tenant
|
|
||||||
@plugin_data(payload_type=RequestInvokeRerank)
|
@plugin_data(payload_type=RequestInvokeRerank)
|
||||||
@inner_api_ns.doc("plugin_invoke_rerank")
|
@inner_api_ns.doc("plugin_invoke_rerank")
|
||||||
@inner_api_ns.doc(description="Invoke rerank models through plugin interface")
|
@inner_api_ns.doc(description="Invoke rerank models through plugin interface")
|
||||||
|
|
@ -141,9 +141,9 @@ class PluginInvokeRerankApi(Resource):
|
||||||
|
|
||||||
@inner_api_ns.route("/invoke/tts")
|
@inner_api_ns.route("/invoke/tts")
|
||||||
class PluginInvokeTTSApi(Resource):
|
class PluginInvokeTTSApi(Resource):
|
||||||
|
@get_user_tenant
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
@get_user_tenant
|
|
||||||
@plugin_data(payload_type=RequestInvokeTTS)
|
@plugin_data(payload_type=RequestInvokeTTS)
|
||||||
@inner_api_ns.doc("plugin_invoke_tts")
|
@inner_api_ns.doc("plugin_invoke_tts")
|
||||||
@inner_api_ns.doc(description="Invoke text-to-speech models through plugin interface")
|
@inner_api_ns.doc(description="Invoke text-to-speech models through plugin interface")
|
||||||
|
|
@ -168,9 +168,9 @@ class PluginInvokeTTSApi(Resource):
|
||||||
|
|
||||||
@inner_api_ns.route("/invoke/speech2text")
|
@inner_api_ns.route("/invoke/speech2text")
|
||||||
class PluginInvokeSpeech2TextApi(Resource):
|
class PluginInvokeSpeech2TextApi(Resource):
|
||||||
|
@get_user_tenant
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
@get_user_tenant
|
|
||||||
@plugin_data(payload_type=RequestInvokeSpeech2Text)
|
@plugin_data(payload_type=RequestInvokeSpeech2Text)
|
||||||
@inner_api_ns.doc("plugin_invoke_speech2text")
|
@inner_api_ns.doc("plugin_invoke_speech2text")
|
||||||
@inner_api_ns.doc(description="Invoke speech-to-text models through plugin interface")
|
@inner_api_ns.doc(description="Invoke speech-to-text models through plugin interface")
|
||||||
|
|
@ -194,9 +194,9 @@ class PluginInvokeSpeech2TextApi(Resource):
|
||||||
|
|
||||||
@inner_api_ns.route("/invoke/moderation")
|
@inner_api_ns.route("/invoke/moderation")
|
||||||
class PluginInvokeModerationApi(Resource):
|
class PluginInvokeModerationApi(Resource):
|
||||||
|
@get_user_tenant
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
@get_user_tenant
|
|
||||||
@plugin_data(payload_type=RequestInvokeModeration)
|
@plugin_data(payload_type=RequestInvokeModeration)
|
||||||
@inner_api_ns.doc("plugin_invoke_moderation")
|
@inner_api_ns.doc("plugin_invoke_moderation")
|
||||||
@inner_api_ns.doc(description="Invoke moderation models through plugin interface")
|
@inner_api_ns.doc(description="Invoke moderation models through plugin interface")
|
||||||
|
|
@ -220,9 +220,9 @@ class PluginInvokeModerationApi(Resource):
|
||||||
|
|
||||||
@inner_api_ns.route("/invoke/tool")
|
@inner_api_ns.route("/invoke/tool")
|
||||||
class PluginInvokeToolApi(Resource):
|
class PluginInvokeToolApi(Resource):
|
||||||
|
@get_user_tenant
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
@get_user_tenant
|
|
||||||
@plugin_data(payload_type=RequestInvokeTool)
|
@plugin_data(payload_type=RequestInvokeTool)
|
||||||
@inner_api_ns.doc("plugin_invoke_tool")
|
@inner_api_ns.doc("plugin_invoke_tool")
|
||||||
@inner_api_ns.doc(description="Invoke tools through plugin interface")
|
@inner_api_ns.doc(description="Invoke tools through plugin interface")
|
||||||
|
|
@ -252,9 +252,9 @@ class PluginInvokeToolApi(Resource):
|
||||||
|
|
||||||
@inner_api_ns.route("/invoke/parameter-extractor")
|
@inner_api_ns.route("/invoke/parameter-extractor")
|
||||||
class PluginInvokeParameterExtractorNodeApi(Resource):
|
class PluginInvokeParameterExtractorNodeApi(Resource):
|
||||||
|
@get_user_tenant
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
@get_user_tenant
|
|
||||||
@plugin_data(payload_type=RequestInvokeParameterExtractorNode)
|
@plugin_data(payload_type=RequestInvokeParameterExtractorNode)
|
||||||
@inner_api_ns.doc("plugin_invoke_parameter_extractor")
|
@inner_api_ns.doc("plugin_invoke_parameter_extractor")
|
||||||
@inner_api_ns.doc(description="Invoke parameter extractor node through plugin interface")
|
@inner_api_ns.doc(description="Invoke parameter extractor node through plugin interface")
|
||||||
|
|
@ -285,9 +285,9 @@ class PluginInvokeParameterExtractorNodeApi(Resource):
|
||||||
|
|
||||||
@inner_api_ns.route("/invoke/question-classifier")
|
@inner_api_ns.route("/invoke/question-classifier")
|
||||||
class PluginInvokeQuestionClassifierNodeApi(Resource):
|
class PluginInvokeQuestionClassifierNodeApi(Resource):
|
||||||
|
@get_user_tenant
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
@get_user_tenant
|
|
||||||
@plugin_data(payload_type=RequestInvokeQuestionClassifierNode)
|
@plugin_data(payload_type=RequestInvokeQuestionClassifierNode)
|
||||||
@inner_api_ns.doc("plugin_invoke_question_classifier")
|
@inner_api_ns.doc("plugin_invoke_question_classifier")
|
||||||
@inner_api_ns.doc(description="Invoke question classifier node through plugin interface")
|
@inner_api_ns.doc(description="Invoke question classifier node through plugin interface")
|
||||||
|
|
@ -318,9 +318,9 @@ class PluginInvokeQuestionClassifierNodeApi(Resource):
|
||||||
|
|
||||||
@inner_api_ns.route("/invoke/app")
|
@inner_api_ns.route("/invoke/app")
|
||||||
class PluginInvokeAppApi(Resource):
|
class PluginInvokeAppApi(Resource):
|
||||||
|
@get_user_tenant
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
@get_user_tenant
|
|
||||||
@plugin_data(payload_type=RequestInvokeApp)
|
@plugin_data(payload_type=RequestInvokeApp)
|
||||||
@inner_api_ns.doc("plugin_invoke_app")
|
@inner_api_ns.doc("plugin_invoke_app")
|
||||||
@inner_api_ns.doc(description="Invoke application through plugin interface")
|
@inner_api_ns.doc(description="Invoke application through plugin interface")
|
||||||
|
|
@ -348,9 +348,9 @@ class PluginInvokeAppApi(Resource):
|
||||||
|
|
||||||
@inner_api_ns.route("/invoke/encrypt")
|
@inner_api_ns.route("/invoke/encrypt")
|
||||||
class PluginInvokeEncryptApi(Resource):
|
class PluginInvokeEncryptApi(Resource):
|
||||||
|
@get_user_tenant
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
@get_user_tenant
|
|
||||||
@plugin_data(payload_type=RequestInvokeEncrypt)
|
@plugin_data(payload_type=RequestInvokeEncrypt)
|
||||||
@inner_api_ns.doc("plugin_invoke_encrypt")
|
@inner_api_ns.doc("plugin_invoke_encrypt")
|
||||||
@inner_api_ns.doc(description="Encrypt or decrypt data through plugin interface")
|
@inner_api_ns.doc(description="Encrypt or decrypt data through plugin interface")
|
||||||
|
|
@ -375,9 +375,9 @@ class PluginInvokeEncryptApi(Resource):
|
||||||
|
|
||||||
@inner_api_ns.route("/invoke/summary")
|
@inner_api_ns.route("/invoke/summary")
|
||||||
class PluginInvokeSummaryApi(Resource):
|
class PluginInvokeSummaryApi(Resource):
|
||||||
|
@get_user_tenant
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
@get_user_tenant
|
|
||||||
@plugin_data(payload_type=RequestInvokeSummary)
|
@plugin_data(payload_type=RequestInvokeSummary)
|
||||||
@inner_api_ns.doc("plugin_invoke_summary")
|
@inner_api_ns.doc("plugin_invoke_summary")
|
||||||
@inner_api_ns.doc(description="Invoke summary functionality through plugin interface")
|
@inner_api_ns.doc(description="Invoke summary functionality through plugin interface")
|
||||||
|
|
@ -405,9 +405,9 @@ class PluginInvokeSummaryApi(Resource):
|
||||||
|
|
||||||
@inner_api_ns.route("/upload/file/request")
|
@inner_api_ns.route("/upload/file/request")
|
||||||
class PluginUploadFileRequestApi(Resource):
|
class PluginUploadFileRequestApi(Resource):
|
||||||
|
@get_user_tenant
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
@get_user_tenant
|
|
||||||
@plugin_data(payload_type=RequestRequestUploadFile)
|
@plugin_data(payload_type=RequestRequestUploadFile)
|
||||||
@inner_api_ns.doc("plugin_upload_file_request")
|
@inner_api_ns.doc("plugin_upload_file_request")
|
||||||
@inner_api_ns.doc(description="Request signed URL for file upload through plugin interface")
|
@inner_api_ns.doc(description="Request signed URL for file upload through plugin interface")
|
||||||
|
|
@ -426,9 +426,9 @@ class PluginUploadFileRequestApi(Resource):
|
||||||
|
|
||||||
@inner_api_ns.route("/fetch/app/info")
|
@inner_api_ns.route("/fetch/app/info")
|
||||||
class PluginFetchAppInfoApi(Resource):
|
class PluginFetchAppInfoApi(Resource):
|
||||||
|
@get_user_tenant
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
@get_user_tenant
|
|
||||||
@plugin_data(payload_type=RequestFetchAppInfo)
|
@plugin_data(payload_type=RequestFetchAppInfo)
|
||||||
@inner_api_ns.doc("plugin_fetch_app_info")
|
@inner_api_ns.doc("plugin_fetch_app_info")
|
||||||
@inner_api_ns.doc(description="Fetch application information through plugin interface")
|
@inner_api_ns.doc(description="Fetch application information through plugin interface")
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Optional, ParamSpec, TypeVar
|
from typing import Optional, ParamSpec, TypeVar, cast
|
||||||
|
|
||||||
from flask import current_app, request
|
from flask import current_app, request
|
||||||
from flask_login import user_logged_in
|
from flask_login import user_logged_in
|
||||||
|
|
@ -10,7 +10,7 @@ from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from core.file.constants import DEFAULT_SERVICE_API_USER_ID
|
from core.file.constants import DEFAULT_SERVICE_API_USER_ID
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.login import _get_user
|
from libs.login import current_user
|
||||||
from models.account import Tenant
|
from models.account import Tenant
|
||||||
from models.model import EndUser
|
from models.model import EndUser
|
||||||
|
|
||||||
|
|
@ -66,8 +66,8 @@ def get_user_tenant(view: Optional[Callable[P, R]] = None):
|
||||||
|
|
||||||
p = parser.parse_args()
|
p = parser.parse_args()
|
||||||
|
|
||||||
user_id: Optional[str] = p.get("user_id")
|
user_id = cast(str, p.get("user_id"))
|
||||||
tenant_id: str = p.get("tenant_id")
|
tenant_id = cast(str, p.get("tenant_id"))
|
||||||
|
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
raise ValueError("tenant_id is required")
|
raise ValueError("tenant_id is required")
|
||||||
|
|
@ -98,7 +98,7 @@ def get_user_tenant(view: Optional[Callable[P, R]] = None):
|
||||||
kwargs["user_model"] = user
|
kwargs["user_model"] = user
|
||||||
|
|
||||||
current_app.login_manager._update_request_context_with_user(user) # type: ignore
|
current_app.login_manager._update_request_context_with_user(user) # type: ignore
|
||||||
user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
|
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
|
||||||
|
|
||||||
return view_func(*args, **kwargs)
|
return view_func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,6 @@ api = ExternalApi(
|
||||||
|
|
||||||
mcp_ns = Namespace("mcp", description="MCP operations", path="/")
|
mcp_ns = Namespace("mcp", description="MCP operations", path="/")
|
||||||
|
|
||||||
from . import mcp
|
from . import mcp # pyright: ignore[reportUnusedImport]
|
||||||
|
|
||||||
api.add_namespace(mcp_ns)
|
api.add_namespace(mcp_ns)
|
||||||
|
|
|
||||||
|
|
@ -15,9 +15,27 @@ api = ExternalApi(
|
||||||
|
|
||||||
service_api_ns = Namespace("service_api", description="Service operations", path="/")
|
service_api_ns = Namespace("service_api", description="Service operations", path="/")
|
||||||
|
|
||||||
from . import index
|
from . import index # pyright: ignore[reportUnusedImport]
|
||||||
from .app import annotation, app, audio, completion, conversation, file, file_preview, message, site, workflow
|
from .app import (
|
||||||
from .dataset import dataset, document, hit_testing, metadata, segment, upload_file
|
annotation, # pyright: ignore[reportUnusedImport]
|
||||||
from .workspace import models
|
app, # pyright: ignore[reportUnusedImport]
|
||||||
|
audio, # pyright: ignore[reportUnusedImport]
|
||||||
|
completion, # pyright: ignore[reportUnusedImport]
|
||||||
|
conversation, # pyright: ignore[reportUnusedImport]
|
||||||
|
file, # pyright: ignore[reportUnusedImport]
|
||||||
|
file_preview, # pyright: ignore[reportUnusedImport]
|
||||||
|
message, # pyright: ignore[reportUnusedImport]
|
||||||
|
site, # pyright: ignore[reportUnusedImport]
|
||||||
|
workflow, # pyright: ignore[reportUnusedImport]
|
||||||
|
)
|
||||||
|
from .dataset import (
|
||||||
|
dataset, # pyright: ignore[reportUnusedImport]
|
||||||
|
document, # pyright: ignore[reportUnusedImport]
|
||||||
|
hit_testing, # pyright: ignore[reportUnusedImport]
|
||||||
|
metadata, # pyright: ignore[reportUnusedImport]
|
||||||
|
segment, # pyright: ignore[reportUnusedImport]
|
||||||
|
upload_file, # pyright: ignore[reportUnusedImport]
|
||||||
|
)
|
||||||
|
from .workspace import models # pyright: ignore[reportUnusedImport]
|
||||||
|
|
||||||
api.add_namespace(service_api_ns)
|
api.add_namespace(service_api_ns)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
|
from flask_restx._http import HTTPStatus
|
||||||
from flask_restx.inputs import int_range
|
from flask_restx.inputs import int_range
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import BadRequest, NotFound
|
from werkzeug.exceptions import BadRequest, NotFound
|
||||||
|
|
@ -121,7 +122,7 @@ class ConversationDetailApi(Resource):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||||
@service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=204)
|
@service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=HTTPStatus.NO_CONTENT)
|
||||||
def delete(self, app_model: App, end_user: EndUser, c_id):
|
def delete(self, app_model: App, end_user: EndUser, c_id):
|
||||||
"""Delete a specific conversation."""
|
"""Delete a specific conversation."""
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@ from extensions.ext_database import db
|
||||||
from fields.document_fields import document_fields, document_status_fields
|
from fields.document_fields import document_fields, document_status_fields
|
||||||
from libs.login import current_user
|
from libs.login import current_user
|
||||||
from models.dataset import Dataset, Document, DocumentSegment
|
from models.dataset import Dataset, Document, DocumentSegment
|
||||||
|
from models.model import EndUser
|
||||||
from services.dataset_service import DatasetService, DocumentService
|
from services.dataset_service import DatasetService, DocumentService
|
||||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
|
|
@ -298,6 +299,9 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||||
if not file.filename:
|
if not file.filename:
|
||||||
raise FilenameNotExistsError
|
raise FilenameNotExistsError
|
||||||
|
|
||||||
|
if not isinstance(current_user, EndUser):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
|
|
||||||
upload_file = FileService.upload_file(
|
upload_file = FileService.upload_file(
|
||||||
filename=file.filename,
|
filename=file.filename,
|
||||||
content=file.read(),
|
content=file.read(),
|
||||||
|
|
@ -387,6 +391,8 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||||
raise FilenameNotExistsError
|
raise FilenameNotExistsError
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if not isinstance(current_user, EndUser):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
upload_file = FileService.upload_file(
|
upload_file = FileService.upload_file(
|
||||||
filename=file.filename,
|
filename=file.filename,
|
||||||
content=file.read(),
|
content=file.read(),
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from core.file.constants import DEFAULT_SERVICE_API_USER_ID
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.login import _get_user
|
from libs.login import current_user
|
||||||
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
|
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
|
||||||
from models.dataset import Dataset, RateLimitLog
|
from models.dataset import Dataset, RateLimitLog
|
||||||
from models.model import ApiToken, App, EndUser
|
from models.model import ApiToken, App, EndUser
|
||||||
|
|
@ -210,7 +210,7 @@ def validate_dataset_token(view: Optional[Callable[Concatenate[T, P], R]] = None
|
||||||
if account:
|
if account:
|
||||||
account.current_tenant = tenant
|
account.current_tenant = tenant
|
||||||
current_app.login_manager._update_request_context_with_user(account) # type: ignore
|
current_app.login_manager._update_request_context_with_user(account) # type: ignore
|
||||||
user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
|
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
|
||||||
else:
|
else:
|
||||||
raise Unauthorized("Tenant owner account does not exist.")
|
raise Unauthorized("Tenant owner account does not exist.")
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -17,20 +17,20 @@ api = ExternalApi(
|
||||||
web_ns = Namespace("web", description="Web application API operations", path="/")
|
web_ns = Namespace("web", description="Web application API operations", path="/")
|
||||||
|
|
||||||
from . import (
|
from . import (
|
||||||
app,
|
app, # pyright: ignore[reportUnusedImport]
|
||||||
audio,
|
audio, # pyright: ignore[reportUnusedImport]
|
||||||
completion,
|
completion, # pyright: ignore[reportUnusedImport]
|
||||||
conversation,
|
conversation, # pyright: ignore[reportUnusedImport]
|
||||||
feature,
|
feature, # pyright: ignore[reportUnusedImport]
|
||||||
files,
|
files, # pyright: ignore[reportUnusedImport]
|
||||||
forgot_password,
|
forgot_password, # pyright: ignore[reportUnusedImport]
|
||||||
login,
|
login, # pyright: ignore[reportUnusedImport]
|
||||||
message,
|
message, # pyright: ignore[reportUnusedImport]
|
||||||
passport,
|
passport, # pyright: ignore[reportUnusedImport]
|
||||||
remote_files,
|
remote_files, # pyright: ignore[reportUnusedImport]
|
||||||
saved_message,
|
saved_message, # pyright: ignore[reportUnusedImport]
|
||||||
site,
|
site, # pyright: ignore[reportUnusedImport]
|
||||||
workflow,
|
workflow, # pyright: ignore[reportUnusedImport]
|
||||||
)
|
)
|
||||||
|
|
||||||
api.add_namespace(web_ns)
|
api.add_namespace(web_ns)
|
||||||
|
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
import core.moderation.base
|
|
||||||
|
|
@ -72,6 +72,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||||
function_call_state = True
|
function_call_state = True
|
||||||
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
|
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
|
||||||
final_answer = ""
|
final_answer = ""
|
||||||
|
prompt_messages: list = [] # Initialize prompt_messages
|
||||||
|
agent_thought_id = "" # Initialize agent_thought_id
|
||||||
|
|
||||||
def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
|
def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
|
||||||
if not final_llm_usage_dict["usage"]:
|
if not final_llm_usage_dict["usage"]:
|
||||||
|
|
|
||||||
|
|
@ -54,6 +54,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||||
function_call_state = True
|
function_call_state = True
|
||||||
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
|
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
|
||||||
final_answer = ""
|
final_answer = ""
|
||||||
|
prompt_messages: list = [] # Initialize prompt_messages
|
||||||
|
|
||||||
# get tracing instance
|
# get tracing instance
|
||||||
trace_manager = app_generate_entity.trace_manager
|
trace_manager = app_generate_entity.trace_manager
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ class SensitiveWordAvoidanceConfigManager:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_and_set_defaults(
|
def validate_and_set_defaults(
|
||||||
cls, tenant_id, config: dict, only_structure_validate: bool = False
|
cls, tenant_id: str, config: dict, only_structure_validate: bool = False
|
||||||
) -> tuple[dict, list[str]]:
|
) -> tuple[dict, list[str]]:
|
||||||
if not config.get("sensitive_word_avoidance"):
|
if not config.get("sensitive_word_avoidance"):
|
||||||
config["sensitive_word_avoidance"] = {"enabled": False}
|
config["sensitive_word_avoidance"] = {"enabled": False}
|
||||||
|
|
@ -38,7 +38,14 @@ class SensitiveWordAvoidanceConfigManager:
|
||||||
|
|
||||||
if not only_structure_validate:
|
if not only_structure_validate:
|
||||||
typ = config["sensitive_word_avoidance"]["type"]
|
typ = config["sensitive_word_avoidance"]["type"]
|
||||||
sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"]
|
if not isinstance(typ, str):
|
||||||
|
raise ValueError("sensitive_word_avoidance.type must be a string")
|
||||||
|
|
||||||
|
sensitive_word_avoidance_config = config["sensitive_word_avoidance"].get("config")
|
||||||
|
if sensitive_word_avoidance_config is None:
|
||||||
|
sensitive_word_avoidance_config = {}
|
||||||
|
if not isinstance(sensitive_word_avoidance_config, dict):
|
||||||
|
raise ValueError("sensitive_word_avoidance.config must be a dict")
|
||||||
|
|
||||||
ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config)
|
ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,10 +25,14 @@ class PromptTemplateConfigManager:
|
||||||
if chat_prompt_config:
|
if chat_prompt_config:
|
||||||
chat_prompt_messages = []
|
chat_prompt_messages = []
|
||||||
for message in chat_prompt_config.get("prompt", []):
|
for message in chat_prompt_config.get("prompt", []):
|
||||||
|
text = message.get("text")
|
||||||
|
if not isinstance(text, str):
|
||||||
|
raise ValueError("message text must be a string")
|
||||||
|
role = message.get("role")
|
||||||
|
if not isinstance(role, str):
|
||||||
|
raise ValueError("message role must be a string")
|
||||||
chat_prompt_messages.append(
|
chat_prompt_messages.append(
|
||||||
AdvancedChatMessageEntity(
|
AdvancedChatMessageEntity(text=text, role=PromptMessageRole.value_of(role))
|
||||||
**{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])}
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages)
|
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages)
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||||
yield "ping"
|
yield "ping"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
response_chunk = {
|
response_chunk: dict[str, Any] = {
|
||||||
"event": sub_stream_response.event.value,
|
"event": sub_stream_response.event.value,
|
||||||
"conversation_id": chunk.conversation_id,
|
"conversation_id": chunk.conversation_id,
|
||||||
"message_id": chunk.message_id,
|
"message_id": chunk.message_id,
|
||||||
|
|
@ -82,7 +82,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||||
response_chunk.update(data)
|
response_chunk.update(data)
|
||||||
else:
|
else:
|
||||||
response_chunk.update(sub_stream_response.to_dict())
|
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||||
yield response_chunk
|
yield response_chunk
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -102,7 +102,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||||
yield "ping"
|
yield "ping"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
response_chunk = {
|
response_chunk: dict[str, Any] = {
|
||||||
"event": sub_stream_response.event.value,
|
"event": sub_stream_response.event.value,
|
||||||
"conversation_id": chunk.conversation_id,
|
"conversation_id": chunk.conversation_id,
|
||||||
"message_id": chunk.message_id,
|
"message_id": chunk.message_id,
|
||||||
|
|
@ -110,7 +110,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||||
}
|
}
|
||||||
|
|
||||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||||
sub_stream_response_dict = sub_stream_response.to_dict()
|
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||||
metadata = sub_stream_response_dict.get("metadata", {})
|
metadata = sub_stream_response_dict.get("metadata", {})
|
||||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||||
response_chunk.update(sub_stream_response_dict)
|
response_chunk.update(sub_stream_response_dict)
|
||||||
|
|
@ -118,8 +118,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||||
response_chunk.update(data)
|
response_chunk.update(data)
|
||||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
|
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||||
else:
|
else:
|
||||||
response_chunk.update(sub_stream_response.to_dict())
|
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||||
|
|
||||||
yield response_chunk
|
yield response_chunk
|
||||||
|
|
|
||||||
|
|
@ -174,7 +174,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||||
|
|
||||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||||
|
|
||||||
if self._base_task_pipeline._stream:
|
if self._base_task_pipeline.stream:
|
||||||
return self._to_stream_response(generator)
|
return self._to_stream_response(generator)
|
||||||
else:
|
else:
|
||||||
return self._to_blocking_response(generator)
|
return self._to_blocking_response(generator)
|
||||||
|
|
@ -302,13 +302,13 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||||
|
|
||||||
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
|
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
|
||||||
"""Handle ping events."""
|
"""Handle ping events."""
|
||||||
yield self._base_task_pipeline._ping_stream_response()
|
yield self._base_task_pipeline.ping_stream_response()
|
||||||
|
|
||||||
def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]:
|
def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]:
|
||||||
"""Handle error events."""
|
"""Handle error events."""
|
||||||
with self._database_session() as session:
|
with self._database_session() as session:
|
||||||
err = self._base_task_pipeline._handle_error(event=event, session=session, message_id=self._message_id)
|
err = self._base_task_pipeline.handle_error(event=event, session=session, message_id=self._message_id)
|
||||||
yield self._base_task_pipeline._error_to_stream_response(err)
|
yield self._base_task_pipeline.error_to_stream_response(err)
|
||||||
|
|
||||||
def _handle_workflow_started_event(self, *args, **kwargs) -> Generator[StreamResponse, None, None]:
|
def _handle_workflow_started_event(self, *args, **kwargs) -> Generator[StreamResponse, None, None]:
|
||||||
"""Handle workflow started events."""
|
"""Handle workflow started events."""
|
||||||
|
|
@ -627,10 +627,10 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||||
workflow_execution=workflow_execution,
|
workflow_execution=workflow_execution,
|
||||||
)
|
)
|
||||||
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}"))
|
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}"))
|
||||||
err = self._base_task_pipeline._handle_error(event=err_event, session=session, message_id=self._message_id)
|
err = self._base_task_pipeline.handle_error(event=err_event, session=session, message_id=self._message_id)
|
||||||
|
|
||||||
yield workflow_finish_resp
|
yield workflow_finish_resp
|
||||||
yield self._base_task_pipeline._error_to_stream_response(err)
|
yield self._base_task_pipeline.error_to_stream_response(err)
|
||||||
|
|
||||||
def _handle_stop_event(
|
def _handle_stop_event(
|
||||||
self,
|
self,
|
||||||
|
|
@ -683,7 +683,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||||
"""Handle advanced chat message end events."""
|
"""Handle advanced chat message end events."""
|
||||||
self._ensure_graph_runtime_initialized(graph_runtime_state)
|
self._ensure_graph_runtime_initialized(graph_runtime_state)
|
||||||
|
|
||||||
output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished(
|
output_moderation_answer = self._base_task_pipeline.handle_output_moderation_when_task_finished(
|
||||||
self._task_state.answer
|
self._task_state.answer
|
||||||
)
|
)
|
||||||
if output_moderation_answer:
|
if output_moderation_answer:
|
||||||
|
|
@ -899,7 +899,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||||
|
|
||||||
message.answer = answer_text
|
message.answer = answer_text
|
||||||
message.updated_at = naive_utc_now()
|
message.updated_at = naive_utc_now()
|
||||||
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
|
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline.start_at
|
||||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||||
message_files = [
|
message_files = [
|
||||||
MessageFile(
|
MessageFile(
|
||||||
|
|
@ -955,9 +955,9 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||||
:param text: text
|
:param text: text
|
||||||
:return: True if output moderation should direct output, otherwise False
|
:return: True if output moderation should direct output, otherwise False
|
||||||
"""
|
"""
|
||||||
if self._base_task_pipeline._output_moderation_handler:
|
if self._base_task_pipeline.output_moderation_handler:
|
||||||
if self._base_task_pipeline._output_moderation_handler.should_direct_output():
|
if self._base_task_pipeline.output_moderation_handler.should_direct_output():
|
||||||
self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output()
|
self._task_state.answer = self._base_task_pipeline.output_moderation_handler.get_final_output()
|
||||||
self._base_task_pipeline.queue_manager.publish(
|
self._base_task_pipeline.queue_manager.publish(
|
||||||
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
|
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
|
||||||
)
|
)
|
||||||
|
|
@ -967,7 +967,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
self._base_task_pipeline._output_moderation_handler.append_new_token(text)
|
self._base_task_pipeline.output_moderation_handler.append_new_token(text)
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
from core.agent.entities import AgentEntity
|
from core.agent.entities import AgentEntity
|
||||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||||
|
|
@ -160,7 +160,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||||
return filtered_config
|
return filtered_config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_agent_mode_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
|
def validate_agent_mode_and_set_defaults(
|
||||||
|
cls, tenant_id: str, config: dict[str, Any]
|
||||||
|
) -> tuple[dict[str, Any], list[str]]:
|
||||||
"""
|
"""
|
||||||
Validate agent_mode and set defaults for agent feature
|
Validate agent_mode and set defaults for agent feature
|
||||||
|
|
||||||
|
|
@ -170,30 +172,32 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||||
if not config.get("agent_mode"):
|
if not config.get("agent_mode"):
|
||||||
config["agent_mode"] = {"enabled": False, "tools": []}
|
config["agent_mode"] = {"enabled": False, "tools": []}
|
||||||
|
|
||||||
if not isinstance(config["agent_mode"], dict):
|
agent_mode = config["agent_mode"]
|
||||||
|
if not isinstance(agent_mode, dict):
|
||||||
raise ValueError("agent_mode must be of object type")
|
raise ValueError("agent_mode must be of object type")
|
||||||
|
|
||||||
if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]:
|
# FIXME(-LAN-): Cast needed due to basedpyright limitation with dict type narrowing
|
||||||
config["agent_mode"]["enabled"] = False
|
agent_mode = cast(dict[str, Any], agent_mode)
|
||||||
|
|
||||||
if not isinstance(config["agent_mode"]["enabled"], bool):
|
if "enabled" not in agent_mode or not agent_mode["enabled"]:
|
||||||
|
agent_mode["enabled"] = False
|
||||||
|
|
||||||
|
if not isinstance(agent_mode["enabled"], bool):
|
||||||
raise ValueError("enabled in agent_mode must be of boolean type")
|
raise ValueError("enabled in agent_mode must be of boolean type")
|
||||||
|
|
||||||
if not config["agent_mode"].get("strategy"):
|
if not agent_mode.get("strategy"):
|
||||||
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
|
agent_mode["strategy"] = PlanningStrategy.ROUTER.value
|
||||||
|
|
||||||
if config["agent_mode"]["strategy"] not in [
|
if agent_mode["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]:
|
||||||
member.value for member in list(PlanningStrategy.__members__.values())
|
|
||||||
]:
|
|
||||||
raise ValueError("strategy in agent_mode must be in the specified strategy list")
|
raise ValueError("strategy in agent_mode must be in the specified strategy list")
|
||||||
|
|
||||||
if not config["agent_mode"].get("tools"):
|
if not agent_mode.get("tools"):
|
||||||
config["agent_mode"]["tools"] = []
|
agent_mode["tools"] = []
|
||||||
|
|
||||||
if not isinstance(config["agent_mode"]["tools"], list):
|
if not isinstance(agent_mode["tools"], list):
|
||||||
raise ValueError("tools in agent_mode must be a list of objects")
|
raise ValueError("tools in agent_mode must be a list of objects")
|
||||||
|
|
||||||
for tool in config["agent_mode"]["tools"]:
|
for tool in agent_mode["tools"]:
|
||||||
key = list(tool.keys())[0]
|
key = list(tool.keys())[0]
|
||||||
if key in OLD_TOOLS:
|
if key in OLD_TOOLS:
|
||||||
# old style, use tool name as key
|
# old style, use tool name as key
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,10 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||||
response = cls.convert_blocking_full_response(blocking_response)
|
response = cls.convert_blocking_full_response(blocking_response)
|
||||||
|
|
||||||
metadata = response.get("metadata", {})
|
metadata = response.get("metadata", {})
|
||||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
if isinstance(metadata, dict):
|
||||||
|
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||||
|
else:
|
||||||
|
response["metadata"] = {}
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
@ -78,7 +81,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||||
response_chunk.update(data)
|
response_chunk.update(data)
|
||||||
else:
|
else:
|
||||||
response_chunk.update(sub_stream_response.to_dict())
|
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||||
yield response_chunk
|
yield response_chunk
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -106,7 +109,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||||
}
|
}
|
||||||
|
|
||||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||||
sub_stream_response_dict = sub_stream_response.to_dict()
|
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||||
metadata = sub_stream_response_dict.get("metadata", {})
|
metadata = sub_stream_response_dict.get("metadata", {})
|
||||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||||
response_chunk.update(sub_stream_response_dict)
|
response_chunk.update(sub_stream_response_dict)
|
||||||
|
|
@ -114,6 +117,6 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||||
response_chunk.update(data)
|
response_chunk.update(data)
|
||||||
else:
|
else:
|
||||||
response_chunk.update(sub_stream_response.to_dict())
|
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||||
|
|
||||||
yield response_chunk
|
yield response_chunk
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,7 @@ class AppQueueManager:
|
||||||
self._task_id = task_id
|
self._task_id = task_id
|
||||||
self._user_id = user_id
|
self._user_id = user_id
|
||||||
self._invoke_from = invoke_from
|
self._invoke_from = invoke_from
|
||||||
|
self.invoke_from = invoke_from # Public accessor for invoke_from
|
||||||
|
|
||||||
user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
|
user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
|
||||||
redis_client.setex(
|
redis_client.setex(
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,10 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||||
response = cls.convert_blocking_full_response(blocking_response)
|
response = cls.convert_blocking_full_response(blocking_response)
|
||||||
|
|
||||||
metadata = response.get("metadata", {})
|
metadata = response.get("metadata", {})
|
||||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
if isinstance(metadata, dict):
|
||||||
|
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||||
|
else:
|
||||||
|
response["metadata"] = {}
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
@ -78,7 +81,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||||
response_chunk.update(data)
|
response_chunk.update(data)
|
||||||
else:
|
else:
|
||||||
response_chunk.update(sub_stream_response.to_dict())
|
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||||
yield response_chunk
|
yield response_chunk
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -106,7 +109,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||||
}
|
}
|
||||||
|
|
||||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||||
sub_stream_response_dict = sub_stream_response.to_dict()
|
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||||
metadata = sub_stream_response_dict.get("metadata", {})
|
metadata = sub_stream_response_dict.get("metadata", {})
|
||||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||||
response_chunk.update(sub_stream_response_dict)
|
response_chunk.update(sub_stream_response_dict)
|
||||||
|
|
@ -114,6 +117,6 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||||
response_chunk.update(data)
|
response_chunk.update(data)
|
||||||
else:
|
else:
|
||||||
response_chunk.update(sub_stream_response.to_dict())
|
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||||
|
|
||||||
yield response_chunk
|
yield response_chunk
|
||||||
|
|
|
||||||
|
|
@ -271,6 +271,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||||
raise MoreLikeThisDisabledError()
|
raise MoreLikeThisDisabledError()
|
||||||
|
|
||||||
app_model_config = message.app_model_config
|
app_model_config = message.app_model_config
|
||||||
|
if not app_model_config:
|
||||||
|
raise ValueError("Message app_model_config is None")
|
||||||
override_model_config_dict = app_model_config.to_dict()
|
override_model_config_dict = app_model_config.to_dict()
|
||||||
model_dict = override_model_config_dict["model"]
|
model_dict = override_model_config_dict["model"]
|
||||||
completion_params = model_dict.get("completion_params")
|
completion_params = model_dict.get("completion_params")
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,10 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||||
response = cls.convert_blocking_full_response(blocking_response)
|
response = cls.convert_blocking_full_response(blocking_response)
|
||||||
|
|
||||||
metadata = response.get("metadata", {})
|
metadata = response.get("metadata", {})
|
||||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
if isinstance(metadata, dict):
|
||||||
|
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||||
|
else:
|
||||||
|
response["metadata"] = {}
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
@ -76,7 +79,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||||
response_chunk.update(data)
|
response_chunk.update(data)
|
||||||
else:
|
else:
|
||||||
response_chunk.update(sub_stream_response.to_dict())
|
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||||
yield response_chunk
|
yield response_chunk
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -103,14 +106,16 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||||
}
|
}
|
||||||
|
|
||||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||||
sub_stream_response_dict = sub_stream_response.to_dict()
|
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||||
metadata = sub_stream_response_dict.get("metadata", {})
|
metadata = sub_stream_response_dict.get("metadata", {})
|
||||||
|
if not isinstance(metadata, dict):
|
||||||
|
metadata = {}
|
||||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||||
response_chunk.update(sub_stream_response_dict)
|
response_chunk.update(sub_stream_response_dict)
|
||||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||||
response_chunk.update(data)
|
response_chunk.update(data)
|
||||||
else:
|
else:
|
||||||
response_chunk.update(sub_stream_response.to_dict())
|
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||||
|
|
||||||
yield response_chunk
|
yield response_chunk
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||||
:param blocking_response: blocking response
|
:param blocking_response: blocking response
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
return dict(blocking_response.to_dict())
|
return blocking_response.model_dump()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
|
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
|
||||||
|
|
@ -51,7 +51,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||||
yield "ping"
|
yield "ping"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
response_chunk = {
|
response_chunk: dict[str, object] = {
|
||||||
"event": sub_stream_response.event.value,
|
"event": sub_stream_response.event.value,
|
||||||
"workflow_run_id": chunk.workflow_run_id,
|
"workflow_run_id": chunk.workflow_run_id,
|
||||||
}
|
}
|
||||||
|
|
@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||||
response_chunk.update(data)
|
response_chunk.update(data)
|
||||||
else:
|
else:
|
||||||
response_chunk.update(sub_stream_response.to_dict())
|
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||||
yield response_chunk
|
yield response_chunk
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -80,7 +80,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||||
yield "ping"
|
yield "ping"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
response_chunk = {
|
response_chunk: dict[str, object] = {
|
||||||
"event": sub_stream_response.event.value,
|
"event": sub_stream_response.event.value,
|
||||||
"workflow_run_id": chunk.workflow_run_id,
|
"workflow_run_id": chunk.workflow_run_id,
|
||||||
}
|
}
|
||||||
|
|
@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
|
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
|
||||||
else:
|
else:
|
||||||
response_chunk.update(sub_stream_response.to_dict())
|
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||||
yield response_chunk
|
yield response_chunk
|
||||||
|
|
|
||||||
|
|
@ -137,7 +137,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||||
self._application_generate_entity = application_generate_entity
|
self._application_generate_entity = application_generate_entity
|
||||||
self._workflow_features_dict = workflow.features_dict
|
self._workflow_features_dict = workflow.features_dict
|
||||||
self._workflow_run_id = ""
|
self._workflow_run_id = ""
|
||||||
self._invoke_from = queue_manager._invoke_from
|
self._invoke_from = queue_manager.invoke_from
|
||||||
self._draft_var_saver_factory = draft_var_saver_factory
|
self._draft_var_saver_factory = draft_var_saver_factory
|
||||||
|
|
||||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||||
|
|
@ -146,7 +146,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||||
if self._base_task_pipeline._stream:
|
if self._base_task_pipeline.stream:
|
||||||
return self._to_stream_response(generator)
|
return self._to_stream_response(generator)
|
||||||
else:
|
else:
|
||||||
return self._to_blocking_response(generator)
|
return self._to_blocking_response(generator)
|
||||||
|
|
@ -276,12 +276,12 @@ class WorkflowAppGenerateTaskPipeline:
|
||||||
|
|
||||||
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
|
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
|
||||||
"""Handle ping events."""
|
"""Handle ping events."""
|
||||||
yield self._base_task_pipeline._ping_stream_response()
|
yield self._base_task_pipeline.ping_stream_response()
|
||||||
|
|
||||||
def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]:
|
def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]:
|
||||||
"""Handle error events."""
|
"""Handle error events."""
|
||||||
err = self._base_task_pipeline._handle_error(event=event)
|
err = self._base_task_pipeline.handle_error(event=event)
|
||||||
yield self._base_task_pipeline._error_to_stream_response(err)
|
yield self._base_task_pipeline.error_to_stream_response(err)
|
||||||
|
|
||||||
def _handle_workflow_started_event(
|
def _handle_workflow_started_event(
|
||||||
self, event: QueueWorkflowStartedEvent, **kwargs
|
self, event: QueueWorkflowStartedEvent, **kwargs
|
||||||
|
|
|
||||||
|
|
@ -123,7 +123,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# app config
|
# app config
|
||||||
app_config: EasyUIBasedAppConfig
|
app_config: EasyUIBasedAppConfig = None # type: ignore
|
||||||
model_conf: ModelConfigWithCredentialsEntity
|
model_conf: ModelConfigWithCredentialsEntity
|
||||||
|
|
||||||
query: Optional[str] = None
|
query: Optional[str] = None
|
||||||
|
|
@ -186,7 +186,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# app config
|
# app config
|
||||||
app_config: WorkflowUIBasedAppConfig
|
app_config: WorkflowUIBasedAppConfig = None # type: ignore
|
||||||
|
|
||||||
workflow_run_id: Optional[str] = None
|
workflow_run_id: Optional[str] = None
|
||||||
query: str
|
query: str
|
||||||
|
|
@ -218,7 +218,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# app config
|
# app config
|
||||||
app_config: WorkflowUIBasedAppConfig
|
app_config: WorkflowUIBasedAppConfig = None # type: ignore
|
||||||
workflow_execution_id: str
|
workflow_execution_id: str
|
||||||
|
|
||||||
class SingleIterationRunEntity(BaseModel):
|
class SingleIterationRunEntity(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ from typing import Any, Optional
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
|
||||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||||
|
|
@ -92,9 +91,6 @@ class StreamResponse(BaseModel):
|
||||||
event: StreamEvent
|
event: StreamEvent
|
||||||
task_id: str
|
task_id: str
|
||||||
|
|
||||||
def to_dict(self):
|
|
||||||
return jsonable_encoder(self)
|
|
||||||
|
|
||||||
|
|
||||||
class ErrorStreamResponse(StreamResponse):
|
class ErrorStreamResponse(StreamResponse):
|
||||||
"""
|
"""
|
||||||
|
|
@ -745,9 +741,6 @@ class AppBlockingResponse(BaseModel):
|
||||||
|
|
||||||
task_id: str
|
task_id: str
|
||||||
|
|
||||||
def to_dict(self):
|
|
||||||
return jsonable_encoder(self)
|
|
||||||
|
|
||||||
|
|
||||||
class ChatbotAppBlockingResponse(AppBlockingResponse):
|
class ChatbotAppBlockingResponse(AppBlockingResponse):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,9 @@ class AnnotationReplyFeature:
|
||||||
|
|
||||||
collection_binding_detail = annotation_setting.collection_binding_detail
|
collection_binding_detail = annotation_setting.collection_binding_detail
|
||||||
|
|
||||||
|
if not collection_binding_detail:
|
||||||
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
score_threshold = annotation_setting.score_threshold or 1
|
score_threshold = annotation_setting.score_threshold or 1
|
||||||
embedding_provider_name = collection_binding_detail.provider_name
|
embedding_provider_name = collection_binding_detail.provider_name
|
||||||
|
|
|
||||||
|
|
@ -1 +1,3 @@
|
||||||
from .rate_limit import RateLimit
|
from .rate_limit import RateLimit
|
||||||
|
|
||||||
|
__all__ = ["RateLimit"]
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ class RateLimit:
|
||||||
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
|
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
|
||||||
_instance_dict: dict[str, "RateLimit"] = {}
|
_instance_dict: dict[str, "RateLimit"] = {}
|
||||||
|
|
||||||
def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int):
|
def __new__(cls, client_id: str, max_active_requests: int):
|
||||||
if client_id not in cls._instance_dict:
|
if client_id not in cls._instance_dict:
|
||||||
instance = super().__new__(cls)
|
instance = super().__new__(cls)
|
||||||
cls._instance_dict[client_id] = instance
|
cls._instance_dict[client_id] = instance
|
||||||
|
|
|
||||||
|
|
@ -38,11 +38,11 @@ class BasedGenerateTaskPipeline:
|
||||||
):
|
):
|
||||||
self._application_generate_entity = application_generate_entity
|
self._application_generate_entity = application_generate_entity
|
||||||
self.queue_manager = queue_manager
|
self.queue_manager = queue_manager
|
||||||
self._start_at = time.perf_counter()
|
self.start_at = time.perf_counter()
|
||||||
self._output_moderation_handler = self._init_output_moderation()
|
self.output_moderation_handler = self._init_output_moderation()
|
||||||
self._stream = stream
|
self.stream = stream
|
||||||
|
|
||||||
def _handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""):
|
def handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""):
|
||||||
logger.debug("error: %s", event.error)
|
logger.debug("error: %s", event.error)
|
||||||
e = event.error
|
e = event.error
|
||||||
err: Exception
|
err: Exception
|
||||||
|
|
@ -86,7 +86,7 @@ class BasedGenerateTaskPipeline:
|
||||||
|
|
||||||
return message
|
return message
|
||||||
|
|
||||||
def _error_to_stream_response(self, e: Exception):
|
def error_to_stream_response(self, e: Exception):
|
||||||
"""
|
"""
|
||||||
Error to stream response.
|
Error to stream response.
|
||||||
:param e: exception
|
:param e: exception
|
||||||
|
|
@ -94,7 +94,7 @@ class BasedGenerateTaskPipeline:
|
||||||
"""
|
"""
|
||||||
return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e)
|
return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e)
|
||||||
|
|
||||||
def _ping_stream_response(self) -> PingStreamResponse:
|
def ping_stream_response(self) -> PingStreamResponse:
|
||||||
"""
|
"""
|
||||||
Ping stream response.
|
Ping stream response.
|
||||||
:return:
|
:return:
|
||||||
|
|
@ -118,21 +118,21 @@ class BasedGenerateTaskPipeline:
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]:
|
def handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Handle output moderation when task finished.
|
Handle output moderation when task finished.
|
||||||
:param completion: completion
|
:param completion: completion
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# response moderation
|
# response moderation
|
||||||
if self._output_moderation_handler:
|
if self.output_moderation_handler:
|
||||||
self._output_moderation_handler.stop_thread()
|
self.output_moderation_handler.stop_thread()
|
||||||
|
|
||||||
completion, flagged = self._output_moderation_handler.moderation_completion(
|
completion, flagged = self.output_moderation_handler.moderation_completion(
|
||||||
completion=completion, public_event=False
|
completion=completion, public_event=False
|
||||||
)
|
)
|
||||||
|
|
||||||
self._output_moderation_handler = None
|
self.output_moderation_handler = None
|
||||||
if flagged:
|
if flagged:
|
||||||
return completion
|
return completion
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -125,7 +125,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||||
)
|
)
|
||||||
|
|
||||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||||
if self._stream:
|
if self.stream:
|
||||||
return self._to_stream_response(generator)
|
return self._to_stream_response(generator)
|
||||||
else:
|
else:
|
||||||
return self._to_blocking_response(generator)
|
return self._to_blocking_response(generator)
|
||||||
|
|
@ -265,9 +265,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||||
|
|
||||||
if isinstance(event, QueueErrorEvent):
|
if isinstance(event, QueueErrorEvent):
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
err = self._handle_error(event=event, session=session, message_id=self._message_id)
|
err = self.handle_error(event=event, session=session, message_id=self._message_id)
|
||||||
session.commit()
|
session.commit()
|
||||||
yield self._error_to_stream_response(err)
|
yield self.error_to_stream_response(err)
|
||||||
break
|
break
|
||||||
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
|
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
|
||||||
if isinstance(event, QueueMessageEndEvent):
|
if isinstance(event, QueueMessageEndEvent):
|
||||||
|
|
@ -277,7 +277,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||||
self._handle_stop(event)
|
self._handle_stop(event)
|
||||||
|
|
||||||
# handle output moderation
|
# handle output moderation
|
||||||
output_moderation_answer = self._handle_output_moderation_when_task_finished(
|
output_moderation_answer = self.handle_output_moderation_when_task_finished(
|
||||||
cast(str, self._task_state.llm_result.message.content)
|
cast(str, self._task_state.llm_result.message.content)
|
||||||
)
|
)
|
||||||
if output_moderation_answer:
|
if output_moderation_answer:
|
||||||
|
|
@ -354,7 +354,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||||
elif isinstance(event, QueueMessageReplaceEvent):
|
elif isinstance(event, QueueMessageReplaceEvent):
|
||||||
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text)
|
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text)
|
||||||
elif isinstance(event, QueuePingEvent):
|
elif isinstance(event, QueuePingEvent):
|
||||||
yield self._ping_stream_response()
|
yield self.ping_stream_response()
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
if publisher:
|
if publisher:
|
||||||
|
|
@ -394,7 +394,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||||
message.answer_tokens = usage.completion_tokens
|
message.answer_tokens = usage.completion_tokens
|
||||||
message.answer_unit_price = usage.completion_unit_price
|
message.answer_unit_price = usage.completion_unit_price
|
||||||
message.answer_price_unit = usage.completion_price_unit
|
message.answer_price_unit = usage.completion_price_unit
|
||||||
message.provider_response_latency = time.perf_counter() - self._start_at
|
message.provider_response_latency = time.perf_counter() - self.start_at
|
||||||
message.total_price = usage.total_price
|
message.total_price = usage.total_price
|
||||||
message.currency = usage.currency
|
message.currency = usage.currency
|
||||||
self._task_state.llm_result.usage.latency = message.provider_response_latency
|
self._task_state.llm_result.usage.latency = message.provider_response_latency
|
||||||
|
|
@ -438,7 +438,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||||
# transform usage
|
# transform usage
|
||||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||||
self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
|
self._task_state.llm_result.usage = model_type_instance.calc_response_usage(
|
||||||
model, credentials, prompt_tokens, completion_tokens
|
model, credentials, prompt_tokens, completion_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -498,10 +498,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||||
:param text: text
|
:param text: text
|
||||||
:return: True if output moderation should direct output, otherwise False
|
:return: True if output moderation should direct output, otherwise False
|
||||||
"""
|
"""
|
||||||
if self._output_moderation_handler:
|
if self.output_moderation_handler:
|
||||||
if self._output_moderation_handler.should_direct_output():
|
if self.output_moderation_handler.should_direct_output():
|
||||||
# stop subscribe new token when output moderation should direct output
|
# stop subscribe new token when output moderation should direct output
|
||||||
self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
|
self._task_state.llm_result.message.content = self.output_moderation_handler.get_final_output()
|
||||||
self.queue_manager.publish(
|
self.queue_manager.publish(
|
||||||
QueueLLMChunkEvent(
|
QueueLLMChunkEvent(
|
||||||
chunk=LLMResultChunk(
|
chunk=LLMResultChunk(
|
||||||
|
|
@ -521,6 +521,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
self._output_moderation_handler.append_new_token(text)
|
self.output_moderation_handler.append_new_token(text)
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
|
||||||
|
|
@ -72,7 +72,7 @@ class AppGeneratorTTSPublisher:
|
||||||
self.voice = voice
|
self.voice = voice
|
||||||
if not voice or voice not in values:
|
if not voice or voice not in values:
|
||||||
self.voice = self.voices[0].get("value")
|
self.voice = self.voices[0].get("value")
|
||||||
self.MAX_SENTENCE = 2
|
self.max_sentence = 2
|
||||||
self._last_audio_event: Optional[AudioTrunk] = None
|
self._last_audio_event: Optional[AudioTrunk] = None
|
||||||
# FIXME better way to handle this threading.start
|
# FIXME better way to handle this threading.start
|
||||||
threading.Thread(target=self._runtime).start()
|
threading.Thread(target=self._runtime).start()
|
||||||
|
|
@ -113,8 +113,8 @@ class AppGeneratorTTSPublisher:
|
||||||
self.msg_text += message.event.outputs.get("output", "")
|
self.msg_text += message.event.outputs.get("output", "")
|
||||||
self.last_message = message
|
self.last_message = message
|
||||||
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
|
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
|
||||||
if len(sentence_arr) >= min(self.MAX_SENTENCE, 7):
|
if len(sentence_arr) >= min(self.max_sentence, 7):
|
||||||
self.MAX_SENTENCE += 1
|
self.max_sentence += 1
|
||||||
text_content = "".join(sentence_arr)
|
text_content = "".join(sentence_arr)
|
||||||
futures_result = self.executor.submit(
|
futures_result = self.executor.submit(
|
||||||
_invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice
|
_invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice
|
||||||
|
|
|
||||||
|
|
@ -1840,8 +1840,14 @@ class ProviderConfigurations(BaseModel):
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
self.configurations[key] = value
|
self.configurations[key] = value
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
if "/" not in key:
|
||||||
|
key = str(ModelProviderID(key))
|
||||||
|
return key in self.configurations
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return iter(self.configurations)
|
# Return an iterator of (key, value) tuples to match BaseModel's __iter__
|
||||||
|
yield from self.configurations.items()
|
||||||
|
|
||||||
def values(self) -> Iterator[ProviderConfiguration]:
|
def values(self) -> Iterator[ProviderConfiguration]:
|
||||||
return iter(self.configurations.values())
|
return iter(self.configurations.values())
|
||||||
|
|
|
||||||
|
|
@ -98,7 +98,7 @@ def to_prompt_message_content(
|
||||||
|
|
||||||
def download(f: File, /):
|
def download(f: File, /):
|
||||||
if f.transfer_method in (FileTransferMethod.TOOL_FILE, FileTransferMethod.LOCAL_FILE):
|
if f.transfer_method in (FileTransferMethod.TOOL_FILE, FileTransferMethod.LOCAL_FILE):
|
||||||
return _download_file_content(f._storage_key)
|
return _download_file_content(f.storage_key)
|
||||||
elif f.transfer_method == FileTransferMethod.REMOTE_URL:
|
elif f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||||
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
@ -134,9 +134,9 @@ def _get_encoded_string(f: File, /):
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.content
|
data = response.content
|
||||||
case FileTransferMethod.LOCAL_FILE:
|
case FileTransferMethod.LOCAL_FILE:
|
||||||
data = _download_file_content(f._storage_key)
|
data = _download_file_content(f.storage_key)
|
||||||
case FileTransferMethod.TOOL_FILE:
|
case FileTransferMethod.TOOL_FILE:
|
||||||
data = _download_file_content(f._storage_key)
|
data = _download_file_content(f.storage_key)
|
||||||
|
|
||||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||||
return encoded_string
|
return encoded_string
|
||||||
|
|
|
||||||
|
|
@ -146,3 +146,11 @@ class File(BaseModel):
|
||||||
if not self.related_id:
|
if not self.related_id:
|
||||||
raise ValueError("Missing file related_id")
|
raise ValueError("Missing file related_id")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@property
|
||||||
|
def storage_key(self) -> str:
|
||||||
|
return self._storage_key
|
||||||
|
|
||||||
|
@storage_key.setter
|
||||||
|
def storage_key(self, value: str):
|
||||||
|
self._storage_key = value
|
||||||
|
|
|
||||||
|
|
@ -13,18 +13,18 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
|
SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
|
||||||
|
|
||||||
HTTP_REQUEST_NODE_SSL_VERIFY = True # Default value for HTTP_REQUEST_NODE_SSL_VERIFY is True
|
http_request_node_ssl_verify = True # Default value for http_request_node_ssl_verify is True
|
||||||
try:
|
try:
|
||||||
HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
|
config_value = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
|
||||||
http_request_node_ssl_verify_lower = str(HTTP_REQUEST_NODE_SSL_VERIFY).lower()
|
http_request_node_ssl_verify_lower = str(config_value).lower()
|
||||||
if http_request_node_ssl_verify_lower == "true":
|
if http_request_node_ssl_verify_lower == "true":
|
||||||
HTTP_REQUEST_NODE_SSL_VERIFY = True
|
http_request_node_ssl_verify = True
|
||||||
elif http_request_node_ssl_verify_lower == "false":
|
elif http_request_node_ssl_verify_lower == "false":
|
||||||
HTTP_REQUEST_NODE_SSL_VERIFY = False
|
http_request_node_ssl_verify = False
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid value. HTTP_REQUEST_NODE_SSL_VERIFY should be 'True' or 'False'")
|
raise ValueError("Invalid value. HTTP_REQUEST_NODE_SSL_VERIFY should be 'True' or 'False'")
|
||||||
except NameError:
|
except NameError:
|
||||||
HTTP_REQUEST_NODE_SSL_VERIFY = True
|
http_request_node_ssl_verify = True
|
||||||
|
|
||||||
BACKOFF_FACTOR = 0.5
|
BACKOFF_FACTOR = 0.5
|
||||||
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
||||||
|
|
@ -51,7 +51,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||||
)
|
)
|
||||||
|
|
||||||
if "ssl_verify" not in kwargs:
|
if "ssl_verify" not in kwargs:
|
||||||
kwargs["ssl_verify"] = HTTP_REQUEST_NODE_SSL_VERIFY
|
kwargs["ssl_verify"] = http_request_node_ssl_verify
|
||||||
|
|
||||||
ssl_verify = kwargs.pop("ssl_verify")
|
ssl_verify = kwargs.pop("ssl_verify")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -529,6 +529,7 @@ class IndexingRunner:
|
||||||
# chunk nodes by chunk size
|
# chunk nodes by chunk size
|
||||||
indexing_start_at = time.perf_counter()
|
indexing_start_at = time.perf_counter()
|
||||||
tokens = 0
|
tokens = 0
|
||||||
|
create_keyword_thread = None
|
||||||
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy":
|
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy":
|
||||||
# create keyword index
|
# create keyword index
|
||||||
create_keyword_thread = threading.Thread(
|
create_keyword_thread = threading.Thread(
|
||||||
|
|
@ -567,7 +568,11 @@ class IndexingRunner:
|
||||||
|
|
||||||
for future in futures:
|
for future in futures:
|
||||||
tokens += future.result()
|
tokens += future.result()
|
||||||
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy":
|
if (
|
||||||
|
dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX
|
||||||
|
and dataset.indexing_technique == "economy"
|
||||||
|
and create_keyword_thread is not None
|
||||||
|
):
|
||||||
create_keyword_thread.join()
|
create_keyword_thread.join()
|
||||||
indexing_end_at = time.perf_counter()
|
indexing_end_at = time.perf_counter()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ from core.llm_generator.prompts import (
|
||||||
)
|
)
|
||||||
from core.model_manager import ModelManager
|
from core.model_manager import ModelManager
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult
|
from core.model_runtime.entities.llm_entities import LLMResult
|
||||||
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
|
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||||
from core.ops.entities.trace_entity import TraceTaskName
|
from core.ops.entities.trace_entity import TraceTaskName
|
||||||
|
|
@ -313,14 +313,20 @@ class LLMGenerator:
|
||||||
model_type=ModelType.LLM,
|
model_type=ModelType.LLM,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)]
|
prompt_messages: list[PromptMessage] = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)]
|
||||||
|
|
||||||
response: LLMResult = model_instance.invoke_llm(
|
# Explicitly use the non-streaming overload
|
||||||
|
result = model_instance.invoke_llm(
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
model_parameters={"temperature": 0.01, "max_tokens": 2000},
|
model_parameters={"temperature": 0.01, "max_tokens": 2000},
|
||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Runtime type check since pyright has issues with the overload
|
||||||
|
if not isinstance(result, LLMResult):
|
||||||
|
raise TypeError("Expected LLMResult when stream=False")
|
||||||
|
response = result
|
||||||
|
|
||||||
answer = cast(str, response.message.content)
|
answer = cast(str, response.message.content)
|
||||||
return answer.strip()
|
return answer.strip()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,7 @@ class SpecialModelType(StrEnum):
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def invoke_llm_with_structured_output(
|
def invoke_llm_with_structured_output(
|
||||||
|
*,
|
||||||
provider: str,
|
provider: str,
|
||||||
model_schema: AIModelEntity,
|
model_schema: AIModelEntity,
|
||||||
model_instance: ModelInstance,
|
model_instance: ModelInstance,
|
||||||
|
|
@ -53,14 +54,13 @@ def invoke_llm_with_structured_output(
|
||||||
model_parameters: Optional[Mapping] = None,
|
model_parameters: Optional[Mapping] = None,
|
||||||
tools: Sequence[PromptMessageTool] | None = None,
|
tools: Sequence[PromptMessageTool] | None = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
stream: Literal[True] = True,
|
stream: Literal[True],
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: Optional[list[Callback]] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def invoke_llm_with_structured_output(
|
def invoke_llm_with_structured_output(
|
||||||
|
*,
|
||||||
provider: str,
|
provider: str,
|
||||||
model_schema: AIModelEntity,
|
model_schema: AIModelEntity,
|
||||||
model_instance: ModelInstance,
|
model_instance: ModelInstance,
|
||||||
|
|
@ -69,14 +69,13 @@ def invoke_llm_with_structured_output(
|
||||||
model_parameters: Optional[Mapping] = None,
|
model_parameters: Optional[Mapping] = None,
|
||||||
tools: Sequence[PromptMessageTool] | None = None,
|
tools: Sequence[PromptMessageTool] | None = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
stream: Literal[False] = False,
|
stream: Literal[False],
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: Optional[list[Callback]] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
) -> LLMResultWithStructuredOutput: ...
|
) -> LLMResultWithStructuredOutput: ...
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def invoke_llm_with_structured_output(
|
def invoke_llm_with_structured_output(
|
||||||
|
*,
|
||||||
provider: str,
|
provider: str,
|
||||||
model_schema: AIModelEntity,
|
model_schema: AIModelEntity,
|
||||||
model_instance: ModelInstance,
|
model_instance: ModelInstance,
|
||||||
|
|
@ -89,9 +88,8 @@ def invoke_llm_with_structured_output(
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: Optional[list[Callback]] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||||
|
|
||||||
|
|
||||||
def invoke_llm_with_structured_output(
|
def invoke_llm_with_structured_output(
|
||||||
|
*,
|
||||||
provider: str,
|
provider: str,
|
||||||
model_schema: AIModelEntity,
|
model_schema: AIModelEntity,
|
||||||
model_instance: ModelInstance,
|
model_instance: ModelInstance,
|
||||||
|
|
|
||||||
|
|
@ -23,13 +23,13 @@ DEFAULT_QUEUE_READ_TIMEOUT = 3
|
||||||
@final
|
@final
|
||||||
class _StatusReady:
|
class _StatusReady:
|
||||||
def __init__(self, endpoint_url: str):
|
def __init__(self, endpoint_url: str):
|
||||||
self._endpoint_url = endpoint_url
|
self.endpoint_url = endpoint_url
|
||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
class _StatusError:
|
class _StatusError:
|
||||||
def __init__(self, exc: Exception):
|
def __init__(self, exc: Exception):
|
||||||
self._exc = exc
|
self.exc = exc
|
||||||
|
|
||||||
|
|
||||||
# Type aliases for better readability
|
# Type aliases for better readability
|
||||||
|
|
@ -211,9 +211,9 @@ class SSETransport:
|
||||||
raise ValueError("failed to get endpoint URL")
|
raise ValueError("failed to get endpoint URL")
|
||||||
|
|
||||||
if isinstance(status, _StatusReady):
|
if isinstance(status, _StatusReady):
|
||||||
return status._endpoint_url
|
return status.endpoint_url
|
||||||
elif isinstance(status, _StatusError):
|
elif isinstance(status, _StatusError):
|
||||||
raise status._exc
|
raise status.exc
|
||||||
else:
|
else:
|
||||||
raise ValueError("failed to get endpoint URL")
|
raise ValueError("failed to get endpoint URL")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,7 @@ def handle_mcp_request(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
request_type = type(request.root)
|
request_type = type(request.root)
|
||||||
|
request_root = request.root
|
||||||
|
|
||||||
def create_success_response(result_data: mcp_types.Result) -> mcp_types.JSONRPCResponse:
|
def create_success_response(result_data: mcp_types.Result) -> mcp_types.JSONRPCResponse:
|
||||||
"""Create success response with business result data"""
|
"""Create success response with business result data"""
|
||||||
|
|
@ -58,21 +59,20 @@ def handle_mcp_request(
|
||||||
error=error_data,
|
error=error_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Request handler mapping using functional approach
|
|
||||||
request_handlers = {
|
|
||||||
mcp_types.InitializeRequest: lambda: handle_initialize(mcp_server.description),
|
|
||||||
mcp_types.ListToolsRequest: lambda: handle_list_tools(
|
|
||||||
app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict
|
|
||||||
),
|
|
||||||
mcp_types.CallToolRequest: lambda: handle_call_tool(app, request, user_input_form, end_user),
|
|
||||||
mcp_types.PingRequest: lambda: handle_ping(),
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Dispatch request to appropriate handler
|
# Dispatch request to appropriate handler based on instance type
|
||||||
handler = request_handlers.get(request_type)
|
if isinstance(request_root, mcp_types.InitializeRequest):
|
||||||
if handler:
|
return create_success_response(handle_initialize(mcp_server.description))
|
||||||
return create_success_response(handler())
|
elif isinstance(request_root, mcp_types.ListToolsRequest):
|
||||||
|
return create_success_response(
|
||||||
|
handle_list_tools(
|
||||||
|
app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif isinstance(request_root, mcp_types.CallToolRequest):
|
||||||
|
return create_success_response(handle_call_tool(app, request, user_input_form, end_user))
|
||||||
|
elif isinstance(request_root, mcp_types.PingRequest):
|
||||||
|
return create_success_response(handle_ping())
|
||||||
else:
|
else:
|
||||||
return create_error_response(mcp_types.METHOD_NOT_FOUND, f"Method not found: {request_type.__name__}")
|
return create_error_response(mcp_types.METHOD_NOT_FOUND, f"Method not found: {request_type.__name__}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -81,7 +81,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||||
self.request_meta = request_meta
|
self.request_meta = request_meta
|
||||||
self.request = request
|
self.request = request
|
||||||
self._session = session
|
self._session = session
|
||||||
self._completed = False
|
self.completed = False
|
||||||
self._on_complete = on_complete
|
self._on_complete = on_complete
|
||||||
self._entered = False # Track if we're in a context manager
|
self._entered = False # Track if we're in a context manager
|
||||||
|
|
||||||
|
|
@ -98,7 +98,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||||
):
|
):
|
||||||
"""Exit the context manager, performing cleanup and notifying completion."""
|
"""Exit the context manager, performing cleanup and notifying completion."""
|
||||||
try:
|
try:
|
||||||
if self._completed:
|
if self.completed:
|
||||||
self._on_complete(self)
|
self._on_complete(self)
|
||||||
finally:
|
finally:
|
||||||
self._entered = False
|
self._entered = False
|
||||||
|
|
@ -113,9 +113,9 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||||
"""
|
"""
|
||||||
if not self._entered:
|
if not self._entered:
|
||||||
raise RuntimeError("RequestResponder must be used as a context manager")
|
raise RuntimeError("RequestResponder must be used as a context manager")
|
||||||
assert not self._completed, "Request already responded to"
|
assert not self.completed, "Request already responded to"
|
||||||
|
|
||||||
self._completed = True
|
self.completed = True
|
||||||
|
|
||||||
self._session._send_response(request_id=self.request_id, response=response)
|
self._session._send_response(request_id=self.request_id, response=response)
|
||||||
|
|
||||||
|
|
@ -124,7 +124,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||||
if not self._entered:
|
if not self._entered:
|
||||||
raise RuntimeError("RequestResponder must be used as a context manager")
|
raise RuntimeError("RequestResponder must be used as a context manager")
|
||||||
|
|
||||||
self._completed = True # Mark as completed so it's removed from in_flight
|
self.completed = True # Mark as completed so it's removed from in_flight
|
||||||
# Send an error response to indicate cancellation
|
# Send an error response to indicate cancellation
|
||||||
self._session._send_response(
|
self._session._send_response(
|
||||||
request_id=self.request_id,
|
request_id=self.request_id,
|
||||||
|
|
@ -351,7 +351,7 @@ class BaseSession(
|
||||||
self._in_flight[responder.request_id] = responder
|
self._in_flight[responder.request_id] = responder
|
||||||
self._received_request(responder)
|
self._received_request(responder)
|
||||||
|
|
||||||
if not responder._completed:
|
if not responder.completed:
|
||||||
self._handle_incoming(responder)
|
self._handle_incoming(responder)
|
||||||
|
|
||||||
elif isinstance(message.message.root, JSONRPCNotification):
|
elif isinstance(message.message.root, JSONRPCNotification):
|
||||||
|
|
|
||||||
|
|
@ -354,7 +354,7 @@ class LargeLanguageModel(AIModel):
|
||||||
)
|
)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def _calc_response_usage(
|
def calc_response_usage(
|
||||||
self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int
|
self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int
|
||||||
) -> LLMUsage:
|
) -> LLMUsage:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import enum
|
import enum
|
||||||
|
import json
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
@ -162,8 +163,6 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
|
||||||
# Try to parse JSON string for arrays
|
# Try to parse JSON string for arrays
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
try:
|
try:
|
||||||
import json
|
|
||||||
|
|
||||||
parsed_value = json.loads(value)
|
parsed_value = json.loads(value)
|
||||||
if isinstance(parsed_value, list):
|
if isinstance(parsed_value, list):
|
||||||
return parsed_value
|
return parsed_value
|
||||||
|
|
@ -176,8 +175,6 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
|
||||||
# Try to parse JSON string for objects
|
# Try to parse JSON string for objects
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
try:
|
try:
|
||||||
import json
|
|
||||||
|
|
||||||
parsed_value = json.loads(value)
|
parsed_value = json.loads(value)
|
||||||
if isinstance(parsed_value, dict):
|
if isinstance(parsed_value, dict):
|
||||||
return parsed_value
|
return parsed_value
|
||||||
|
|
|
||||||
|
|
@ -82,7 +82,9 @@ def merge_blob_chunks(
|
||||||
message_class = type(resp)
|
message_class = type(resp)
|
||||||
merged_message = message_class(
|
merged_message = message_class(
|
||||||
type=ToolInvokeMessage.MessageType.BLOB,
|
type=ToolInvokeMessage.MessageType.BLOB,
|
||||||
message=ToolInvokeMessage.BlobMessage(blob=files[chunk_id].data[: files[chunk_id].bytes_written]),
|
message=ToolInvokeMessage.BlobMessage(
|
||||||
|
blob=bytes(files[chunk_id].data[: files[chunk_id].bytes_written])
|
||||||
|
),
|
||||||
meta=resp.meta,
|
meta=resp.meta,
|
||||||
)
|
)
|
||||||
yield cast(MessageType, merged_message)
|
yield cast(MessageType, merged_message)
|
||||||
|
|
|
||||||
|
|
@ -101,9 +101,22 @@ class SimplePromptTransform(PromptTransform):
|
||||||
with_memory_prompt=histories is not None,
|
with_memory_prompt=histories is not None,
|
||||||
)
|
)
|
||||||
|
|
||||||
variables = {k: inputs[k] for k in prompt_template_config["custom_variable_keys"] if k in inputs}
|
custom_variable_keys_obj = prompt_template_config["custom_variable_keys"]
|
||||||
|
special_variable_keys_obj = prompt_template_config["special_variable_keys"]
|
||||||
|
|
||||||
for v in prompt_template_config["special_variable_keys"]:
|
# Type check for custom_variable_keys
|
||||||
|
if not isinstance(custom_variable_keys_obj, list):
|
||||||
|
raise TypeError(f"Expected list for custom_variable_keys, got {type(custom_variable_keys_obj)}")
|
||||||
|
custom_variable_keys = cast(list[str], custom_variable_keys_obj)
|
||||||
|
|
||||||
|
# Type check for special_variable_keys
|
||||||
|
if not isinstance(special_variable_keys_obj, list):
|
||||||
|
raise TypeError(f"Expected list for special_variable_keys, got {type(special_variable_keys_obj)}")
|
||||||
|
special_variable_keys = cast(list[str], special_variable_keys_obj)
|
||||||
|
|
||||||
|
variables = {k: inputs[k] for k in custom_variable_keys if k in inputs}
|
||||||
|
|
||||||
|
for v in special_variable_keys:
|
||||||
# support #context#, #query# and #histories#
|
# support #context#, #query# and #histories#
|
||||||
if v == "#context#":
|
if v == "#context#":
|
||||||
variables["#context#"] = context or ""
|
variables["#context#"] = context or ""
|
||||||
|
|
@ -113,9 +126,16 @@ class SimplePromptTransform(PromptTransform):
|
||||||
variables["#histories#"] = histories or ""
|
variables["#histories#"] = histories or ""
|
||||||
|
|
||||||
prompt_template = prompt_template_config["prompt_template"]
|
prompt_template = prompt_template_config["prompt_template"]
|
||||||
|
if not isinstance(prompt_template, PromptTemplateParser):
|
||||||
|
raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template)}")
|
||||||
|
|
||||||
prompt = prompt_template.format(variables)
|
prompt = prompt_template.format(variables)
|
||||||
|
|
||||||
return prompt, prompt_template_config["prompt_rules"]
|
prompt_rules = prompt_template_config["prompt_rules"]
|
||||||
|
if not isinstance(prompt_rules, dict):
|
||||||
|
raise TypeError(f"Expected dict for prompt_rules, got {type(prompt_rules)}")
|
||||||
|
|
||||||
|
return prompt, prompt_rules
|
||||||
|
|
||||||
def get_prompt_template(
|
def get_prompt_template(
|
||||||
self,
|
self,
|
||||||
|
|
@ -126,11 +146,11 @@ class SimplePromptTransform(PromptTransform):
|
||||||
has_context: bool,
|
has_context: bool,
|
||||||
query_in_prompt: bool,
|
query_in_prompt: bool,
|
||||||
with_memory_prompt: bool = False,
|
with_memory_prompt: bool = False,
|
||||||
):
|
) -> dict[str, object]:
|
||||||
prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model)
|
prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model)
|
||||||
|
|
||||||
custom_variable_keys = []
|
custom_variable_keys: list[str] = []
|
||||||
special_variable_keys = []
|
special_variable_keys: list[str] = []
|
||||||
|
|
||||||
prompt = ""
|
prompt = ""
|
||||||
for order in prompt_rules["system_prompt_orders"]:
|
for order in prompt_rules["system_prompt_orders"]:
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,19 @@ if TYPE_CHECKING:
|
||||||
MetadataFilter = Union[DictFilter, common_types.Filter]
|
MetadataFilter = Union[DictFilter, common_types.Filter]
|
||||||
|
|
||||||
|
|
||||||
|
class PathQdrantParams(BaseModel):
|
||||||
|
path: str
|
||||||
|
|
||||||
|
|
||||||
|
class UrlQdrantParams(BaseModel):
|
||||||
|
url: str
|
||||||
|
api_key: Optional[str]
|
||||||
|
timeout: float
|
||||||
|
verify: bool
|
||||||
|
grpc_port: int
|
||||||
|
prefer_grpc: bool
|
||||||
|
|
||||||
|
|
||||||
class QdrantConfig(BaseModel):
|
class QdrantConfig(BaseModel):
|
||||||
endpoint: str
|
endpoint: str
|
||||||
api_key: Optional[str] = None
|
api_key: Optional[str] = None
|
||||||
|
|
@ -50,7 +63,7 @@ class QdrantConfig(BaseModel):
|
||||||
replication_factor: int = 1
|
replication_factor: int = 1
|
||||||
write_consistency_factor: int = 1
|
write_consistency_factor: int = 1
|
||||||
|
|
||||||
def to_qdrant_params(self):
|
def to_qdrant_params(self) -> PathQdrantParams | UrlQdrantParams:
|
||||||
if self.endpoint and self.endpoint.startswith("path:"):
|
if self.endpoint and self.endpoint.startswith("path:"):
|
||||||
path = self.endpoint.replace("path:", "")
|
path = self.endpoint.replace("path:", "")
|
||||||
if not os.path.isabs(path):
|
if not os.path.isabs(path):
|
||||||
|
|
@ -58,23 +71,23 @@ class QdrantConfig(BaseModel):
|
||||||
raise ValueError("Root path is not set")
|
raise ValueError("Root path is not set")
|
||||||
path = os.path.join(self.root_path, path)
|
path = os.path.join(self.root_path, path)
|
||||||
|
|
||||||
return {"path": path}
|
return PathQdrantParams(path=path)
|
||||||
else:
|
else:
|
||||||
return {
|
return UrlQdrantParams(
|
||||||
"url": self.endpoint,
|
url=self.endpoint,
|
||||||
"api_key": self.api_key,
|
api_key=self.api_key,
|
||||||
"timeout": self.timeout,
|
timeout=self.timeout,
|
||||||
"verify": self.endpoint.startswith("https"),
|
verify=self.endpoint.startswith("https"),
|
||||||
"grpc_port": self.grpc_port,
|
grpc_port=self.grpc_port,
|
||||||
"prefer_grpc": self.prefer_grpc,
|
prefer_grpc=self.prefer_grpc,
|
||||||
}
|
)
|
||||||
|
|
||||||
|
|
||||||
class QdrantVector(BaseVector):
|
class QdrantVector(BaseVector):
|
||||||
def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = "Cosine"):
|
def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = "Cosine"):
|
||||||
super().__init__(collection_name)
|
super().__init__(collection_name)
|
||||||
self._client_config = config
|
self._client_config = config
|
||||||
self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params())
|
self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params().model_dump())
|
||||||
self._distance_func = distance_func.upper()
|
self._distance_func = distance_func.upper()
|
||||||
self._group_id = group_id
|
self._group_id = group_id
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -94,10 +94,10 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
|
||||||
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
|
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
|
||||||
|
|
||||||
# In-memory cache for workflow node executions
|
# In-memory cache for workflow node executions
|
||||||
self._execution_cache: dict[str, WorkflowNodeExecution] = {}
|
self._execution_cache = {}
|
||||||
|
|
||||||
# Cache for mapping workflow_execution_ids to execution IDs for efficient retrieval
|
# Cache for mapping workflow_execution_ids to execution IDs for efficient retrieval
|
||||||
self._workflow_execution_mapping: dict[str, list[str]] = {}
|
self._workflow_execution_mapping = {}
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Initialized CeleryWorkflowNodeExecutionRepository for tenant %s, app %s, triggered_from %s",
|
"Initialized CeleryWorkflowNodeExecutionRepository for tenant %s, app %s, triggered_from %s",
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from .types import SegmentType
|
||||||
|
|
||||||
class SegmentGroup(Segment):
|
class SegmentGroup(Segment):
|
||||||
value_type: SegmentType = SegmentType.GROUP
|
value_type: SegmentType = SegmentType.GROUP
|
||||||
value: list[Segment]
|
value: list[Segment] = None # type: ignore
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text(self):
|
def text(self):
|
||||||
|
|
|
||||||
|
|
@ -74,12 +74,12 @@ class NoneSegment(Segment):
|
||||||
|
|
||||||
class StringSegment(Segment):
|
class StringSegment(Segment):
|
||||||
value_type: SegmentType = SegmentType.STRING
|
value_type: SegmentType = SegmentType.STRING
|
||||||
value: str
|
value: str = None # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class FloatSegment(Segment):
|
class FloatSegment(Segment):
|
||||||
value_type: SegmentType = SegmentType.FLOAT
|
value_type: SegmentType = SegmentType.FLOAT
|
||||||
value: float
|
value: float = None # type: ignore
|
||||||
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
|
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
|
||||||
# The following tests cannot pass.
|
# The following tests cannot pass.
|
||||||
#
|
#
|
||||||
|
|
@ -98,12 +98,12 @@ class FloatSegment(Segment):
|
||||||
|
|
||||||
class IntegerSegment(Segment):
|
class IntegerSegment(Segment):
|
||||||
value_type: SegmentType = SegmentType.INTEGER
|
value_type: SegmentType = SegmentType.INTEGER
|
||||||
value: int
|
value: int = None # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class ObjectSegment(Segment):
|
class ObjectSegment(Segment):
|
||||||
value_type: SegmentType = SegmentType.OBJECT
|
value_type: SegmentType = SegmentType.OBJECT
|
||||||
value: Mapping[str, Any]
|
value: Mapping[str, Any] = None # type: ignore
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text(self) -> str:
|
def text(self) -> str:
|
||||||
|
|
@ -136,7 +136,7 @@ class ArraySegment(Segment):
|
||||||
|
|
||||||
class FileSegment(Segment):
|
class FileSegment(Segment):
|
||||||
value_type: SegmentType = SegmentType.FILE
|
value_type: SegmentType = SegmentType.FILE
|
||||||
value: File
|
value: File = None # type: ignore
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def markdown(self) -> str:
|
def markdown(self) -> str:
|
||||||
|
|
@ -153,17 +153,17 @@ class FileSegment(Segment):
|
||||||
|
|
||||||
class BooleanSegment(Segment):
|
class BooleanSegment(Segment):
|
||||||
value_type: SegmentType = SegmentType.BOOLEAN
|
value_type: SegmentType = SegmentType.BOOLEAN
|
||||||
value: bool
|
value: bool = None # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class ArrayAnySegment(ArraySegment):
|
class ArrayAnySegment(ArraySegment):
|
||||||
value_type: SegmentType = SegmentType.ARRAY_ANY
|
value_type: SegmentType = SegmentType.ARRAY_ANY
|
||||||
value: Sequence[Any]
|
value: Sequence[Any] = None # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class ArrayStringSegment(ArraySegment):
|
class ArrayStringSegment(ArraySegment):
|
||||||
value_type: SegmentType = SegmentType.ARRAY_STRING
|
value_type: SegmentType = SegmentType.ARRAY_STRING
|
||||||
value: Sequence[str]
|
value: Sequence[str] = None # type: ignore
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text(self) -> str:
|
def text(self) -> str:
|
||||||
|
|
@ -175,17 +175,17 @@ class ArrayStringSegment(ArraySegment):
|
||||||
|
|
||||||
class ArrayNumberSegment(ArraySegment):
|
class ArrayNumberSegment(ArraySegment):
|
||||||
value_type: SegmentType = SegmentType.ARRAY_NUMBER
|
value_type: SegmentType = SegmentType.ARRAY_NUMBER
|
||||||
value: Sequence[float | int]
|
value: Sequence[float | int] = None # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class ArrayObjectSegment(ArraySegment):
|
class ArrayObjectSegment(ArraySegment):
|
||||||
value_type: SegmentType = SegmentType.ARRAY_OBJECT
|
value_type: SegmentType = SegmentType.ARRAY_OBJECT
|
||||||
value: Sequence[Mapping[str, Any]]
|
value: Sequence[Mapping[str, Any]] = None # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class ArrayFileSegment(ArraySegment):
|
class ArrayFileSegment(ArraySegment):
|
||||||
value_type: SegmentType = SegmentType.ARRAY_FILE
|
value_type: SegmentType = SegmentType.ARRAY_FILE
|
||||||
value: Sequence[File]
|
value: Sequence[File] = None # type: ignore
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def markdown(self) -> str:
|
def markdown(self) -> str:
|
||||||
|
|
@ -205,7 +205,7 @@ class ArrayFileSegment(ArraySegment):
|
||||||
|
|
||||||
class ArrayBooleanSegment(ArraySegment):
|
class ArrayBooleanSegment(ArraySegment):
|
||||||
value_type: SegmentType = SegmentType.ARRAY_BOOLEAN
|
value_type: SegmentType = SegmentType.ARRAY_BOOLEAN
|
||||||
value: Sequence[bool]
|
value: Sequence[bool] = None # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def get_segment_discriminator(v: Any) -> SegmentType | None:
|
def get_segment_discriminator(v: Any) -> SegmentType | None:
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,6 @@ from core.workflow.nodes.base import BaseNode
|
||||||
|
|
||||||
class WorkflowNodeRunFailedError(Exception):
|
class WorkflowNodeRunFailedError(Exception):
|
||||||
def __init__(self, node: BaseNode, err_msg: str):
|
def __init__(self, node: BaseNode, err_msg: str):
|
||||||
self._node = node
|
self.node = node
|
||||||
self._error = err_msg
|
self.error = err_msg
|
||||||
super().__init__(f"Node {node.title} run failed: {err_msg}")
|
super().__init__(f"Node {node.title} run failed: {err_msg}")
|
||||||
|
|
|
||||||
|
|
@ -67,8 +67,8 @@ class ListOperatorNode(BaseNode):
|
||||||
return "1"
|
return "1"
|
||||||
|
|
||||||
def _run(self):
|
def _run(self):
|
||||||
inputs: dict[str, list] = {}
|
inputs: dict[str, Sequence[object]] = {}
|
||||||
process_data: dict[str, list] = {}
|
process_data: dict[str, Sequence[object]] = {}
|
||||||
outputs: dict[str, Any] = {}
|
outputs: dict[str, Any] = {}
|
||||||
|
|
||||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable)
|
variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable)
|
||||||
|
|
|
||||||
|
|
@ -1183,7 +1183,8 @@ def _combine_message_content_with_role(
|
||||||
return AssistantPromptMessage(content=contents)
|
return AssistantPromptMessage(content=contents)
|
||||||
case PromptMessageRole.SYSTEM:
|
case PromptMessageRole.SYSTEM:
|
||||||
return SystemPromptMessage(content=contents)
|
return SystemPromptMessage(content=contents)
|
||||||
raise NotImplementedError(f"Role {role} is not supported")
|
case _:
|
||||||
|
raise NotImplementedError(f"Role {role} is not supported")
|
||||||
|
|
||||||
|
|
||||||
def _render_jinja2_message(
|
def _render_jinja2_message(
|
||||||
|
|
|
||||||
|
|
@ -462,9 +462,9 @@ class StorageKeyLoader:
|
||||||
upload_file_row = upload_files.get(model_id)
|
upload_file_row = upload_files.get(model_id)
|
||||||
if upload_file_row is None:
|
if upload_file_row is None:
|
||||||
raise ValueError(f"Upload file not found for id: {model_id}")
|
raise ValueError(f"Upload file not found for id: {model_id}")
|
||||||
file._storage_key = upload_file_row.key
|
file.storage_key = upload_file_row.key
|
||||||
elif file.transfer_method == FileTransferMethod.TOOL_FILE:
|
elif file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||||
tool_file_row = tool_files.get(model_id)
|
tool_file_row = tool_files.get(model_id)
|
||||||
if tool_file_row is None:
|
if tool_file_row is None:
|
||||||
raise ValueError(f"Tool file not found for id: {model_id}")
|
raise ValueError(f"Tool file not found for id: {model_id}")
|
||||||
file._storage_key = tool_file_row.file_key
|
file.storage_key = tool_file_row.file_key
|
||||||
|
|
|
||||||
|
|
@ -12,4 +12,7 @@ def serialize_value_type(v: _VarTypedDict | Segment) -> str:
|
||||||
if isinstance(v, Segment):
|
if isinstance(v, Segment):
|
||||||
return v.value_type.exposed_type().value
|
return v.value_type.exposed_type().value
|
||||||
else:
|
else:
|
||||||
return v["value_type"].exposed_type().value
|
value_type = v.get("value_type")
|
||||||
|
if value_type is None:
|
||||||
|
raise ValueError("value_type is required but not provided")
|
||||||
|
return value_type.exposed_type().value
|
||||||
|
|
|
||||||
|
|
@ -69,6 +69,8 @@ def register_external_error_handlers(api: Api):
|
||||||
headers["WWW-Authenticate"] = 'Bearer realm="api"'
|
headers["WWW-Authenticate"] = 'Bearer realm="api"'
|
||||||
return data, status_code, headers
|
return data, status_code, headers
|
||||||
|
|
||||||
|
_ = handle_http_exception
|
||||||
|
|
||||||
@api.errorhandler(ValueError)
|
@api.errorhandler(ValueError)
|
||||||
def handle_value_error(e: ValueError):
|
def handle_value_error(e: ValueError):
|
||||||
got_request_exception.send(current_app, exception=e)
|
got_request_exception.send(current_app, exception=e)
|
||||||
|
|
@ -76,6 +78,8 @@ def register_external_error_handlers(api: Api):
|
||||||
data = {"code": "invalid_param", "message": str(e), "status": status_code}
|
data = {"code": "invalid_param", "message": str(e), "status": status_code}
|
||||||
return data, status_code
|
return data, status_code
|
||||||
|
|
||||||
|
_ = handle_value_error
|
||||||
|
|
||||||
@api.errorhandler(AppInvokeQuotaExceededError)
|
@api.errorhandler(AppInvokeQuotaExceededError)
|
||||||
def handle_quota_exceeded(e: AppInvokeQuotaExceededError):
|
def handle_quota_exceeded(e: AppInvokeQuotaExceededError):
|
||||||
got_request_exception.send(current_app, exception=e)
|
got_request_exception.send(current_app, exception=e)
|
||||||
|
|
@ -83,15 +87,17 @@ def register_external_error_handlers(api: Api):
|
||||||
data = {"code": "too_many_requests", "message": str(e), "status": status_code}
|
data = {"code": "too_many_requests", "message": str(e), "status": status_code}
|
||||||
return data, status_code
|
return data, status_code
|
||||||
|
|
||||||
|
_ = handle_quota_exceeded
|
||||||
|
|
||||||
@api.errorhandler(Exception)
|
@api.errorhandler(Exception)
|
||||||
def handle_general_exception(e: Exception):
|
def handle_general_exception(e: Exception):
|
||||||
got_request_exception.send(current_app, exception=e)
|
got_request_exception.send(current_app, exception=e)
|
||||||
|
|
||||||
status_code = 500
|
status_code = 500
|
||||||
data: dict[str, Any] = getattr(e, "data", {"message": http_status_message(status_code)})
|
data = getattr(e, "data", {"message": http_status_message(status_code)})
|
||||||
|
|
||||||
# 🔒 Normalize non-mapping data (e.g., if someone set e.data = Response)
|
# 🔒 Normalize non-mapping data (e.g., if someone set e.data = Response)
|
||||||
if not isinstance(data, Mapping):
|
if not isinstance(data, dict):
|
||||||
data = {"message": str(e)}
|
data = {"message": str(e)}
|
||||||
|
|
||||||
data.setdefault("code", "unknown")
|
data.setdefault("code", "unknown")
|
||||||
|
|
@ -101,10 +107,12 @@ def register_external_error_handlers(api: Api):
|
||||||
exc_info: Any = sys.exc_info()
|
exc_info: Any = sys.exc_info()
|
||||||
if exc_info[1] is None:
|
if exc_info[1] is None:
|
||||||
exc_info = None
|
exc_info = None
|
||||||
current_app.log_exception(exc_info) # ty: ignore [invalid-argument-type]
|
current_app.log_exception(exc_info)
|
||||||
|
|
||||||
return data, status_code
|
return data, status_code
|
||||||
|
|
||||||
|
_ = handle_general_exception
|
||||||
|
|
||||||
|
|
||||||
class ExternalApi(Api):
|
class ExternalApi(Api):
|
||||||
_authorizations = {
|
_authorizations = {
|
||||||
|
|
|
||||||
|
|
@ -167,13 +167,6 @@ class DatetimeString:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
def _get_float(value):
|
|
||||||
try:
|
|
||||||
return float(value)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
raise ValueError(f"{value} is not a valid float")
|
|
||||||
|
|
||||||
|
|
||||||
def timezone(timezone_string):
|
def timezone(timezone_string):
|
||||||
if timezone_string and timezone_string in available_timezones():
|
if timezone_string and timezone_string in available_timezones():
|
||||||
return timezone_string
|
return timezone_string
|
||||||
|
|
|
||||||
|
|
@ -1,24 +1,44 @@
|
||||||
{
|
{
|
||||||
"include": ["."],
|
"include": ["."],
|
||||||
"exclude": [".venv", "tests/", "migrations/"],
|
"exclude": [
|
||||||
"ignore": [
|
".venv",
|
||||||
"core/",
|
"tests/",
|
||||||
"controllers/",
|
"migrations/",
|
||||||
"tasks/",
|
"core/rag",
|
||||||
"services/",
|
"extensions",
|
||||||
"schedule/",
|
"libs",
|
||||||
"extensions/",
|
"controllers/console/datasets",
|
||||||
"utils/",
|
"controllers/service_api/dataset",
|
||||||
"repositories/",
|
"core/ops",
|
||||||
"libs/",
|
"core/tools",
|
||||||
"fields/",
|
"core/model_runtime",
|
||||||
"factories/",
|
"core/workflow",
|
||||||
"events/",
|
"core/app/app_config/easy_ui_based_app/dataset"
|
||||||
"contexts/",
|
|
||||||
"constants/",
|
|
||||||
"commands.py"
|
|
||||||
],
|
],
|
||||||
"typeCheckingMode": "strict",
|
"typeCheckingMode": "strict",
|
||||||
|
"allowedUntypedLibraries": [
|
||||||
|
"flask_restx",
|
||||||
|
"flask_login",
|
||||||
|
"opentelemetry.instrumentation.celery",
|
||||||
|
"opentelemetry.instrumentation.flask",
|
||||||
|
"opentelemetry.instrumentation.requests",
|
||||||
|
"opentelemetry.instrumentation.sqlalchemy",
|
||||||
|
"opentelemetry.instrumentation.redis"
|
||||||
|
],
|
||||||
|
"reportUnknownMemberType": "hint",
|
||||||
|
"reportUnknownParameterType": "hint",
|
||||||
|
"reportUnknownArgumentType": "hint",
|
||||||
|
"reportUnknownVariableType": "hint",
|
||||||
|
"reportUnknownLambdaType": "hint",
|
||||||
|
"reportMissingParameterType": "hint",
|
||||||
|
"reportMissingTypeArgument": "hint",
|
||||||
|
"reportUnnecessaryContains": "hint",
|
||||||
|
"reportUnnecessaryComparison": "hint",
|
||||||
|
"reportUnnecessaryCast": "hint",
|
||||||
|
"reportUnnecessaryIsInstance": "hint",
|
||||||
|
"reportUntypedFunctionDecorator": "hint",
|
||||||
|
|
||||||
|
"reportAttributeAccessIssue": "hint",
|
||||||
"pythonVersion": "3.11",
|
"pythonVersion": "3.11",
|
||||||
"pythonPlatform": "All"
|
"pythonPlatform": "All"
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1318,7 +1318,7 @@ class RegisterService:
|
||||||
def get_invitation_if_token_valid(
|
def get_invitation_if_token_valid(
|
||||||
cls, workspace_id: Optional[str], email: str, token: str
|
cls, workspace_id: Optional[str], email: str, token: str
|
||||||
) -> Optional[dict[str, Any]]:
|
) -> Optional[dict[str, Any]]:
|
||||||
invitation_data = cls._get_invitation_by_token(token, workspace_id, email)
|
invitation_data = cls.get_invitation_by_token(token, workspace_id, email)
|
||||||
if not invitation_data:
|
if not invitation_data:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -1355,7 +1355,7 @@ class RegisterService:
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_invitation_by_token(
|
def get_invitation_by_token(
|
||||||
cls, token: str, workspace_id: Optional[str] = None, email: Optional[str] = None
|
cls, token: str, workspace_id: Optional[str] = None, email: Optional[str] = None
|
||||||
) -> Optional[dict[str, str]]:
|
) -> Optional[dict[str, str]]:
|
||||||
if workspace_id is not None and email is not None:
|
if workspace_id is not None and email is not None:
|
||||||
|
|
|
||||||
|
|
@ -349,7 +349,7 @@ class AppAnnotationService:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Skip the first row
|
# Skip the first row
|
||||||
df = pd.read_csv(file, dtype=str)
|
df = pd.read_csv(file.stream, dtype=str)
|
||||||
result = []
|
result = []
|
||||||
for _, row in df.iterrows():
|
for _, row in df.iterrows():
|
||||||
content = {"question": row.iloc[0], "answer": row.iloc[1]}
|
content = {"question": row.iloc[0], "answer": row.iloc[1]}
|
||||||
|
|
@ -463,15 +463,23 @@ class AppAnnotationService:
|
||||||
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
|
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
|
||||||
if annotation_setting:
|
if annotation_setting:
|
||||||
collection_binding_detail = annotation_setting.collection_binding_detail
|
collection_binding_detail = annotation_setting.collection_binding_detail
|
||||||
return {
|
if collection_binding_detail:
|
||||||
"id": annotation_setting.id,
|
return {
|
||||||
"enabled": True,
|
"id": annotation_setting.id,
|
||||||
"score_threshold": annotation_setting.score_threshold,
|
"enabled": True,
|
||||||
"embedding_model": {
|
"score_threshold": annotation_setting.score_threshold,
|
||||||
"embedding_provider_name": collection_binding_detail.provider_name,
|
"embedding_model": {
|
||||||
"embedding_model_name": collection_binding_detail.model_name,
|
"embedding_provider_name": collection_binding_detail.provider_name,
|
||||||
},
|
"embedding_model_name": collection_binding_detail.model_name,
|
||||||
}
|
},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"id": annotation_setting.id,
|
||||||
|
"enabled": True,
|
||||||
|
"score_threshold": annotation_setting.score_threshold,
|
||||||
|
"embedding_model": {},
|
||||||
|
}
|
||||||
return {"enabled": False}
|
return {"enabled": False}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -506,15 +514,23 @@ class AppAnnotationService:
|
||||||
|
|
||||||
collection_binding_detail = annotation_setting.collection_binding_detail
|
collection_binding_detail = annotation_setting.collection_binding_detail
|
||||||
|
|
||||||
return {
|
if collection_binding_detail:
|
||||||
"id": annotation_setting.id,
|
return {
|
||||||
"enabled": True,
|
"id": annotation_setting.id,
|
||||||
"score_threshold": annotation_setting.score_threshold,
|
"enabled": True,
|
||||||
"embedding_model": {
|
"score_threshold": annotation_setting.score_threshold,
|
||||||
"embedding_provider_name": collection_binding_detail.provider_name,
|
"embedding_model": {
|
||||||
"embedding_model_name": collection_binding_detail.model_name,
|
"embedding_provider_name": collection_binding_detail.provider_name,
|
||||||
},
|
"embedding_model_name": collection_binding_detail.model_name,
|
||||||
}
|
},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"id": annotation_setting.id,
|
||||||
|
"enabled": True,
|
||||||
|
"score_threshold": annotation_setting.score_threshold,
|
||||||
|
"embedding_model": {},
|
||||||
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def clear_all_annotations(cls, app_id: str):
|
def clear_all_annotations(cls, app_id: str):
|
||||||
|
|
|
||||||
|
|
@ -407,6 +407,7 @@ class ClearFreePlanTenantExpiredLogs:
|
||||||
datetime.timedelta(hours=1),
|
datetime.timedelta(hours=1),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
tenant_count = 0
|
||||||
for test_interval in test_intervals:
|
for test_interval in test_intervals:
|
||||||
tenant_count = (
|
tenant_count = (
|
||||||
session.query(Tenant.id)
|
session.query(Tenant.id)
|
||||||
|
|
|
||||||
|
|
@ -134,11 +134,14 @@ class DatasetService:
|
||||||
|
|
||||||
# Check if tag_ids is not empty to avoid WHERE false condition
|
# Check if tag_ids is not empty to avoid WHERE false condition
|
||||||
if tag_ids and len(tag_ids) > 0:
|
if tag_ids and len(tag_ids) > 0:
|
||||||
target_ids = TagService.get_target_ids_by_tag_ids(
|
if tenant_id is not None:
|
||||||
"knowledge",
|
target_ids = TagService.get_target_ids_by_tag_ids(
|
||||||
tenant_id, # ty: ignore [invalid-argument-type]
|
"knowledge",
|
||||||
tag_ids,
|
tenant_id,
|
||||||
)
|
tag_ids,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
target_ids = []
|
||||||
if target_ids and len(target_ids) > 0:
|
if target_ids and len(target_ids) > 0:
|
||||||
query = query.where(Dataset.id.in_(target_ids))
|
query = query.where(Dataset.id.in_(target_ids))
|
||||||
else:
|
else:
|
||||||
|
|
@ -987,7 +990,8 @@ class DocumentService:
|
||||||
for document in documents
|
for document in documents
|
||||||
if document.data_source_type == "upload_file" and document.data_source_info_dict
|
if document.data_source_type == "upload_file" and document.data_source_info_dict
|
||||||
]
|
]
|
||||||
batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids)
|
if dataset.doc_form is not None:
|
||||||
|
batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids)
|
||||||
|
|
||||||
for document in documents:
|
for document in documents:
|
||||||
db.session.delete(document)
|
db.session.delete(document)
|
||||||
|
|
@ -2688,56 +2692,6 @@ class SegmentService:
|
||||||
|
|
||||||
return paginated_segments.items, paginated_segments.total
|
return paginated_segments.items, paginated_segments.total
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def update_segment_by_id(
|
|
||||||
cls, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, segment_data: dict, user_id: str
|
|
||||||
) -> tuple[DocumentSegment, Document]:
|
|
||||||
"""Update a segment by its ID with validation and checks."""
|
|
||||||
# check dataset
|
|
||||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
|
||||||
if not dataset:
|
|
||||||
raise NotFound("Dataset not found.")
|
|
||||||
|
|
||||||
# check user's model setting
|
|
||||||
DatasetService.check_dataset_model_setting(dataset)
|
|
||||||
|
|
||||||
# check document
|
|
||||||
document = DocumentService.get_document(dataset_id, document_id)
|
|
||||||
if not document:
|
|
||||||
raise NotFound("Document not found.")
|
|
||||||
|
|
||||||
# check embedding model setting if high quality
|
|
||||||
if dataset.indexing_technique == "high_quality":
|
|
||||||
try:
|
|
||||||
model_manager = ModelManager()
|
|
||||||
model_manager.get_model_instance(
|
|
||||||
tenant_id=user_id,
|
|
||||||
provider=dataset.embedding_model_provider,
|
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
|
||||||
model=dataset.embedding_model,
|
|
||||||
)
|
|
||||||
except LLMBadRequestError:
|
|
||||||
raise ValueError(
|
|
||||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
|
||||||
)
|
|
||||||
except ProviderTokenNotInitError as ex:
|
|
||||||
raise ValueError(ex.description)
|
|
||||||
|
|
||||||
# check segment
|
|
||||||
segment = (
|
|
||||||
db.session.query(DocumentSegment)
|
|
||||||
.where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if not segment:
|
|
||||||
raise NotFound("Segment not found.")
|
|
||||||
|
|
||||||
# validate and update segment
|
|
||||||
cls.segment_create_args_validate(segment_data, document)
|
|
||||||
updated_segment = cls.update_segment(SegmentUpdateArgs(**segment_data), segment, document, dataset)
|
|
||||||
|
|
||||||
return updated_segment, document
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]:
|
def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]:
|
||||||
"""Get a segment by its ID."""
|
"""Get a segment by its ID."""
|
||||||
|
|
|
||||||
|
|
@ -181,7 +181,7 @@ class ExternalDatasetService:
|
||||||
do http request depending on api bundle
|
do http request depending on api bundle
|
||||||
"""
|
"""
|
||||||
|
|
||||||
kwargs = {
|
kwargs: dict[str, Any] = {
|
||||||
"url": settings.url,
|
"url": settings.url,
|
||||||
"headers": settings.headers,
|
"headers": settings.headers,
|
||||||
"follow_redirects": True,
|
"follow_redirects": True,
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Literal, Union
|
from typing import Literal, Union
|
||||||
|
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
|
|
@ -35,7 +35,7 @@ class FileService:
|
||||||
filename: str,
|
filename: str,
|
||||||
content: bytes,
|
content: bytes,
|
||||||
mimetype: str,
|
mimetype: str,
|
||||||
user: Union[Account, EndUser, Any],
|
user: Union[Account, EndUser],
|
||||||
source: Literal["datasets"] | None = None,
|
source: Literal["datasets"] | None = None,
|
||||||
source_url: str = "",
|
source_url: str = "",
|
||||||
) -> UploadFile:
|
) -> UploadFile:
|
||||||
|
|
|
||||||
|
|
@ -165,7 +165,7 @@ class ModelLoadBalancingService:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if load_balancing_config.encrypted_config:
|
if load_balancing_config.encrypted_config:
|
||||||
credentials = json.loads(load_balancing_config.encrypted_config)
|
credentials: dict[str, object] = json.loads(load_balancing_config.encrypted_config)
|
||||||
else:
|
else:
|
||||||
credentials = {}
|
credentials = {}
|
||||||
except JSONDecodeError:
|
except JSONDecodeError:
|
||||||
|
|
@ -180,11 +180,13 @@ class ModelLoadBalancingService:
|
||||||
for variable in credential_secret_variables:
|
for variable in credential_secret_variables:
|
||||||
if variable in credentials:
|
if variable in credentials:
|
||||||
try:
|
try:
|
||||||
credentials[variable] = encrypter.decrypt_token_with_decoding(
|
token_value = credentials.get(variable)
|
||||||
credentials.get(variable), # ty: ignore [invalid-argument-type]
|
if isinstance(token_value, str):
|
||||||
decoding_rsa_key,
|
credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||||
decoding_cipher_rsa,
|
token_value,
|
||||||
)
|
decoding_rsa_key,
|
||||||
|
decoding_cipher_rsa,
|
||||||
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
@ -345,8 +347,9 @@ class ModelLoadBalancingService:
|
||||||
credential_id = config.get("credential_id")
|
credential_id = config.get("credential_id")
|
||||||
enabled = config.get("enabled")
|
enabled = config.get("enabled")
|
||||||
|
|
||||||
|
credential_record: ProviderCredential | ProviderModelCredential | None = None
|
||||||
|
|
||||||
if credential_id:
|
if credential_id:
|
||||||
credential_record: ProviderCredential | ProviderModelCredential | None = None
|
|
||||||
if config_from == "predefined-model":
|
if config_from == "predefined-model":
|
||||||
credential_record = (
|
credential_record = (
|
||||||
db.session.query(ProviderCredential)
|
db.session.query(ProviderCredential)
|
||||||
|
|
|
||||||
|
|
@ -99,6 +99,7 @@ class PluginMigration:
|
||||||
datetime.timedelta(hours=1),
|
datetime.timedelta(hours=1),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
tenant_count = 0
|
||||||
for test_interval in test_intervals:
|
for test_interval in test_intervals:
|
||||||
tenant_count = (
|
tenant_count = (
|
||||||
session.query(Tenant.id)
|
session.query(Tenant.id)
|
||||||
|
|
|
||||||
|
|
@ -223,8 +223,8 @@ class BuiltinToolManageService:
|
||||||
"""
|
"""
|
||||||
add builtin tool provider
|
add builtin tool provider
|
||||||
"""
|
"""
|
||||||
try:
|
with Session(db.engine) as session:
|
||||||
with Session(db.engine) as session:
|
try:
|
||||||
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
|
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
|
||||||
with redis_client.lock(lock, timeout=20):
|
with redis_client.lock(lock, timeout=20):
|
||||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||||
|
|
@ -285,9 +285,9 @@ class BuiltinToolManageService:
|
||||||
|
|
||||||
session.add(db_provider)
|
session.add(db_provider)
|
||||||
session.commit()
|
session.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
session.rollback()
|
session.rollback()
|
||||||
raise ValueError(str(e))
|
raise ValueError(str(e))
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ from core.helper import encrypter
|
||||||
from core.model_runtime.entities.llm_entities import LLMMode
|
from core.model_runtime.entities.llm_entities import LLMMode
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.prompt.simple_prompt_transform import SimplePromptTransform
|
from core.prompt.simple_prompt_transform import SimplePromptTransform
|
||||||
|
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
from events.app_event import app_was_created
|
from events.app_event import app_was_created
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
|
@ -420,7 +421,11 @@ class WorkflowConverter:
|
||||||
query_in_prompt=False,
|
query_in_prompt=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
template = prompt_template_config["prompt_template"].template
|
prompt_template_obj = prompt_template_config["prompt_template"]
|
||||||
|
if not isinstance(prompt_template_obj, PromptTemplateParser):
|
||||||
|
raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template_obj)}")
|
||||||
|
|
||||||
|
template = prompt_template_obj.template
|
||||||
if not template:
|
if not template:
|
||||||
prompts = []
|
prompts = []
|
||||||
else:
|
else:
|
||||||
|
|
@ -457,7 +462,11 @@ class WorkflowConverter:
|
||||||
query_in_prompt=False,
|
query_in_prompt=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
template = prompt_template_config["prompt_template"].template
|
prompt_template_obj = prompt_template_config["prompt_template"]
|
||||||
|
if not isinstance(prompt_template_obj, PromptTemplateParser):
|
||||||
|
raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template_obj)}")
|
||||||
|
|
||||||
|
template = prompt_template_obj.template
|
||||||
template = self._replace_template_variables(
|
template = self._replace_template_variables(
|
||||||
template=template,
|
template=template,
|
||||||
variables=start_node["data"]["variables"],
|
variables=start_node["data"]["variables"],
|
||||||
|
|
@ -467,6 +476,9 @@ class WorkflowConverter:
|
||||||
prompts = {"text": template}
|
prompts = {"text": template}
|
||||||
|
|
||||||
prompt_rules = prompt_template_config["prompt_rules"]
|
prompt_rules = prompt_template_config["prompt_rules"]
|
||||||
|
if not isinstance(prompt_rules, dict):
|
||||||
|
raise TypeError(f"Expected dict for prompt_rules, got {type(prompt_rules)}")
|
||||||
|
|
||||||
role_prefix = {
|
role_prefix = {
|
||||||
"user": prompt_rules.get("human_prefix", "Human"),
|
"user": prompt_rules.get("human_prefix", "Human"),
|
||||||
"assistant": prompt_rules.get("assistant_prefix", "Assistant"),
|
"assistant": prompt_rules.get("assistant_prefix", "Assistant"),
|
||||||
|
|
|
||||||
|
|
@ -769,10 +769,10 @@ class WorkflowService:
|
||||||
)
|
)
|
||||||
error = node_run_result.error if not run_succeeded else None
|
error = node_run_result.error if not run_succeeded else None
|
||||||
except WorkflowNodeRunFailedError as e:
|
except WorkflowNodeRunFailedError as e:
|
||||||
node = e._node
|
node = e.node
|
||||||
run_succeeded = False
|
run_succeeded = False
|
||||||
node_run_result = None
|
node_run_result = None
|
||||||
error = e._error
|
error = e.error
|
||||||
|
|
||||||
# Create a NodeExecution domain model
|
# Create a NodeExecution domain model
|
||||||
node_execution = WorkflowNodeExecution(
|
node_execution = WorkflowNodeExecution(
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ class WorkspaceService:
|
||||||
def get_tenant_info(cls, tenant: Tenant):
|
def get_tenant_info(cls, tenant: Tenant):
|
||||||
if not tenant:
|
if not tenant:
|
||||||
return None
|
return None
|
||||||
tenant_info = {
|
tenant_info: dict[str, object] = {
|
||||||
"id": tenant.id,
|
"id": tenant.id,
|
||||||
"name": tenant.name,
|
"name": tenant.name,
|
||||||
"plan": tenant.plan,
|
"plan": tenant.plan,
|
||||||
|
|
|
||||||
|
|
@ -3278,7 +3278,7 @@ class TestRegisterService:
|
||||||
redis_client.setex(cache_key, 24 * 60 * 60, account_id)
|
redis_client.setex(cache_key, 24 * 60 * 60, account_id)
|
||||||
|
|
||||||
# Execute invitation retrieval
|
# Execute invitation retrieval
|
||||||
result = RegisterService._get_invitation_by_token(
|
result = RegisterService.get_invitation_by_token(
|
||||||
token=token,
|
token=token,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
email=email,
|
email=email,
|
||||||
|
|
@ -3316,7 +3316,7 @@ class TestRegisterService:
|
||||||
redis_client.setex(token_key, 24 * 60 * 60, json.dumps(invitation_data))
|
redis_client.setex(token_key, 24 * 60 * 60, json.dumps(invitation_data))
|
||||||
|
|
||||||
# Execute invitation retrieval
|
# Execute invitation retrieval
|
||||||
result = RegisterService._get_invitation_by_token(token=token)
|
result = RegisterService.get_invitation_by_token(token=token)
|
||||||
|
|
||||||
# Verify result contains expected data
|
# Verify result contains expected data
|
||||||
assert result is not None
|
assert result is not None
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ from core.app.app_config.entities import (
|
||||||
VariableEntityType,
|
VariableEntityType,
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.llm_entities import LLMMode
|
from core.model_runtime.entities.llm_entities import LLMMode
|
||||||
|
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||||
from models.account import Account, Tenant
|
from models.account import Account, Tenant
|
||||||
from models.api_based_extension import APIBasedExtension
|
from models.api_based_extension import APIBasedExtension
|
||||||
from models.model import App, AppMode, AppModelConfig
|
from models.model import App, AppMode, AppModelConfig
|
||||||
|
|
@ -37,7 +38,7 @@ class TestWorkflowConverter:
|
||||||
# Setup default mock returns
|
# Setup default mock returns
|
||||||
mock_encrypter.decrypt_token.return_value = "decrypted_api_key"
|
mock_encrypter.decrypt_token.return_value = "decrypted_api_key"
|
||||||
mock_prompt_transform.return_value.get_prompt_template.return_value = {
|
mock_prompt_transform.return_value.get_prompt_template.return_value = {
|
||||||
"prompt_template": type("obj", (object,), {"template": "You are a helpful assistant {{text_input}}"})(),
|
"prompt_template": PromptTemplateParser(template="You are a helpful assistant {{text_input}}"),
|
||||||
"prompt_rules": {"human_prefix": "Human", "assistant_prefix": "Assistant"},
|
"prompt_rules": {"human_prefix": "Human", "assistant_prefix": "Assistant"},
|
||||||
}
|
}
|
||||||
mock_agent_chat_config_manager.get_app_config.return_value = self._create_mock_app_config()
|
mock_agent_chat_config_manager.get_app_config.return_value = self._create_mock_app_config()
|
||||||
|
|
|
||||||
|
|
@ -1370,8 +1370,8 @@ class TestRegisterService:
|
||||||
account_id="user-123", email="test@example.com"
|
account_id="user-123", email="test@example.com"
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch("services.account_service.RegisterService._get_invitation_by_token") as mock_get_invitation_by_token:
|
with patch("services.account_service.RegisterService.get_invitation_by_token") as mock_get_invitation_by_token:
|
||||||
# Mock the invitation data returned by _get_invitation_by_token
|
# Mock the invitation data returned by get_invitation_by_token
|
||||||
invitation_data = {
|
invitation_data = {
|
||||||
"account_id": "user-123",
|
"account_id": "user-123",
|
||||||
"email": "test@example.com",
|
"email": "test@example.com",
|
||||||
|
|
@ -1503,12 +1503,12 @@ class TestRegisterService:
|
||||||
assert result == "member_invite:token:test-token"
|
assert result == "member_invite:token:test-token"
|
||||||
|
|
||||||
def test_get_invitation_by_token_with_workspace_and_email(self, mock_redis_dependencies):
|
def test_get_invitation_by_token_with_workspace_and_email(self, mock_redis_dependencies):
|
||||||
"""Test _get_invitation_by_token with workspace ID and email."""
|
"""Test get_invitation_by_token with workspace ID and email."""
|
||||||
# Setup mock
|
# Setup mock
|
||||||
mock_redis_dependencies.get.return_value = b"user-123"
|
mock_redis_dependencies.get.return_value = b"user-123"
|
||||||
|
|
||||||
# Execute test
|
# Execute test
|
||||||
result = RegisterService._get_invitation_by_token("token-123", "workspace-456", "test@example.com")
|
result = RegisterService.get_invitation_by_token("token-123", "workspace-456", "test@example.com")
|
||||||
|
|
||||||
# Verify results
|
# Verify results
|
||||||
assert result is not None
|
assert result is not None
|
||||||
|
|
@ -1517,7 +1517,7 @@ class TestRegisterService:
|
||||||
assert result["workspace_id"] == "workspace-456"
|
assert result["workspace_id"] == "workspace-456"
|
||||||
|
|
||||||
def test_get_invitation_by_token_without_workspace_and_email(self, mock_redis_dependencies):
|
def test_get_invitation_by_token_without_workspace_and_email(self, mock_redis_dependencies):
|
||||||
"""Test _get_invitation_by_token without workspace ID and email."""
|
"""Test get_invitation_by_token without workspace ID and email."""
|
||||||
# Setup mock
|
# Setup mock
|
||||||
invitation_data = {
|
invitation_data = {
|
||||||
"account_id": "user-123",
|
"account_id": "user-123",
|
||||||
|
|
@ -1527,19 +1527,19 @@ class TestRegisterService:
|
||||||
mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode()
|
mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode()
|
||||||
|
|
||||||
# Execute test
|
# Execute test
|
||||||
result = RegisterService._get_invitation_by_token("token-123")
|
result = RegisterService.get_invitation_by_token("token-123")
|
||||||
|
|
||||||
# Verify results
|
# Verify results
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result == invitation_data
|
assert result == invitation_data
|
||||||
|
|
||||||
def test_get_invitation_by_token_no_data(self, mock_redis_dependencies):
|
def test_get_invitation_by_token_no_data(self, mock_redis_dependencies):
|
||||||
"""Test _get_invitation_by_token with no data."""
|
"""Test get_invitation_by_token with no data."""
|
||||||
# Setup mock
|
# Setup mock
|
||||||
mock_redis_dependencies.get.return_value = None
|
mock_redis_dependencies.get.return_value = None
|
||||||
|
|
||||||
# Execute test
|
# Execute test
|
||||||
result = RegisterService._get_invitation_by_token("token-123")
|
result = RegisterService.get_invitation_by_token("token-123")
|
||||||
|
|
||||||
# Verify results
|
# Verify results
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue