From e12f4009d396f0b244cf25c3487a4d84820b8d09 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 30 Sep 2024 17:46:31 +0800 Subject: [PATCH] feat: optimize icon url --- .../console/workspace/model_providers.py | 63 +++++-------------- api/core/file/upload_file_parser.py | 22 ++----- api/core/helper/url_signer.py | 52 +++++++++++++++ .../model_providers/model_provider_factory.py | 30 ++++++--- api/core/plugin/entities/plugin_daemon.py | 6 +- api/core/plugin/manager/asset.py | 2 +- api/services/model_provider_service.py | 63 ++----------------- 7 files changed, 103 insertions(+), 135 deletions(-) create mode 100644 api/core/helper/url_signer.py diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index 03d11b5619..d422455c87 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -1,6 +1,6 @@ import io -from flask import send_file +from flask import request, send_file from flask_login import current_user from flask_restful import Resource, reqparse from werkzeug.exceptions import Forbidden @@ -126,13 +126,17 @@ class ModelProviderIconApi(Resource): Get model provider icon """ - @setup_required - @login_required - @account_initialization_required def get(self, provider: str, icon_type: str, lang: str): + tenant_id = request.args.get("tenant_id") + if not tenant_id: + return {"content": "Invalid request."}, 400 + model_provider_service = ModelProviderService() icon, mimetype = model_provider_service.get_model_provider_icon( - tenant_id=current_user.current_tenant_id, provider=provider, icon_type=icon_type, lang=lang + tenant_id=tenant_id, + provider=provider, + icon_type=icon_type, + lang=lang, ) return send_file(io.BytesIO(icon), mimetype=mimetype) @@ -184,53 +188,16 @@ class ModelProviderPaymentCheckoutUrlApi(Resource): return data -class ModelProviderFreeQuotaSubmitApi(Resource): - @setup_required - @login_required - @account_initialization_required - def post(self, provider: str): - model_provider_service = ModelProviderService() - result = model_provider_service.free_quota_submit(tenant_id=current_user.current_tenant_id, provider=provider) - - return result - - -class ModelProviderFreeQuotaQualificationVerifyApi(Resource): - @setup_required - @login_required - @account_initialization_required - def get(self, provider: str): - parser = reqparse.RequestParser() - parser.add_argument("token", type=str, required=False, nullable=True, location="args") - args = parser.parse_args() - - model_provider_service = ModelProviderService() - result = model_provider_service.free_quota_qualification_verify( - tenant_id=current_user.current_tenant_id, provider=provider, token=args["token"] - ) - - return result - - api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers") -api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers//credentials") -api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers//credentials/validate") -api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/") +api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers//credentials") +api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers//credentials/validate") +api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/") api.add_resource( - ModelProviderIconApi, "/workspaces/current/model-providers///" + ModelProviderIconApi, "/workspaces/current/model-providers///" ) api.add_resource( - PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers//preferred-provider-type" -) -api.add_resource( - ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers//checkout-url" -) -api.add_resource( - ModelProviderFreeQuotaSubmitApi, "/workspaces/current/model-providers//free-quota-submit" -) -api.add_resource( - ModelProviderFreeQuotaQualificationVerifyApi, - "/workspaces/current/model-providers//free-quota-qualification-verify", + PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers//preferred-provider-type" ) +api.add_resource(ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers//checkout-url") diff --git a/api/core/file/upload_file_parser.py b/api/core/file/upload_file_parser.py index a8c1fd4d02..a244f651c9 100644 --- a/api/core/file/upload_file_parser.py +++ b/api/core/file/upload_file_parser.py @@ -1,12 +1,10 @@ import base64 -import hashlib -import hmac import logging -import os import time from typing import Optional from configs import dify_config +from core.helper.url_signer import UrlSigner from extensions.ext_storage import storage IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] @@ -46,14 +44,7 @@ class UploadFileParser: base_url = dify_config.FILES_URL image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview" - timestamp = str(int(time.time())) - nonce = os.urandom(16).hex() - data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() - sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - encoded_sign = base64.urlsafe_b64encode(sign).decode() - - return f"{image_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + return UrlSigner.get_signed_url(url=image_preview_url, sign_key=upload_file_id, prefix="image-preview") @classmethod def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: @@ -66,13 +57,12 @@ class UploadFileParser: :param sign: signature :return: """ - data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() - recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() + result = UrlSigner.verify( + sign_key=upload_file_id, timestamp=timestamp, nonce=nonce, sign=sign, prefix="image-preview" + ) # verify signature - if sign != recalculated_encoded_sign: + if not result: return False current_time = int(time.time()) diff --git a/api/core/helper/url_signer.py b/api/core/helper/url_signer.py new file mode 100644 index 0000000000..dfb143f4c4 --- /dev/null +++ b/api/core/helper/url_signer.py @@ -0,0 +1,52 @@ +import base64 +import hashlib +import hmac +import os +import time + +from pydantic import BaseModel, Field + +from configs import dify_config + + +class SignedUrlParams(BaseModel): + sign_key: str = Field(..., description="The sign key") + timestamp: str = Field(..., description="Timestamp") + nonce: str = Field(..., description="Nonce") + sign: str = Field(..., description="Signature") + + +class UrlSigner: + @classmethod + def get_signed_url(cls, url: str, sign_key: str, prefix: str) -> str: + signed_url_params = cls.get_signed_url_params(sign_key, prefix) + return ( + f"{url}?timestamp={signed_url_params.timestamp}" + f"&nonce={signed_url_params.nonce}&sign={signed_url_params.sign}" + ) + + @classmethod + def get_signed_url_params(cls, sign_key: str, prefix: str) -> SignedUrlParams: + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + sign = cls._sign(sign_key, timestamp, nonce, prefix) + + return SignedUrlParams(sign_key=sign_key, timestamp=timestamp, nonce=nonce, sign=sign) + + @classmethod + def verify(cls, sign_key: str, timestamp: str, nonce: str, sign: str, prefix: str) -> bool: + recalculated_sign = cls._sign(sign_key, timestamp, nonce, prefix) + + return sign == recalculated_sign + + @classmethod + def _sign(cls, sign_key: str, timestamp: str, nonce: str, prefix: str) -> str: + if not dify_config.SECRET_KEY: + raise Exception("SECRET_KEY is not set") + + data_to_sign = f"{prefix}|{sign_key}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + return encoded_sign diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index f2c6d6a650..85a79dc0ce 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -96,15 +96,9 @@ class ModelProviderFactory: # fetch plugin model providers plugin_model_provider_entities = self.get_plugin_model_providers() - plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider) - # get the provider plugin_model_provider_entity = next( - ( - p - for p in plugin_model_provider_entities - if p.declaration.provider == provider_name and (plugin_id and p.plugin_id == plugin_id) - ), + (p for p in plugin_model_provider_entities if p.declaration.provider == provider), None, ) @@ -284,7 +278,7 @@ class ModelProviderFactory: elif model_type == ModelType.TTS: return TTSModel(**init_params) - def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> bytes: + def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: """ Get provider icon :param provider: provider name @@ -315,9 +309,27 @@ class ModelProviderFactory: if not file_name: raise ValueError(f"Provider {provider} does not have icon.") + image_mime_types = { + "jpg": "image/jpeg", + "jpeg": "image/jpeg", + "png": "image/png", + "gif": "image/gif", + "bmp": "image/bmp", + "tiff": "image/tiff", + "tif": "image/tiff", + "webp": "image/webp", + "svg": "image/svg+xml", + "ico": "image/vnd.microsoft.icon", + "heif": "image/heif", + "heic": "image/heic", + } + + extension = file_name.split(".")[-1] + mime_type = image_mime_types.get(extension, "image/png") + # get icon bytes from plugin asset manager plugin_asset_manager = PluginAssetManager() - return plugin_asset_manager.fetch_asset(tenant_id=self.tenant_id, id=file_name) + return plugin_asset_manager.fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type def get_plugin_id_and_provider_name_from_provider(self, provider: str) -> tuple[str, str]: """ diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index ebce6e6cee..9596972def 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -58,9 +58,9 @@ class PluginModelSchemaEntity(BaseModel): class PluginModelProviderEntity(BaseModel): - id: str = Field(alias="ID", description="ID") - created_at: datetime = Field(alias="CreatedAt", description="The created at time of the model provider.") - updated_at: datetime = Field(alias="UpdatedAt", description="The updated at time of the model provider.") + id: str = Field(description="ID") + created_at: datetime = Field(description="The created at time of the model provider.") + updated_at: datetime = Field(description="The updated at time of the model provider.") provider: str = Field(description="The provider of the model.") tenant_id: str = Field(description="The tenant ID.") plugin_unique_identifier: str = Field(description="The plugin unique identifier.") diff --git a/api/core/plugin/manager/asset.py b/api/core/plugin/manager/asset.py index fc4a99ad49..17755d3561 100644 --- a/api/core/plugin/manager/asset.py +++ b/api/core/plugin/manager/asset.py @@ -6,7 +6,7 @@ class PluginAssetManager(BasePluginManager): """ Fetch an asset by id. """ - response = self._request(method="GET", path=f"plugin/{tenant_id}/assets/{id}") + response = self._request(method="GET", path=f"plugin/{tenant_id}/asset/{id}") if response.status_code != 200: raise ValueError(f"can not found asset {id}") return response.content diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 630d575f82..0375974041 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -1,9 +1,6 @@ import logging -import os from typing import Optional -import requests - from core.entities.model_entities import ModelStatus, ProviderModelWithStatusEntity from core.model_runtime.entities.model_entities import ModelType, ParameterRule from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory @@ -368,8 +365,9 @@ class ModelProviderService: :return: """ model_type_enum = ModelType.value_of(model_type) - result = self.provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type_enum) + try: + result = self.provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type_enum) return ( DefaultModelResponse( model=result.model, @@ -386,7 +384,7 @@ class ModelProviderService: else None ) except Exception as e: - logger.info(f"get_default_model_of_model_type error: {e}") + logger.debug(f"get_default_model_of_model_type error: {e}") return None def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None: @@ -417,9 +415,9 @@ class ModelProviderService: :return: """ model_provider_factory = ModelProviderFactory(tenant_id) - byte_data = model_provider_factory.get_provider_icon(provider, icon_type, lang) + byte_data, mime_type = model_provider_factory.get_provider_icon(provider, icon_type, lang) - return byte_data, "application/octet-stream" + return byte_data, mime_type def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str) -> None: """ @@ -485,54 +483,3 @@ class ModelProviderService: # Enable model provider_configuration.disable_model(model=model, model_type=ModelType.value_of(model_type)) - - def free_quota_submit(self, tenant_id: str, provider: str): - api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") - api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL") - if not api_base_url: - raise Exception("FREE_QUOTA_APPLY_BASE_URL is not set") - - api_url = api_base_url + "/api/v1/providers/apply" - - headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} - response = requests.post(api_url, headers=headers, json={"workspace_id": tenant_id, "provider_name": provider}) - if not response.ok: - logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ") - raise ValueError(f"Error: {response.status_code} ") - - if response.json()["code"] != "success": - raise ValueError(f"error: {response.json()['message']}") - - rst = response.json() - - if rst["type"] == "redirect": - return {"type": rst["type"], "redirect_url": rst["redirect_url"]} - else: - return {"type": rst["type"], "result": "success"} - - def free_quota_qualification_verify(self, tenant_id: str, provider: str, token: Optional[str]): - api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") - api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL") - if not api_base_url: - raise Exception("FREE_QUOTA_APPLY_BASE_URL is not set") - - api_url = api_base_url + "/api/v1/providers/qualification-verify" - - headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} - json_data = {"workspace_id": tenant_id, "provider_name": provider} - if token: - json_data["token"] = token - response = requests.post(api_url, headers=headers, json=json_data) - if not response.ok: - logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ") - raise ValueError(f"Error: {response.status_code} ") - - rst = response.json() - if rst["code"] != "success": - raise ValueError(f"error: {rst['message']}") - - data = rst["data"] - if data["qualified"] is True: - return {"result": "success", "provider_name": provider, "flag": True} - else: - return {"result": "success", "provider_name": provider, "flag": False, "reason": data["reason"]}