Merge branch 'main' into jzh

This commit is contained in:
JzoNg 2026-02-03 14:42:30 +08:00
commit 0187838a54
152 changed files with 7420 additions and 2686 deletions

3
.github/CODEOWNERS vendored
View File

@ -9,6 +9,9 @@
# CODEOWNERS file
/.github/CODEOWNERS @laipz8200 @crazywoola
# Agents
/.agents/skills/ @hyoban
# Docs
/docs/ @crazywoola

View File

@ -53,6 +53,7 @@ select = [
"S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers.
"S302", # suspicious-marshal-usage, disallow use of `marshal` module
"S311", # suspicious-non-cryptographic-random-usage,
"TID", # flake8-tidy-imports
]
@ -88,6 +89,7 @@ ignore = [
"SIM113", # enumerate-for-loop
"SIM117", # multiple-with-statements
"SIM210", # if-expr-with-true-false
"TID252", # allow relative imports from parent modules
]
[lint.per-file-ignores]
@ -109,10 +111,20 @@ ignore = [
"S110", # allow ignoring exceptions in tests code (currently)
]
"controllers/console/explore/trial.py" = ["TID251"]
"controllers/console/human_input_form.py" = ["TID251"]
"controllers/web/human_input_form.py" = ["TID251"]
[lint.pyflakes]
allowed-unused-imports = [
"_pytest.monkeypatch",
"tests.integration_tests",
"tests.unit_tests",
]
[lint.flake8-tidy-imports]
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"]
msg = "Use Pydantic payload/query models instead of reqparse."
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse.RequestParser"]
msg = "Use Pydantic payload/query models instead of reqparse."

View File

@ -1,4 +1,12 @@
from __future__ import annotations
import sys
from typing import TYPE_CHECKING, cast
if TYPE_CHECKING:
from celery import Celery
celery: Celery
def is_db_command() -> bool:
@ -23,7 +31,7 @@ else:
from app_factory import create_app
app = create_app()
celery = app.extensions["celery"]
celery = cast("Celery", app.extensions["celery"])
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5001)

View File

@ -149,7 +149,7 @@ def initialize_extensions(app: DifyApp):
logger.info("Loaded %s (%s ms)", short_name, round((end_time - start_time) * 1000, 2))
def create_migrations_app():
def create_migrations_app() -> DifyApp:
app = create_flask_app_with_configs()
from extensions import ext_database, ext_migrate

View File

@ -1450,54 +1450,58 @@ def clear_orphaned_file_records(force: bool):
all_ids_in_tables = []
for ids_table in ids_tables:
query = ""
if ids_table["type"] == "uuid":
click.echo(
click.style(
f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}", fg="white"
match ids_table["type"]:
case "uuid":
click.echo(
click.style(
f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}",
fg="white",
)
)
)
query = (
f"SELECT {ids_table['column']} FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for i in rs:
all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])})
elif ids_table["type"] == "text":
click.echo(
click.style(
f"- Listing file-id-like strings in column {ids_table['column']} in table {ids_table['table']}",
fg="white",
c = ids_table["column"]
query = f"SELECT {c} FROM {ids_table['table']} WHERE {c} IS NOT NULL"
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for i in rs:
all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])})
case "text":
t = ids_table["table"]
click.echo(
click.style(
f"- Listing file-id-like strings in column {ids_table['column']} in table {t}",
fg="white",
)
)
)
query = (
f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id "
f"FROM {ids_table['table']}"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for i in rs:
for j in i[0]:
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
elif ids_table["type"] == "json":
click.echo(
click.style(
(
f"- Listing file-id-like JSON string in column {ids_table['column']} "
f"in table {ids_table['table']}"
),
fg="white",
query = (
f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id "
f"FROM {ids_table['table']}"
)
)
query = (
f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id "
f"FROM {ids_table['table']}"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for i in rs:
for j in i[0]:
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for i in rs:
for j in i[0]:
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
case "json":
click.echo(
click.style(
(
f"- Listing file-id-like JSON string in column {ids_table['column']} "
f"in table {ids_table['table']}"
),
fg="white",
)
)
query = (
f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id "
f"FROM {ids_table['table']}"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for i in rs:
for j in i[0]:
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
case _:
pass
click.echo(click.style(f"Found {len(all_ids_in_tables)} file ids in tables.", fg="white"))
except Exception as e:
@ -1737,59 +1741,18 @@ def file_usage(
if src_filter != src:
continue
if ids_table["type"] == "uuid":
# Direct UUID match
query = (
f"SELECT {ids_table['pk_column']}, {ids_table['column']} "
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for row in rs:
record_id = str(row[0])
ref_file_id = str(row[1])
if ref_file_id not in file_key_map:
continue
storage_key = file_key_map[ref_file_id]
# Apply filters
if file_id and ref_file_id != file_id:
continue
if key and not storage_key.endswith(key):
continue
# Only collect items within the requested page range
if offset <= total_count < offset + limit:
paginated_usages.append(
{
"src": f"{ids_table['table']}.{ids_table['column']}",
"record_id": record_id,
"file_id": ref_file_id,
"key": storage_key,
}
)
total_count += 1
elif ids_table["type"] in ("text", "json"):
# Extract UUIDs from text/json content
column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"]
query = (
f"SELECT {ids_table['pk_column']}, {column_cast} "
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for row in rs:
record_id = str(row[0])
content = str(row[1])
# Find all UUIDs in the content
import re
uuid_pattern = re.compile(guid_regexp, re.IGNORECASE)
matches = uuid_pattern.findall(content)
for ref_file_id in matches:
match ids_table["type"]:
case "uuid":
# Direct UUID match
query = (
f"SELECT {ids_table['pk_column']}, {ids_table['column']} "
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for row in rs:
record_id = str(row[0])
ref_file_id = str(row[1])
if ref_file_id not in file_key_map:
continue
storage_key = file_key_map[ref_file_id]
@ -1812,6 +1775,50 @@ def file_usage(
)
total_count += 1
case "text" | "json":
# Extract UUIDs from text/json content
column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"]
query = (
f"SELECT {ids_table['pk_column']}, {column_cast} "
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for row in rs:
record_id = str(row[0])
content = str(row[1])
# Find all UUIDs in the content
import re
uuid_pattern = re.compile(guid_regexp, re.IGNORECASE)
matches = uuid_pattern.findall(content)
for ref_file_id in matches:
if ref_file_id not in file_key_map:
continue
storage_key = file_key_map[ref_file_id]
# Apply filters
if file_id and ref_file_id != file_id:
continue
if key and not storage_key.endswith(key):
continue
# Only collect items within the requested page range
if offset <= total_count < offset + limit:
paginated_usages.append(
{
"src": f"{ids_table['table']}.{ids_table['column']}",
"record_id": record_id,
"file_id": ref_file_id,
"key": storage_key,
}
)
total_count += 1
case _:
pass
# Output results
if output_json:
result = {

View File

@ -1,10 +1,11 @@
from typing import Any, Literal
from flask import abort, make_response, request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
@ -16,9 +17,11 @@ from controllers.console.wraps import (
)
from extensions.ext_redis import redis_client
from fields.annotation_fields import (
annotation_fields,
annotation_hit_history_fields,
build_annotation_model,
Annotation,
AnnotationExportList,
AnnotationHitHistory,
AnnotationHitHistoryList,
AnnotationList,
)
from libs.helper import uuid_value
from libs.login import login_required
@ -89,6 +92,14 @@ reg(CreateAnnotationPayload)
reg(UpdateAnnotationPayload)
reg(AnnotationReplyStatusQuery)
reg(AnnotationFilePayload)
register_schema_models(
console_ns,
Annotation,
AnnotationList,
AnnotationExportList,
AnnotationHitHistory,
AnnotationHitHistoryList,
)
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>")
@ -107,10 +118,11 @@ class AnnotationReplyActionApi(Resource):
def post(self, app_id, action: Literal["enable", "disable"]):
app_id = str(app_id)
args = AnnotationReplyPayload.model_validate(console_ns.payload)
if action == "enable":
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_id)
match action:
case "enable":
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
case "disable":
result = AppAnnotationService.disable_app_annotation(app_id)
return result, 200
@ -201,33 +213,33 @@ class AnnotationApi(Resource):
app_id = str(app_id)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
response = {
"data": marshal(annotation_list, annotation_fields),
"has_more": len(annotation_list) == limit,
"limit": limit,
"total": total,
"page": page,
}
return response, 200
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
response = AnnotationList(
data=annotation_models,
has_more=len(annotation_list) == limit,
limit=limit,
total=total,
page=page,
)
return response.model_dump(mode="json"), 200
@console_ns.doc("create_annotation")
@console_ns.doc(description="Create a new annotation for an app")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[CreateAnnotationPayload.__name__])
@console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns))
@console_ns.response(201, "Annotation created successfully", console_ns.models[Annotation.__name__])
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@marshal_with(annotation_fields)
@edit_permission_required
def post(self, app_id):
app_id = str(app_id)
args = CreateAnnotationPayload.model_validate(console_ns.payload)
data = args.model_dump(exclude_none=True)
annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
return annotation
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
@setup_required
@login_required
@ -264,7 +276,7 @@ class AnnotationExportApi(Resource):
@console_ns.response(
200,
"Annotations exported successfully",
console_ns.model("AnnotationList", {"data": fields.List(fields.Nested(build_annotation_model(console_ns)))}),
console_ns.models[AnnotationExportList.__name__],
)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -274,7 +286,8 @@ class AnnotationExportApi(Resource):
def get(self, app_id):
app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
response_data = {"data": marshal(annotation_list, annotation_fields)}
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
response_data = AnnotationExportList(data=annotation_models).model_dump(mode="json")
# Create response with secure headers for CSV export
response = make_response(response_data, 200)
@ -289,7 +302,7 @@ class AnnotationUpdateDeleteApi(Resource):
@console_ns.doc("update_delete_annotation")
@console_ns.doc(description="Update or delete an annotation")
@console_ns.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
@console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns))
@console_ns.response(200, "Annotation updated successfully", console_ns.models[Annotation.__name__])
@console_ns.response(204, "Annotation deleted successfully")
@console_ns.response(403, "Insufficient permissions")
@console_ns.expect(console_ns.models[UpdateAnnotationPayload.__name__])
@ -298,7 +311,6 @@ class AnnotationUpdateDeleteApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@edit_permission_required
@marshal_with(annotation_fields)
def post(self, app_id, annotation_id):
app_id = str(app_id)
annotation_id = str(annotation_id)
@ -306,7 +318,7 @@ class AnnotationUpdateDeleteApi(Resource):
annotation = AppAnnotationService.update_app_annotation_directly(
args.model_dump(exclude_none=True), app_id, annotation_id
)
return annotation
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
@setup_required
@login_required
@ -414,14 +426,7 @@ class AnnotationHitHistoryListApi(Resource):
@console_ns.response(
200,
"Hit histories retrieved successfully",
console_ns.model(
"AnnotationHitHistoryList",
{
"data": fields.List(
fields.Nested(console_ns.model("AnnotationHitHistoryItem", annotation_hit_history_fields))
)
},
),
console_ns.models[AnnotationHitHistoryList.__name__],
)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -436,11 +441,14 @@ class AnnotationHitHistoryListApi(Resource):
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(
app_id, annotation_id, page, limit
)
response = {
"data": marshal(annotation_hit_history_list, annotation_hit_history_fields),
"has_more": len(annotation_hit_history_list) == limit,
"limit": limit,
"total": total,
"page": page,
}
return response
history_models = TypeAdapter(list[AnnotationHitHistory]).validate_python(
annotation_hit_history_list, from_attributes=True
)
response = AnnotationHitHistoryList(
data=history_models,
has_more=len(annotation_hit_history_list) == limit,
limit=limit,
total=total,
page=page,
)
return response.model_dump(mode="json")

View File

@ -6,6 +6,7 @@ from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError
import services
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import (
AppUnavailableError,
@ -33,7 +34,6 @@ from services.errors.audio import (
)
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class TextToSpeechPayload(BaseModel):
@ -47,13 +47,11 @@ class TextToSpeechVoiceQuery(BaseModel):
language: str = Field(..., description="Language code")
console_ns.schema_model(
TextToSpeechPayload.__name__, TextToSpeechPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
TextToSpeechVoiceQuery.__name__,
TextToSpeechVoiceQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
class AudioTranscriptResponse(BaseModel):
text: str = Field(description="Transcribed text from audio")
register_schema_models(console_ns, AudioTranscriptResponse, TextToSpeechPayload, TextToSpeechVoiceQuery)
@console_ns.route("/apps/<uuid:app_id>/audio-to-text")
@ -64,7 +62,7 @@ class ChatMessageAudioApi(Resource):
@console_ns.response(
200,
"Audio transcription successful",
console_ns.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}),
console_ns.models[AudioTranscriptResponse.__name__],
)
@console_ns.response(400, "Bad request - No audio uploaded or unsupported type")
@console_ns.response(413, "Audio file too large")

View File

@ -508,16 +508,19 @@ class ChatConversationApi(Resource):
case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at <= end_datetime_utc)
if args.annotation_status == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
elif args.annotation_status == "not_annotated":
query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(MessageAnnotation.id) == 0)
)
match args.annotation_status:
case "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
case "not_annotated":
query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(MessageAnnotation.id) == 0)
)
case "all":
pass
if app_model.mode == AppMode.ADVANCED_CHAT:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)

View File

@ -1,5 +1,4 @@
from collections.abc import Sequence
from typing import Any
from flask_restx import Resource
from pydantic import BaseModel, Field
@ -12,10 +11,12 @@ from controllers.console.app.error import (
ProviderQuotaExceededError,
)
from controllers.console.wraps import account_initialization_required, setup_required
from core.app.app_config.entities import ModelConfig
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
from core.llm_generator.llm_generator import LLMGenerator
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
@ -26,28 +27,13 @@ from services.workflow_service import WorkflowService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class RuleGeneratePayload(BaseModel):
instruction: str = Field(..., description="Rule generation instruction")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
no_variable: bool = Field(default=False, description="Whether to exclude variables")
class RuleCodeGeneratePayload(RuleGeneratePayload):
code_language: str = Field(default="javascript", description="Programming language for code generation")
class RuleStructuredOutputPayload(BaseModel):
instruction: str = Field(..., description="Structured output generation instruction")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
class InstructionGeneratePayload(BaseModel):
flow_id: str = Field(..., description="Workflow/Flow ID")
node_id: str = Field(default="", description="Node ID for workflow context")
current: str = Field(default="", description="Current instruction text")
language: str = Field(default="javascript", description="Programming language (javascript/python)")
instruction: str = Field(..., description="Instruction for generation")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")
ideal_output: str = Field(default="", description="Expected ideal output")
@ -64,6 +50,7 @@ reg(RuleCodeGeneratePayload)
reg(RuleStructuredOutputPayload)
reg(InstructionGeneratePayload)
reg(InstructionTemplatePayload)
reg(ModelConfig)
@console_ns.route("/rule-generate")
@ -82,12 +69,7 @@ class RuleGenerateApi(Resource):
_, current_tenant_id = current_account_with_tenant()
try:
rules = LLMGenerator.generate_rule_config(
tenant_id=current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=args.no_variable,
)
rules = LLMGenerator.generate_rule_config(tenant_id=current_tenant_id, args=args)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
@ -118,9 +100,7 @@ class RuleCodeGenerateApi(Resource):
try:
code_result = LLMGenerator.generate_code(
tenant_id=current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
code_language=args.code_language,
args=args,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -152,8 +132,7 @@ class RuleStructuredOutputGenerateApi(Resource):
try:
structured_output = LLMGenerator.generate_structured_output(
tenant_id=current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
args=args,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -204,23 +183,29 @@ class InstructionGenerateApi(Resource):
case "llm":
return LLMGenerator.generate_rule_config(
current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=True,
args=RuleGeneratePayload(
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=True,
),
)
case "agent":
return LLMGenerator.generate_rule_config(
current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=True,
args=RuleGeneratePayload(
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=True,
),
)
case "code":
return LLMGenerator.generate_code(
tenant_id=current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
code_language=args.language,
args=RuleCodeGeneratePayload(
instruction=args.instruction,
model_config=args.model_config_data,
code_language=args.language,
),
)
case _:
return {"error": f"invalid node type: {node_type}"}

View File

@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, field_validator
from sqlalchemy import exists, select
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import (
CompletionRequestError,
@ -35,7 +36,6 @@ from services.errors.message import MessageNotExistsError, SuggestedQuestionsAft
from services.message_service import MessageService
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ChatMessagesQuery(BaseModel):
@ -90,13 +90,22 @@ class FeedbackExportQuery(BaseModel):
raise ValueError("has_comment must be a boolean value")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
class AnnotationCountResponse(BaseModel):
count: int = Field(description="Number of annotations")
reg(ChatMessagesQuery)
reg(MessageFeedbackPayload)
reg(FeedbackExportQuery)
class SuggestedQuestionsResponse(BaseModel):
data: list[str] = Field(description="Suggested question")
register_schema_models(
console_ns,
ChatMessagesQuery,
MessageFeedbackPayload,
FeedbackExportQuery,
AnnotationCountResponse,
SuggestedQuestionsResponse,
)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
@ -231,7 +240,7 @@ class ChatMessageListApi(Resource):
@marshal_with(message_infinite_scroll_pagination_model)
@edit_permission_required
def get(self, app_model):
args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = ChatMessagesQuery.model_validate(request.args.to_dict())
conversation = (
db.session.query(Conversation)
@ -356,7 +365,7 @@ class MessageAnnotationCountApi(Resource):
@console_ns.response(
200,
"Annotation count retrieved successfully",
console_ns.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}),
console_ns.models[AnnotationCountResponse.__name__],
)
@get_app_model
@setup_required
@ -376,9 +385,7 @@ class MessageSuggestedQuestionApi(Resource):
@console_ns.response(
200,
"Suggested questions retrieved successfully",
console_ns.model(
"SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))}
),
console_ns.models[SuggestedQuestionsResponse.__name__],
)
@console_ns.response(404, "Message or conversation not found")
@setup_required
@ -428,7 +435,7 @@ class MessageFeedbackExportApi(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = FeedbackExportQuery.model_validate(request.args.to_dict())
# Import the service function
from services.feedback_service import FeedbackService

View File

@ -2,9 +2,11 @@ import logging
import httpx
from flask import current_app, redirect, request
from flask_restx import Resource, fields
from flask_restx import Resource
from pydantic import BaseModel, Field
from configs import dify_config
from controllers.common.schema import register_schema_models
from libs.login import login_required
from libs.oauth_data_source import NotionOAuth
@ -14,6 +16,26 @@ from ..wraps import account_initialization_required, is_admin_or_owner_required,
logger = logging.getLogger(__name__)
class OAuthDataSourceResponse(BaseModel):
data: str = Field(description="Authorization URL or 'internal' for internal setup")
class OAuthDataSourceBindingResponse(BaseModel):
result: str = Field(description="Operation result")
class OAuthDataSourceSyncResponse(BaseModel):
result: str = Field(description="Operation result")
register_schema_models(
console_ns,
OAuthDataSourceResponse,
OAuthDataSourceBindingResponse,
OAuthDataSourceSyncResponse,
)
def get_oauth_providers():
with current_app.app_context():
notion_oauth = NotionOAuth(
@ -34,10 +56,7 @@ class OAuthDataSource(Resource):
@console_ns.response(
200,
"Authorization URL or internal setup success",
console_ns.model(
"OAuthDataSourceResponse",
{"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")},
),
console_ns.models[OAuthDataSourceResponse.__name__],
)
@console_ns.response(400, "Invalid provider")
@console_ns.response(403, "Admin privileges required")
@ -101,7 +120,7 @@ class OAuthDataSourceBinding(Resource):
@console_ns.response(
200,
"Data source binding success",
console_ns.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}),
console_ns.models[OAuthDataSourceBindingResponse.__name__],
)
@console_ns.response(400, "Invalid provider or code")
def get(self, provider: str):
@ -133,7 +152,7 @@ class OAuthDataSourceSync(Resource):
@console_ns.response(
200,
"Data source sync success",
console_ns.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}),
console_ns.models[OAuthDataSourceSyncResponse.__name__],
)
@console_ns.response(400, "Invalid provider or sync failed")
@setup_required

View File

@ -2,10 +2,11 @@ import base64
import secrets
from flask import request
from flask_restx import Resource, fields
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.auth.error import (
EmailCodeError,
@ -48,8 +49,31 @@ class ForgotPasswordResetPayload(BaseModel):
return valid_password(value)
for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
class ForgotPasswordEmailResponse(BaseModel):
result: str = Field(description="Operation result")
data: str | None = Field(default=None, description="Reset token")
code: str | None = Field(default=None, description="Error code if account not found")
class ForgotPasswordCheckResponse(BaseModel):
is_valid: bool = Field(description="Whether code is valid")
email: EmailStr = Field(description="Email address")
token: str = Field(description="New reset token")
class ForgotPasswordResetResponse(BaseModel):
result: str = Field(description="Operation result")
register_schema_models(
console_ns,
ForgotPasswordSendPayload,
ForgotPasswordCheckPayload,
ForgotPasswordResetPayload,
ForgotPasswordEmailResponse,
ForgotPasswordCheckResponse,
ForgotPasswordResetResponse,
)
@console_ns.route("/forgot-password")
@ -60,14 +84,7 @@ class ForgotPasswordSendEmailApi(Resource):
@console_ns.response(
200,
"Email sent successfully",
console_ns.model(
"ForgotPasswordEmailResponse",
{
"result": fields.String(description="Operation result"),
"data": fields.String(description="Reset token"),
"code": fields.String(description="Error code if account not found"),
},
),
console_ns.models[ForgotPasswordEmailResponse.__name__],
)
@console_ns.response(400, "Invalid email or rate limit exceeded")
@setup_required
@ -106,14 +123,7 @@ class ForgotPasswordCheckApi(Resource):
@console_ns.response(
200,
"Code verified successfully",
console_ns.model(
"ForgotPasswordCheckResponse",
{
"is_valid": fields.Boolean(description="Whether code is valid"),
"email": fields.String(description="Email address"),
"token": fields.String(description="New reset token"),
},
),
console_ns.models[ForgotPasswordCheckResponse.__name__],
)
@console_ns.response(400, "Invalid code or token")
@setup_required
@ -163,7 +173,7 @@ class ForgotPasswordResetApi(Resource):
@console_ns.response(
200,
"Password reset successfully",
console_ns.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}),
console_ns.models[ForgotPasswordResetResponse.__name__],
)
@console_ns.response(400, "Invalid token or password mismatch")
@setup_required

View File

@ -155,43 +155,43 @@ class OAuthServerUserTokenApi(Resource):
grant_type = OAuthGrantType(payload.grant_type)
except ValueError:
raise BadRequest("invalid grant_type")
match grant_type:
case OAuthGrantType.AUTHORIZATION_CODE:
if not payload.code:
raise BadRequest("code is required")
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
if not payload.code:
raise BadRequest("code is required")
if payload.client_secret != oauth_provider_app.client_secret:
raise BadRequest("client_secret is invalid")
if payload.client_secret != oauth_provider_app.client_secret:
raise BadRequest("client_secret is invalid")
if payload.redirect_uri not in oauth_provider_app.redirect_uris:
raise BadRequest("redirect_uri is invalid")
if payload.redirect_uri not in oauth_provider_app.redirect_uris:
raise BadRequest("redirect_uri is invalid")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, code=payload.code, client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
case OAuthGrantType.REFRESH_TOKEN:
if not payload.refresh_token:
raise BadRequest("refresh_token is required")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, code=payload.code, client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
if not payload.refresh_token:
raise BadRequest("refresh_token is required")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
@console_ns.route("/oauth/provider/account")

View File

@ -1,6 +1,6 @@
import json
from collections.abc import Generator
from typing import Any, cast
from typing import Any, Literal, cast
from flask import request
from flask_restx import Resource, fields, marshal_with
@ -157,9 +157,8 @@ class DataSourceApi(Resource):
@setup_required
@login_required
@account_initialization_required
def patch(self, binding_id, action):
def patch(self, binding_id, action: Literal["enable", "disable"]):
binding_id = str(binding_id)
action = str(action)
with Session(db.engine) as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).filter_by(id=binding_id)
@ -167,23 +166,24 @@ class DataSourceApi(Resource):
if data_source_binding is None:
raise NotFound("Data source binding not found.")
# enable binding
if action == "enable":
if data_source_binding.disabled:
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.add(data_source_binding)
db.session.commit()
else:
raise ValueError("Data source is not disabled.")
# disable binding
if action == "disable":
if not data_source_binding.disabled:
data_source_binding.disabled = True
data_source_binding.updated_at = naive_utc_now()
db.session.add(data_source_binding)
db.session.commit()
else:
raise ValueError("Data source is disabled.")
match action:
case "enable":
if data_source_binding.disabled:
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.add(data_source_binding)
db.session.commit()
else:
raise ValueError("Data source is not disabled.")
# disable binding
case "disable":
if not data_source_binding.disabled:
data_source_binding.disabled = True
data_source_binding.updated_at = naive_utc_now()
db.session.add(data_source_binding)
db.session.commit()
else:
raise ValueError("Data source is disabled.")
return {"result": "success"}, 200

View File

@ -576,63 +576,62 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
if document.indexing_status in {"completed", "error"}:
raise DocumentAlreadyFinishedError()
data_source_info = document.data_source_info_dict
match document.data_source_type:
case "upload_file":
if not data_source_info:
continue
file_id = data_source_info["upload_file_id"]
file_detail = (
db.session.query(UploadFile)
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
.first()
)
if document.data_source_type == "upload_file":
if not data_source_info:
continue
file_id = data_source_info["upload_file_id"]
file_detail = (
db.session.query(UploadFile)
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
.first()
)
if file_detail is None:
raise NotFound("File not found.")
if file_detail is None:
raise NotFound("File not found.")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form
)
extract_settings.append(extract_setting)
case "notion_import":
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info.get("credential_id"),
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"tenant_id": current_tenant_id,
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)
case "website_crawl":
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"url": data_source_info["url"],
"tenant_id": current_tenant_id,
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form
)
extract_settings.append(extract_setting)
elif document.data_source_type == "notion_import":
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info.get("credential_id"),
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"tenant_id": current_tenant_id,
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)
elif document.data_source_type == "website_crawl":
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"url": data_source_info["url"],
"tenant_id": current_tenant_id,
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)
else:
raise ValueError("Data source type not support")
case _:
raise ValueError("Data source type not support")
indexing_runner = IndexingRunner()
try:
response = indexing_runner.indexing_estimate(
@ -954,23 +953,24 @@ class DocumentProcessingApi(DocumentResource):
if not current_user.is_dataset_editor:
raise Forbidden()
if action == "pause":
if document.indexing_status != "indexing":
raise InvalidActionError("Document not in indexing state.")
match action:
case "pause":
if document.indexing_status != "indexing":
raise InvalidActionError("Document not in indexing state.")
document.paused_by = current_user.id
document.paused_at = naive_utc_now()
document.is_paused = True
db.session.commit()
document.paused_by = current_user.id
document.paused_at = naive_utc_now()
document.is_paused = True
db.session.commit()
elif action == "resume":
if document.indexing_status not in {"paused", "error"}:
raise InvalidActionError("Document not in paused or error state.")
case "resume":
if document.indexing_status not in {"paused", "error"}:
raise InvalidActionError("Document not in paused or error state.")
document.paused_by = None
document.paused_at = None
document.is_paused = False
db.session.commit()
document.paused_by = None
document.paused_at = None
document.is_paused = False
db.session.commit()
return {"result": "success"}, 200
@ -1339,6 +1339,18 @@ class DocumentGenerateSummaryApi(Resource):
missing_ids = set(document_list) - found_ids
raise NotFound(f"Some documents not found: {list(missing_ids)}")
# Update need_summary to True for documents that don't have it set
# This handles the case where documents were created when summary_index_setting was disabled
documents_to_update = [doc for doc in documents if not doc.need_summary and doc.doc_form != "qa_model"]
if documents_to_update:
document_ids_to_update = [str(doc.id) for doc in documents_to_update]
DocumentService.update_documents_need_summary(
dataset_id=dataset_id,
document_ids=document_ids_to_update,
need_summary=True,
)
# Dispatch async tasks for each document
for document in documents:
# Skip qa_model documents as they don't generate summaries

View File

@ -126,10 +126,11 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
if action == "enable":
MetadataService.enable_built_in_field(dataset)
elif action == "disable":
MetadataService.disable_built_in_field(dataset)
match action:
case "enable":
MetadataService.enable_built_in_field(dataset)
case "disable":
MetadataService.disable_built_in_field(dataset)
return {"result": "success"}, 200

View File

@ -1,10 +1,9 @@
import json
import logging
from typing import Any, Literal, cast
from uuid import UUID
from flask import abort, request
from flask_restx import Resource, marshal_with, reqparse # type: ignore
from flask_restx import Resource, marshal_with # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@ -38,7 +37,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from factories import variable_factory
from libs import helper
from libs.helper import TimestampField
from libs.helper import TimestampField, UUIDStrOrEmpty
from libs.login import current_account_with_tenant, current_user, login_required
from models import Account
from models.dataset import Pipeline
@ -110,7 +109,7 @@ class NodeIdQuery(BaseModel):
class WorkflowRunQuery(BaseModel):
last_id: UUID | None = None
last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
@ -121,6 +120,10 @@ class DatasourceVariablesPayload(BaseModel):
start_node_title: str
class RagPipelineRecommendedPluginQuery(BaseModel):
type: str = "all"
register_schema_models(
console_ns,
DraftWorkflowSyncPayload,
@ -135,6 +138,7 @@ register_schema_models(
NodeIdQuery,
WorkflowRunQuery,
DatasourceVariablesPayload,
RagPipelineRecommendedPluginQuery,
)
@ -975,11 +979,8 @@ class RagPipelineRecommendedPluginApi(Resource):
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("type", type=str, location="args", required=False, default="all")
args = parser.parse_args()
type = args["type"]
query = RagPipelineRecommendedPluginQuery.model_validate(request.args.to_dict())
rag_pipeline_service = RagPipelineService()
recommended_plugins = rag_pipeline_service.get_recommended_plugins(type)
recommended_plugins = rag_pipeline_service.get_recommended_plugins(query.type)
return recommended_plugins

View File

@ -9,7 +9,7 @@ import services
from controllers.common.fields import Parameters as ParametersResponse
from controllers.common.fields import Site as SiteResponse
from controllers.common.schema import get_or_create_model
from controllers.console import api, console_ns
from controllers.console import api
from controllers.console.app.error import (
AppUnavailableError,
AudioTooLargeError,
@ -51,7 +51,7 @@ from fields.app_fields import (
tag_fields,
)
from fields.dataset_fields import dataset_fields
from fields.member_fields import build_simple_account_model
from fields.member_fields import simple_account_fields
from fields.workflow_fields import (
conversation_variable_fields,
pipeline_variable_fields,
@ -103,7 +103,7 @@ app_detail_fields_with_site_copy["tags"] = fields.List(fields.Nested(tag_model))
app_detail_fields_with_site_copy["site"] = fields.Nested(site_model)
app_detail_with_site_model = get_or_create_model("TrialAppDetailWithSite", app_detail_fields_with_site_copy)
simple_account_model = build_simple_account_model(console_ns)
simple_account_model = get_or_create_model("SimpleAccount", simple_account_fields)
conversation_variable_model = get_or_create_model("TrialConversationVariable", conversation_variable_fields)
pipeline_variable_model = get_or_create_model("TrialPipelineVariable", pipeline_variable_fields)

View File

@ -1,87 +1,74 @@
import os
from typing import Literal
from flask import session
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from controllers.fastopenapi import console_router
from extensions.ext_database import db
from models.model import DifySetup
from services.account_service import TenantService
from . import console_ns
from .error import AlreadySetupError, InitValidateFailedError
from .wraps import only_edition_self_hosted
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class InitValidatePayload(BaseModel):
password: str = Field(..., max_length=30)
password: str = Field(..., max_length=30, description="Initialization password")
console_ns.schema_model(
InitValidatePayload.__name__,
InitValidatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
class InitStatusResponse(BaseModel):
status: Literal["finished", "not_started"] = Field(..., description="Initialization status")
class InitValidateResponse(BaseModel):
result: str = Field(description="Operation result", examples=["success"])
@console_router.get(
"/init",
response_model=InitStatusResponse,
tags=["console"],
)
def get_init_status() -> InitStatusResponse:
"""Get initialization validation status."""
init_status = get_init_validate_status()
if init_status:
return InitStatusResponse(status="finished")
return InitStatusResponse(status="not_started")
@console_ns.route("/init")
class InitValidateAPI(Resource):
@console_ns.doc("get_init_status")
@console_ns.doc(description="Get initialization validation status")
@console_ns.response(
200,
"Success",
model=console_ns.model(
"InitStatusResponse",
{"status": fields.String(description="Initialization status", enum=["finished", "not_started"])},
),
)
def get(self):
"""Get initialization validation status"""
init_status = get_init_validate_status()
if init_status:
return {"status": "finished"}
return {"status": "not_started"}
@console_router.post(
"/init",
response_model=InitValidateResponse,
tags=["console"],
status_code=201,
)
@only_edition_self_hosted
def validate_init_password(payload: InitValidatePayload) -> InitValidateResponse:
"""Validate initialization password."""
tenant_count = TenantService.get_tenant_count()
if tenant_count > 0:
raise AlreadySetupError()
@console_ns.doc("validate_init_password")
@console_ns.doc(description="Validate initialization password for self-hosted edition")
@console_ns.expect(console_ns.models[InitValidatePayload.__name__])
@console_ns.response(
201,
"Success",
model=console_ns.model("InitValidateResponse", {"result": fields.String(description="Operation result")}),
)
@console_ns.response(400, "Already setup or validation failed")
@only_edition_self_hosted
def post(self):
"""Validate initialization password"""
# is tenant created
tenant_count = TenantService.get_tenant_count()
if tenant_count > 0:
raise AlreadySetupError()
if payload.password != os.environ.get("INIT_PASSWORD"):
session["is_init_validated"] = False
raise InitValidateFailedError()
payload = InitValidatePayload.model_validate(console_ns.payload)
input_password = payload.password
if input_password != os.environ.get("INIT_PASSWORD"):
session["is_init_validated"] = False
raise InitValidateFailedError()
session["is_init_validated"] = True
return {"result": "success"}, 201
session["is_init_validated"] = True
return InitValidateResponse(result="success")
def get_init_validate_status():
def get_init_validate_status() -> bool:
if dify_config.EDITION == "SELF_HOSTED":
if os.environ.get("INIT_PASSWORD"):
if session.get("is_init_validated"):
return True
with Session(db.engine) as db_session:
return db_session.execute(select(DifySetup)).scalar_one_or_none()
return db_session.execute(select(DifySetup)).scalar_one_or_none() is not None
return True

View File

@ -1,7 +1,6 @@
import urllib.parse
import httpx
from flask_restx import Resource
from pydantic import BaseModel, Field
import services
@ -11,7 +10,7 @@ from controllers.common.errors import (
RemoteFileUploadError,
UnsupportedFileTypeError,
)
from controllers.common.schema import register_schema_models
from controllers.fastopenapi import console_router
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from extensions.ext_database import db
@ -19,84 +18,74 @@ from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
from libs.login import current_account_with_tenant
from services.file_service import FileService
from . import console_ns
register_schema_models(console_ns, RemoteFileInfo, FileWithSignedUrl)
@console_ns.route("/remote-files/<path:url>")
class RemoteFileInfoApi(Resource):
@console_ns.response(200, "Remote file info", console_ns.models[RemoteFileInfo.__name__])
def get(self, url):
decoded_url = urllib.parse.unquote(url)
resp = ssrf_proxy.head(decoded_url)
if resp.status_code != httpx.codes.OK:
# failed back to get method
resp = ssrf_proxy.get(decoded_url, timeout=3)
resp.raise_for_status()
info = RemoteFileInfo(
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
file_length=int(resp.headers.get("Content-Length", 0)),
)
return info.model_dump(mode="json")
class RemoteFileUploadPayload(BaseModel):
url: str = Field(..., description="URL to fetch")
console_ns.schema_model(
RemoteFileUploadPayload.__name__,
RemoteFileUploadPayload.model_json_schema(ref_template="#/definitions/{model}"),
@console_router.get(
"/remote-files/<path:url>",
response_model=RemoteFileInfo,
tags=["console"],
)
def get_remote_file_info(url: str) -> RemoteFileInfo:
decoded_url = urllib.parse.unquote(url)
resp = ssrf_proxy.head(decoded_url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(decoded_url, timeout=3)
resp.raise_for_status()
return RemoteFileInfo(
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
file_length=int(resp.headers.get("Content-Length", 0)),
)
@console_ns.route("/remote-files/upload")
class RemoteFileUploadApi(Resource):
@console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__])
@console_ns.response(201, "Remote file uploaded", console_ns.models[FileWithSignedUrl.__name__])
def post(self):
args = RemoteFileUploadPayload.model_validate(console_ns.payload)
url = args.url
@console_router.post(
"/remote-files/upload",
response_model=FileWithSignedUrl,
tags=["console"],
status_code=201,
)
def upload_remote_file(payload: RemoteFileUploadPayload) -> FileWithSignedUrl:
url = payload.url
try:
resp = ssrf_proxy.head(url=url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
if resp.status_code != httpx.codes.OK:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
except httpx.RequestError as e:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}")
try:
resp = ssrf_proxy.head(url=url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
if resp.status_code != httpx.codes.OK:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
except httpx.RequestError as e:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}")
file_info = helpers.guess_file_info_from_response(resp)
file_info = helpers.guess_file_info_from_response(resp)
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
raise FileTooLargeError
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
raise FileTooLargeError
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
try:
user, _ = current_account_with_tenant()
upload_file = FileService(db.engine).upload_file(
filename=file_info.filename,
content=content,
mimetype=file_info.mimetype,
user=user,
source_url=url,
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
payload = FileWithSignedUrl(
id=upload_file.id,
name=upload_file.name,
size=upload_file.size,
extension=upload_file.extension,
url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
mime_type=upload_file.mime_type,
created_by=upload_file.created_by,
created_at=int(upload_file.created_at.timestamp()),
try:
user, _ = current_account_with_tenant()
upload_file = FileService(db.engine).upload_file(
filename=file_info.filename,
content=content,
mimetype=file_info.mimetype,
user=user,
source_url=url,
)
return payload.model_dump(mode="json"), 201
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return FileWithSignedUrl(
id=upload_file.id,
name=upload_file.name,
size=upload_file.size,
extension=upload_file.extension,
url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
mime_type=upload_file.mime_type,
created_by=upload_file.created_by,
created_at=int(upload_file.created_at.timestamp()),
)

View File

@ -1,14 +1,11 @@
from typing import Literal
from uuid import UUID
from flask import request
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from fields.tag_fields import dataset_tag_fields
from controllers.fastopenapi import console_router
from libs.login import current_account_with_tenant, login_required
from services.tag_service import TagService
@ -35,115 +32,129 @@ class TagListQueryParam(BaseModel):
keyword: str | None = Field(None, description="Search keyword")
register_schema_models(
console_ns,
TagBasePayload,
TagBindingPayload,
TagBindingRemovePayload,
TagListQueryParam,
class TagResponse(BaseModel):
id: str = Field(description="Tag ID")
name: str = Field(description="Tag name")
type: str = Field(description="Tag type")
binding_count: int = Field(description="Number of bindings")
class TagBindingResult(BaseModel):
result: Literal["success"] = Field(description="Operation result", examples=["success"])
@console_router.get(
"/tags",
response_model=list[TagResponse],
tags=["console"],
)
@setup_required
@login_required
@account_initialization_required
def list_tags(query: TagListQueryParam) -> list[TagResponse]:
_, current_tenant_id = current_account_with_tenant()
tags = TagService.get_tags(query.type, current_tenant_id, query.keyword)
return [
TagResponse(
id=tag.id,
name=tag.name,
type=tag.type,
binding_count=int(tag.binding_count),
)
for tag in tags
]
@console_ns.route("/tags")
class TagListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.doc(
params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."}
)
@marshal_with(dataset_tag_fields)
def get(self):
_, current_tenant_id = current_account_with_tenant()
raw_args = request.args.to_dict()
param = TagListQueryParam.model_validate(raw_args)
tags = TagService.get_tags(param.type, current_tenant_id, param.keyword)
@console_router.post(
"/tags",
response_model=TagResponse,
tags=["console"],
)
@setup_required
@login_required
@account_initialization_required
def create_tag(payload: TagBasePayload) -> TagResponse:
current_user, _ = current_account_with_tenant()
# The role of the current user in the tag table must be admin, owner, or editor
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
return tags, 200
tag = TagService.save_tags(payload.model_dump())
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.save_tags(payload.model_dump())
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
return response, 200
return TagResponse(id=tag.id, name=tag.name, type=tag.type, binding_count=0)
@console_ns.route("/tags/<uuid:tag_id>")
class TagUpdateDeleteApi(Resource):
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def patch(self, tag_id):
current_user, _ = current_account_with_tenant()
tag_id = str(tag_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
@console_router.patch(
"/tags/<uuid:tag_id>",
response_model=TagResponse,
tags=["console"],
)
@setup_required
@login_required
@account_initialization_required
def update_tag(tag_id: UUID, payload: TagBasePayload) -> TagResponse:
current_user, _ = current_account_with_tenant()
tag_id_str = str(tag_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.update_tags(payload.model_dump(), tag_id)
tag = TagService.update_tags(payload.model_dump(), tag_id_str)
binding_count = TagService.get_tag_binding_count(tag_id)
binding_count = TagService.get_tag_binding_count(tag_id_str)
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
return response, 200
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, tag_id):
tag_id = str(tag_id)
TagService.delete_tag(tag_id)
return 204
return TagResponse(id=tag.id, name=tag.name, type=tag.type, binding_count=binding_count)
@console_ns.route("/tag-bindings/create")
class TagBindingCreateApi(Resource):
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
@console_router.delete(
"/tags/<uuid:tag_id>",
tags=["console"],
status_code=204,
)
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def delete_tag(tag_id: UUID) -> None:
tag_id_str = str(tag_id)
payload = TagBindingPayload.model_validate(console_ns.payload or {})
TagService.save_tag_binding(payload.model_dump())
return {"result": "success"}, 200
TagService.delete_tag(tag_id_str)
@console_ns.route("/tag-bindings/remove")
class TagBindingDeleteApi(Resource):
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
@console_router.post(
"/tag-bindings/create",
response_model=TagBindingResult,
tags=["console"],
)
@setup_required
@login_required
@account_initialization_required
def create_tag_binding(payload: TagBindingPayload) -> TagBindingResult:
current_user, _ = current_account_with_tenant()
# The role of the current user in the tag table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(payload.model_dump())
TagService.save_tag_binding(payload.model_dump())
return {"result": "success"}, 200
return TagBindingResult(result="success")
@console_router.post(
"/tag-bindings/remove",
response_model=TagBindingResult,
tags=["console"],
)
@setup_required
@login_required
@account_initialization_required
def delete_tag_binding(payload: TagBindingRemovePayload) -> TagBindingResult:
current_user, _ = current_account_with_tenant()
# The role of the current user in the tag table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
TagService.delete_tag_binding(payload.model_dump())
return TagBindingResult(result="success")

View File

@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
from configs import dify_config
from constants.languages import supported_language
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.auth.error import (
EmailAlreadyInUseError,
@ -37,7 +38,7 @@ from controllers.console.wraps import (
setup_required,
)
from extensions.ext_database import db
from fields.member_fields import account_fields
from fields.member_fields import Account as AccountResponse
from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
from libs.login import current_account_with_tenant, login_required
@ -170,6 +171,12 @@ reg(ChangeEmailSendPayload)
reg(ChangeEmailValidityPayload)
reg(ChangeEmailResetPayload)
reg(CheckEmailUniquePayload)
register_schema_models(console_ns, AccountResponse)
def _serialize_account(account) -> dict:
return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json")
integrate_fields = {
"provider": fields.String,
@ -236,11 +243,11 @@ class AccountProfileApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
@enterprise_license_required
def get(self):
current_user, _ = current_account_with_tenant()
return current_user
return _serialize_account(current_user)
@console_ns.route("/account/name")
@ -249,14 +256,14 @@ class AccountNameApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
args = AccountNamePayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, name=args.name)
return updated_account
return _serialize_account(updated_account)
@console_ns.route("/account/avatar")
@ -265,7 +272,7 @@ class AccountAvatarApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
@ -273,7 +280,7 @@ class AccountAvatarApi(Resource):
updated_account = AccountService.update_account(current_user, avatar=args.avatar)
return updated_account
return _serialize_account(updated_account)
@console_ns.route("/account/interface-language")
@ -282,7 +289,7 @@ class AccountInterfaceLanguageApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
@ -290,7 +297,7 @@ class AccountInterfaceLanguageApi(Resource):
updated_account = AccountService.update_account(current_user, interface_language=args.interface_language)
return updated_account
return _serialize_account(updated_account)
@console_ns.route("/account/interface-theme")
@ -299,7 +306,7 @@ class AccountInterfaceThemeApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
@ -307,7 +314,7 @@ class AccountInterfaceThemeApi(Resource):
updated_account = AccountService.update_account(current_user, interface_theme=args.interface_theme)
return updated_account
return _serialize_account(updated_account)
@console_ns.route("/account/timezone")
@ -316,7 +323,7 @@ class AccountTimezoneApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
@ -324,7 +331,7 @@ class AccountTimezoneApi(Resource):
updated_account = AccountService.update_account(current_user, timezone=args.timezone)
return updated_account
return _serialize_account(updated_account)
@console_ns.route("/account/password")
@ -333,7 +340,7 @@ class AccountPasswordApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
@ -344,7 +351,7 @@ class AccountPasswordApi(Resource):
except ServiceCurrentPasswordIncorrectError:
raise CurrentPasswordIncorrectError()
return {"result": "success"}
return _serialize_account(current_user)
@console_ns.route("/account/integrates")
@ -620,7 +627,7 @@ class ChangeEmailResetApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
payload = console_ns.payload or {}
args = ChangeEmailResetPayload.model_validate(payload)
@ -649,7 +656,7 @@ class ChangeEmailResetApi(Resource):
email=normalized_new_email,
)
return updated_account
return _serialize_account(updated_account)
@console_ns.route("/account/change-email/check-email-unique")

View File

@ -1,9 +1,10 @@
from typing import Any
from flask import request
from flask_restx import Resource, fields
from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
@ -38,15 +39,53 @@ class EndpointListForPluginQuery(EndpointListQuery):
plugin_id: str
class EndpointCreateResponse(BaseModel):
success: bool = Field(description="Operation success")
class EndpointListResponse(BaseModel):
endpoints: list[dict[str, Any]] = Field(description="Endpoint information")
class PluginEndpointListResponse(BaseModel):
endpoints: list[dict[str, Any]] = Field(description="Endpoint information")
class EndpointDeleteResponse(BaseModel):
success: bool = Field(description="Operation success")
class EndpointUpdateResponse(BaseModel):
success: bool = Field(description="Operation success")
class EndpointEnableResponse(BaseModel):
success: bool = Field(description="Operation success")
class EndpointDisableResponse(BaseModel):
success: bool = Field(description="Operation success")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(EndpointCreatePayload)
reg(EndpointIdPayload)
reg(EndpointUpdatePayload)
reg(EndpointListQuery)
reg(EndpointListForPluginQuery)
register_schema_models(
console_ns,
EndpointCreatePayload,
EndpointIdPayload,
EndpointUpdatePayload,
EndpointListQuery,
EndpointListForPluginQuery,
EndpointCreateResponse,
EndpointListResponse,
PluginEndpointListResponse,
EndpointDeleteResponse,
EndpointUpdateResponse,
EndpointEnableResponse,
EndpointDisableResponse,
)
@console_ns.route("/workspaces/current/endpoints/create")
@ -57,7 +96,7 @@ class EndpointCreateApi(Resource):
@console_ns.response(
200,
"Endpoint created successfully",
console_ns.model("EndpointCreateResponse", {"success": fields.Boolean(description="Operation success")}),
console_ns.models[EndpointCreateResponse.__name__],
)
@console_ns.response(403, "Admin privileges required")
@setup_required
@ -91,9 +130,7 @@ class EndpointListApi(Resource):
@console_ns.response(
200,
"Success",
console_ns.model(
"EndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}
),
console_ns.models[EndpointListResponse.__name__],
)
@setup_required
@login_required
@ -126,9 +163,7 @@ class EndpointListForSinglePluginApi(Resource):
@console_ns.response(
200,
"Success",
console_ns.model(
"PluginEndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}
),
console_ns.models[PluginEndpointListResponse.__name__],
)
@setup_required
@login_required
@ -163,7 +198,7 @@ class EndpointDeleteApi(Resource):
@console_ns.response(
200,
"Endpoint deleted successfully",
console_ns.model("EndpointDeleteResponse", {"success": fields.Boolean(description="Operation success")}),
console_ns.models[EndpointDeleteResponse.__name__],
)
@console_ns.response(403, "Admin privileges required")
@setup_required
@ -190,7 +225,7 @@ class EndpointUpdateApi(Resource):
@console_ns.response(
200,
"Endpoint updated successfully",
console_ns.model("EndpointUpdateResponse", {"success": fields.Boolean(description="Operation success")}),
console_ns.models[EndpointUpdateResponse.__name__],
)
@console_ns.response(403, "Admin privileges required")
@setup_required
@ -221,7 +256,7 @@ class EndpointEnableApi(Resource):
@console_ns.response(
200,
"Endpoint enabled successfully",
console_ns.model("EndpointEnableResponse", {"success": fields.Boolean(description="Operation success")}),
console_ns.models[EndpointEnableResponse.__name__],
)
@console_ns.response(403, "Admin privileges required")
@setup_required
@ -248,7 +283,7 @@ class EndpointDisableApi(Resource):
@console_ns.response(
200,
"Endpoint disabled successfully",
console_ns.model("EndpointDisableResponse", {"success": fields.Boolean(description="Operation success")}),
console_ns.models[EndpointDisableResponse.__name__],
)
@console_ns.response(403, "Admin privileges required")
@setup_required

View File

@ -1,12 +1,12 @@
from urllib import parse
from flask import abort, request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter
import services
from configs import dify_config
from controllers.common.schema import get_or_create_model, register_enum_models
from controllers.common.schema import register_enum_models, register_schema_models
from controllers.console import console_ns
from controllers.console.auth.error import (
CannotTransferOwnerToSelfError,
@ -25,7 +25,7 @@ from controllers.console.wraps import (
setup_required,
)
from extensions.ext_database import db
from fields.member_fields import account_with_role_fields, account_with_role_list_fields
from fields.member_fields import AccountWithRole, AccountWithRoleList
from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant, login_required
from models.account import Account, TenantAccountRole
@ -69,12 +69,7 @@ reg(OwnerTransferEmailPayload)
reg(OwnerTransferCheckPayload)
reg(OwnerTransferPayload)
register_enum_models(console_ns, TenantAccountRole)
account_with_role_model = get_or_create_model("AccountWithRole", account_with_role_fields)
account_with_role_list_fields_copy = account_with_role_list_fields.copy()
account_with_role_list_fields_copy["accounts"] = fields.List(fields.Nested(account_with_role_model))
account_with_role_list_model = get_or_create_model("AccountWithRoleList", account_with_role_list_fields_copy)
register_schema_models(console_ns, AccountWithRole, AccountWithRoleList)
@console_ns.route("/workspaces/current/members")
@ -84,13 +79,15 @@ class MemberListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_with_role_list_model)
@console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
def get(self):
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_tenant_members(current_user.current_tenant)
return {"result": "success", "accounts": members}, 200
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
response = AccountWithRoleList(accounts=member_models)
return response.model_dump(mode="json"), 200
@console_ns.route("/workspaces/current/members/invite-email")
@ -235,13 +232,15 @@ class DatasetOperatorMemberListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_with_role_list_model)
@console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
def get(self):
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
return {"result": "success", "accounts": members}, 200
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
response = AccountWithRoleList(accounts=member_models)
return response.model_dump(mode="json"), 200
@console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email")

File diff suppressed because it is too large Load Diff

View File

@ -1,16 +1,16 @@
from typing import Literal
from flask import request
from flask_restx import Namespace, Resource, fields
from flask_restx import Resource
from flask_restx.api import HTTPStatus
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, TypeAdapter
from controllers.common.schema import register_schema_models
from controllers.console.wraps import edit_permission_required
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import validate_app_token
from extensions.ext_redis import redis_client
from fields.annotation_fields import annotation_fields, build_annotation_model
from fields.annotation_fields import Annotation, AnnotationList
from models.model import App
from services.annotation_service import AppAnnotationService
@ -26,7 +26,9 @@ class AnnotationReplyActionPayload(BaseModel):
embedding_model_name: str = Field(description="Embedding model name")
register_schema_models(service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload)
register_schema_models(
service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload, Annotation, AnnotationList
)
@service_api_ns.route("/apps/annotation-reply/<string:action>")
@ -45,10 +47,11 @@ class AnnotationReplyActionApi(Resource):
def post(self, app_model: App, action: Literal["enable", "disable"]):
"""Enable or disable annotation reply feature."""
args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump()
if action == "enable":
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_model.id)
match action:
case "enable":
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
case "disable":
result = AppAnnotationService.disable_app_annotation(app_model.id)
return result, 200
@ -82,23 +85,6 @@ class AnnotationReplyActionStatusApi(Resource):
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
# Define annotation list response model
annotation_list_fields = {
"data": fields.List(fields.Nested(annotation_fields)),
"has_more": fields.Boolean,
"limit": fields.Integer,
"total": fields.Integer,
"page": fields.Integer,
}
def build_annotation_list_model(api_or_ns: Namespace):
"""Build the annotation list model for the API or Namespace."""
copied_annotation_list_fields = annotation_list_fields.copy()
copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns)))
return api_or_ns.model("AnnotationList", copied_annotation_list_fields)
@service_api_ns.route("/apps/annotations")
class AnnotationListApi(Resource):
@service_api_ns.doc("list_annotations")
@ -109,8 +95,12 @@ class AnnotationListApi(Resource):
401: "Unauthorized - invalid API token",
}
)
@service_api_ns.response(
200,
"Annotations retrieved successfully",
service_api_ns.models[AnnotationList.__name__],
)
@validate_app_token
@service_api_ns.marshal_with(build_annotation_list_model(service_api_ns))
def get(self, app_model: App):
"""List annotations for the application."""
page = request.args.get("page", default=1, type=int)
@ -118,13 +108,15 @@ class AnnotationListApi(Resource):
keyword = request.args.get("keyword", default="", type=str)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_model.id, page, limit, keyword)
return {
"data": annotation_list,
"has_more": len(annotation_list) == limit,
"limit": limit,
"total": total,
"page": page,
}
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
response = AnnotationList(
data=annotation_models,
has_more=len(annotation_list) == limit,
limit=limit,
total=total,
page=page,
)
return response.model_dump(mode="json")
@service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__])
@service_api_ns.doc("create_annotation")
@ -135,13 +127,18 @@ class AnnotationListApi(Resource):
401: "Unauthorized - invalid API token",
}
)
@service_api_ns.response(
HTTPStatus.CREATED,
"Annotation created successfully",
service_api_ns.models[Annotation.__name__],
)
@validate_app_token
@service_api_ns.marshal_with(build_annotation_model(service_api_ns), code=HTTPStatus.CREATED)
def post(self, app_model: App):
"""Create a new annotation."""
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id)
return annotation, 201
response = Annotation.model_validate(annotation, from_attributes=True)
return response.model_dump(mode="json"), HTTPStatus.CREATED
@service_api_ns.route("/apps/annotations/<uuid:annotation_id>")
@ -158,14 +155,19 @@ class AnnotationUpdateDeleteApi(Resource):
404: "Annotation not found",
}
)
@service_api_ns.response(
200,
"Annotation updated successfully",
service_api_ns.models[Annotation.__name__],
)
@validate_app_token
@edit_permission_required
@service_api_ns.marshal_with(build_annotation_model(service_api_ns))
def put(self, app_model: App, annotation_id: str):
"""Update an existing annotation."""
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
return annotation
response = Annotation.model_validate(annotation, from_attributes=True)
return response.model_dump(mode="json")
@service_api_ns.doc("delete_annotation")
@service_api_ns.doc(description="Delete an annotation")

View File

@ -30,6 +30,7 @@ from core.errors.error import (
from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import UUIDStrOrEmpty
from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService
from services.app_task_service import AppTaskService
@ -52,7 +53,7 @@ class ChatRequestPayload(BaseModel):
query: str
files: list[dict[str, Any]] | None = None
response_mode: Literal["blocking", "streaming"] | None = None
conversation_id: str | None = Field(default=None, description="Conversation UUID")
conversation_id: UUIDStrOrEmpty | None = Field(default=None, description="Conversation UUID")
retriever_from: str = Field(default="dev")
auto_generate_name: bool = Field(default=True, description="Auto generate conversation name")
workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat")

View File

@ -1,5 +1,4 @@
from typing import Any, Literal
from uuid import UUID
from flask import request
from flask_restx import Resource
@ -23,12 +22,13 @@ from fields.conversation_variable_fields import (
build_conversation_variable_infinite_scroll_pagination_model,
build_conversation_variable_model,
)
from libs.helper import UUIDStrOrEmpty
from models.model import App, AppMode, EndUser
from services.conversation_service import ConversationService
class ConversationListQuery(BaseModel):
last_id: UUID | None = Field(default=None, description="Last conversation ID for pagination")
last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last conversation ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of conversations to return")
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
default="-updated_at", description="Sort order for conversations"
@ -48,7 +48,7 @@ class ConversationRenamePayload(BaseModel):
class ConversationVariablesQuery(BaseModel):
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")
last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last variable ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
variable_name: str | None = Field(
default=None, description="Filter variables by name", min_length=1, max_length=255

View File

@ -1,6 +1,5 @@
import logging
from typing import Literal
from uuid import UUID
from flask import request
from flask_restx import Resource
@ -15,6 +14,7 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
from libs.helper import UUIDStrOrEmpty
from models.model import App, AppMode, EndUser
from services.errors.message import (
FirstMessageNotExistsError,
@ -27,8 +27,8 @@ logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
conversation_id: UUID
first_id: UUID | None = None
conversation_id: UUIDStrOrEmpty
first_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return")

View File

@ -17,7 +17,7 @@ from controllers.service_api.wraps import (
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from fields.dataset_fields import dataset_detail_fields
from fields.tag_fields import build_dataset_tag_fields
from fields.tag_fields import DataSetTag
from libs.login import current_user
from models.account import Account
from models.dataset import DatasetPermissionEnum
@ -114,6 +114,7 @@ register_schema_models(
TagBindingPayload,
TagUnbindingPayload,
DatasetListQuery,
DataSetTag,
)
@ -480,15 +481,14 @@ class DatasetTagsApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
def get(self, _):
"""Get all knowledge type tags."""
assert isinstance(current_user, Account)
cid = current_user.current_tenant_id
assert cid is not None
tags = TagService.get_tags("knowledge", cid)
return tags, 200
tag_models = TypeAdapter(list[DataSetTag]).validate_python(tags, from_attributes=True)
return [tag.model_dump(mode="json") for tag in tag_models], 200
@service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__])
@service_api_ns.doc("create_dataset_tag")
@ -500,7 +500,6 @@ class DatasetTagsApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
def post(self, _):
"""Add a knowledge type tag."""
assert isinstance(current_user, Account)
@ -510,7 +509,9 @@ class DatasetTagsApi(DatasetApiResource):
payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
tag = TagService.save_tags({"name": payload.name, "type": "knowledge"})
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
response = DataSetTag.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
).model_dump(mode="json")
return response, 200
@service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__])
@ -523,7 +524,6 @@ class DatasetTagsApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
def patch(self, _):
assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@ -536,8 +536,9 @@ class DatasetTagsApi(DatasetApiResource):
binding_count = TagService.get_tag_binding_count(tag_id)
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
response = DataSetTag.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
).model_dump(mode="json")
return response, 200
@service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__])

View File

@ -1,7 +1,10 @@
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
from controllers.common.schema import register_schema_model
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
register_schema_model(service_api_ns, HitTestingPayload)
@service_api_ns.route("/datasets/<uuid:dataset_id>/hit-testing", "/datasets/<uuid:dataset_id>/retrieve")
class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
@ -15,6 +18,7 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
404: "Dataset not found",
}
)
@service_api_ns.expect(service_api_ns.models[HitTestingPayload.__name__])
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id):
"""Perform hit testing on a dataset.

View File

@ -168,10 +168,11 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
if action == "enable":
MetadataService.enable_built_in_field(dataset)
elif action == "disable":
MetadataService.disable_built_in_field(dataset)
match action:
case "enable":
MetadataService.enable_built_in_field(dataset)
case "disable":
MetadataService.disable_built_in_field(dataset)
return {"result": "success"}, 200

View File

@ -73,14 +73,14 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
# If caller needs end-user context, attach EndUser to current_user
if fetch_user_arg:
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
user_id = request.args.get("user")
elif fetch_user_arg.fetch_from == WhereisUserArg.JSON:
user_id = request.get_json().get("user")
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
user_id = request.form.get("user")
else:
user_id = None
user_id = None
match fetch_user_arg.fetch_from:
case WhereisUserArg.QUERY:
user_id = request.args.get("user")
case WhereisUserArg.JSON:
user_id = request.get_json().get("user")
case WhereisUserArg.FORM:
user_id = request.form.get("user")
if not user_id and fetch_user_arg.required:
raise ValueError("Arg user must be provided.")

View File

@ -14,16 +14,17 @@ class AgentConfigManager:
agent_dict = config.get("agent_mode", {})
agent_strategy = agent_dict.get("strategy", "cot")
if agent_strategy == "function_call":
strategy = AgentEntity.Strategy.FUNCTION_CALLING
elif agent_strategy in {"cot", "react"}:
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
else:
# old configs, try to detect default strategy
if config["model"]["provider"] == "openai":
match agent_strategy:
case "function_call":
strategy = AgentEntity.Strategy.FUNCTION_CALLING
else:
case "cot" | "react":
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
case _:
# old configs, try to detect default strategy
if config["model"]["provider"] == "openai":
strategy = AgentEntity.Strategy.FUNCTION_CALLING
else:
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
agent_tools = []
for tool in agent_dict.get("tools", []):

View File

@ -250,7 +250,7 @@ class WorkflowResponseConverter:
data=WorkflowFinishStreamResponse.Data(
id=run_id,
workflow_id=workflow_id,
status=status.value,
status=status,
outputs=encoded_outputs,
error=error,
elapsed_time=elapsed_time,
@ -340,13 +340,13 @@ class WorkflowResponseConverter:
metadata = self._merge_metadata(event.execution_metadata, snapshot)
if isinstance(event, QueueNodeSucceededEvent):
status = WorkflowNodeExecutionStatus.SUCCEEDED.value
status = WorkflowNodeExecutionStatus.SUCCEEDED
error_message = event.error
elif isinstance(event, QueueNodeFailedEvent):
status = WorkflowNodeExecutionStatus.FAILED.value
status = WorkflowNodeExecutionStatus.FAILED
error_message = event.error
else:
status = WorkflowNodeExecutionStatus.EXCEPTION.value
status = WorkflowNodeExecutionStatus.EXCEPTION
error_message = event.error
return NodeFinishStreamResponse(
@ -413,7 +413,7 @@ class WorkflowResponseConverter:
process_data_truncated=process_data_truncated,
outputs=outputs,
outputs_truncated=outputs_truncated,
status=WorkflowNodeExecutionStatus.RETRY.value,
status=WorkflowNodeExecutionStatus.RETRY,
error=event.error,
elapsed_time=elapsed_time,
execution_metadata=metadata,

View File

@ -120,7 +120,7 @@ class PipelineGenerator(BaseAppGenerator):
raise ValueError("Pipeline dataset is required")
inputs: Mapping[str, Any] = args["inputs"]
start_node_id: str = args["start_node_id"]
datasource_type: str = args["datasource_type"]
datasource_type = DatasourceProviderType(args["datasource_type"])
datasource_info_list: list[Mapping[str, Any]] = self._format_datasource_info_list(
datasource_type, args["datasource_info_list"], pipeline, workflow, start_node_id, user
)
@ -660,7 +660,7 @@ class PipelineGenerator(BaseAppGenerator):
tenant_id: str,
dataset_id: str,
built_in_field_enabled: bool,
datasource_type: str,
datasource_type: DatasourceProviderType,
datasource_info: Mapping[str, Any],
created_from: str,
position: int,
@ -668,17 +668,17 @@ class PipelineGenerator(BaseAppGenerator):
batch: str,
document_form: str,
):
if datasource_type == "local_file":
name = datasource_info.get("name", "untitled")
elif datasource_type == "online_document":
name = datasource_info.get("page", {}).get("page_name", "untitled")
elif datasource_type == "website_crawl":
name = datasource_info.get("title", "untitled")
elif datasource_type == "online_drive":
name = datasource_info.get("name", "untitled")
else:
raise ValueError(f"Unsupported datasource type: {datasource_type}")
match datasource_type:
case DatasourceProviderType.LOCAL_FILE:
name = datasource_info.get("name", "untitled")
case DatasourceProviderType.ONLINE_DOCUMENT:
name = datasource_info.get("page", {}).get("page_name", "untitled")
case DatasourceProviderType.WEBSITE_CRAWL:
name = datasource_info.get("title", "untitled")
case DatasourceProviderType.ONLINE_DRIVE:
name = datasource_info.get("name", "untitled")
case _:
raise ValueError(f"Unsupported datasource type: {datasource_type}")
document = Document(
tenant_id=tenant_id,
dataset_id=dataset_id,
@ -706,7 +706,7 @@ class PipelineGenerator(BaseAppGenerator):
def _format_datasource_info_list(
self,
datasource_type: str,
datasource_type: DatasourceProviderType,
datasource_info_list: list[Mapping[str, Any]],
pipeline: Pipeline,
workflow: Workflow,
@ -716,7 +716,7 @@ class PipelineGenerator(BaseAppGenerator):
"""
Format datasource info list.
"""
if datasource_type == "online_drive":
if datasource_type == DatasourceProviderType.ONLINE_DRIVE:
all_files: list[Mapping[str, Any]] = []
datasource_node_data = None
datasource_nodes = workflow.graph_dict.get("nodes", [])

View File

@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
class AnnotationReplyAccount(BaseModel):
@ -223,7 +223,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
id: str
workflow_id: str
status: str
status: WorkflowExecutionStatus
outputs: Mapping[str, Any] | None = None
error: str | None = None
elapsed_time: float
@ -311,7 +311,7 @@ class NodeFinishStreamResponse(StreamResponse):
process_data_truncated: bool = False
outputs: Mapping[str, Any] | None = None
outputs_truncated: bool = True
status: str
status: WorkflowNodeExecutionStatus
error: str | None = None
elapsed_time: float
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
@ -375,7 +375,7 @@ class NodeRetryStreamResponse(StreamResponse):
process_data_truncated: bool = False
outputs: Mapping[str, Any] | None = None
outputs_truncated: bool = False
status: str
status: WorkflowNodeExecutionStatus
error: str | None = None
elapsed_time: float
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
@ -719,7 +719,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
id: str
workflow_id: str
status: str
status: WorkflowExecutionStatus
outputs: Mapping[str, Any] | None = None
error: str | None = None
elapsed_time: float

View File

@ -4,13 +4,14 @@ from typing import TYPE_CHECKING, final
from typing_extensions import override
from configs import dify_config
from core.file import file_manager
from core.helper import ssrf_proxy
from core.file.file_manager import file_manager
from core.helper.code_executor.code_executor import CodeExecutor
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.ssrf_proxy import ssrf_proxy
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import NodeType
from core.workflow.graph import NodeFactory
from core.workflow.graph.graph import NodeFactory
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.limits import CodeNodeLimits
@ -22,7 +23,6 @@ from core.workflow.nodes.template_transform.template_renderer import (
Jinja2TemplateRenderer,
)
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from libs.typing import is_str, is_str_dict
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
@ -47,9 +47,9 @@ class DifyNodeFactory(NodeFactory):
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
code_limits: CodeNodeLimits | None = None,
template_renderer: Jinja2TemplateRenderer | None = None,
http_request_http_client: HttpClientProtocol = ssrf_proxy,
http_request_http_client: HttpClientProtocol | None = None,
http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
http_request_file_manager: FileManagerProtocol = file_manager,
http_request_file_manager: FileManagerProtocol | None = None,
) -> None:
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
@ -68,12 +68,12 @@ class DifyNodeFactory(NodeFactory):
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
self._http_request_http_client = http_request_http_client
self._http_request_http_client = http_request_http_client or ssrf_proxy
self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
self._http_request_file_manager = http_request_file_manager
self._http_request_file_manager = http_request_file_manager or file_manager
@override
def create_node(self, node_config: dict[str, object]) -> Node:
def create_node(self, node_config: NodeConfigDict) -> Node:
"""
Create a Node instance from node configuration data using the traditional mapping.
@ -82,23 +82,14 @@ class DifyNodeFactory(NodeFactory):
:raises ValueError: if node type is unknown or configuration is invalid
"""
# Get node_id from config
node_id = node_config.get("id")
if not is_str(node_id):
raise ValueError("Node config missing id")
node_id = node_config["id"]
# Get node type from config
node_data = node_config.get("data", {})
if not is_str_dict(node_data):
raise ValueError(f"Node {node_id} missing data information")
node_type_str = node_data.get("type")
if not is_str(node_type_str):
raise ValueError(f"Node {node_id} missing or invalid type information")
node_data = node_config["data"]
try:
node_type = NodeType(node_type_str)
node_type = NodeType(node_data["type"])
except ValueError:
raise ValueError(f"Unknown node type: {node_type_str}")
raise ValueError(f"Unknown node type: {node_data['type']}")
# Get node class
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)

View File

@ -168,3 +168,18 @@ def _to_url(f: File, /):
return sign_tool_file(tool_file_id=f.related_id, extension=f.extension)
else:
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
class FileManager:
"""
Adapter exposing file manager helpers behind FileManagerProtocol.
This is intentionally a thin wrapper over the existing module-level functions so callers can inject it
where a protocol-typed file manager is expected.
"""
def download(self, f: File, /) -> bytes:
return download(f)
file_manager = FileManager()

View File

@ -47,15 +47,16 @@ class CodeNodeProvider(BaseModel, ABC):
@classmethod
def get_default_config(cls) -> DefaultConfig:
return {
"type": "code",
"config": {
"variables": [
{"variable": "arg1", "value_selector": []},
{"variable": "arg2", "value_selector": []},
],
"code_language": cls.get_language(),
"code": cls.get_default_code(),
"outputs": {"result": {"type": "string", "children": None}},
},
variables: list[VariableConfig] = [
{"variable": "arg1", "value_selector": []},
{"variable": "arg2", "value_selector": []},
]
outputs: dict[str, OutputConfig] = {"result": {"type": "string", "children": None}}
config: CodeConfig = {
"variables": variables,
"code_language": cls.get_language(),
"code": cls.get_default_code(),
"outputs": outputs,
}
return {"type": "code", "config": config}

View File

@ -230,3 +230,41 @@ def delete(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any)
def head(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return make_request("HEAD", url, max_retries=max_retries, **kwargs)
class SSRFProxy:
"""
Adapter exposing SSRF-protected HTTP helpers behind HttpClientProtocol.
This is intentionally a thin wrapper over the existing module-level functions so callers can inject it
where a protocol-typed HTTP client is expected.
"""
@property
def max_retries_exceeded_error(self) -> type[Exception]:
return max_retries_exceeded_error
@property
def request_error(self) -> type[Exception]:
return request_error
def get(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return get(url=url, max_retries=max_retries, **kwargs)
def head(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return head(url=url, max_retries=max_retries, **kwargs)
def post(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return post(url=url, max_retries=max_retries, **kwargs)
def put(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return put(url=url, max_retries=max_retries, **kwargs)
def delete(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return delete(url=url, max_retries=max_retries, **kwargs)
def patch(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return patch(url=url, max_retries=max_retries, **kwargs)
ssrf_proxy = SSRFProxy()

View File

@ -369,77 +369,78 @@ class IndexingRunner:
# Generate summary preview
summary_index_setting = tmp_processing_rule.get("summary_index_setting")
if summary_index_setting and summary_index_setting.get("enable") and preview_texts:
preview_texts = index_processor.generate_summary_preview(tenant_id, preview_texts, summary_index_setting)
preview_texts = index_processor.generate_summary_preview(
tenant_id, preview_texts, summary_index_setting, doc_language
)
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
def _extract(
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
) -> list[Document]:
# load file
if dataset_document.data_source_type not in {"upload_file", "notion_import", "website_crawl"}:
return []
data_source_info = dataset_document.data_source_info_dict
text_docs = []
if dataset_document.data_source_type == "upload_file":
if not data_source_info or "upload_file_id" not in data_source_info:
raise ValueError("no upload file found")
stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"])
file_detail = db.session.scalars(stmt).one_or_none()
match dataset_document.data_source_type:
case "upload_file":
if not data_source_info or "upload_file_id" not in data_source_info:
raise ValueError("no upload file found")
stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"])
file_detail = db.session.scalars(stmt).one_or_none()
if file_detail:
if file_detail:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE,
upload_file=file_detail,
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
case "notion_import":
if (
not data_source_info
or "notion_workspace_id" not in data_source_info
or "notion_page_id" not in data_source_info
):
raise ValueError("no notion import info found")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE,
upload_file=file_detail,
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info.get("credential_id"),
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"document": dataset_document,
"tenant_id": dataset_document.tenant_id,
}
),
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
elif dataset_document.data_source_type == "notion_import":
if (
not data_source_info
or "notion_workspace_id" not in data_source_info
or "notion_page_id" not in data_source_info
):
raise ValueError("no notion import info found")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info.get("credential_id"),
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"document": dataset_document,
"tenant_id": dataset_document.tenant_id,
}
),
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
elif dataset_document.data_source_type == "website_crawl":
if (
not data_source_info
or "provider" not in data_source_info
or "url" not in data_source_info
or "job_id" not in data_source_info
):
raise ValueError("no website import info found")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"tenant_id": dataset_document.tenant_id,
"url": data_source_info["url"],
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
}
),
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
case "website_crawl":
if (
not data_source_info
or "provider" not in data_source_info
or "url" not in data_source_info
or "job_id" not in data_source_info
):
raise ValueError("no website import info found")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"tenant_id": dataset_document.tenant_id,
"url": data_source_info["url"],
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
}
),
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
case _:
return []
# update document status to splitting
self._update_document_index_status(
document_id=dataset_document.id,

View File

@ -0,0 +1,20 @@
"""Shared payload models for LLM generator helpers and controllers."""
from pydantic import BaseModel, Field
from core.app.app_config.entities import ModelConfig
class RuleGeneratePayload(BaseModel):
instruction: str = Field(..., description="Rule generation instruction")
model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")
no_variable: bool = Field(default=False, description="Whether to exclude variables")
class RuleCodeGeneratePayload(RuleGeneratePayload):
code_language: str = Field(default="javascript", description="Programming language for code generation")
class RuleStructuredOutputPayload(BaseModel):
instruction: str = Field(..., description="Structured output generation instruction")
model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")

View File

@ -6,6 +6,8 @@ from typing import Protocol, cast
import json_repair
from core.app.app_config.entities import ModelConfig
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
from core.llm_generator.prompts import (
@ -151,19 +153,19 @@ class LLMGenerator:
return questions
@classmethod
def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool):
def generate_rule_config(cls, tenant_id: str, args: RuleGeneratePayload):
output_parser = RuleConfigGeneratorOutputParser()
error = ""
error_step = ""
rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""}
model_parameters = model_config.get("completion_params", {})
if no_variable:
model_parameters = args.model_config_data.completion_params
if args.no_variable:
prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE)
prompt_generate = prompt_template.format(
inputs={
"TASK_DESCRIPTION": instruction,
"TASK_DESCRIPTION": args.instruction,
},
remove_template_variables=False,
)
@ -175,8 +177,8 @@ class LLMGenerator:
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.get("provider", ""),
model=model_config.get("name", ""),
provider=args.model_config_data.provider,
model=args.model_config_data.name,
)
try:
@ -190,7 +192,7 @@ class LLMGenerator:
error = str(e)
error_step = "generate rule config"
except Exception as e:
logger.exception("Failed to generate rule config, model: %s", model_config.get("name"))
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
rule_config["error"] = str(e)
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
@ -209,7 +211,7 @@ class LLMGenerator:
# format the prompt_generate_prompt
prompt_generate_prompt = prompt_template.format(
inputs={
"TASK_DESCRIPTION": instruction,
"TASK_DESCRIPTION": args.instruction,
},
remove_template_variables=False,
)
@ -220,8 +222,8 @@ class LLMGenerator:
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.get("provider", ""),
model=model_config.get("name", ""),
provider=args.model_config_data.provider,
model=args.model_config_data.name,
)
try:
@ -250,7 +252,7 @@ class LLMGenerator:
# the second step to generate the task_parameter and task_statement
statement_generate_prompt = statement_template.format(
inputs={
"TASK_DESCRIPTION": instruction,
"TASK_DESCRIPTION": args.instruction,
"INPUT_TEXT": prompt_content.message.get_text_content(),
},
remove_template_variables=False,
@ -276,7 +278,7 @@ class LLMGenerator:
error_step = "generate conversation opener"
except Exception as e:
logger.exception("Failed to generate rule config, model: %s", model_config.get("name"))
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
rule_config["error"] = str(e)
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
@ -284,16 +286,20 @@ class LLMGenerator:
return rule_config
@classmethod
def generate_code(cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript"):
if code_language == "python":
def generate_code(
cls,
tenant_id: str,
args: RuleCodeGeneratePayload,
):
if args.code_language == "python":
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
else:
prompt_template = PromptTemplateParser(JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE)
prompt = prompt_template.format(
inputs={
"INSTRUCTION": instruction,
"CODE_LANGUAGE": code_language,
"INSTRUCTION": args.instruction,
"CODE_LANGUAGE": args.code_language,
},
remove_template_variables=False,
)
@ -302,28 +308,28 @@ class LLMGenerator:
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.get("provider", ""),
model=model_config.get("name", ""),
provider=args.model_config_data.provider,
model=args.model_config_data.name,
)
prompt_messages = [UserPromptMessage(content=prompt)]
model_parameters = model_config.get("completion_params", {})
model_parameters = args.model_config_data.completion_params
try:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
generated_code = response.message.get_text_content()
return {"code": generated_code, "language": code_language, "error": ""}
return {"code": generated_code, "language": args.code_language, "error": ""}
except InvokeError as e:
error = str(e)
return {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"}
return {"code": "", "language": args.code_language, "error": f"Failed to generate code. Error: {error}"}
except Exception as e:
logger.exception(
"Failed to invoke LLM model, model: %s, language: %s", model_config.get("name"), code_language
"Failed to invoke LLM model, model: %s, language: %s", args.model_config_data.name, args.code_language
)
return {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"}
return {"code": "", "language": args.code_language, "error": f"An unexpected error occurred: {str(e)}"}
@classmethod
def generate_qa_document(cls, tenant_id: str, query, document_language: str):
@ -353,20 +359,20 @@ class LLMGenerator:
return answer.strip()
@classmethod
def generate_structured_output(cls, tenant_id: str, instruction: str, model_config: dict):
def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload):
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.get("provider", ""),
model=model_config.get("name", ""),
provider=args.model_config_data.provider,
model=args.model_config_data.name,
)
prompt_messages = [
SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE),
UserPromptMessage(content=instruction),
UserPromptMessage(content=args.instruction),
]
model_parameters = model_config.get("model_parameters", {})
model_parameters = args.model_config_data.completion_params
try:
response: LLMResult = model_instance.invoke_llm(
@ -390,12 +396,17 @@ class LLMGenerator:
error = str(e)
return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"}
except Exception as e:
logger.exception("Failed to invoke LLM model, model: %s", model_config.get("name"))
logger.exception("Failed to invoke LLM model, model: %s", args.model_config_data.name)
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
@staticmethod
def instruction_modify_legacy(
tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None
tenant_id: str,
flow_id: str,
current: str,
instruction: str,
model_config: ModelConfig,
ideal_output: str | None,
):
last_run: Message | None = (
db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first()
@ -434,7 +445,7 @@ class LLMGenerator:
node_id: str,
current: str,
instruction: str,
model_config: dict,
model_config: ModelConfig,
ideal_output: str | None,
workflow_service: WorkflowServiceInterface,
):
@ -505,7 +516,7 @@ class LLMGenerator:
@staticmethod
def __instruction_modify_common(
tenant_id: str,
model_config: dict,
model_config: ModelConfig,
last_run: dict | None,
current: str | None,
error_message: str | None,
@ -526,8 +537,8 @@ class LLMGenerator:
model_instance = ModelManager().get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.get("provider", ""),
model=model_config.get("name", ""),
provider=model_config.provider,
model=model_config.name,
)
match node_type:
case "llm" | "agent":
@ -570,7 +581,5 @@ class LLMGenerator:
error = str(e)
return {"error": f"Failed to generate code. Error: {error}"}
except Exception as e:
logger.exception(
"Failed to invoke LLM model, model: %s", json.dumps(model_config.get("name")), exc_info=True
)
logger.exception("Failed to invoke LLM model, model: %s", json.dumps(model_config.name), exc_info=True)
return {"error": f"An unexpected error occurred: {str(e)}"}

View File

@ -441,11 +441,13 @@ DEFAULT_GENERATOR_SUMMARY_PROMPT = (
Requirements:
1. Write a concise summary in plain text
2. Use the same language as the input content
2. You must write in {language}. No language other than {language} should be used.
3. Focus on important facts, concepts, and details
4. If images are included, describe their key information
5. Do not use words like "好的", "ok", "I understand", "This text discusses", "The content mentions"
6. Write directly without extra words
7. If there is not enough content to generate a meaningful summary,
return an empty string without any explanation or prompt
Output only the summary text. Start summarizing now:

View File

@ -347,7 +347,7 @@ class BaseSession(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
responder = RequestResponder(
responder = RequestResponder[ReceiveRequestT, SendResultT](
request_id=message.message.root.id,
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
request=validated_request,

View File

@ -88,7 +88,7 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
DefaultParameterName.MAX_TOKENS: {
"label": {
"en_US": "Max Tokens",
"zh_Hans": "最大标记",
"zh_Hans": "最大 Token 数",
},
"type": "int",
"help": {

View File

@ -92,6 +92,10 @@ def _build_llm_result_from_first_chunk(
Build a single `LLMResult` from the first returned chunk.
This is used for `stream=False` because the plugin side may still implement the response via a chunked stream.
Note:
This function always drains the `chunks` iterator after reading the first chunk to ensure any underlying
streaming resources are released (e.g., HTTP connections owned by the plugin runtime).
"""
content = ""
content_list: list[PromptMessageContentUnionTypes] = []
@ -99,18 +103,25 @@ def _build_llm_result_from_first_chunk(
system_fingerprint: str | None = None
tools_calls: list[AssistantPromptMessage.ToolCall] = []
first_chunk = next(chunks, None)
if first_chunk is not None:
if isinstance(first_chunk.delta.message.content, str):
content += first_chunk.delta.message.content
elif isinstance(first_chunk.delta.message.content, list):
content_list.extend(first_chunk.delta.message.content)
try:
first_chunk = next(chunks, None)
if first_chunk is not None:
if isinstance(first_chunk.delta.message.content, str):
content += first_chunk.delta.message.content
elif isinstance(first_chunk.delta.message.content, list):
content_list.extend(first_chunk.delta.message.content)
if first_chunk.delta.message.tool_calls:
_increase_tool_call(first_chunk.delta.message.tool_calls, tools_calls)
if first_chunk.delta.message.tool_calls:
_increase_tool_call(first_chunk.delta.message.tool_calls, tools_calls)
usage = first_chunk.delta.usage or LLMUsage.empty_usage()
system_fingerprint = first_chunk.system_fingerprint
usage = first_chunk.delta.usage or LLMUsage.empty_usage()
system_fingerprint = first_chunk.system_fingerprint
finally:
try:
for _ in chunks:
pass
except Exception:
logger.debug("Failed to drain non-stream plugin chunk iterator.", exc_info=True)
return LLMResult(
model=model,
@ -283,7 +294,7 @@ class LargeLanguageModel(AIModel):
# TODO
raise self._transform_invoke_error(e)
if stream and isinstance(result, Generator):
if stream and not isinstance(result, LLMResult):
return self._invoke_result_generator(
model=model,
result=result,

View File

@ -314,6 +314,8 @@ class ModelProviderFactory:
elif model_type == ModelType.TTS:
return TTSModel.model_validate(init_params)
raise ValueError(f"Unsupported model type: {model_type}")
def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
"""
Get provider icon

View File

@ -48,12 +48,22 @@ class BaseIndexProcessor(ABC):
@abstractmethod
def generate_summary_preview(
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
self,
tenant_id: str,
preview_texts: list[PreviewDetail],
summary_index_setting: dict,
doc_language: str | None = None,
) -> list[PreviewDetail]:
"""
For each segment in preview_texts, generate a summary using LLM and attach it to the segment.
The summary can be stored in a new attribute, e.g., summary.
This method should be implemented by subclasses.
Args:
tenant_id: Tenant ID
preview_texts: List of preview details to generate summaries for
summary_index_setting: Summary index configuration
doc_language: Optional document language to ensure summary is generated in the correct language
"""
raise NotImplementedError

View File

@ -275,7 +275,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
raise ValueError("Chunks is not a list")
def generate_summary_preview(
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
self,
tenant_id: str,
preview_texts: list[PreviewDetail],
summary_index_setting: dict,
doc_language: str | None = None,
) -> list[PreviewDetail]:
"""
For each segment, concurrently call generate_summary to generate a summary
@ -298,11 +302,15 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
if flask_app:
# Ensure Flask app context in worker thread
with flask_app.app_context():
summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting)
summary, _ = self.generate_summary(
tenant_id, preview.content, summary_index_setting, document_language=doc_language
)
preview.summary = summary
else:
# Fallback: try without app context (may fail)
summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting)
summary, _ = self.generate_summary(
tenant_id, preview.content, summary_index_setting, document_language=doc_language
)
preview.summary = summary
# Generate summaries concurrently using ThreadPoolExecutor
@ -356,6 +364,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
text: str,
summary_index_setting: dict | None = None,
segment_id: str | None = None,
document_language: str | None = None,
) -> tuple[str, LLMUsage]:
"""
Generate summary for the given text using ModelInstance.invoke_llm and the default or custom summary prompt,
@ -366,6 +375,8 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
text: Text content to summarize
summary_index_setting: Summary index configuration
segment_id: Optional segment ID to fetch attachments from SegmentAttachmentBinding table
document_language: Optional document language (e.g., "Chinese", "English")
to ensure summary is generated in the correct language
Returns:
Tuple of (summary_content, llm_usage) where llm_usage is LLMUsage object
@ -381,8 +392,22 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
raise ValueError("model_name and model_provider_name are required in summary_index_setting")
# Import default summary prompt
is_default_prompt = False
if not summary_prompt:
summary_prompt = DEFAULT_GENERATOR_SUMMARY_PROMPT
is_default_prompt = True
# Format prompt with document language only for default prompt
# Custom prompts are used as-is to avoid interfering with user-defined templates
# If document_language is provided, use it; otherwise, use "the same language as the input content"
# This is especially important for image-only chunks where text is empty or minimal
if is_default_prompt:
language_for_prompt = document_language or "the same language as the input content"
try:
summary_prompt = summary_prompt.format(language=language_for_prompt)
except KeyError:
# If default prompt doesn't have {language} placeholder, use it as-is
pass
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(

View File

@ -358,7 +358,11 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
}
def generate_summary_preview(
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
self,
tenant_id: str,
preview_texts: list[PreviewDetail],
summary_index_setting: dict,
doc_language: str | None = None,
) -> list[PreviewDetail]:
"""
For each parent chunk in preview_texts, concurrently call generate_summary to generate a summary
@ -389,6 +393,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
tenant_id=tenant_id,
text=preview.content,
summary_index_setting=summary_index_setting,
document_language=doc_language,
)
preview.summary = summary
else:
@ -397,6 +402,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
tenant_id=tenant_id,
text=preview.content,
summary_index_setting=summary_index_setting,
document_language=doc_language,
)
preview.summary = summary

View File

@ -241,7 +241,11 @@ class QAIndexProcessor(BaseIndexProcessor):
}
def generate_summary_preview(
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
self,
tenant_id: str,
preview_texts: list[PreviewDetail],
summary_index_setting: dict,
doc_language: str | None = None,
) -> list[PreviewDetail]:
"""
QA model doesn't generate summaries, so this method returns preview_texts unchanged.

View File

@ -35,6 +35,7 @@ class SchemaRegistry:
registry.load_all_versions()
cls._default_instance = registry
return cls._default_instance
return cls._default_instance

View File

@ -189,16 +189,13 @@ class ToolManager:
raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
if not provider_controller.need_credentials:
return cast(
BuiltinTool,
builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
),
return builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
)
builtin_provider = None
if isinstance(provider_controller, PluginToolProviderController):
@ -300,18 +297,15 @@ class ToolManager:
decrypted_credentials = refreshed_credentials.credentials
cache.delete()
return cast(
BuiltinTool,
builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials=dict(decrypted_credentials),
credential_type=CredentialType.of(builtin_provider.credential_type),
runtime_parameters={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
),
return builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials=dict(decrypted_credentials),
credential_type=CredentialType.of(builtin_provider.credential_type),
runtime_parameters={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
)
elif provider_type == ToolProviderType.API:

View File

@ -7,11 +7,6 @@ from core.workflow.nodes.base.entities import OutputVariableEntity
class WorkflowToolConfigurationUtils:
@classmethod
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
for configuration in configurations:
WorkflowToolParameterConfiguration.model_validate(configuration)
@classmethod
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
"""

View File

@ -23,8 +23,8 @@ class TriggerDebugEventBus:
"""
# LUA_SELECT: Atomic poll or register for event
# KEYS[1] = trigger_debug_inbox:{tenant_id}:{address_id}
# KEYS[2] = trigger_debug_waiting_pool:{tenant_id}:...
# KEYS[1] = trigger_debug_inbox:{<tenant_id>}:<address_id>
# KEYS[2] = trigger_debug_waiting_pool:{<tenant_id>}:...
# ARGV[1] = address_id
LUA_SELECT = (
"local v=redis.call('GET',KEYS[1]);"
@ -35,7 +35,7 @@ class TriggerDebugEventBus:
)
# LUA_DISPATCH: Dispatch event to all waiting addresses
# KEYS[1] = trigger_debug_waiting_pool:{tenant_id}:...
# KEYS[1] = trigger_debug_waiting_pool:{<tenant_id>}:...
# ARGV[1] = tenant_id
# ARGV[2] = event_json
LUA_DISPATCH = (
@ -43,7 +43,7 @@ class TriggerDebugEventBus:
"if #a==0 then return 0 end;"
"redis.call('DEL',KEYS[1]);"
"for i=1,#a do "
f"redis.call('SET','trigger_debug_inbox:'..ARGV[1]..':'..a[i],ARGV[2],'EX',{TRIGGER_DEBUG_EVENT_TTL});"
f"redis.call('SET','trigger_debug_inbox:{{'..ARGV[1]..'}}'..':'..a[i],ARGV[2],'EX',{TRIGGER_DEBUG_EVENT_TTL});"
"end;"
"return #a"
)
@ -108,7 +108,7 @@ class TriggerDebugEventBus:
Event object if available, None otherwise
"""
address_id: str = hashlib.sha256(f"{user_id}|{app_id}|{node_id}".encode()).hexdigest()
address: str = f"trigger_debug_inbox:{tenant_id}:{address_id}"
address: str = f"trigger_debug_inbox:{{{tenant_id}}}:{address_id}"
try:
event_data = redis_client.eval(

View File

@ -42,7 +42,7 @@ def build_webhook_pool_key(tenant_id: str, app_id: str, node_id: str) -> str:
app_id: App ID
node_id: Node ID
"""
return f"{TriggerDebugPoolKey.WEBHOOK}:{tenant_id}:{app_id}:{node_id}"
return f"{TriggerDebugPoolKey.WEBHOOK}:{{{tenant_id}}}:{app_id}:{node_id}"
class PluginTriggerDebugEvent(BaseDebugEvent):
@ -64,4 +64,4 @@ def build_plugin_pool_key(tenant_id: str, provider_id: str, subscription_id: str
provider_id: Provider ID
subscription_id: Subscription ID
"""
return f"{TriggerDebugPoolKey.PLUGIN}:{tenant_id}:{str(provider_id)}:{subscription_id}:{name}"
return f"{TriggerDebugPoolKey.PLUGIN}:{{{tenant_id}}}:{str(provider_id)}:{subscription_id}:{name}"

View File

@ -0,0 +1,24 @@
from __future__ import annotations
import sys
from pydantic import TypeAdapter, with_config
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict
@with_config(extra="allow")
class NodeConfigData(TypedDict):
type: str
@with_config(extra="allow")
class NodeConfigDict(TypedDict):
id: str
data: NodeConfigData
NodeConfigDictAdapter = TypeAdapter(NodeConfigDict)

View File

@ -5,15 +5,20 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Protocol, cast, final
from pydantic import TypeAdapter
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType
from core.workflow.nodes.base.node import Node
from libs.typing import is_str, is_str_dict
from libs.typing import is_str
from .edge import Edge
from .validation import get_graph_validator
logger = logging.getLogger(__name__)
_ListNodeConfigDict = TypeAdapter(list[NodeConfigDict])
class NodeFactory(Protocol):
"""
@ -23,7 +28,7 @@ class NodeFactory(Protocol):
allowing for different node creation strategies while maintaining type safety.
"""
def create_node(self, node_config: dict[str, object]) -> Node:
def create_node(self, node_config: NodeConfigDict) -> Node:
"""
Create a Node instance from node configuration data.
@ -63,28 +68,24 @@ class Graph:
self.root_node = root_node
@classmethod
def _parse_node_configs(cls, node_configs: list[dict[str, object]]) -> dict[str, dict[str, object]]:
def _parse_node_configs(cls, node_configs: list[NodeConfigDict]) -> dict[str, NodeConfigDict]:
"""
Parse node configurations and build a mapping of node IDs to configs.
:param node_configs: list of node configuration dictionaries
:return: mapping of node ID to node config
"""
node_configs_map: dict[str, dict[str, object]] = {}
node_configs_map: dict[str, NodeConfigDict] = {}
for node_config in node_configs:
node_id = node_config.get("id")
if not node_id or not isinstance(node_id, str):
continue
node_configs_map[node_id] = node_config
node_configs_map[node_config["id"]] = node_config
return node_configs_map
@classmethod
def _find_root_node_id(
cls,
node_configs_map: Mapping[str, Mapping[str, object]],
node_configs_map: Mapping[str, NodeConfigDict],
edge_configs: Sequence[Mapping[str, object]],
root_node_id: str | None = None,
) -> str:
@ -113,10 +114,8 @@ class Graph:
# Prefer START node if available
start_node_id = None
for nid in root_candidates:
node_data = node_configs_map[nid].get("data")
if not is_str_dict(node_data):
continue
node_type = node_data.get("type")
node_data = node_configs_map[nid]["data"]
node_type = node_data["type"]
if not isinstance(node_type, str):
continue
if NodeType(node_type).is_start_node:
@ -176,7 +175,7 @@ class Graph:
@classmethod
def _create_node_instances(
cls,
node_configs_map: dict[str, dict[str, object]],
node_configs_map: dict[str, NodeConfigDict],
node_factory: NodeFactory,
) -> dict[str, Node]:
"""
@ -303,7 +302,7 @@ class Graph:
node_configs = graph_config.get("nodes", [])
edge_configs = cast(list[dict[str, object]], edge_configs)
node_configs = cast(list[dict[str, object]], node_configs)
node_configs = _ListNodeConfigDict.validate_python(node_configs)
if not node_configs:
raise ValueError("Graph must have at least one node")

View File

@ -46,7 +46,6 @@ from .graph_traversal import EdgeProcessor, SkipPropagator
from .layers.base import GraphEngineLayer
from .orchestration import Dispatcher, ExecutionCoordinator
from .protocols.command_channel import CommandChannel
from .ready_queue import ReadyQueue
from .worker_management import WorkerPool
if TYPE_CHECKING:
@ -90,7 +89,7 @@ class GraphEngine:
self._graph_execution.workflow_id = workflow_id
# === Execution Queues ===
self._ready_queue = cast(ReadyQueue, self._graph_runtime_state.ready_queue)
self._ready_queue = self._graph_runtime_state.ready_queue
# Queue for events generated during execution
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()

View File

@ -15,10 +15,10 @@ from uuid import uuid4
from pydantic import BaseModel, Field
from core.workflow.enums import NodeExecutionType, NodeState
from core.workflow.graph import Graph
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
from core.workflow.nodes.base.template import TextSegment, VariableSegment
from core.workflow.runtime import VariablePool
from core.workflow.runtime.graph_runtime_state import GraphProtocol
from .path import Path
from .session import ResponseSession
@ -75,7 +75,7 @@ class ResponseStreamCoordinator:
Ensures ordered streaming of responses based on upstream node outputs and constants.
"""
def __init__(self, variable_pool: "VariablePool", graph: "Graph") -> None:
def __init__(self, variable_pool: "VariablePool", graph: GraphProtocol) -> None:
"""
Initialize coordinator with variable pool.

View File

@ -10,10 +10,10 @@ from __future__ import annotations
from dataclasses import dataclass
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
from core.workflow.runtime.graph_runtime_state import NodeProtocol
@dataclass
@ -29,21 +29,26 @@ class ResponseSession:
index: int = 0 # Current position in the template segments
@classmethod
def from_node(cls, node: Node) -> ResponseSession:
def from_node(cls, node: NodeProtocol) -> ResponseSession:
"""
Create a ResponseSession from an AnswerNode or EndNode.
Create a ResponseSession from a response-capable node.
The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer,
but at runtime this must be an `AnswerNode`, `EndNode`, or `KnowledgeIndexNode` that provides:
- `id: str`
- `get_streaming_template() -> Template`
Args:
node: Must be either an AnswerNode or EndNode instance
node: Node from the materialized workflow graph.
Returns:
ResponseSession configured with the node's streaming template
Raises:
TypeError: If node is not an AnswerNode or EndNode
TypeError: If node is not a supported response node type.
"""
if not isinstance(node, AnswerNode | EndNode | KnowledgeIndexNode):
raise TypeError
raise TypeError("ResponseSession.from_node only supports AnswerNode, EndNode, or KnowledgeIndexNode")
return cls(
node_id=node.id,
template=node.get_streaming_template(),

View File

@ -192,32 +192,33 @@ class AgentNode(Node[AgentNodeData]):
result[parameter_name] = None
continue
agent_input = node_data.agent_parameters[parameter_name]
if agent_input.type == "variable":
variable = variable_pool.get(agent_input.value) # type: ignore
if variable is None:
raise AgentVariableNotFoundError(str(agent_input.value))
parameter_value = variable.value
elif agent_input.type in {"mixed", "constant"}:
# variable_pool.convert_template expects a string template,
# but if passing a dict, convert to JSON string first before rendering
try:
if not isinstance(agent_input.value, str):
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
else:
match agent_input.type:
case "variable":
variable = variable_pool.get(agent_input.value) # type: ignore
if variable is None:
raise AgentVariableNotFoundError(str(agent_input.value))
parameter_value = variable.value
case "mixed" | "constant":
# variable_pool.convert_template expects a string template,
# but if passing a dict, convert to JSON string first before rendering
try:
if not isinstance(agent_input.value, str):
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
else:
parameter_value = str(agent_input.value)
except TypeError:
parameter_value = str(agent_input.value)
except TypeError:
parameter_value = str(agent_input.value)
segment_group = variable_pool.convert_template(parameter_value)
parameter_value = segment_group.log if for_log else segment_group.text
# variable_pool.convert_template returns a string,
# so we need to convert it back to a dictionary
try:
if not isinstance(agent_input.value, str):
parameter_value = json.loads(parameter_value)
except json.JSONDecodeError:
parameter_value = parameter_value
else:
raise AgentInputTypeError(agent_input.type)
segment_group = variable_pool.convert_template(parameter_value)
parameter_value = segment_group.log if for_log else segment_group.text
# variable_pool.convert_template returns a string,
# so we need to convert it back to a dictionary
try:
if not isinstance(agent_input.value, str):
parameter_value = json.loads(parameter_value)
except json.JSONDecodeError:
parameter_value = parameter_value
case _:
raise AgentInputTypeError(agent_input.type)
value = parameter_value
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
@ -374,12 +375,13 @@ class AgentNode(Node[AgentNodeData]):
result: dict[str, Any] = {}
for parameter_name in typed_node_data.agent_parameters:
input = typed_node_data.agent_parameters[parameter_name]
if input.type in ["mixed", "constant"]:
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == "variable":
result[parameter_name] = input.value
match input.type:
case "mixed" | "constant":
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
case "variable":
result[parameter_name] = input.value
result = {node_id + "." + key: value for key, value in result.items()}

View File

@ -115,7 +115,7 @@ class DefaultValue(BaseModel):
@model_validator(mode="after")
def validate_value_type(self) -> DefaultValue:
# Type validation configuration
type_validators = {
type_validators: dict[DefaultValueType, dict[str, Any]] = {
DefaultValueType.STRING: {
"type": str,
"converter": lambda x: x,

View File

@ -1,4 +1,4 @@
from typing import Annotated, Literal, Self
from typing import Annotated, Literal
from pydantic import AfterValidator, BaseModel
@ -34,7 +34,7 @@ class CodeNodeData(BaseNodeData):
class Output(BaseModel):
type: Annotated[SegmentType, AfterValidator(_validate_type)]
children: dict[str, Self] | None = None
children: dict[str, "CodeNodeData.Output"] | None = None
class Dependency(BaseModel):
name: str

View File

@ -69,11 +69,13 @@ class DatasourceNode(Node[DatasourceNodeData]):
if datasource_type is None:
raise DatasourceNodeError("Datasource type is not set")
datasource_type = DatasourceProviderType.value_of(datasource_type)
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{node_data.plugin_id}/{node_data.provider_name}",
datasource_name=node_data.datasource_name or "",
tenant_id=self.tenant_id,
datasource_type=DatasourceProviderType.value_of(datasource_type),
datasource_type=datasource_type,
)
datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id)
@ -268,15 +270,18 @@ class DatasourceNode(Node[DatasourceNodeData]):
if typed_node_data.datasource_parameters:
for parameter_name in typed_node_data.datasource_parameters:
input = typed_node_data.datasource_parameters[parameter_name]
if input.type == "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == "variable":
result[parameter_name] = input.value
elif input.type == "constant":
pass
match input.type:
case "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
case "variable":
result[parameter_name] = input.value
case "constant":
pass
case None:
pass
result = {node_id + "." + key: value for key, value in result.items()}
@ -306,99 +311,107 @@ class DatasourceNode(Node[DatasourceNodeData]):
variables: dict[str, Any] = {}
for message in message_stream:
if message.type in {
DatasourceMessage.MessageType.IMAGE_LINK,
DatasourceMessage.MessageType.BINARY_LINK,
DatasourceMessage.MessageType.IMAGE,
}:
assert isinstance(message.message, DatasourceMessage.TextMessage)
match message.type:
case (
DatasourceMessage.MessageType.IMAGE_LINK
| DatasourceMessage.MessageType.BINARY_LINK
| DatasourceMessage.MessageType.IMAGE
):
assert isinstance(message.message, DatasourceMessage.TextMessage)
url = message.message.text
transfer_method = FileTransferMethod.TOOL_FILE
url = message.message.text
transfer_method = FileTransferMethod.TOOL_FILE
datasource_file_id = str(url).split("/")[-1].split(".")[0]
datasource_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
mapping = {
"tool_file_id": datasource_file_id,
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
files.append(file)
elif message.type == DatasourceMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, DatasourceMessage.TextMessage)
assert message.meta
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
mapping = {
"tool_file_id": datasource_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping = {
"tool_file_id": datasource_file_id,
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
)
elif message.type == DatasourceMessage.MessageType.TEXT:
assert isinstance(message.message, DatasourceMessage.TextMessage)
text += message.message.text
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=message.message.text,
is_final=False,
)
elif message.type == DatasourceMessage.MessageType.JSON:
assert isinstance(message.message, DatasourceMessage.JsonMessage)
json.append(message.message.json_object)
elif message.type == DatasourceMessage.MessageType.LINK:
assert isinstance(message.message, DatasourceMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=stream_text,
is_final=False,
)
elif message.type == DatasourceMessage.MessageType.VARIABLE:
assert isinstance(message.message, DatasourceMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
files.append(file)
case DatasourceMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, DatasourceMessage.TextMessage)
assert message.meta
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
mapping = {
"tool_file_id": datasource_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
)
case DatasourceMessage.MessageType.TEXT:
assert isinstance(message.message, DatasourceMessage.TextMessage)
text += message.message.text
yield StreamChunkEvent(
selector=[self._node_id, variable_name],
chunk=variable_value,
selector=[self._node_id, "text"],
chunk=message.message.text,
is_final=False,
)
else:
variables[variable_name] = variable_value
elif message.type == DatasourceMessage.MessageType.FILE:
assert message.meta is not None
files.append(message.meta["file"])
case DatasourceMessage.MessageType.JSON:
assert isinstance(message.message, DatasourceMessage.JsonMessage)
json.append(message.message.json_object)
case DatasourceMessage.MessageType.LINK:
assert isinstance(message.message, DatasourceMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=stream_text,
is_final=False,
)
case DatasourceMessage.MessageType.VARIABLE:
assert isinstance(message.message, DatasourceMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield StreamChunkEvent(
selector=[self._node_id, variable_name],
chunk=variable_value,
is_final=False,
)
else:
variables[variable_name] = variable_value
case DatasourceMessage.MessageType.FILE:
assert message.meta is not None
files.append(message.meta["file"])
case (
DatasourceMessage.MessageType.BLOB_CHUNK
| DatasourceMessage.MessageType.LOG
| DatasourceMessage.MessageType.RETRIEVER_RESOURCES
):
pass
# mark the end of the stream
yield StreamChunkEvent(
selector=[self._node_id, "text"],

View File

@ -2,7 +2,7 @@ import base64
import json
import secrets
import string
from collections.abc import Mapping
from collections.abc import Callable, Mapping
from copy import deepcopy
from typing import Any, Literal
from urllib.parse import urlencode, urlparse
@ -11,9 +11,9 @@ import httpx
from json_repair import repair_json
from configs import dify_config
from core.file import file_manager
from core.file.enums import FileTransferMethod
from core.helper import ssrf_proxy
from core.file.file_manager import file_manager as default_file_manager
from core.helper.ssrf_proxy import ssrf_proxy
from core.variables.segments import ArrayFileSegment, FileSegment
from core.workflow.runtime import VariablePool
@ -79,8 +79,8 @@ class Executor:
timeout: HttpRequestNodeTimeout,
variable_pool: VariablePool,
max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES,
http_client: HttpClientProtocol = ssrf_proxy,
file_manager: FileManagerProtocol = file_manager,
http_client: HttpClientProtocol | None = None,
file_manager: FileManagerProtocol | None = None,
):
# If authorization API key is present, convert the API key using the variable pool
if node_data.authorization.type == "api-key":
@ -107,8 +107,8 @@ class Executor:
self.data = None
self.json = None
self.max_retries = max_retries
self._http_client = http_client
self._file_manager = file_manager
self._http_client = http_client or ssrf_proxy
self._file_manager = file_manager or default_file_manager
# init template
self.variable_pool = variable_pool
@ -336,7 +336,7 @@ class Executor:
"""
do http request depending on api bundle
"""
_METHOD_MAP = {
_METHOD_MAP: dict[str, Callable[..., httpx.Response]] = {
"get": self._http_client.get,
"head": self._http_client.head,
"post": self._http_client.post,
@ -348,7 +348,7 @@ class Executor:
if method_lc not in _METHOD_MAP:
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
request_args = {
request_args: dict[str, Any] = {
"data": self.data,
"files": self.files,
"json": self.json,
@ -361,14 +361,13 @@ class Executor:
}
# request_args = {k: v for k, v in request_args.items() if v is not None}
try:
response: httpx.Response = _METHOD_MAP[method_lc](
response = _METHOD_MAP[method_lc](
url=self.url,
**request_args,
max_retries=self.max_retries,
)
except (self._http_client.max_retries_exceeded_error, self._http_client.request_error) as e:
raise HttpRequestNodeError(str(e)) from e
# FIXME: fix type ignore, this maybe httpx type issue
return response
def invoke(self) -> Response:

View File

@ -4,8 +4,9 @@ from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from configs import dify_config
from core.file import File, FileTransferMethod, file_manager
from core.helper import ssrf_proxy
from core.file import File, FileTransferMethod
from core.file.file_manager import file_manager as default_file_manager
from core.helper.ssrf_proxy import ssrf_proxy
from core.tools.tool_file_manager import ToolFileManager
from core.variables.segments import ArrayFileSegment
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@ -47,9 +48,9 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
http_client: HttpClientProtocol = ssrf_proxy,
http_client: HttpClientProtocol | None = None,
tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
file_manager: FileManagerProtocol = file_manager,
file_manager: FileManagerProtocol | None = None,
) -> None:
super().__init__(
id=id,
@ -57,9 +58,9 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._http_client = http_client
self._http_client = http_client or ssrf_proxy
self._tool_file_manager_factory = tool_file_manager_factory
self._file_manager = file_manager
self._file_manager = file_manager or default_file_manager
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:

View File

@ -397,7 +397,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
return outputs
# Check if all non-None outputs are lists
non_none_outputs = [output for output in outputs if output is not None]
non_none_outputs: list[object] = [output for output in outputs if output is not None]
if not non_none_outputs:
return outputs

View File

@ -78,12 +78,21 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
indexing_technique = node_data.indexing_technique or dataset.indexing_technique
summary_index_setting = node_data.summary_index_setting or dataset.summary_index_setting
# Try to get document language if document_id is available
doc_language = None
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
if document_id:
document = db.session.query(Document).filter_by(id=document_id.value).first()
if document and document.doc_language:
doc_language = document.doc_language
outputs = self._get_preview_output_with_summaries(
node_data.chunk_structure,
chunks,
dataset=dataset,
indexing_technique=indexing_technique,
summary_index_setting=summary_index_setting,
doc_language=doc_language,
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -315,6 +324,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
dataset: Dataset,
indexing_technique: str | None = None,
summary_index_setting: dict | None = None,
doc_language: str | None = None,
) -> Mapping[str, Any]:
"""
Generate preview output with summaries for chunks in preview mode.
@ -326,6 +336,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
dataset: Dataset object (for tenant_id)
indexing_technique: Indexing technique from node config or dataset
summary_index_setting: Summary index setting from node config or dataset
doc_language: Optional document language to ensure summary is generated in the correct language
"""
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
preview_output = index_processor.format_preview(chunks)
@ -365,6 +376,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
tenant_id=dataset.tenant_id,
text=preview_item["content"],
summary_index_setting=summary_index_setting,
document_language=doc_language,
)
if summary:
preview_item["summary"] = summary
@ -374,6 +386,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
tenant_id=dataset.tenant_id,
text=preview_item["content"],
summary_index_setting=summary_index_setting,
document_language=doc_language,
)
if summary:
preview_item["summary"] = summary

View File

@ -303,33 +303,34 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
if node_data.multiple_retrieval_config is None:
raise ValueError("multiple_retrieval_config is required")
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
if node_data.multiple_retrieval_config.reranking_model:
reranking_model = {
"reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
"reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
}
else:
match node_data.multiple_retrieval_config.reranking_mode:
case "reranking_model":
if node_data.multiple_retrieval_config.reranking_model:
reranking_model = {
"reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
"reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
}
else:
reranking_model = None
weights = None
case "weighted_score":
if node_data.multiple_retrieval_config.weights is None:
raise ValueError("weights is required")
reranking_model = None
weights = None
elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score":
if node_data.multiple_retrieval_config.weights is None:
raise ValueError("weights is required")
reranking_model = None
vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
weights = {
"vector_setting": {
"vector_weight": vector_setting.vector_weight,
"embedding_provider_name": vector_setting.embedding_provider_name,
"embedding_model_name": vector_setting.embedding_model_name,
},
"keyword_setting": {
"keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight
},
}
else:
reranking_model = None
weights = None
vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
weights = {
"vector_setting": {
"vector_weight": vector_setting.vector_weight,
"embedding_provider_name": vector_setting.embedding_provider_name,
"embedding_model_name": vector_setting.embedding_model_name,
},
"keyword_setting": {
"keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight
},
}
case _:
reranking_model = None
weights = None
all_documents = dataset_retrieval.multiple_retrieve(
app_id=self.app_id,
tenant_id=self.tenant_id,
@ -453,73 +454,74 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
)
filters: list[Any] = []
metadata_condition = None
if node_data.metadata_filtering_mode == "disabled":
return None, None, usage
elif node_data.metadata_filtering_mode == "automatic":
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
dataset_ids, query, node_data
)
usage = self._merge_usage(usage, automatic_usage)
if automatic_metadata_filters:
conditions = []
for sequence, filter in enumerate(automatic_metadata_filters):
DatasetRetrieval.process_metadata_filter_func(
sequence,
filter.get("condition", ""),
filter.get("metadata_name", ""),
filter.get("value"),
filters,
)
conditions.append(
Condition(
name=filter.get("metadata_name"), # type: ignore
comparison_operator=filter.get("condition"), # type: ignore
value=filter.get("value"),
)
)
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator
if node_data.metadata_filtering_conditions
else "or",
conditions=conditions,
match node_data.metadata_filtering_mode:
case "disabled":
return None, None, usage
case "automatic":
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
dataset_ids, query, node_data
)
elif node_data.metadata_filtering_mode == "manual":
if node_data.metadata_filtering_conditions:
conditions = []
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
metadata_name = condition.name
expected_value = condition.value
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
if isinstance(expected_value, str):
expected_value = self.graph_runtime_state.variable_pool.convert_template(
expected_value
).value[0]
if expected_value.value_type in {"number", "integer", "float"}:
expected_value = expected_value.value
elif expected_value.value_type == "string":
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
else:
raise ValueError("Invalid expected metadata value type")
conditions.append(
Condition(
name=metadata_name,
comparison_operator=condition.comparison_operator,
value=expected_value,
usage = self._merge_usage(usage, automatic_usage)
if automatic_metadata_filters:
conditions = []
for sequence, filter in enumerate(automatic_metadata_filters):
DatasetRetrieval.process_metadata_filter_func(
sequence,
filter.get("condition", ""),
filter.get("metadata_name", ""),
filter.get("value"),
filters,
)
conditions.append(
Condition(
name=filter.get("metadata_name"), # type: ignore
comparison_operator=filter.get("condition"), # type: ignore
value=filter.get("value"),
)
)
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator
if node_data.metadata_filtering_conditions
else "or",
conditions=conditions,
)
filters = DatasetRetrieval.process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
expected_value,
filters,
case "manual":
if node_data.metadata_filtering_conditions:
conditions = []
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
metadata_name = condition.name
expected_value = condition.value
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
if isinstance(expected_value, str):
expected_value = self.graph_runtime_state.variable_pool.convert_template(
expected_value
).value[0]
if expected_value.value_type in {"number", "integer", "float"}:
expected_value = expected_value.value
elif expected_value.value_type == "string":
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
else:
raise ValueError("Invalid expected metadata value type")
conditions.append(
Condition(
name=metadata_name,
comparison_operator=condition.comparison_operator,
value=expected_value,
)
)
filters = DatasetRetrieval.process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
expected_value,
filters,
)
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
conditions=conditions,
)
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
conditions=conditions,
)
else:
raise ValueError("Invalid metadata filtering mode")
case _:
raise ValueError("Invalid metadata filtering mode")
if filters:
if (
node_data.metadata_filtering_conditions

View File

@ -196,13 +196,13 @@ def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
case "name":
return lambda x: x.filename or ""
case "type":
return lambda x: x.type
return lambda x: str(x.type)
case "extension":
return lambda x: x.extension or ""
case "mime_type":
return lambda x: x.mime_type or ""
case "transfer_method":
return lambda x: x.transfer_method
return lambda x: str(x.transfer_method)
case "url":
return lambda x: x.remote_url or ""
case "related_id":
@ -276,7 +276,6 @@ def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Calla
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
extract_func: Callable[[File], Any]
if key in {"name", "extension", "mime_type", "url", "related_id"} and isinstance(value, str):
extract_func = _get_file_extract_string_func(key=key)
return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x))
@ -284,8 +283,8 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str
extract_func = _get_file_extract_string_func(key=key)
return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x))
elif key == "size" and isinstance(value, str):
extract_func = _get_file_extract_number_func(key=key)
return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x))
extract_number = _get_file_extract_number_func(key=key)
return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_number(x))
else:
raise InvalidKeyError(f"Invalid key: {key}")

View File

@ -852,18 +852,16 @@ class LLMNode(Node[LLMNodeData]):
# Insert histories into the prompt
prompt_content = prompt_messages[0].content
# For issue #11247 - Check if prompt content is a string or a list
prompt_content_type = type(prompt_content)
if prompt_content_type == str:
if isinstance(prompt_content, str):
prompt_content = str(prompt_content)
if "#histories#" in prompt_content:
prompt_content = prompt_content.replace("#histories#", memory_text)
else:
prompt_content = memory_text + "\n" + prompt_content
prompt_messages[0].content = prompt_content
elif prompt_content_type == list:
prompt_content = prompt_content if isinstance(prompt_content, list) else []
elif isinstance(prompt_content, list):
for content_item in prompt_content:
if content_item.type == PromptMessageContentType.TEXT:
if isinstance(content_item, TextPromptMessageContent):
if "#histories#" in content_item.data:
content_item.data = content_item.data.replace("#histories#", memory_text)
else:
@ -873,13 +871,12 @@ class LLMNode(Node[LLMNodeData]):
# Add current query to the prompt message
if sys_query:
if prompt_content_type == str:
if isinstance(prompt_content, str):
prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
prompt_messages[0].content = prompt_content
elif prompt_content_type == list:
prompt_content = prompt_content if isinstance(prompt_content, list) else []
elif isinstance(prompt_content, list):
for content_item in prompt_content:
if content_item.type == PromptMessageContentType.TEXT:
if isinstance(content_item, TextPromptMessageContent):
content_item.data = sys_query + "\n" + content_item.data
else:
raise ValueError("Invalid prompt content type")
@ -1033,14 +1030,14 @@ class LLMNode(Node[LLMNodeData]):
if typed_node_data.prompt_config:
enable_jinja = False
if isinstance(prompt_template, list):
if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
if prompt_template.edition_type == "jinja2":
enable_jinja = True
else:
for prompt in prompt_template:
if prompt.edition_type == "jinja2":
enable_jinja = True
break
else:
if prompt_template.edition_type == "jinja2":
enable_jinja = True
if enable_jinja:
for variable_selector in typed_node_data.prompt_config.jinja2_variables or []:

View File

@ -1,4 +1,4 @@
from typing import Protocol
from typing import Any, Protocol
import httpx
@ -12,17 +12,17 @@ class HttpClientProtocol(Protocol):
@property
def request_error(self) -> type[Exception]: ...
def get(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def get(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
def head(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def head(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
def post(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def post(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
def put(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def put(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
def delete(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def delete(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
def patch(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def patch(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
class FileManagerProtocol(Protocol):

View File

@ -482,16 +482,17 @@ class ToolNode(Node[ToolNodeData]):
result = {}
for parameter_name in typed_node_data.tool_parameters:
input = typed_node_data.tool_parameters[parameter_name]
if input.type == "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == "variable":
selector_key = ".".join(input.value)
result[f"#{selector_key}#"] = input.value
elif input.type == "constant":
pass
match input.type:
case "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
case "variable":
selector_key = ".".join(input.value)
result[f"#{selector_key}#"] = input.value
case "constant":
pass
result = {node_id + "." + key: value for key, value in result.items()}

View File

@ -6,12 +6,13 @@ import threading
from collections.abc import Mapping, Sequence
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Protocol
from typing import Any, ClassVar, Protocol
from pydantic.json import pydantic_encoder
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
from core.workflow.runtime.variable_pool import VariablePool
@ -103,14 +104,33 @@ class ResponseStreamCoordinatorProtocol(Protocol):
...
class NodeProtocol(Protocol):
"""Structural interface for graph nodes."""
id: str
state: NodeState
execution_type: NodeExecutionType
node_type: ClassVar[NodeType]
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: ...
class EdgeProtocol(Protocol):
id: str
state: NodeState
tail: str
head: str
source_handle: str
class GraphProtocol(Protocol):
"""Structural interface required from graph instances attached to the runtime state."""
nodes: Mapping[str, object]
edges: Mapping[str, object]
root_node: object
nodes: Mapping[str, NodeProtocol]
edges: Mapping[str, EdgeProtocol]
root_node: NodeProtocol
def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ...
def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ...
@dataclass(slots=True)

View File

@ -144,11 +144,11 @@ class WorkflowEntry:
:param user_inputs: user inputs
:return:
"""
node_config = dict(workflow.get_node_config_by_id(node_id))
node_config_data = node_config.get("data", {})
node_config = workflow.get_node_config_by_id(node_id)
node_config_data = node_config["data"]
# Get node type
node_type = NodeType(node_config_data.get("type"))
node_type = NodeType(node_config_data["type"])
# init graph init params and runtime state
graph_init_params = GraphInitParams(

View File

@ -27,10 +27,13 @@ def init_app(app: DifyApp) -> None:
)
# Ensure route decorators are evaluated.
import controllers.console.init_validate as init_validate_module
import controllers.console.ping as ping_module
from controllers.console import setup
from controllers.console import remote_files, setup
_ = init_validate_module
_ = ping_module
_ = remote_files
_ = setup
router.include_router(console_router, prefix="/console/api")

View File

@ -390,8 +390,7 @@ class ClickZettaVolumeStorage(BaseStorage):
"""
content = self.load_once(filename)
with Path(target_filepath).open("wb") as f:
f.write(content)
Path(target_filepath).write_bytes(content)
logger.debug("File %s downloaded from ClickZetta Volume to %s", filename, target_filepath)

View File

@ -1,36 +1,69 @@
from flask_restx import Namespace, fields
from __future__ import annotations
from libs.helper import TimestampField
from datetime import datetime
annotation_fields = {
"id": fields.String,
"question": fields.String,
"answer": fields.Raw(attribute="content"),
"hit_count": fields.Integer,
"created_at": TimestampField,
# 'account': fields.Nested(simple_account_fields, allow_null=True)
}
from pydantic import BaseModel, ConfigDict, Field, field_validator
def build_annotation_model(api_or_ns: Namespace):
"""Build the annotation model for the API or Namespace."""
return api_or_ns.model("Annotation", annotation_fields)
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
annotation_list_fields = {
"data": fields.List(fields.Nested(annotation_fields)),
}
class ResponseModel(BaseModel):
model_config = ConfigDict(
from_attributes=True,
extra="ignore",
populate_by_name=True,
serialize_by_alias=True,
protected_namespaces=(),
)
annotation_hit_history_fields = {
"id": fields.String,
"source": fields.String,
"score": fields.Float,
"question": fields.String,
"created_at": TimestampField,
"match": fields.String(attribute="annotation_question"),
"response": fields.String(attribute="annotation_content"),
}
annotation_hit_history_list_fields = {
"data": fields.List(fields.Nested(annotation_hit_history_fields)),
}
class Annotation(ResponseModel):
id: str
question: str | None = None
answer: str | None = Field(default=None, validation_alias="content")
hit_count: int | None = None
created_at: int | None = None
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class AnnotationList(ResponseModel):
data: list[Annotation]
has_more: bool
limit: int
total: int
page: int
class AnnotationExportList(ResponseModel):
data: list[Annotation]
class AnnotationHitHistory(ResponseModel):
id: str
source: str | None = None
score: float | None = None
question: str | None = None
created_at: int | None = None
match: str | None = Field(default=None, validation_alias="annotation_question")
response: str | None = Field(default=None, validation_alias="annotation_content")
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class AnnotationHitHistoryList(ResponseModel):
data: list[AnnotationHitHistory]
has_more: bool
limit: int
total: int
page: int

View File

@ -1,4 +1,7 @@
from flask_restx import Namespace, fields
from __future__ import annotations
from flask_restx import fields
from pydantic import BaseModel, ConfigDict
simple_end_user_fields = {
"id": fields.String,
@ -8,5 +11,18 @@ simple_end_user_fields = {
}
def build_simple_end_user_model(api_or_ns: Namespace):
return api_or_ns.model("SimpleEndUser", simple_end_user_fields)
class ResponseModel(BaseModel):
model_config = ConfigDict(
from_attributes=True,
extra="ignore",
populate_by_name=True,
serialize_by_alias=True,
protected_namespaces=(),
)
class SimpleEndUser(ResponseModel):
id: str
type: str
is_anonymous: bool
session_id: str | None = None

View File

@ -1,6 +1,11 @@
from flask_restx import Namespace, fields
from __future__ import annotations
from libs.helper import AvatarUrlField, TimestampField
from datetime import datetime
from flask_restx import fields
from pydantic import BaseModel, ConfigDict, computed_field, field_validator
from core.file import helpers as file_helpers
simple_account_fields = {
"id": fields.String,
@ -9,36 +14,78 @@ simple_account_fields = {
}
def build_simple_account_model(api_or_ns: Namespace):
return api_or_ns.model("SimpleAccount", simple_account_fields)
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
account_fields = {
"id": fields.String,
"name": fields.String,
"avatar": fields.String,
"avatar_url": AvatarUrlField,
"email": fields.String,
"is_password_set": fields.Boolean,
"interface_language": fields.String,
"interface_theme": fields.String,
"timezone": fields.String,
"last_login_at": TimestampField,
"last_login_ip": fields.String,
"created_at": TimestampField,
}
def _build_avatar_url(avatar: str | None) -> str | None:
if avatar is None:
return None
if avatar.startswith(("http://", "https://")):
return avatar
return file_helpers.get_signed_file_url(avatar)
account_with_role_fields = {
"id": fields.String,
"name": fields.String,
"avatar": fields.String,
"avatar_url": AvatarUrlField,
"email": fields.String,
"last_login_at": TimestampField,
"last_active_at": TimestampField,
"created_at": TimestampField,
"role": fields.String,
"status": fields.String,
}
account_with_role_list_fields = {"accounts": fields.List(fields.Nested(account_with_role_fields))}
class ResponseModel(BaseModel):
model_config = ConfigDict(
from_attributes=True,
extra="ignore",
populate_by_name=True,
serialize_by_alias=True,
protected_namespaces=(),
)
class SimpleAccount(ResponseModel):
id: str
name: str
email: str
class _AccountAvatar(ResponseModel):
avatar: str | None = None
@computed_field(return_type=str | None) # type: ignore[prop-decorator]
@property
def avatar_url(self) -> str | None:
return _build_avatar_url(self.avatar)
class Account(_AccountAvatar):
id: str
name: str
email: str
is_password_set: bool
interface_language: str | None = None
interface_theme: str | None = None
timezone: str | None = None
last_login_at: int | None = None
last_login_ip: str | None = None
created_at: int | None = None
@field_validator("last_login_at", "created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class AccountWithRole(_AccountAvatar):
id: str
name: str
email: str
last_login_at: int | None = None
last_active_at: int | None = None
created_at: int | None = None
role: str
status: str
@field_validator("last_login_at", "last_active_at", "created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class AccountWithRoleList(ResponseModel):
accounts: list[AccountWithRole]

View File

@ -1,12 +1,20 @@
from flask_restx import Namespace, fields
from __future__ import annotations
dataset_tag_fields = {
"id": fields.String,
"name": fields.String,
"type": fields.String,
"binding_count": fields.String,
}
from pydantic import BaseModel, ConfigDict
def build_dataset_tag_fields(api_or_ns: Namespace):
return api_or_ns.model("DataSetTag", dataset_tag_fields)
class ResponseModel(BaseModel):
model_config = ConfigDict(
from_attributes=True,
extra="ignore",
populate_by_name=True,
serialize_by_alias=True,
protected_namespaces=(),
)
class DataSetTag(ResponseModel):
id: str
name: str
type: str
binding_count: str | None = None

View File

@ -1,7 +1,7 @@
from flask_restx import Namespace, fields
from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields
from fields.member_fields import build_simple_account_model, simple_account_fields
from fields.end_user_fields import simple_end_user_fields
from fields.member_fields import simple_account_fields
from fields.workflow_run_fields import (
build_workflow_run_for_archived_log_model,
build_workflow_run_for_log_model,
@ -25,17 +25,9 @@ workflow_app_log_partial_fields = {
def build_workflow_app_log_partial_model(api_or_ns: Namespace):
"""Build the workflow app log partial model for the API or Namespace."""
workflow_run_model = build_workflow_run_for_log_model(api_or_ns)
simple_account_model = build_simple_account_model(api_or_ns)
simple_end_user_model = build_simple_end_user_model(api_or_ns)
copied_fields = workflow_app_log_partial_fields.copy()
copied_fields["workflow_run"] = fields.Nested(workflow_run_model, attribute="workflow_run", allow_null=True)
copied_fields["created_by_account"] = fields.Nested(
simple_account_model, attribute="created_by_account", allow_null=True
)
copied_fields["created_by_end_user"] = fields.Nested(
simple_end_user_model, attribute="created_by_end_user", allow_null=True
)
return api_or_ns.model("WorkflowAppLogPartial", copied_fields)
@ -52,17 +44,9 @@ workflow_archived_log_partial_fields = {
def build_workflow_archived_log_partial_model(api_or_ns: Namespace):
"""Build the workflow archived log partial model for the API or Namespace."""
workflow_run_model = build_workflow_run_for_archived_log_model(api_or_ns)
simple_account_model = build_simple_account_model(api_or_ns)
simple_end_user_model = build_simple_end_user_model(api_or_ns)
copied_fields = workflow_archived_log_partial_fields.copy()
copied_fields["workflow_run"] = fields.Nested(workflow_run_model, allow_null=True)
copied_fields["created_by_account"] = fields.Nested(
simple_account_model, attribute="created_by_account", allow_null=True
)
copied_fields["created_by_end_user"] = fields.Nested(
simple_end_user_model, attribute="created_by_end_user", allow_null=True
)
return api_or_ns.model("WorkflowArchivedLogPartial", copied_fields)

View File

@ -136,7 +136,7 @@ class PKCS1OAepCipher:
# Step 3a (OS2IP)
em_int = bytes_to_long(em)
# Step 3b (RSAEP)
m_int = gmpy2.powmod(em_int, self._key.e, self._key.n)
m_int: int = gmpy2.powmod(em_int, self._key.e, self._key.n) # type: ignore[attr-defined]
# Step 3c (I2OSP)
c = long_to_bytes(m_int, k)
return c
@ -169,7 +169,7 @@ class PKCS1OAepCipher:
ct_int = bytes_to_long(ciphertext)
# Step 2b (RSADP)
# m_int = self._key._decrypt(ct_int)
m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n)
m_int: int = gmpy2.powmod(ct_int, self._key.d, self._key.n) # type: ignore[attr-defined]
# Complete step 2c (I2OSP)
em = long_to_bytes(m_int, k)
# Step 3a

View File

@ -51,7 +51,7 @@ def upgrade():
batch_op.add_column(sa.Column('summary_index_setting', models.types.AdjustedJSON(), nullable=True))
with op.batch_alter_table('documents', schema=None) as batch_op:
batch_op.add_column(sa.Column('need_summary', sa.Boolean(), server_default=sa.text('false'), nullable=True))
batch_op.add_column(sa.Column('need_summary', sa.Boolean(), server_default=sa.text('false'), nullable=False))
else:
# MySQL: Use compatible syntax
op.create_table(
@ -83,7 +83,7 @@ def upgrade():
batch_op.add_column(sa.Column('summary_index_setting', models.types.AdjustedJSON(), nullable=True))
with op.batch_alter_table('documents', schema=None) as batch_op:
batch_op.add_column(sa.Column('need_summary', sa.Boolean(), server_default=sa.text('false'), nullable=True))
batch_op.add_column(sa.Column('need_summary', sa.Boolean(), server_default=sa.text('false'), nullable=False))
# ### end Alembic commands ###

View File

@ -420,7 +420,7 @@ class Document(Base):
doc_metadata = mapped_column(AdjustedJSON, nullable=True)
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'"))
doc_language = mapped_column(String(255), nullable=True)
need_summary: Mapped[bool | None] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
need_summary: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"]

View File

@ -29,6 +29,7 @@ from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
)
from core.workflow.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause
from core.workflow.enums import NodeType
from extensions.ext_storage import Storage
@ -229,7 +230,7 @@ class Workflow(Base): # bug
# - `_get_graph_and_variable_pool_for_single_node_run`.
return json.loads(self.graph) if self.graph else {}
def get_node_config_by_id(self, node_id: str) -> Mapping[str, Any]:
def get_node_config_by_id(self, node_id: str) -> NodeConfigDict:
"""Extract a node configuration from the workflow graph by node ID.
A node configuration is a dictionary containing the node's properties, including
the node's id, title, and its data as a dict.
@ -247,8 +248,7 @@ class Workflow(Base): # bug
node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes))
except StopIteration:
raise NodeNotFoundError(node_id)
assert isinstance(node_config, dict)
return node_config
return NodeConfigDictAdapter.validate_python(node_config)
@staticmethod
def get_node_type_from_node_config(node_config: Mapping[str, Any]) -> NodeType:

View File

@ -87,7 +87,7 @@ dependencies = [
"sseclient-py~=1.8.0",
"httpx-sse~=0.4.0",
"sendgrid~=6.12.3",
"flask-restx~=1.3.0",
"flask-restx~=1.3.2",
"packaging~=23.2",
"croniter>=6.0.0",
"weaviate-client==4.17.0",
@ -116,7 +116,7 @@ dev = [
"dotenv-linter~=0.5.0",
"faker~=38.2.0",
"lxml-stubs~=0.5.1",
"ty~=0.0.1a19",
"ty>=0.0.14",
"basedpyright~=1.31.0",
"ruff~=0.14.0",
"pytest~=8.3.2",
@ -145,7 +145,7 @@ dev = [
"types-openpyxl~=3.1.5",
"types-pexpect~=4.9.0",
"types-protobuf~=5.29.1",
"types-psutil~=7.0.0",
"types-psutil~=7.2.2",
"types-psycopg2~=2.9.21",
"types-pygments~=2.19.0",
"types-pymysql~=1.1.0",

View File

@ -158,7 +158,7 @@ class AppAnnotationService:
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
)
annotations = db.paginate(select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False)
return annotations.items, annotations.total
return annotations.items, annotations.total or 0
@classmethod
def export_annotation_list_by_app_id(cls, app_id: str):
@ -524,7 +524,7 @@ class AppAnnotationService:
annotation_hit_histories = db.paginate(
select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False
)
return annotation_hit_histories.items, annotation_hit_histories.total
return annotation_hit_histories.items, annotation_hit_histories.total or 0
@classmethod
def get_annotation_by_id(cls, annotation_id: str) -> MessageAnnotation | None:

View File

@ -16,6 +16,7 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
from core.db.session_factory import session_factory
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.file import helpers as file_helpers
from core.helper.name_generator import generate_incremental_name
@ -1388,6 +1389,46 @@ class DocumentService:
).all()
return documents
@staticmethod
def update_documents_need_summary(dataset_id: str, document_ids: Sequence[str], need_summary: bool = True) -> int:
"""
Update need_summary field for multiple documents.
This method handles the case where documents were created when summary_index_setting was disabled,
and need to be updated when summary_index_setting is later enabled.
Args:
dataset_id: Dataset ID
document_ids: List of document IDs to update
need_summary: Value to set for need_summary field (default: True)
Returns:
Number of documents updated
"""
if not document_ids:
return 0
document_id_list: list[str] = [str(document_id) for document_id in document_ids]
with session_factory.create_session() as session:
updated_count = (
session.query(Document)
.filter(
Document.id.in_(document_id_list),
Document.dataset_id == dataset_id,
Document.doc_form != "qa_model", # Skip qa_model documents
)
.update({Document.need_summary: need_summary}, synchronize_session=False)
)
session.commit()
logger.info(
"Updated need_summary to %s for %d documents in dataset %s",
need_summary,
updated_count,
dataset_id,
)
return updated_count
@staticmethod
def get_document_download_url(document: Document) -> str:
"""
@ -2937,14 +2978,15 @@ class DocumentService:
"""
now = naive_utc_now()
if action == "enable":
return DocumentService._prepare_enable_update(document, now)
elif action == "disable":
return DocumentService._prepare_disable_update(document, user, now)
elif action == "archive":
return DocumentService._prepare_archive_update(document, user, now)
elif action == "un_archive":
return DocumentService._prepare_unarchive_update(document, now)
match action:
case "enable":
return DocumentService._prepare_enable_update(document, now)
case "disable":
return DocumentService._prepare_disable_update(document, user, now)
case "archive":
return DocumentService._prepare_archive_update(document, user, now)
case "un_archive":
return DocumentService._prepare_unarchive_update(document, now)
return None
@ -3581,56 +3623,57 @@ class SegmentService:
# Check if segment_ids is not empty to avoid WHERE false condition
if not segment_ids or len(segment_ids) == 0:
return
if action == "enable":
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
DocumentSegment.enabled == False,
)
).all()
if not segments:
return
real_deal_segment_ids = []
for segment in segments:
indexing_cache_key = f"segment_{segment.id}_indexing"
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
continue
segment.enabled = True
segment.disabled_at = None
segment.disabled_by = None
db.session.add(segment)
real_deal_segment_ids.append(segment.id)
db.session.commit()
match action:
case "enable":
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
DocumentSegment.enabled == False,
)
).all()
if not segments:
return
real_deal_segment_ids = []
for segment in segments:
indexing_cache_key = f"segment_{segment.id}_indexing"
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
continue
segment.enabled = True
segment.disabled_at = None
segment.disabled_by = None
db.session.add(segment)
real_deal_segment_ids.append(segment.id)
db.session.commit()
enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
elif action == "disable":
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
DocumentSegment.enabled == True,
)
).all()
if not segments:
return
real_deal_segment_ids = []
for segment in segments:
indexing_cache_key = f"segment_{segment.id}_indexing"
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
continue
segment.enabled = False
segment.disabled_at = naive_utc_now()
segment.disabled_by = current_user.id
db.session.add(segment)
real_deal_segment_ids.append(segment.id)
db.session.commit()
enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
case "disable":
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
DocumentSegment.enabled == True,
)
).all()
if not segments:
return
real_deal_segment_ids = []
for segment in segments:
indexing_cache_key = f"segment_{segment.id}_indexing"
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
continue
segment.enabled = False
segment.disabled_at = naive_utc_now()
segment.disabled_by = current_user.id
db.session.add(segment)
real_deal_segment_ids.append(segment.id)
db.session.commit()
disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
@classmethod
def create_child_chunk(

View File

@ -174,6 +174,10 @@ class RagPipelineTransformService:
else:
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
# Copy summary_index_setting from dataset to knowledge_index node configuration
if dataset.summary_index_setting:
knowledge_configuration.summary_index_setting = dataset.summary_index_setting
knowledge_configuration_dict.update(knowledge_configuration.model_dump())
node["data"] = knowledge_configuration_dict
return node

View File

@ -49,11 +49,18 @@ class SummaryIndexService:
# Use lazy import to avoid circular import
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
# Get document language to ensure summary is generated in the correct language
# This is especially important for image-only chunks where text is empty or minimal
document_language = None
if segment.document and segment.document.doc_language:
document_language = segment.document.doc_language
summary_content, usage = ParagraphIndexProcessor.generate_summary(
tenant_id=dataset.tenant_id,
text=segment.content,
summary_index_setting=summary_index_setting,
segment_id=segment.id,
document_language=document_language,
)
if not summary_content:
@ -558,6 +565,9 @@ class SummaryIndexService:
)
session.add(summary_record)
# Commit the batch created records
session.commit()
@staticmethod
def update_summary_record_error(
segment: DocumentSegment,
@ -762,7 +772,6 @@ class SummaryIndexService:
dataset=dataset,
status="not_started",
)
session.commit() # Commit initial records
summary_records = []

View File

@ -24,7 +24,7 @@ class TagService:
escaped_keyword = escape_like_pattern(keyword)
query = query.where(sa.and_(Tag.name.ilike(f"%{escaped_keyword}%", escape="\\")))
query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
results: list = query.order_by(Tag.created_at.desc()).all()
results = query.order_by(Tag.created_at.desc()).all()
return results
@staticmethod

View File

@ -1,8 +1,6 @@
import json
import logging
from collections.abc import Mapping
from datetime import datetime
from typing import Any
from sqlalchemy import or_, select
from sqlalchemy.orm import Session
@ -10,8 +8,8 @@ from sqlalchemy.orm import Session
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
@ -38,12 +36,10 @@ class WorkflowToolManageService:
label: str,
icon: dict,
description: str,
parameters: list[Mapping[str, Any]],
parameters: list[WorkflowToolParameterConfiguration],
privacy_policy: str = "",
labels: list[str] | None = None,
):
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
# check if the name is unique
existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider)
@ -75,7 +71,7 @@ class WorkflowToolManageService:
label=label,
icon=json.dumps(icon),
description=description,
parameter_configuration=json.dumps(parameters),
parameter_configuration=json.dumps([p.model_dump() for p in parameters]),
privacy_policy=privacy_policy,
version=workflow.version,
)
@ -104,7 +100,7 @@ class WorkflowToolManageService:
label: str,
icon: dict,
description: str,
parameters: list[Mapping[str, Any]],
parameters: list[WorkflowToolParameterConfiguration],
privacy_policy: str = "",
labels: list[str] | None = None,
):
@ -122,8 +118,6 @@ class WorkflowToolManageService:
:param labels: labels
:return: the updated tool
"""
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
# check if the name is unique
existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider)
@ -162,7 +156,7 @@ class WorkflowToolManageService:
workflow_tool_provider.label = label
workflow_tool_provider.icon = json.dumps(icon)
workflow_tool_provider.description = description
workflow_tool_provider.parameter_configuration = json.dumps(parameters)
workflow_tool_provider.parameter_configuration = json.dumps([p.model_dump() for p in parameters])
workflow_tool_provider.privacy_policy = privacy_policy
workflow_tool_provider.version = workflow.version
workflow_tool_provider.updated_at = datetime.now()

View File

@ -90,6 +90,7 @@ class TestWebhookService:
"id": "webhook_node",
"type": "webhook",
"data": {
"type": "trigger-webhook",
"title": "Test Webhook",
"method": "post",
"content_type": "application/json",

View File

@ -3,7 +3,9 @@ from unittest.mock import patch
import pytest
from faker import Faker
from pydantic import ValidationError
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
from models.tools import WorkflowToolProvider
from models.workflow import Workflow as WorkflowModel
from services.account_service import AccountService, TenantService
@ -130,20 +132,24 @@ class TestWorkflowToolManageService:
def _create_test_workflow_tool_parameters(self):
"""Helper method to create valid workflow tool parameters."""
return [
{
"name": "input_text",
"description": "Input text for processing",
"form": "form",
"type": "string",
"required": True,
},
{
"name": "output_format",
"description": "Output format specification",
"form": "form",
"type": "select",
"required": False,
},
WorkflowToolParameterConfiguration.model_validate(
{
"name": "input_text",
"description": "Input text for processing",
"form": "form",
"type": "string",
"required": True,
}
),
WorkflowToolParameterConfiguration.model_validate(
{
"name": "output_format",
"description": "Output format specification",
"form": "form",
"type": "select",
"required": False,
}
),
]
def test_create_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
@ -208,7 +214,7 @@ class TestWorkflowToolManageService:
assert created_tool_provider.label == tool_label
assert created_tool_provider.icon == json.dumps(tool_icon)
assert created_tool_provider.description == tool_description
assert created_tool_provider.parameter_configuration == json.dumps(tool_parameters)
assert created_tool_provider.parameter_configuration == json.dumps([p.model_dump() for p in tool_parameters])
assert created_tool_provider.privacy_policy == tool_privacy_policy
assert created_tool_provider.version == workflow.version
assert created_tool_provider.user_id == account.id
@ -353,18 +359,9 @@ class TestWorkflowToolManageService:
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
# Setup invalid workflow tool parameters (missing required fields)
invalid_parameters = [
{
"name": "input_text",
# Missing description and form fields
"type": "string",
"required": True,
}
]
# Attempt to create workflow tool with invalid parameters
with pytest.raises(ValueError) as exc_info:
with pytest.raises(ValidationError) as exc_info:
# Setup invalid workflow tool parameters (missing required fields)
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
@ -373,7 +370,16 @@ class TestWorkflowToolManageService:
label=fake.word(),
icon={"type": "emoji", "emoji": "🔧"},
description=fake.text(max_nb_chars=200),
parameters=invalid_parameters,
parameters=[
WorkflowToolParameterConfiguration.model_validate(
{
"name": "input_text",
# Missing description and form fields
"type": "string",
"required": True,
}
)
],
)
# Verify error message contains validation error
@ -579,11 +585,12 @@ class TestWorkflowToolManageService:
# Verify database state was updated
db.session.refresh(created_tool)
assert created_tool is not None
assert created_tool.name == updated_tool_name
assert created_tool.label == updated_tool_label
assert created_tool.icon == json.dumps(updated_tool_icon)
assert created_tool.description == updated_tool_description
assert created_tool.parameter_configuration == json.dumps(updated_tool_parameters)
assert created_tool.parameter_configuration == json.dumps([p.model_dump() for p in updated_tool_parameters])
assert created_tool.privacy_policy == updated_tool_privacy_policy
assert created_tool.version == workflow.version
assert created_tool.updated_at is not None
@ -750,13 +757,15 @@ class TestWorkflowToolManageService:
# Setup workflow tool parameters with FILE type
file_parameters = [
{
"name": "document",
"description": "Upload a document",
"form": "form",
"type": "file",
"required": False,
}
WorkflowToolParameterConfiguration.model_validate(
{
"name": "document",
"description": "Upload a document",
"form": "form",
"type": "file",
"required": False,
}
)
]
# Execute the method under test
@ -823,13 +832,15 @@ class TestWorkflowToolManageService:
# Setup workflow tool parameters with FILES type
files_parameters = [
{
"name": "documents",
"description": "Upload multiple documents",
"form": "form",
"type": "files",
"required": False,
}
WorkflowToolParameterConfiguration.model_validate(
{
"name": "documents",
"description": "Upload multiple documents",
"form": "form",
"type": "files",
"required": False,
}
)
]
# Execute the method under test

View File

@ -0,0 +1,46 @@
import builtins
from unittest.mock import patch
import pytest
from flask import Flask
from flask.views import MethodView
from extensions import ext_fastopenapi
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
@pytest.fixture
def app() -> Flask:
app = Flask(__name__)
app.config["TESTING"] = True
app.secret_key = "test-secret-key"
return app
def test_console_init_get_returns_finished_when_no_init_password(app: Flask, monkeypatch: pytest.MonkeyPatch):
ext_fastopenapi.init_app(app)
monkeypatch.delenv("INIT_PASSWORD", raising=False)
with patch("controllers.console.init_validate.dify_config.EDITION", "SELF_HOSTED"):
client = app.test_client()
response = client.get("/console/api/init")
assert response.status_code == 200
assert response.get_json() == {"status": "finished"}
def test_console_init_post_returns_success(app: Flask, monkeypatch: pytest.MonkeyPatch):
ext_fastopenapi.init_app(app)
monkeypatch.setenv("INIT_PASSWORD", "test-init-password")
with (
patch("controllers.console.init_validate.dify_config.EDITION", "SELF_HOSTED"),
patch("controllers.console.init_validate.TenantService.get_tenant_count", return_value=0),
):
client = app.test_client()
response = client.post("/console/api/init", json={"password": "test-init-password"})
assert response.status_code == 201
assert response.get_json() == {"result": "success"}

View File

@ -0,0 +1,92 @@
import builtins
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import patch
import httpx
import pytest
from flask import Flask
from flask.views import MethodView
from extensions import ext_fastopenapi
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
@pytest.fixture
def app() -> Flask:
app = Flask(__name__)
app.config["TESTING"] = True
return app
def test_console_remote_files_fastopenapi_get_info(app: Flask):
ext_fastopenapi.init_app(app)
response = httpx.Response(
200,
request=httpx.Request("HEAD", "http://example.com/file.txt"),
headers={"Content-Type": "text/plain", "Content-Length": "10"},
)
with patch("controllers.console.remote_files.ssrf_proxy.head", return_value=response):
client = app.test_client()
encoded_url = "http%3A%2F%2Fexample.com%2Ffile.txt"
resp = client.get(f"/console/api/remote-files/{encoded_url}")
assert resp.status_code == 200
assert resp.get_json() == {"file_type": "text/plain", "file_length": 10}
def test_console_remote_files_fastopenapi_upload(app: Flask):
ext_fastopenapi.init_app(app)
head_response = httpx.Response(
200,
request=httpx.Request("GET", "http://example.com/file.txt"),
content=b"hello",
)
file_info = SimpleNamespace(
extension="txt",
size=5,
filename="file.txt",
mimetype="text/plain",
)
uploaded = SimpleNamespace(
id="file-id",
name="file.txt",
size=5,
extension="txt",
mime_type="text/plain",
created_by="user-id",
created_at=datetime(2024, 1, 1),
)
with (
patch("controllers.console.remote_files.db", new=SimpleNamespace(engine=object())),
patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_response),
patch("controllers.console.remote_files.helpers.guess_file_info_from_response", return_value=file_info),
patch("controllers.console.remote_files.FileService.is_file_size_within_limit", return_value=True),
patch("controllers.console.remote_files.FileService.__init__", return_value=None),
patch("controllers.console.remote_files.current_account_with_tenant", return_value=(object(), "tenant-id")),
patch("controllers.console.remote_files.FileService.upload_file", return_value=uploaded),
patch("controllers.console.remote_files.file_helpers.get_signed_file_url", return_value="signed-url"),
):
client = app.test_client()
resp = client.post(
"/console/api/remote-files/upload",
json={"url": "http://example.com/file.txt"},
)
assert resp.status_code == 201
assert resp.get_json() == {
"id": "file-id",
"name": "file.txt",
"size": 5,
"extension": "txt",
"url": "signed-url",
"mime_type": "text/plain",
"created_by": "user-id",
"created_at": int(uploaded.created_at.timestamp()),
}

Some files were not shown because too many files have changed in this diff Show More