diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index e3d87549bb..6d05743bfe 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -12,7 +12,7 @@ from controllers.common.schema import register_response_schema_models, register_ from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict from libs.passport import PassportService from libs.token import extract_webapp_passport -from models.model import App, AppMode +from models.model import App, AppMode, EndUser from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService @@ -56,7 +56,7 @@ class AppParameterApi(WebApiResource): 500: "Internal Server Error", } ) - def get(self, app_model: App, end_user): + def get(self, app_model: App, end_user: EndUser): """Retrieve app parameters.""" if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow @@ -92,7 +92,7 @@ class AppMeta(WebApiResource): 500: "Internal Server Error", } ) - def get(self, app_model: App, end_user): + def get(self, app_model: App, end_user: EndUser): """Get app meta""" return AppService().get_app_meta(app_model) diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index f08d08ab7d..258493303f 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -29,7 +29,7 @@ from core.errors.error import ( from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value -from models.model import AppMode +from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService from services.app_task_service import AppTaskService from services.errors.llm import InvokeRateLimitError @@ -86,7 +86,7 @@ class CompletionApi(WebApiResource): 500: "Internal Server Error", } ) - def post(self, app_model, end_user): + def post(self, app_model: App, end_user: EndUser): if app_model.mode != AppMode.COMPLETION: raise NotCompletionAppError() @@ -140,7 +140,7 @@ class CompletionStopApi(WebApiResource): } ) @web_ns.response(200, "Success", web_ns.models[SimpleResultResponse.__name__]) - def post(self, app_model, end_user, task_id: str): + def post(self, app_model: App, end_user: EndUser, task_id: str): if app_model.mode != AppMode.COMPLETION: raise NotCompletionAppError() @@ -169,7 +169,7 @@ class ChatApi(WebApiResource): 500: "Internal Server Error", } ) - def post(self, app_model, end_user): + def post(self, app_model: App, end_user: EndUser): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() @@ -226,7 +226,7 @@ class ChatStopApi(WebApiResource): } ) @web_ns.response(200, "Success", web_ns.models[SimpleResultResponse.__name__]) - def post(self, app_model, end_user, task_id: str): + def post(self, app_model: App, end_user: EndUser, task_id: str): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index 00db29a606..7803b11f4e 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -19,7 +19,7 @@ from fields.conversation_fields import ( SimpleConversation, ) from libs.helper import uuid_value -from models.model import AppMode +from models.model import App, AppMode, EndUser from services.conversation_service import ConversationService from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError from services.web_conversation_service import WebConversationService @@ -81,7 +81,7 @@ class ConversationListApi(WebApiResource): 500: "Internal Server Error", } ) - def get(self, app_model, end_user): + def get(self, app_model: App, end_user: EndUser): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() @@ -127,7 +127,7 @@ class ConversationApi(WebApiResource): 500: "Internal Server Error", } ) - def delete(self, app_model, end_user, c_id: UUID): + def delete(self, app_model: App, end_user: EndUser, c_id: UUID): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() @@ -166,7 +166,7 @@ class ConversationRenameApi(WebApiResource): 500: "Internal Server Error", } ) - def post(self, app_model, end_user, c_id: UUID): + def post(self, app_model: App, end_user: EndUser, c_id: UUID): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() @@ -204,7 +204,7 @@ class ConversationPinApi(WebApiResource): } ) @web_ns.response(200, "Conversation pinned successfully", web_ns.models[ResultResponse.__name__]) - def patch(self, app_model, end_user, c_id: UUID): + def patch(self, app_model: App, end_user: EndUser, c_id: UUID): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() @@ -235,7 +235,7 @@ class ConversationUnPinApi(WebApiResource): } ) @web_ns.response(200, "Conversation unpinned successfully", web_ns.models[ResultResponse.__name__]) - def patch(self, app_model, end_user, c_id: UUID): + def patch(self, app_model: App, end_user: EndUser, c_id: UUID): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() diff --git a/api/controllers/web/files.py b/api/controllers/web/files.py index 6128490104..e08a337364 100644 --- a/api/controllers/web/files.py +++ b/api/controllers/web/files.py @@ -13,6 +13,7 @@ from controllers.web import web_ns from controllers.web.wraps import WebApiResource from extensions.ext_database import db from fields.file_fields import FileResponse +from models.model import App, EndUser from services.file_service import FileService register_schema_models(web_ns, FileResponse) @@ -31,7 +32,7 @@ class FileApi(WebApiResource): } ) @web_ns.response(201, "File uploaded successfully", web_ns.models[FileResponse.__name__]) - def post(self, app_model, end_user): + def post(self, app_model: App, end_user: EndUser): """Upload a file for use in web applications. Accepts file uploads for use within web applications, supporting diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index cf0363b66e..ee58433679 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -27,7 +27,7 @@ from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfinite from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from models.enums import FeedbackRating -from models.model import AppMode +from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError @@ -81,7 +81,7 @@ class MessageListApi(WebApiResource): 500: "Internal Server Error", } ) - def get(self, app_model, end_user): + def get(self, app_model: App, end_user: EndUser): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() @@ -133,7 +133,7 @@ class MessageFeedbackApi(WebApiResource): } ) @web_ns.response(200, "Feedback submitted successfully", web_ns.models[ResultResponse.__name__]) - def post(self, app_model, end_user, message_id: UUID): + def post(self, app_model: App, end_user: EndUser, message_id: UUID): message_id_str = str(message_id) payload = MessageFeedbackPayload.model_validate(web_ns.payload or {}) @@ -167,7 +167,7 @@ class MessageMoreLikeThisApi(WebApiResource): 500: "Internal Server Error", } ) - def get(self, app_model, end_user, message_id: UUID): + def get(self, app_model: App, end_user: EndUser, message_id: UUID): if app_model.mode != "completion": raise NotCompletionAppError() @@ -223,7 +223,7 @@ class MessageSuggestedQuestionApi(WebApiResource): 500: "Internal Server Error", } ) - def get(self, app_model, end_user, message_id: UUID): + def get(self, app_model: App, end_user: EndUser, message_id: UUID): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index e9f727097b..c18c05d3e9 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -13,6 +13,7 @@ from core.helper import ssrf_proxy from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo from graphon.file import helpers as file_helpers +from models.model import App, EndUser from services.file_service import FileService from ..common.schema import register_response_schema_models, register_schema_models @@ -41,7 +42,7 @@ class RemoteFileInfoApi(WebApiResource): } ) @web_ns.response(200, "Remote file info", web_ns.models[RemoteFileInfo.__name__]) - def get(self, app_model, end_user, url): + def get(self, app_model: App, end_user: EndUser, url: str): """Get information about a remote file. Retrieves basic information about a file located at a remote URL, @@ -85,7 +86,7 @@ class RemoteFileUploadApi(WebApiResource): } ) @web_ns.response(201, "Remote file uploaded", web_ns.models[FileWithSignedUrl.__name__]) - def post(self, app_model, end_user): + def post(self, app_model: App, end_user: EndUser): """Upload a file from a remote URL. Downloads a file from the provided remote URL and uploads it diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index 766cfc6c60..7ce72e56ab 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -11,6 +11,7 @@ from controllers.web.error import NotCompletionAppError from controllers.web.wraps import WebApiResource from fields.conversation_fields import ResultResponse from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem +from models.model import App, EndUser from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService @@ -43,7 +44,7 @@ class SavedMessageListApi(WebApiResource): 500: "Internal Server Error", } ) - def get(self, app_model, end_user): + def get(self, app_model: App, end_user: EndUser): if app_model.mode != "completion": raise NotCompletionAppError() @@ -77,7 +78,7 @@ class SavedMessageListApi(WebApiResource): } ) @web_ns.response(200, "Message saved successfully", web_ns.models[ResultResponse.__name__]) - def post(self, app_model, end_user): + def post(self, app_model: App, end_user: EndUser): if app_model.mode != "completion": raise NotCompletionAppError() @@ -106,7 +107,7 @@ class SavedMessageApi(WebApiResource): 500: "Internal Server Error", } ) - def delete(self, app_model, end_user, message_id: UUID): + def delete(self, app_model: App, end_user: EndUser, message_id: UUID): message_id_str = str(message_id) if app_model.mode != "completion": diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index bd21632b05..19b04b7acc 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -10,7 +10,7 @@ from controllers.web.wraps import WebApiResource from extensions.ext_database import db from libs.helper import AppIconUrlField from models.account import TenantStatus -from models.model import App, Site +from models.model import App, EndUser, Site from services.feature_service import FeatureService @@ -70,7 +70,7 @@ class AppSiteApi(WebApiResource): } ) @marshal_with(app_fields) - def get(self, app_model, end_user): + def get(self, app_model: App, end_user: EndUser): """Retrieve app site info.""" # get site site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) @@ -78,7 +78,7 @@ class AppSiteApi(WebApiResource): if not site: raise Forbidden() - if app_model.tenant.status == TenantStatus.ARCHIVE: + if app_model.tenant and app_model.tenant.status == TenantStatus.ARCHIVE: raise Forbidden() can_replace_logo = FeatureService.get_features(app_model.tenant_id, exclude_vector_space=True).can_replace_logo