Merge branch 'main' into feat/rag-pipeline

This commit is contained in:
zxhlyh 2025-04-29 16:27:49 +08:00
commit a46b4e3616
279 changed files with 4905 additions and 859 deletions

View File

@ -6,7 +6,7 @@
本指南和 Dify 一样在不断完善中。如果有任何滞后于项目实际情况的地方,恳请谅解,我们也欢迎任何改进建议。
关于许可证,请花一分钟阅读我们简短的[许可和贡献者协议](./LICENSE)。社区同时也遵循[行为准则](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)。
关于许可证,请花一分钟阅读我们简短的[许可和贡献者协议](./LICENSE)。同时也遵循社区[行为准则](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)。
## 开始之前

View File

@ -54,7 +54,7 @@
<a href="./README_BN.md"><img alt="README in বাংলা" src="https://img.shields.io/badge/বাংলা-d9d9d9"></a>
</p>
Dify is an open-source LLM app development platform. Its intuitive interface combines agentic AI workflow, RAG pipeline, agent capabilities, model management, observability features and more, letting you quickly go from prototype to production.
Dify is an open-source LLM app development platform. Its intuitive interface combines agentic AI workflow, RAG pipeline, agent capabilities, model management, observability features, and more, allowing you to quickly move from prototype to production.
## Quick start
@ -188,7 +188,7 @@ All of Dify's offerings come with corresponding APIs, so you could effortlessly
- **Dify for enterprise / organizations</br>**
We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) to discuss enterprise needs. </br>
> For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one-click. It's an affordable AMI offering with the option to create apps with custom logo and branding.
> For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one click. It's an affordable AMI offering with the option to create apps with custom logo and branding.
## Staying ahead
@ -233,7 +233,7 @@ Deploy Dify to AWS with [CDK](https://aws.amazon.com/cdk/)
For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
At the same time, please consider supporting Dify by sharing it on social media and at events and conferences.
> We are looking for contributors to help with translating Dify to languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/8Tpq4AcN9c).
> We are looking for contributors to help translate Dify into languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/8Tpq4AcN9c).
## Community & contact

View File

@ -297,6 +297,7 @@ LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:
LINDORM_USERNAME=admin
LINDORM_PASSWORD=admin
USING_UGC_INDEX=False
LINDORM_QUERY_TIMEOUT=1
# OceanBase Vector configuration
OCEANBASE_VECTOR_HOST=127.0.0.1

View File

@ -52,7 +52,6 @@ def initialize_extensions(app: DifyApp):
ext_mail,
ext_migrate,
ext_otel,
ext_otel_patch,
ext_proxy_fix,
ext_redis,
ext_repositories,
@ -85,7 +84,6 @@ def initialize_extensions(app: DifyApp):
ext_proxy_fix,
ext_blueprints,
ext_commands,
ext_otel_patch, # Apply patch before initializing OpenTelemetry
ext_otel,
]
for ext in extensions:

View File

@ -17,6 +17,7 @@ from core.rag.models.document import Document
from events.app_event import app_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from libs.helper import email as email_validate
from libs.password import hash_password, password_pattern, valid_password
from libs.rsa import generate_key_pair
@ -271,6 +272,7 @@ def migrate_knowledge_vector_database():
upper_collection_vector_types = {
VectorType.MILVUS,
VectorType.PGVECTOR,
VectorType.VASTBASE,
VectorType.RELYT,
VectorType.WEAVIATE,
VectorType.ORACLE,
@ -814,3 +816,331 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[
ClearFreePlanTenantExpiredLogs.process(days, batch, tenant_ids)
click.echo(click.style("Clear free plan tenant expired logs completed.", fg="green"))
@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.")
@click.command("clear-orphaned-file-records", help="Clear orphaned file records.")
def clear_orphaned_file_records(force: bool):
"""
Clear orphaned file records in the database.
"""
# define tables and columns to process
files_tables = [
{"table": "upload_files", "id_column": "id", "key_column": "key"},
{"table": "tool_files", "id_column": "id", "key_column": "file_key"},
]
ids_tables = [
{"type": "uuid", "table": "message_files", "column": "upload_file_id"},
{"type": "text", "table": "documents", "column": "data_source_info"},
{"type": "text", "table": "document_segments", "column": "content"},
{"type": "text", "table": "messages", "column": "answer"},
{"type": "text", "table": "workflow_node_executions", "column": "inputs"},
{"type": "text", "table": "workflow_node_executions", "column": "process_data"},
{"type": "text", "table": "workflow_node_executions", "column": "outputs"},
{"type": "text", "table": "conversations", "column": "introduction"},
{"type": "text", "table": "conversations", "column": "system_instruction"},
{"type": "json", "table": "messages", "column": "inputs"},
{"type": "json", "table": "messages", "column": "message"},
]
# notify user and ask for confirmation
click.echo(
click.style(
"This command will first find and delete orphaned file records from the message_files table,", fg="yellow"
)
)
click.echo(
click.style(
"and then it will find and delete orphaned file records in the following tables:",
fg="yellow",
)
)
for files_table in files_tables:
click.echo(click.style(f"- {files_table['table']}", fg="yellow"))
click.echo(
click.style("The following tables and columns will be scanned to find orphaned file records:", fg="yellow")
)
for ids_table in ids_tables:
click.echo(click.style(f"- {ids_table['table']} ({ids_table['column']})", fg="yellow"))
click.echo("")
click.echo(click.style("!!! USE WITH CAUTION !!!", fg="red"))
click.echo(
click.style(
(
"Since not all patterns have been fully tested, "
"please note that this command may delete unintended file records."
),
fg="yellow",
)
)
click.echo(
click.style("This cannot be undone. Please make sure to back up your database before proceeding.", fg="yellow")
)
click.echo(
click.style(
(
"It is also recommended to run this during the maintenance window, "
"as this may cause high load on your instance."
),
fg="yellow",
)
)
if not force:
click.confirm("Do you want to proceed?", abort=True)
# start the cleanup process
click.echo(click.style("Starting orphaned file records cleanup.", fg="white"))
# clean up the orphaned records in the message_files table where message_id doesn't exist in messages table
try:
click.echo(
click.style("- Listing message_files records where message_id doesn't exist in messages table", fg="white")
)
query = (
"SELECT mf.id, mf.message_id "
"FROM message_files mf LEFT JOIN messages m ON mf.message_id = m.id "
"WHERE m.id IS NULL"
)
orphaned_message_files = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(query))
for i in rs:
orphaned_message_files.append({"id": str(i[0]), "message_id": str(i[1])})
if orphaned_message_files:
click.echo(click.style(f"Found {len(orphaned_message_files)} orphaned message_files records:", fg="white"))
for record in orphaned_message_files:
click.echo(click.style(f" - id: {record['id']}, message_id: {record['message_id']}", fg="black"))
if not force:
click.confirm(
(
f"Do you want to proceed "
f"to delete all {len(orphaned_message_files)} orphaned message_files records?"
),
abort=True,
)
click.echo(click.style("- Deleting orphaned message_files records", fg="white"))
query = "DELETE FROM message_files WHERE id IN :ids"
with db.engine.begin() as conn:
conn.execute(db.text(query), {"ids": tuple([record["id"] for record in orphaned_message_files])})
click.echo(
click.style(f"Removed {len(orphaned_message_files)} orphaned message_files records.", fg="green")
)
else:
click.echo(click.style("No orphaned message_files records found. There is nothing to delete.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error deleting orphaned message_files records: {str(e)}", fg="red"))
# clean up the orphaned records in the rest of the *_files tables
try:
# fetch file id and keys from each table
all_files_in_tables = []
for files_table in files_tables:
click.echo(click.style(f"- Listing file records in table {files_table['table']}", fg="white"))
query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}"
with db.engine.begin() as conn:
rs = conn.execute(db.text(query))
for i in rs:
all_files_in_tables.append({"table": files_table["table"], "id": str(i[0]), "key": i[1]})
click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white"))
# fetch referred table and columns
guid_regexp = "[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
all_ids_in_tables = []
for ids_table in ids_tables:
query = ""
if ids_table["type"] == "uuid":
click.echo(
click.style(
f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}", fg="white"
)
)
query = (
f"SELECT {ids_table['column']} FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(db.text(query))
for i in rs:
all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])})
elif ids_table["type"] == "text":
click.echo(
click.style(
f"- Listing file-id-like strings in column {ids_table['column']} in table {ids_table['table']}",
fg="white",
)
)
query = (
f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id "
f"FROM {ids_table['table']}"
)
with db.engine.begin() as conn:
rs = conn.execute(db.text(query))
for i in rs:
for j in i[0]:
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
elif ids_table["type"] == "json":
click.echo(
click.style(
(
f"- Listing file-id-like JSON string in column {ids_table['column']} "
f"in table {ids_table['table']}"
),
fg="white",
)
)
query = (
f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id "
f"FROM {ids_table['table']}"
)
with db.engine.begin() as conn:
rs = conn.execute(db.text(query))
for i in rs:
for j in i[0]:
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
click.echo(click.style(f"Found {len(all_ids_in_tables)} file ids in tables.", fg="white"))
except Exception as e:
click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red"))
return
# find orphaned files
all_files = [file["id"] for file in all_files_in_tables]
all_ids = [file["id"] for file in all_ids_in_tables]
orphaned_files = list(set(all_files) - set(all_ids))
if not orphaned_files:
click.echo(click.style("No orphaned file records found. There is nothing to delete.", fg="green"))
return
click.echo(click.style(f"Found {len(orphaned_files)} orphaned file records.", fg="white"))
for file in orphaned_files:
click.echo(click.style(f"- orphaned file id: {file}", fg="black"))
if not force:
click.confirm(f"Do you want to proceed to delete all {len(orphaned_files)} orphaned file records?", abort=True)
# delete orphaned records for each file
try:
for files_table in files_tables:
click.echo(click.style(f"- Deleting orphaned file records in table {files_table['table']}", fg="white"))
query = f"DELETE FROM {files_table['table']} WHERE {files_table['id_column']} IN :ids"
with db.engine.begin() as conn:
conn.execute(db.text(query), {"ids": tuple(orphaned_files)})
except Exception as e:
click.echo(click.style(f"Error deleting orphaned file records: {str(e)}", fg="red"))
return
click.echo(click.style(f"Removed {len(orphaned_files)} orphaned file records.", fg="green"))
@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.")
@click.command("remove-orphaned-files-on-storage", help="Remove orphaned files on the storage.")
def remove_orphaned_files_on_storage(force: bool):
"""
Remove orphaned files on the storage.
"""
# define tables and columns to process
files_tables = [
{"table": "upload_files", "key_column": "key"},
{"table": "tool_files", "key_column": "file_key"},
]
storage_paths = ["image_files", "tools", "upload_files"]
# notify user and ask for confirmation
click.echo(click.style("This command will find and remove orphaned files on the storage,", fg="yellow"))
click.echo(
click.style("by comparing the files on the storage with the records in the following tables:", fg="yellow")
)
for files_table in files_tables:
click.echo(click.style(f"- {files_table['table']}", fg="yellow"))
click.echo(click.style("The following paths on the storage will be scanned to find orphaned files:", fg="yellow"))
for storage_path in storage_paths:
click.echo(click.style(f"- {storage_path}", fg="yellow"))
click.echo("")
click.echo(click.style("!!! USE WITH CAUTION !!!", fg="red"))
click.echo(
click.style(
"Currently, this command will work only for opendal based storage (STORAGE_TYPE=opendal).", fg="yellow"
)
)
click.echo(
click.style(
"Since not all patterns have been fully tested, please note that this command may delete unintended files.",
fg="yellow",
)
)
click.echo(
click.style("This cannot be undone. Please make sure to back up your storage before proceeding.", fg="yellow")
)
click.echo(
click.style(
(
"It is also recommended to run this during the maintenance window, "
"as this may cause high load on your instance."
),
fg="yellow",
)
)
if not force:
click.confirm("Do you want to proceed?", abort=True)
# start the cleanup process
click.echo(click.style("Starting orphaned files cleanup.", fg="white"))
# fetch file id and keys from each table
all_files_in_tables = []
try:
for files_table in files_tables:
click.echo(click.style(f"- Listing files from table {files_table['table']}", fg="white"))
query = f"SELECT {files_table['key_column']} FROM {files_table['table']}"
with db.engine.begin() as conn:
rs = conn.execute(db.text(query))
for i in rs:
all_files_in_tables.append(str(i[0]))
click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white"))
except Exception as e:
click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red"))
all_files_on_storage = []
for storage_path in storage_paths:
try:
click.echo(click.style(f"- Scanning files on storage path {storage_path}", fg="white"))
files = storage.scan(path=storage_path, files=True, directories=False)
all_files_on_storage.extend(files)
except FileNotFoundError as e:
click.echo(click.style(f" -> Skipping path {storage_path} as it does not exist.", fg="yellow"))
continue
except Exception as e:
click.echo(click.style(f" -> Error scanning files on storage path {storage_path}: {str(e)}", fg="red"))
continue
click.echo(click.style(f"Found {len(all_files_on_storage)} files on storage.", fg="white"))
# find orphaned files
orphaned_files = list(set(all_files_on_storage) - set(all_files_in_tables))
if not orphaned_files:
click.echo(click.style("No orphaned files found. There is nothing to remove.", fg="green"))
return
click.echo(click.style(f"Found {len(orphaned_files)} orphaned files.", fg="white"))
for file in orphaned_files:
click.echo(click.style(f"- orphaned file: {file}", fg="black"))
if not force:
click.confirm(f"Do you want to proceed to remove all {len(orphaned_files)} orphaned files?", abort=True)
# delete orphaned files
removed_files = 0
error_files = 0
for file in orphaned_files:
try:
storage.delete(file)
removed_files += 1
click.echo(click.style(f"- Removing orphaned file: {file}", fg="white"))
except Exception as e:
error_files += 1
click.echo(click.style(f"- Error deleting orphaned file {file}: {str(e)}", fg="red"))
continue
if error_files == 0:
click.echo(click.style(f"Removed {removed_files} orphaned files without errors.", fg="green"))
else:
click.echo(click.style(f"Removed {removed_files} orphaned files, with {error_files} errors.", fg="yellow"))

View File

@ -39,6 +39,7 @@ from .vdb.tencent_vector_config import TencentVectorDBConfig
from .vdb.tidb_on_qdrant_config import TidbOnQdrantConfig
from .vdb.tidb_vector_config import TiDBVectorConfig
from .vdb.upstash_config import UpstashConfig
from .vdb.vastbase_vector_config import VastbaseVectorConfig
from .vdb.vikingdb_config import VikingDBConfig
from .vdb.weaviate_config import WeaviateConfig
@ -270,6 +271,7 @@ class MiddlewareConfig(
OpenSearchConfig,
OracleConfig,
PGVectorConfig,
VastbaseVectorConfig,
PGVectoRSConfig,
QdrantConfig,
RelytConfig,

View File

@ -32,3 +32,4 @@ class LindormConfig(BaseSettings):
description="Using UGC index will store the same type of Index in a single index but can retrieve separately.",
default=False,
)
LINDORM_QUERY_TIMEOUT: Optional[float] = Field(description="The lindorm search request timeout (s)", default=2.0)

View File

@ -1,4 +1,5 @@
from typing import Optional
import enum
from typing import Literal, Optional
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
@ -9,6 +10,14 @@ class OpenSearchConfig(BaseSettings):
Configuration settings for OpenSearch
"""
class AuthMethod(enum.StrEnum):
"""
Authentication method for OpenSearch
"""
BASIC = "basic"
AWS_MANAGED_IAM = "aws_managed_iam"
OPENSEARCH_HOST: Optional[str] = Field(
description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')",
default=None,
@ -19,6 +28,16 @@ class OpenSearchConfig(BaseSettings):
default=9200,
)
OPENSEARCH_SECURE: bool = Field(
description="Whether to use SSL/TLS encrypted connection for OpenSearch (True for HTTPS, False for HTTP)",
default=False,
)
OPENSEARCH_AUTH_METHOD: AuthMethod = Field(
description="Authentication method for OpenSearch connection (default is 'basic')",
default=AuthMethod.BASIC,
)
OPENSEARCH_USER: Optional[str] = Field(
description="Username for authenticating with OpenSearch",
default=None,
@ -29,7 +48,11 @@ class OpenSearchConfig(BaseSettings):
default=None,
)
OPENSEARCH_SECURE: bool = Field(
description="Whether to use SSL/TLS encrypted connection for OpenSearch (True for HTTPS, False for HTTP)",
default=False,
OPENSEARCH_AWS_REGION: Optional[str] = Field(
description="AWS region for OpenSearch (e.g. 'us-west-2')",
default=None,
)
OPENSEARCH_AWS_SERVICE: Optional[Literal["es", "aoss"]] = Field(
description="AWS service for OpenSearch (e.g. 'aoss' for OpenSearch Serverless)", default=None
)

View File

@ -0,0 +1,45 @@
from typing import Optional
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class VastbaseVectorConfig(BaseSettings):
"""
Configuration settings for Vector (Vastbase with vector extension)
"""
VASTBASE_HOST: Optional[str] = Field(
description="Hostname or IP address of the Vastbase server with Vector extension (e.g., 'localhost')",
default=None,
)
VASTBASE_PORT: PositiveInt = Field(
description="Port number on which the Vastbase server is listening (default is 5432)",
default=5432,
)
VASTBASE_USER: Optional[str] = Field(
description="Username for authenticating with the Vastbase database",
default=None,
)
VASTBASE_PASSWORD: Optional[str] = Field(
description="Password for authenticating with the Vastbase database",
default=None,
)
VASTBASE_DATABASE: Optional[str] = Field(
description="Name of the Vastbase database to connect to",
default=None,
)
VASTBASE_MIN_CONNECTION: PositiveInt = Field(
description="Min connection of the Vastbase database",
default=1,
)
VASTBASE_MAX_CONNECTION: PositiveInt = Field(
description="Max connection of the Vastbase database",
default=5,
)

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
default="1.2.0",
default="1.3.1",
)
COMMIT_SHA: str = Field(

View File

@ -16,11 +16,25 @@ AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
if dify_config.ETL_TYPE == "Unstructured":
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls"]
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"]
DOCUMENT_EXTENSIONS.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
if dify_config.UNSTRUCTURED_API_URL:
DOCUMENT_EXTENSIONS.append("ppt")
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
else:
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"]
DOCUMENT_EXTENSIONS = [
"txt",
"markdown",
"md",
"mdx",
"pdf",
"html",
"htm",
"xlsx",
"xls",
"docx",
"csv",
"vtt",
"properties",
]
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])

View File

@ -186,7 +186,7 @@ class AnnotationUpdateDeleteApi(Resource):
app_id = str(app_id)
annotation_id = str(annotation_id)
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
return {"result": "success"}, 200
return {"result": "success"}, 204
class AnnotationBatchImportApi(Resource):

View File

@ -80,8 +80,6 @@ class ChatMessageTextApi(Resource):
@account_initialization_required
@get_app_model
def post(self, app_model: App):
from werkzeug.exceptions import InternalServerError
try:
parser = reqparse.RequestParser()
parser.add_argument("message_id", type=str, location="json")

View File

@ -84,7 +84,7 @@ class TraceAppConfigApi(Resource):
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
if not result:
raise TracingConfigNotExist()
return {"result": "success"}
return {"result": "success"}, 204
except Exception as e:
raise BadRequest(str(e))

View File

@ -65,7 +65,7 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
return {"result": "success"}, 200
return {"result": "success"}, 204
api.add_resource(ApiKeyAuthDataSource, "/api-key-auth/data-source")

View File

@ -657,6 +657,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.ELASTICSEARCH
| VectorType.ELASTICSEARCH_JA
| VectorType.PGVECTOR
| VectorType.VASTBASE
| VectorType.TIDB_ON_QDRANT
| VectorType.LINDORM
| VectorType.COUCHBASE
@ -706,6 +707,7 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.ELASTICSEARCH_JA
| VectorType.COUCHBASE
| VectorType.PGVECTOR
| VectorType.VASTBASE
| VectorType.LINDORM
| VectorType.OPENGAUSS
| VectorType.OCEANBASE

View File

@ -40,7 +40,7 @@ from core.indexing_runner import IndexingRunner
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.plugin.manager.exc import PluginDaemonClientSideError
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.rag.extractor.entity.extract_setting import ExtractSetting
from extensions.ext_database import db
from extensions.ext_redis import redis_client

View File

@ -131,7 +131,7 @@ class DatasetDocumentSegmentListApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
SegmentService.delete_segments(segment_ids, document, dataset)
return {"result": "success"}, 200
return {"result": "success"}, 204
class DatasetDocumentSegmentApi(Resource):
@ -333,7 +333,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
SegmentService.delete_segment(segment, document, dataset)
return {"result": "success"}, 200
return {"result": "success"}, 204
class DatasetDocumentSegmentBatchImportApi(Resource):
@ -590,7 +590,7 @@ class ChildChunkUpdateApi(Resource):
SegmentService.delete_child_chunk(child_chunk, dataset)
except ChildChunkDeleteIndexServiceError as e:
raise ChildChunkDeleteIndexError(str(e))
return {"result": "success"}, 200
return {"result": "success"}, 204
@setup_required
@login_required

View File

@ -135,7 +135,7 @@ class ExternalApiTemplateApi(Resource):
raise Forbidden()
ExternalDatasetService.delete_external_knowledge_api(current_user.current_tenant_id, external_knowledge_api_id)
return {"result": "success"}, 200
return {"result": "success"}, 204
class ExternalApiUseCheckApi(Resource):

View File

@ -82,7 +82,7 @@ class DatasetMetadataApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user)
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
return 200
return {"result": "success"}, 204
class DatasetMetadataBuiltInFieldApi(Resource):

View File

@ -113,7 +113,7 @@ class InstalledAppApi(InstalledAppResource):
db.session.delete(installed_app)
db.session.commit()
return {"result": "success", "message": "App uninstalled successfully"}
return {"result": "success", "message": "App uninstalled successfully"}, 204
def patch(self, installed_app):
parser = reqparse.RequestParser()

View File

@ -72,7 +72,7 @@ class SavedMessageApi(InstalledAppResource):
SavedMessageService.delete(app_model, current_user, message_id)
return {"result": "success"}
return {"result": "success"}, 204
api.add_resource(

View File

@ -99,7 +99,7 @@ class APIBasedExtensionDetailAPI(Resource):
APIBasedExtensionService.delete(extension_data_from_db)
return {"result": "success"}
return {"result": "success"}, 204
api.add_resource(CodeBasedExtensionAPI, "/code-based-extension")

View File

@ -86,7 +86,7 @@ class TagUpdateDeleteApi(Resource):
TagService.delete_tag(tag_id)
return 200
return 204
class TagBindingCreateApi(Resource):

View File

@ -5,7 +5,7 @@ from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.manager.exc import PluginPermissionDeniedError
from core.plugin.impl.exc import PluginPermissionDeniedError
from libs.login import login_required
from services.plugin.endpoint_service import EndpointService

View File

@ -10,7 +10,7 @@ from controllers.console import api
from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.manager.exc import PluginDaemonClientSideError
from core.plugin.impl.exc import PluginDaemonClientSideError
from libs.login import login_required
from models.account import TenantPluginPermission
from services.plugin.plugin_permission_service import PluginPermissionService

View File

@ -10,6 +10,7 @@ from configs import dify_config
from controllers.console.workspace.error import AccountNotInitializedError
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.account import AccountStatus
from models.dataset import RateLimitLog
from models.model import DifySetup
from services.feature_service import FeatureService, LicenseStatus
@ -24,7 +25,7 @@ def account_initialization_required(view):
# check account initialization
account = current_user
if account.status == "uninitialized":
if account.status == AccountStatus.UNINITIALIZED:
raise AccountNotInitializedError()
return view(*args, **kwargs)

View File

@ -75,7 +75,7 @@ class FilePreviewApi(Resource):
if args["as_attachment"]:
encoded_filename = quote(upload_file.name)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
response.headers["Content-Type"] = "application/octet-stream"
response.headers["Content-Type"] = "application/octet-stream"
return response

View File

@ -79,7 +79,7 @@ class AnnotationListApi(Resource):
class AnnotationUpdateDeleteApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(annotation_fields)
def post(self, app_model: App, end_user: EndUser, annotation_id):
def put(self, app_model: App, end_user: EndUser, annotation_id):
if not current_user.is_editor:
raise Forbidden()
@ -98,7 +98,7 @@ class AnnotationUpdateDeleteApi(Resource):
annotation_id = str(annotation_id)
AppAnnotationService.delete_app_annotation(app_model.id, annotation_id)
return {"result": "success"}, 200
return {"result": "success"}, 204
api.add_resource(AnnotationReplyActionApi, "/apps/annotation-reply/<string:action>")

View File

@ -14,6 +14,9 @@ from fields.conversation_fields import (
conversation_infinite_scroll_pagination_fields,
simple_conversation_fields,
)
from fields.conversation_variable_fields import (
conversation_variable_infinite_scroll_pagination_fields,
)
from libs.helper import uuid_value
from models.model import App, AppMode, EndUser
from services.conversation_service import ConversationService
@ -69,7 +72,7 @@ class ConversationDetailApi(Resource):
ConversationService.delete(app_model, conversation_id, end_user)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return {"result": "success"}, 200
return {"result": "success"}, 204
class ConversationRenameApi(Resource):
@ -93,6 +96,31 @@ class ConversationRenameApi(Resource):
raise NotFound("Conversation Not Exists.")
class ConversationVariablesApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@marshal_with(conversation_variable_infinite_scroll_pagination_fields)
def get(self, app_model: App, end_user: EndUser, c_id):
# conversational variable only for chat app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
parser = reqparse.RequestParser()
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args()
try:
return ConversationService.get_conversational_variable(
app_model, conversation_id, end_user, args["limit"], args["last_id"]
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
api.add_resource(ConversationRenameApi, "/conversations/<uuid:c_id>/name", endpoint="conversation_name")
api.add_resource(ConversationApi, "/conversations")
api.add_resource(ConversationDetailApi, "/conversations/<uuid:c_id>", endpoint="conversation_detail")
api.add_resource(ConversationVariablesApi, "/conversations/<uuid:c_id>/variables", endpoint="conversation_variables")

View File

@ -59,7 +59,7 @@ class WorkflowRunDetailApi(Resource):
Get a workflow task running detail
"""
app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW:
if app_mode not in [AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]:
raise NotWorkflowAppError()
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()

View File

@ -323,7 +323,7 @@ class DocumentDeleteApi(DatasetApiResource):
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return {"result": "success"}, 200
return {"result": "success"}, 204
class DocumentListApi(DatasetApiResource):

View File

@ -63,7 +63,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
DatasetService.check_dataset_permission(dataset, current_user)
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
return 200
return 204
class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):

View File

@ -159,7 +159,7 @@ class DatasetSegmentApi(DatasetApiResource):
if not segment:
raise NotFound("Segment not found.")
SegmentService.delete_segment(segment, document, dataset)
return {"result": "success"}, 200
return {"result": "success"}, 204
@cloud_edition_billing_resource_check("vector_space", "dataset")
def post(self, tenant_id, dataset_id, document_id, segment_id):
@ -344,7 +344,7 @@ class DatasetChildChunkApi(DatasetApiResource):
except ChildChunkDeleteIndexServiceError as e:
raise ChildChunkDeleteIndexError(str(e))
return {"result": "success"}, 200
return {"result": "success"}, 204
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")

View File

@ -67,7 +67,7 @@ class SavedMessageApi(WebApiResource):
SavedMessageService.delete(app_model, end_user, message_id)
return {"result": "success"}
return {"result": "success"}, 204
api.add_resource(SavedMessageListApi, "/saved-messages")

View File

@ -4,7 +4,7 @@ from typing import Any, Optional
from core.agent.entities import AgentInvokeMessage
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
from core.agent.strategy.base import BaseAgentStrategy
from core.plugin.manager.agent import PluginAgentManager
from core.plugin.impl.agent import PluginAgentClient
from core.plugin.utils.converter import convert_parameters_to_plugin_format
@ -42,7 +42,7 @@ class PluginAgentStrategy(BaseAgentStrategy):
"""
Invoke the agent strategy.
"""
manager = PluginAgentManager()
manager = PluginAgentClient()
initialized_params = self.initialize_parameters(params)
params = convert_parameters_to_plugin_format(initialized_params)

View File

@ -7,6 +7,7 @@ from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy.orm import sessionmaker
import contexts
from configs import dify_config
@ -24,6 +25,8 @@ from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotA
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
from core.workflow.repository import RepositoryFactory
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from extensions.ext_database import db
from factories import file_factory
from models.account import Account
@ -158,11 +161,22 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": application_generate_entity.app_config.tenant_id,
"app_id": application_generate_entity.app_config.app_id,
"session_factory": session_factory,
}
)
return self._generate(
workflow=workflow,
user=user,
invoke_from=invoke_from,
application_generate_entity=application_generate_entity,
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=conversation,
stream=streaming,
)
@ -215,11 +229,22 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": application_generate_entity.app_config.tenant_id,
"app_id": application_generate_entity.app_config.app_id,
"session_factory": session_factory,
}
)
return self._generate(
workflow=workflow,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=None,
stream=streaming,
)
@ -270,11 +295,22 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": application_generate_entity.app_config.tenant_id,
"app_id": application_generate_entity.app_config.app_id,
"session_factory": session_factory,
}
)
return self._generate(
workflow=workflow,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=None,
stream=streaming,
)
@ -286,6 +322,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
conversation: Optional[Conversation] = None,
stream: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
@ -296,6 +333,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param user: account or end user
:param invoke_from: invoke from source
:param application_generate_entity: application generate entity
:param workflow_node_execution_repository: repository for workflow node execution
:param conversation: conversation
:param stream: is stream
"""
@ -348,6 +386,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message,
user=user,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=stream,
)
@ -419,6 +458,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
stream: bool = False,
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
@ -430,6 +470,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param message: message
:param user: account or end user
:param stream: is stream
:param workflow_node_execution_repository: optional repository for workflow node execution
:return:
"""
# init generate task pipeline
@ -442,6 +483,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user=user,
stream=stream,
dialogue_count=self._dialogue_count,
workflow_node_execution_repository=workflow_node_execution_repository,
)
try:

View File

@ -65,6 +65,7 @@ from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes import NodeType
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from events.message_event import message_was_created
from extensions.ext_database import db
from models import Conversation, EndUser, Message, MessageFile
@ -93,6 +94,7 @@ class AdvancedChatAppGenerateTaskPipeline:
user: Union[Account, EndUser],
stream: bool,
dialogue_count: int,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
) -> None:
self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
@ -123,6 +125,7 @@ class AdvancedChatAppGenerateTaskPipeline:
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
},
workflow_node_execution_repository=workflow_node_execution_repository,
)
self._task_state = WorkflowTaskState()
@ -684,7 +687,9 @@ class AdvancedChatAppGenerateTaskPipeline:
)
elif isinstance(event, QueueMessageReplaceEvent):
# published by moderation
yield self._message_cycle_manager._message_replace_to_stream_response(answer=event.text)
yield self._message_cycle_manager._message_replace_to_stream_response(
answer=event.text, reason=event.reason
)
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
@ -695,7 +700,8 @@ class AdvancedChatAppGenerateTaskPipeline:
if output_moderation_answer:
self._task_state.answer = output_moderation_answer
yield self._message_cycle_manager._message_replace_to_stream_response(
answer=output_moderation_answer
answer=output_moderation_answer,
reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
)
# Save message

View File

@ -153,6 +153,8 @@ class MessageBasedAppGenerator(BaseAppGenerator):
query = application_generate_entity.query or "New conversation"
else:
query = next(iter(application_generate_entity.inputs.values()), "New conversation")
if isinstance(query, int):
query = str(query)
query = query or "New conversation"
conversation_name = (query[:20] + "") if len(query) > 20 else query

View File

@ -7,6 +7,7 @@ from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy.orm import sessionmaker
import contexts
from configs import dify_config
@ -22,6 +23,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.repository import RepositoryFactory
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from extensions.ext_database import db
from factories import file_factory
from models import Account, App, EndUser, Workflow
@ -133,12 +136,23 @@ class WorkflowAppGenerator(BaseAppGenerator):
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": application_generate_entity.app_config.tenant_id,
"app_id": application_generate_entity.app_config.app_id,
"session_factory": session_factory,
}
)
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
)
@ -151,6 +165,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
user: Union[Account, EndUser],
application_generate_entity: WorkflowAppGenerateEntity,
invoke_from: InvokeFrom,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
streaming: bool = True,
workflow_thread_pool_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
@ -162,6 +177,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param user: account or end user
:param application_generate_entity: application generate entity
:param invoke_from: invoke from source
:param workflow_node_execution_repository: repository for workflow node execution
:param streaming: is stream
:param workflow_thread_pool_id: workflow thread pool id
"""
@ -193,6 +209,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow=workflow,
queue_manager=queue_manager,
user=user,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=streaming,
)
@ -245,12 +262,23 @@ class WorkflowAppGenerator(BaseAppGenerator):
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": application_generate_entity.app_config.tenant_id,
"app_id": application_generate_entity.app_config.app_id,
"session_factory": session_factory,
}
)
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
)
@ -299,12 +327,23 @@ class WorkflowAppGenerator(BaseAppGenerator):
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": application_generate_entity.app_config.tenant_id,
"app_id": application_generate_entity.app_config.app_id,
"session_factory": session_factory,
}
)
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
)
@ -361,6 +400,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
@ -370,6 +410,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param queue_manager: queue manager
:param user: account or end user
:param stream: is stream
:param workflow_node_execution_repository: optional repository for workflow node execution
:return:
"""
# init generate task pipeline
@ -379,6 +420,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
queue_manager=queue_manager,
user=user,
stream=stream,
workflow_node_execution_repository=workflow_node_execution_repository,
)
try:

View File

@ -55,6 +55,7 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.enums import SystemVariableKey
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from extensions.ext_database import db
from models.account import Account
from models.enums import CreatedByRole
@ -82,6 +83,7 @@ class WorkflowAppGenerateTaskPipeline:
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
) -> None:
self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
@ -109,6 +111,7 @@ class WorkflowAppGenerateTaskPipeline:
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
},
workflow_node_execution_repository=workflow_node_execution_repository,
)
self._application_generate_entity = application_generate_entity

View File

@ -264,8 +264,16 @@ class QueueMessageReplaceEvent(AppQueueEvent):
QueueMessageReplaceEvent entity
"""
class MessageReplaceReason(StrEnum):
"""
Reason for message replace event
"""
OUTPUT_MODERATION = "output_moderation"
event: QueueEvent = QueueEvent.MESSAGE_REPLACE
text: str
reason: str
class QueueRetrieverResourcesEvent(AppQueueEvent):

View File

@ -148,6 +148,7 @@ class MessageReplaceStreamResponse(StreamResponse):
event: StreamEvent = StreamEvent.MESSAGE_REPLACE
answer: str
reason: str
class AgentThoughtStreamResponse(StreamResponse):

View File

@ -126,12 +126,12 @@ class BasedGenerateTaskPipeline:
if self._output_moderation_handler:
self._output_moderation_handler.stop_thread()
completion = self._output_moderation_handler.moderation_completion(
completion, flagged = self._output_moderation_handler.moderation_completion(
completion=completion, public_event=False
)
self._output_moderation_handler = None
return completion
if flagged:
return completion
return None

View File

@ -182,10 +182,12 @@ class MessageCycleManage:
from_variable_selector=from_variable_selector,
)
def _message_replace_to_stream_response(self, answer: str) -> MessageReplaceStreamResponse:
def _message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
"""
Message replace to stream response.
:param answer: answer
:return:
"""
return MessageReplaceStreamResponse(task_id=self._application_generate_entity.task_id, answer=answer)
return MessageReplaceStreamResponse(
task_id=self._application_generate_entity.task_id, answer=answer, reason=reason
)

View File

@ -6,7 +6,7 @@ from typing import Any, Optional, Union, cast
from uuid import uuid4
from sqlalchemy import func, select
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
@ -49,14 +49,13 @@ from core.file import FILE_MODEL_IDENTITY, File
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.repository import RepositoryFactory
from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.account import Account
from models.enums import CreatedByRole, WorkflowRunTriggeredFrom
from models.model import EndUser
@ -76,26 +75,13 @@ class WorkflowCycleManage:
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
workflow_system_variables: dict[SystemVariableKey, Any],
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
) -> None:
self._workflow_run: WorkflowRun | None = None
self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {}
self._application_generate_entity = application_generate_entity
self._workflow_system_variables = workflow_system_variables
# Initialize the session factory and repository
# We use the global db engine instead of the session passed to methods
# Disable expire_on_commit to avoid the need for merging objects
self._session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": self._application_generate_entity.app_config.tenant_id,
"app_id": self._application_generate_entity.app_config.app_id,
"session_factory": self._session_factory,
}
)
# We'll still keep the cache for backward compatibility and performance
# but use the repository for database operations
self._workflow_node_execution_repository = workflow_node_execution_repository
def _handle_workflow_run_start(
self,
@ -395,6 +381,8 @@ class WorkflowCycleManage:
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.execution_metadata = execution_metadata
self._workflow_node_execution_repository.update(workflow_node_execution)
return workflow_node_execution
def _handle_workflow_node_execution_retried(

View File

@ -798,7 +798,25 @@ class ProviderConfiguration(BaseModel):
provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
# resort provider_models
return sorted(provider_models, key=lambda x: x.model_type.value)
# Optimize sorting logic: first sort by provider.position order, then by model_type.value
# Get the position list for model types (retrieve only once for better performance)
model_type_positions = {}
if hasattr(self.provider, "position") and self.provider.position:
model_type_positions = self.provider.position
def get_sort_key(model: ModelWithProviderEntity):
# Get the position list for the current model type
positions = model_type_positions.get(model.model_type.value, [])
# If the model name is in the position list, use its index for sorting
# Otherwise use a large value (list length) to place undefined models at the end
position_index = positions.index(model.model) if model.model in positions else len(positions)
# Return composite sort key: (model_type value, model position index)
return (model.model_type.value, position_index)
# Sort using the composite sort key
return sorted(provider_models, key=get_sort_key)
def _get_system_provider_models(
self,

View File

@ -3,6 +3,8 @@ import logging
import re
from typing import Optional, cast
import json_repair
from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
from core.llm_generator.prompts import (
@ -366,7 +368,20 @@ class LLMGenerator:
),
)
generated_json_schema = cast(str, response.message.content)
raw_content = response.message.content
if not isinstance(raw_content, str):
raise ValueError(f"LLM response content must be a string, got: {type(raw_content)}")
try:
parsed_content = json.loads(raw_content)
except json.JSONDecodeError:
parsed_content = json_repair.loads(raw_content)
if not isinstance(parsed_content, dict | list):
raise ValueError(f"Failed to parse structured output from llm: {raw_content}")
generated_json_schema = json.dumps(parsed_content, indent=2, ensure_ascii=False)
return {"output": generated_json_schema, "error": ""}
except InvokeError as e:

View File

@ -1,7 +1,7 @@
# Written by YORKI MINAKO🤡, Edited by Xiaoyi
CONVERSATION_TITLE_PROMPT = """You need to decompose the user's input into "subject" and "intention" in order to accurately figure out what the user's input language actually is.
Notice: the language type user use could be diverse, which can be English, Chinese, Italian, Español, Arabic, Japanese, French, and etc.
MAKE SURE your output is the SAME language as the user's input!
Notice: the language type user uses could be diverse, which can be English, Chinese, Italian, Español, Arabic, Japanese, French, and etc.
ENSURE your output is in the SAME language as the user's input!
Your output is restricted only to: (Input language) Intention + Subject(short as possible)
Your output MUST be a valid JSON.
@ -19,7 +19,7 @@ User Input: hi, yesterday i had some burgers.
example 2:
User Input: hello
{
"Language Type": "The user's input is written in pure English",
"Language Type": "The user's input is pure English",
"Your Reasoning": "The language of my output must be pure English.",
"Your Output": "Greeting myself☺"
}
@ -46,7 +46,7 @@ example 5:
User Input: why小红的年龄is老than小明
{
"Language Type": "The user's input is English-Chinese mixed",
"Your Reasoning": "The English parts are subjective particles, the main intention is written in Chinese, besides, Chinese occupies a greater \"actual meaning\" than English, so the language of my output must be using Chinese.",
"Your Reasoning": "The English parts are filler words, the main intention is written in Chinese, besides, Chinese occupies a greater \"actual meaning\" than English, so the language of my output must be using Chinese.",
"Your Output": "询问小红和小明的年龄"
}
@ -114,6 +114,13 @@ JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE = (
"4. The returned object should contain at least one key-value pair.\n\n"
"5. The returned object should always be in the format: {result: ...}\n\n"
"Example:\n"
"/**\n"
" * Multiplies two numbers together.\n"
" *\n"
" * @param {number} arg1 - The first number to multiply.\n"
" * @param {number} arg2 - The second number to multiply.\n"
" * @returns {{ result: number }} The result of the multiplication.\n"
" */\n"
"function main(arg1, arg2) {\n"
" return {\n"
" result: arg1 * arg2\n"
@ -130,7 +137,7 @@ JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE = (
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
"Please help me predict the three most likely questions that human would ask, "
"and keeping each question under 20 characters.\n"
"and keep each question under 20 characters.\n"
"MAKE SURE your output is the SAME language as the Assistant's latest response. "
"The output must be an array in JSON format following the specified schema:\n"
'["question1","question2","question3"]\n'
@ -157,9 +164,9 @@ Here is a task description for which I would like you to create a high-quality p
</task_description>
Based on task description, please create a well-structured prompt template that another AI could use to consistently complete the task. The prompt template should include:
- Do not include <input> or <output> section and variables in the prompt, assume user will add them at their own will.
- Clear instructions for the AI that will be using this prompt, demarcated with <instructions> tags. The instructions should provide step-by-step directions on how to complete the task using the input variables. Also Specifies in the instructions that the output should not contain any xml tag.
- Clear instructions for the AI that will be using this prompt, demarcated with <instruction> tags. The instructions should provide step-by-step directions on how to complete the task using the input variables. Also Specifies in the instructions that the output should not contain any xml tag.
- Relevant examples if needed to clarify the task further, demarcated with <example> tags. Do not include variables in the prompt. Give three pairs of input and output examples.
- Include other relevant sections demarcated with appropriate XML tags like <examples>, <instructions>.
- Include other relevant sections demarcated with appropriate XML tags like <examples>, <instruction>.
- Use the same language as task description.
- Output in ``` xml ``` and start with <instruction>
Please generate the full prompt template with at least 300 words and output only the prompt template.
@ -172,7 +179,7 @@ Here is a task description for which I would like you to create a high-quality p
</task_description>
Based on task description, please create a well-structured prompt template that another AI could use to consistently complete the task. The prompt template should include:
- Descriptive variable names surrounded by {{ }} (two curly brackets) to indicate where the actual values will be substituted in. Choose variable names that clearly indicate the type of value expected. Variable names have to be composed of number, english alphabets and underline and nothing else.
- Clear instructions for the AI that will be using this prompt, demarcated with <instructions> tags. The instructions should provide step-by-step directions on how to complete the task using the input variables. Also Specifies in the instructions that the output should not contain any xml tag.
- Clear instructions for the AI that will be using this prompt, demarcated with <instruction> tags. The instructions should provide step-by-step directions on how to complete the task using the input variables. Also Specifies in the instructions that the output should not contain any xml tag.
- Relevant examples if needed to clarify the task further, demarcated with <example> tags. Do not use curly brackets any other than in <instruction> section.
- Any other relevant sections demarcated with appropriate XML tags like <input>, <output>, etc.
- Use the same language as task description.
@ -291,32 +298,30 @@ Your task is to convert simple user descriptions into properly formatted JSON Sc
{
"type": "object",
"properties": {
"properties": {
"songs": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"id": {
"type": "string"
},
"duration": {
"type": "string"
},
"aritst": {
"type": "string"
}
"songs": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"required": [
"name",
"id",
"duration",
"aritst"
]
}
"id": {
"type": "string"
},
"duration": {
"type": "string"
},
"aritst": {
"type": "string"
}
},
"required": [
"name",
"id",
"duration",
"aritst"
]
}
}
},

View File

@ -134,6 +134,9 @@ class ProviderEntity(BaseModel):
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
# position from plugin _position.yaml
position: Optional[dict[str, list[str]]] = {}
@field_validator("models", mode="before")
@classmethod
def validate_models(cls, v):

View File

@ -26,7 +26,7 @@ from core.model_runtime.errors.invoke import (
)
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError, PluginModelProviderEntity
from core.plugin.manager.model import PluginModelManager
from core.plugin.impl.model import PluginModelClient
class AIModel(BaseModel):
@ -141,7 +141,7 @@ class AIModel(BaseModel):
:param credentials: model credentials
:return: model schema
"""
plugin_model_manager = PluginModelManager()
plugin_model_manager = PluginModelClient()
cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}"
# sort credentials
sorted_credentials = sorted(credentials.items()) if credentials else []

View File

@ -2,7 +2,7 @@ import logging
import time
import uuid
from collections.abc import Generator, Sequence
from typing import Optional, Union
from typing import Optional, Union, cast
from pydantic import ConfigDict
@ -20,7 +20,8 @@ from core.model_runtime.entities.model_entities import (
PriceType,
)
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.plugin.manager.model import PluginModelManager
from core.model_runtime.utils.helper import convert_llm_result_chunk_to_str
from core.plugin.impl.model import PluginModelClient
logger = logging.getLogger(__name__)
@ -140,7 +141,7 @@ class LargeLanguageModel(AIModel):
result: Union[LLMResult, Generator[LLMResultChunk, None, None]]
try:
plugin_model_manager = PluginModelManager()
plugin_model_manager = PluginModelClient()
result = plugin_model_manager.invoke_llm(
tenant_id=self.tenant_id,
user_id=user or "unknown",
@ -280,7 +281,9 @@ class LargeLanguageModel(AIModel):
callbacks=callbacks,
)
assistant_message.content += chunk.delta.message.content
text = convert_llm_result_chunk_to_str(chunk.delta.message.content)
current_content = cast(str, assistant_message.content)
assistant_message.content = current_content + text
real_model = chunk.model
if chunk.delta.usage:
usage = chunk.delta.usage
@ -326,7 +329,7 @@ class LargeLanguageModel(AIModel):
:return:
"""
if dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED:
plugin_model_manager = PluginModelManager()
plugin_model_manager = PluginModelClient()
return plugin_model_manager.get_llm_num_tokens(
tenant_id=self.tenant_id,
user_id="unknown",

View File

@ -5,7 +5,7 @@ from pydantic import ConfigDict
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.plugin.manager.model import PluginModelManager
from core.plugin.impl.model import PluginModelClient
class ModerationModel(AIModel):
@ -31,7 +31,7 @@ class ModerationModel(AIModel):
self.started_at = time.perf_counter()
try:
plugin_model_manager = PluginModelManager()
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_moderation(
tenant_id=self.tenant_id,
user_id=user or "unknown",

View File

@ -3,7 +3,7 @@ from typing import Optional
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.plugin.manager.model import PluginModelManager
from core.plugin.impl.model import PluginModelClient
class RerankModel(AIModel):
@ -36,7 +36,7 @@ class RerankModel(AIModel):
:return: rerank result
"""
try:
plugin_model_manager = PluginModelManager()
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_rerank(
tenant_id=self.tenant_id,
user_id=user or "unknown",

View File

@ -4,7 +4,7 @@ from pydantic import ConfigDict
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.plugin.manager.model import PluginModelManager
from core.plugin.impl.model import PluginModelClient
class Speech2TextModel(AIModel):
@ -28,7 +28,7 @@ class Speech2TextModel(AIModel):
:return: text for given audio file
"""
try:
plugin_model_manager = PluginModelManager()
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_speech_to_text(
tenant_id=self.tenant_id,
user_id=user or "unknown",

View File

@ -6,7 +6,7 @@ from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.plugin.manager.model import PluginModelManager
from core.plugin.impl.model import PluginModelClient
class TextEmbeddingModel(AIModel):
@ -38,7 +38,7 @@ class TextEmbeddingModel(AIModel):
:return: embeddings result
"""
try:
plugin_model_manager = PluginModelManager()
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_text_embedding(
tenant_id=self.tenant_id,
user_id=user or "unknown",
@ -61,7 +61,7 @@ class TextEmbeddingModel(AIModel):
:param texts: texts to embed
:return:
"""
plugin_model_manager = PluginModelManager()
plugin_model_manager = PluginModelClient()
return plugin_model_manager.get_text_embedding_num_tokens(
tenant_id=self.tenant_id,
user_id="unknown",

View File

@ -6,7 +6,7 @@ from pydantic import ConfigDict
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.plugin.manager.model import PluginModelManager
from core.plugin.impl.model import PluginModelClient
logger = logging.getLogger(__name__)
@ -42,7 +42,7 @@ class TTSModel(AIModel):
:return: translated audio file
"""
try:
plugin_model_manager = PluginModelManager()
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_tts(
tenant_id=self.tenant_id,
user_id=user or "unknown",
@ -65,7 +65,7 @@ class TTSModel(AIModel):
:param credentials: The credentials required to access the TTS model.
:return: A list of voices supported by the TTS model.
"""
plugin_model_manager = PluginModelManager()
plugin_model_manager = PluginModelClient()
return plugin_model_manager.get_tts_model_voices(
tenant_id=self.tenant_id,
user_id="unknown",

View File

@ -22,8 +22,8 @@ from core.model_runtime.schema_validators.model_credential_schema_validator impo
from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
from core.plugin.entities.plugin import ModelProviderID
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.plugin.manager.asset import PluginAssetManager
from core.plugin.manager.model import PluginModelManager
from core.plugin.impl.asset import PluginAssetManager
from core.plugin.impl.model import PluginModelClient
logger = logging.getLogger(__name__)
@ -40,7 +40,7 @@ class ModelProviderFactory:
self.provider_position_map = {}
self.tenant_id = tenant_id
self.plugin_model_manager = PluginModelManager()
self.plugin_model_manager = PluginModelClient()
if not self.provider_position_map:
# get the path of current classes

View File

@ -1,6 +1,8 @@
import pydantic
from pydantic import BaseModel
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
def dump_model(model: BaseModel) -> dict:
if hasattr(pydantic, "model_dump"):
@ -8,3 +10,18 @@ def dump_model(model: BaseModel) -> dict:
return pydantic.model_dump(model) # type: ignore
else:
return model.model_dump()
def convert_llm_result_chunk_to_str(content: None | str | list[PromptMessageContentUnionTypes]) -> str:
if content is None:
message_text = ""
elif isinstance(content, str):
message_text = content
elif isinstance(content, list):
# Assuming the list contains PromptMessageContent objects with a "data" attribute
message_text = "".join(
item.data if hasattr(item, "data") and isinstance(item.data, str) else str(item) for item in content
)
else:
message_text = str(content)
return message_text

View File

@ -46,14 +46,14 @@ class OutputModeration(BaseModel):
if not self.thread:
self.thread = self.start_thread()
def moderation_completion(self, completion: str, public_event: bool = False) -> str:
def moderation_completion(self, completion: str, public_event: bool = False) -> tuple[str, bool]:
self.buffer = completion
self.is_final_chunk = True
result = self.moderation(tenant_id=self.tenant_id, app_id=self.app_id, moderation_buffer=completion)
if not result or not result.flagged:
return completion
return completion, False
if result.action == ModerationAction.DIRECT_OUTPUT:
final_output = result.preset_response
@ -61,9 +61,14 @@ class OutputModeration(BaseModel):
final_output = result.text
if public_event:
self.queue_manager.publish(QueueMessageReplaceEvent(text=final_output), PublishFrom.TASK_PIPELINE)
self.queue_manager.publish(
QueueMessageReplaceEvent(
text=final_output, reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION
),
PublishFrom.TASK_PIPELINE,
)
return final_output
return final_output, True
def start_thread(self) -> threading.Thread:
buffer_size = dify_config.MODERATION_BUFFER_SIZE
@ -112,7 +117,12 @@ class OutputModeration(BaseModel):
# trigger replace event
if self.thread_running:
self.queue_manager.publish(QueueMessageReplaceEvent(text=final_output), PublishFrom.TASK_PIPELINE)
self.queue_manager.publish(
QueueMessageReplaceEvent(
text=final_output, reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION
),
PublishFrom.TASK_PIPELINE,
)
if result.action == ModerationAction.DIRECT_OUTPUT:
break

View File

@ -7,6 +7,7 @@ class TracingProviderEnum(Enum):
LANGFUSE = "langfuse"
LANGSMITH = "langsmith"
OPIK = "opik"
WEAVE = "weave"
class BaseTracingConfig(BaseModel):
@ -88,5 +89,26 @@ class OpikConfig(BaseTracingConfig):
return v
class WeaveConfig(BaseTracingConfig):
"""
Model class for Weave tracing config.
"""
api_key: str
entity: str | None = None
project: str
endpoint: str = "https://trace.wandb.ai"
@field_validator("endpoint")
@classmethod
def set_value(cls, v, info: ValidationInfo):
if v is None or v == "":
v = "https://trace.wandb.ai"
if not v.startswith("https://"):
raise ValueError("endpoint must start with https://")
return v
OPS_FILE_PATH = "ops_trace/"
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"

View File

@ -29,7 +29,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
UnitEnum,
)
from core.ops.utils import filter_none_values
from core.repository.repository_factory import RepositoryFactory
from core.workflow.repository.repository_factory import RepositoryFactory
from extensions.ext_database import db
from models.model import EndUser

View File

@ -28,7 +28,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
LangSmithRunUpdateModel,
)
from core.ops.utils import filter_none_values, generate_dotted_order
from core.repository.repository_factory import RepositoryFactory
from core.workflow.repository.repository_factory import RepositoryFactory
from extensions.ext_database import db
from models.model import EndUser, MessageFile

View File

@ -22,7 +22,7 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
from core.repository.repository_factory import RepositoryFactory
from core.workflow.repository.repository_factory import RepositoryFactory
from extensions.ext_database import db
from models.model import EndUser, MessageFile

View File

@ -20,6 +20,7 @@ from core.ops.entities.config_entity import (
LangSmithConfig,
OpikConfig,
TracingProviderEnum,
WeaveConfig,
)
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
@ -34,7 +35,9 @@ from core.ops.entities.trace_entity import (
)
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
from core.ops.opik_trace.opik_trace import OpikDataTrace
from core.ops.utils import get_message_data
from core.ops.weave_trace.weave_trace import WeaveDataTrace
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
@ -43,8 +46,6 @@ from tasks.ops_trace_task import process_trace_tasks
def build_opik_trace_instance(config: OpikConfig):
from core.ops.opik_trace.opik_trace import OpikDataTrace
return OpikDataTrace(config)
@ -67,6 +68,12 @@ provider_config_map: dict[str, dict[str, Any]] = {
"other_keys": ["project", "url", "workspace"],
"trace_instance": lambda config: build_opik_trace_instance(config),
},
TracingProviderEnum.WEAVE.value: {
"config_class": WeaveConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "entity", "endpoint"],
"trace_instance": WeaveDataTrace,
},
}

View File

View File

@ -0,0 +1,97 @@
from typing import Any, Optional, Union
from pydantic import BaseModel, Field, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.ops.utils import replace_text_with_content
class WeaveTokenUsage(BaseModel):
input_tokens: Optional[int] = None
output_tokens: Optional[int] = None
total_tokens: Optional[int] = None
class WeaveMultiModel(BaseModel):
file_list: Optional[list[str]] = Field(None, description="List of files")
class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
id: str = Field(..., description="ID of the trace")
op: str = Field(..., description="Name of the operation")
inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the trace")
outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the trace")
attributes: Optional[Union[str, dict[str, Any], list, None]] = Field(
None, description="Metadata and attributes associated with trace"
)
exception: Optional[str] = Field(None, description="Exception message of the trace")
@field_validator("inputs", "outputs")
@classmethod
def ensure_dict(cls, v, info: ValidationInfo):
field_name = info.field_name
values = info.data
if v == {} or v is None:
return v
usage_metadata = {
"input_tokens": values.get("input_tokens", 0),
"output_tokens": values.get("output_tokens", 0),
"total_tokens": values.get("total_tokens", 0),
}
file_list = values.get("file_list", [])
if isinstance(v, str):
if field_name == "inputs":
return {
"messages": {
"role": "user",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
elif field_name == "outputs":
return {
"choices": {
"role": "ai",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
elif isinstance(v, list):
data = {}
if len(v) > 0 and isinstance(v[0], dict):
# rename text to content
v = replace_text_with_content(data=v)
if field_name == "inputs":
data = {
"messages": [
dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) for msg in v
]
if isinstance(v, list)
else v,
}
elif field_name == "outputs":
data = {
"choices": {
"role": "ai",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
return data
else:
return {
"choices": {
"role": "ai" if field_name == "outputs" else "user",
"content": str(v),
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
if isinstance(v, dict):
v["usage_metadata"] = usage_metadata
v["file_list"] = file_list
return v
return v

View File

@ -0,0 +1,420 @@
import json
import logging
import os
import uuid
from datetime import datetime, timedelta
from typing import Any, Optional, cast
import wandb
import weave
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import WeaveConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
TraceTaskName,
WorkflowTraceInfo,
)
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
from extensions.ext_database import db
from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecution
logger = logging.getLogger(__name__)
class WeaveDataTrace(BaseTraceInstance):
def __init__(
self,
weave_config: WeaveConfig,
):
super().__init__(weave_config)
self.weave_api_key = weave_config.api_key
self.project_name = weave_config.project
self.entity = weave_config.entity
# Login with API key first
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
if not login_status:
logger.error("Failed to login to Weights & Biases with the provided API key")
raise ValueError("Weave login failed")
# Then initialize weave client
self.weave_client = weave.init(
project_name=(f"{self.entity}/{self.project_name}" if self.entity else self.project_name)
)
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
self.calls: dict[str, Any] = {}
def get_project_url(
self,
):
try:
project_url = f"https://wandb.ai/{self.weave_client._project_id()}"
return project_url
except Exception as e:
logger.debug(f"Weave get run url failed: {str(e)}")
raise ValueError(f"Weave get run url failed: {str(e)}")
def trace(self, trace_info: BaseTraceInfo):
logger.debug(f"Trace info: {trace_info}")
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
if isinstance(trace_info, MessageTraceInfo):
self.message_trace(trace_info)
if isinstance(trace_info, ModerationTraceInfo):
self.moderation_trace(trace_info)
if isinstance(trace_info, SuggestedQuestionTraceInfo):
self.suggested_question_trace(trace_info)
if isinstance(trace_info, DatasetRetrievalTraceInfo):
self.dataset_retrieval_trace(trace_info)
if isinstance(trace_info, ToolTraceInfo):
self.tool_trace(trace_info)
if isinstance(trace_info, GenerateNameTraceInfo):
self.generate_name_trace(trace_info)
def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = trace_info.message_id or trace_info.workflow_run_id
if trace_info.start_time is None:
trace_info.start_time = datetime.now()
if trace_info.message_id:
message_attributes = trace_info.metadata
message_attributes["workflow_app_log_id"] = trace_info.workflow_app_log_id
message_attributes["message_id"] = trace_info.message_id
message_attributes["workflow_run_id"] = trace_info.workflow_run_id
message_attributes["trace_id"] = trace_id
message_attributes["start_time"] = trace_info.start_time
message_attributes["end_time"] = trace_info.end_time
message_attributes["tags"] = ["message", "workflow"]
message_run = WeaveTraceModel(
id=trace_info.message_id,
op=str(TraceTaskName.MESSAGE_TRACE.value),
inputs=dict(trace_info.workflow_run_inputs),
outputs=dict(trace_info.workflow_run_outputs),
total_tokens=trace_info.total_tokens,
attributes=message_attributes,
exception=trace_info.error,
file_list=[],
)
self.start_call(message_run, parent_run_id=trace_info.workflow_run_id)
self.finish_call(message_run)
workflow_attributes = trace_info.metadata
workflow_attributes["workflow_run_id"] = trace_info.workflow_run_id
workflow_attributes["trace_id"] = trace_id
workflow_attributes["start_time"] = trace_info.start_time
workflow_attributes["end_time"] = trace_info.end_time
workflow_attributes["tags"] = ["workflow"]
workflow_run = WeaveTraceModel(
file_list=trace_info.file_list,
total_tokens=trace_info.total_tokens,
id=trace_info.workflow_run_id,
op=str(TraceTaskName.WORKFLOW_TRACE.value),
inputs=dict(trace_info.workflow_run_inputs),
outputs=dict(trace_info.workflow_run_outputs),
attributes=workflow_attributes,
exception=trace_info.error,
)
self.start_call(workflow_run, parent_run_id=trace_info.message_id)
# through workflow_run_id get all_nodes_execution
workflow_nodes_execution_id_records = (
db.session.query(WorkflowNodeExecution.id)
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
.all()
)
for node_execution_id_record in workflow_nodes_execution_id_records:
node_execution = (
db.session.query(
WorkflowNodeExecution.id,
WorkflowNodeExecution.tenant_id,
WorkflowNodeExecution.app_id,
WorkflowNodeExecution.title,
WorkflowNodeExecution.node_type,
WorkflowNodeExecution.status,
WorkflowNodeExecution.inputs,
WorkflowNodeExecution.outputs,
WorkflowNodeExecution.created_at,
WorkflowNodeExecution.elapsed_time,
WorkflowNodeExecution.process_data,
WorkflowNodeExecution.execution_metadata,
)
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
.first()
)
if not node_execution:
continue
node_execution_id = node_execution.id
tenant_id = node_execution.tenant_id
app_id = node_execution.app_id
node_name = node_execution.title
node_type = node_execution.node_type
status = node_execution.status
if node_type == "llm":
inputs = (
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
)
else:
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = (
json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
)
node_total_tokens = execution_metadata.get("total_tokens", 0)
attributes = execution_metadata.copy()
attributes.update(
{
"workflow_run_id": trace_info.workflow_run_id,
"node_execution_id": node_execution_id,
"tenant_id": tenant_id,
"app_id": app_id,
"app_name": node_name,
"node_type": node_type,
"status": status,
}
)
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
if process_data and process_data.get("model_mode") == "chat":
attributes.update(
{
"ls_provider": process_data.get("model_provider", ""),
"ls_model_name": process_data.get("model_name", ""),
}
)
attributes["tags"] = ["node_execution"]
attributes["start_time"] = created_at
attributes["end_time"] = finished_at
attributes["elapsed_time"] = elapsed_time
attributes["workflow_run_id"] = trace_info.workflow_run_id
attributes["trace_id"] = trace_id
node_run = WeaveTraceModel(
total_tokens=node_total_tokens,
op=node_type,
inputs=inputs,
outputs=outputs,
file_list=trace_info.file_list,
attributes=attributes,
id=node_execution_id,
exception=None,
)
self.start_call(node_run, parent_run_id=trace_info.workflow_run_id)
self.finish_call(node_run)
self.finish_call(workflow_run)
def message_trace(self, trace_info: MessageTraceInfo):
# get message file data
file_list = cast(list[str], trace_info.file_list) or []
message_file_data: Optional[MessageFile] = trace_info.message_file_data
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
file_list.append(file_url)
attributes = trace_info.metadata
message_data = trace_info.message_data
if message_data is None:
return
message_id = message_data.id
user_id = message_data.from_account_id
attributes["user_id"] = user_id
if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
end_user_id = end_user_data.session_id
attributes["end_user_id"] = end_user_id
attributes["message_id"] = message_id
attributes["start_time"] = trace_info.start_time
attributes["end_time"] = trace_info.end_time
attributes["tags"] = ["message", str(trace_info.conversation_mode)]
message_run = WeaveTraceModel(
id=message_id,
op=str(TraceTaskName.MESSAGE_TRACE.value),
input_tokens=trace_info.message_tokens,
output_tokens=trace_info.answer_tokens,
total_tokens=trace_info.total_tokens,
inputs=trace_info.inputs,
outputs=trace_info.outputs,
exception=trace_info.error,
file_list=file_list,
attributes=attributes,
)
self.start_call(message_run)
# create llm run parented to message run
llm_run = WeaveTraceModel(
id=str(uuid.uuid4()),
input_tokens=trace_info.message_tokens,
output_tokens=trace_info.answer_tokens,
total_tokens=trace_info.total_tokens,
op="llm",
inputs=trace_info.inputs,
outputs=trace_info.outputs,
attributes=attributes,
file_list=[],
exception=None,
)
self.start_call(
llm_run,
parent_run_id=message_id,
)
self.finish_call(llm_run)
self.finish_call(message_run)
def moderation_trace(self, trace_info: ModerationTraceInfo):
if trace_info.message_data is None:
return
attributes = trace_info.metadata
attributes["tags"] = ["moderation"]
attributes["message_id"] = trace_info.message_id
attributes["start_time"] = trace_info.start_time or trace_info.message_data.created_at
attributes["end_time"] = trace_info.end_time or trace_info.message_data.updated_at
moderation_run = WeaveTraceModel(
id=str(uuid.uuid4()),
op=str(TraceTaskName.MODERATION_TRACE.value),
inputs=trace_info.inputs,
outputs={
"action": trace_info.action,
"flagged": trace_info.flagged,
"preset_response": trace_info.preset_response,
"inputs": trace_info.inputs,
},
attributes=attributes,
exception=getattr(trace_info, "error", None),
file_list=[],
)
self.start_call(moderation_run, parent_run_id=trace_info.message_id)
self.finish_call(moderation_run)
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
message_data = trace_info.message_data
if message_data is None:
return
attributes = trace_info.metadata
attributes["message_id"] = trace_info.message_id
attributes["tags"] = ["suggested_question"]
attributes["start_time"] = (trace_info.start_time or message_data.created_at,)
attributes["end_time"] = (trace_info.end_time or message_data.updated_at,)
suggested_question_run = WeaveTraceModel(
id=str(uuid.uuid4()),
op=str(TraceTaskName.SUGGESTED_QUESTION_TRACE.value),
inputs=trace_info.inputs,
outputs=trace_info.suggested_question,
attributes=attributes,
exception=trace_info.error,
file_list=[],
)
self.start_call(suggested_question_run, parent_run_id=trace_info.message_id)
self.finish_call(suggested_question_run)
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
if trace_info.message_data is None:
return
attributes = trace_info.metadata
attributes["message_id"] = trace_info.message_id
attributes["tags"] = ["dataset_retrieval"]
attributes["start_time"] = (trace_info.start_time or trace_info.message_data.created_at,)
attributes["end_time"] = (trace_info.end_time or trace_info.message_data.updated_at,)
dataset_retrieval_run = WeaveTraceModel(
id=str(uuid.uuid4()),
op=str(TraceTaskName.DATASET_RETRIEVAL_TRACE.value),
inputs=trace_info.inputs,
outputs={"documents": trace_info.documents},
attributes=attributes,
exception=getattr(trace_info, "error", None),
file_list=[],
)
self.start_call(dataset_retrieval_run, parent_run_id=trace_info.message_id)
self.finish_call(dataset_retrieval_run)
def tool_trace(self, trace_info: ToolTraceInfo):
attributes = trace_info.metadata
attributes["tags"] = ["tool", trace_info.tool_name]
attributes["start_time"] = trace_info.start_time
attributes["end_time"] = trace_info.end_time
tool_run = WeaveTraceModel(
id=str(uuid.uuid4()),
op=trace_info.tool_name,
inputs=trace_info.tool_inputs,
outputs=trace_info.tool_outputs,
file_list=[cast(str, trace_info.file_url)] if trace_info.file_url else [],
attributes=attributes,
exception=trace_info.error,
)
message_id = trace_info.message_id or getattr(trace_info, "conversation_id", None)
message_id = message_id or None
self.start_call(tool_run, parent_run_id=message_id)
self.finish_call(tool_run)
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
attributes = trace_info.metadata
attributes["tags"] = ["generate_name"]
attributes["start_time"] = trace_info.start_time
attributes["end_time"] = trace_info.end_time
name_run = WeaveTraceModel(
id=str(uuid.uuid4()),
op=str(TraceTaskName.GENERATE_NAME_TRACE.value),
inputs=trace_info.inputs,
outputs=trace_info.outputs,
attributes=attributes,
exception=getattr(trace_info, "error", None),
file_list=[],
)
self.start_call(name_run)
self.finish_call(name_run)
def api_check(self):
try:
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
if not login_status:
raise ValueError("Weave login failed")
else:
print("Weave login successful")
return True
except Exception as e:
logger.debug(f"Weave API check failed: {str(e)}")
raise ValueError(f"Weave API check failed: {str(e)}")
def start_call(self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None):
call = self.weave_client.create_call(op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes)
self.calls[run_data.id] = call
if parent_run_id:
self.calls[run_data.id].parent_id = parent_run_id
def finish_call(self, run_data: WeaveTraceModel):
call = self.calls.get(run_data.id)
if call:
self.weave_client.finish_call(call=call, output=run_data.outputs, exception=run_data.exception)
else:
raise ValueError(f"Call with id {run_data.id} not found")

View File

@ -72,7 +72,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
raise ValueError("missing query")
return cls.invoke_chat_app(app, user, conversation_id, query, stream, inputs, files)
elif app.mode == AppMode.WORKFLOW.value:
elif app.mode == AppMode.WORKFLOW:
return cls.invoke_workflow_app(app, user, stream, inputs, files)
elif app.mode == AppMode.COMPLETION:
return cls.invoke_completion_app(app, user, stream, inputs, files)

View File

@ -1,6 +1,7 @@
from collections.abc import Mapping
from datetime import datetime
from enum import StrEnum
from typing import Generic, Optional, TypeVar
from typing import Any, Generic, Optional, TypeVar
from pydantic import BaseModel, ConfigDict, Field
@ -158,3 +159,11 @@ class PluginInstallTaskStartResponse(BaseModel):
class PluginUploadResponse(BaseModel):
unique_identifier: str = Field(description="The unique identifier of the plugin.")
manifest: PluginDeclaration
class PluginOAuthAuthorizationUrlResponse(BaseModel):
authorization_url: str = Field(description="The URL of the authorization.")
class PluginOAuthCredentialsResponse(BaseModel):
credentials: Mapping[str, Any] = Field(description="The credentials of the OAuth.")

View File

@ -6,10 +6,10 @@ from core.plugin.entities.plugin import GenericProviderID
from core.plugin.entities.plugin_daemon import (
PluginAgentProviderEntity,
)
from core.plugin.manager.base import BasePluginManager
from core.plugin.impl.base import BasePluginClient
class PluginAgentManager(BasePluginManager):
class PluginAgentClient(BasePluginClient):
def fetch_agent_strategy_providers(self, tenant_id: str) -> list[PluginAgentProviderEntity]:
"""
Fetch agent providers for the given tenant.

View File

@ -1,7 +1,7 @@
from core.plugin.manager.base import BasePluginManager
from core.plugin.impl.base import BasePluginClient
class PluginAssetManager(BasePluginManager):
class PluginAssetManager(BasePluginClient):
def fetch_asset(self, tenant_id: str, id: str) -> bytes:
"""
Fetch an asset by id.

View File

@ -18,7 +18,7 @@ from core.model_runtime.errors.invoke import (
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse, PluginDaemonError, PluginDaemonInnerError
from core.plugin.manager.exc import (
from core.plugin.impl.exc import (
PluginDaemonBadRequestError,
PluginDaemonInternalServerError,
PluginDaemonNotFoundError,
@ -37,7 +37,7 @@ T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
logger = logging.getLogger(__name__)
class BasePluginManager:
class BasePluginClient:
def _request(
self,
method: str,

View File

@ -1,9 +1,9 @@
from pydantic import BaseModel
from core.plugin.manager.base import BasePluginManager
from core.plugin.impl.base import BasePluginClient
class PluginDebuggingManager(BasePluginManager):
class PluginDebuggingClient(BasePluginClient):
def get_debugging_key(self, tenant_id: str) -> str:
"""
Get the debugging key for the given tenant.

View File

@ -1,8 +1,8 @@
from core.plugin.entities.endpoint import EndpointEntityWithInstance
from core.plugin.manager.base import BasePluginManager
from core.plugin.impl.base import BasePluginClient
class PluginEndpointManager(BasePluginManager):
class PluginEndpointClient(BasePluginClient):
def create_endpoint(
self, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict
) -> bool:

View File

@ -18,10 +18,10 @@ from core.plugin.entities.plugin_daemon import (
PluginTextEmbeddingNumTokensResponse,
PluginVoicesResponse,
)
from core.plugin.manager.base import BasePluginManager
from core.plugin.impl.base import BasePluginClient
class PluginModelManager(BasePluginManager):
class PluginModelClient(BasePluginClient):
def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]:
"""
Fetch model providers for the given tenant.

View File

@ -0,0 +1,98 @@
from collections.abc import Mapping
from typing import Any
from werkzeug import Request
from core.plugin.entities.plugin_daemon import PluginOAuthAuthorizationUrlResponse, PluginOAuthCredentialsResponse
from core.plugin.impl.base import BasePluginClient
class OAuthHandler(BasePluginClient):
def get_authorization_url(
self,
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
system_credentials: Mapping[str, Any],
) -> PluginOAuthAuthorizationUrlResponse:
return self._request_with_plugin_daemon_response(
"POST",
f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url",
PluginOAuthAuthorizationUrlResponse,
data={
"user_id": user_id,
"data": {
"provider": provider,
"system_credentials": system_credentials,
},
},
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
def get_credentials(
self,
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
system_credentials: Mapping[str, Any],
request: Request,
) -> PluginOAuthCredentialsResponse:
"""
Get credentials from the given request.
"""
# encode request to raw http request
raw_request_bytes = self._convert_request_to_raw_data(request)
return self._request_with_plugin_daemon_response(
"POST",
f"plugin/{tenant_id}/dispatch/oauth/get_credentials",
PluginOAuthCredentialsResponse,
data={
"user_id": user_id,
"data": {
"provider": provider,
"system_credentials": system_credentials,
"raw_request_bytes": raw_request_bytes,
},
},
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
def _convert_request_to_raw_data(self, request: Request) -> bytes:
"""
Convert a Request object to raw HTTP data.
Args:
request: The Request object to convert.
Returns:
The raw HTTP data as bytes.
"""
# Start with the request line
method = request.method
path = request.path
protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1")
raw_data = f"{method} {path} {protocol}\r\n".encode()
# Add headers
for header_name, header_value in request.headers.items():
raw_data += f"{header_name}: {header_value}\r\n".encode()
# Add empty line to separate headers from body
raw_data += b"\r\n"
# Add body if exists
body = request.get_data(as_text=False)
if body:
raw_data += body
return raw_data

View File

@ -10,10 +10,10 @@ from core.plugin.entities.plugin import (
PluginInstallationSource,
)
from core.plugin.entities.plugin_daemon import PluginInstallTask, PluginInstallTaskStartResponse, PluginUploadResponse
from core.plugin.manager.base import BasePluginManager
from core.plugin.impl.base import BasePluginClient
class PluginInstallationManager(BasePluginManager):
class PluginInstaller(BasePluginClient):
def fetch_plugin_by_identifier(
self,
tenant_id: str,

View File

@ -5,11 +5,11 @@ from pydantic import BaseModel
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
from core.plugin.manager.base import BasePluginManager
from core.plugin.impl.base import BasePluginClient
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
class PluginToolManager(BasePluginManager):
class PluginToolManager(BasePluginClient):
def fetch_tool_providers(self, tenant_id: str) -> list[PluginToolProviderEntity]:
"""
Fetch tool providers for the given tenant.

View File

@ -32,6 +32,7 @@ class LindormVectorStoreConfig(BaseModel):
username: Optional[str] = None
password: Optional[str] = None
using_ugc: Optional[bool] = False
request_timeout: Optional[float] = 1.0 # timeout units: s
@model_validator(mode="before")
@classmethod
@ -251,9 +252,9 @@ class LindormVectorStore(BaseVector):
query = default_vector_search_query(query_vector=query_vector, k=top_k, filters=filters, **kwargs)
try:
params = {}
params = {"timeout": self._client_config.request_timeout}
if self._using_ugc:
params["routing"] = self._routing
params["routing"] = self._routing # type: ignore
response = self._client.search(index=self._collection_name, body=query, params=params)
except Exception:
logger.exception(f"Error executing vector search, query: {query}")
@ -304,8 +305,8 @@ class LindormVectorStore(BaseVector):
routing=routing,
routing_field=self._routing_field,
)
response = self._client.search(index=self._collection_name, body=full_text_query)
params = {"timeout": self._client_config.request_timeout}
response = self._client.search(index=self._collection_name, body=full_text_query, params=params)
docs = []
for hit in response["hits"]["hits"]:
docs.append(
@ -554,6 +555,7 @@ class LindormVectorStoreFactory(AbstractVectorFactory):
username=dify_config.LINDORM_USERNAME,
password=dify_config.LINDORM_PASSWORD,
using_ugc=dify_config.USING_UGC_INDEX,
request_timeout=dify_config.LINDORM_QUERY_TIMEOUT,
)
using_ugc = dify_config.USING_UGC_INDEX
if using_ugc is None:

View File

@ -27,8 +27,8 @@ class MilvusConfig(BaseModel):
uri: str # Milvus server URI
token: Optional[str] = None # Optional token for authentication
user: str # Username for authentication
password: str # Password for authentication
user: Optional[str] = None # Username for authentication
password: Optional[str] = None # Password for authentication
batch_size: int = 100 # Batch size for operations
database: str = "default" # Database name
enable_hybrid_search: bool = False # Flag to enable hybrid search
@ -43,10 +43,11 @@ class MilvusConfig(BaseModel):
"""
if not values.get("uri"):
raise ValueError("config MILVUS_URI is required")
if not values.get("user"):
raise ValueError("config MILVUS_USER is required")
if not values.get("password"):
raise ValueError("config MILVUS_PASSWORD is required")
if not values.get("token"):
if not values.get("user"):
raise ValueError("config MILVUS_USER is required")
if not values.get("password"):
raise ValueError("config MILVUS_PASSWORD is required")
return values
def to_milvus_params(self):
@ -356,11 +357,14 @@ class MilvusVector(BaseVector):
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _init_client(self, config) -> MilvusClient:
def _init_client(self, config: MilvusConfig) -> MilvusClient:
"""
Initialize and return a Milvus client.
"""
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
if config.token:
client = MilvusClient(uri=config.uri, token=config.token, db_name=config.database)
else:
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
return client

View File

@ -1,10 +1,9 @@
import json
import logging
import ssl
from typing import Any, Optional
from typing import Any, Literal, Optional
from uuid import uuid4
from opensearchpy import OpenSearch, helpers
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator
@ -24,9 +23,12 @@ logger = logging.getLogger(__name__)
class OpenSearchConfig(BaseModel):
host: str
port: int
secure: bool = False
auth_method: Literal["basic", "aws_managed_iam"] = "basic"
user: Optional[str] = None
password: Optional[str] = None
secure: bool = False
aws_region: Optional[str] = None
aws_service: Optional[str] = None
@model_validator(mode="before")
@classmethod
@ -35,24 +37,40 @@ class OpenSearchConfig(BaseModel):
raise ValueError("config OPENSEARCH_HOST is required")
if not values.get("port"):
raise ValueError("config OPENSEARCH_PORT is required")
if values.get("auth_method") == "aws_managed_iam":
if not values.get("aws_region"):
raise ValueError("config OPENSEARCH_AWS_REGION is required for AWS_MANAGED_IAM auth method")
if not values.get("aws_service"):
raise ValueError("config OPENSEARCH_AWS_SERVICE is required for AWS_MANAGED_IAM auth method")
return values
def create_ssl_context(self) -> ssl.SSLContext:
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE # Disable Certificate Validation
return ssl_context
def create_aws_managed_iam_auth(self) -> Urllib3AWSV4SignerAuth:
import boto3 # type: ignore
return Urllib3AWSV4SignerAuth(
credentials=boto3.Session().get_credentials(),
region=self.aws_region,
service=self.aws_service, # type: ignore[arg-type]
)
def to_opensearch_params(self) -> dict[str, Any]:
params = {
"hosts": [{"host": self.host, "port": self.port}],
"use_ssl": self.secure,
"verify_certs": self.secure,
"connection_class": Urllib3HttpConnection,
"pool_maxsize": 20,
}
if self.user and self.password:
if self.auth_method == "basic":
logger.info("Using basic authentication for OpenSearch Vector DB")
params["http_auth"] = (self.user, self.password)
if self.secure:
params["ssl_context"] = self.create_ssl_context()
elif self.auth_method == "aws_managed_iam":
logger.info("Using AWS managed IAM role for OpenSearch Vector DB")
params["http_auth"] = self.create_aws_managed_iam_auth()
return params
@ -76,16 +94,23 @@ class OpenSearchVector(BaseVector):
action = {
"_op_type": "index",
"_index": self._collection_name.lower(),
"_id": uuid4().hex,
"_source": {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata,
},
}
# See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377
if self._client_config.aws_service not in ["aoss"]:
action["_id"] = uuid4().hex
actions.append(action)
helpers.bulk(self._client, actions)
helpers.bulk(
client=self._client,
actions=actions,
timeout=30,
max_retries=3,
)
def get_ids_by_metadata_field(self, key: str, value: str):
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}}
@ -234,6 +259,7 @@ class OpenSearchVector(BaseVector):
},
}
logger.info(f"Creating OpenSearch index {self._collection_name.lower()}")
self._client.indices.create(index=self._collection_name.lower(), body=index_body)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
@ -252,9 +278,12 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
open_search_config = OpenSearchConfig(
host=dify_config.OPENSEARCH_HOST or "localhost",
port=dify_config.OPENSEARCH_PORT,
secure=dify_config.OPENSEARCH_SECURE,
auth_method=dify_config.OPENSEARCH_AUTH_METHOD.value,
user=dify_config.OPENSEARCH_USER,
password=dify_config.OPENSEARCH_PASSWORD,
secure=dify_config.OPENSEARCH_SECURE,
aws_region=dify_config.OPENSEARCH_AWS_REGION,
aws_service=dify_config.OPENSEARCH_AWS_SERVICE,
)
return OpenSearchVector(collection_name=collection_name, config=open_search_config)

View File

@ -0,0 +1,243 @@
import json
import uuid
from contextlib import contextmanager
from typing import Any
import psycopg2.extras # type: ignore
import psycopg2.pool # type: ignore
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
class VastbaseVectorConfig(BaseModel):
host: str
port: int
user: str
password: str
database: str
min_connection: int
max_connection: int
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config VASTBASE_HOST is required")
if not values["port"]:
raise ValueError("config VASTBASE_PORT is required")
if not values["user"]:
raise ValueError("config VASTBASE_USER is required")
if not values["password"]:
raise ValueError("config VASTBASE_PASSWORD is required")
if not values["database"]:
raise ValueError("config VASTBASE_DATABASE is required")
if not values["min_connection"]:
raise ValueError("config VASTBASE_MIN_CONNECTION is required")
if not values["max_connection"]:
raise ValueError("config VASTBASE_MAX_CONNECTION is required")
if values["min_connection"] > values["max_connection"]:
raise ValueError("config VASTBASE_MIN_CONNECTION should less than VASTBASE_MAX_CONNECTION")
return values
SQL_CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS {table_name} (
id UUID PRIMARY KEY,
text TEXT NOT NULL,
meta JSONB NOT NULL,
embedding floatvector({dimension}) NOT NULL
);
"""
SQL_CREATE_INDEX = """
CREATE INDEX IF NOT EXISTS embedding_cosine_v1_idx ON {table_name}
USING hnsw (embedding floatvector_cosine_ops) WITH (m = 16, ef_construction = 64);
"""
class VastbaseVector(BaseVector):
def __init__(self, collection_name: str, config: VastbaseVectorConfig):
super().__init__(collection_name)
self.pool = self._create_connection_pool(config)
self.table_name = f"embedding_{collection_name}"
def get_type(self) -> str:
return VectorType.VASTBASE
def _create_connection_pool(self, config: VastbaseVectorConfig):
return psycopg2.pool.SimpleConnectionPool(
config.min_connection,
config.max_connection,
host=config.host,
port=config.port,
user=config.user,
password=config.password,
database=config.database,
)
@contextmanager
def _get_cursor(self):
conn = self.pool.getconn()
cur = conn.cursor()
try:
yield cur
finally:
cur.close()
conn.commit()
self.pool.putconn(conn)
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self._create_collection(dimension)
return self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
values = []
pks = []
for i, doc in enumerate(documents):
if doc.metadata is not None:
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
pks.append(doc_id)
values.append(
(
doc_id,
doc.page_content,
json.dumps(doc.metadata),
embeddings[i],
)
)
with self._get_cursor() as cur:
psycopg2.extras.execute_values(
cur, f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES %s", values
)
return pks
def text_exists(self, id: str) -> bool:
with self._get_cursor() as cur:
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,))
return cur.fetchone() is not None
def get_by_ids(self, ids: list[str]) -> list[Document]:
with self._get_cursor() as cur:
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
docs = []
for record in cur:
docs.append(Document(page_content=record[1], metadata=record[0]))
return docs
def delete_by_ids(self, ids: list[str]) -> None:
# Avoiding crashes caused by performing delete operations on empty lists in certain scenarios
# Scenario 1: extract a document fails, resulting in a table not being created.
# Then clicking the retry button triggers a delete operation on an empty list.
if not ids:
return
with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
def delete_by_metadata_field(self, key: str, value: str) -> None:
with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""
Search the nearest neighbors to a vector.
:param query_vector: The input vector to search for similar items.
:param top_k: The number of nearest neighbors to return, default is 5.
:return: List of Documents that are nearest to the query vector.
"""
top_k = kwargs.get("top_k", 4)
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer")
with self._get_cursor() as cur:
cur.execute(
f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}"
f" ORDER BY distance LIMIT {top_k}",
(json.dumps(query_vector),),
)
docs = []
score_threshold = float(kwargs.get("score_threshold") or 0.0)
for record in cur:
metadata, text, distance = record
score = 1 - distance
metadata["score"] = score
if score > score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer")
with self._get_cursor() as cur:
cur.execute(
f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score
FROM {self.table_name}
WHERE to_tsvector(text) @@ plainto_tsquery(%s)
ORDER BY score DESC
LIMIT {top_k}""",
# f"'{query}'" is required in order to account for whitespace in query
(f"'{query}'", f"'{query}'"),
)
docs = []
for record in cur:
metadata, text, score = record
metadata["score"] = score
docs.append(Document(page_content=text, metadata=metadata))
return docs
def delete(self) -> None:
with self._get_cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
def _create_collection(self, dimension: int):
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
with self._get_cursor() as cur:
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
# Vastbase 支持的向量维度取值范围为 [1,16000]
if dimension <= 16000:
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
redis_client.set(collection_exist_cache_key, 1, ex=3600)
class VastbaseVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> VastbaseVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.VASTBASE, collection_name))
return VastbaseVector(
collection_name=collection_name,
config=VastbaseVectorConfig(
host=dify_config.VASTBASE_HOST or "localhost",
port=dify_config.VASTBASE_PORT,
user=dify_config.VASTBASE_USER or "dify",
password=dify_config.VASTBASE_PASSWORD or "",
database=dify_config.VASTBASE_DATABASE or "dify",
min_connection=dify_config.VASTBASE_MIN_CONNECTION,
max_connection=dify_config.VASTBASE_MAX_CONNECTION,
),
)

View File

@ -74,6 +74,10 @@ class Vector:
from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory
return PGVectorFactory
case VectorType.VASTBASE:
from core.rag.datasource.vdb.pyvastbase.vastbase_vector import VastbaseVectorFactory
return VastbaseVectorFactory
case VectorType.PGVECTO_RS:
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory

View File

@ -7,7 +7,9 @@ class VectorType(StrEnum):
MILVUS = "milvus"
MYSCALE = "myscale"
PGVECTOR = "pgvector"
VASTBASE = "vastbase"
PGVECTO_RS = "pgvecto-rs"
QDRANT = "qdrant"
RELYT = "relyt"
TIDB_VECTOR = "tidb_vector"

View File

@ -20,7 +20,7 @@ class WaterCrawlProvider:
}
if options.get("crawl_sub_pages", True):
spider_options["page_limit"] = options.get("limit", 1)
spider_options["max_depth"] = options.get("depth", 1)
spider_options["max_depth"] = options.get("max_depth", 1)
spider_options["include_paths"] = options.get("includes", "").split(",") if options.get("includes") else []
spider_options["exclude_paths"] = options.get("excludes", "").split(",") if options.get("excludes") else []

View File

@ -52,14 +52,16 @@ class RerankModelRunner(BaseRerankRunner):
rerank_documents = []
for result in rerank_result.docs:
# format document
rerank_document = Document(
page_content=result.text,
metadata=documents[result.index].metadata,
provider=documents[result.index].provider,
)
if rerank_document.metadata is not None:
rerank_document.metadata["score"] = result.score
rerank_documents.append(rerank_document)
if score_threshold is None or result.score >= score_threshold:
# format document
rerank_document = Document(
page_content=result.text,
metadata=documents[result.index].metadata,
provider=documents[result.index].provider,
)
if rerank_document.metadata is not None:
rerank_document.metadata["score"] = result.score
rerank_documents.append(rerank_document)
return rerank_documents
rerank_documents.sort(key=lambda x: x.metadata.get("score", 0.0), reverse=True)
return rerank_documents[:top_n] if top_n else rerank_documents

View File

@ -2,5 +2,5 @@
Repository implementations for data access.
This package contains concrete implementations of the repository interfaces
defined in the core.repository package.
defined in the core.workflow.repository package.
"""

View File

@ -11,9 +11,9 @@ from typing import Any
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from core.repository.repository_factory import RepositoryFactory
from core.repositories.workflow_node_execution import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.repository.repository_factory import RepositoryFactory
from extensions.ext_database import db
from repositories.workflow_node_execution import SQLAlchemyWorkflowNodeExecutionRepository
logger = logging.getLogger(__name__)

View File

@ -2,7 +2,7 @@
WorkflowNodeExecution repository implementations.
"""
from repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository
__all__ = [
"SQLAlchemyWorkflowNodeExecutionRepository",

View File

@ -10,7 +10,7 @@ from sqlalchemy import UnaryExpression, asc, delete, desc, select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.repository.workflow_node_execution_repository import OrderConfig
from core.workflow.repository.workflow_node_execution_repository import OrderConfig
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)

View File

@ -35,8 +35,9 @@ class BuiltinToolProviderController(ToolProviderController):
provider_yaml["credentials_for_provider"][credential_name]["name"] = credential_name
credentials_schema = []
for credential in provider_yaml.get("credentials_for_provider", {}).values():
credentials_schema.append(credential)
for credential in provider_yaml.get("credentials_for_provider", {}):
credential_dict = provider_yaml.get("credentials_for_provider", {}).get(credential, {})
credentials_schema.append(credential_dict)
super().__init__(
entity=ToolProviderEntity(

View File

@ -1,6 +1,6 @@
from typing import Any
from core.plugin.manager.tool import PluginToolManager
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin, ToolProviderType

View File

@ -1,7 +1,7 @@
from collections.abc import Generator
from typing import Any, Optional
from core.plugin.manager.tool import PluginToolManager
from core.plugin.impl.tool import PluginToolManager
from core.plugin.utils.converter import convert_parameters_to_plugin_format
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime

View File

@ -246,7 +246,7 @@ class ToolEngine:
+ "you do not need to create it, just tell the user to check it now."
)
elif response.type == ToolInvokeMessage.MessageType.JSON:
result = json.dumps(
result += json.dumps(
cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False
)
else:

View File

@ -10,7 +10,7 @@ from yarl import URL
import contexts
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.manager.tool import PluginToolManager
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.plugin_tool.provider import PluginToolProviderController

View File

@ -7,8 +7,8 @@ from core.agent.plugin_entities import AgentStrategyParameter
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.plugin.manager.exc import PluginDaemonClientSideError
from core.plugin.manager.plugin import PluginInstallationManager
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.plugin.impl.plugin import PluginInstaller
from core.provider_manager import ProviderManager
from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
from core.tools.tool_manager import ToolManager
@ -297,7 +297,7 @@ class AgentNode(ToolNode):
Get agent strategy icon
:return:
"""
manager = PluginInstallationManager()
manager = PluginInstaller()
plugins = manager.list_plugins(self.tenant_id)
try:
current_plugin = next(

Some files were not shown because too many files have changed in this diff Show More