mirror of https://github.com/langgenius/dify.git
Merge ef6d73c3f4 into 2c919efa69
This commit is contained in:
commit
9e19a05368
|
|
@ -49,6 +49,7 @@ select = [
|
|||
"S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers.
|
||||
"S302", # suspicious-marshal-usage, disallow use of `marshal` module
|
||||
"S311", # suspicious-non-cryptographic-random-usage,
|
||||
"TID", # flake8-tidy-imports
|
||||
|
||||
]
|
||||
|
||||
|
|
@ -84,6 +85,7 @@ ignore = [
|
|||
"SIM113", # enumerate-for-loop
|
||||
"SIM117", # multiple-with-statements
|
||||
"SIM210", # if-expr-with-true-false
|
||||
"TID252", # allow relative imports from parent modules
|
||||
]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
|
|
@ -112,3 +114,11 @@ allowed-unused-imports = [
|
|||
"tests.integration_tests",
|
||||
"tests.unit_tests",
|
||||
]
|
||||
|
||||
[lint.flake8-tidy-imports]
|
||||
|
||||
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"]
|
||||
msg = "Use Pydantic payload/query models instead of reqparse."
|
||||
|
||||
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse.RequestParser"]
|
||||
msg = "Use Pydantic payload/query models instead of reqparse."
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import marshal, reqparse
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
|
|
@ -56,15 +56,10 @@ class DatasetsHitTestingBase:
|
|||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
@staticmethod
|
||||
def parse_args():
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("query", type=str, required=False, location="json")
|
||||
.add_argument("attachment_ids", type=list, required=False, location="json")
|
||||
.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
)
|
||||
return parser.parse_args()
|
||||
def parse_args(payload: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
"""Validate and return hit-testing arguments from an incoming payload."""
|
||||
hit_testing_payload = HitTestingPayload.model_validate(payload or {})
|
||||
return hit_testing_payload.model_dump(exclude_none=True)
|
||||
|
||||
@staticmethod
|
||||
def perform_hit_testing(dataset, args):
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import Any, Literal, cast
|
||||
from uuid import UUID
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, marshal_with, reqparse # type: ignore
|
||||
from flask_restx import Resource, marshal_with # type: ignore
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
|
@ -38,7 +37,7 @@ from fields.workflow_run_fields import (
|
|||
workflow_run_pagination_fields,
|
||||
)
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField
|
||||
from libs.helper import TimestampField, UUIDStrOrEmpty
|
||||
from libs.login import current_account_with_tenant, current_user, login_required
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
|
|
@ -110,7 +109,7 @@ class NodeIdQuery(BaseModel):
|
|||
|
||||
|
||||
class WorkflowRunQuery(BaseModel):
|
||||
last_id: UUID | None = None
|
||||
last_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
|
|
@ -121,6 +120,10 @@ class DatasourceVariablesPayload(BaseModel):
|
|||
start_node_title: str
|
||||
|
||||
|
||||
class RagPipelineRecommendedPluginQuery(BaseModel):
|
||||
type: str = "all"
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
DraftWorkflowSyncPayload,
|
||||
|
|
@ -135,6 +138,7 @@ register_schema_models(
|
|||
NodeIdQuery,
|
||||
WorkflowRunQuery,
|
||||
DatasourceVariablesPayload,
|
||||
RagPipelineRecommendedPluginQuery,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -975,11 +979,8 @@ class RagPipelineRecommendedPluginApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("type", type=str, location="args", required=False, default="all")
|
||||
args = parser.parse_args()
|
||||
type = args["type"]
|
||||
query = RagPipelineRecommendedPluginQuery.model_validate(request.args.to_dict())
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
recommended_plugins = rag_pipeline_service.get_recommended_plugins(type)
|
||||
recommended_plugins = rag_pipeline_service.get_recommended_plugins(query.type)
|
||||
return recommended_plugins
|
||||
|
|
|
|||
|
|
@ -11,6 +11,20 @@ from libs.login import current_account_with_tenant, login_required
|
|||
from models import TenantAccountRole
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class LoadBalancingCredentialPayload(BaseModel):
|
||||
model: str
|
||||
model_type: ModelType
|
||||
credentials: dict
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
LoadBalancingCredentialPayload.__name__,
|
||||
LoadBalancingCredentialPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
class LoadBalancingCredentialPayload(BaseModel):
|
||||
model: str
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,15 +1,14 @@
|
|||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import make_response, redirect, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web.error import NotFoundError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
|
|
@ -36,29 +35,32 @@ from ..wraps import (
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerSubscriptionUpdateRequest(BaseModel):
|
||||
"""Request payload for updating a trigger subscription"""
|
||||
|
||||
name: str | None = Field(default=None, description="The name for the subscription")
|
||||
credentials: Mapping[str, Any] | None = Field(default=None, description="The credentials for the subscription")
|
||||
parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters for the subscription")
|
||||
properties: Mapping[str, Any] | None = Field(default=None, description="The properties for the subscription")
|
||||
class TriggerSubscriptionBuilderCreatePayload(BaseModel):
|
||||
credential_type: str | None = None
|
||||
|
||||
|
||||
class TriggerSubscriptionVerifyRequest(BaseModel):
|
||||
"""Request payload for verifying subscription credentials."""
|
||||
|
||||
credentials: Mapping[str, Any] = Field(description="The credentials to verify")
|
||||
class TriggerSubscriptionBuilderVerifyPayload(BaseModel):
|
||||
credentials: dict[str, Any] | None = None
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
TriggerSubscriptionUpdateRequest.__name__,
|
||||
TriggerSubscriptionUpdateRequest.model_json_schema(ref_template="#/definitions/{model}"),
|
||||
)
|
||||
class TriggerSubscriptionBuilderUpdatePayload(BaseModel):
|
||||
name: str | None = None
|
||||
parameters: dict[str, Any] | None = None
|
||||
properties: dict[str, Any] | None = None
|
||||
credentials: dict[str, Any] | None = None
|
||||
|
||||
console_ns.schema_model(
|
||||
TriggerSubscriptionVerifyRequest.__name__,
|
||||
TriggerSubscriptionVerifyRequest.model_json_schema(ref_template="#/definitions/{model}"),
|
||||
|
||||
class TriggerOAuthClientPayload(BaseModel):
|
||||
client_params: dict[str, Any] | None = None
|
||||
enabled: bool | None = None
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
TriggerSubscriptionBuilderCreatePayload,
|
||||
TriggerSubscriptionBuilderVerifyPayload,
|
||||
TriggerSubscriptionBuilderUpdatePayload,
|
||||
TriggerOAuthClientPayload,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -127,16 +129,11 @@ class TriggerSubscriptionListApi(Resource):
|
|||
raise
|
||||
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"credential_type", type=str, required=False, nullable=True, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/create",
|
||||
)
|
||||
class TriggerSubscriptionBuilderCreateApi(Resource):
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderCreatePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
|
|
@ -146,10 +143,10 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
|
|||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
args = parser.parse_args()
|
||||
payload = TriggerSubscriptionBuilderCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
credential_type = CredentialType.of(args.get("credential_type") or CredentialType.UNAUTHORIZED.value)
|
||||
credential_type = CredentialType.of(payload.credential_type or CredentialType.UNAUTHORIZED.value)
|
||||
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
|
|
@ -177,18 +174,11 @@ class TriggerSubscriptionBuilderGetApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
parser_api = (
|
||||
reqparse.RequestParser()
|
||||
# The credentials of the subscription builder
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify-and-update/<path:subscription_builder_id>",
|
||||
)
|
||||
class TriggerSubscriptionBuilderVerifyAndUpdateApi(Resource):
|
||||
@console_ns.expect(parser_api)
|
||||
class TriggerSubscriptionBuilderVerifyApi(Resource):
|
||||
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderVerifyPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
|
|
@ -198,7 +188,7 @@ class TriggerSubscriptionBuilderVerifyAndUpdateApi(Resource):
|
|||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
args = parser_api.parse_args()
|
||||
payload = TriggerSubscriptionBuilderVerifyPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
# Use atomic update_and_verify to prevent race conditions
|
||||
|
|
@ -208,7 +198,7 @@ class TriggerSubscriptionBuilderVerifyAndUpdateApi(Resource):
|
|||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
credentials=args.get("credentials", None),
|
||||
credentials=payload.credentials,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
|
|
@ -216,24 +206,11 @@ class TriggerSubscriptionBuilderVerifyAndUpdateApi(Resource):
|
|||
raise ValueError(str(e)) from e
|
||||
|
||||
|
||||
parser_update_api = (
|
||||
reqparse.RequestParser()
|
||||
# The name of the subscription builder
|
||||
.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
# The parameters of the subscription builder
|
||||
.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||
# The properties of the subscription builder
|
||||
.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||
# The credentials of the subscription builder
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/update/<path:subscription_builder_id>",
|
||||
)
|
||||
class TriggerSubscriptionBuilderUpdateApi(Resource):
|
||||
@console_ns.expect(parser_update_api)
|
||||
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
|
|
@ -244,7 +221,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
|
|||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
args = parser_update_api.parse_args()
|
||||
payload = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {})
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
|
|
@ -252,10 +229,10 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
|
|||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
name=args.get("name", None),
|
||||
parameters=args.get("parameters", None),
|
||||
properties=args.get("properties", None),
|
||||
credentials=args.get("credentials", None),
|
||||
name=payload.name,
|
||||
parameters=payload.parameters,
|
||||
properties=payload.properties,
|
||||
credentials=payload.credentials,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
@ -290,7 +267,7 @@ class TriggerSubscriptionBuilderLogsApi(Resource):
|
|||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/build/<path:subscription_builder_id>",
|
||||
)
|
||||
class TriggerSubscriptionBuilderBuildApi(Resource):
|
||||
@console_ns.expect(parser_update_api)
|
||||
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
|
|
@ -299,7 +276,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
|
|||
"""Build a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
args = parser_update_api.parse_args()
|
||||
payload = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {})
|
||||
try:
|
||||
# Use atomic update_and_build to prevent race conditions
|
||||
TriggerSubscriptionBuilderService.update_and_build_builder(
|
||||
|
|
@ -308,9 +285,9 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
|
|||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
name=args.get("name", None),
|
||||
parameters=args.get("parameters", None),
|
||||
properties=args.get("properties", None),
|
||||
name=payload.name,
|
||||
parameters=payload.parameters,
|
||||
properties=payload.properties,
|
||||
),
|
||||
)
|
||||
return 200
|
||||
|
|
@ -581,13 +558,6 @@ class TriggerOAuthCallbackApi(Resource):
|
|||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
parser_oauth_client = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/oauth/client")
|
||||
class TriggerOAuthClientManageApi(Resource):
|
||||
@setup_required
|
||||
|
|
@ -635,7 +605,7 @@ class TriggerOAuthClientManageApi(Resource):
|
|||
logger.exception("Error getting OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
@console_ns.expect(parser_oauth_client)
|
||||
@console_ns.expect(console_ns.models[TriggerOAuthClientPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -645,15 +615,15 @@ class TriggerOAuthClientManageApi(Resource):
|
|||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
args = parser_oauth_client.parse_args()
|
||||
payload = TriggerOAuthClientPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
return TriggerProviderService.save_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
client_params=args.get("client_params"),
|
||||
enabled=args.get("enabled"),
|
||||
client_params=payload.client_params,
|
||||
enabled=payload.enabled,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from core.errors.error import (
|
|||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
|
|
@ -52,7 +53,7 @@ class ChatRequestPayload(BaseModel):
|
|||
query: str
|
||||
files: list[dict[str, Any]] | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
conversation_id: str | None = Field(default=None, description="Conversation UUID")
|
||||
conversation_id: UUIDStrOrEmpty | None = Field(default=None, description="Conversation UUID")
|
||||
retriever_from: str = Field(default="dev")
|
||||
auto_generate_name: bool = Field(default=True, description="Auto generate conversation name")
|
||||
workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
|
|
@ -24,12 +23,13 @@ from fields.conversation_variable_fields import (
|
|||
build_conversation_variable_infinite_scroll_pagination_model,
|
||||
build_conversation_variable_model,
|
||||
)
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
|
||||
|
||||
class ConversationListQuery(BaseModel):
|
||||
last_id: UUID | None = Field(default=None, description="Last conversation ID for pagination")
|
||||
last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last conversation ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of conversations to return")
|
||||
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
|
||||
default="-updated_at", description="Sort order for conversations"
|
||||
|
|
@ -49,7 +49,7 @@ class ConversationRenamePayload(BaseModel):
|
|||
|
||||
|
||||
class ConversationVariablesQuery(BaseModel):
|
||||
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")
|
||||
last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last variable ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
|
||||
variable_name: str | None = Field(
|
||||
default=None, description="Filter variables by name", min_length=1, max_length=255
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
|
|
@ -17,7 +16,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
|||
from fields.conversation_fields import build_message_file_model
|
||||
from fields.message_fields import build_agent_thought_model, build_feedback_model
|
||||
from fields.raws import FilesContainedField
|
||||
from libs.helper import TimestampField
|
||||
from libs.helper import TimestampField, UUIDStrOrEmpty
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.errors.message import (
|
||||
FirstMessageNotExistsError,
|
||||
|
|
@ -30,8 +29,8 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class MessageListQuery(BaseModel):
|
||||
conversation_id: UUID
|
||||
first_id: UUID | None = None
|
||||
conversation_id: UUIDStrOrEmpty
|
||||
first_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
|
||||
|
||||
register_schema_model(service_api_ns, HitTestingPayload)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/hit-testing", "/datasets/<uuid:dataset_id>/retrieve")
|
||||
class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
||||
|
|
@ -15,6 +18,7 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
|||
404: "Dataset not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.expect(service_api_ns.models[HitTestingPayload.__name__])
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Perform hit testing on a dataset.
|
||||
|
|
@ -24,7 +28,7 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
|||
dataset_id_str = str(dataset_id)
|
||||
|
||||
dataset = self.get_and_validate_dataset(dataset_id_str)
|
||||
args = self.parse_args()
|
||||
args = self.parse_args(service_api_ns.payload)
|
||||
self.hit_testing_args_check(args)
|
||||
|
||||
return self.perform_hit_testing(dataset, args)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ from controllers.web.error import (
|
|||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value
|
||||
|
|
|
|||
|
|
@ -1,8 +1,12 @@
|
|||
from flask_restx import fields, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotChatAppError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
|
|
@ -16,6 +20,35 @@ from services.errors.conversation import ConversationNotExistsError, LastConvers
|
|||
from services.web_conversation_service import WebConversationService
|
||||
|
||||
|
||||
class ConversationListQuery(BaseModel):
|
||||
last_id: str | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
pinned: bool | None = None
|
||||
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = "-updated_at"
|
||||
|
||||
@field_validator("last_id")
|
||||
@classmethod
|
||||
def validate_last_id(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class ConversationRenamePayload(BaseModel):
|
||||
name: str | None = None
|
||||
auto_generate: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_name_requirement(self):
|
||||
if not self.auto_generate:
|
||||
if self.name is None or not self.name.strip():
|
||||
raise ValueError("name is required when auto_generate is false")
|
||||
return self
|
||||
|
||||
|
||||
register_schema_models(web_ns, ConversationListQuery, ConversationRenamePayload)
|
||||
|
||||
|
||||
@web_ns.route("/conversations")
|
||||
class ConversationListApi(WebApiResource):
|
||||
@web_ns.doc("Get Conversation List")
|
||||
|
|
@ -60,25 +93,8 @@ class ConversationListApi(WebApiResource):
|
|||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("last_id", type=uuid_value, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
.add_argument("pinned", type=str, choices=["true", "false", None], location="args")
|
||||
.add_argument(
|
||||
"sort_by",
|
||||
type=str,
|
||||
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||
required=False,
|
||||
default="-updated_at",
|
||||
location="args",
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
pinned = None
|
||||
if "pinned" in args and args["pinned"] is not None:
|
||||
pinned = args["pinned"] == "true"
|
||||
raw_args = request.args.to_dict()
|
||||
query = ConversationListQuery.model_validate(raw_args)
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
|
|
@ -86,11 +102,11 @@ class ConversationListApi(WebApiResource):
|
|||
session=session,
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
last_id=query.last_id,
|
||||
limit=query.limit,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
pinned=pinned,
|
||||
sort_by=args["sort_by"],
|
||||
pinned=query.pinned,
|
||||
sort_by=query.sort_by,
|
||||
)
|
||||
except LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
|
@ -163,15 +179,12 @@ class ConversationRenameApi(WebApiResource):
|
|||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=False, location="json")
|
||||
.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = ConversationRenamePayload.model_validate(web_ns.payload or {})
|
||||
|
||||
try:
|
||||
return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"])
|
||||
return ConversationService.rename(
|
||||
app_model, conversation_id, end_user, payload.name or "", payload.auto_generate
|
||||
)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
from flask import make_response, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from jwt import InvalidTokenError
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.auth.error import (
|
||||
AuthenticationFailedError,
|
||||
EmailCodeError,
|
||||
|
|
@ -13,7 +15,7 @@ from controllers.console.error import AccountBannedError
|
|||
from controllers.console.wraps import only_edition_enterprise, setup_required
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.wraps import decode_jwt_token
|
||||
from libs.helper import email
|
||||
from libs.helper import EmailStr
|
||||
from libs.passport import PassportService
|
||||
from libs.password import valid_password
|
||||
from libs.token import (
|
||||
|
|
@ -25,6 +27,30 @@ from services.app_service import AppService
|
|||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
|
||||
class LoginPayload(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
def validate_password(cls, value: str) -> str:
|
||||
return valid_password(value)
|
||||
|
||||
|
||||
class EmailCodeLoginSendPayload(BaseModel):
|
||||
email: EmailStr
|
||||
language: str | None = None
|
||||
|
||||
|
||||
class EmailCodeLoginVerifyPayload(BaseModel):
|
||||
email: EmailStr
|
||||
code: str
|
||||
token: str = Field(min_length=1)
|
||||
|
||||
|
||||
register_schema_models(web_ns, LoginPayload, EmailCodeLoginSendPayload, EmailCodeLoginVerifyPayload)
|
||||
|
||||
|
||||
@web_ns.route("/login")
|
||||
class LoginApi(Resource):
|
||||
"""Resource for web app email/password login."""
|
||||
|
|
@ -44,15 +70,10 @@ class LoginApi(Resource):
|
|||
)
|
||||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("password", type=valid_password, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = LoginPayload.model_validate(web_ns.payload or {})
|
||||
|
||||
try:
|
||||
account = WebAppAuthService.authenticate(args["email"], args["password"])
|
||||
account = WebAppAuthService.authenticate(payload.email, payload.password)
|
||||
except services.errors.account.AccountLoginError:
|
||||
raise AccountBannedError()
|
||||
except services.errors.account.AccountPasswordError:
|
||||
|
|
@ -139,6 +160,7 @@ class EmailCodeLoginSendEmailApi(Resource):
|
|||
@only_edition_enterprise
|
||||
@web_ns.doc("send_email_code_login")
|
||||
@web_ns.doc(description="Send email verification code for login")
|
||||
@web_ns.expect(web_ns.models[EmailCodeLoginSendPayload.__name__])
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Email code sent successfully",
|
||||
|
|
@ -147,19 +169,14 @@ class EmailCodeLoginSendEmailApi(Resource):
|
|||
}
|
||||
)
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = EmailCodeLoginSendPayload.model_validate(web_ns.payload or {})
|
||||
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
if payload.language == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
account = WebAppAuthService.get_user_through_email(args["email"])
|
||||
account = WebAppAuthService.get_user_through_email(payload.email)
|
||||
if account is None:
|
||||
raise AuthenticationFailedError()
|
||||
else:
|
||||
|
|
@ -173,6 +190,7 @@ class EmailCodeLoginApi(Resource):
|
|||
@only_edition_enterprise
|
||||
@web_ns.doc("verify_email_code_login")
|
||||
@web_ns.doc(description="Verify email code and complete login")
|
||||
@web_ns.expect(web_ns.models[EmailCodeLoginVerifyPayload.__name__])
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Email code verified and login successful",
|
||||
|
|
@ -182,33 +200,27 @@ class EmailCodeLoginApi(Resource):
|
|||
}
|
||||
)
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = EmailCodeLoginVerifyPayload.model_validate(web_ns.payload or {})
|
||||
|
||||
user_email = args["email"]
|
||||
user_email = payload.email
|
||||
|
||||
token_data = WebAppAuthService.get_email_code_login_data(args["token"])
|
||||
token_data = WebAppAuthService.get_email_code_login_data(payload.token)
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if token_data["email"] != args["email"]:
|
||||
if token_data["email"] != payload.email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if token_data["code"] != args["code"]:
|
||||
if token_data["code"] != payload.code:
|
||||
raise EmailCodeError()
|
||||
|
||||
WebAppAuthService.revoke_email_code_login_token(args["token"])
|
||||
WebAppAuthService.revoke_email_code_login_token(payload.token)
|
||||
account = WebAppAuthService.get_user_through_email(user_email)
|
||||
if not account:
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
token = WebAppAuthService.login(account=account)
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
AccountService.reset_login_error_rate_limit(payload.email)
|
||||
response = make_response({"result": "success", "data": {"access_token": token}})
|
||||
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -1,12 +1,14 @@
|
|||
from flask_restx import fields, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from flask import request
|
||||
from flask_restx import fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotCompletionAppError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.helper import TimestampField, UUIDStrOrEmpty
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
|
|
@ -23,6 +25,18 @@ message_fields = {
|
|||
}
|
||||
|
||||
|
||||
class SavedMessageListQuery(BaseModel):
|
||||
last_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class SavedMessageCreatePayload(BaseModel):
|
||||
message_id: UUIDStrOrEmpty
|
||||
|
||||
|
||||
register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload)
|
||||
|
||||
|
||||
@web_ns.route("/saved-messages")
|
||||
class SavedMessageListApi(WebApiResource):
|
||||
saved_message_infinite_scroll_pagination_fields = {
|
||||
|
|
@ -63,14 +77,10 @@ class SavedMessageListApi(WebApiResource):
|
|||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("last_id", type=uuid_value, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
raw_args = request.args.to_dict()
|
||||
query = SavedMessageListQuery.model_validate(raw_args)
|
||||
|
||||
return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"])
|
||||
return SavedMessageService.pagination_by_last_id(app_model, end_user, query.last_id, query.limit)
|
||||
|
||||
@web_ns.doc("Save Message")
|
||||
@web_ns.doc(description="Save a specific message for later reference.")
|
||||
|
|
@ -94,11 +104,10 @@ class SavedMessageListApi(WebApiResource):
|
|||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
payload = SavedMessageCreatePayload.model_validate(web_ns.payload or {})
|
||||
|
||||
try:
|
||||
SavedMessageService.save(app_model, end_user, args["message_id"])
|
||||
SavedMessageService.save(app_model, end_user, payload.message_id)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import reqparse
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import (
|
||||
CompletionRequestError,
|
||||
|
|
@ -27,8 +29,16 @@ from models.model import App, AppMode, EndUser
|
|||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
|
||||
class WorkflowRunPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
register_schema_models(web_ns, WorkflowRunPayload)
|
||||
|
||||
|
||||
@web_ns.route("/workflows/run")
|
||||
class WorkflowRunApi(WebApiResource):
|
||||
|
|
@ -58,12 +68,8 @@ class WorkflowRunApi(WebApiResource):
|
|||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = WorkflowRunPayload.model_validate(web_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
|
|
|
|||
|
|
@ -378,7 +378,7 @@ class ApiBasedToolSchemaParser:
|
|||
@staticmethod
|
||||
def auto_parse_to_tool_bundle(
|
||||
content: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> tuple[list[ApiToolBundle], str]:
|
||||
) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]:
|
||||
"""
|
||||
auto parse to tool bundle
|
||||
|
||||
|
|
|
|||
|
|
@ -7,11 +7,6 @@ from core.workflow.nodes.base.entities import OutputVariableEntity
|
|||
|
||||
|
||||
class WorkflowToolConfigurationUtils:
|
||||
@classmethod
|
||||
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
|
||||
for configuration in configurations:
|
||||
WorkflowToolParameterConfiguration.model_validate(configuration)
|
||||
|
||||
@classmethod
|
||||
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -86,7 +86,9 @@ class ApiToolManageService:
|
|||
raise ValueError(f"invalid schema: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def convert_schema_to_tool_bundles(schema: str, extra_info: dict | None = None) -> tuple[list[ApiToolBundle], str]:
|
||||
def convert_schema_to_tool_bundles(
|
||||
schema: str, extra_info: dict | None = None
|
||||
) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]:
|
||||
"""
|
||||
convert schema to tool bundles
|
||||
|
||||
|
|
@ -104,7 +106,7 @@ class ApiToolManageService:
|
|||
provider_name: str,
|
||||
icon: dict,
|
||||
credentials: dict,
|
||||
schema_type: str,
|
||||
schema_type: ApiProviderSchemaType,
|
||||
schema: str,
|
||||
privacy_policy: str,
|
||||
custom_disclaimer: str,
|
||||
|
|
@ -113,9 +115,6 @@ class ApiToolManageService:
|
|||
"""
|
||||
create api tool provider
|
||||
"""
|
||||
if schema_type not in [member.value for member in ApiProviderSchemaType]:
|
||||
raise ValueError(f"invalid schema type {schema}")
|
||||
|
||||
provider_name = provider_name.strip()
|
||||
|
||||
# check if the provider exists
|
||||
|
|
@ -245,18 +244,15 @@ class ApiToolManageService:
|
|||
original_provider: str,
|
||||
icon: dict,
|
||||
credentials: dict,
|
||||
schema_type: str,
|
||||
_schema_type: ApiProviderSchemaType,
|
||||
schema: str,
|
||||
privacy_policy: str,
|
||||
privacy_policy: str | None,
|
||||
custom_disclaimer: str,
|
||||
labels: list[str],
|
||||
):
|
||||
"""
|
||||
update api tool provider
|
||||
"""
|
||||
if schema_type not in [member.value for member in ApiProviderSchemaType]:
|
||||
raise ValueError(f"invalid schema type {schema}")
|
||||
|
||||
provider_name = provider_name.strip()
|
||||
|
||||
# check if the provider exists
|
||||
|
|
@ -281,7 +277,7 @@ class ApiToolManageService:
|
|||
provider.icon = json.dumps(icon)
|
||||
provider.schema = schema
|
||||
provider.description = extra_info.get("description", "")
|
||||
provider.schema_type_str = ApiProviderSchemaType.OPENAPI
|
||||
provider.schema_type_str = schema_type
|
||||
provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
|
||||
provider.privacy_policy = privacy_policy
|
||||
provider.custom_disclaimer = custom_disclaimer
|
||||
|
|
@ -366,7 +362,7 @@ class ApiToolManageService:
|
|||
tool_name: str,
|
||||
credentials: dict,
|
||||
parameters: dict,
|
||||
schema_type: str,
|
||||
schema_type: ApiProviderSchemaType,
|
||||
schema: str,
|
||||
):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import or_, select
|
||||
|
||||
|
|
@ -11,8 +9,8 @@ from core.helper.tool_provider_cache import ToolProviderListCache
|
|||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -39,12 +37,10 @@ class WorkflowToolManageService:
|
|||
label: str,
|
||||
icon: dict,
|
||||
description: str,
|
||||
parameters: list[Mapping[str, Any]],
|
||||
parameters: list[WorkflowToolParameterConfiguration],
|
||||
privacy_policy: str = "",
|
||||
labels: list[str] | None = None,
|
||||
):
|
||||
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
|
||||
|
||||
# check if the name is unique
|
||||
existing_workflow_tool_provider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
|
|
@ -109,7 +105,7 @@ class WorkflowToolManageService:
|
|||
label: str,
|
||||
icon: dict,
|
||||
description: str,
|
||||
parameters: list[Mapping[str, Any]],
|
||||
parameters: list[WorkflowToolParameterConfiguration],
|
||||
privacy_policy: str = "",
|
||||
labels: list[str] | None = None,
|
||||
):
|
||||
|
|
@ -127,8 +123,6 @@ class WorkflowToolManageService:
|
|||
:param labels: labels
|
||||
:return: the updated tool
|
||||
"""
|
||||
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
|
||||
|
||||
# check if the name is unique
|
||||
existing_workflow_tool_provider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
|
|
@ -167,7 +161,7 @@ class WorkflowToolManageService:
|
|||
workflow_tool_provider.label = label
|
||||
workflow_tool_provider.icon = json.dumps(icon)
|
||||
workflow_tool_provider.description = description
|
||||
workflow_tool_provider.parameter_configuration = json.dumps(parameters)
|
||||
workflow_tool_provider.parameter_configuration = json.dumps([p.model_dump() for p in parameters])
|
||||
workflow_tool_provider.privacy_policy = privacy_policy
|
||||
workflow_tool_provider.version = workflow.version
|
||||
workflow_tool_provider.updated_at = datetime.now()
|
||||
|
|
|
|||
|
|
@ -3,7 +3,9 @@ from unittest.mock import patch
|
|||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.workflow import Workflow as WorkflowModel
|
||||
from services.account_service import AccountService, TenantService
|
||||
|
|
@ -130,20 +132,24 @@ class TestWorkflowToolManageService:
|
|||
def _create_test_workflow_tool_parameters(self):
|
||||
"""Helper method to create valid workflow tool parameters."""
|
||||
return [
|
||||
{
|
||||
"name": "input_text",
|
||||
"description": "Input text for processing",
|
||||
"form": "form",
|
||||
"type": "string",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "output_format",
|
||||
"description": "Output format specification",
|
||||
"form": "form",
|
||||
"type": "select",
|
||||
"required": False,
|
||||
},
|
||||
WorkflowToolParameterConfiguration.model_validate(
|
||||
{
|
||||
"name": "input_text",
|
||||
"description": "Input text for processing",
|
||||
"form": "form",
|
||||
"type": "string",
|
||||
"required": True,
|
||||
}
|
||||
),
|
||||
WorkflowToolParameterConfiguration.model_validate(
|
||||
{
|
||||
"name": "output_format",
|
||||
"description": "Output format specification",
|
||||
"form": "form",
|
||||
"type": "select",
|
||||
"required": False,
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
def test_create_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
|
|
@ -208,7 +214,7 @@ class TestWorkflowToolManageService:
|
|||
assert created_tool_provider.label == tool_label
|
||||
assert created_tool_provider.icon == json.dumps(tool_icon)
|
||||
assert created_tool_provider.description == tool_description
|
||||
assert created_tool_provider.parameter_configuration == json.dumps(tool_parameters)
|
||||
assert created_tool_provider.parameter_configuration == json.dumps([p.model_dump() for p in tool_parameters])
|
||||
assert created_tool_provider.privacy_policy == tool_privacy_policy
|
||||
assert created_tool_provider.version == workflow.version
|
||||
assert created_tool_provider.user_id == account.id
|
||||
|
|
@ -353,18 +359,9 @@ class TestWorkflowToolManageService:
|
|||
app, account, workflow = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Setup invalid workflow tool parameters (missing required fields)
|
||||
invalid_parameters = [
|
||||
{
|
||||
"name": "input_text",
|
||||
# Missing description and form fields
|
||||
"type": "string",
|
||||
"required": True,
|
||||
}
|
||||
]
|
||||
# Attempt to create workflow tool with invalid parameters
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
# Setup invalid workflow tool parameters (missing required fields)
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
|
|
@ -373,7 +370,16 @@ class TestWorkflowToolManageService:
|
|||
label=fake.word(),
|
||||
icon={"type": "emoji", "emoji": "🔧"},
|
||||
description=fake.text(max_nb_chars=200),
|
||||
parameters=invalid_parameters,
|
||||
parameters=[
|
||||
WorkflowToolParameterConfiguration.model_validate(
|
||||
{
|
||||
"name": "input_text",
|
||||
# Missing description and form fields
|
||||
"type": "string",
|
||||
"required": True,
|
||||
}
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Verify error message contains validation error
|
||||
|
|
@ -579,11 +585,12 @@ class TestWorkflowToolManageService:
|
|||
|
||||
# Verify database state was updated
|
||||
db.session.refresh(created_tool)
|
||||
assert created_tool is not None
|
||||
assert created_tool.name == updated_tool_name
|
||||
assert created_tool.label == updated_tool_label
|
||||
assert created_tool.icon == json.dumps(updated_tool_icon)
|
||||
assert created_tool.description == updated_tool_description
|
||||
assert created_tool.parameter_configuration == json.dumps(updated_tool_parameters)
|
||||
assert created_tool.parameter_configuration == json.dumps([p.model_dump() for p in updated_tool_parameters])
|
||||
assert created_tool.privacy_policy == updated_tool_privacy_policy
|
||||
assert created_tool.version == workflow.version
|
||||
assert created_tool.updated_at is not None
|
||||
|
|
|
|||
Loading…
Reference in New Issue