mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 21:28:25 +08:00
Merge remote-tracking branch 'origin/main' into feat/support-agent-sandbox
# Conflicts: # api/core/file/file_manager.py # api/core/workflow/graph_engine/response_coordinator/coordinator.py # api/core/workflow/nodes/llm/node.py # api/core/workflow/nodes/tool/tool_node.py # api/pyproject.toml # web/package.json # web/pnpm-lock.yaml
This commit is contained in:
commit
c111079624
3
.github/CODEOWNERS
vendored
3
.github/CODEOWNERS
vendored
@ -9,6 +9,9 @@
|
||||
# CODEOWNERS file
|
||||
/.github/CODEOWNERS @laipz8200 @crazywoola
|
||||
|
||||
# Agents
|
||||
/.agents/skills/ @hyoban
|
||||
|
||||
# Docs
|
||||
/docs/ @crazywoola
|
||||
|
||||
|
||||
@ -53,6 +53,7 @@ select = [
|
||||
"S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers.
|
||||
"S302", # suspicious-marshal-usage, disallow use of `marshal` module
|
||||
"S311", # suspicious-non-cryptographic-random-usage,
|
||||
"TID", # flake8-tidy-imports
|
||||
|
||||
]
|
||||
|
||||
@ -88,6 +89,7 @@ ignore = [
|
||||
"SIM113", # enumerate-for-loop
|
||||
"SIM117", # multiple-with-statements
|
||||
"SIM210", # if-expr-with-true-false
|
||||
"TID252", # allow relative imports from parent modules
|
||||
]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
@ -109,10 +111,20 @@ ignore = [
|
||||
"S110", # allow ignoring exceptions in tests code (currently)
|
||||
|
||||
]
|
||||
"controllers/console/explore/trial.py" = ["TID251"]
|
||||
"controllers/console/human_input_form.py" = ["TID251"]
|
||||
"controllers/web/human_input_form.py" = ["TID251"]
|
||||
|
||||
[lint.pyflakes]
|
||||
allowed-unused-imports = [
|
||||
"_pytest.monkeypatch",
|
||||
"tests.integration_tests",
|
||||
"tests.unit_tests",
|
||||
]
|
||||
|
||||
[lint.flake8-tidy-imports]
|
||||
|
||||
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"]
|
||||
msg = "Use Pydantic payload/query models instead of reqparse."
|
||||
|
||||
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse.RequestParser"]
|
||||
msg = "Use Pydantic payload/query models instead of reqparse."
|
||||
|
||||
203
api/commands.py
203
api/commands.py
@ -1451,54 +1451,58 @@ def clear_orphaned_file_records(force: bool):
|
||||
all_ids_in_tables = []
|
||||
for ids_table in ids_tables:
|
||||
query = ""
|
||||
if ids_table["type"] == "uuid":
|
||||
click.echo(
|
||||
click.style(
|
||||
f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}", fg="white"
|
||||
match ids_table["type"]:
|
||||
case "uuid":
|
||||
click.echo(
|
||||
click.style(
|
||||
f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
)
|
||||
query = (
|
||||
f"SELECT {ids_table['column']} FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])})
|
||||
elif ids_table["type"] == "text":
|
||||
click.echo(
|
||||
click.style(
|
||||
f"- Listing file-id-like strings in column {ids_table['column']} in table {ids_table['table']}",
|
||||
fg="white",
|
||||
c = ids_table["column"]
|
||||
query = f"SELECT {c} FROM {ids_table['table']} WHERE {c} IS NOT NULL"
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])})
|
||||
case "text":
|
||||
t = ids_table["table"]
|
||||
click.echo(
|
||||
click.style(
|
||||
f"- Listing file-id-like strings in column {ids_table['column']} in table {t}",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
)
|
||||
query = (
|
||||
f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id "
|
||||
f"FROM {ids_table['table']}"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
for j in i[0]:
|
||||
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
|
||||
elif ids_table["type"] == "json":
|
||||
click.echo(
|
||||
click.style(
|
||||
(
|
||||
f"- Listing file-id-like JSON string in column {ids_table['column']} "
|
||||
f"in table {ids_table['table']}"
|
||||
),
|
||||
fg="white",
|
||||
query = (
|
||||
f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id "
|
||||
f"FROM {ids_table['table']}"
|
||||
)
|
||||
)
|
||||
query = (
|
||||
f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id "
|
||||
f"FROM {ids_table['table']}"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
for j in i[0]:
|
||||
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
for j in i[0]:
|
||||
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
|
||||
case "json":
|
||||
click.echo(
|
||||
click.style(
|
||||
(
|
||||
f"- Listing file-id-like JSON string in column {ids_table['column']} "
|
||||
f"in table {ids_table['table']}"
|
||||
),
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
query = (
|
||||
f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id "
|
||||
f"FROM {ids_table['table']}"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
for j in i[0]:
|
||||
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
|
||||
case _:
|
||||
pass
|
||||
click.echo(click.style(f"Found {len(all_ids_in_tables)} file ids in tables.", fg="white"))
|
||||
|
||||
except Exception as e:
|
||||
@ -1738,59 +1742,18 @@ def file_usage(
|
||||
if src_filter != src:
|
||||
continue
|
||||
|
||||
if ids_table["type"] == "uuid":
|
||||
# Direct UUID match
|
||||
query = (
|
||||
f"SELECT {ids_table['pk_column']}, {ids_table['column']} "
|
||||
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for row in rs:
|
||||
record_id = str(row[0])
|
||||
ref_file_id = str(row[1])
|
||||
if ref_file_id not in file_key_map:
|
||||
continue
|
||||
storage_key = file_key_map[ref_file_id]
|
||||
|
||||
# Apply filters
|
||||
if file_id and ref_file_id != file_id:
|
||||
continue
|
||||
if key and not storage_key.endswith(key):
|
||||
continue
|
||||
|
||||
# Only collect items within the requested page range
|
||||
if offset <= total_count < offset + limit:
|
||||
paginated_usages.append(
|
||||
{
|
||||
"src": f"{ids_table['table']}.{ids_table['column']}",
|
||||
"record_id": record_id,
|
||||
"file_id": ref_file_id,
|
||||
"key": storage_key,
|
||||
}
|
||||
)
|
||||
total_count += 1
|
||||
|
||||
elif ids_table["type"] in ("text", "json"):
|
||||
# Extract UUIDs from text/json content
|
||||
column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"]
|
||||
query = (
|
||||
f"SELECT {ids_table['pk_column']}, {column_cast} "
|
||||
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for row in rs:
|
||||
record_id = str(row[0])
|
||||
content = str(row[1])
|
||||
|
||||
# Find all UUIDs in the content
|
||||
import re
|
||||
|
||||
uuid_pattern = re.compile(guid_regexp, re.IGNORECASE)
|
||||
matches = uuid_pattern.findall(content)
|
||||
|
||||
for ref_file_id in matches:
|
||||
match ids_table["type"]:
|
||||
case "uuid":
|
||||
# Direct UUID match
|
||||
query = (
|
||||
f"SELECT {ids_table['pk_column']}, {ids_table['column']} "
|
||||
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for row in rs:
|
||||
record_id = str(row[0])
|
||||
ref_file_id = str(row[1])
|
||||
if ref_file_id not in file_key_map:
|
||||
continue
|
||||
storage_key = file_key_map[ref_file_id]
|
||||
@ -1813,6 +1776,50 @@ def file_usage(
|
||||
)
|
||||
total_count += 1
|
||||
|
||||
case "text" | "json":
|
||||
# Extract UUIDs from text/json content
|
||||
column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"]
|
||||
query = (
|
||||
f"SELECT {ids_table['pk_column']}, {column_cast} "
|
||||
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for row in rs:
|
||||
record_id = str(row[0])
|
||||
content = str(row[1])
|
||||
|
||||
# Find all UUIDs in the content
|
||||
import re
|
||||
|
||||
uuid_pattern = re.compile(guid_regexp, re.IGNORECASE)
|
||||
matches = uuid_pattern.findall(content)
|
||||
|
||||
for ref_file_id in matches:
|
||||
if ref_file_id not in file_key_map:
|
||||
continue
|
||||
storage_key = file_key_map[ref_file_id]
|
||||
|
||||
# Apply filters
|
||||
if file_id and ref_file_id != file_id:
|
||||
continue
|
||||
if key and not storage_key.endswith(key):
|
||||
continue
|
||||
|
||||
# Only collect items within the requested page range
|
||||
if offset <= total_count < offset + limit:
|
||||
paginated_usages.append(
|
||||
{
|
||||
"src": f"{ids_table['table']}.{ids_table['column']}",
|
||||
"record_id": record_id,
|
||||
"file_id": ref_file_id,
|
||||
"key": storage_key,
|
||||
}
|
||||
)
|
||||
total_count += 1
|
||||
case _:
|
||||
pass
|
||||
|
||||
# Output results
|
||||
if output_json:
|
||||
result = {
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import abort, make_response, request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
|
||||
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
@ -16,9 +17,11 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import (
|
||||
annotation_fields,
|
||||
annotation_hit_history_fields,
|
||||
build_annotation_model,
|
||||
Annotation,
|
||||
AnnotationExportList,
|
||||
AnnotationHitHistory,
|
||||
AnnotationHitHistoryList,
|
||||
AnnotationList,
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
@ -89,6 +92,14 @@ reg(CreateAnnotationPayload)
|
||||
reg(UpdateAnnotationPayload)
|
||||
reg(AnnotationReplyStatusQuery)
|
||||
reg(AnnotationFilePayload)
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
Annotation,
|
||||
AnnotationList,
|
||||
AnnotationExportList,
|
||||
AnnotationHitHistory,
|
||||
AnnotationHitHistoryList,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>")
|
||||
@ -107,10 +118,11 @@ class AnnotationReplyActionApi(Resource):
|
||||
def post(self, app_id, action: Literal["enable", "disable"]):
|
||||
app_id = str(app_id)
|
||||
args = AnnotationReplyPayload.model_validate(console_ns.payload)
|
||||
if action == "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
|
||||
elif action == "disable":
|
||||
result = AppAnnotationService.disable_app_annotation(app_id)
|
||||
match action:
|
||||
case "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
|
||||
case "disable":
|
||||
result = AppAnnotationService.disable_app_annotation(app_id)
|
||||
return result, 200
|
||||
|
||||
|
||||
@ -201,33 +213,33 @@ class AnnotationApi(Resource):
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
|
||||
response = {
|
||||
"data": marshal(annotation_list, annotation_fields),
|
||||
"has_more": len(annotation_list) == limit,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"page": page,
|
||||
}
|
||||
return response, 200
|
||||
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
|
||||
response = AnnotationList(
|
||||
data=annotation_models,
|
||||
has_more=len(annotation_list) == limit,
|
||||
limit=limit,
|
||||
total=total,
|
||||
page=page,
|
||||
)
|
||||
return response.model_dump(mode="json"), 200
|
||||
|
||||
@console_ns.doc("create_annotation")
|
||||
@console_ns.doc(description="Create a new annotation for an app")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[CreateAnnotationPayload.__name__])
|
||||
@console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns))
|
||||
@console_ns.response(201, "Annotation created successfully", console_ns.models[Annotation.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@marshal_with(annotation_fields)
|
||||
@edit_permission_required
|
||||
def post(self, app_id):
|
||||
app_id = str(app_id)
|
||||
args = CreateAnnotationPayload.model_validate(console_ns.payload)
|
||||
data = args.model_dump(exclude_none=True)
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
|
||||
return annotation
|
||||
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -264,7 +276,7 @@ class AnnotationExportApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Annotations exported successfully",
|
||||
console_ns.model("AnnotationList", {"data": fields.List(fields.Nested(build_annotation_model(console_ns)))}),
|
||||
console_ns.models[AnnotationExportList.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@ -274,7 +286,8 @@ class AnnotationExportApi(Resource):
|
||||
def get(self, app_id):
|
||||
app_id = str(app_id)
|
||||
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
|
||||
response_data = {"data": marshal(annotation_list, annotation_fields)}
|
||||
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
|
||||
response_data = AnnotationExportList(data=annotation_models).model_dump(mode="json")
|
||||
|
||||
# Create response with secure headers for CSV export
|
||||
response = make_response(response_data, 200)
|
||||
@ -289,7 +302,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@console_ns.doc("update_delete_annotation")
|
||||
@console_ns.doc(description="Update or delete an annotation")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
|
||||
@console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns))
|
||||
@console_ns.response(200, "Annotation updated successfully", console_ns.models[Annotation.__name__])
|
||||
@console_ns.response(204, "Annotation deleted successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.expect(console_ns.models[UpdateAnnotationPayload.__name__])
|
||||
@ -298,7 +311,6 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@edit_permission_required
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id, annotation_id):
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
@ -306,7 +318,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(
|
||||
args.model_dump(exclude_none=True), app_id, annotation_id
|
||||
)
|
||||
return annotation
|
||||
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -414,14 +426,7 @@ class AnnotationHitHistoryListApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Hit histories retrieved successfully",
|
||||
console_ns.model(
|
||||
"AnnotationHitHistoryList",
|
||||
{
|
||||
"data": fields.List(
|
||||
fields.Nested(console_ns.model("AnnotationHitHistoryItem", annotation_hit_history_fields))
|
||||
)
|
||||
},
|
||||
),
|
||||
console_ns.models[AnnotationHitHistoryList.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@ -436,11 +441,14 @@ class AnnotationHitHistoryListApi(Resource):
|
||||
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(
|
||||
app_id, annotation_id, page, limit
|
||||
)
|
||||
response = {
|
||||
"data": marshal(annotation_hit_history_list, annotation_hit_history_fields),
|
||||
"has_more": len(annotation_hit_history_list) == limit,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"page": page,
|
||||
}
|
||||
return response
|
||||
history_models = TypeAdapter(list[AnnotationHitHistory]).validate_python(
|
||||
annotation_hit_history_list, from_attributes=True
|
||||
)
|
||||
response = AnnotationHitHistoryList(
|
||||
data=history_models,
|
||||
has_more=len(annotation_hit_history_list) == limit,
|
||||
limit=limit,
|
||||
total=total,
|
||||
page=page,
|
||||
)
|
||||
return response.model_dump(mode="json")
|
||||
|
||||
@ -6,6 +6,7 @@ from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
@ -33,7 +34,6 @@ from services.errors.audio import (
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class TextToSpeechPayload(BaseModel):
|
||||
@ -47,13 +47,11 @@ class TextToSpeechVoiceQuery(BaseModel):
|
||||
language: str = Field(..., description="Language code")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
TextToSpeechPayload.__name__, TextToSpeechPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
TextToSpeechVoiceQuery.__name__,
|
||||
TextToSpeechVoiceQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
class AudioTranscriptResponse(BaseModel):
|
||||
text: str = Field(description="Transcribed text from audio")
|
||||
|
||||
|
||||
register_schema_models(console_ns, AudioTranscriptResponse, TextToSpeechPayload, TextToSpeechVoiceQuery)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/audio-to-text")
|
||||
@ -64,7 +62,7 @@ class ChatMessageAudioApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Audio transcription successful",
|
||||
console_ns.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}),
|
||||
console_ns.models[AudioTranscriptResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Bad request - No audio uploaded or unsupported type")
|
||||
@console_ns.response(413, "Audio file too large")
|
||||
|
||||
@ -508,16 +508,19 @@ class ChatConversationApi(Resource):
|
||||
case "created_at" | "-created_at" | _:
|
||||
query = query.where(Conversation.created_at <= end_datetime_utc)
|
||||
|
||||
if args.annotation_status == "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
elif args.annotation_status == "not_annotated":
|
||||
query = (
|
||||
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
.having(func.count(MessageAnnotation.id) == 0)
|
||||
)
|
||||
match args.annotation_status:
|
||||
case "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
case "not_annotated":
|
||||
query = (
|
||||
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
.having(func.count(MessageAnnotation.id) == 0)
|
||||
)
|
||||
case "all":
|
||||
pass
|
||||
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
|
||||
|
||||
@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import exists, select
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
@ -35,7 +36,6 @@ from services.errors.message import MessageNotExistsError, SuggestedQuestionsAft
|
||||
from services.message_service import MessageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ChatMessagesQuery(BaseModel):
|
||||
@ -90,13 +90,22 @@ class FeedbackExportQuery(BaseModel):
|
||||
raise ValueError("has_comment must be a boolean value")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
class AnnotationCountResponse(BaseModel):
|
||||
count: int = Field(description="Number of annotations")
|
||||
|
||||
|
||||
reg(ChatMessagesQuery)
|
||||
reg(MessageFeedbackPayload)
|
||||
reg(FeedbackExportQuery)
|
||||
class SuggestedQuestionsResponse(BaseModel):
|
||||
data: list[str] = Field(description="Suggested question")
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
ChatMessagesQuery,
|
||||
MessageFeedbackPayload,
|
||||
FeedbackExportQuery,
|
||||
AnnotationCountResponse,
|
||||
SuggestedQuestionsResponse,
|
||||
)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
@ -232,7 +241,7 @@ class ChatMessageListApi(Resource):
|
||||
@marshal_with(message_infinite_scroll_pagination_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = ChatMessagesQuery.model_validate(request.args.to_dict())
|
||||
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
@ -357,7 +366,7 @@ class MessageAnnotationCountApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Annotation count retrieved successfully",
|
||||
console_ns.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}),
|
||||
console_ns.models[AnnotationCountResponse.__name__],
|
||||
)
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@ -377,9 +386,7 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Suggested questions retrieved successfully",
|
||||
console_ns.model(
|
||||
"SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))}
|
||||
),
|
||||
console_ns.models[SuggestedQuestionsResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Message or conversation not found")
|
||||
@setup_required
|
||||
@ -429,7 +436,7 @@ class MessageFeedbackExportApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = FeedbackExportQuery.model_validate(request.args.to_dict())
|
||||
|
||||
# Import the service function
|
||||
from services.feedback_service import FeedbackService
|
||||
|
||||
@ -2,9 +2,11 @@ import logging
|
||||
|
||||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restx import Resource, fields
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import register_schema_models
|
||||
from libs.login import login_required
|
||||
from libs.oauth_data_source import NotionOAuth
|
||||
|
||||
@ -14,6 +16,26 @@ from ..wraps import account_initialization_required, is_admin_or_owner_required,
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthDataSourceResponse(BaseModel):
|
||||
data: str = Field(description="Authorization URL or 'internal' for internal setup")
|
||||
|
||||
|
||||
class OAuthDataSourceBindingResponse(BaseModel):
|
||||
result: str = Field(description="Operation result")
|
||||
|
||||
|
||||
class OAuthDataSourceSyncResponse(BaseModel):
|
||||
result: str = Field(description="Operation result")
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
OAuthDataSourceResponse,
|
||||
OAuthDataSourceBindingResponse,
|
||||
OAuthDataSourceSyncResponse,
|
||||
)
|
||||
|
||||
|
||||
def get_oauth_providers():
|
||||
with current_app.app_context():
|
||||
notion_oauth = NotionOAuth(
|
||||
@ -34,10 +56,7 @@ class OAuthDataSource(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Authorization URL or internal setup success",
|
||||
console_ns.model(
|
||||
"OAuthDataSourceResponse",
|
||||
{"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")},
|
||||
),
|
||||
console_ns.models[OAuthDataSourceResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Invalid provider")
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@ -101,7 +120,7 @@ class OAuthDataSourceBinding(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Data source binding success",
|
||||
console_ns.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}),
|
||||
console_ns.models[OAuthDataSourceBindingResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Invalid provider or code")
|
||||
def get(self, provider: str):
|
||||
@ -133,7 +152,7 @@ class OAuthDataSourceSync(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Data source sync success",
|
||||
console_ns.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}),
|
||||
console_ns.models[OAuthDataSourceSyncResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Invalid provider or sync failed")
|
||||
@setup_required
|
||||
|
||||
@ -2,10 +2,11 @@ import base64
|
||||
import secrets
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailCodeError,
|
||||
@ -48,8 +49,31 @@ class ForgotPasswordResetPayload(BaseModel):
|
||||
return valid_password(value)
|
||||
|
||||
|
||||
for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
class ForgotPasswordEmailResponse(BaseModel):
|
||||
result: str = Field(description="Operation result")
|
||||
data: str | None = Field(default=None, description="Reset token")
|
||||
code: str | None = Field(default=None, description="Error code if account not found")
|
||||
|
||||
|
||||
class ForgotPasswordCheckResponse(BaseModel):
|
||||
is_valid: bool = Field(description="Whether code is valid")
|
||||
email: EmailStr = Field(description="Email address")
|
||||
token: str = Field(description="New reset token")
|
||||
|
||||
|
||||
class ForgotPasswordResetResponse(BaseModel):
|
||||
result: str = Field(description="Operation result")
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
ForgotPasswordSendPayload,
|
||||
ForgotPasswordCheckPayload,
|
||||
ForgotPasswordResetPayload,
|
||||
ForgotPasswordEmailResponse,
|
||||
ForgotPasswordCheckResponse,
|
||||
ForgotPasswordResetResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/forgot-password")
|
||||
@ -60,14 +84,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Email sent successfully",
|
||||
console_ns.model(
|
||||
"ForgotPasswordEmailResponse",
|
||||
{
|
||||
"result": fields.String(description="Operation result"),
|
||||
"data": fields.String(description="Reset token"),
|
||||
"code": fields.String(description="Error code if account not found"),
|
||||
},
|
||||
),
|
||||
console_ns.models[ForgotPasswordEmailResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Invalid email or rate limit exceeded")
|
||||
@setup_required
|
||||
@ -106,14 +123,7 @@ class ForgotPasswordCheckApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Code verified successfully",
|
||||
console_ns.model(
|
||||
"ForgotPasswordCheckResponse",
|
||||
{
|
||||
"is_valid": fields.Boolean(description="Whether code is valid"),
|
||||
"email": fields.String(description="Email address"),
|
||||
"token": fields.String(description="New reset token"),
|
||||
},
|
||||
),
|
||||
console_ns.models[ForgotPasswordCheckResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Invalid code or token")
|
||||
@setup_required
|
||||
@ -163,7 +173,7 @@ class ForgotPasswordResetApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Password reset successfully",
|
||||
console_ns.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}),
|
||||
console_ns.models[ForgotPasswordResetResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Invalid token or password mismatch")
|
||||
@setup_required
|
||||
|
||||
@ -155,43 +155,43 @@ class OAuthServerUserTokenApi(Resource):
|
||||
grant_type = OAuthGrantType(payload.grant_type)
|
||||
except ValueError:
|
||||
raise BadRequest("invalid grant_type")
|
||||
match grant_type:
|
||||
case OAuthGrantType.AUTHORIZATION_CODE:
|
||||
if not payload.code:
|
||||
raise BadRequest("code is required")
|
||||
|
||||
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
|
||||
if not payload.code:
|
||||
raise BadRequest("code is required")
|
||||
if payload.client_secret != oauth_provider_app.client_secret:
|
||||
raise BadRequest("client_secret is invalid")
|
||||
|
||||
if payload.client_secret != oauth_provider_app.client_secret:
|
||||
raise BadRequest("client_secret is invalid")
|
||||
if payload.redirect_uri not in oauth_provider_app.redirect_uris:
|
||||
raise BadRequest("redirect_uri is invalid")
|
||||
|
||||
if payload.redirect_uri not in oauth_provider_app.redirect_uris:
|
||||
raise BadRequest("redirect_uri is invalid")
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, code=payload.code, client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
)
|
||||
case OAuthGrantType.REFRESH_TOKEN:
|
||||
if not payload.refresh_token:
|
||||
raise BadRequest("refresh_token is required")
|
||||
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, code=payload.code, client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
)
|
||||
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
|
||||
if not payload.refresh_token:
|
||||
raise BadRequest("refresh_token is required")
|
||||
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
)
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/oauth/provider/account")
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
@ -157,9 +157,8 @@ class DataSourceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, binding_id, action):
|
||||
def patch(self, binding_id, action: Literal["enable", "disable"]):
|
||||
binding_id = str(binding_id)
|
||||
action = str(action)
|
||||
with Session(db.engine) as session:
|
||||
data_source_binding = session.execute(
|
||||
select(DataSourceOauthBinding).filter_by(id=binding_id)
|
||||
@ -167,23 +166,24 @@ class DataSourceApi(Resource):
|
||||
if data_source_binding is None:
|
||||
raise NotFound("Data source binding not found.")
|
||||
# enable binding
|
||||
if action == "enable":
|
||||
if data_source_binding.disabled:
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source is not disabled.")
|
||||
# disable binding
|
||||
if action == "disable":
|
||||
if not data_source_binding.disabled:
|
||||
data_source_binding.disabled = True
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source is disabled.")
|
||||
match action:
|
||||
case "enable":
|
||||
if data_source_binding.disabled:
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source is not disabled.")
|
||||
# disable binding
|
||||
case "disable":
|
||||
if not data_source_binding.disabled:
|
||||
data_source_binding.disabled = True
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source is disabled.")
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
|
||||
@ -576,63 +576,62 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
if document.indexing_status in {"completed", "error"}:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
data_source_info = document.data_source_info_dict
|
||||
match document.data_source_type:
|
||||
case "upload_file":
|
||||
if not data_source_info:
|
||||
continue
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
file_detail = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if document.data_source_type == "upload_file":
|
||||
if not data_source_info:
|
||||
continue
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
file_detail = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
|
||||
.first()
|
||||
)
|
||||
if file_detail is None:
|
||||
raise NotFound("File not found.")
|
||||
|
||||
if file_detail is None:
|
||||
raise NotFound("File not found.")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
case "notion_import":
|
||||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": data_source_info.get("credential_id"),
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
"tenant_id": current_tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=document.doc_form,
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
case "website_crawl":
|
||||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE,
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
"url": data_source_info["url"],
|
||||
"tenant_id": current_tenant_id,
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
}
|
||||
),
|
||||
document_model=document.doc_form,
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
elif document.data_source_type == "notion_import":
|
||||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": data_source_info.get("credential_id"),
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
"tenant_id": current_tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=document.doc_form,
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
elif document.data_source_type == "website_crawl":
|
||||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE,
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
"url": data_source_info["url"],
|
||||
"tenant_id": current_tenant_id,
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
}
|
||||
),
|
||||
document_model=document.doc_form,
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
else:
|
||||
raise ValueError("Data source type not support")
|
||||
case _:
|
||||
raise ValueError("Data source type not support")
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.indexing_estimate(
|
||||
@ -954,23 +953,24 @@ class DocumentProcessingApi(DocumentResource):
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if action == "pause":
|
||||
if document.indexing_status != "indexing":
|
||||
raise InvalidActionError("Document not in indexing state.")
|
||||
match action:
|
||||
case "pause":
|
||||
if document.indexing_status != "indexing":
|
||||
raise InvalidActionError("Document not in indexing state.")
|
||||
|
||||
document.paused_by = current_user.id
|
||||
document.paused_at = naive_utc_now()
|
||||
document.is_paused = True
|
||||
db.session.commit()
|
||||
document.paused_by = current_user.id
|
||||
document.paused_at = naive_utc_now()
|
||||
document.is_paused = True
|
||||
db.session.commit()
|
||||
|
||||
elif action == "resume":
|
||||
if document.indexing_status not in {"paused", "error"}:
|
||||
raise InvalidActionError("Document not in paused or error state.")
|
||||
case "resume":
|
||||
if document.indexing_status not in {"paused", "error"}:
|
||||
raise InvalidActionError("Document not in paused or error state.")
|
||||
|
||||
document.paused_by = None
|
||||
document.paused_at = None
|
||||
document.is_paused = False
|
||||
db.session.commit()
|
||||
document.paused_by = None
|
||||
document.paused_at = None
|
||||
document.is_paused = False
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@ -1339,6 +1339,18 @@ class DocumentGenerateSummaryApi(Resource):
|
||||
missing_ids = set(document_list) - found_ids
|
||||
raise NotFound(f"Some documents not found: {list(missing_ids)}")
|
||||
|
||||
# Update need_summary to True for documents that don't have it set
|
||||
# This handles the case where documents were created when summary_index_setting was disabled
|
||||
documents_to_update = [doc for doc in documents if not doc.need_summary and doc.doc_form != "qa_model"]
|
||||
|
||||
if documents_to_update:
|
||||
document_ids_to_update = [str(doc.id) for doc in documents_to_update]
|
||||
DocumentService.update_documents_need_summary(
|
||||
dataset_id=dataset_id,
|
||||
document_ids=document_ids_to_update,
|
||||
need_summary=True,
|
||||
)
|
||||
|
||||
# Dispatch async tasks for each document
|
||||
for document in documents:
|
||||
# Skip qa_model documents as they don't generate summaries
|
||||
|
||||
@ -126,10 +126,11 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
if action == "enable":
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
elif action == "disable":
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
match action:
|
||||
case "enable":
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
case "disable":
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Literal, cast
|
||||
from uuid import UUID
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, marshal_with, reqparse # type: ignore
|
||||
from flask_restx import Resource, marshal_with # type: ignore
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
@ -38,7 +37,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField
|
||||
from libs.helper import TimestampField, UUIDStrOrEmpty
|
||||
from libs.login import current_account_with_tenant, current_user, login_required
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
@ -110,7 +109,7 @@ class NodeIdQuery(BaseModel):
|
||||
|
||||
|
||||
class WorkflowRunQuery(BaseModel):
|
||||
last_id: UUID | None = None
|
||||
last_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
@ -121,6 +120,10 @@ class DatasourceVariablesPayload(BaseModel):
|
||||
start_node_title: str
|
||||
|
||||
|
||||
class RagPipelineRecommendedPluginQuery(BaseModel):
|
||||
type: str = "all"
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
DraftWorkflowSyncPayload,
|
||||
@ -135,6 +138,7 @@ register_schema_models(
|
||||
NodeIdQuery,
|
||||
WorkflowRunQuery,
|
||||
DatasourceVariablesPayload,
|
||||
RagPipelineRecommendedPluginQuery,
|
||||
)
|
||||
|
||||
|
||||
@ -975,11 +979,8 @@ class RagPipelineRecommendedPluginApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("type", type=str, location="args", required=False, default="all")
|
||||
args = parser.parse_args()
|
||||
type = args["type"]
|
||||
query = RagPipelineRecommendedPluginQuery.model_validate(request.args.to_dict())
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
recommended_plugins = rag_pipeline_service.get_recommended_plugins(type)
|
||||
recommended_plugins = rag_pipeline_service.get_recommended_plugins(query.type)
|
||||
return recommended_plugins
|
||||
|
||||
@ -9,7 +9,7 @@ import services
|
||||
from controllers.common.fields import Parameters as ParametersResponse
|
||||
from controllers.common.fields import Site as SiteResponse
|
||||
from controllers.common.schema import get_or_create_model
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
@ -51,7 +51,7 @@ from fields.app_fields import (
|
||||
tag_fields,
|
||||
)
|
||||
from fields.dataset_fields import dataset_fields
|
||||
from fields.member_fields import build_simple_account_model
|
||||
from fields.member_fields import simple_account_fields
|
||||
from fields.workflow_fields import (
|
||||
conversation_variable_fields,
|
||||
pipeline_variable_fields,
|
||||
@ -103,7 +103,7 @@ app_detail_fields_with_site_copy["tags"] = fields.List(fields.Nested(tag_model))
|
||||
app_detail_fields_with_site_copy["site"] = fields.Nested(site_model)
|
||||
app_detail_with_site_model = get_or_create_model("TrialAppDetailWithSite", app_detail_fields_with_site_copy)
|
||||
|
||||
simple_account_model = build_simple_account_model(console_ns)
|
||||
simple_account_model = get_or_create_model("SimpleAccount", simple_account_fields)
|
||||
conversation_variable_model = get_or_create_model("TrialConversationVariable", conversation_variable_fields)
|
||||
pipeline_variable_model = get_or_create_model("TrialPipelineVariable", pipeline_variable_fields)
|
||||
|
||||
|
||||
@ -1,87 +1,74 @@
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
from flask import session
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.fastopenapi import console_router
|
||||
from extensions.ext_database import db
|
||||
from models.model import DifySetup
|
||||
from services.account_service import TenantService
|
||||
|
||||
from . import console_ns
|
||||
from .error import AlreadySetupError, InitValidateFailedError
|
||||
from .wraps import only_edition_self_hosted
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class InitValidatePayload(BaseModel):
|
||||
password: str = Field(..., max_length=30)
|
||||
password: str = Field(..., max_length=30, description="Initialization password")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
InitValidatePayload.__name__,
|
||||
InitValidatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
class InitStatusResponse(BaseModel):
|
||||
status: Literal["finished", "not_started"] = Field(..., description="Initialization status")
|
||||
|
||||
|
||||
class InitValidateResponse(BaseModel):
|
||||
result: str = Field(description="Operation result", examples=["success"])
|
||||
|
||||
|
||||
@console_router.get(
|
||||
"/init",
|
||||
response_model=InitStatusResponse,
|
||||
tags=["console"],
|
||||
)
|
||||
def get_init_status() -> InitStatusResponse:
|
||||
"""Get initialization validation status."""
|
||||
init_status = get_init_validate_status()
|
||||
if init_status:
|
||||
return InitStatusResponse(status="finished")
|
||||
return InitStatusResponse(status="not_started")
|
||||
|
||||
|
||||
@console_ns.route("/init")
|
||||
class InitValidateAPI(Resource):
|
||||
@console_ns.doc("get_init_status")
|
||||
@console_ns.doc(description="Get initialization validation status")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
model=console_ns.model(
|
||||
"InitStatusResponse",
|
||||
{"status": fields.String(description="Initialization status", enum=["finished", "not_started"])},
|
||||
),
|
||||
)
|
||||
def get(self):
|
||||
"""Get initialization validation status"""
|
||||
init_status = get_init_validate_status()
|
||||
if init_status:
|
||||
return {"status": "finished"}
|
||||
return {"status": "not_started"}
|
||||
@console_router.post(
|
||||
"/init",
|
||||
response_model=InitValidateResponse,
|
||||
tags=["console"],
|
||||
status_code=201,
|
||||
)
|
||||
@only_edition_self_hosted
|
||||
def validate_init_password(payload: InitValidatePayload) -> InitValidateResponse:
|
||||
"""Validate initialization password."""
|
||||
tenant_count = TenantService.get_tenant_count()
|
||||
if tenant_count > 0:
|
||||
raise AlreadySetupError()
|
||||
|
||||
@console_ns.doc("validate_init_password")
|
||||
@console_ns.doc(description="Validate initialization password for self-hosted edition")
|
||||
@console_ns.expect(console_ns.models[InitValidatePayload.__name__])
|
||||
@console_ns.response(
|
||||
201,
|
||||
"Success",
|
||||
model=console_ns.model("InitValidateResponse", {"result": fields.String(description="Operation result")}),
|
||||
)
|
||||
@console_ns.response(400, "Already setup or validation failed")
|
||||
@only_edition_self_hosted
|
||||
def post(self):
|
||||
"""Validate initialization password"""
|
||||
# is tenant created
|
||||
tenant_count = TenantService.get_tenant_count()
|
||||
if tenant_count > 0:
|
||||
raise AlreadySetupError()
|
||||
if payload.password != os.environ.get("INIT_PASSWORD"):
|
||||
session["is_init_validated"] = False
|
||||
raise InitValidateFailedError()
|
||||
|
||||
payload = InitValidatePayload.model_validate(console_ns.payload)
|
||||
input_password = payload.password
|
||||
|
||||
if input_password != os.environ.get("INIT_PASSWORD"):
|
||||
session["is_init_validated"] = False
|
||||
raise InitValidateFailedError()
|
||||
|
||||
session["is_init_validated"] = True
|
||||
return {"result": "success"}, 201
|
||||
session["is_init_validated"] = True
|
||||
return InitValidateResponse(result="success")
|
||||
|
||||
|
||||
def get_init_validate_status():
|
||||
def get_init_validate_status() -> bool:
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
if os.environ.get("INIT_PASSWORD"):
|
||||
if session.get("is_init_validated"):
|
||||
return True
|
||||
|
||||
with Session(db.engine) as db_session:
|
||||
return db_session.execute(select(DifySetup)).scalar_one_or_none()
|
||||
return db_session.execute(select(DifySetup)).scalar_one_or_none() is not None
|
||||
|
||||
return True
|
||||
|
||||
@ -1,17 +1,27 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from flask_restx import Namespace, Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from fields.tag_fields import dataset_tag_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.tag_service import TagService
|
||||
|
||||
dataset_tag_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"type": fields.String,
|
||||
"binding_count": fields.String,
|
||||
}
|
||||
|
||||
|
||||
def build_dataset_tag_fields(api_or_ns: Namespace):
|
||||
return api_or_ns.model("DataSetTag", dataset_tag_fields)
|
||||
|
||||
|
||||
class TagBasePayload(BaseModel):
|
||||
name: str = Field(description="Tag name", min_length=1, max_length=50)
|
||||
|
||||
@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailAlreadyInUseError,
|
||||
@ -38,7 +39,7 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from core.file import helpers as file_helpers
|
||||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_fields
|
||||
from fields.member_fields import Account as AccountResponse
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
@ -176,6 +177,12 @@ reg(ChangeEmailSendPayload)
|
||||
reg(ChangeEmailValidityPayload)
|
||||
reg(ChangeEmailResetPayload)
|
||||
reg(CheckEmailUniquePayload)
|
||||
register_schema_models(console_ns, AccountResponse)
|
||||
|
||||
|
||||
def _serialize_account(account) -> dict:
|
||||
return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
integrate_fields = {
|
||||
"provider": fields.String,
|
||||
@ -242,11 +249,11 @@ class AccountProfileApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
return current_user
|
||||
return _serialize_account(current_user)
|
||||
|
||||
|
||||
@console_ns.route("/account/name")
|
||||
@ -255,14 +262,14 @@ class AccountNameApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountNamePayload.model_validate(payload)
|
||||
updated_account = AccountService.update_account(current_user, name=args.name)
|
||||
|
||||
return updated_account
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
|
||||
@console_ns.route("/account/avatar")
|
||||
@ -283,7 +290,7 @@ class AccountAvatarApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = console_ns.payload or {}
|
||||
@ -291,7 +298,7 @@ class AccountAvatarApi(Resource):
|
||||
|
||||
updated_account = AccountService.update_account(current_user, avatar=args.avatar)
|
||||
|
||||
return updated_account
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
|
||||
@console_ns.route("/account/interface-language")
|
||||
@ -300,7 +307,7 @@ class AccountInterfaceLanguageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = console_ns.payload or {}
|
||||
@ -308,7 +315,7 @@ class AccountInterfaceLanguageApi(Resource):
|
||||
|
||||
updated_account = AccountService.update_account(current_user, interface_language=args.interface_language)
|
||||
|
||||
return updated_account
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
|
||||
@console_ns.route("/account/interface-theme")
|
||||
@ -317,7 +324,7 @@ class AccountInterfaceThemeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = console_ns.payload or {}
|
||||
@ -325,7 +332,7 @@ class AccountInterfaceThemeApi(Resource):
|
||||
|
||||
updated_account = AccountService.update_account(current_user, interface_theme=args.interface_theme)
|
||||
|
||||
return updated_account
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
|
||||
@console_ns.route("/account/timezone")
|
||||
@ -334,7 +341,7 @@ class AccountTimezoneApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = console_ns.payload or {}
|
||||
@ -342,7 +349,7 @@ class AccountTimezoneApi(Resource):
|
||||
|
||||
updated_account = AccountService.update_account(current_user, timezone=args.timezone)
|
||||
|
||||
return updated_account
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
|
||||
@console_ns.route("/account/password")
|
||||
@ -351,7 +358,7 @@ class AccountPasswordApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = console_ns.payload or {}
|
||||
@ -362,7 +369,7 @@ class AccountPasswordApi(Resource):
|
||||
except ServiceCurrentPasswordIncorrectError:
|
||||
raise CurrentPasswordIncorrectError()
|
||||
|
||||
return {"result": "success"}
|
||||
return _serialize_account(current_user)
|
||||
|
||||
|
||||
@console_ns.route("/account/integrates")
|
||||
@ -638,7 +645,7 @@ class ChangeEmailResetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
payload = console_ns.payload or {}
|
||||
args = ChangeEmailResetPayload.model_validate(payload)
|
||||
@ -667,7 +674,7 @@ class ChangeEmailResetApi(Resource):
|
||||
email=normalized_new_email,
|
||||
)
|
||||
|
||||
return updated_account
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email/check-email-unique")
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
@ -38,15 +39,53 @@ class EndpointListForPluginQuery(EndpointListQuery):
|
||||
plugin_id: str
|
||||
|
||||
|
||||
class EndpointCreateResponse(BaseModel):
|
||||
success: bool = Field(description="Operation success")
|
||||
|
||||
|
||||
class EndpointListResponse(BaseModel):
|
||||
endpoints: list[dict[str, Any]] = Field(description="Endpoint information")
|
||||
|
||||
|
||||
class PluginEndpointListResponse(BaseModel):
|
||||
endpoints: list[dict[str, Any]] = Field(description="Endpoint information")
|
||||
|
||||
|
||||
class EndpointDeleteResponse(BaseModel):
|
||||
success: bool = Field(description="Operation success")
|
||||
|
||||
|
||||
class EndpointUpdateResponse(BaseModel):
|
||||
success: bool = Field(description="Operation success")
|
||||
|
||||
|
||||
class EndpointEnableResponse(BaseModel):
|
||||
success: bool = Field(description="Operation success")
|
||||
|
||||
|
||||
class EndpointDisableResponse(BaseModel):
|
||||
success: bool = Field(description="Operation success")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(EndpointCreatePayload)
|
||||
reg(EndpointIdPayload)
|
||||
reg(EndpointUpdatePayload)
|
||||
reg(EndpointListQuery)
|
||||
reg(EndpointListForPluginQuery)
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
EndpointCreatePayload,
|
||||
EndpointIdPayload,
|
||||
EndpointUpdatePayload,
|
||||
EndpointListQuery,
|
||||
EndpointListForPluginQuery,
|
||||
EndpointCreateResponse,
|
||||
EndpointListResponse,
|
||||
PluginEndpointListResponse,
|
||||
EndpointDeleteResponse,
|
||||
EndpointUpdateResponse,
|
||||
EndpointEnableResponse,
|
||||
EndpointDisableResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/create")
|
||||
@ -57,7 +96,7 @@ class EndpointCreateApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint created successfully",
|
||||
console_ns.model("EndpointCreateResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
console_ns.models[EndpointCreateResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@ -91,9 +130,7 @@ class EndpointListApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model(
|
||||
"EndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}
|
||||
),
|
||||
console_ns.models[EndpointListResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -126,9 +163,7 @@ class EndpointListForSinglePluginApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model(
|
||||
"PluginEndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}
|
||||
),
|
||||
console_ns.models[PluginEndpointListResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -163,7 +198,7 @@ class EndpointDeleteApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint deleted successfully",
|
||||
console_ns.model("EndpointDeleteResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
console_ns.models[EndpointDeleteResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@ -190,7 +225,7 @@ class EndpointUpdateApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint updated successfully",
|
||||
console_ns.model("EndpointUpdateResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
console_ns.models[EndpointUpdateResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@ -221,7 +256,7 @@ class EndpointEnableApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint enabled successfully",
|
||||
console_ns.model("EndpointEnableResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
console_ns.models[EndpointEnableResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@ -248,7 +283,7 @@ class EndpointDisableApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint disabled successfully",
|
||||
console_ns.model("EndpointDisableResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
console_ns.models[EndpointDisableResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
from urllib import parse
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import get_or_create_model, register_enum_models
|
||||
from controllers.common.schema import register_enum_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
CannotTransferOwnerToSelfError,
|
||||
@ -25,7 +25,7 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_with_role_fields, account_with_role_list_fields
|
||||
from fields.member_fields import AccountWithRole, AccountWithRoleList
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import Account, TenantAccountRole
|
||||
@ -69,12 +69,7 @@ reg(OwnerTransferEmailPayload)
|
||||
reg(OwnerTransferCheckPayload)
|
||||
reg(OwnerTransferPayload)
|
||||
register_enum_models(console_ns, TenantAccountRole)
|
||||
|
||||
account_with_role_model = get_or_create_model("AccountWithRole", account_with_role_fields)
|
||||
|
||||
account_with_role_list_fields_copy = account_with_role_list_fields.copy()
|
||||
account_with_role_list_fields_copy["accounts"] = fields.List(fields.Nested(account_with_role_model))
|
||||
account_with_role_list_model = get_or_create_model("AccountWithRoleList", account_with_role_list_fields_copy)
|
||||
register_schema_models(console_ns, AccountWithRole, AccountWithRoleList)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members")
|
||||
@ -84,13 +79,15 @@ class MemberListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_with_role_list_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
return {"result": "success", "accounts": members}, 200
|
||||
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
|
||||
response = AccountWithRoleList(accounts=member_models)
|
||||
return response.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/invite-email")
|
||||
@ -235,13 +232,15 @@ class DatasetOperatorMemberListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_with_role_list_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
|
||||
return {"result": "success", "accounts": members}, 200
|
||||
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
|
||||
response = AccountWithRoleList(accounts=member_models)
|
||||
return response.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,16 +1,16 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
from flask_restx import Resource
|
||||
from flask_restx.api import HTTPStatus
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.wraps import edit_permission_required
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import annotation_fields, build_annotation_model
|
||||
from fields.annotation_fields import Annotation, AnnotationList
|
||||
from models.model import App
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
@ -26,7 +26,9 @@ class AnnotationReplyActionPayload(BaseModel):
|
||||
embedding_model_name: str = Field(description="Embedding model name")
|
||||
|
||||
|
||||
register_schema_models(service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload)
|
||||
register_schema_models(
|
||||
service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload, Annotation, AnnotationList
|
||||
)
|
||||
|
||||
|
||||
@service_api_ns.route("/apps/annotation-reply/<string:action>")
|
||||
@ -45,10 +47,11 @@ class AnnotationReplyActionApi(Resource):
|
||||
def post(self, app_model: App, action: Literal["enable", "disable"]):
|
||||
"""Enable or disable annotation reply feature."""
|
||||
args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump()
|
||||
if action == "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
|
||||
elif action == "disable":
|
||||
result = AppAnnotationService.disable_app_annotation(app_model.id)
|
||||
match action:
|
||||
case "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
|
||||
case "disable":
|
||||
result = AppAnnotationService.disable_app_annotation(app_model.id)
|
||||
return result, 200
|
||||
|
||||
|
||||
@ -82,23 +85,6 @@ class AnnotationReplyActionStatusApi(Resource):
|
||||
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
|
||||
|
||||
|
||||
# Define annotation list response model
|
||||
annotation_list_fields = {
|
||||
"data": fields.List(fields.Nested(annotation_fields)),
|
||||
"has_more": fields.Boolean,
|
||||
"limit": fields.Integer,
|
||||
"total": fields.Integer,
|
||||
"page": fields.Integer,
|
||||
}
|
||||
|
||||
|
||||
def build_annotation_list_model(api_or_ns: Namespace):
|
||||
"""Build the annotation list model for the API or Namespace."""
|
||||
copied_annotation_list_fields = annotation_list_fields.copy()
|
||||
copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns)))
|
||||
return api_or_ns.model("AnnotationList", copied_annotation_list_fields)
|
||||
|
||||
|
||||
@service_api_ns.route("/apps/annotations")
|
||||
class AnnotationListApi(Resource):
|
||||
@service_api_ns.doc("list_annotations")
|
||||
@ -109,8 +95,12 @@ class AnnotationListApi(Resource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Annotations retrieved successfully",
|
||||
service_api_ns.models[AnnotationList.__name__],
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_annotation_list_model(service_api_ns))
|
||||
def get(self, app_model: App):
|
||||
"""List annotations for the application."""
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
@ -118,13 +108,15 @@ class AnnotationListApi(Resource):
|
||||
keyword = request.args.get("keyword", default="", type=str)
|
||||
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_model.id, page, limit, keyword)
|
||||
return {
|
||||
"data": annotation_list,
|
||||
"has_more": len(annotation_list) == limit,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"page": page,
|
||||
}
|
||||
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
|
||||
response = AnnotationList(
|
||||
data=annotation_models,
|
||||
has_more=len(annotation_list) == limit,
|
||||
limit=limit,
|
||||
total=total,
|
||||
page=page,
|
||||
)
|
||||
return response.model_dump(mode="json")
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_annotation")
|
||||
@ -135,13 +127,18 @@ class AnnotationListApi(Resource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
HTTPStatus.CREATED,
|
||||
"Annotation created successfully",
|
||||
service_api_ns.models[Annotation.__name__],
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_annotation_model(service_api_ns), code=HTTPStatus.CREATED)
|
||||
def post(self, app_model: App):
|
||||
"""Create a new annotation."""
|
||||
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id)
|
||||
return annotation, 201
|
||||
response = Annotation.model_validate(annotation, from_attributes=True)
|
||||
return response.model_dump(mode="json"), HTTPStatus.CREATED
|
||||
|
||||
|
||||
@service_api_ns.route("/apps/annotations/<uuid:annotation_id>")
|
||||
@ -158,14 +155,19 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
404: "Annotation not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Annotation updated successfully",
|
||||
service_api_ns.models[Annotation.__name__],
|
||||
)
|
||||
@validate_app_token
|
||||
@edit_permission_required
|
||||
@service_api_ns.marshal_with(build_annotation_model(service_api_ns))
|
||||
def put(self, app_model: App, annotation_id: str):
|
||||
"""Update an existing annotation."""
|
||||
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
|
||||
return annotation
|
||||
response = Annotation.model_validate(annotation, from_attributes=True)
|
||||
return response.model_dump(mode="json")
|
||||
|
||||
@service_api_ns.doc("delete_annotation")
|
||||
@service_api_ns.doc(description="Delete an annotation")
|
||||
|
||||
@ -30,6 +30,7 @@ from core.errors.error import (
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
@ -52,7 +53,7 @@ class ChatRequestPayload(BaseModel):
|
||||
query: str
|
||||
files: list[dict[str, Any]] | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
conversation_id: str | None = Field(default=None, description="Conversation UUID")
|
||||
conversation_id: UUIDStrOrEmpty | None = Field(default=None, description="Conversation UUID")
|
||||
retriever_from: str = Field(default="dev")
|
||||
auto_generate_name: bool = Field(default=True, description="Auto generate conversation name")
|
||||
workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat")
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -23,12 +22,13 @@ from fields.conversation_variable_fields import (
|
||||
build_conversation_variable_infinite_scroll_pagination_model,
|
||||
build_conversation_variable_model,
|
||||
)
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
|
||||
|
||||
class ConversationListQuery(BaseModel):
|
||||
last_id: UUID | None = Field(default=None, description="Last conversation ID for pagination")
|
||||
last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last conversation ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of conversations to return")
|
||||
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
|
||||
default="-updated_at", description="Sort order for conversations"
|
||||
@ -48,7 +48,7 @@ class ConversationRenamePayload(BaseModel):
|
||||
|
||||
|
||||
class ConversationVariablesQuery(BaseModel):
|
||||
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")
|
||||
last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last variable ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
|
||||
variable_name: str | None = Field(
|
||||
default=None, description="Filter variables by name", min_length=1, max_length=255
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import logging
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -15,6 +14,7 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.errors.message import (
|
||||
FirstMessageNotExistsError,
|
||||
@ -27,8 +27,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageListQuery(BaseModel):
|
||||
conversation_id: UUID
|
||||
first_id: UUID | None = None
|
||||
conversation_id: UUIDStrOrEmpty
|
||||
first_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return")
|
||||
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@ from controllers.service_api.wraps import (
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from fields.tag_fields import build_dataset_tag_fields
|
||||
from fields.tag_fields import DataSetTag
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
@ -114,6 +114,7 @@ register_schema_models(
|
||||
TagBindingPayload,
|
||||
TagUnbindingPayload,
|
||||
DatasetListQuery,
|
||||
DataSetTag,
|
||||
)
|
||||
|
||||
|
||||
@ -480,15 +481,14 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||
def get(self, _):
|
||||
"""Get all knowledge type tags."""
|
||||
assert isinstance(current_user, Account)
|
||||
cid = current_user.current_tenant_id
|
||||
assert cid is not None
|
||||
tags = TagService.get_tags("knowledge", cid)
|
||||
|
||||
return tags, 200
|
||||
tag_models = TypeAdapter(list[DataSetTag]).validate_python(tags, from_attributes=True)
|
||||
return [tag.model_dump(mode="json") for tag in tag_models], 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_dataset_tag")
|
||||
@ -500,7 +500,6 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
)
|
||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||
def post(self, _):
|
||||
"""Add a knowledge type tag."""
|
||||
assert isinstance(current_user, Account)
|
||||
@ -510,7 +509,9 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
tag = TagService.save_tags({"name": payload.name, "type": "knowledge"})
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
response = DataSetTag.model_validate(
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
).model_dump(mode="json")
|
||||
return response, 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__])
|
||||
@ -523,7 +524,6 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
)
|
||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||
def patch(self, _):
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
@ -536,8 +536,9 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
|
||||
response = DataSetTag.model_validate(
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
).model_dump(mode="json")
|
||||
return response, 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__])
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
|
||||
|
||||
register_schema_model(service_api_ns, HitTestingPayload)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/hit-testing", "/datasets/<uuid:dataset_id>/retrieve")
|
||||
class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
||||
@ -15,6 +18,7 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
||||
404: "Dataset not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.expect(service_api_ns.models[HitTestingPayload.__name__])
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Perform hit testing on a dataset.
|
||||
|
||||
@ -168,10 +168,11 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
if action == "enable":
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
elif action == "disable":
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
match action:
|
||||
case "enable":
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
case "disable":
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
|
||||
@ -73,14 +73,14 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
|
||||
|
||||
# If caller needs end-user context, attach EndUser to current_user
|
||||
if fetch_user_arg:
|
||||
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
|
||||
user_id = request.args.get("user")
|
||||
elif fetch_user_arg.fetch_from == WhereisUserArg.JSON:
|
||||
user_id = request.get_json().get("user")
|
||||
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
|
||||
user_id = request.form.get("user")
|
||||
else:
|
||||
user_id = None
|
||||
user_id = None
|
||||
match fetch_user_arg.fetch_from:
|
||||
case WhereisUserArg.QUERY:
|
||||
user_id = request.args.get("user")
|
||||
case WhereisUserArg.JSON:
|
||||
user_id = request.get_json().get("user")
|
||||
case WhereisUserArg.FORM:
|
||||
user_id = request.form.get("user")
|
||||
|
||||
if not user_id and fetch_user_arg.required:
|
||||
raise ValueError("Arg user must be provided.")
|
||||
|
||||
@ -14,16 +14,17 @@ class AgentConfigManager:
|
||||
agent_dict = config.get("agent_mode", {})
|
||||
agent_strategy = agent_dict.get("strategy", "cot")
|
||||
|
||||
if agent_strategy == "function_call":
|
||||
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
elif agent_strategy in {"cot", "react"}:
|
||||
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
else:
|
||||
# old configs, try to detect default strategy
|
||||
if config["model"]["provider"] == "openai":
|
||||
match agent_strategy:
|
||||
case "function_call":
|
||||
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
else:
|
||||
case "cot" | "react":
|
||||
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
case _:
|
||||
# old configs, try to detect default strategy
|
||||
if config["model"]["provider"] == "openai":
|
||||
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
else:
|
||||
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
|
||||
agent_tools = []
|
||||
for tool in agent_dict.get("tools", []):
|
||||
|
||||
@ -253,7 +253,7 @@ class WorkflowResponseConverter:
|
||||
data=WorkflowFinishStreamResponse.Data(
|
||||
id=run_id,
|
||||
workflow_id=workflow_id,
|
||||
status=status.value,
|
||||
status=status,
|
||||
outputs=encoded_outputs,
|
||||
error=error,
|
||||
elapsed_time=elapsed_time,
|
||||
@ -344,13 +344,13 @@ class WorkflowResponseConverter:
|
||||
metadata = self._merge_metadata(event.execution_metadata, snapshot)
|
||||
|
||||
if isinstance(event, QueueNodeSucceededEvent):
|
||||
status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
error_message = event.error
|
||||
elif isinstance(event, QueueNodeFailedEvent):
|
||||
status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
status = WorkflowNodeExecutionStatus.FAILED
|
||||
error_message = event.error
|
||||
else:
|
||||
status = WorkflowNodeExecutionStatus.EXCEPTION.value
|
||||
status = WorkflowNodeExecutionStatus.EXCEPTION
|
||||
error_message = event.error
|
||||
|
||||
return NodeFinishStreamResponse(
|
||||
@ -418,7 +418,7 @@ class WorkflowResponseConverter:
|
||||
process_data_truncated=process_data_truncated,
|
||||
outputs=outputs,
|
||||
outputs_truncated=outputs_truncated,
|
||||
status=WorkflowNodeExecutionStatus.RETRY.value,
|
||||
status=WorkflowNodeExecutionStatus.RETRY,
|
||||
error=event.error,
|
||||
elapsed_time=elapsed_time,
|
||||
execution_metadata=metadata,
|
||||
|
||||
@ -120,7 +120,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
raise ValueError("Pipeline dataset is required")
|
||||
inputs: Mapping[str, Any] = args["inputs"]
|
||||
start_node_id: str = args["start_node_id"]
|
||||
datasource_type: str = args["datasource_type"]
|
||||
datasource_type = DatasourceProviderType(args["datasource_type"])
|
||||
datasource_info_list: list[Mapping[str, Any]] = self._format_datasource_info_list(
|
||||
datasource_type, args["datasource_info_list"], pipeline, workflow, start_node_id, user
|
||||
)
|
||||
@ -660,7 +660,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
tenant_id: str,
|
||||
dataset_id: str,
|
||||
built_in_field_enabled: bool,
|
||||
datasource_type: str,
|
||||
datasource_type: DatasourceProviderType,
|
||||
datasource_info: Mapping[str, Any],
|
||||
created_from: str,
|
||||
position: int,
|
||||
@ -668,17 +668,17 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
batch: str,
|
||||
document_form: str,
|
||||
):
|
||||
if datasource_type == "local_file":
|
||||
name = datasource_info.get("name", "untitled")
|
||||
elif datasource_type == "online_document":
|
||||
name = datasource_info.get("page", {}).get("page_name", "untitled")
|
||||
elif datasource_type == "website_crawl":
|
||||
name = datasource_info.get("title", "untitled")
|
||||
elif datasource_type == "online_drive":
|
||||
name = datasource_info.get("name", "untitled")
|
||||
else:
|
||||
raise ValueError(f"Unsupported datasource type: {datasource_type}")
|
||||
|
||||
match datasource_type:
|
||||
case DatasourceProviderType.LOCAL_FILE:
|
||||
name = datasource_info.get("name", "untitled")
|
||||
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
name = datasource_info.get("page", {}).get("page_name", "untitled")
|
||||
case DatasourceProviderType.WEBSITE_CRAWL:
|
||||
name = datasource_info.get("title", "untitled")
|
||||
case DatasourceProviderType.ONLINE_DRIVE:
|
||||
name = datasource_info.get("name", "untitled")
|
||||
case _:
|
||||
raise ValueError(f"Unsupported datasource type: {datasource_type}")
|
||||
document = Document(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
@ -706,7 +706,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
|
||||
def _format_datasource_info_list(
|
||||
self,
|
||||
datasource_type: str,
|
||||
datasource_type: DatasourceProviderType,
|
||||
datasource_info_list: list[Mapping[str, Any]],
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
@ -716,7 +716,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
"""
|
||||
Format datasource info list.
|
||||
"""
|
||||
if datasource_type == "online_drive":
|
||||
if datasource_type == DatasourceProviderType.ONLINE_DRIVE:
|
||||
all_files: list[Mapping[str, Any]] = []
|
||||
datasource_node_data = None
|
||||
datasource_nodes = workflow.graph_dict.get("nodes", [])
|
||||
|
||||
@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class AnnotationReplyAccount(BaseModel):
|
||||
@ -255,7 +255,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
||||
|
||||
id: str
|
||||
workflow_id: str
|
||||
status: str
|
||||
status: WorkflowExecutionStatus
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
error: str | None = None
|
||||
elapsed_time: float
|
||||
@ -345,7 +345,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
process_data_truncated: bool = False
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
outputs_truncated: bool = True
|
||||
status: str
|
||||
status: WorkflowNodeExecutionStatus
|
||||
error: str | None = None
|
||||
elapsed_time: float
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||
@ -411,7 +411,7 @@ class NodeRetryStreamResponse(StreamResponse):
|
||||
process_data_truncated: bool = False
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
outputs_truncated: bool = False
|
||||
status: str
|
||||
status: WorkflowNodeExecutionStatus
|
||||
error: str | None = None
|
||||
elapsed_time: float
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||
@ -798,7 +798,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
|
||||
|
||||
id: str
|
||||
workflow_id: str
|
||||
status: str
|
||||
status: WorkflowExecutionStatus
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
error: str | None = None
|
||||
elapsed_time: float
|
||||
|
||||
@ -4,13 +4,14 @@ from typing import TYPE_CHECKING, final
|
||||
from typing_extensions import override
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import file_manager
|
||||
from core.helper import ssrf_proxy
|
||||
from core.file.file_manager import file_manager
|
||||
from core.helper.code_executor.code_executor import CodeExecutor
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph import NodeFactory
|
||||
from core.workflow.graph.graph import NodeFactory
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
@ -22,7 +23,6 @@ from core.workflow.nodes.template_transform.template_renderer import (
|
||||
Jinja2TemplateRenderer,
|
||||
)
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from libs.typing import is_str, is_str_dict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
@ -47,9 +47,9 @@ class DifyNodeFactory(NodeFactory):
|
||||
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
|
||||
code_limits: CodeNodeLimits | None = None,
|
||||
template_renderer: Jinja2TemplateRenderer | None = None,
|
||||
http_request_http_client: HttpClientProtocol = ssrf_proxy,
|
||||
http_request_http_client: HttpClientProtocol | None = None,
|
||||
http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
|
||||
http_request_file_manager: FileManagerProtocol = file_manager,
|
||||
http_request_file_manager: FileManagerProtocol | None = None,
|
||||
) -> None:
|
||||
self.graph_init_params = graph_init_params
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
@ -68,12 +68,12 @@ class DifyNodeFactory(NodeFactory):
|
||||
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
|
||||
)
|
||||
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
|
||||
self._http_request_http_client = http_request_http_client
|
||||
self._http_request_http_client = http_request_http_client or ssrf_proxy
|
||||
self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
|
||||
self._http_request_file_manager = http_request_file_manager
|
||||
self._http_request_file_manager = http_request_file_manager or file_manager
|
||||
|
||||
@override
|
||||
def create_node(self, node_config: dict[str, object]) -> Node:
|
||||
def create_node(self, node_config: NodeConfigDict) -> Node:
|
||||
"""
|
||||
Create a Node instance from node configuration data using the traditional mapping.
|
||||
|
||||
@ -82,23 +82,14 @@ class DifyNodeFactory(NodeFactory):
|
||||
:raises ValueError: if node type is unknown or configuration is invalid
|
||||
"""
|
||||
# Get node_id from config
|
||||
node_id = node_config.get("id")
|
||||
if not is_str(node_id):
|
||||
raise ValueError("Node config missing id")
|
||||
node_id = node_config["id"]
|
||||
|
||||
# Get node type from config
|
||||
node_data = node_config.get("data", {})
|
||||
if not is_str_dict(node_data):
|
||||
raise ValueError(f"Node {node_id} missing data information")
|
||||
|
||||
node_type_str = node_data.get("type")
|
||||
if not is_str(node_type_str):
|
||||
raise ValueError(f"Node {node_id} missing or invalid type information")
|
||||
|
||||
node_data = node_config["data"]
|
||||
try:
|
||||
node_type = NodeType(node_type_str)
|
||||
node_type = NodeType(node_data["type"])
|
||||
except ValueError:
|
||||
raise ValueError(f"Unknown node type: {node_type_str}")
|
||||
raise ValueError(f"Unknown node type: {node_data['type']}")
|
||||
|
||||
# Get node class
|
||||
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
|
||||
|
||||
@ -312,3 +312,18 @@ def _build_file_from_ref(
|
||||
|
||||
logger.warning("File not found for file_ref: %s", file_ref)
|
||||
return None
|
||||
|
||||
|
||||
class FileManager:
|
||||
"""
|
||||
Adapter exposing file manager helpers behind FileManagerProtocol.
|
||||
|
||||
This is intentionally a thin wrapper over the existing module-level functions so callers can inject it
|
||||
where a protocol-typed file manager is expected.
|
||||
"""
|
||||
|
||||
def download(self, f: File, /) -> bytes:
|
||||
return download(f)
|
||||
|
||||
|
||||
file_manager = FileManager()
|
||||
|
||||
@ -47,15 +47,16 @@ class CodeNodeProvider(BaseModel, ABC):
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls) -> DefaultConfig:
|
||||
return {
|
||||
"type": "code",
|
||||
"config": {
|
||||
"variables": [
|
||||
{"variable": "arg1", "value_selector": []},
|
||||
{"variable": "arg2", "value_selector": []},
|
||||
],
|
||||
"code_language": cls.get_language(),
|
||||
"code": cls.get_default_code(),
|
||||
"outputs": {"result": {"type": "string", "children": None}},
|
||||
},
|
||||
variables: list[VariableConfig] = [
|
||||
{"variable": "arg1", "value_selector": []},
|
||||
{"variable": "arg2", "value_selector": []},
|
||||
]
|
||||
outputs: dict[str, OutputConfig] = {"result": {"type": "string", "children": None}}
|
||||
|
||||
config: CodeConfig = {
|
||||
"variables": variables,
|
||||
"code_language": cls.get_language(),
|
||||
"code": cls.get_default_code(),
|
||||
"outputs": outputs,
|
||||
}
|
||||
return {"type": "code", "config": config}
|
||||
|
||||
@ -230,3 +230,41 @@ def delete(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any)
|
||||
|
||||
def head(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return make_request("HEAD", url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
class SSRFProxy:
|
||||
"""
|
||||
Adapter exposing SSRF-protected HTTP helpers behind HttpClientProtocol.
|
||||
|
||||
This is intentionally a thin wrapper over the existing module-level functions so callers can inject it
|
||||
where a protocol-typed HTTP client is expected.
|
||||
"""
|
||||
|
||||
@property
|
||||
def max_retries_exceeded_error(self) -> type[Exception]:
|
||||
return max_retries_exceeded_error
|
||||
|
||||
@property
|
||||
def request_error(self) -> type[Exception]:
|
||||
return request_error
|
||||
|
||||
def get(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return get(url=url, max_retries=max_retries, **kwargs)
|
||||
|
||||
def head(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return head(url=url, max_retries=max_retries, **kwargs)
|
||||
|
||||
def post(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return post(url=url, max_retries=max_retries, **kwargs)
|
||||
|
||||
def put(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return put(url=url, max_retries=max_retries, **kwargs)
|
||||
|
||||
def delete(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return delete(url=url, max_retries=max_retries, **kwargs)
|
||||
|
||||
def patch(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return patch(url=url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
ssrf_proxy = SSRFProxy()
|
||||
|
||||
@ -369,77 +369,78 @@ class IndexingRunner:
|
||||
# Generate summary preview
|
||||
summary_index_setting = tmp_processing_rule.get("summary_index_setting")
|
||||
if summary_index_setting and summary_index_setting.get("enable") and preview_texts:
|
||||
preview_texts = index_processor.generate_summary_preview(tenant_id, preview_texts, summary_index_setting)
|
||||
preview_texts = index_processor.generate_summary_preview(
|
||||
tenant_id, preview_texts, summary_index_setting, doc_language
|
||||
)
|
||||
|
||||
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
|
||||
|
||||
def _extract(
|
||||
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
|
||||
) -> list[Document]:
|
||||
# load file
|
||||
if dataset_document.data_source_type not in {"upload_file", "notion_import", "website_crawl"}:
|
||||
return []
|
||||
|
||||
data_source_info = dataset_document.data_source_info_dict
|
||||
text_docs = []
|
||||
if dataset_document.data_source_type == "upload_file":
|
||||
if not data_source_info or "upload_file_id" not in data_source_info:
|
||||
raise ValueError("no upload file found")
|
||||
stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"])
|
||||
file_detail = db.session.scalars(stmt).one_or_none()
|
||||
match dataset_document.data_source_type:
|
||||
case "upload_file":
|
||||
if not data_source_info or "upload_file_id" not in data_source_info:
|
||||
raise ValueError("no upload file found")
|
||||
stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"])
|
||||
file_detail = db.session.scalars(stmt).one_or_none()
|
||||
|
||||
if file_detail:
|
||||
if file_detail:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE,
|
||||
upload_file=file_detail,
|
||||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
||||
case "notion_import":
|
||||
if (
|
||||
not data_source_info
|
||||
or "notion_workspace_id" not in data_source_info
|
||||
or "notion_page_id" not in data_source_info
|
||||
):
|
||||
raise ValueError("no notion import info found")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE,
|
||||
upload_file=file_detail,
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": data_source_info.get("credential_id"),
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
"document": dataset_document,
|
||||
"tenant_id": dataset_document.tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
||||
elif dataset_document.data_source_type == "notion_import":
|
||||
if (
|
||||
not data_source_info
|
||||
or "notion_workspace_id" not in data_source_info
|
||||
or "notion_page_id" not in data_source_info
|
||||
):
|
||||
raise ValueError("no notion import info found")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": data_source_info.get("credential_id"),
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
"document": dataset_document,
|
||||
"tenant_id": dataset_document.tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
||||
elif dataset_document.data_source_type == "website_crawl":
|
||||
if (
|
||||
not data_source_info
|
||||
or "provider" not in data_source_info
|
||||
or "url" not in data_source_info
|
||||
or "job_id" not in data_source_info
|
||||
):
|
||||
raise ValueError("no website import info found")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE,
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
"tenant_id": dataset_document.tenant_id,
|
||||
"url": data_source_info["url"],
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
}
|
||||
),
|
||||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
||||
case "website_crawl":
|
||||
if (
|
||||
not data_source_info
|
||||
or "provider" not in data_source_info
|
||||
or "url" not in data_source_info
|
||||
or "job_id" not in data_source_info
|
||||
):
|
||||
raise ValueError("no website import info found")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE,
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
"tenant_id": dataset_document.tenant_id,
|
||||
"url": data_source_info["url"],
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
}
|
||||
),
|
||||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
||||
case _:
|
||||
return []
|
||||
# update document status to splitting
|
||||
self._update_document_index_status(
|
||||
document_id=dataset_document.id,
|
||||
|
||||
@ -442,11 +442,13 @@ DEFAULT_GENERATOR_SUMMARY_PROMPT = (
|
||||
|
||||
Requirements:
|
||||
1. Write a concise summary in plain text
|
||||
2. Use the same language as the input content
|
||||
2. You must write in {language}. No language other than {language} should be used.
|
||||
3. Focus on important facts, concepts, and details
|
||||
4. If images are included, describe their key information
|
||||
5. Do not use words like "好的", "ok", "I understand", "This text discusses", "The content mentions"
|
||||
6. Write directly without extra words
|
||||
7. If there is not enough content to generate a meaningful summary,
|
||||
return an empty string without any explanation or prompt
|
||||
|
||||
Output only the summary text. Start summarizing now:
|
||||
|
||||
|
||||
@ -88,7 +88,7 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
||||
DefaultParameterName.MAX_TOKENS: {
|
||||
"label": {
|
||||
"en_US": "Max Tokens",
|
||||
"zh_Hans": "最大标记",
|
||||
"zh_Hans": "最大 Token 数",
|
||||
},
|
||||
"type": "int",
|
||||
"help": {
|
||||
|
||||
@ -48,12 +48,22 @@ class BaseIndexProcessor(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def generate_summary_preview(
|
||||
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
|
||||
self,
|
||||
tenant_id: str,
|
||||
preview_texts: list[PreviewDetail],
|
||||
summary_index_setting: dict,
|
||||
doc_language: str | None = None,
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
For each segment in preview_texts, generate a summary using LLM and attach it to the segment.
|
||||
The summary can be stored in a new attribute, e.g., summary.
|
||||
This method should be implemented by subclasses.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
preview_texts: List of preview details to generate summaries for
|
||||
summary_index_setting: Summary index configuration
|
||||
doc_language: Optional document language to ensure summary is generated in the correct language
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -275,7 +275,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
raise ValueError("Chunks is not a list")
|
||||
|
||||
def generate_summary_preview(
|
||||
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
|
||||
self,
|
||||
tenant_id: str,
|
||||
preview_texts: list[PreviewDetail],
|
||||
summary_index_setting: dict,
|
||||
doc_language: str | None = None,
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
For each segment, concurrently call generate_summary to generate a summary
|
||||
@ -298,11 +302,15 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
if flask_app:
|
||||
# Ensure Flask app context in worker thread
|
||||
with flask_app.app_context():
|
||||
summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting)
|
||||
summary, _ = self.generate_summary(
|
||||
tenant_id, preview.content, summary_index_setting, document_language=doc_language
|
||||
)
|
||||
preview.summary = summary
|
||||
else:
|
||||
# Fallback: try without app context (may fail)
|
||||
summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting)
|
||||
summary, _ = self.generate_summary(
|
||||
tenant_id, preview.content, summary_index_setting, document_language=doc_language
|
||||
)
|
||||
preview.summary = summary
|
||||
|
||||
# Generate summaries concurrently using ThreadPoolExecutor
|
||||
@ -356,6 +364,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
text: str,
|
||||
summary_index_setting: dict | None = None,
|
||||
segment_id: str | None = None,
|
||||
document_language: str | None = None,
|
||||
) -> tuple[str, LLMUsage]:
|
||||
"""
|
||||
Generate summary for the given text using ModelInstance.invoke_llm and the default or custom summary prompt,
|
||||
@ -366,6 +375,8 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
text: Text content to summarize
|
||||
summary_index_setting: Summary index configuration
|
||||
segment_id: Optional segment ID to fetch attachments from SegmentAttachmentBinding table
|
||||
document_language: Optional document language (e.g., "Chinese", "English")
|
||||
to ensure summary is generated in the correct language
|
||||
|
||||
Returns:
|
||||
Tuple of (summary_content, llm_usage) where llm_usage is LLMUsage object
|
||||
@ -381,8 +392,22 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
raise ValueError("model_name and model_provider_name are required in summary_index_setting")
|
||||
|
||||
# Import default summary prompt
|
||||
is_default_prompt = False
|
||||
if not summary_prompt:
|
||||
summary_prompt = DEFAULT_GENERATOR_SUMMARY_PROMPT
|
||||
is_default_prompt = True
|
||||
|
||||
# Format prompt with document language only for default prompt
|
||||
# Custom prompts are used as-is to avoid interfering with user-defined templates
|
||||
# If document_language is provided, use it; otherwise, use "the same language as the input content"
|
||||
# This is especially important for image-only chunks where text is empty or minimal
|
||||
if is_default_prompt:
|
||||
language_for_prompt = document_language or "the same language as the input content"
|
||||
try:
|
||||
summary_prompt = summary_prompt.format(language=language_for_prompt)
|
||||
except KeyError:
|
||||
# If default prompt doesn't have {language} placeholder, use it as-is
|
||||
pass
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
|
||||
@ -358,7 +358,11 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
}
|
||||
|
||||
def generate_summary_preview(
|
||||
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
|
||||
self,
|
||||
tenant_id: str,
|
||||
preview_texts: list[PreviewDetail],
|
||||
summary_index_setting: dict,
|
||||
doc_language: str | None = None,
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
For each parent chunk in preview_texts, concurrently call generate_summary to generate a summary
|
||||
@ -389,6 +393,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
tenant_id=tenant_id,
|
||||
text=preview.content,
|
||||
summary_index_setting=summary_index_setting,
|
||||
document_language=doc_language,
|
||||
)
|
||||
preview.summary = summary
|
||||
else:
|
||||
@ -397,6 +402,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
tenant_id=tenant_id,
|
||||
text=preview.content,
|
||||
summary_index_setting=summary_index_setting,
|
||||
document_language=doc_language,
|
||||
)
|
||||
preview.summary = summary
|
||||
|
||||
|
||||
@ -241,7 +241,11 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
}
|
||||
|
||||
def generate_summary_preview(
|
||||
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
|
||||
self,
|
||||
tenant_id: str,
|
||||
preview_texts: list[PreviewDetail],
|
||||
summary_index_setting: dict,
|
||||
doc_language: str | None = None,
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
QA model doesn't generate summaries, so this method returns preview_texts unchanged.
|
||||
|
||||
@ -35,6 +35,7 @@ class SchemaRegistry:
|
||||
registry.load_all_versions()
|
||||
|
||||
cls._default_instance = registry
|
||||
return cls._default_instance
|
||||
|
||||
return cls._default_instance
|
||||
|
||||
|
||||
@ -192,16 +192,13 @@ class ToolManager:
|
||||
raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
|
||||
|
||||
if not provider_controller.need_credentials:
|
||||
return cast(
|
||||
BuiltinTool,
|
||||
builtin_tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
credentials={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
),
|
||||
return builtin_tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
credentials={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
)
|
||||
builtin_provider = None
|
||||
if isinstance(provider_controller, PluginToolProviderController):
|
||||
@ -303,18 +300,15 @@ class ToolManager:
|
||||
decrypted_credentials = refreshed_credentials.credentials
|
||||
cache.delete()
|
||||
|
||||
return cast(
|
||||
BuiltinTool,
|
||||
builtin_tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
credentials=dict(decrypted_credentials),
|
||||
credential_type=CredentialType.of(builtin_provider.credential_type),
|
||||
runtime_parameters={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
),
|
||||
return builtin_tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
credentials=dict(decrypted_credentials),
|
||||
credential_type=CredentialType.of(builtin_provider.credential_type),
|
||||
runtime_parameters={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
)
|
||||
|
||||
elif provider_type == ToolProviderType.API:
|
||||
|
||||
@ -7,11 +7,6 @@ from core.workflow.nodes.base.entities import OutputVariableEntity
|
||||
|
||||
|
||||
class WorkflowToolConfigurationUtils:
|
||||
@classmethod
|
||||
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
|
||||
for configuration in configurations:
|
||||
WorkflowToolParameterConfiguration.model_validate(configuration)
|
||||
|
||||
@classmethod
|
||||
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
|
||||
"""
|
||||
|
||||
24
api/core/workflow/entities/graph_config.py
Normal file
24
api/core/workflow/entities/graph_config.py
Normal file
@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
from pydantic import TypeAdapter, with_config
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
@with_config(extra="allow")
|
||||
class NodeConfigData(TypedDict):
|
||||
type: str
|
||||
|
||||
|
||||
@with_config(extra="allow")
|
||||
class NodeConfigDict(TypedDict):
|
||||
id: str
|
||||
data: NodeConfigData
|
||||
|
||||
|
||||
NodeConfigDictAdapter = TypeAdapter(NodeConfigDict)
|
||||
@ -5,15 +5,20 @@ from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Protocol, cast, final
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from libs.typing import is_str, is_str_dict
|
||||
from libs.typing import is_str
|
||||
|
||||
from .edge import Edge
|
||||
from .validation import get_graph_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ListNodeConfigDict = TypeAdapter(list[NodeConfigDict])
|
||||
|
||||
|
||||
class NodeFactory(Protocol):
|
||||
"""
|
||||
@ -23,7 +28,7 @@ class NodeFactory(Protocol):
|
||||
allowing for different node creation strategies while maintaining type safety.
|
||||
"""
|
||||
|
||||
def create_node(self, node_config: dict[str, object]) -> Node:
|
||||
def create_node(self, node_config: NodeConfigDict) -> Node:
|
||||
"""
|
||||
Create a Node instance from node configuration data.
|
||||
|
||||
@ -63,28 +68,24 @@ class Graph:
|
||||
self.root_node = root_node
|
||||
|
||||
@classmethod
|
||||
def _parse_node_configs(cls, node_configs: list[dict[str, object]]) -> dict[str, dict[str, object]]:
|
||||
def _parse_node_configs(cls, node_configs: list[NodeConfigDict]) -> dict[str, NodeConfigDict]:
|
||||
"""
|
||||
Parse node configurations and build a mapping of node IDs to configs.
|
||||
|
||||
:param node_configs: list of node configuration dictionaries
|
||||
:return: mapping of node ID to node config
|
||||
"""
|
||||
node_configs_map: dict[str, dict[str, object]] = {}
|
||||
node_configs_map: dict[str, NodeConfigDict] = {}
|
||||
|
||||
for node_config in node_configs:
|
||||
node_id = node_config.get("id")
|
||||
if not node_id or not isinstance(node_id, str):
|
||||
continue
|
||||
|
||||
node_configs_map[node_id] = node_config
|
||||
node_configs_map[node_config["id"]] = node_config
|
||||
|
||||
return node_configs_map
|
||||
|
||||
@classmethod
|
||||
def _find_root_node_id(
|
||||
cls,
|
||||
node_configs_map: Mapping[str, Mapping[str, object]],
|
||||
node_configs_map: Mapping[str, NodeConfigDict],
|
||||
edge_configs: Sequence[Mapping[str, object]],
|
||||
root_node_id: str | None = None,
|
||||
) -> str:
|
||||
@ -113,10 +114,8 @@ class Graph:
|
||||
# Prefer START node if available
|
||||
start_node_id = None
|
||||
for nid in root_candidates:
|
||||
node_data = node_configs_map[nid].get("data")
|
||||
if not is_str_dict(node_data):
|
||||
continue
|
||||
node_type = node_data.get("type")
|
||||
node_data = node_configs_map[nid]["data"]
|
||||
node_type = node_data["type"]
|
||||
if not isinstance(node_type, str):
|
||||
continue
|
||||
if NodeType(node_type).is_start_node:
|
||||
@ -176,7 +175,7 @@ class Graph:
|
||||
@classmethod
|
||||
def _create_node_instances(
|
||||
cls,
|
||||
node_configs_map: dict[str, dict[str, object]],
|
||||
node_configs_map: dict[str, NodeConfigDict],
|
||||
node_factory: NodeFactory,
|
||||
) -> dict[str, Node]:
|
||||
"""
|
||||
@ -303,7 +302,7 @@ class Graph:
|
||||
node_configs = graph_config.get("nodes", [])
|
||||
|
||||
edge_configs = cast(list[dict[str, object]], edge_configs)
|
||||
node_configs = cast(list[dict[str, object]], node_configs)
|
||||
node_configs = _ListNodeConfigDict.validate_python(node_configs)
|
||||
|
||||
if not node_configs:
|
||||
raise ValueError("Graph must have at least one node")
|
||||
|
||||
@ -46,7 +46,6 @@ from .graph_traversal import EdgeProcessor, SkipPropagator
|
||||
from .layers.base import GraphEngineLayer
|
||||
from .orchestration import Dispatcher, ExecutionCoordinator
|
||||
from .protocols.command_channel import CommandChannel
|
||||
from .ready_queue import ReadyQueue
|
||||
from .worker_management import WorkerPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -90,7 +89,7 @@ class GraphEngine:
|
||||
self._graph_execution.workflow_id = workflow_id
|
||||
|
||||
# === Execution Queues ===
|
||||
self._ready_queue = cast(ReadyQueue, self._graph_runtime_state.ready_queue)
|
||||
self._ready_queue = self._graph_runtime_state.ready_queue
|
||||
|
||||
# Queue for events generated during execution
|
||||
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
|
||||
|
||||
@ -25,6 +25,7 @@ from core.workflow.graph_events import (
|
||||
)
|
||||
from core.workflow.nodes.base.template import TextSegment, VariableSegment
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.runtime.graph_runtime_state import GraphProtocol
|
||||
|
||||
from .path import Path
|
||||
from .session import ResponseSession
|
||||
@ -81,7 +82,7 @@ class ResponseStreamCoordinator:
|
||||
Ensures ordered streaming of responses based on upstream node outputs and constants.
|
||||
"""
|
||||
|
||||
def __init__(self, variable_pool: "VariablePool", graph: "Graph") -> None:
|
||||
def __init__(self, variable_pool: "VariablePool", graph: GraphProtocol) -> None:
|
||||
"""
|
||||
Initialize coordinator with variable pool.
|
||||
|
||||
|
||||
@ -10,10 +10,10 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.template import Template
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
|
||||
from core.workflow.runtime.graph_runtime_state import NodeProtocol
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -29,21 +29,26 @@ class ResponseSession:
|
||||
index: int = 0 # Current position in the template segments
|
||||
|
||||
@classmethod
|
||||
def from_node(cls, node: Node) -> ResponseSession:
|
||||
def from_node(cls, node: NodeProtocol) -> ResponseSession:
|
||||
"""
|
||||
Create a ResponseSession from an AnswerNode or EndNode.
|
||||
Create a ResponseSession from a response-capable node.
|
||||
|
||||
The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer,
|
||||
but at runtime this must be an `AnswerNode`, `EndNode`, or `KnowledgeIndexNode` that provides:
|
||||
- `id: str`
|
||||
- `get_streaming_template() -> Template`
|
||||
|
||||
Args:
|
||||
node: Must be either an AnswerNode or EndNode instance
|
||||
node: Node from the materialized workflow graph.
|
||||
|
||||
Returns:
|
||||
ResponseSession configured with the node's streaming template
|
||||
|
||||
Raises:
|
||||
TypeError: If node is not an AnswerNode or EndNode
|
||||
TypeError: If node is not a supported response node type.
|
||||
"""
|
||||
if not isinstance(node, AnswerNode | EndNode | KnowledgeIndexNode):
|
||||
raise TypeError
|
||||
raise TypeError("ResponseSession.from_node only supports AnswerNode, EndNode, or KnowledgeIndexNode")
|
||||
return cls(
|
||||
node_id=node.id,
|
||||
template=node.get_streaming_template(),
|
||||
|
||||
@ -205,32 +205,33 @@ class AgentNode(Node[AgentNodeData]):
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
agent_input = node_data.agent_parameters[parameter_name]
|
||||
if agent_input.type == "variable":
|
||||
variable = variable_pool.get(agent_input.value) # type: ignore
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(agent_input.value))
|
||||
parameter_value = variable.value
|
||||
elif agent_input.type in {"mixed", "constant"}:
|
||||
# variable_pool.convert_template expects a string template,
|
||||
# but if passing a dict, convert to JSON string first before rendering
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
|
||||
else:
|
||||
match agent_input.type:
|
||||
case "variable":
|
||||
variable = variable_pool.get(agent_input.value) # type: ignore
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(agent_input.value))
|
||||
parameter_value = variable.value
|
||||
case "mixed" | "constant":
|
||||
# variable_pool.convert_template expects a string template,
|
||||
# but if passing a dict, convert to JSON string first before rendering
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
|
||||
else:
|
||||
parameter_value = str(agent_input.value)
|
||||
except TypeError:
|
||||
parameter_value = str(agent_input.value)
|
||||
except TypeError:
|
||||
parameter_value = str(agent_input.value)
|
||||
segment_group = variable_pool.convert_template(parameter_value)
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
# variable_pool.convert_template returns a string,
|
||||
# so we need to convert it back to a dictionary
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.loads(parameter_value)
|
||||
except json.JSONDecodeError:
|
||||
parameter_value = parameter_value
|
||||
else:
|
||||
raise AgentInputTypeError(agent_input.type)
|
||||
segment_group = variable_pool.convert_template(parameter_value)
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
# variable_pool.convert_template returns a string,
|
||||
# so we need to convert it back to a dictionary
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.loads(parameter_value)
|
||||
except json.JSONDecodeError:
|
||||
parameter_value = parameter_value
|
||||
case _:
|
||||
raise AgentInputTypeError(agent_input.type)
|
||||
value = parameter_value
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
@ -387,12 +388,13 @@ class AgentNode(Node[AgentNodeData]):
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in typed_node_data.agent_parameters:
|
||||
input = typed_node_data.agent_parameters[parameter_name]
|
||||
if input.type in ["mixed", "constant"]:
|
||||
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == "variable":
|
||||
result[parameter_name] = input.value
|
||||
match input.type:
|
||||
case "mixed" | "constant":
|
||||
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
case "variable":
|
||||
result[parameter_name] = input.value
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
|
||||
@ -115,7 +115,7 @@ class DefaultValue(BaseModel):
|
||||
@model_validator(mode="after")
|
||||
def validate_value_type(self) -> DefaultValue:
|
||||
# Type validation configuration
|
||||
type_validators = {
|
||||
type_validators: dict[DefaultValueType, dict[str, Any]] = {
|
||||
DefaultValueType.STRING: {
|
||||
"type": str,
|
||||
"converter": lambda x: x,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Annotated, Literal, Self
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import AfterValidator, BaseModel
|
||||
|
||||
@ -34,7 +34,7 @@ class CodeNodeData(BaseNodeData):
|
||||
|
||||
class Output(BaseModel):
|
||||
type: Annotated[SegmentType, AfterValidator(_validate_type)]
|
||||
children: dict[str, Self] | None = None
|
||||
children: dict[str, "CodeNodeData.Output"] | None = None
|
||||
|
||||
class Dependency(BaseModel):
|
||||
name: str
|
||||
|
||||
@ -69,11 +69,13 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
if datasource_type is None:
|
||||
raise DatasourceNodeError("Datasource type is not set")
|
||||
|
||||
datasource_type = DatasourceProviderType.value_of(datasource_type)
|
||||
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id=f"{node_data.plugin_id}/{node_data.provider_name}",
|
||||
datasource_name=node_data.datasource_name or "",
|
||||
tenant_id=self.tenant_id,
|
||||
datasource_type=DatasourceProviderType.value_of(datasource_type),
|
||||
datasource_type=datasource_type,
|
||||
)
|
||||
datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id)
|
||||
|
||||
@ -268,15 +270,18 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
if typed_node_data.datasource_parameters:
|
||||
for parameter_name in typed_node_data.datasource_parameters:
|
||||
input = typed_node_data.datasource_parameters[parameter_name]
|
||||
if input.type == "mixed":
|
||||
assert isinstance(input.value, str)
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == "variable":
|
||||
result[parameter_name] = input.value
|
||||
elif input.type == "constant":
|
||||
pass
|
||||
match input.type:
|
||||
case "mixed":
|
||||
assert isinstance(input.value, str)
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
case "variable":
|
||||
result[parameter_name] = input.value
|
||||
case "constant":
|
||||
pass
|
||||
case None:
|
||||
pass
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
@ -306,99 +311,107 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
if message.type in {
|
||||
DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
DatasourceMessage.MessageType.BINARY_LINK,
|
||||
DatasourceMessage.MessageType.IMAGE,
|
||||
}:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
match message.type:
|
||||
case (
|
||||
DatasourceMessage.MessageType.IMAGE_LINK
|
||||
| DatasourceMessage.MessageType.BINARY_LINK
|
||||
| DatasourceMessage.MessageType.IMAGE
|
||||
):
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
url = message.message.text
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
elif message.type == DatasourceMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, DatasourceMessage.JsonMessage)
|
||||
json.append(message.message.json_object)
|
||||
elif message.type == DatasourceMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, DatasourceMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
files.append(file)
|
||||
case DatasourceMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
)
|
||||
case DatasourceMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
elif message.type == DatasourceMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
files.append(message.meta["file"])
|
||||
case DatasourceMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, DatasourceMessage.JsonMessage)
|
||||
json.append(message.message.json_object)
|
||||
case DatasourceMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
case DatasourceMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, DatasourceMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
case DatasourceMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
files.append(message.meta["file"])
|
||||
case (
|
||||
DatasourceMessage.MessageType.BLOB_CHUNK
|
||||
| DatasourceMessage.MessageType.LOG
|
||||
| DatasourceMessage.MessageType.RETRIEVER_RESOURCES
|
||||
):
|
||||
pass
|
||||
|
||||
# mark the end of the stream
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
|
||||
@ -2,7 +2,7 @@ import base64
|
||||
import json
|
||||
import secrets
|
||||
import string
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Callable, Mapping
|
||||
from copy import deepcopy
|
||||
from typing import Any, Literal
|
||||
from urllib.parse import urlencode, urlparse
|
||||
@ -11,9 +11,9 @@ import httpx
|
||||
from json_repair import repair_json
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import file_manager
|
||||
from core.file.enums import FileTransferMethod
|
||||
from core.helper import ssrf_proxy
|
||||
from core.file.file_manager import file_manager as default_file_manager
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
@ -79,8 +79,8 @@ class Executor:
|
||||
timeout: HttpRequestNodeTimeout,
|
||||
variable_pool: VariablePool,
|
||||
max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
||||
http_client: HttpClientProtocol = ssrf_proxy,
|
||||
file_manager: FileManagerProtocol = file_manager,
|
||||
http_client: HttpClientProtocol | None = None,
|
||||
file_manager: FileManagerProtocol | None = None,
|
||||
):
|
||||
# If authorization API key is present, convert the API key using the variable pool
|
||||
if node_data.authorization.type == "api-key":
|
||||
@ -107,8 +107,8 @@ class Executor:
|
||||
self.data = None
|
||||
self.json = None
|
||||
self.max_retries = max_retries
|
||||
self._http_client = http_client
|
||||
self._file_manager = file_manager
|
||||
self._http_client = http_client or ssrf_proxy
|
||||
self._file_manager = file_manager or default_file_manager
|
||||
|
||||
# init template
|
||||
self.variable_pool = variable_pool
|
||||
@ -336,7 +336,7 @@ class Executor:
|
||||
"""
|
||||
do http request depending on api bundle
|
||||
"""
|
||||
_METHOD_MAP = {
|
||||
_METHOD_MAP: dict[str, Callable[..., httpx.Response]] = {
|
||||
"get": self._http_client.get,
|
||||
"head": self._http_client.head,
|
||||
"post": self._http_client.post,
|
||||
@ -348,7 +348,7 @@ class Executor:
|
||||
if method_lc not in _METHOD_MAP:
|
||||
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
|
||||
|
||||
request_args = {
|
||||
request_args: dict[str, Any] = {
|
||||
"data": self.data,
|
||||
"files": self.files,
|
||||
"json": self.json,
|
||||
@ -361,14 +361,13 @@ class Executor:
|
||||
}
|
||||
# request_args = {k: v for k, v in request_args.items() if v is not None}
|
||||
try:
|
||||
response: httpx.Response = _METHOD_MAP[method_lc](
|
||||
response = _METHOD_MAP[method_lc](
|
||||
url=self.url,
|
||||
**request_args,
|
||||
max_retries=self.max_retries,
|
||||
)
|
||||
except (self._http_client.max_retries_exceeded_error, self._http_client.request_error) as e:
|
||||
raise HttpRequestNodeError(str(e)) from e
|
||||
# FIXME: fix type ignore, this maybe httpx type issue
|
||||
return response
|
||||
|
||||
def invoke(self) -> Response:
|
||||
|
||||
@ -4,8 +4,9 @@ from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File, FileTransferMethod, file_manager
|
||||
from core.helper import ssrf_proxy
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.file.file_manager import file_manager as default_file_manager
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.variables.segments import ArrayFileSegment
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
@ -47,9 +48,9 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
http_client: HttpClientProtocol = ssrf_proxy,
|
||||
http_client: HttpClientProtocol | None = None,
|
||||
tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
|
||||
file_manager: FileManagerProtocol = file_manager,
|
||||
file_manager: FileManagerProtocol | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
@ -57,9 +58,9 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._http_client = http_client
|
||||
self._http_client = http_client or ssrf_proxy
|
||||
self._tool_file_manager_factory = tool_file_manager_factory
|
||||
self._file_manager = file_manager
|
||||
self._file_manager = file_manager or default_file_manager
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
|
||||
@ -397,7 +397,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
return outputs
|
||||
|
||||
# Check if all non-None outputs are lists
|
||||
non_none_outputs = [output for output in outputs if output is not None]
|
||||
non_none_outputs: list[object] = [output for output in outputs if output is not None]
|
||||
if not non_none_outputs:
|
||||
return outputs
|
||||
|
||||
|
||||
@ -78,12 +78,21 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
indexing_technique = node_data.indexing_technique or dataset.indexing_technique
|
||||
summary_index_setting = node_data.summary_index_setting or dataset.summary_index_setting
|
||||
|
||||
# Try to get document language if document_id is available
|
||||
doc_language = None
|
||||
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
|
||||
if document_id:
|
||||
document = db.session.query(Document).filter_by(id=document_id.value).first()
|
||||
if document and document.doc_language:
|
||||
doc_language = document.doc_language
|
||||
|
||||
outputs = self._get_preview_output_with_summaries(
|
||||
node_data.chunk_structure,
|
||||
chunks,
|
||||
dataset=dataset,
|
||||
indexing_technique=indexing_technique,
|
||||
summary_index_setting=summary_index_setting,
|
||||
doc_language=doc_language,
|
||||
)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@ -315,6 +324,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
dataset: Dataset,
|
||||
indexing_technique: str | None = None,
|
||||
summary_index_setting: dict | None = None,
|
||||
doc_language: str | None = None,
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Generate preview output with summaries for chunks in preview mode.
|
||||
@ -326,6 +336,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
dataset: Dataset object (for tenant_id)
|
||||
indexing_technique: Indexing technique from node config or dataset
|
||||
summary_index_setting: Summary index setting from node config or dataset
|
||||
doc_language: Optional document language to ensure summary is generated in the correct language
|
||||
"""
|
||||
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
|
||||
preview_output = index_processor.format_preview(chunks)
|
||||
@ -365,6 +376,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
tenant_id=dataset.tenant_id,
|
||||
text=preview_item["content"],
|
||||
summary_index_setting=summary_index_setting,
|
||||
document_language=doc_language,
|
||||
)
|
||||
if summary:
|
||||
preview_item["summary"] = summary
|
||||
@ -374,6 +386,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
tenant_id=dataset.tenant_id,
|
||||
text=preview_item["content"],
|
||||
summary_index_setting=summary_index_setting,
|
||||
document_language=doc_language,
|
||||
)
|
||||
if summary:
|
||||
preview_item["summary"] = summary
|
||||
|
||||
@ -303,33 +303,34 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||
if node_data.multiple_retrieval_config is None:
|
||||
raise ValueError("multiple_retrieval_config is required")
|
||||
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
|
||||
if node_data.multiple_retrieval_config.reranking_model:
|
||||
reranking_model = {
|
||||
"reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
|
||||
"reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
|
||||
}
|
||||
else:
|
||||
match node_data.multiple_retrieval_config.reranking_mode:
|
||||
case "reranking_model":
|
||||
if node_data.multiple_retrieval_config.reranking_model:
|
||||
reranking_model = {
|
||||
"reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
|
||||
"reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
|
||||
}
|
||||
else:
|
||||
reranking_model = None
|
||||
weights = None
|
||||
case "weighted_score":
|
||||
if node_data.multiple_retrieval_config.weights is None:
|
||||
raise ValueError("weights is required")
|
||||
reranking_model = None
|
||||
weights = None
|
||||
elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score":
|
||||
if node_data.multiple_retrieval_config.weights is None:
|
||||
raise ValueError("weights is required")
|
||||
reranking_model = None
|
||||
vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
|
||||
weights = {
|
||||
"vector_setting": {
|
||||
"vector_weight": vector_setting.vector_weight,
|
||||
"embedding_provider_name": vector_setting.embedding_provider_name,
|
||||
"embedding_model_name": vector_setting.embedding_model_name,
|
||||
},
|
||||
"keyword_setting": {
|
||||
"keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight
|
||||
},
|
||||
}
|
||||
else:
|
||||
reranking_model = None
|
||||
weights = None
|
||||
vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
|
||||
weights = {
|
||||
"vector_setting": {
|
||||
"vector_weight": vector_setting.vector_weight,
|
||||
"embedding_provider_name": vector_setting.embedding_provider_name,
|
||||
"embedding_model_name": vector_setting.embedding_model_name,
|
||||
},
|
||||
"keyword_setting": {
|
||||
"keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight
|
||||
},
|
||||
}
|
||||
case _:
|
||||
reranking_model = None
|
||||
weights = None
|
||||
all_documents = dataset_retrieval.multiple_retrieve(
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id,
|
||||
@ -453,73 +454,74 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
)
|
||||
filters: list[Any] = []
|
||||
metadata_condition = None
|
||||
if node_data.metadata_filtering_mode == "disabled":
|
||||
return None, None, usage
|
||||
elif node_data.metadata_filtering_mode == "automatic":
|
||||
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
|
||||
dataset_ids, query, node_data
|
||||
)
|
||||
usage = self._merge_usage(usage, automatic_usage)
|
||||
if automatic_metadata_filters:
|
||||
conditions = []
|
||||
for sequence, filter in enumerate(automatic_metadata_filters):
|
||||
DatasetRetrieval.process_metadata_filter_func(
|
||||
sequence,
|
||||
filter.get("condition", ""),
|
||||
filter.get("metadata_name", ""),
|
||||
filter.get("value"),
|
||||
filters,
|
||||
)
|
||||
conditions.append(
|
||||
Condition(
|
||||
name=filter.get("metadata_name"), # type: ignore
|
||||
comparison_operator=filter.get("condition"), # type: ignore
|
||||
value=filter.get("value"),
|
||||
)
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator
|
||||
if node_data.metadata_filtering_conditions
|
||||
else "or",
|
||||
conditions=conditions,
|
||||
match node_data.metadata_filtering_mode:
|
||||
case "disabled":
|
||||
return None, None, usage
|
||||
case "automatic":
|
||||
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
|
||||
dataset_ids, query, node_data
|
||||
)
|
||||
elif node_data.metadata_filtering_mode == "manual":
|
||||
if node_data.metadata_filtering_conditions:
|
||||
conditions = []
|
||||
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
|
||||
metadata_name = condition.name
|
||||
expected_value = condition.value
|
||||
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
|
||||
if isinstance(expected_value, str):
|
||||
expected_value = self.graph_runtime_state.variable_pool.convert_template(
|
||||
expected_value
|
||||
).value[0]
|
||||
if expected_value.value_type in {"number", "integer", "float"}:
|
||||
expected_value = expected_value.value
|
||||
elif expected_value.value_type == "string":
|
||||
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
|
||||
else:
|
||||
raise ValueError("Invalid expected metadata value type")
|
||||
conditions.append(
|
||||
Condition(
|
||||
name=metadata_name,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
value=expected_value,
|
||||
usage = self._merge_usage(usage, automatic_usage)
|
||||
if automatic_metadata_filters:
|
||||
conditions = []
|
||||
for sequence, filter in enumerate(automatic_metadata_filters):
|
||||
DatasetRetrieval.process_metadata_filter_func(
|
||||
sequence,
|
||||
filter.get("condition", ""),
|
||||
filter.get("metadata_name", ""),
|
||||
filter.get("value"),
|
||||
filters,
|
||||
)
|
||||
conditions.append(
|
||||
Condition(
|
||||
name=filter.get("metadata_name"), # type: ignore
|
||||
comparison_operator=filter.get("condition"), # type: ignore
|
||||
value=filter.get("value"),
|
||||
)
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator
|
||||
if node_data.metadata_filtering_conditions
|
||||
else "or",
|
||||
conditions=conditions,
|
||||
)
|
||||
filters = DatasetRetrieval.process_metadata_filter_func(
|
||||
sequence,
|
||||
condition.comparison_operator,
|
||||
metadata_name,
|
||||
expected_value,
|
||||
filters,
|
||||
case "manual":
|
||||
if node_data.metadata_filtering_conditions:
|
||||
conditions = []
|
||||
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
|
||||
metadata_name = condition.name
|
||||
expected_value = condition.value
|
||||
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
|
||||
if isinstance(expected_value, str):
|
||||
expected_value = self.graph_runtime_state.variable_pool.convert_template(
|
||||
expected_value
|
||||
).value[0]
|
||||
if expected_value.value_type in {"number", "integer", "float"}:
|
||||
expected_value = expected_value.value
|
||||
elif expected_value.value_type == "string":
|
||||
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
|
||||
else:
|
||||
raise ValueError("Invalid expected metadata value type")
|
||||
conditions.append(
|
||||
Condition(
|
||||
name=metadata_name,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
value=expected_value,
|
||||
)
|
||||
)
|
||||
filters = DatasetRetrieval.process_metadata_filter_func(
|
||||
sequence,
|
||||
condition.comparison_operator,
|
||||
metadata_name,
|
||||
expected_value,
|
||||
filters,
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
|
||||
conditions=conditions,
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
|
||||
conditions=conditions,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid metadata filtering mode")
|
||||
case _:
|
||||
raise ValueError("Invalid metadata filtering mode")
|
||||
if filters:
|
||||
if (
|
||||
node_data.metadata_filtering_conditions
|
||||
|
||||
@ -196,13 +196,13 @@ def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
|
||||
case "name":
|
||||
return lambda x: x.filename or ""
|
||||
case "type":
|
||||
return lambda x: x.type
|
||||
return lambda x: str(x.type)
|
||||
case "extension":
|
||||
return lambda x: x.extension or ""
|
||||
case "mime_type":
|
||||
return lambda x: x.mime_type or ""
|
||||
case "transfer_method":
|
||||
return lambda x: x.transfer_method
|
||||
return lambda x: str(x.transfer_method)
|
||||
case "url":
|
||||
return lambda x: x.remote_url or ""
|
||||
case "related_id":
|
||||
@ -276,7 +276,6 @@ def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Calla
|
||||
|
||||
|
||||
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
|
||||
extract_func: Callable[[File], Any]
|
||||
if key in {"name", "extension", "mime_type", "url", "related_id"} and isinstance(value, str):
|
||||
extract_func = _get_file_extract_string_func(key=key)
|
||||
return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x))
|
||||
@ -284,8 +283,8 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str
|
||||
extract_func = _get_file_extract_string_func(key=key)
|
||||
return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x))
|
||||
elif key == "size" and isinstance(value, str):
|
||||
extract_func = _get_file_extract_number_func(key=key)
|
||||
return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x))
|
||||
extract_number = _get_file_extract_number_func(key=key)
|
||||
return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_number(x))
|
||||
else:
|
||||
raise InvalidKeyError(f"Invalid key: {key}")
|
||||
|
||||
|
||||
@ -1288,18 +1288,16 @@ class LLMNode(Node[LLMNodeData]):
|
||||
# Insert histories into the prompt
|
||||
prompt_content = prompt_messages[0].content
|
||||
# For issue #11247 - Check if prompt content is a string or a list
|
||||
prompt_content_type = type(prompt_content)
|
||||
if prompt_content_type == str:
|
||||
if isinstance(prompt_content, str):
|
||||
prompt_content = str(prompt_content)
|
||||
if "#histories#" in prompt_content:
|
||||
prompt_content = prompt_content.replace("#histories#", memory_text)
|
||||
else:
|
||||
prompt_content = memory_text + "\n" + prompt_content
|
||||
prompt_messages[0].content = prompt_content
|
||||
elif prompt_content_type == list:
|
||||
prompt_content = prompt_content if isinstance(prompt_content, list) else []
|
||||
elif isinstance(prompt_content, list):
|
||||
for content_item in prompt_content:
|
||||
if content_item.type == PromptMessageContentType.TEXT:
|
||||
if isinstance(content_item, TextPromptMessageContent):
|
||||
if "#histories#" in content_item.data:
|
||||
content_item.data = content_item.data.replace("#histories#", memory_text)
|
||||
else:
|
||||
@ -1309,13 +1307,12 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
# Add current query to the prompt message
|
||||
if sys_query:
|
||||
if prompt_content_type == str:
|
||||
if isinstance(prompt_content, str):
|
||||
prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
|
||||
prompt_messages[0].content = prompt_content
|
||||
elif prompt_content_type == list:
|
||||
prompt_content = prompt_content if isinstance(prompt_content, list) else []
|
||||
elif isinstance(prompt_content, list):
|
||||
for content_item in prompt_content:
|
||||
if content_item.type == PromptMessageContentType.TEXT:
|
||||
if isinstance(content_item, TextPromptMessageContent):
|
||||
content_item.data = sys_query + "\n" + content_item.data
|
||||
else:
|
||||
raise ValueError("Invalid prompt content type")
|
||||
@ -1481,13 +1478,14 @@ class LLMNode(Node[LLMNodeData]):
|
||||
if typed_node_data.prompt_config:
|
||||
enable_jinja = False
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
for item in prompt_template:
|
||||
if isinstance(item, LLMNodeChatModelMessage) and item.edition_type == "jinja2":
|
||||
if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
||||
if prompt_template.edition_type == "jinja2":
|
||||
enable_jinja = True
|
||||
else:
|
||||
for prompt in prompt_template:
|
||||
if prompt.edition_type == "jinja2":
|
||||
enable_jinja = True
|
||||
break
|
||||
else:
|
||||
enable_jinja = True
|
||||
|
||||
if enable_jinja:
|
||||
for variable_selector in typed_node_data.prompt_config.jinja2_variables or []:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Protocol
|
||||
from typing import Any, Protocol
|
||||
|
||||
import httpx
|
||||
|
||||
@ -12,17 +12,17 @@ class HttpClientProtocol(Protocol):
|
||||
@property
|
||||
def request_error(self) -> type[Exception]: ...
|
||||
|
||||
def get(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
|
||||
def get(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
|
||||
|
||||
def head(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
|
||||
def head(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
|
||||
|
||||
def post(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
|
||||
def post(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
|
||||
|
||||
def put(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
|
||||
def put(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
|
||||
|
||||
def delete(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
|
||||
def delete(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
|
||||
|
||||
def patch(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
|
||||
def patch(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
|
||||
|
||||
|
||||
class FileManagerProtocol(Protocol):
|
||||
|
||||
@ -513,25 +513,26 @@ class ToolNode(Node[ToolNodeData]):
|
||||
result: dict[str, Sequence[str]] = {}
|
||||
for parameter_name in typed_node_data.tool_parameters:
|
||||
input = typed_node_data.tool_parameters[parameter_name]
|
||||
if input.type == "mixed":
|
||||
assert isinstance(input.value, str)
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == "variable":
|
||||
if isinstance(input.value, list):
|
||||
selector_key = ".".join(input.value)
|
||||
result[f"#{selector_key}#"] = input.value
|
||||
elif input.type == "nested_node":
|
||||
# Nested node type: extract variable selector from nested_node_config
|
||||
# The full selector is extractor_node_id + output_selector
|
||||
if input.nested_node_config is not None:
|
||||
config = input.nested_node_config
|
||||
full_selector = [config.extractor_node_id] + list(config.output_selector)
|
||||
selector_key = ".".join(full_selector)
|
||||
result[f"#{selector_key}#"] = full_selector
|
||||
elif input.type == "constant":
|
||||
pass
|
||||
match input.type:
|
||||
case "mixed":
|
||||
assert isinstance(input.value, str)
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
case "variable":
|
||||
if isinstance(input.value, list):
|
||||
selector_key = ".".join(input.value)
|
||||
result[f"#{selector_key}#"] = input.value
|
||||
case "nested_node":
|
||||
# Nested node type: extract variable selector from nested_node_config
|
||||
# The full selector is extractor_node_id + output_selector
|
||||
if input.nested_node_config is not None:
|
||||
config = input.nested_node_config
|
||||
full_selector = [config.extractor_node_id] + list(config.output_selector)
|
||||
selector_key = ".".join(full_selector)
|
||||
result[f"#{selector_key}#"] = full_selector
|
||||
case "constant":
|
||||
pass
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
|
||||
@ -6,13 +6,14 @@ import threading
|
||||
from collections.abc import Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol
|
||||
from typing import Any, ClassVar, Protocol
|
||||
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.sandbox.sandbox import Sandbox
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
|
||||
from core.workflow.runtime.variable_pool import VariablePool
|
||||
|
||||
|
||||
@ -104,14 +105,33 @@ class ResponseStreamCoordinatorProtocol(Protocol):
|
||||
...
|
||||
|
||||
|
||||
class NodeProtocol(Protocol):
|
||||
"""Structural interface for graph nodes."""
|
||||
|
||||
id: str
|
||||
state: NodeState
|
||||
execution_type: NodeExecutionType
|
||||
node_type: ClassVar[NodeType]
|
||||
|
||||
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: ...
|
||||
|
||||
|
||||
class EdgeProtocol(Protocol):
|
||||
id: str
|
||||
state: NodeState
|
||||
tail: str
|
||||
head: str
|
||||
source_handle: str
|
||||
|
||||
|
||||
class GraphProtocol(Protocol):
|
||||
"""Structural interface required from graph instances attached to the runtime state."""
|
||||
|
||||
nodes: Mapping[str, object]
|
||||
edges: Mapping[str, object]
|
||||
root_node: object
|
||||
nodes: Mapping[str, NodeProtocol]
|
||||
edges: Mapping[str, EdgeProtocol]
|
||||
root_node: NodeProtocol
|
||||
|
||||
def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ...
|
||||
def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ...
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
|
||||
@ -146,11 +146,11 @@ class WorkflowEntry:
|
||||
:param user_inputs: user inputs
|
||||
:return:
|
||||
"""
|
||||
node_config = dict(workflow.get_node_config_by_id(node_id))
|
||||
node_config_data = node_config.get("data", {})
|
||||
node_config = workflow.get_node_config_by_id(node_id)
|
||||
node_config_data = node_config["data"]
|
||||
|
||||
# Get node type
|
||||
node_type = NodeType(node_config_data.get("type"))
|
||||
node_type = NodeType(node_config_data["type"])
|
||||
|
||||
# init graph init params and runtime state
|
||||
graph_init_params = GraphInitParams(
|
||||
|
||||
@ -27,9 +27,11 @@ def init_app(app: DifyApp) -> None:
|
||||
)
|
||||
|
||||
# Ensure route decorators are evaluated.
|
||||
import controllers.console.init_validate as init_validate_module
|
||||
import controllers.console.ping as ping_module
|
||||
from controllers.console import remote_files, setup
|
||||
|
||||
_ = init_validate_module
|
||||
_ = ping_module
|
||||
_ = remote_files
|
||||
_ = setup
|
||||
|
||||
@ -1,36 +1,69 @@
|
||||
from flask_restx import Namespace, fields
|
||||
from __future__ import annotations
|
||||
|
||||
from libs.helper import TimestampField
|
||||
from datetime import datetime
|
||||
|
||||
annotation_fields = {
|
||||
"id": fields.String,
|
||||
"question": fields.String,
|
||||
"answer": fields.Raw(attribute="content"),
|
||||
"hit_count": fields.Integer,
|
||||
"created_at": TimestampField,
|
||||
# 'account': fields.Nested(simple_account_fields, allow_null=True)
|
||||
}
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
def build_annotation_model(api_or_ns: Namespace):
|
||||
"""Build the annotation model for the API or Namespace."""
|
||||
return api_or_ns.model("Annotation", annotation_fields)
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
annotation_list_fields = {
|
||||
"data": fields.List(fields.Nested(annotation_fields)),
|
||||
}
|
||||
class ResponseModel(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
extra="ignore",
|
||||
populate_by_name=True,
|
||||
serialize_by_alias=True,
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
annotation_hit_history_fields = {
|
||||
"id": fields.String,
|
||||
"source": fields.String,
|
||||
"score": fields.Float,
|
||||
"question": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"match": fields.String(attribute="annotation_question"),
|
||||
"response": fields.String(attribute="annotation_content"),
|
||||
}
|
||||
|
||||
annotation_hit_history_list_fields = {
|
||||
"data": fields.List(fields.Nested(annotation_hit_history_fields)),
|
||||
}
|
||||
class Annotation(ResponseModel):
|
||||
id: str
|
||||
question: str | None = None
|
||||
answer: str | None = Field(default=None, validation_alias="content")
|
||||
hit_count: int | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class AnnotationList(ResponseModel):
|
||||
data: list[Annotation]
|
||||
has_more: bool
|
||||
limit: int
|
||||
total: int
|
||||
page: int
|
||||
|
||||
|
||||
class AnnotationExportList(ResponseModel):
|
||||
data: list[Annotation]
|
||||
|
||||
|
||||
class AnnotationHitHistory(ResponseModel):
|
||||
id: str
|
||||
source: str | None = None
|
||||
score: float | None = None
|
||||
question: str | None = None
|
||||
created_at: int | None = None
|
||||
match: str | None = Field(default=None, validation_alias="annotation_question")
|
||||
response: str | None = Field(default=None, validation_alias="annotation_content")
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class AnnotationHitHistoryList(ResponseModel):
|
||||
data: list[AnnotationHitHistory]
|
||||
has_more: bool
|
||||
limit: int
|
||||
total: int
|
||||
page: int
|
||||
|
||||
@ -1,4 +1,7 @@
|
||||
from flask_restx import Namespace, fields
|
||||
from __future__ import annotations
|
||||
|
||||
from flask_restx import fields
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
simple_end_user_fields = {
|
||||
"id": fields.String,
|
||||
@ -8,5 +11,18 @@ simple_end_user_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_simple_end_user_model(api_or_ns: Namespace):
|
||||
return api_or_ns.model("SimpleEndUser", simple_end_user_fields)
|
||||
class ResponseModel(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
extra="ignore",
|
||||
populate_by_name=True,
|
||||
serialize_by_alias=True,
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
|
||||
class SimpleEndUser(ResponseModel):
|
||||
id: str
|
||||
type: str
|
||||
is_anonymous: bool
|
||||
session_id: str | None = None
|
||||
|
||||
@ -1,6 +1,11 @@
|
||||
from flask_restx import Namespace, fields
|
||||
from __future__ import annotations
|
||||
|
||||
from libs.helper import AvatarUrlField, TimestampField
|
||||
from datetime import datetime
|
||||
|
||||
from flask_restx import fields
|
||||
from pydantic import BaseModel, ConfigDict, computed_field, field_validator
|
||||
|
||||
from core.file import helpers as file_helpers
|
||||
|
||||
simple_account_fields = {
|
||||
"id": fields.String,
|
||||
@ -9,36 +14,78 @@ simple_account_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_simple_account_model(api_or_ns: Namespace):
|
||||
return api_or_ns.model("SimpleAccount", simple_account_fields)
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
account_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"avatar": fields.String,
|
||||
"avatar_url": AvatarUrlField,
|
||||
"email": fields.String,
|
||||
"is_password_set": fields.Boolean,
|
||||
"interface_language": fields.String,
|
||||
"interface_theme": fields.String,
|
||||
"timezone": fields.String,
|
||||
"last_login_at": TimestampField,
|
||||
"last_login_ip": fields.String,
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
def _build_avatar_url(avatar: str | None) -> str | None:
|
||||
if avatar is None:
|
||||
return None
|
||||
if avatar.startswith(("http://", "https://")):
|
||||
return avatar
|
||||
return file_helpers.get_signed_file_url(avatar)
|
||||
|
||||
account_with_role_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"avatar": fields.String,
|
||||
"avatar_url": AvatarUrlField,
|
||||
"email": fields.String,
|
||||
"last_login_at": TimestampField,
|
||||
"last_active_at": TimestampField,
|
||||
"created_at": TimestampField,
|
||||
"role": fields.String,
|
||||
"status": fields.String,
|
||||
}
|
||||
|
||||
account_with_role_list_fields = {"accounts": fields.List(fields.Nested(account_with_role_fields))}
|
||||
class ResponseModel(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
extra="ignore",
|
||||
populate_by_name=True,
|
||||
serialize_by_alias=True,
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
|
||||
class SimpleAccount(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
|
||||
|
||||
class _AccountAvatar(ResponseModel):
|
||||
avatar: str | None = None
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore[prop-decorator]
|
||||
@property
|
||||
def avatar_url(self) -> str | None:
|
||||
return _build_avatar_url(self.avatar)
|
||||
|
||||
|
||||
class Account(_AccountAvatar):
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
is_password_set: bool
|
||||
interface_language: str | None = None
|
||||
interface_theme: str | None = None
|
||||
timezone: str | None = None
|
||||
last_login_at: int | None = None
|
||||
last_login_ip: str | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("last_login_at", "created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class AccountWithRole(_AccountAvatar):
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
last_login_at: int | None = None
|
||||
last_active_at: int | None = None
|
||||
created_at: int | None = None
|
||||
role: str
|
||||
status: str
|
||||
|
||||
@field_validator("last_login_at", "last_active_at", "created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class AccountWithRoleList(ResponseModel):
|
||||
accounts: list[AccountWithRole]
|
||||
|
||||
@ -1,12 +1,20 @@
|
||||
from flask_restx import Namespace, fields
|
||||
from __future__ import annotations
|
||||
|
||||
dataset_tag_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"type": fields.String,
|
||||
"binding_count": fields.String,
|
||||
}
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
def build_dataset_tag_fields(api_or_ns: Namespace):
|
||||
return api_or_ns.model("DataSetTag", dataset_tag_fields)
|
||||
class ResponseModel(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
extra="ignore",
|
||||
populate_by_name=True,
|
||||
serialize_by_alias=True,
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
|
||||
class DataSetTag(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
type: str
|
||||
binding_count: str | None = None
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields
|
||||
from fields.member_fields import build_simple_account_model, simple_account_fields
|
||||
from fields.end_user_fields import simple_end_user_fields
|
||||
from fields.member_fields import simple_account_fields
|
||||
from fields.workflow_run_fields import (
|
||||
build_workflow_run_for_archived_log_model,
|
||||
build_workflow_run_for_log_model,
|
||||
@ -25,17 +25,9 @@ workflow_app_log_partial_fields = {
|
||||
def build_workflow_app_log_partial_model(api_or_ns: Namespace):
|
||||
"""Build the workflow app log partial model for the API or Namespace."""
|
||||
workflow_run_model = build_workflow_run_for_log_model(api_or_ns)
|
||||
simple_account_model = build_simple_account_model(api_or_ns)
|
||||
simple_end_user_model = build_simple_end_user_model(api_or_ns)
|
||||
|
||||
copied_fields = workflow_app_log_partial_fields.copy()
|
||||
copied_fields["workflow_run"] = fields.Nested(workflow_run_model, attribute="workflow_run", allow_null=True)
|
||||
copied_fields["created_by_account"] = fields.Nested(
|
||||
simple_account_model, attribute="created_by_account", allow_null=True
|
||||
)
|
||||
copied_fields["created_by_end_user"] = fields.Nested(
|
||||
simple_end_user_model, attribute="created_by_end_user", allow_null=True
|
||||
)
|
||||
return api_or_ns.model("WorkflowAppLogPartial", copied_fields)
|
||||
|
||||
|
||||
@ -52,17 +44,9 @@ workflow_archived_log_partial_fields = {
|
||||
def build_workflow_archived_log_partial_model(api_or_ns: Namespace):
|
||||
"""Build the workflow archived log partial model for the API or Namespace."""
|
||||
workflow_run_model = build_workflow_run_for_archived_log_model(api_or_ns)
|
||||
simple_account_model = build_simple_account_model(api_or_ns)
|
||||
simple_end_user_model = build_simple_end_user_model(api_or_ns)
|
||||
|
||||
copied_fields = workflow_archived_log_partial_fields.copy()
|
||||
copied_fields["workflow_run"] = fields.Nested(workflow_run_model, allow_null=True)
|
||||
copied_fields["created_by_account"] = fields.Nested(
|
||||
simple_account_model, attribute="created_by_account", allow_null=True
|
||||
)
|
||||
copied_fields["created_by_end_user"] = fields.Nested(
|
||||
simple_end_user_model, attribute="created_by_end_user", allow_null=True
|
||||
)
|
||||
return api_or_ns.model("WorkflowArchivedLogPartial", copied_fields)
|
||||
|
||||
|
||||
|
||||
@ -29,6 +29,7 @@ from core.workflow.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
)
|
||||
from core.workflow.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause
|
||||
from core.workflow.enums import NodeType
|
||||
from extensions.ext_storage import Storage
|
||||
@ -260,7 +261,7 @@ class Workflow(Base): # bug
|
||||
# - `_get_graph_and_variable_pool_for_single_node_run`.
|
||||
return json.loads(self.graph) if self.graph else {}
|
||||
|
||||
def get_node_config_by_id(self, node_id: str) -> Mapping[str, Any]:
|
||||
def get_node_config_by_id(self, node_id: str) -> NodeConfigDict:
|
||||
"""Extract a node configuration from the workflow graph by node ID.
|
||||
A node configuration is a dictionary containing the node's properties, including
|
||||
the node's id, title, and its data as a dict.
|
||||
@ -278,8 +279,7 @@ class Workflow(Base): # bug
|
||||
node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes))
|
||||
except StopIteration:
|
||||
raise NodeNotFoundError(node_id)
|
||||
assert isinstance(node_config, dict)
|
||||
return node_config
|
||||
return NodeConfigDictAdapter.validate_python(node_config)
|
||||
|
||||
@staticmethod
|
||||
def get_node_type_from_node_config(node_config: Mapping[str, Any]) -> NodeType:
|
||||
|
||||
@ -91,7 +91,7 @@ dependencies = [
|
||||
"sseclient-py~=1.8.0",
|
||||
"httpx-sse~=0.4.0",
|
||||
"sendgrid~=6.12.3",
|
||||
"flask-restx~=1.3.0",
|
||||
"flask-restx~=1.3.2",
|
||||
"packaging==24.1",
|
||||
"croniter>=6.0.0",
|
||||
"weaviate-client==4.17.0",
|
||||
@ -122,7 +122,7 @@ dev = [
|
||||
"dotenv-linter~=0.5.0",
|
||||
"faker~=38.2.0",
|
||||
"lxml-stubs~=0.5.1",
|
||||
"ty~=0.0.1a19",
|
||||
"ty>=0.0.14",
|
||||
"basedpyright~=1.31.0",
|
||||
"ruff~=0.14.0",
|
||||
"pytest~=8.3.2",
|
||||
@ -151,7 +151,7 @@ dev = [
|
||||
"types-openpyxl~=3.1.5",
|
||||
"types-pexpect~=4.9.0",
|
||||
"types-protobuf~=5.29.1",
|
||||
"types-psutil~=7.0.0",
|
||||
"types-psutil~=7.2.2",
|
||||
"types-psycopg2~=2.9.21",
|
||||
"types-pygments~=2.19.0",
|
||||
"types-pymysql~=1.1.0",
|
||||
|
||||
@ -158,7 +158,7 @@ class AppAnnotationService:
|
||||
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
|
||||
)
|
||||
annotations = db.paginate(select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
return annotations.items, annotations.total
|
||||
return annotations.items, annotations.total or 0
|
||||
|
||||
@classmethod
|
||||
def export_annotation_list_by_app_id(cls, app_id: str):
|
||||
@ -524,7 +524,7 @@ class AppAnnotationService:
|
||||
annotation_hit_histories = db.paginate(
|
||||
select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False
|
||||
)
|
||||
return annotation_hit_histories.items, annotation_hit_histories.total
|
||||
return annotation_hit_histories.items, annotation_hit_histories.total or 0
|
||||
|
||||
@classmethod
|
||||
def get_annotation_by_id(cls, annotation_id: str) -> MessageAnnotation | None:
|
||||
|
||||
@ -16,6 +16,7 @@ from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper.name_generator import generate_incremental_name
|
||||
@ -1388,6 +1389,46 @@ class DocumentService:
|
||||
).all()
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def update_documents_need_summary(dataset_id: str, document_ids: Sequence[str], need_summary: bool = True) -> int:
|
||||
"""
|
||||
Update need_summary field for multiple documents.
|
||||
|
||||
This method handles the case where documents were created when summary_index_setting was disabled,
|
||||
and need to be updated when summary_index_setting is later enabled.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID
|
||||
document_ids: List of document IDs to update
|
||||
need_summary: Value to set for need_summary field (default: True)
|
||||
|
||||
Returns:
|
||||
Number of documents updated
|
||||
"""
|
||||
if not document_ids:
|
||||
return 0
|
||||
|
||||
document_id_list: list[str] = [str(document_id) for document_id in document_ids]
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
updated_count = (
|
||||
session.query(Document)
|
||||
.filter(
|
||||
Document.id.in_(document_id_list),
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.doc_form != "qa_model", # Skip qa_model documents
|
||||
)
|
||||
.update({Document.need_summary: need_summary}, synchronize_session=False)
|
||||
)
|
||||
session.commit()
|
||||
logger.info(
|
||||
"Updated need_summary to %s for %d documents in dataset %s",
|
||||
need_summary,
|
||||
updated_count,
|
||||
dataset_id,
|
||||
)
|
||||
return updated_count
|
||||
|
||||
@staticmethod
|
||||
def get_document_download_url(document: Document) -> str:
|
||||
"""
|
||||
@ -2937,14 +2978,15 @@ class DocumentService:
|
||||
"""
|
||||
now = naive_utc_now()
|
||||
|
||||
if action == "enable":
|
||||
return DocumentService._prepare_enable_update(document, now)
|
||||
elif action == "disable":
|
||||
return DocumentService._prepare_disable_update(document, user, now)
|
||||
elif action == "archive":
|
||||
return DocumentService._prepare_archive_update(document, user, now)
|
||||
elif action == "un_archive":
|
||||
return DocumentService._prepare_unarchive_update(document, now)
|
||||
match action:
|
||||
case "enable":
|
||||
return DocumentService._prepare_enable_update(document, now)
|
||||
case "disable":
|
||||
return DocumentService._prepare_disable_update(document, user, now)
|
||||
case "archive":
|
||||
return DocumentService._prepare_archive_update(document, user, now)
|
||||
case "un_archive":
|
||||
return DocumentService._prepare_unarchive_update(document, now)
|
||||
|
||||
return None
|
||||
|
||||
@ -3581,56 +3623,57 @@ class SegmentService:
|
||||
# Check if segment_ids is not empty to avoid WHERE false condition
|
||||
if not segment_ids or len(segment_ids) == 0:
|
||||
return
|
||||
if action == "enable":
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.enabled == False,
|
||||
)
|
||||
).all()
|
||||
if not segments:
|
||||
return
|
||||
real_deal_segment_ids = []
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
continue
|
||||
segment.enabled = True
|
||||
segment.disabled_at = None
|
||||
segment.disabled_by = None
|
||||
db.session.add(segment)
|
||||
real_deal_segment_ids.append(segment.id)
|
||||
db.session.commit()
|
||||
match action:
|
||||
case "enable":
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.enabled == False,
|
||||
)
|
||||
).all()
|
||||
if not segments:
|
||||
return
|
||||
real_deal_segment_ids = []
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
continue
|
||||
segment.enabled = True
|
||||
segment.disabled_at = None
|
||||
segment.disabled_by = None
|
||||
db.session.add(segment)
|
||||
real_deal_segment_ids.append(segment.id)
|
||||
db.session.commit()
|
||||
|
||||
enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
|
||||
elif action == "disable":
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
).all()
|
||||
if not segments:
|
||||
return
|
||||
real_deal_segment_ids = []
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
continue
|
||||
segment.enabled = False
|
||||
segment.disabled_at = naive_utc_now()
|
||||
segment.disabled_by = current_user.id
|
||||
db.session.add(segment)
|
||||
real_deal_segment_ids.append(segment.id)
|
||||
db.session.commit()
|
||||
enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
|
||||
case "disable":
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
).all()
|
||||
if not segments:
|
||||
return
|
||||
real_deal_segment_ids = []
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
continue
|
||||
segment.enabled = False
|
||||
segment.disabled_at = naive_utc_now()
|
||||
segment.disabled_by = current_user.id
|
||||
db.session.add(segment)
|
||||
real_deal_segment_ids.append(segment.id)
|
||||
db.session.commit()
|
||||
|
||||
disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
|
||||
disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
|
||||
|
||||
@classmethod
|
||||
def create_child_chunk(
|
||||
|
||||
@ -174,6 +174,10 @@ class RagPipelineTransformService:
|
||||
else:
|
||||
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
|
||||
|
||||
# Copy summary_index_setting from dataset to knowledge_index node configuration
|
||||
if dataset.summary_index_setting:
|
||||
knowledge_configuration.summary_index_setting = dataset.summary_index_setting
|
||||
|
||||
knowledge_configuration_dict.update(knowledge_configuration.model_dump())
|
||||
node["data"] = knowledge_configuration_dict
|
||||
return node
|
||||
|
||||
@ -49,11 +49,18 @@ class SummaryIndexService:
|
||||
# Use lazy import to avoid circular import
|
||||
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
|
||||
|
||||
# Get document language to ensure summary is generated in the correct language
|
||||
# This is especially important for image-only chunks where text is empty or minimal
|
||||
document_language = None
|
||||
if segment.document and segment.document.doc_language:
|
||||
document_language = segment.document.doc_language
|
||||
|
||||
summary_content, usage = ParagraphIndexProcessor.generate_summary(
|
||||
tenant_id=dataset.tenant_id,
|
||||
text=segment.content,
|
||||
summary_index_setting=summary_index_setting,
|
||||
segment_id=segment.id,
|
||||
document_language=document_language,
|
||||
)
|
||||
|
||||
if not summary_content:
|
||||
@ -558,6 +565,9 @@ class SummaryIndexService:
|
||||
)
|
||||
session.add(summary_record)
|
||||
|
||||
# Commit the batch created records
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def update_summary_record_error(
|
||||
segment: DocumentSegment,
|
||||
@ -762,7 +772,6 @@ class SummaryIndexService:
|
||||
dataset=dataset,
|
||||
status="not_started",
|
||||
)
|
||||
session.commit() # Commit initial records
|
||||
|
||||
summary_records = []
|
||||
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -10,8 +8,8 @@ from sqlalchemy.orm import Session
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
@ -38,12 +36,10 @@ class WorkflowToolManageService:
|
||||
label: str,
|
||||
icon: dict,
|
||||
description: str,
|
||||
parameters: list[Mapping[str, Any]],
|
||||
parameters: list[WorkflowToolParameterConfiguration],
|
||||
privacy_policy: str = "",
|
||||
labels: list[str] | None = None,
|
||||
):
|
||||
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
|
||||
|
||||
# check if the name is unique
|
||||
existing_workflow_tool_provider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
@ -75,7 +71,7 @@ class WorkflowToolManageService:
|
||||
label=label,
|
||||
icon=json.dumps(icon),
|
||||
description=description,
|
||||
parameter_configuration=json.dumps(parameters),
|
||||
parameter_configuration=json.dumps([p.model_dump() for p in parameters]),
|
||||
privacy_policy=privacy_policy,
|
||||
version=workflow.version,
|
||||
)
|
||||
@ -104,7 +100,7 @@ class WorkflowToolManageService:
|
||||
label: str,
|
||||
icon: dict,
|
||||
description: str,
|
||||
parameters: list[Mapping[str, Any]],
|
||||
parameters: list[WorkflowToolParameterConfiguration],
|
||||
privacy_policy: str = "",
|
||||
labels: list[str] | None = None,
|
||||
):
|
||||
@ -122,8 +118,6 @@ class WorkflowToolManageService:
|
||||
:param labels: labels
|
||||
:return: the updated tool
|
||||
"""
|
||||
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
|
||||
|
||||
# check if the name is unique
|
||||
existing_workflow_tool_provider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
@ -162,7 +156,7 @@ class WorkflowToolManageService:
|
||||
workflow_tool_provider.label = label
|
||||
workflow_tool_provider.icon = json.dumps(icon)
|
||||
workflow_tool_provider.description = description
|
||||
workflow_tool_provider.parameter_configuration = json.dumps(parameters)
|
||||
workflow_tool_provider.parameter_configuration = json.dumps([p.model_dump() for p in parameters])
|
||||
workflow_tool_provider.privacy_policy = privacy_policy
|
||||
workflow_tool_provider.version = workflow.version
|
||||
workflow_tool_provider.updated_at = datetime.now()
|
||||
|
||||
@ -90,6 +90,7 @@ class TestWebhookService:
|
||||
"id": "webhook_node",
|
||||
"type": "webhook",
|
||||
"data": {
|
||||
"type": "trigger-webhook",
|
||||
"title": "Test Webhook",
|
||||
"method": "post",
|
||||
"content_type": "application/json",
|
||||
|
||||
@ -3,7 +3,9 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.workflow import Workflow as WorkflowModel
|
||||
from services.account_service import AccountService, TenantService
|
||||
@ -130,20 +132,24 @@ class TestWorkflowToolManageService:
|
||||
def _create_test_workflow_tool_parameters(self):
|
||||
"""Helper method to create valid workflow tool parameters."""
|
||||
return [
|
||||
{
|
||||
"name": "input_text",
|
||||
"description": "Input text for processing",
|
||||
"form": "form",
|
||||
"type": "string",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "output_format",
|
||||
"description": "Output format specification",
|
||||
"form": "form",
|
||||
"type": "select",
|
||||
"required": False,
|
||||
},
|
||||
WorkflowToolParameterConfiguration.model_validate(
|
||||
{
|
||||
"name": "input_text",
|
||||
"description": "Input text for processing",
|
||||
"form": "form",
|
||||
"type": "string",
|
||||
"required": True,
|
||||
}
|
||||
),
|
||||
WorkflowToolParameterConfiguration.model_validate(
|
||||
{
|
||||
"name": "output_format",
|
||||
"description": "Output format specification",
|
||||
"form": "form",
|
||||
"type": "select",
|
||||
"required": False,
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
def test_create_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
@ -208,7 +214,7 @@ class TestWorkflowToolManageService:
|
||||
assert created_tool_provider.label == tool_label
|
||||
assert created_tool_provider.icon == json.dumps(tool_icon)
|
||||
assert created_tool_provider.description == tool_description
|
||||
assert created_tool_provider.parameter_configuration == json.dumps(tool_parameters)
|
||||
assert created_tool_provider.parameter_configuration == json.dumps([p.model_dump() for p in tool_parameters])
|
||||
assert created_tool_provider.privacy_policy == tool_privacy_policy
|
||||
assert created_tool_provider.version == workflow.version
|
||||
assert created_tool_provider.user_id == account.id
|
||||
@ -353,18 +359,9 @@ class TestWorkflowToolManageService:
|
||||
app, account, workflow = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Setup invalid workflow tool parameters (missing required fields)
|
||||
invalid_parameters = [
|
||||
{
|
||||
"name": "input_text",
|
||||
# Missing description and form fields
|
||||
"type": "string",
|
||||
"required": True,
|
||||
}
|
||||
]
|
||||
# Attempt to create workflow tool with invalid parameters
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
# Setup invalid workflow tool parameters (missing required fields)
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
@ -373,7 +370,16 @@ class TestWorkflowToolManageService:
|
||||
label=fake.word(),
|
||||
icon={"type": "emoji", "emoji": "🔧"},
|
||||
description=fake.text(max_nb_chars=200),
|
||||
parameters=invalid_parameters,
|
||||
parameters=[
|
||||
WorkflowToolParameterConfiguration.model_validate(
|
||||
{
|
||||
"name": "input_text",
|
||||
# Missing description and form fields
|
||||
"type": "string",
|
||||
"required": True,
|
||||
}
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Verify error message contains validation error
|
||||
@ -579,11 +585,12 @@ class TestWorkflowToolManageService:
|
||||
|
||||
# Verify database state was updated
|
||||
db.session.refresh(created_tool)
|
||||
assert created_tool is not None
|
||||
assert created_tool.name == updated_tool_name
|
||||
assert created_tool.label == updated_tool_label
|
||||
assert created_tool.icon == json.dumps(updated_tool_icon)
|
||||
assert created_tool.description == updated_tool_description
|
||||
assert created_tool.parameter_configuration == json.dumps(updated_tool_parameters)
|
||||
assert created_tool.parameter_configuration == json.dumps([p.model_dump() for p in updated_tool_parameters])
|
||||
assert created_tool.privacy_policy == updated_tool_privacy_policy
|
||||
assert created_tool.version == workflow.version
|
||||
assert created_tool.updated_at is not None
|
||||
@ -750,13 +757,15 @@ class TestWorkflowToolManageService:
|
||||
|
||||
# Setup workflow tool parameters with FILE type
|
||||
file_parameters = [
|
||||
{
|
||||
"name": "document",
|
||||
"description": "Upload a document",
|
||||
"form": "form",
|
||||
"type": "file",
|
||||
"required": False,
|
||||
}
|
||||
WorkflowToolParameterConfiguration.model_validate(
|
||||
{
|
||||
"name": "document",
|
||||
"description": "Upload a document",
|
||||
"form": "form",
|
||||
"type": "file",
|
||||
"required": False,
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
# Execute the method under test
|
||||
@ -823,13 +832,15 @@ class TestWorkflowToolManageService:
|
||||
|
||||
# Setup workflow tool parameters with FILES type
|
||||
files_parameters = [
|
||||
{
|
||||
"name": "documents",
|
||||
"description": "Upload multiple documents",
|
||||
"form": "form",
|
||||
"type": "files",
|
||||
"required": False,
|
||||
}
|
||||
WorkflowToolParameterConfiguration.model_validate(
|
||||
{
|
||||
"name": "documents",
|
||||
"description": "Upload multiple documents",
|
||||
"form": "form",
|
||||
"type": "files",
|
||||
"required": False,
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
# Execute the method under test
|
||||
|
||||
@ -0,0 +1,46 @@
|
||||
import builtins
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from extensions import ext_fastopenapi
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.secret_key = "test-secret-key"
|
||||
return app
|
||||
|
||||
|
||||
def test_console_init_get_returns_finished_when_no_init_password(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
ext_fastopenapi.init_app(app)
|
||||
monkeypatch.delenv("INIT_PASSWORD", raising=False)
|
||||
|
||||
with patch("controllers.console.init_validate.dify_config.EDITION", "SELF_HOSTED"):
|
||||
client = app.test_client()
|
||||
response = client.get("/console/api/init")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"status": "finished"}
|
||||
|
||||
|
||||
def test_console_init_post_returns_success(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
ext_fastopenapi.init_app(app)
|
||||
monkeypatch.setenv("INIT_PASSWORD", "test-init-password")
|
||||
|
||||
with (
|
||||
patch("controllers.console.init_validate.dify_config.EDITION", "SELF_HOSTED"),
|
||||
patch("controllers.console.init_validate.TenantService.get_tenant_count", return_value=0),
|
||||
):
|
||||
client = app.test_client()
|
||||
response = client.post("/console/api/init", json={"password": "test-init-password"})
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.get_json() == {"result": "success"}
|
||||
@ -0,0 +1,364 @@
|
||||
"""Endpoint tests for controllers.console.workspace.tool_providers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import importlib
|
||||
from contextlib import contextmanager
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
_CONTROLLER_MODULE: ModuleType | None = None
|
||||
_WRAPS_MODULE: ModuleType | None = None
|
||||
_CONTROLLER_PATCHERS: list[patch] = []
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _mock_db():
|
||||
mock_session = SimpleNamespace(query=lambda *args, **kwargs: SimpleNamespace(first=lambda: True))
|
||||
with patch("extensions.ext_database.db.session", mock_session):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def controller_module(monkeypatch: pytest.MonkeyPatch):
|
||||
module_name = "controllers.console.workspace.tool_providers"
|
||||
global _CONTROLLER_MODULE
|
||||
if _CONTROLLER_MODULE is None:
|
||||
|
||||
def _noop(func):
|
||||
return func
|
||||
|
||||
patch_targets = [
|
||||
("libs.login.login_required", _noop),
|
||||
("controllers.console.wraps.setup_required", _noop),
|
||||
("controllers.console.wraps.account_initialization_required", _noop),
|
||||
("controllers.console.wraps.is_admin_or_owner_required", _noop),
|
||||
("controllers.console.wraps.enterprise_license_required", _noop),
|
||||
]
|
||||
for target, value in patch_targets:
|
||||
patcher = patch(target, value)
|
||||
patcher.start()
|
||||
_CONTROLLER_PATCHERS.append(patcher)
|
||||
monkeypatch.setenv("DIFY_SETUP_READY", "true")
|
||||
with _mock_db():
|
||||
_CONTROLLER_MODULE = importlib.import_module(module_name)
|
||||
|
||||
module = _CONTROLLER_MODULE
|
||||
monkeypatch.setattr(module, "jsonable_encoder", lambda payload: payload)
|
||||
|
||||
# Ensure decorators that consult deployment edition do not reach the database.
|
||||
global _WRAPS_MODULE
|
||||
wraps_module = importlib.import_module("controllers.console.wraps")
|
||||
_WRAPS_MODULE = wraps_module
|
||||
monkeypatch.setattr(module.dify_config, "EDITION", "CLOUD")
|
||||
monkeypatch.setattr(wraps_module.dify_config, "EDITION", "CLOUD")
|
||||
|
||||
login_module = importlib.import_module("libs.login")
|
||||
monkeypatch.setattr(login_module, "check_csrf_token", lambda *args, **kwargs: None)
|
||||
return module
|
||||
|
||||
|
||||
def _mock_account(user_id: str = "user-123") -> SimpleNamespace:
|
||||
return SimpleNamespace(id=user_id, status="active", is_authenticated=True, current_tenant_id=None)
|
||||
|
||||
|
||||
def _set_current_account(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
controller_module: ModuleType,
|
||||
user: SimpleNamespace,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
def _getter():
|
||||
return user, tenant_id
|
||||
|
||||
user.current_tenant_id = tenant_id
|
||||
|
||||
monkeypatch.setattr(controller_module, "current_account_with_tenant", _getter)
|
||||
if _WRAPS_MODULE is not None:
|
||||
monkeypatch.setattr(_WRAPS_MODULE, "current_account_with_tenant", _getter)
|
||||
|
||||
login_module = importlib.import_module("libs.login")
|
||||
monkeypatch.setattr(login_module, "_get_user", lambda: user)
|
||||
|
||||
|
||||
def test_tool_provider_list_calls_service_with_query(
|
||||
app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-456")
|
||||
|
||||
service_mock = MagicMock(return_value=[{"provider": "builtin"}])
|
||||
monkeypatch.setattr(controller_module.ToolCommonService, "list_tool_providers", service_mock)
|
||||
|
||||
with app.test_request_context("/workspaces/current/tool-providers?type=builtin"):
|
||||
response = controller_module.ToolProviderListApi().get()
|
||||
|
||||
assert response == [{"provider": "builtin"}]
|
||||
service_mock.assert_called_once_with(user.id, "tenant-456", "builtin")
|
||||
|
||||
|
||||
def test_builtin_provider_add_passes_payload(
|
||||
app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-456")
|
||||
|
||||
service_mock = MagicMock(return_value={"status": "ok"})
|
||||
monkeypatch.setattr(controller_module.BuiltinToolManageService, "add_builtin_tool_provider", service_mock)
|
||||
|
||||
payload = {
|
||||
"credentials": {"api_key": "sk-test"},
|
||||
"name": "MyTool",
|
||||
"type": controller_module.CredentialType.API_KEY,
|
||||
}
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/tool-provider/builtin/openai/add",
|
||||
method="POST",
|
||||
json=payload,
|
||||
):
|
||||
response = controller_module.ToolBuiltinProviderAddApi().post(provider="openai")
|
||||
|
||||
assert response == {"status": "ok"}
|
||||
service_mock.assert_called_once_with(
|
||||
user_id="user-123",
|
||||
tenant_id="tenant-456",
|
||||
provider="openai",
|
||||
credentials={"api_key": "sk-test"},
|
||||
name="MyTool",
|
||||
api_type=controller_module.CredentialType.API_KEY,
|
||||
)
|
||||
|
||||
|
||||
def test_builtin_provider_tools_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-tenant-789")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-789")
|
||||
|
||||
service_mock = MagicMock(return_value=[{"name": "tool-a"}])
|
||||
monkeypatch.setattr(controller_module.BuiltinToolManageService, "list_builtin_tool_provider_tools", service_mock)
|
||||
monkeypatch.setattr(controller_module, "jsonable_encoder", lambda payload: payload)
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/tool-provider/builtin/my-provider/tools",
|
||||
method="GET",
|
||||
):
|
||||
response = controller_module.ToolBuiltinProviderListToolsApi().get(provider="my-provider")
|
||||
|
||||
assert response == [{"name": "tool-a"}]
|
||||
service_mock.assert_called_once_with("tenant-789", "my-provider")
|
||||
|
||||
|
||||
def test_builtin_provider_info_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-tenant-9")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-9")
|
||||
service_mock = MagicMock(return_value={"info": True})
|
||||
monkeypatch.setattr(controller_module.BuiltinToolManageService, "get_builtin_tool_provider_info", service_mock)
|
||||
|
||||
with app.test_request_context("/info", method="GET"):
|
||||
resp = controller_module.ToolBuiltinProviderInfoApi().get(provider="demo")
|
||||
|
||||
assert resp == {"info": True}
|
||||
service_mock.assert_called_once_with("tenant-9", "demo")
|
||||
|
||||
|
||||
def test_builtin_provider_credentials_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-tenant-cred")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-cred")
|
||||
service_mock = MagicMock(return_value=[{"cred": 1}])
|
||||
monkeypatch.setattr(
|
||||
controller_module.BuiltinToolManageService,
|
||||
"get_builtin_tool_provider_credentials",
|
||||
service_mock,
|
||||
)
|
||||
|
||||
with app.test_request_context("/creds", method="GET"):
|
||||
resp = controller_module.ToolBuiltinProviderGetCredentialsApi().get(provider="demo")
|
||||
|
||||
assert resp == [{"cred": 1}]
|
||||
service_mock.assert_called_once_with(tenant_id="tenant-cred", provider_name="demo")
|
||||
|
||||
|
||||
def test_api_provider_remote_schema_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-10")
|
||||
service_mock = MagicMock(return_value={"schema": "ok"})
|
||||
monkeypatch.setattr(controller_module.ApiToolManageService, "get_api_tool_provider_remote_schema", service_mock)
|
||||
|
||||
with app.test_request_context("/remote?url=https://example.com/"):
|
||||
resp = controller_module.ToolApiProviderGetRemoteSchemaApi().get()
|
||||
|
||||
assert resp == {"schema": "ok"}
|
||||
service_mock.assert_called_once_with(user.id, "tenant-10", "https://example.com/")
|
||||
|
||||
|
||||
def test_api_provider_list_tools_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-11")
|
||||
service_mock = MagicMock(return_value=[{"tool": "t"}])
|
||||
monkeypatch.setattr(controller_module.ApiToolManageService, "list_api_tool_provider_tools", service_mock)
|
||||
|
||||
with app.test_request_context("/tools?provider=foo"):
|
||||
resp = controller_module.ToolApiProviderListToolsApi().get()
|
||||
|
||||
assert resp == [{"tool": "t"}]
|
||||
service_mock.assert_called_once_with(user.id, "tenant-11", "foo")
|
||||
|
||||
|
||||
def test_api_provider_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-12")
|
||||
service_mock = MagicMock(return_value={"provider": "foo"})
|
||||
monkeypatch.setattr(controller_module.ApiToolManageService, "get_api_tool_provider", service_mock)
|
||||
|
||||
with app.test_request_context("/get?provider=foo"):
|
||||
resp = controller_module.ToolApiProviderGetApi().get()
|
||||
|
||||
assert resp == {"provider": "foo"}
|
||||
service_mock.assert_called_once_with(user.id, "tenant-12", "foo")
|
||||
|
||||
|
||||
def test_builtin_provider_credentials_schema_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-tenant-13")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-13")
|
||||
service_mock = MagicMock(return_value={"schema": True})
|
||||
monkeypatch.setattr(
|
||||
controller_module.BuiltinToolManageService,
|
||||
"list_builtin_provider_credentials_schema",
|
||||
service_mock,
|
||||
)
|
||||
|
||||
with app.test_request_context("/schema", method="GET"):
|
||||
resp = controller_module.ToolBuiltinProviderCredentialsSchemaApi().get(
|
||||
provider="demo", credential_type="api-key"
|
||||
)
|
||||
|
||||
assert resp == {"schema": True}
|
||||
service_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_workflow_provider_get_by_tool(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-wf")
|
||||
tool_service = MagicMock(return_value={"wf": 1})
|
||||
monkeypatch.setattr(
|
||||
controller_module.WorkflowToolManageService,
|
||||
"get_workflow_tool_by_tool_id",
|
||||
tool_service,
|
||||
)
|
||||
|
||||
tool_id = "00000000-0000-0000-0000-000000000001"
|
||||
with app.test_request_context(f"/workflow?workflow_tool_id={tool_id}"):
|
||||
resp = controller_module.ToolWorkflowProviderGetApi().get()
|
||||
|
||||
assert resp == {"wf": 1}
|
||||
tool_service.assert_called_once_with(user.id, "tenant-wf", tool_id)
|
||||
|
||||
|
||||
def test_workflow_provider_get_by_app(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-wf2")
|
||||
service_mock = MagicMock(return_value={"app": 1})
|
||||
monkeypatch.setattr(
|
||||
controller_module.WorkflowToolManageService,
|
||||
"get_workflow_tool_by_app_id",
|
||||
service_mock,
|
||||
)
|
||||
|
||||
app_id = "00000000-0000-0000-0000-000000000002"
|
||||
with app.test_request_context(f"/workflow?workflow_app_id={app_id}"):
|
||||
resp = controller_module.ToolWorkflowProviderGetApi().get()
|
||||
|
||||
assert resp == {"app": 1}
|
||||
service_mock.assert_called_once_with(user.id, "tenant-wf2", app_id)
|
||||
|
||||
|
||||
def test_workflow_provider_list_tools(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-wf3")
|
||||
service_mock = MagicMock(return_value=[{"id": 1}])
|
||||
monkeypatch.setattr(controller_module.WorkflowToolManageService, "list_single_workflow_tools", service_mock)
|
||||
|
||||
tool_id = "00000000-0000-0000-0000-000000000003"
|
||||
with app.test_request_context(f"/workflow/tools?workflow_tool_id={tool_id}"):
|
||||
resp = controller_module.ToolWorkflowProviderListToolApi().get()
|
||||
|
||||
assert resp == [{"id": 1}]
|
||||
service_mock.assert_called_once_with(user.id, "tenant-wf3", tool_id)
|
||||
|
||||
|
||||
def test_builtin_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-bt")
|
||||
|
||||
provider = SimpleNamespace(to_dict=lambda: {"name": "builtin"})
|
||||
monkeypatch.setattr(
|
||||
controller_module.BuiltinToolManageService,
|
||||
"list_builtin_tools",
|
||||
MagicMock(return_value=[provider]),
|
||||
)
|
||||
|
||||
with app.test_request_context("/tools/builtin"):
|
||||
resp = controller_module.ToolBuiltinListApi().get()
|
||||
|
||||
assert resp == [{"name": "builtin"}]
|
||||
|
||||
|
||||
def test_api_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-tenant-api")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-api")
|
||||
|
||||
provider = SimpleNamespace(to_dict=lambda: {"name": "api"})
|
||||
monkeypatch.setattr(
|
||||
controller_module.ApiToolManageService,
|
||||
"list_api_tools",
|
||||
MagicMock(return_value=[provider]),
|
||||
)
|
||||
|
||||
with app.test_request_context("/tools/api"):
|
||||
resp = controller_module.ToolApiListApi().get()
|
||||
|
||||
assert resp == [{"name": "api"}]
|
||||
|
||||
|
||||
def test_workflow_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-wf4")
|
||||
|
||||
provider = SimpleNamespace(to_dict=lambda: {"name": "wf"})
|
||||
monkeypatch.setattr(
|
||||
controller_module.WorkflowToolManageService,
|
||||
"list_tenant_workflow_tools",
|
||||
MagicMock(return_value=[provider]),
|
||||
)
|
||||
|
||||
with app.test_request_context("/tools/workflow"):
|
||||
resp = controller_module.ToolWorkflowListApi().get()
|
||||
|
||||
assert resp == [{"name": "wf"}]
|
||||
|
||||
|
||||
def test_tool_labels_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-label")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-labels")
|
||||
monkeypatch.setattr(controller_module.ToolLabelsService, "list_tool_labels", lambda: ["a", "b"])
|
||||
|
||||
with app.test_request_context("/tool-labels"):
|
||||
resp = controller_module.ToolLabelsApi().get()
|
||||
|
||||
assert resp == ["a", "b"]
|
||||
12
api/ty.toml
12
api/ty.toml
@ -1,16 +1,15 @@
|
||||
[src]
|
||||
exclude = [
|
||||
# deps groups (A1/A2/B/C/D/E)
|
||||
# A2: workflow engine/nodes
|
||||
"core/workflow",
|
||||
"core/app/workflow",
|
||||
"core/helper/code_executor",
|
||||
# B: app runner + prompt
|
||||
"core/prompt",
|
||||
"core/app/apps/base_app_runner.py",
|
||||
"core/app/apps/workflow_app_runner.py",
|
||||
"core/agent",
|
||||
"core/plugin",
|
||||
# C: services/controllers/fields/libs
|
||||
"services",
|
||||
"controllers/inner_api",
|
||||
"controllers/console/app",
|
||||
"controllers/console/explore",
|
||||
"controllers/console/datasets",
|
||||
@ -28,3 +27,8 @@ exclude = [
|
||||
"tests",
|
||||
]
|
||||
|
||||
|
||||
[rules]
|
||||
deprecated = "ignore"
|
||||
unused-ignore-comment = "ignore"
|
||||
# possibly-missing-attribute = "ignore"
|
||||
12
api/uv.lock
generated
12
api/uv.lock
generated
@ -1716,7 +1716,7 @@ requires-dist = [
|
||||
{ name = "flask-login", specifier = "~=0.6.3" },
|
||||
{ name = "flask-migrate", specifier = "~=4.0.7" },
|
||||
{ name = "flask-orjson", specifier = "~=2.0.0" },
|
||||
{ name = "flask-restx", specifier = "~=1.3.0" },
|
||||
{ name = "flask-restx", specifier = "~=1.3.2" },
|
||||
{ name = "flask-sqlalchemy", specifier = "~=3.1.1" },
|
||||
{ name = "gevent", specifier = "~=25.9.1" },
|
||||
{ name = "gevent-websocket", specifier = "~=0.10.1" },
|
||||
@ -1814,7 +1814,7 @@ dev = [
|
||||
{ name = "scipy-stubs", specifier = ">=1.15.3.0" },
|
||||
{ name = "sseclient-py", specifier = ">=1.8.0" },
|
||||
{ name = "testcontainers", specifier = "~=4.13.2" },
|
||||
{ name = "ty", specifier = "~=0.0.1a19" },
|
||||
{ name = "ty", specifier = ">=0.0.14" },
|
||||
{ name = "types-aiofiles", specifier = "~=24.1.0" },
|
||||
{ name = "types-beautifulsoup4", specifier = "~=4.12.0" },
|
||||
{ name = "types-cachetools", specifier = "~=5.5.0" },
|
||||
@ -1837,7 +1837,7 @@ dev = [
|
||||
{ name = "types-openpyxl", specifier = "~=3.1.5" },
|
||||
{ name = "types-pexpect", specifier = "~=4.9.0" },
|
||||
{ name = "types-protobuf", specifier = "~=5.29.1" },
|
||||
{ name = "types-psutil", specifier = "~=7.0.0" },
|
||||
{ name = "types-psutil", specifier = "~=7.2.2" },
|
||||
{ name = "types-psycopg2", specifier = "~=2.9.21" },
|
||||
{ name = "types-pygments", specifier = "~=2.19.0" },
|
||||
{ name = "types-pymysql", specifier = "~=1.1.0" },
|
||||
@ -6779,11 +6779,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "types-psutil"
|
||||
version = "7.0.0.20251116"
|
||||
version = "7.2.2.20260130"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/47/ec/c1e9308b91582cad1d7e7d3007fd003ef45a62c2500f8219313df5fc3bba/types_psutil-7.0.0.20251116.tar.gz", hash = "sha256:92b5c78962e55ce1ed7b0189901a4409ece36ab9fd50c3029cca7e681c606c8a", size = 22192, upload-time = "2025-11-16T03:10:32.859Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/69/14/fc5fb0a6ddfadf68c27e254a02ececd4d5c7fdb0efcb7e7e917a183497fb/types_psutil-7.2.2.20260130.tar.gz", hash = "sha256:15b0ab69c52841cf9ce3c383e8480c620a4d13d6a8e22b16978ebddac5590950", size = 26535, upload-time = "2026-01-30T03:58:14.116Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c3/0e/11ba08a5375c21039ed5f8e6bba41e9452fb69f0e2f7ee05ed5cca2a2cdf/types_psutil-7.0.0.20251116-py3-none-any.whl", hash = "sha256:74c052de077c2024b85cd435e2cba971165fe92a5eace79cbeb821e776dbc047", size = 25376, upload-time = "2025-11-16T03:10:31.813Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/17/d7/60974b7e31545d3768d1770c5fe6e093182c3bfd819429b33133ba6b3e89/types_psutil-7.2.2.20260130-py3-none-any.whl", hash = "sha256:15523a3caa7b3ff03ac7f9b78a6470a59f88f48df1d74a39e70e06d2a99107da", size = 32876, upload-time = "2026-01-30T03:58:13.172Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
#!/bin/bash
|
||||
set -x
|
||||
set -euxo pipefail
|
||||
|
||||
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
|
||||
cd "$SCRIPT_DIR/../.."
|
||||
|
||||
@ -663,13 +663,14 @@ services:
|
||||
- "${IRIS_SUPER_SERVER_PORT:-1972}:1972"
|
||||
- "${IRIS_WEB_SERVER_PORT:-52773}:52773"
|
||||
volumes:
|
||||
- ./volumes/iris:/opt/iris
|
||||
- ./volumes/iris:/durable
|
||||
- ./iris/iris-init.script:/iris-init.script
|
||||
- ./iris/docker-entrypoint.sh:/custom-entrypoint.sh
|
||||
entrypoint: ["/custom-entrypoint.sh"]
|
||||
tty: true
|
||||
environment:
|
||||
TZ: ${IRIS_TIMEZONE:-UTC}
|
||||
ISC_DATA_DIRECTORY: /durable/iris
|
||||
|
||||
# Oracle vector database
|
||||
oracle:
|
||||
|
||||
@ -1351,13 +1351,14 @@ services:
|
||||
- "${IRIS_SUPER_SERVER_PORT:-1972}:1972"
|
||||
- "${IRIS_WEB_SERVER_PORT:-52773}:52773"
|
||||
volumes:
|
||||
- ./volumes/iris:/opt/iris
|
||||
- ./volumes/iris:/durable
|
||||
- ./iris/iris-init.script:/iris-init.script
|
||||
- ./iris/docker-entrypoint.sh:/custom-entrypoint.sh
|
||||
entrypoint: ["/custom-entrypoint.sh"]
|
||||
tty: true
|
||||
environment:
|
||||
TZ: ${IRIS_TIMEZONE:-UTC}
|
||||
ISC_DATA_DIRECTORY: /durable/iris
|
||||
|
||||
# Oracle vector database
|
||||
oracle:
|
||||
|
||||
@ -1,15 +1,33 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# IRIS configuration flag file
|
||||
IRIS_CONFIG_DONE="/opt/iris/.iris-configured"
|
||||
# IRIS configuration flag file (stored in durable directory to persist with data)
|
||||
IRIS_CONFIG_DONE="/durable/.iris-configured"
|
||||
|
||||
# Function to wait for IRIS to be ready
|
||||
wait_for_iris() {
|
||||
echo "Waiting for IRIS to be ready..."
|
||||
local max_attempts=30
|
||||
local attempt=1
|
||||
while [ "$attempt" -le "$max_attempts" ]; do
|
||||
if iris qlist IRIS 2>/dev/null | grep -q "running"; then
|
||||
echo "IRIS is ready."
|
||||
return 0
|
||||
fi
|
||||
echo "Attempt $attempt/$max_attempts: IRIS not ready yet, waiting..."
|
||||
sleep 2
|
||||
attempt=$((attempt + 1))
|
||||
done
|
||||
echo "ERROR: IRIS failed to start within expected time." >&2
|
||||
return 1
|
||||
}
|
||||
|
||||
# Function to configure IRIS
|
||||
configure_iris() {
|
||||
echo "Configuring IRIS for first-time setup..."
|
||||
|
||||
# Wait for IRIS to be fully started
|
||||
sleep 5
|
||||
wait_for_iris
|
||||
|
||||
# Execute the initialization script
|
||||
iris session IRIS < /iris-init.script
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import Cookies from 'js-cookie'
|
||||
import { usePathname, useRouter, useSearchParams } from 'next/navigation'
|
||||
import { parseAsString, useQueryState } from 'nuqs'
|
||||
import { parseAsBoolean, useQueryState } from 'nuqs'
|
||||
import { useCallback, useEffect, useState } from 'react'
|
||||
import {
|
||||
EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION,
|
||||
@ -28,7 +28,7 @@ export const AppInitializer = ({
|
||||
const [init, setInit] = useState(false)
|
||||
const [oauthNewUser, setOauthNewUser] = useQueryState(
|
||||
'oauth_new_user',
|
||||
parseAsString.withOptions({ history: 'replace' }),
|
||||
parseAsBoolean.withOptions({ history: 'replace' }),
|
||||
)
|
||||
|
||||
const isSetupFinished = useCallback(async () => {
|
||||
@ -46,7 +46,7 @@ export const AppInitializer = ({
|
||||
(async () => {
|
||||
const action = searchParams.get('action')
|
||||
|
||||
if (oauthNewUser === 'true') {
|
||||
if (oauthNewUser) {
|
||||
let utmInfo = null
|
||||
const utmInfoStr = Cookies.get('utm_info')
|
||||
if (utmInfoStr) {
|
||||
|
||||
@ -62,19 +62,19 @@ const AppCard = ({
|
||||
{app.description}
|
||||
</div>
|
||||
</div>
|
||||
{canCreate && (
|
||||
{(canCreate || isTrialApp) && (
|
||||
<div className={cn('absolute bottom-0 left-0 right-0 hidden bg-gradient-to-t from-components-panel-gradient-2 from-[60.27%] to-transparent p-4 pt-8 group-hover:flex')}>
|
||||
<div className={cn('grid h-8 w-full grid-cols-1 items-center space-x-2', isTrialApp && 'grid-cols-2')}>
|
||||
<Button variant="primary" onClick={() => onCreate()}>
|
||||
<PlusIcon className="mr-1 h-4 w-4" />
|
||||
<span className="text-xs">{t('newApp.useTemplate', { ns: 'app' })}</span>
|
||||
</Button>
|
||||
{isTrialApp && (
|
||||
<Button onClick={showTryAPPPanel(app.app_id)}>
|
||||
<RiInformation2Line className="mr-1 size-4" />
|
||||
<span>{t('appCard.try', { ns: 'explore' })}</span>
|
||||
<div className={cn('grid h-8 w-full grid-cols-1 items-center space-x-2', canCreate && 'grid-cols-2')}>
|
||||
{canCreate && (
|
||||
<Button variant="primary" onClick={() => onCreate()}>
|
||||
<PlusIcon className="mr-1 h-4 w-4" />
|
||||
<span className="text-xs">{t('newApp.useTemplate', { ns: 'app' })}</span>
|
||||
</Button>
|
||||
)}
|
||||
<Button onClick={showTryAPPPanel(app.app_id)}>
|
||||
<RiInformation2Line className="mr-1 size-4" />
|
||||
<span>{t('appCard.try', { ns: 'explore' })}</span>
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@ -124,7 +124,7 @@ describe('CreateAppModal', () => {
|
||||
|
||||
const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder')
|
||||
fireEvent.change(nameInput, { target: { value: 'My App' } })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' }))
|
||||
fireEvent.click(screen.getByRole('button', { name: /app\.newApp\.Create/ }))
|
||||
|
||||
await waitFor(() => expect(mockCreateApp).toHaveBeenCalledWith({
|
||||
name: 'My App',
|
||||
@ -152,7 +152,7 @@ describe('CreateAppModal', () => {
|
||||
|
||||
const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder')
|
||||
fireEvent.change(nameInput, { target: { value: 'My App' } })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' }))
|
||||
fireEvent.click(screen.getByRole('button', { name: /app\.newApp\.Create/ }))
|
||||
|
||||
await waitFor(() => expect(mockCreateApp).toHaveBeenCalled())
|
||||
expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'boom' })
|
||||
|
||||
@ -3,8 +3,6 @@ import type { FC, ReactNode } from 'react'
|
||||
import type { SliceProps } from './type'
|
||||
import { autoUpdate, flip, FloatingFocusManager, offset, shift, useDismiss, useFloating, useHover, useInteractions, useRole } from '@floating-ui/react'
|
||||
import { RiDeleteBinLine } from '@remixicon/react'
|
||||
// @ts-expect-error no types available
|
||||
import lineClamp from 'line-clamp'
|
||||
import { useState } from 'react'
|
||||
import ActionButton, { ActionButtonState } from '@/app/components/base/action-button'
|
||||
import { cn } from '@/utils/classnames'
|
||||
@ -58,12 +56,8 @@ export const EditSlice: FC<EditSliceProps> = (props) => {
|
||||
<>
|
||||
<SliceContainer
|
||||
{...rest}
|
||||
className={cn('mr-0 block', className)}
|
||||
ref={(ref) => {
|
||||
refs.setReference(ref)
|
||||
if (ref)
|
||||
lineClamp(ref, 4)
|
||||
}}
|
||||
className={cn('mr-0 line-clamp-4 block', className)}
|
||||
ref={refs.setReference}
|
||||
{...getReferenceProps()}
|
||||
>
|
||||
<SliceLabel
|
||||
|
||||
@ -74,11 +74,15 @@ const AppCard = ({
|
||||
</div>
|
||||
{isExplore && (canCreate || isTrialApp) && (
|
||||
<div className={cn('absolute bottom-0 left-0 right-0 hidden bg-gradient-to-t from-components-panel-gradient-2 from-[60.27%] to-transparent p-4 pt-8 group-hover:flex')}>
|
||||
<div className={cn('grid h-8 w-full grid-cols-2 space-x-2')}>
|
||||
<Button variant="primary" className="h-7" onClick={() => onCreate()}>
|
||||
<PlusIcon className="mr-1 h-4 w-4" />
|
||||
<span className="text-xs">{t('appCard.addToWorkspace', { ns: 'explore' })}</span>
|
||||
</Button>
|
||||
<div className={cn('grid h-8 w-full grid-cols-1 space-x-2', canCreate && 'grid-cols-2')}>
|
||||
{
|
||||
canCreate && (
|
||||
<Button variant="primary" className="h-7" onClick={() => onCreate()}>
|
||||
<PlusIcon className="mr-1 h-4 w-4" />
|
||||
<span className="text-xs">{t('appCard.addToWorkspace', { ns: 'explore' })}</span>
|
||||
</Button>
|
||||
)
|
||||
}
|
||||
<Button className="h-7" onClick={showTryAPPPanel(app.app_id)}>
|
||||
<RiInformation2Line className="mr-1 size-4" />
|
||||
<span>{t('appCard.try', { ns: 'explore' })}</span>
|
||||
|
||||
@ -138,7 +138,7 @@ describe('CreateAppModal', () => {
|
||||
setup({ appName: 'My App', isEditModal: false })
|
||||
|
||||
expect(screen.getByText('explore.appCustomize.title:{"name":"My App"}')).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: 'common.operation.create' })).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: /common\.operation\.create/ })).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: 'common.operation.cancel' })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
@ -146,7 +146,7 @@ describe('CreateAppModal', () => {
|
||||
setup({ isEditModal: true, appMode: AppModeEnum.CHAT, max_active_requests: 5 })
|
||||
|
||||
expect(screen.getByText('app.editAppTitle')).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: /common\.operation\.save/ })).toBeInTheDocument()
|
||||
expect(screen.getByRole('switch')).toBeInTheDocument()
|
||||
expect((screen.getByRole('spinbutton') as HTMLInputElement).value).toBe('5')
|
||||
})
|
||||
@ -166,7 +166,7 @@ describe('CreateAppModal', () => {
|
||||
it('should not render modal content when hidden', () => {
|
||||
setup({ show: false })
|
||||
|
||||
expect(screen.queryByRole('button', { name: 'common.operation.create' })).not.toBeInTheDocument()
|
||||
expect(screen.queryByRole('button', { name: /common\.operation\.create/ })).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@ -175,13 +175,13 @@ describe('CreateAppModal', () => {
|
||||
it('should disable confirm action when confirmDisabled is true', () => {
|
||||
setup({ confirmDisabled: true })
|
||||
|
||||
expect(screen.getByRole('button', { name: 'common.operation.create' })).toBeDisabled()
|
||||
expect(screen.getByRole('button', { name: /common\.operation\.create/ })).toBeDisabled()
|
||||
})
|
||||
|
||||
it('should disable confirm action when appName is empty', () => {
|
||||
setup({ appName: ' ' })
|
||||
|
||||
expect(screen.getByRole('button', { name: 'common.operation.create' })).toBeDisabled()
|
||||
expect(screen.getByRole('button', { name: /common\.operation\.create/ })).toBeDisabled()
|
||||
})
|
||||
})
|
||||
|
||||
@ -245,7 +245,7 @@ describe('CreateAppModal', () => {
|
||||
setup({ isEditModal: false })
|
||||
|
||||
expect(screen.getByText('billing.apps.fullTip2')).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: 'common.operation.create' })).toBeDisabled()
|
||||
expect(screen.getByRole('button', { name: /common\.operation\.create/ })).toBeDisabled()
|
||||
})
|
||||
|
||||
it('should allow saving when apps quota is reached in edit mode', () => {
|
||||
@ -257,7 +257,7 @@ describe('CreateAppModal', () => {
|
||||
setup({ isEditModal: true })
|
||||
|
||||
expect(screen.queryByText('billing.apps.fullTip2')).not.toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeEnabled()
|
||||
expect(screen.getByRole('button', { name: /common\.operation\.save/ })).toBeEnabled()
|
||||
})
|
||||
})
|
||||
|
||||
@ -384,7 +384,7 @@ describe('CreateAppModal', () => {
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'app.iconPicker.ok' }))
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' }))
|
||||
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ }))
|
||||
act(() => {
|
||||
vi.advanceTimersByTime(300)
|
||||
})
|
||||
@ -433,7 +433,7 @@ describe('CreateAppModal', () => {
|
||||
expect(screen.queryByRole('button', { name: 'app.iconPicker.cancel' })).not.toBeInTheDocument()
|
||||
|
||||
// Submit and verify the payload uses the original icon (cancel reverts to props)
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' }))
|
||||
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ }))
|
||||
act(() => {
|
||||
vi.advanceTimersByTime(300)
|
||||
})
|
||||
@ -471,7 +471,7 @@ describe('CreateAppModal', () => {
|
||||
appIconBackground: '#000000',
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' }))
|
||||
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ }))
|
||||
act(() => {
|
||||
vi.advanceTimersByTime(300)
|
||||
})
|
||||
@ -495,7 +495,7 @@ describe('CreateAppModal', () => {
|
||||
const { onConfirm } = setup({ appDescription: 'Old description' })
|
||||
|
||||
fireEvent.change(screen.getByPlaceholderText('app.newApp.appDescriptionPlaceholder'), { target: { value: 'Updated description' } })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' }))
|
||||
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ }))
|
||||
act(() => {
|
||||
vi.advanceTimersByTime(300)
|
||||
})
|
||||
@ -512,7 +512,7 @@ describe('CreateAppModal', () => {
|
||||
appIconBackground: null,
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' }))
|
||||
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ }))
|
||||
act(() => {
|
||||
vi.advanceTimersByTime(300)
|
||||
})
|
||||
@ -536,7 +536,7 @@ describe('CreateAppModal', () => {
|
||||
fireEvent.click(screen.getByRole('switch'))
|
||||
fireEvent.change(screen.getByRole('spinbutton'), { target: { value: '12' } })
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' }))
|
||||
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.save/ }))
|
||||
act(() => {
|
||||
vi.advanceTimersByTime(300)
|
||||
})
|
||||
@ -551,7 +551,7 @@ describe('CreateAppModal', () => {
|
||||
it('should omit max_active_requests when input is empty', () => {
|
||||
const { onConfirm } = setup({ isEditModal: true, max_active_requests: null })
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' }))
|
||||
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.save/ }))
|
||||
act(() => {
|
||||
vi.advanceTimersByTime(300)
|
||||
})
|
||||
@ -564,7 +564,7 @@ describe('CreateAppModal', () => {
|
||||
const { onConfirm } = setup({ isEditModal: true, max_active_requests: null })
|
||||
|
||||
fireEvent.change(screen.getByRole('spinbutton'), { target: { value: 'abc' } })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' }))
|
||||
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.save/ }))
|
||||
act(() => {
|
||||
vi.advanceTimersByTime(300)
|
||||
})
|
||||
@ -576,7 +576,7 @@ describe('CreateAppModal', () => {
|
||||
it('should show toast error and not submit when name becomes empty before debounced submit runs', () => {
|
||||
const { onConfirm, onHide } = setup({ appName: 'My App' })
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' }))
|
||||
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ }))
|
||||
fireEvent.change(screen.getByPlaceholderText('app.newApp.appNamePlaceholder'), { target: { value: ' ' } })
|
||||
|
||||
act(() => {
|
||||
|
||||
@ -16,6 +16,14 @@ vi.mock('react-i18next', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/config', async (importOriginal) => {
|
||||
const actual = await importOriginal() as object
|
||||
return {
|
||||
...actual,
|
||||
IS_CLOUD_EDITION: true,
|
||||
}
|
||||
})
|
||||
|
||||
const mockUseGetTryAppInfo = vi.fn()
|
||||
|
||||
vi.mock('@/service/use-try-app', () => ({
|
||||
|
||||
@ -14,6 +14,14 @@ vi.mock('react-i18next', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/config', async (importOriginal) => {
|
||||
const actual = await importOriginal() as object
|
||||
return {
|
||||
...actual,
|
||||
IS_CLOUD_EDITION: true,
|
||||
}
|
||||
})
|
||||
|
||||
describe('Tab', () => {
|
||||
afterEach(() => {
|
||||
cleanup()
|
||||
|
||||
@ -81,4 +81,205 @@ describe('CommandSelector', () => {
|
||||
|
||||
expect(onSelect).toHaveBeenCalledWith('/zen')
|
||||
})
|
||||
|
||||
it('should show all slash commands when no filter provided', () => {
|
||||
const actions = createActions()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter=""
|
||||
originalQuery="/"
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
// Should show the zen command from mock
|
||||
expect(screen.getByText('/zen')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should exclude slash action when in @ mode', () => {
|
||||
const actions = {
|
||||
...createActions(),
|
||||
slash: {
|
||||
key: '/',
|
||||
shortcut: '/',
|
||||
title: 'Slash',
|
||||
search: vi.fn(),
|
||||
description: '',
|
||||
} as ActionItem,
|
||||
}
|
||||
const onSelect = vi.fn()
|
||||
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter=""
|
||||
originalQuery="@"
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
// Should show @ commands but not /
|
||||
expect(screen.getByText('@app')).toBeInTheDocument()
|
||||
expect(screen.queryByText('/')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show all actions when no filter in @ mode', () => {
|
||||
const actions = createActions()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter=""
|
||||
originalQuery="@"
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('@app')).toBeInTheDocument()
|
||||
expect(screen.getByText('@plugin')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should set default command value when items exist but value does not', () => {
|
||||
const actions = createActions()
|
||||
const onSelect = vi.fn()
|
||||
const onCommandValueChange = vi.fn()
|
||||
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter=""
|
||||
originalQuery="@"
|
||||
commandValue="non-existent"
|
||||
onCommandValueChange={onCommandValueChange}
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
expect(onCommandValueChange).toHaveBeenCalledWith('@app')
|
||||
})
|
||||
|
||||
it('should NOT set command value when value already exists in items', () => {
|
||||
const actions = createActions()
|
||||
const onSelect = vi.fn()
|
||||
const onCommandValueChange = vi.fn()
|
||||
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter=""
|
||||
originalQuery="@"
|
||||
commandValue="@app"
|
||||
onCommandValueChange={onCommandValueChange}
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
expect(onCommandValueChange).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should show no matching commands message when filter has no results', () => {
|
||||
const actions = createActions()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter="nonexistent"
|
||||
originalQuery="@nonexistent"
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument()
|
||||
expect(screen.getByText('app.gotoAnything.tryDifferentSearch')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show no matching commands for slash mode with no results', () => {
|
||||
const actions = createActions()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter="nonexistentcommand"
|
||||
originalQuery="/nonexistentcommand"
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render description for @ commands', () => {
|
||||
const actions = createActions()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter=""
|
||||
originalQuery="@"
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.actions.searchApplicationsDesc')).toBeInTheDocument()
|
||||
expect(screen.getByText('app.gotoAnything.actions.searchPluginsDesc')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render group header for @ mode', () => {
|
||||
const actions = createActions()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter=""
|
||||
originalQuery="@"
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.selectSearchType')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render group header for slash mode', () => {
|
||||
const actions = createActions()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter=""
|
||||
originalQuery="/"
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.groups.commands')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
157
web/app/components/goto-anything/components/empty-state.spec.tsx
Normal file
157
web/app/components/goto-anything/components/empty-state.spec.tsx
Normal file
@ -0,0 +1,157 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import EmptyState from './empty-state'
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string, options?: { ns?: string, shortcuts?: string }) => {
|
||||
if (options?.shortcuts !== undefined)
|
||||
return `${key}:${options.shortcuts}`
|
||||
return `${options?.ns || 'common'}.${key}`
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
describe('EmptyState', () => {
|
||||
describe('loading variant', () => {
|
||||
it('should render loading spinner', () => {
|
||||
render(<EmptyState variant="loading" />)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.searching')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should have spinner animation class', () => {
|
||||
const { container } = render(<EmptyState variant="loading" />)
|
||||
|
||||
const spinner = container.querySelector('.animate-spin')
|
||||
expect(spinner).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('error variant', () => {
|
||||
it('should render error message when error has message', () => {
|
||||
const error = new Error('Connection failed')
|
||||
render(<EmptyState variant="error" error={error} />)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.searchFailed')).toBeInTheDocument()
|
||||
expect(screen.getByText('Connection failed')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render generic error when error has no message', () => {
|
||||
render(<EmptyState variant="error" error={null} />)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.searchTemporarilyUnavailable')).toBeInTheDocument()
|
||||
expect(screen.getByText('app.gotoAnything.servicesUnavailableMessage')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render generic error when error is undefined', () => {
|
||||
render(<EmptyState variant="error" />)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.searchTemporarilyUnavailable')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should have red error text styling', () => {
|
||||
const error = new Error('Test error')
|
||||
const { container } = render(<EmptyState variant="error" error={error} />)
|
||||
|
||||
const errorText = container.querySelector('.text-red-500')
|
||||
expect(errorText).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('default variant', () => {
|
||||
it('should render search title', () => {
|
||||
render(<EmptyState variant="default" />)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.searchTitle')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render all hint messages', () => {
|
||||
render(<EmptyState variant="default" />)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.searchHint')).toBeInTheDocument()
|
||||
expect(screen.getByText('app.gotoAnything.commandHint')).toBeInTheDocument()
|
||||
expect(screen.getByText('app.gotoAnything.slashHint')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('no-results variant', () => {
|
||||
describe('general search mode', () => {
|
||||
it('should render generic no results message', () => {
|
||||
render(<EmptyState variant="no-results" searchMode="general" />)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.noResults')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show specific search hint with shortcuts', () => {
|
||||
const Actions = {
|
||||
app: { key: '@app', shortcut: '@app' },
|
||||
plugin: { key: '@plugin', shortcut: '@plugin' },
|
||||
} as unknown as Record<string, import('../actions/types').ActionItem>
|
||||
render(<EmptyState variant="no-results" searchMode="general" Actions={Actions} />)
|
||||
|
||||
expect(screen.getByText('gotoAnything.emptyState.trySpecificSearch:@app, @plugin')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('app search mode', () => {
|
||||
it('should render no apps found message', () => {
|
||||
render(<EmptyState variant="no-results" searchMode="@app" />)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.emptyState.noAppsFound')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show try different term hint', () => {
|
||||
render(<EmptyState variant="no-results" searchMode="@app" />)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.emptyState.tryDifferentTerm')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('plugin search mode', () => {
|
||||
it('should render no plugins found message', () => {
|
||||
render(<EmptyState variant="no-results" searchMode="@plugin" />)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.emptyState.noPluginsFound')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('knowledge search mode', () => {
|
||||
it('should render no knowledge bases found message', () => {
|
||||
render(<EmptyState variant="no-results" searchMode="@knowledge" />)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.emptyState.noKnowledgeBasesFound')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('node search mode', () => {
|
||||
it('should render no workflow nodes found message', () => {
|
||||
render(<EmptyState variant="no-results" searchMode="@node" />)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.emptyState.noWorkflowNodesFound')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('unknown search mode', () => {
|
||||
it('should fallback to generic no results message', () => {
|
||||
render(<EmptyState variant="no-results" searchMode="@unknown" />)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.noResults')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('default props', () => {
|
||||
it('should use general as default searchMode', () => {
|
||||
render(<EmptyState variant="no-results" />)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.noResults')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should use empty object as default Actions', () => {
|
||||
render(<EmptyState variant="no-results" searchMode="general" />)
|
||||
|
||||
// Should show empty shortcuts
|
||||
expect(screen.getByText('gotoAnything.emptyState.trySpecificSearch:')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user