mirror of https://github.com/langgenius/dify.git
parent
3225ca8337
commit
995bfe16fc
|
|
@ -3,15 +3,16 @@ from flask_restx import Resource, fields, marshal_with
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from controllers.console import console_ns
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.api_based_extension_fields import api_based_extension_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
from services.api_based_extension_service import APIBasedExtensionService
|
||||
from services.code_based_extension_service import CodeBasedExtensionService
|
||||
|
||||
from ..common.schema import register_schema_models
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, setup_required
|
||||
|
||||
|
||||
class CodeBasedExtensionQuery(BaseModel):
|
||||
module: str
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from fields.tag_fields import dataset_tag_fields
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from core.mcp.mcp_client import MCPClient
|
|||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import alphanumeric, uuid_value
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
|
|
@ -74,12 +75,12 @@ class BuiltinToolUpdatePayload(BaseModel):
|
|||
class ApiToolProviderBasePayload(BaseModel):
|
||||
credentials: dict[str, Any]
|
||||
schema_type: str
|
||||
schema: str
|
||||
schema_: str = Field(alias="schema")
|
||||
provider: str
|
||||
icon: dict[str, Any]
|
||||
privacy_policy: str | None = None
|
||||
labels: list[str] | None = None
|
||||
custom_disclaimer: str | None = None
|
||||
custom_disclaimer: str = ""
|
||||
|
||||
|
||||
class ApiToolProviderAddPayload(ApiToolProviderBasePayload):
|
||||
|
|
@ -110,7 +111,7 @@ class ApiToolProviderDeletePayload(BaseModel):
|
|||
|
||||
|
||||
class ApiToolSchemaPayload(BaseModel):
|
||||
schema: str
|
||||
schema_: str = Field(alias="schema")
|
||||
|
||||
|
||||
class ApiToolTestPayload(BaseModel):
|
||||
|
|
@ -119,7 +120,7 @@ class ApiToolTestPayload(BaseModel):
|
|||
credentials: dict[str, Any]
|
||||
parameters: dict[str, Any]
|
||||
schema_type: str
|
||||
schema: str
|
||||
schema_: str = Field(alias="schema")
|
||||
|
||||
|
||||
class WorkflowToolBasePayload(BaseModel):
|
||||
|
|
@ -127,7 +128,7 @@ class WorkflowToolBasePayload(BaseModel):
|
|||
label: str
|
||||
description: str
|
||||
icon: dict[str, Any]
|
||||
parameters: list[dict[str, Any]]
|
||||
parameters: list[WorkflowToolParameterConfiguration]
|
||||
privacy_policy: str | None = ""
|
||||
labels: list[str] | None = None
|
||||
|
||||
|
|
@ -205,7 +206,7 @@ class MCPProviderBasePayload(BaseModel):
|
|||
name: str
|
||||
icon: str
|
||||
icon_type: str
|
||||
icon_background: str | None = ""
|
||||
icon_background: str = ""
|
||||
server_identifier: str
|
||||
configuration: dict[str, Any] | None = Field(default_factory=dict)
|
||||
headers: dict[str, Any] | None = Field(default_factory=dict)
|
||||
|
|
@ -266,10 +267,10 @@ class ToolProviderListApi(Resource):
|
|||
|
||||
user_id = user.id
|
||||
|
||||
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type]
|
||||
raw_args = request.args.to_dict()
|
||||
query = ToolProviderListQuery.model_validate(raw_args)
|
||||
|
||||
return ToolCommonService.list_tool_providers(user_id, tenant_id, query.type)
|
||||
return ToolCommonService.list_tool_providers(user_id, tenant_id, query.type) # type: ignore
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/tools")
|
||||
|
|
@ -411,7 +412,7 @@ class ToolApiProviderAddApi(Resource):
|
|||
payload.icon,
|
||||
payload.credentials,
|
||||
payload.schema_type,
|
||||
payload.schema,
|
||||
payload.schema_,
|
||||
payload.privacy_policy or "",
|
||||
payload.custom_disclaimer or "",
|
||||
payload.labels or [],
|
||||
|
|
@ -428,7 +429,7 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
|
|||
|
||||
user_id = user.id
|
||||
|
||||
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type]
|
||||
raw_args = request.args.to_dict()
|
||||
query = UrlQuery.model_validate(raw_args)
|
||||
|
||||
return ApiToolManageService.get_api_tool_provider_remote_schema(
|
||||
|
|
@ -448,7 +449,7 @@ class ToolApiProviderListToolsApi(Resource):
|
|||
|
||||
user_id = user.id
|
||||
|
||||
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type]
|
||||
raw_args = request.args.to_dict()
|
||||
query = ProviderQuery.model_validate(raw_args)
|
||||
|
||||
return jsonable_encoder(
|
||||
|
|
@ -482,7 +483,7 @@ class ToolApiProviderUpdateApi(Resource):
|
|||
payload.icon,
|
||||
payload.credentials,
|
||||
payload.schema_type,
|
||||
payload.schema,
|
||||
payload.schema_,
|
||||
payload.privacy_policy,
|
||||
payload.custom_disclaimer,
|
||||
payload.labels or [],
|
||||
|
|
@ -520,7 +521,7 @@ class ToolApiProviderGetApi(Resource):
|
|||
|
||||
user_id = user.id
|
||||
|
||||
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type]
|
||||
raw_args = request.args.to_dict()
|
||||
query = ProviderQuery.model_validate(raw_args)
|
||||
|
||||
return ApiToolManageService.get_api_tool_provider(
|
||||
|
|
@ -555,7 +556,7 @@ class ToolApiProviderSchemaApi(Resource):
|
|||
payload = ApiToolSchemaPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
return ApiToolManageService.parser_api_schema(
|
||||
schema=payload.schema,
|
||||
schema=payload.schema_,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -575,7 +576,7 @@ class ToolApiProviderPreviousTestApi(Resource):
|
|||
payload.credentials,
|
||||
payload.parameters,
|
||||
payload.schema_type,
|
||||
payload.schema,
|
||||
payload.schema_,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -665,7 +666,7 @@ class ToolWorkflowProviderGetApi(Resource):
|
|||
|
||||
user_id = user.id
|
||||
|
||||
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type]
|
||||
raw_args = request.args.to_dict()
|
||||
query = WorkflowToolGetQuery.model_validate(raw_args)
|
||||
|
||||
if query.workflow_tool_id:
|
||||
|
|
@ -696,7 +697,7 @@ class ToolWorkflowProviderListToolApi(Resource):
|
|||
|
||||
user_id = user.id
|
||||
|
||||
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type]
|
||||
raw_args = request.args.to_dict()
|
||||
query = WorkflowToolListQuery.model_validate(raw_args)
|
||||
|
||||
return jsonable_encoder(
|
||||
|
|
@ -1155,7 +1156,7 @@ class ToolMCPUpdateApi(Resource):
|
|||
@console_ns.route("/mcp/oauth/callback")
|
||||
class ToolMCPCallbackApi(Resource):
|
||||
def get(self):
|
||||
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type]
|
||||
raw_args = request.args.to_dict()
|
||||
query = MCPCallbackQuery.model_validate(raw_args)
|
||||
state_key = query.state
|
||||
authorization_code = query.code
|
||||
|
|
|
|||
|
|
@ -7,9 +7,7 @@ from werkzeug.exceptions import Unauthorized
|
|||
|
||||
from constants import HEADER_NAME_APP_CODE
|
||||
from controllers.common import fields
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import AppUnavailableError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from controllers.common.schema import register_schema_models
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from libs.passport import PassportService
|
||||
from libs.token import extract_webapp_passport
|
||||
|
|
@ -19,6 +17,10 @@ from services.enterprise.enterprise_service import EnterpriseService
|
|||
from services.feature_service import FeatureService
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
from . import web_ns
|
||||
from .error import AppUnavailableError
|
||||
from .wraps import WebApiResource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -29,6 +31,9 @@ class AppAccessModeQuery(BaseModel):
|
|||
app_code: str | None = Field(default=None, alias="appCode", description="Application code")
|
||||
|
||||
|
||||
register_schema_models(web_ns, AppAccessModeQuery)
|
||||
|
||||
|
||||
@web_ns.route("/parameters")
|
||||
class AppParameterApi(WebApiResource):
|
||||
"""Resource for app variables."""
|
||||
|
|
@ -99,7 +104,7 @@ class AppAccessMode(Resource):
|
|||
}
|
||||
)
|
||||
def get(self):
|
||||
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type]
|
||||
raw_args = request.args.to_dict()
|
||||
args = AppAccessModeQuery.model_validate(raw_args)
|
||||
|
||||
features = FeatureService.get_system_features()
|
||||
|
|
|
|||
|
|
@ -18,8 +18,6 @@ from controllers.web.error import (
|
|||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from ..common import register_schema_models
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value
|
||||
|
|
@ -32,6 +30,9 @@ from services.errors.audio import (
|
|||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
from ..common.schema import register_schema_models
|
||||
from ..web.wraps import WebApiResource
|
||||
|
||||
|
||||
class TextToAudioPayload(BaseModel):
|
||||
message_id: str | None = None
|
||||
|
|
@ -46,6 +47,7 @@ class TextToAudioPayload(BaseModel):
|
|||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
register_schema_models(web_ns, TextToAudioPayload)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ class ConversationListApi(WebApiResource):
|
|||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type]
|
||||
raw_args = request.args.to_dict()
|
||||
query = ConversationListQuery.model_validate(raw_args)
|
||||
|
||||
try:
|
||||
|
|
@ -171,7 +171,7 @@ class ConversationRenameApi(WebApiResource):
|
|||
payload = ConversationRenamePayload.model_validate(web_ns.payload or {})
|
||||
|
||||
try:
|
||||
return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate)
|
||||
return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate) # type: ignore
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
|
|
|||
|
|
@ -108,7 +108,7 @@ class MessageListApi(WebApiResource):
|
|||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type]
|
||||
raw_args = request.args.to_dict()
|
||||
query = MessageListQuery.model_validate(raw_args)
|
||||
|
||||
try:
|
||||
|
|
@ -182,7 +182,7 @@ class MessageMoreLikeThisApi(WebApiResource):
|
|||
|
||||
message_id = str(message_id)
|
||||
|
||||
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type]
|
||||
raw_args = request.args.to_dict()
|
||||
query = MessageMoreLikeThisQuery.model_validate(raw_args)
|
||||
|
||||
streaming = query.response_mode == "streaming"
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ class SavedMessageListApi(WebApiResource):
|
|||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type]
|
||||
raw_args = request.args.to_dict()
|
||||
query = SavedMessageListQuery.model_validate(raw_args)
|
||||
|
||||
return SavedMessageService.pagination_by_last_id(app_model, end_user, query.last_id, query.limit)
|
||||
|
|
|
|||
|
|
@ -7,11 +7,6 @@ from core.workflow.nodes.base.entities import OutputVariableEntity
|
|||
|
||||
|
||||
class WorkflowToolConfigurationUtils:
|
||||
@classmethod
|
||||
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
|
||||
for configuration in configurations:
|
||||
WorkflowToolParameterConfiguration.model_validate(configuration)
|
||||
|
||||
@classmethod
|
||||
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -247,14 +247,14 @@ class ApiToolManageService:
|
|||
credentials: dict,
|
||||
schema_type: str,
|
||||
schema: str,
|
||||
privacy_policy: str,
|
||||
privacy_policy: str | None,
|
||||
custom_disclaimer: str,
|
||||
labels: list[str],
|
||||
):
|
||||
"""
|
||||
update api tool provider
|
||||
"""
|
||||
if schema_type not in [member.value for member in ApiProviderSchemaType]:
|
||||
if schema_type not in list(ApiProviderSchemaType):
|
||||
raise ValueError(f"invalid schema type {schema}")
|
||||
|
||||
provider_name = provider_name.strip()
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
|
@ -11,8 +9,8 @@ from core.helper.tool_provider_cache import ToolProviderListCache
|
|||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -39,12 +37,10 @@ class WorkflowToolManageService:
|
|||
label: str,
|
||||
icon: dict,
|
||||
description: str,
|
||||
parameters: list[Mapping[str, Any]],
|
||||
parameters: list[WorkflowToolParameterConfiguration],
|
||||
privacy_policy: str = "",
|
||||
labels: list[str] | None = None,
|
||||
):
|
||||
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
|
||||
|
||||
# check if the name is unique
|
||||
existing_workflow_tool_provider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
|
|
@ -77,7 +73,7 @@ class WorkflowToolManageService:
|
|||
label=label,
|
||||
icon=json.dumps(icon),
|
||||
description=description,
|
||||
parameter_configuration=json.dumps(parameters),
|
||||
parameter_configuration=json.dumps([p.model_dump() for p in parameters]),
|
||||
privacy_policy=privacy_policy,
|
||||
version=workflow.version,
|
||||
)
|
||||
|
|
@ -108,7 +104,7 @@ class WorkflowToolManageService:
|
|||
label: str,
|
||||
icon: dict,
|
||||
description: str,
|
||||
parameters: list[Mapping[str, Any]],
|
||||
parameters: list[WorkflowToolParameterConfiguration],
|
||||
privacy_policy: str = "",
|
||||
labels: list[str] | None = None,
|
||||
):
|
||||
|
|
@ -126,8 +122,6 @@ class WorkflowToolManageService:
|
|||
:param labels: labels
|
||||
:return: the updated tool
|
||||
"""
|
||||
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
|
||||
|
||||
# check if the name is unique
|
||||
existing_workflow_tool_provider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
|
|
@ -166,7 +160,7 @@ class WorkflowToolManageService:
|
|||
workflow_tool_provider.label = label
|
||||
workflow_tool_provider.icon = json.dumps(icon)
|
||||
workflow_tool_provider.description = description
|
||||
workflow_tool_provider.parameter_configuration = json.dumps(parameters)
|
||||
workflow_tool_provider.parameter_configuration = json.dumps(obj=[p.model_dump() for p in parameters])
|
||||
workflow_tool_provider.privacy_policy = privacy_policy
|
||||
workflow_tool_provider.version = workflow.version
|
||||
workflow_tool_provider.updated_at = datetime.now()
|
||||
|
|
|
|||
Loading…
Reference in New Issue